mirror of https://github.com/inclusionAI/AReaL
334 lines
12 KiB
Python
334 lines
12 KiB
Python
import functools
|
|
from typing import Dict, List, Optional
|
|
|
|
import torch
|
|
from tensordict import TensorDict
|
|
|
|
from arealite.api.cli_args import MicroBatchSpec, PPOActorConfig
|
|
from arealite.api.engine_api import TrainEngine
|
|
from arealite.engine.fsdp_engine import FSDPEngine
|
|
from arealite.utils.data import split_padded_tensor_dict_into_mb_list
|
|
from arealite.utils.functional import (
|
|
gather_logprobs,
|
|
gather_logprobs_entropy,
|
|
masked_normalization,
|
|
ppo_actor_loss_fn,
|
|
)
|
|
from realhf.base import stats_tracker
|
|
|
|
|
|
class PPOActor:
|
|
|
|
def __init__(self, config: PPOActorConfig, engine: TrainEngine):
|
|
self.config = config
|
|
self.engine = engine
|
|
|
|
self.reward_bias = config.reward_bias
|
|
self.reward_scaling = config.reward_scaling
|
|
self.reward_clip = config.reward_clip
|
|
|
|
self.group_reward_norm = config.group_reward_norm
|
|
self.group_adv_norm = config.group_adv_norm
|
|
self.group_size = config.group_size
|
|
|
|
self.kl_ctl = config.kl_ctl
|
|
|
|
self.adv_norm = config.adv_norm
|
|
self.discount = config.discount
|
|
self.gae_lambda = config.gae_lambda
|
|
self.mask_no_eos_with_zero = config.mask_no_eos_with_zero
|
|
|
|
self.temperature = config.temperature
|
|
|
|
@torch.no_grad()
|
|
def compute_logp(
|
|
self,
|
|
data: TensorDict,
|
|
temperature: Optional[float] = None,
|
|
) -> torch.Tensor | None:
|
|
|
|
def calc_logprobs(logits, input_data):
|
|
labels = torch.roll(input_data["input_ids"], shifts=-1, dims=-1)
|
|
logprobs = gather_logprobs(logits, labels, temperature or 1.0)
|
|
return logprobs
|
|
|
|
self.engine.eval()
|
|
return self.engine.forward(
|
|
input_=data,
|
|
post_hook=calc_logprobs,
|
|
aggregate_fn=lambda xs: torch.cat(xs, dim=-1),
|
|
)
|
|
|
|
def compute_advantages(self, data: TensorDict) -> None:
|
|
bs = data["input_ids"].shape[0]
|
|
max_seqlen = data["input_ids"].shape[1]
|
|
batch_indices = torch.arange(
|
|
bs, device=data["input_ids"].device, dtype=torch.long
|
|
)
|
|
|
|
# Compute rewards using the reward function in synchronous RLVR pipeline.
|
|
reward_score = data["rewards"]
|
|
reward_score = (reward_score + self.reward_bias) * self.reward_scaling
|
|
reward_score = torch.clip(
|
|
reward_score, max=self.reward_clip, min=-self.reward_clip
|
|
)
|
|
if self.group_reward_norm:
|
|
for i in range(bs // self.group_size):
|
|
s = slice(i * self.group_size, (i + 1) * self.group_size)
|
|
r = reward_score[s]
|
|
reward_score[s] = (r - r.mean()) / (r.std() + 1e-9)
|
|
|
|
loss_mask = data["loss_mask"].float()
|
|
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
|
|
# Apply the mask to log probabilities.
|
|
if not self.config.use_decoupled_loss and self.config.recompute_logprob:
|
|
# Overwrite logprobs produced by the inference engine
|
|
old_logp = data["logprobs"] = data["prox_logp"]
|
|
else:
|
|
old_logp = torch.roll(data["logprobs"], shifts=-1, dims=-1)
|
|
if not self.config.use_decoupled_loss:
|
|
# prox logp not available, use inferenced logp
|
|
data["prox_logp"] = old_logp
|
|
ref_logp = data.get("ref_logp", torch.zeros_like(old_logp))
|
|
ref_logp *= loss_mask
|
|
old_logp *= loss_mask
|
|
|
|
# Compute KL-regularized rewards.
|
|
attn_mask = data["attention_mask"]
|
|
seqlens = attn_mask.sum(-1).long()
|
|
seq_no_eos_mask = seqlens == attn_mask.shape[1]
|
|
rewards = -self.kl_ctl * (old_logp - ref_logp)
|
|
kl_rewards = rewards.clone()
|
|
# KL rewards at the next token after eos is zero.
|
|
rewards[batch_indices, seqlens - 1] = 0
|
|
indices = torch.clip(seqlens - 2, min=0)
|
|
if self.mask_no_eos_with_zero:
|
|
rewards[batch_indices, indices] += torch.where(
|
|
seq_no_eos_mask, 0, reward_score
|
|
)
|
|
else:
|
|
rewards[batch_indices, indices] += reward_score
|
|
|
|
# Compute GAE.
|
|
if "values" not in data:
|
|
values = torch.zeros_like(rewards)
|
|
else:
|
|
values = data["values"]
|
|
advantages_reversed = [
|
|
torch.zeros(bs, dtype=torch.float32, device=values.device)
|
|
]
|
|
lastgaelam = 0
|
|
for t in reversed(range(max_seqlen - 1)):
|
|
nextvalues = values[:, t + 1]
|
|
if t == max_seqlen - 2:
|
|
nextvalues *= seq_no_eos_mask
|
|
delta = rewards[:, t] + self.discount * nextvalues - values[:, t]
|
|
lastgaelam = delta + self.discount * self.gae_lambda * lastgaelam
|
|
advantages_reversed.append(lastgaelam)
|
|
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
|
|
|
# Optionally perform advantage normalization.
|
|
if self.adv_norm or self.group_adv_norm:
|
|
if self.group_adv_norm:
|
|
adv_list = []
|
|
for i in range(0, bs, self.group_size):
|
|
s = slice(i * self.group_size, (i + 1) * self.group_size)
|
|
adv = advantages[s]
|
|
m = loss_mask[s]
|
|
adv_list.append(masked_normalization(adv, m, all_reduce=False))
|
|
advantages = torch.cat(adv_list, 0)
|
|
else:
|
|
advantages = masked_normalization(advantages, loss_mask)
|
|
|
|
# Store data in the dict.
|
|
data["advantages"] = advantages
|
|
data["kl_rewards"] = kl_rewards
|
|
data["tot_rewards"] = rewards
|
|
data["loss_mask"] = loss_mask
|
|
# because we have rolled old_logp by -1
|
|
data["logprobs"] = old_logp
|
|
|
|
def ppo_update(self, data: TensorDict) -> List[Dict[str, float]]:
|
|
attn_mask = data["attention_mask"]
|
|
loss_mask = data["loss_mask"]
|
|
reward_score = data["rewards"]
|
|
seqlens = attn_mask.sum(-1)
|
|
|
|
all_stats = []
|
|
########## Logging code starts ##########
|
|
result_denominators = {
|
|
"correct_n_seqs": (reward_score > 0).bool(),
|
|
"incorrect_n_seqs": (reward_score <= 0).bool(),
|
|
}
|
|
global_denominators = dict(
|
|
n_seqs=torch.ones_like(reward_score, dtype=torch.bool),
|
|
n_tokens=torch.ones_like(loss_mask, dtype=torch.bool),
|
|
n_valid_tokens=loss_mask.bool(),
|
|
**result_denominators,
|
|
)
|
|
stats_tracker.denominator(**global_denominators)
|
|
stats_tracker.stat(
|
|
correct_seq_len=seqlens.float(), denominator="correct_n_seqs"
|
|
)
|
|
stats_tracker.stat(
|
|
incorrect_seq_len=seqlens.float(), denominator="incorrect_n_seqs"
|
|
)
|
|
|
|
stats = dict(
|
|
advantages=data["advantages"],
|
|
kl_rewards=data["kl_rewards"],
|
|
final_reward=data["tot_rewards"],
|
|
)
|
|
stats_tracker.stat(**stats, denominator="n_valid_tokens")
|
|
|
|
prompt_lens = []
|
|
prompt_lens = data["attention_mask"].sum(-1) - data["loss_mask"].sum(-1)
|
|
seq_stats = dict(
|
|
no_eos_ratios=(seqlens == attn_mask.shape[-1]).float(),
|
|
task_reward=reward_score.float(),
|
|
prompt_len=prompt_lens.float(),
|
|
seq_len=seqlens.float(),
|
|
)
|
|
stats_tracker.stat(**seq_stats, denominator="n_seqs")
|
|
scalars = dict(
|
|
mask_no_eos_with_zero=self.config.mask_no_eos_with_zero,
|
|
eps_clip=self.config.eps_clip,
|
|
)
|
|
if self.config.c_clip is not None:
|
|
scalars["c_clip"] = self.config.c_clip
|
|
scalars["use_dual_clip"] = 1
|
|
else:
|
|
scalars["use_dual_clip"] = 0
|
|
if self.config.behav_imp_weight_cap is not None:
|
|
scalars["behav_imp_weight_cap"] = self.config.behav_imp_weight_cap
|
|
stats_tracker.scalar(**scalars)
|
|
|
|
global_stats = stats_tracker.export(reduce_group=self.engine.parallelism_group)
|
|
for k in global_denominators:
|
|
keys = list(global_stats.keys())
|
|
for k2 in keys:
|
|
if k2.endswith(k):
|
|
global_stats.pop(k2)
|
|
########## Logging code ends ##########
|
|
|
|
for key in ["rewards", "tot_rewards", "kl_rewards", "versions"]:
|
|
data.pop(key, None)
|
|
# NOTE: calling engine.train() is critical to enabling gradient checkpointing
|
|
self.engine.train()
|
|
mb_inputs = split_padded_tensor_dict_into_mb_list(
|
|
data,
|
|
mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches),
|
|
)
|
|
for mb in mb_inputs.mbs:
|
|
train_stat = self.engine.train_batch(
|
|
mb,
|
|
loss_fn=functools.partial(
|
|
grpo_loss_fn,
|
|
temperature=self.temperature,
|
|
eps_clip=self.config.eps_clip,
|
|
c_clip=self.config.c_clip,
|
|
behav_imp_weight_cap=self.config.behav_imp_weight_cap,
|
|
),
|
|
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
|
|
)
|
|
stats_tracker.scalar(**train_stat)
|
|
all_stats.append(
|
|
stats_tracker.export(reduce_group=self.engine.parallelism_group)
|
|
)
|
|
all_stats[0].update(global_stats)
|
|
return all_stats
|
|
|
|
|
|
class FSDPPPOActor(FSDPEngine):
|
|
|
|
def __init__(self, config: PPOActorConfig):
|
|
super().__init__(config)
|
|
self.actor = PPOActor(config, self)
|
|
|
|
@torch.no_grad()
|
|
def compute_logp(self, *args, **kwargs) -> torch.Tensor | None:
|
|
return self.actor.compute_logp(*args, **kwargs)
|
|
|
|
@torch.no_grad()
|
|
def compute_advantages(self, *args, **kwargs) -> None:
|
|
self.actor.compute_advantages(*args, **kwargs)
|
|
|
|
def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]:
|
|
return self.actor.ppo_update(*args, **kwargs)
|
|
|
|
|
|
def grpo_loss_fn(
|
|
logits: torch.Tensor,
|
|
input_data: Dict,
|
|
temperature: float,
|
|
eps_clip: float,
|
|
c_clip: float | None,
|
|
behav_imp_weight_cap: float | None,
|
|
):
|
|
"""Loss function for actor step, all inputs should be splitted into
|
|
pipeline micro batches, returns loss and logging stats."""
|
|
input_ids = input_data["input_ids"]
|
|
old_logp = input_data["logprobs"]
|
|
advantages = input_data["advantages"]
|
|
loss_mask = input_data["loss_mask"].bool()
|
|
prox_logp = input_data["prox_logp"]
|
|
|
|
logprobs, entropy = gather_logprobs_entropy(
|
|
logits, torch.roll(input_ids, shifts=-1, dims=-1), temperature
|
|
)
|
|
entropy = entropy.detach()
|
|
loss, stat = ppo_actor_loss_fn(
|
|
logprobs=logprobs,
|
|
old_logprobs=old_logp,
|
|
advantages=advantages,
|
|
eps_clip=eps_clip,
|
|
loss_mask=loss_mask,
|
|
c_clip=c_clip,
|
|
proximal_logprobs=prox_logp,
|
|
behav_imp_weight_cap=behav_imp_weight_cap,
|
|
)
|
|
|
|
# Log training statistics
|
|
stats_tracker.denominator(
|
|
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
|
|
n_valid_tokens=loss_mask.bool(),
|
|
clipped_tokens=stat["clip_mask"],
|
|
dual_clipped_tokens=stat["dual_clip_mask"],
|
|
)
|
|
|
|
stats_tracker.stat(
|
|
importance_weight=stat["importance_weight"],
|
|
approx_kl=stat["approx_kl"],
|
|
new_logp=logprobs.detach(),
|
|
old_logp=old_logp,
|
|
entropy=entropy.float(),
|
|
actor_loss=stat["loss"],
|
|
clip_ratio=stat["clip_mask"].float(),
|
|
dual_clip_ratio=stat["dual_clip_mask"].float(),
|
|
denominator="n_valid_tokens",
|
|
)
|
|
if "behave_imp_weight" in stat:
|
|
stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"])
|
|
stats_tracker.stat(
|
|
behave_imp_weight=stat["behave_imp_weight"],
|
|
behave_approx_kl=stat["behave_approx_kl"],
|
|
denominator="unclipped_behave_tokens",
|
|
)
|
|
vocab_min_logits = logits.detach().min(-1).values.float()
|
|
vocab_max_logits = logits.detach().max(-1).values.float()
|
|
stats_tracker.stat(
|
|
vocab_min_logits=vocab_min_logits,
|
|
vocab_max_logits=vocab_max_logits,
|
|
denominator="n_tokens",
|
|
)
|
|
|
|
clip_mask = stat["clip_mask"]
|
|
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
|
|
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
|
|
stats_tracker.stat(
|
|
clipped_new_logp=clipped_new_logp,
|
|
clipped_old_logp=clipped_old_logp,
|
|
denominator="clipped_tokens",
|
|
)
|
|
return loss
|