mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
parent
a78fd2dd24
commit
347bcc07a6
|
@ -135,6 +135,85 @@ class TrainEngineConfig:
|
|||
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PPOActorConfig(TrainEngineConfig):
|
||||
# Core PPO/GRPO Parameters
|
||||
group_size: int = field(
|
||||
default=1, metadata={"help": "Number of sequences in each group"}
|
||||
)
|
||||
group_adv_norm: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Normalize advantages within each prompt group rather than globally"
|
||||
},
|
||||
)
|
||||
ppo_n_minibatches: int = field(
|
||||
default=4, metadata={"help": "Number of minibatches for each PPO update"}
|
||||
)
|
||||
eps_clip: float = field(
|
||||
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
|
||||
)
|
||||
c_clip: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
|
||||
},
|
||||
)
|
||||
temperature: float = field(
|
||||
default=1.0, metadata={"help": "Temperature during generation."}
|
||||
)
|
||||
# Reward
|
||||
group_reward_norm: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
|
||||
},
|
||||
)
|
||||
reward_scaling: float = field(
|
||||
default=1.0, metadata={"help": "Reward scaling factor"}
|
||||
)
|
||||
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
|
||||
reward_clip: float = field(
|
||||
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
|
||||
)
|
||||
mask_no_eos_with_zero: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Mask truncated generations (no EOS token) and exclude from training"
|
||||
},
|
||||
)
|
||||
|
||||
# Advantage Estimation
|
||||
discount: float = field(
|
||||
default=1.0, metadata={"help": "Discount factor for future rewards"}
|
||||
)
|
||||
gae_lambda: float = field(
|
||||
default=1.0, metadata={"help": "Lambda parameter for GAE"}
|
||||
)
|
||||
adv_norm: bool = field(
|
||||
default=True, metadata={"help": "Enable advantage normalization"}
|
||||
)
|
||||
|
||||
# KL Control
|
||||
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
|
||||
|
||||
# Asynchronous RL
|
||||
recompute_logprob: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Recompute logp and replace the logp returned by inference."},
|
||||
)
|
||||
use_decoupled_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
|
||||
)
|
||||
behav_imp_weight_cap: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SGLangConfig:
|
||||
"""Configuration for SGLang runtime. Refer to:
|
||||
|
@ -262,8 +341,8 @@ class SGLangConfig:
|
|||
|
||||
@dataclass
|
||||
class InferenceEngineConfig:
|
||||
experiment_name: str
|
||||
trial_name: str
|
||||
experiment_name: str = MISSING
|
||||
trial_name: str = MISSING
|
||||
max_concurrent_rollouts: None | int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
|
@ -550,6 +629,17 @@ class SFTConfig(BaseExperimentConfig):
|
|||
model: TrainEngineConfig = field(default_factory=TrainEngineConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRPOConfig(BaseExperimentConfig):
|
||||
async_training: bool = field(default=True)
|
||||
gconfig: GenerationHyperparameters = field(
|
||||
default_factory=GenerationHyperparameters
|
||||
)
|
||||
rollout: InferenceEngineConfig = field(default_factory=InferenceEngineConfig)
|
||||
actor: PPOActorConfig = field(default_factory=PPOActorConfig)
|
||||
ref: PPOActorConfig = field(default_factory=PPOActorConfig)
|
||||
|
||||
|
||||
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
|
|
@ -77,7 +77,7 @@ class TrainEngine(abc.ABC):
|
|||
|
||||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict[str, float]:
|
||||
|
@ -87,7 +87,7 @@ class TrainEngine(abc.ABC):
|
|||
@torch.no_grad()
|
||||
def eval_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
input_: TensorDict,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
|
@ -97,7 +97,7 @@ class TrainEngine(abc.ABC):
|
|||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_: Dict,
|
||||
input_: TensorDict,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
|
|
|
@ -6,3 +6,115 @@ def gather_logprobs(logits: torch.Tensor, labels: torch.Tensor):
|
|||
log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
||||
return log_probs_labels
|
||||
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
@torch.compile
|
||||
@torch.no_grad()
|
||||
def calc_entropy(logits, cu_seqlens):
|
||||
probs = torch.nn.functional.softmax(logits.detach().float(), dim=-1)
|
||||
entropy = -torch.sum(probs * torch.log(probs + 1e-7), dim=-1)
|
||||
return entropy
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def masked_normalization(
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
dim=None,
|
||||
unbiased=False,
|
||||
eps=1e-5,
|
||||
high_precision=True,
|
||||
all_reduce=True,
|
||||
reduce_group=None,
|
||||
):
|
||||
dtype = torch.float64 if high_precision else torch.float32
|
||||
x = x.to(dtype)
|
||||
if dim is None:
|
||||
dim = tuple(range(len(x.shape)))
|
||||
if mask is None:
|
||||
factor = torch.tensor(
|
||||
np.prod([x.shape[d] for d in dim]), dtype=dtype, device=x.device
|
||||
)
|
||||
else:
|
||||
mask = mask.to(dtype)
|
||||
x = x * mask
|
||||
factor = mask.sum(dim, keepdim=True)
|
||||
x_sum = x.sum(dim=dim, keepdim=True)
|
||||
x_sum_sq = x.square().sum(dim=dim, keepdim=True)
|
||||
if dist.is_initialized() and all_reduce:
|
||||
dist.all_reduce(factor, op=dist.ReduceOp.SUM, group=reduce_group)
|
||||
dist.all_reduce(x_sum, op=dist.ReduceOp.SUM, group=reduce_group)
|
||||
dist.all_reduce(
|
||||
x_sum_sq,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=reduce_group,
|
||||
)
|
||||
mean = x_sum / factor
|
||||
meansq = x_sum_sq / factor
|
||||
var = meansq - mean**2
|
||||
if unbiased:
|
||||
var *= factor / (factor - 1)
|
||||
return ((x - mean) / (var.sqrt() + eps)).float()
|
||||
|
||||
|
||||
def ppo_actor_loss_fn(
|
||||
logprobs: torch.Tensor,
|
||||
old_logprobs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
eps_clip: float,
|
||||
loss_mask: torch.Tensor,
|
||||
c_clip: Optional[float] = None,
|
||||
proximal_logprobs: Optional[torch.Tensor] = None,
|
||||
behav_imp_weight_cap: Optional[float] = None,
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
denorm_logprobs = (
|
||||
proximal_logprobs if proximal_logprobs is not None else old_logprobs
|
||||
)
|
||||
loss_mask_count = loss_mask.count_nonzero() or 1
|
||||
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
|
||||
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
|
||||
pg_loss1 = -advantages * ratio
|
||||
pg_loss2 = -advantages * clipped_ratio
|
||||
clip_mask = pg_loss1.detach() < pg_loss2.detach()
|
||||
pg_loss = torch.max(pg_loss1, pg_loss2)
|
||||
if c_clip is not None:
|
||||
assert c_clip > 1.0, c_clip
|
||||
pg_loss3 = torch.sign(advantages) * c_clip * advantages
|
||||
dual_clip_mask = pg_loss3.detach() < pg_loss.detach()
|
||||
pg_loss = torch.min(pg_loss, pg_loss3)
|
||||
else:
|
||||
dual_clip_mask = torch.zeros_like(clip_mask)
|
||||
if proximal_logprobs is not None:
|
||||
behav_kl = proximal_logprobs - old_logprobs
|
||||
behav_imp_weight = behav_kl.exp()
|
||||
behav_mask = (
|
||||
(behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask)
|
||||
if behav_imp_weight_cap is not None
|
||||
else loss_mask
|
||||
)
|
||||
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
|
||||
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
|
||||
pg_loss = pg_loss * behav_imp_weight
|
||||
logging_loss = pg_loss.detach()
|
||||
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
|
||||
clip_mask.logical_and_(loss_mask)
|
||||
dual_clip_mask.logical_and_(loss_mask)
|
||||
stat = dict(
|
||||
loss=logging_loss,
|
||||
importance_weight=ratio.detach(),
|
||||
approx_kl=(logprobs - denorm_logprobs).detach(),
|
||||
clip_mask=clip_mask,
|
||||
dual_clip_mask=dual_clip_mask,
|
||||
)
|
||||
if proximal_logprobs is not None:
|
||||
stat["behave_imp_weight"] = behav_imp_weight
|
||||
stat["behave_approx_kl"] = behav_kl
|
||||
stat["behave_mask"] = behav_mask
|
||||
return pg_loss, stat
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import getpass
|
||||
import os
|
||||
import time
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
|
@ -21,6 +21,8 @@ class StatsLogger:
|
|||
self.ft_spec = ft_spec
|
||||
self.init()
|
||||
|
||||
self._last_commit_step = 0
|
||||
|
||||
def init(self):
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
|
@ -61,7 +63,7 @@ class StatsLogger:
|
|||
if self.summary_writer is not None:
|
||||
self.summary_writer.close()
|
||||
|
||||
def commit(self, epoch: int, step: int, global_step: int, data: Dict):
|
||||
def commit(self, epoch: int, step: int, global_step: int, data: Dict | List[Dict]):
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
self.info(
|
||||
|
@ -69,12 +71,17 @@ class StatsLogger:
|
|||
f"Step {step+1}/{self.ft_spec.steps_per_epoch} "
|
||||
f"Train step {global_step + 1}/{self.ft_spec.total_train_steps} done."
|
||||
)
|
||||
self.info("Stats:")
|
||||
self.print_stats(data)
|
||||
wandb.log(data, step=global_step)
|
||||
if self.summary_writer is not None:
|
||||
for key, val in data.items():
|
||||
self.summary_writer.add_scalar(f"{key}", val, global_step)
|
||||
if isinstance(data, Dict):
|
||||
data = [data]
|
||||
log_step = max(global_step, self._last_commit_step)
|
||||
for i, item in enumerate(data):
|
||||
self.info(f"Stats ({i+1}/{len(data)}):")
|
||||
self.print_stats(item)
|
||||
wandb.log(item, step=log_step + i)
|
||||
if self.summary_writer is not None:
|
||||
for key, val in item.items():
|
||||
self.summary_writer.add_scalar(f"{key}", val, log_step + i)
|
||||
self._last_commit_step = log_step + len(data) - 1
|
||||
|
||||
def print_stats(self, stats: Dict[str, float]):
|
||||
self.info("\n" + tabulate_stats(stats))
|
||||
|
|
Loading…
Reference in New Issue