This commit is contained in:
bowei.fw 2025-07-10 11:05:08 +08:00
parent a78fd2dd24
commit 347bcc07a6
4 changed files with 222 additions and 13 deletions

View File

@ -135,6 +135,85 @@ class TrainEngineConfig:
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
@dataclass
class PPOActorConfig(TrainEngineConfig):
# Core PPO/GRPO Parameters
group_size: int = field(
default=1, metadata={"help": "Number of sequences in each group"}
)
group_adv_norm: bool = field(
default=False,
metadata={
"help": "Normalize advantages within each prompt group rather than globally"
},
)
ppo_n_minibatches: int = field(
default=4, metadata={"help": "Number of minibatches for each PPO update"}
)
eps_clip: float = field(
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
)
c_clip: Optional[float] = field(
default=None,
metadata={
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
},
)
temperature: float = field(
default=1.0, metadata={"help": "Temperature during generation."}
)
# Reward
group_reward_norm: bool = field(
default=False,
metadata={
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
},
)
reward_scaling: float = field(
default=1.0, metadata={"help": "Reward scaling factor"}
)
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
reward_clip: float = field(
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
)
mask_no_eos_with_zero: bool = field(
default=False,
metadata={
"help": "Mask truncated generations (no EOS token) and exclude from training"
},
)
# Advantage Estimation
discount: float = field(
default=1.0, metadata={"help": "Discount factor for future rewards"}
)
gae_lambda: float = field(
default=1.0, metadata={"help": "Lambda parameter for GAE"}
)
adv_norm: bool = field(
default=True, metadata={"help": "Enable advantage normalization"}
)
# KL Control
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
# Asynchronous RL
recompute_logprob: bool = field(
default=False,
metadata={"help": "Recompute logp and replace the logp returned by inference."},
)
use_decoupled_loss: bool = field(
default=False,
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
)
behav_imp_weight_cap: Optional[float] = field(
default=None,
metadata={
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
},
)
@dataclass
class SGLangConfig:
"""Configuration for SGLang runtime. Refer to:
@ -262,8 +341,8 @@ class SGLangConfig:
@dataclass
class InferenceEngineConfig:
experiment_name: str
trial_name: str
experiment_name: str = MISSING
trial_name: str = MISSING
max_concurrent_rollouts: None | int = field(
default=None,
metadata={
@ -550,6 +629,17 @@ class SFTConfig(BaseExperimentConfig):
model: TrainEngineConfig = field(default_factory=TrainEngineConfig)
@dataclass
class GRPOConfig(BaseExperimentConfig):
async_training: bool = field(default=True)
gconfig: GenerationHyperparameters = field(
default_factory=GenerationHyperparameters
)
rollout: InferenceEngineConfig = field(default_factory=InferenceEngineConfig)
actor: PPOActorConfig = field(default_factory=PPOActorConfig)
ref: PPOActorConfig = field(default_factory=PPOActorConfig)
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
parser = argparse.ArgumentParser()
parser.add_argument(

View File

@ -77,7 +77,7 @@ class TrainEngine(abc.ABC):
def train_batch(
self,
input_: Dict,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict[str, float]:
@ -87,7 +87,7 @@ class TrainEngine(abc.ABC):
@torch.no_grad()
def eval_batch(
self,
input_: Dict,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
@ -97,7 +97,7 @@ class TrainEngine(abc.ABC):
@torch.no_grad()
def forward(
self,
input_: Dict,
input_: TensorDict,
output_seqlens: List[List[int]] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,

View File

@ -6,3 +6,115 @@ def gather_logprobs(logits: torch.Tensor, labels: torch.Tensor):
log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
return log_probs_labels
from typing import Dict, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
@torch.compile
@torch.no_grad()
def calc_entropy(logits, cu_seqlens):
probs = torch.nn.functional.softmax(logits.detach().float(), dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-7), dim=-1)
return entropy
@torch.no_grad()
def masked_normalization(
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
dim=None,
unbiased=False,
eps=1e-5,
high_precision=True,
all_reduce=True,
reduce_group=None,
):
dtype = torch.float64 if high_precision else torch.float32
x = x.to(dtype)
if dim is None:
dim = tuple(range(len(x.shape)))
if mask is None:
factor = torch.tensor(
np.prod([x.shape[d] for d in dim]), dtype=dtype, device=x.device
)
else:
mask = mask.to(dtype)
x = x * mask
factor = mask.sum(dim, keepdim=True)
x_sum = x.sum(dim=dim, keepdim=True)
x_sum_sq = x.square().sum(dim=dim, keepdim=True)
if dist.is_initialized() and all_reduce:
dist.all_reduce(factor, op=dist.ReduceOp.SUM, group=reduce_group)
dist.all_reduce(x_sum, op=dist.ReduceOp.SUM, group=reduce_group)
dist.all_reduce(
x_sum_sq,
op=dist.ReduceOp.SUM,
group=reduce_group,
)
mean = x_sum / factor
meansq = x_sum_sq / factor
var = meansq - mean**2
if unbiased:
var *= factor / (factor - 1)
return ((x - mean) / (var.sqrt() + eps)).float()
def ppo_actor_loss_fn(
logprobs: torch.Tensor,
old_logprobs: torch.Tensor,
advantages: torch.Tensor,
eps_clip: float,
loss_mask: torch.Tensor,
c_clip: Optional[float] = None,
proximal_logprobs: Optional[torch.Tensor] = None,
behav_imp_weight_cap: Optional[float] = None,
) -> Tuple[torch.Tensor, Dict]:
denorm_logprobs = (
proximal_logprobs if proximal_logprobs is not None else old_logprobs
)
loss_mask_count = loss_mask.count_nonzero() or 1
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * clipped_ratio
clip_mask = pg_loss1.detach() < pg_loss2.detach()
pg_loss = torch.max(pg_loss1, pg_loss2)
if c_clip is not None:
assert c_clip > 1.0, c_clip
pg_loss3 = torch.sign(advantages) * c_clip * advantages
dual_clip_mask = pg_loss3.detach() < pg_loss.detach()
pg_loss = torch.min(pg_loss, pg_loss3)
else:
dual_clip_mask = torch.zeros_like(clip_mask)
if proximal_logprobs is not None:
behav_kl = proximal_logprobs - old_logprobs
behav_imp_weight = behav_kl.exp()
behav_mask = (
(behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask)
if behav_imp_weight_cap is not None
else loss_mask
)
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
pg_loss = pg_loss * behav_imp_weight
logging_loss = pg_loss.detach()
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
clip_mask.logical_and_(loss_mask)
dual_clip_mask.logical_and_(loss_mask)
stat = dict(
loss=logging_loss,
importance_weight=ratio.detach(),
approx_kl=(logprobs - denorm_logprobs).detach(),
clip_mask=clip_mask,
dual_clip_mask=dual_clip_mask,
)
if proximal_logprobs is not None:
stat["behave_imp_weight"] = behav_imp_weight
stat["behave_approx_kl"] = behav_kl
stat["behave_mask"] = behav_mask
return pg_loss, stat

View File

@ -1,7 +1,7 @@
import getpass
import os
import time
from typing import Dict
from typing import Dict, List
import torch.distributed as dist
import wandb
@ -21,6 +21,8 @@ class StatsLogger:
self.ft_spec = ft_spec
self.init()
self._last_commit_step = 0
def init(self):
if dist.is_initialized() and dist.get_rank() != 0:
return
@ -61,7 +63,7 @@ class StatsLogger:
if self.summary_writer is not None:
self.summary_writer.close()
def commit(self, epoch: int, step: int, global_step: int, data: Dict):
def commit(self, epoch: int, step: int, global_step: int, data: Dict | List[Dict]):
if dist.is_initialized() and dist.get_rank() != 0:
return
self.info(
@ -69,12 +71,17 @@ class StatsLogger:
f"Step {step+1}/{self.ft_spec.steps_per_epoch} "
f"Train step {global_step + 1}/{self.ft_spec.total_train_steps} done."
)
self.info("Stats:")
self.print_stats(data)
wandb.log(data, step=global_step)
if self.summary_writer is not None:
for key, val in data.items():
self.summary_writer.add_scalar(f"{key}", val, global_step)
if isinstance(data, Dict):
data = [data]
log_step = max(global_step, self._last_commit_step)
for i, item in enumerate(data):
self.info(f"Stats ({i+1}/{len(data)}):")
self.print_stats(item)
wandb.log(item, step=log_step + i)
if self.summary_writer is not None:
for key, val in item.items():
self.summary_writer.add_scalar(f"{key}", val, log_step + i)
self._last_commit_step = log_step + len(data) - 1
def print_stats(self, stats: Dict[str, float]):
self.info("\n" + tabulate_stats(stats))