mirror of https://github.com/inclusionAI/AReaL
1310 lines
47 KiB
Python
1310 lines
47 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
import dataclasses
|
|
from typing import Dict, List, Literal, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
import realhf.api.core.model_api as model_api
|
|
import realhf.impl.model.utils.ppo_functional as ppo_functional
|
|
from realhf.api.core.data_api import (
|
|
RL_TASKS,
|
|
MicroBatchSpec,
|
|
SequenceSample,
|
|
SequenceSplitSpec,
|
|
)
|
|
from realhf.base import constants, logging, stats_tracker
|
|
from realhf.base.datapack import flat2d
|
|
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
|
|
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
|
from realhf.impl.model.nn.real_llm_generate import concat_prompt_to_generation_output
|
|
from realhf.impl.model.utils.functional import (
|
|
gather_packed_shifted_log_probs,
|
|
masked_normalization,
|
|
)
|
|
|
|
logger = logging.getLogger("PackedPPOInterface")
|
|
|
|
|
|
def get_score(prompt_ids, generated, query_ids, tokenizer):
|
|
prompt_strs = tokenizer.batch_decode(
|
|
prompt_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
|
)
|
|
seq_strs = tokenizer.batch_decode(
|
|
generated, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
|
)
|
|
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
|
|
return parse_lines_in_parallel(seq_strs, query_id_strs)
|
|
|
|
|
|
def topk(scores, gen_lengths, k) -> list:
|
|
indexed = list(enumerate(zip(scores, gen_lengths)))
|
|
|
|
sorted_indices = sorted(indexed, key=lambda x: (x[1][0], x[1][1]), reverse=True)[:k]
|
|
|
|
return [idx for idx, _ in sorted_indices]
|
|
|
|
|
|
def _ppo_actor_loss_from_model_outputs(
|
|
logits: torch.FloatTensor, # [tot_seqlen, vocab_size]
|
|
input_: SequenceSample,
|
|
kl_adapter: ppo_functional.KLController, # const
|
|
eps_clip: float, # const
|
|
c_clip: float | None,
|
|
early_stop_imp_ratio: Optional[float], # const
|
|
early_stop_kl: Optional[float], # const
|
|
temperature: Optional[float] = 1,
|
|
) -> torch.Tensor:
|
|
"""Loss function for ppo actor step, all inputs should be splitted into
|
|
pipeline micro batches, returns loss and logging stats."""
|
|
packed_input_ids = input_.data["packed_input_ids"]
|
|
cu_seqlens = (
|
|
torch.nn.functional.pad(
|
|
torch.tensor(flat2d(input_.seqlens["packed_input_ids"])).cumsum(0),
|
|
(1, 0),
|
|
)
|
|
.int()
|
|
.to(logits.device)
|
|
)
|
|
ppo_loss_mask = input_.data["ppo_loss_mask"]
|
|
advantages = input_.data["advantages"]
|
|
old_logp = input_.data["old_logp"]
|
|
kl_rewards = input_.data["kl_rewards"]
|
|
|
|
if temperature is not None:
|
|
logits /= temperature
|
|
logprobs = gather_packed_shifted_log_probs(
|
|
logits, cu_seqlens, packed_input_ids
|
|
).float()
|
|
loss, ppo_stat = ppo_functional.actor_loss_fn(
|
|
logprobs=logprobs,
|
|
old_logprobs=old_logp,
|
|
advantages=advantages,
|
|
eps_clip=eps_clip,
|
|
loss_mask=ppo_loss_mask,
|
|
c_clip=c_clip,
|
|
proximal_logprobs=input_.data.get("prox_logp", None),
|
|
)
|
|
|
|
# Log training statistics
|
|
stats_tracker.denominator(
|
|
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
|
|
n_valid_tokens=ppo_loss_mask.bool(),
|
|
clipped_tokens=ppo_stat["clip_mask"],
|
|
dual_clipped_tokens=ppo_stat["dual_clip_mask"],
|
|
)
|
|
|
|
stats_tracker.stat(
|
|
importance_weight=ppo_stat["importance_weight"],
|
|
approx_kl=ppo_stat["approx_kl"],
|
|
new_logp=logprobs.detach(),
|
|
old_logp=old_logp,
|
|
actor_loss=ppo_stat["loss"],
|
|
clip_ratio=ppo_stat["clip_mask"].float(),
|
|
dual_clip_ratio=ppo_stat["dual_clip_mask"].float(),
|
|
denominator="n_valid_tokens",
|
|
)
|
|
if "behave_imp_weight" in ppo_stat:
|
|
stats_tracker.denominator(unclipped_behave_tokens=ppo_stat["behave_mask"])
|
|
stats_tracker.stat(
|
|
behave_imp_weight=ppo_stat["behave_imp_weight"],
|
|
behave_approx_kl=ppo_stat["behave_approx_kl"],
|
|
denominator="unclipped_behave_tokens",
|
|
)
|
|
vocab_min_logits = logits.detach().min(-1).values.float()
|
|
vocab_max_logits = logits.detach().max(-1).values.float()
|
|
dist.all_reduce(
|
|
vocab_min_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MIN
|
|
)
|
|
dist.all_reduce(
|
|
vocab_max_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MAX
|
|
)
|
|
stats_tracker.stat(
|
|
vocab_min_logits=vocab_min_logits,
|
|
vocab_max_logits=vocab_max_logits,
|
|
denominator="n_tokens",
|
|
)
|
|
|
|
clip_mask = ppo_stat["clip_mask"]
|
|
dual_clip_mask = ppo_stat["dual_clip_mask"]
|
|
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
|
|
dual_clipped_new_logp = torch.where(dual_clip_mask, logprobs.detach(), 0.0)
|
|
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
|
|
dual_clipped_old_logp = torch.where(dual_clip_mask, old_logp, 0.0)
|
|
stats_tracker.stat(
|
|
clipped_new_logp=clipped_new_logp,
|
|
clipped_old_logp=clipped_old_logp,
|
|
denominator="clipped_tokens",
|
|
)
|
|
stats_tracker.stat(
|
|
dual_clipped_new_logp=dual_clipped_new_logp,
|
|
dual_clipped_old_logp=dual_clipped_old_logp,
|
|
denominator="dual_clipped_tokens",
|
|
)
|
|
|
|
# Logging and early stopping according to KL (logp vs ref) or importance ratio (new logp vs old logp).
|
|
mean_ref_kl = (kl_rewards.detach().float() * ppo_loss_mask).sum()
|
|
dist.all_reduce(mean_ref_kl, group=constants.data_parallel_group())
|
|
_imp = (ppo_stat["importance_weight"].float() * ppo_loss_mask).sum()
|
|
dist.all_reduce(_imp, group=constants.data_parallel_group())
|
|
_kl = (ppo_stat["approx_kl"].float() * ppo_loss_mask).sum()
|
|
dist.all_reduce(_kl, group=constants.data_parallel_group())
|
|
_n_valid_tokens = ppo_loss_mask.count_nonzero().clone()
|
|
dist.all_reduce(_n_valid_tokens, group=constants.data_parallel_group())
|
|
mean_ref_kl /= _n_valid_tokens
|
|
_imp /= _n_valid_tokens
|
|
_kl /= _n_valid_tokens
|
|
# Early stopping.
|
|
kl_adapter.update(mean_ref_kl, n_steps=cu_seqlens.shape[0] - 1)
|
|
if early_stop_imp_ratio is not None and _imp > early_stop_imp_ratio:
|
|
logger.warning(
|
|
f"Current importance ratio {_imp.item():.4f} is larger "
|
|
f"than early stop threshold {early_stop_imp_ratio}. Abandon this minibatch."
|
|
)
|
|
loss = loss * 0.0
|
|
if early_stop_kl is not None and _kl > early_stop_kl:
|
|
logger.warning(
|
|
f"Current approximate KL divergence {_kl.item():.4f} is larger "
|
|
f"than early stop threshold {early_stop_kl}. Abort actor update."
|
|
)
|
|
loss = loss * 0.0
|
|
|
|
return loss
|
|
|
|
|
|
def splited_sum_bool_tensor(t: torch.BoolTensor, chunk_size=256 * 1024 * 1024) -> int:
|
|
"""Sum a boolean tensor by splitting them into chunks and sum the chunks
|
|
separately.
|
|
|
|
to avoid memory overhead introduced by torch default sum method
|
|
(which will apply for a block of memory of size `8 * t.numel()`
|
|
bytes.)
|
|
"""
|
|
flatten = t.flatten()
|
|
splitted = flatten.split(chunk_size // 8, dim=0)
|
|
r = 0
|
|
for chunk in splitted:
|
|
r += chunk.sum()
|
|
return r
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class PPOActorInterface(model_api.ModelInterface):
|
|
n_minibatches: int = 4
|
|
|
|
# Use dict here to allow argument passing through commandline.
|
|
generation_config: Dict = dataclasses.field(default_factory=dict)
|
|
|
|
kl_ctl: float = 0.1
|
|
|
|
adv_norm: bool = True
|
|
discount: float = 1.0
|
|
gae_lambda: float = 1.0
|
|
|
|
eps_clip: float = 0.2
|
|
c_clip: Optional[float] = None
|
|
value_eps_clip: float = 0.2
|
|
max_reward_clip: float = 5.0
|
|
|
|
disable_value: bool = False
|
|
|
|
early_stop_kl: Optional[float] = None # e.g. 0.1
|
|
early_stop_imp_ratio: Optional[float] = None # e.g., 10.0
|
|
|
|
adaptive_kl_ctl: bool = False
|
|
adaptive_kl_target: Optional[float] = 6
|
|
adaptive_kl_horizon: Optional[float] = 10000
|
|
|
|
enable_save: bool = True
|
|
|
|
value_norm: bool = False
|
|
value_norm_type: str = dataclasses.field(
|
|
metadata={"choices": ["exp", "ma"]}, default="exp"
|
|
)
|
|
value_norm_beta: float = 0.99995
|
|
value_norm_eps: float = 1e-5
|
|
|
|
group_size: int = 1
|
|
generation_size: Optional[int] = None
|
|
mask_no_eos_with_zero: bool = False
|
|
group_adv_norm: bool = False
|
|
mask_too_long: bool = False
|
|
use_dense_reward: bool = False
|
|
reward_delta: bool = True
|
|
token_normalize_scope: Literal["global", "dp"] = "global"
|
|
|
|
sample_reuse: int = 1
|
|
|
|
def __post_init__(self):
|
|
if self.adaptive_kl_ctl:
|
|
assert self.adaptive_kl_target is not None
|
|
assert self.adaptive_kl_horizon is not None
|
|
self.kl_adapter = ppo_functional.AdaptiveKLController(
|
|
self.kl_ctl, self.adaptive_kl_target, self.adaptive_kl_horizon
|
|
)
|
|
else:
|
|
self.kl_adapter = ppo_functional.FixedKLController(self.kl_ctl)
|
|
if self.value_norm:
|
|
from realhf.impl.model.modules import (
|
|
ExponentialRunningMeanStd,
|
|
MovingAverageRunningMeanStd,
|
|
)
|
|
|
|
if self.value_norm_type == "exp":
|
|
self.rms = ExponentialRunningMeanStd(
|
|
beta=self.value_norm_beta, epsilon=self.value_norm_eps
|
|
)
|
|
elif self.value_norm_type == "ma":
|
|
self.rms = MovingAverageRunningMeanStd()
|
|
else:
|
|
raise ValueError(f"Unknown value_norm_type {self.value_norm_type}")
|
|
self.kl_ctl = None
|
|
|
|
self.gconfig = model_api.GenerationHyperparameters(**self.generation_config)
|
|
if self.generation_size is not None:
|
|
assert self.generation_size >= self.group_size
|
|
else:
|
|
self.generation_size = self.group_size
|
|
self.gconfig.n = self.generation_size
|
|
|
|
def save(self, model: model_api.Model, save_dir: str):
|
|
if not self.enable_save:
|
|
return
|
|
module = model.module
|
|
if not isinstance(module, ReaLModel):
|
|
module = module.module
|
|
module.save_to_hf(
|
|
tokenizer=model.tokenizer,
|
|
save_dir=save_dir,
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
model: model_api.Model,
|
|
input_: SequenceSample,
|
|
mb_spec: MicroBatchSpec,
|
|
) -> SequenceSample:
|
|
module = model.module
|
|
|
|
module.eval()
|
|
|
|
# Remap the key `packed_prompts` to `packed_input_ids`,
|
|
# because the pipe runner only recognizes `packed_input_ids`.
|
|
# x = SequenceSample.from_default(
|
|
# ids=input_.ids,
|
|
# seqlens=input_.seqlens["packed_prompts"],
|
|
# data=dict(packed_input_ids=input_.data["packed_prompts"]),
|
|
# )
|
|
|
|
packed_input_ids = input_.data["packed_prompts"]
|
|
new_input_ids = []
|
|
offset = 0
|
|
for x in input_.seqlens["packed_prompts"]:
|
|
new_input_ids += [
|
|
packed_input_ids[offset : offset + x[0]]
|
|
] * self.generation_size
|
|
offset += x[0]
|
|
assert offset == sum(x[0] for x in input_.seqlens["packed_prompts"])
|
|
|
|
if model.backend_name not in ["vllm", "sglang"]:
|
|
# Replicate prompts
|
|
grouped_input = SequenceSample.from_default(
|
|
ids=list(range(input_.bs * self.generation_size)),
|
|
seqlens=[
|
|
x[0]
|
|
for x in input_.seqlens["packed_prompts"]
|
|
for _ in range(self.generation_size)
|
|
],
|
|
data=dict(packed_input_ids=torch.cat(new_input_ids)),
|
|
)
|
|
else:
|
|
grouped_input = SequenceSample(
|
|
ids=input_.ids,
|
|
seqlens=dict(packed_input_ids=input_.seqlens["packed_prompts"]),
|
|
keys=["packed_input_ids"],
|
|
dtypes=dict(packed_input_ids=torch.long),
|
|
trailing_shapes=dict(packed_input_ids=()),
|
|
data=dict(packed_input_ids=input_.data["packed_prompts"]),
|
|
)
|
|
|
|
res = module.generate(
|
|
input_=grouped_input,
|
|
tokenizer=model.tokenizer,
|
|
gconfig=self.gconfig,
|
|
mb_spec=mb_spec,
|
|
)
|
|
if res is None or res[0] is None:
|
|
return None
|
|
|
|
gen_tokens, logprobs, _ = res
|
|
|
|
pad_token_id = model.tokenizer.pad_token_id
|
|
eos_token_id = model.tokenizer.eos_token_id
|
|
seq_no_eos_mask = (gen_tokens[:, -1] != eos_token_id).logical_and(
|
|
gen_tokens[:, -1] != pad_token_id
|
|
)
|
|
# We also want gen_lengths to include the eos token, where the reward model outputs a score for this sequence.
|
|
gen_lengths = (gen_tokens != pad_token_id).logical_and(
|
|
gen_tokens != eos_token_id
|
|
).sum(dim=-1) + 1
|
|
gen_lengths = gen_lengths.clip(max=gen_tokens.shape[-1])
|
|
input_seq_lens = [
|
|
x for x in input_.seqlens["packed_prompts"] for _ in range(self.group_size)
|
|
]
|
|
input_token_ids = torch.cat(new_input_ids)
|
|
|
|
if self.generation_size is not None and self.generation_size > self.group_size:
|
|
|
|
# best of k
|
|
query_ids = [
|
|
query_id for query_id in input_.ids for _ in range(self.generation_size)
|
|
]
|
|
scores = get_score(new_input_ids, gen_tokens, query_ids, model.tokenizer)
|
|
input_ids_topk, gen_tokens_topk, logprobs_topk, gen_lengths_topk = (
|
|
[],
|
|
[],
|
|
[],
|
|
[],
|
|
)
|
|
for data_idx in range(0, len(gen_tokens), self.generation_size):
|
|
topk_idx = topk(
|
|
scores[data_idx : data_idx + self.generation_size],
|
|
gen_lengths[data_idx : data_idx + self.generation_size],
|
|
self.group_size,
|
|
)
|
|
topk_idx = [data_idx + x for x in topk_idx]
|
|
gen_tokens_topk += gen_tokens[topk_idx]
|
|
logprobs_topk += logprobs[topk_idx]
|
|
gen_lengths_topk += gen_lengths[topk_idx]
|
|
input_ids_topk += [new_input_ids[x] for x in topk_idx]
|
|
|
|
input_token_ids = torch.cat(input_ids_topk)
|
|
|
|
gen_tokens = torch.stack(gen_tokens_topk)
|
|
logprobs = torch.stack(logprobs_topk)
|
|
gen_lengths = torch.stack(gen_lengths_topk)
|
|
seq_no_eos_mask = (gen_tokens[:, -1] != eos_token_id).logical_and(
|
|
gen_tokens[:, -1] != pad_token_id
|
|
)
|
|
|
|
(
|
|
packed_input_ids,
|
|
packed_logprobs,
|
|
_,
|
|
seq_lengths,
|
|
prompt_mask,
|
|
) = concat_prompt_to_generation_output(
|
|
packed_prompts=input_token_ids,
|
|
prompt_lengths=torch.tensor(flat2d(input_seq_lens)).to(model.device),
|
|
gen_tokens=gen_tokens,
|
|
logprobs=logprobs,
|
|
logits_mask=None,
|
|
gen_lengths=gen_lengths,
|
|
)
|
|
|
|
# Partition generated data into groups.
|
|
seqlens = [
|
|
seq_lengths[i * self.group_size : (i + 1) * self.group_size]
|
|
.cpu()
|
|
.numpy()
|
|
.tolist()
|
|
for i in range(input_.bs)
|
|
]
|
|
|
|
data = dict(
|
|
seq_no_eos_mask=seq_no_eos_mask,
|
|
packed_input_ids=packed_input_ids,
|
|
packed_logprobs=packed_logprobs,
|
|
prompt_mask=prompt_mask,
|
|
)
|
|
|
|
res = SequenceSample(
|
|
keys=[
|
|
"packed_input_ids",
|
|
"prompt_mask",
|
|
"packed_logprobs",
|
|
"seq_no_eos_mask",
|
|
],
|
|
trailing_shapes=dict(
|
|
packed_input_ids=(),
|
|
prompt_mask=(),
|
|
packed_logprobs=(),
|
|
seq_no_eos_mask=(),
|
|
),
|
|
dtypes=dict(
|
|
packed_input_ids=torch.long,
|
|
prompt_mask=torch.bool,
|
|
packed_logprobs=torch.float,
|
|
seq_no_eos_mask=torch.bool,
|
|
),
|
|
seqlens=dict(
|
|
packed_input_ids=seqlens,
|
|
packed_logprobs=[[x - 1 for x in slens] for slens in seqlens],
|
|
prompt_mask=seqlens,
|
|
seq_no_eos_mask=[[1] * self.group_size for _ in seqlens],
|
|
),
|
|
data=data,
|
|
ids=input_.ids,
|
|
prompt_mask=prompt_mask,
|
|
)
|
|
|
|
return res
|
|
|
|
@torch.no_grad()
|
|
def inference(
|
|
self,
|
|
model: model_api.Model,
|
|
input_: SequenceSample,
|
|
mb_spec: MicroBatchSpec,
|
|
) -> SequenceSample:
|
|
module = model.module
|
|
module.eval()
|
|
|
|
# This post_hook will gather log probabilities in mini-batches,
|
|
# reducing peak memory usage.
|
|
def calc_logprobs(logits, input_):
|
|
logits /= self.gconfig.temperature
|
|
|
|
input_lens = torch.tensor(input_.seqlens["packed_input_ids"]).view(-1)
|
|
cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
|
|
|
|
logprobs = gather_packed_shifted_log_probs(
|
|
logits, cu_seqlens, input_.data["packed_input_ids"]
|
|
)
|
|
return logprobs
|
|
|
|
input_flattend = SequenceSample.from_default(
|
|
ids=list(range(input_.bs * self.group_size)),
|
|
seqlens=flat2d(input_.seqlens["packed_input_ids"]),
|
|
data=dict(packed_input_ids=input_.data["packed_input_ids"]),
|
|
)
|
|
# add posthook to avoid storing full logits
|
|
logprobs = module.forward(
|
|
input_=input_flattend,
|
|
post_hook=calc_logprobs,
|
|
output_seqlens=[
|
|
[x - 1 for x in slens]
|
|
for slens in input_flattend.seqlens["packed_input_ids"]
|
|
],
|
|
mb_spec=mb_spec,
|
|
)
|
|
|
|
res = SequenceSample(
|
|
keys=["logprobs"],
|
|
ids=input_.ids,
|
|
dtypes=dict(logprobs=model.module.dtype),
|
|
trailing_shapes=dict(logprobs=()),
|
|
data=dict(logprobs=logprobs),
|
|
seqlens=dict(
|
|
logprobs=[
|
|
[x - 1 for x in slen] for slen in input_.seqlens["packed_input_ids"]
|
|
]
|
|
),
|
|
)
|
|
|
|
return res
|
|
|
|
def train_step(
|
|
self,
|
|
model: model_api.Model,
|
|
input_: SequenceSample,
|
|
mb_spec: MicroBatchSpec,
|
|
) -> Dict | List[Dict]:
|
|
module = model.module
|
|
# We call module.eval() because dropout causes the computation of incorrect of log probs.
|
|
module.eval()
|
|
|
|
prompt_mask = input_.data["prompt_mask"]
|
|
input_lens = torch.tensor(
|
|
flat2d(input_.seqlens["packed_input_ids"]), device=model.device
|
|
)
|
|
cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
|
|
prompt_lens = []
|
|
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
|
prompt_lens.append(prompt_mask[s:e].sum())
|
|
prompt_lens = torch.tensor(prompt_lens, device=model.device)
|
|
reward_score = input_.data["rewards"].float()
|
|
task_ids = input_.data["task_ids"]
|
|
task_ids = task_ids.repeat(self.group_size, 1).transpose(0, 1).reshape(-1)
|
|
|
|
if "dense_rewards" in input_.data:
|
|
dense_reward_score = input_.data["dense_rewards"].float()
|
|
if not self.disable_value:
|
|
values = input_.data["values"].float()
|
|
else:
|
|
values = torch.zeros_like(
|
|
input_.data["packed_input_ids"], dtype=torch.float32
|
|
)
|
|
seq_no_eos_mask = input_.data["seq_no_eos_mask"]
|
|
if self.kl_adapter.value == 0:
|
|
ref_logp: torch.FloatTensor = reward_score.new_zeros(
|
|
int(input_lens.sum()) - len(input_lens)
|
|
)
|
|
else:
|
|
ref_logp: torch.FloatTensor = input_.data["packed_ref_logprobs"].float()
|
|
old_logp: torch.FloatTensor = input_.data["packed_logprobs"].float()
|
|
|
|
if not self.disable_value:
|
|
if self.value_norm:
|
|
denormalized_values = self.rms.denormalize(values)
|
|
else:
|
|
denormalized_values = values
|
|
else:
|
|
denormalized_values = values
|
|
|
|
for i in range(seq_no_eos_mask.shape[0]):
|
|
if not seq_no_eos_mask[i]:
|
|
# Set value at the EOS token to be zero.
|
|
denormalized_values[cu_seqlens[i + 1] - 1] = 0.0
|
|
values[cu_seqlens[i + 1] - 1] = 0.0
|
|
|
|
# Shift the loss mask by one token for each packed sequences.
|
|
short1cu_seqlens = cu_seqlens.clone()
|
|
short1cu_seqlens[1:] -= torch.ones_like(cu_seqlens[1:]).cumsum(0)
|
|
loss_mask = prompt_mask.logical_not()
|
|
|
|
if self.mask_too_long:
|
|
for i in range(seq_no_eos_mask.shape[0]):
|
|
if seq_no_eos_mask[i]:
|
|
loss_mask[cu_seqlens[i] : cu_seqlens[i + 1]] = False
|
|
|
|
shift_one_indices = torch.cat(
|
|
[
|
|
torch.arange(
|
|
cu_seqlens[i] + 1,
|
|
cu_seqlens[i + 1],
|
|
dtype=torch.long,
|
|
device=cu_seqlens.device,
|
|
)
|
|
for i in range(cu_seqlens.shape[0] - 1)
|
|
]
|
|
)
|
|
loss_mask = loss_mask[shift_one_indices]
|
|
|
|
# Apply the mask to log probabilities.
|
|
ref_logp *= loss_mask
|
|
old_logp *= loss_mask
|
|
|
|
# Compute rewards and GAEs.
|
|
if self.use_dense_reward:
|
|
kl_rewards, rewards = ppo_functional.get_packed_reward_dense(
|
|
kl_ctl=self.kl_adapter.value,
|
|
clip_reward_value=self.max_reward_clip,
|
|
log_probs=old_logp,
|
|
ref_log_probs=ref_logp,
|
|
dense_reward_score=dense_reward_score,
|
|
short1cu_seqlens=short1cu_seqlens,
|
|
seq_no_eos_mask=seq_no_eos_mask,
|
|
reward_delta=self.reward_delta,
|
|
)
|
|
else:
|
|
kl_rewards, rewards = ppo_functional.get_packed_rewards(
|
|
kl_ctl=self.kl_adapter.value,
|
|
clip_reward_value=self.max_reward_clip,
|
|
log_probs=old_logp,
|
|
ref_log_probs=ref_logp,
|
|
reward_score=(reward_score),
|
|
short1cu_seqlens=short1cu_seqlens,
|
|
seq_no_eos_mask=seq_no_eos_mask,
|
|
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
|
|
)
|
|
advantages, returns = ppo_functional.get_packed_advantages_and_returns(
|
|
gamma=self.discount,
|
|
lam=self.gae_lambda,
|
|
values=(
|
|
denormalized_values
|
|
if not self.disable_value
|
|
else denormalized_values.new_zeros(denormalized_values.shape)
|
|
),
|
|
rewards=rewards,
|
|
short1cu_seqlens=short1cu_seqlens,
|
|
seq_no_eos_mask=seq_no_eos_mask,
|
|
)
|
|
|
|
# Optionally perform normalization.
|
|
if self.value_norm:
|
|
self.rms.update(returns, mask=loss_mask)
|
|
if self.adv_norm:
|
|
if self.group_adv_norm == False:
|
|
advantages = masked_normalization(advantages, loss_mask)
|
|
else:
|
|
logger.info(f"adv_shape: {advantages.shape}")
|
|
logger.info(f"prompt_mask_shape: {prompt_mask.shape}")
|
|
n_samples = len(cu_seqlens) - 1
|
|
assert n_samples % self.group_size == 0
|
|
adv_list = []
|
|
for i in range(0, n_samples, self.group_size):
|
|
for j in range(1, self.group_size):
|
|
assert (
|
|
prompt_mask[cu_seqlens[i] : cu_seqlens[i + 1]].sum()
|
|
== prompt_mask[
|
|
cu_seqlens[i + j] : cu_seqlens[i + j + 1]
|
|
].sum()
|
|
)
|
|
adv_list.append(
|
|
masked_normalization(
|
|
advantages[
|
|
short1cu_seqlens[i] : short1cu_seqlens[
|
|
i + self.group_size
|
|
]
|
|
],
|
|
loss_mask[
|
|
short1cu_seqlens[i] : short1cu_seqlens[
|
|
i + self.group_size
|
|
]
|
|
],
|
|
all_reduce=False,
|
|
)
|
|
)
|
|
|
|
advantages = torch.cat(adv_list, 0)
|
|
|
|
# Prepare data to be splitted into mini-batches.
|
|
flat_data = dict(
|
|
advantages=advantages,
|
|
old_logp=old_logp,
|
|
ppo_loss_mask=loss_mask,
|
|
packed_input_ids=input_.data["packed_input_ids"],
|
|
kl_rewards=kl_rewards,
|
|
)
|
|
use_prox_logp = "proximal_logprobs" in input_.data
|
|
if use_prox_logp:
|
|
flat_data["prox_logp"] = input_.data["proximal_logprobs"].float()
|
|
|
|
flat_input = SequenceSample.from_default(
|
|
ids=list(range(input_.bs * self.group_size)),
|
|
data=flat_data,
|
|
seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()],
|
|
)
|
|
|
|
if self.use_dense_reward:
|
|
dense_reward_score = dense_reward_score[shift_one_indices]
|
|
|
|
### Logging code starts. ###
|
|
all_stats = []
|
|
with stats_tracker.scope("ppo_actor"):
|
|
assert (
|
|
task_ids.shape == reward_score.shape
|
|
), f"task_ids ({task_ids.shape}) and reward_score ({reward_score.shape}) must have the same shape"
|
|
|
|
task_denominators = {
|
|
f"{task}_n_seqs": (task_ids == idx).bool()
|
|
for idx, task in enumerate(RL_TASKS)
|
|
}
|
|
|
|
global_denominators = dict(
|
|
n_seqs=torch.ones_like(reward_score, dtype=torch.bool),
|
|
n_tokens=torch.ones_like(prompt_mask, dtype=torch.bool),
|
|
n_valid_tokens=loss_mask.bool(),
|
|
**task_denominators,
|
|
)
|
|
stats_tracker.denominator(**global_denominators)
|
|
|
|
for task in RL_TASKS:
|
|
stats_tracker.stat(
|
|
**{f"{task}_reward": reward_score}, denominator=f"{task}_n_seqs"
|
|
)
|
|
|
|
stats = dict(
|
|
advantages=advantages,
|
|
kl_rewards=kl_rewards,
|
|
final_reward=rewards,
|
|
)
|
|
if self.use_dense_reward:
|
|
stats["dense_reward"] = dense_reward_score
|
|
stats_tracker.stat(**stats, denominator="n_valid_tokens")
|
|
|
|
seq_stats = dict(
|
|
no_eos_ratios=seq_no_eos_mask.float(),
|
|
task_reward=reward_score,
|
|
prompt_len=prompt_lens.float(),
|
|
seq_len=input_lens.float(),
|
|
)
|
|
if "version_start" in input_.data:
|
|
seq_stats["head_offpolicyness"] = (
|
|
model.version.global_step - input_.data["version_start"]
|
|
).float()
|
|
if "version_end" in input_.data:
|
|
seq_stats["tail_offpolicyness"] = (
|
|
model.version.global_step - input_.data["version_end"]
|
|
).float()
|
|
stats_tracker.stat(
|
|
**seq_stats,
|
|
denominator="n_seqs",
|
|
)
|
|
scalars = dict(
|
|
disable_value=self.disable_value,
|
|
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
|
|
eps_clip=self.eps_clip,
|
|
use_prox_logp=use_prox_logp,
|
|
)
|
|
if self.c_clip is not None:
|
|
scalars["c_clip"] = self.c_clip
|
|
scalars["use_dual_clip"] = 1
|
|
else:
|
|
scalars["use_dual_clip"] = 0
|
|
stats_tracker.scalar(**scalars)
|
|
|
|
global_stats = stats_tracker.export()
|
|
for k in global_denominators:
|
|
global_stats.pop(f"ppo_actor/{k}")
|
|
|
|
# Run mini-batched PPO training!
|
|
def _loss_fn(logits, input_):
|
|
return _ppo_actor_loss_from_model_outputs(
|
|
logits,
|
|
input_,
|
|
kl_adapter=self.kl_adapter,
|
|
eps_clip=self.eps_clip,
|
|
early_stop_imp_ratio=self.early_stop_imp_ratio,
|
|
early_stop_kl=self.early_stop_kl,
|
|
c_clip=self.c_clip,
|
|
temperature=self.gconfig.temperature,
|
|
)
|
|
|
|
for reuse in range(self.sample_reuse):
|
|
# NOTE: We split PPO minibatches in terms of #seqs instead of #tokens.
|
|
flat_input = SequenceSample.shuffled(flat_input)
|
|
bs = flat_input.bs
|
|
sizes = [0 for _ in range(self.n_minibatches)]
|
|
for idx in range(bs):
|
|
sizes[idx % self.n_minibatches] += 1
|
|
spec = SequenceSplitSpec(sizes=sizes)
|
|
datas = flat_input.split_with_spec(spec)
|
|
logger.info(
|
|
f"PPO minibatch split (size {self.n_minibatches}): "
|
|
f"#seqs: {[s.bs for s in datas]}, "
|
|
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}"
|
|
)
|
|
for mb_i, data in enumerate(datas):
|
|
train_stat = module.train_batch(
|
|
input_=data,
|
|
mb_spec=mb_spec,
|
|
version_steps=model.version.global_step,
|
|
loss_fn=_loss_fn,
|
|
loss_weight_fn=lambda x: x.data[
|
|
"ppo_loss_mask"
|
|
].count_nonzero(),
|
|
token_normalize_scope=self.token_normalize_scope,
|
|
)
|
|
stats_tracker.scalar(**train_stat)
|
|
all_stats.append(stats_tracker.export())
|
|
|
|
model.inc_version()
|
|
all_stats[0].update(global_stats)
|
|
|
|
return all_stats
|
|
|
|
# Mock methods for profiling only.
|
|
def _mock_inference(
|
|
self,
|
|
model: model_api.Model,
|
|
dataset_input: SequenceSample,
|
|
) -> SequenceSample:
|
|
prompt_lens = flat2d(dataset_input.seqlens["packed_prompts"])
|
|
seqlens = [x + self.gconfig.max_new_tokens for x in prompt_lens]
|
|
module = model.module
|
|
if not isinstance(module, ReaLModel):
|
|
module = module.module
|
|
mconfig = module.config
|
|
packed_input_ids = torch.randint(
|
|
0,
|
|
mconfig.vocab_size,
|
|
(sum(seqlens),),
|
|
dtype=torch.long,
|
|
device=model.device,
|
|
)
|
|
|
|
return SequenceSample.from_default(
|
|
seqlens=seqlens,
|
|
ids=dataset_input.ids,
|
|
data=dict(packed_input_ids=packed_input_ids),
|
|
)
|
|
|
|
# Mock methods for profiling only.
|
|
def _mock_train_step(
|
|
self,
|
|
model: model_api.Model,
|
|
dataset_input: SequenceSample,
|
|
) -> Dict:
|
|
prompt_lens = flat2d(dataset_input.seqlens["packed_prompts"])
|
|
bs = len(prompt_lens)
|
|
seqlens = [x + self.gconfig.max_new_tokens for x in prompt_lens]
|
|
module = model.module
|
|
if not isinstance(module, ReaLModel):
|
|
module = module.module
|
|
mconfig = module.config
|
|
mdtype = module.dtype
|
|
short1_seqlens = [x - 1 for x in seqlens]
|
|
|
|
packed_logprobs = torch.randn(
|
|
(sum(short1_seqlens),), dtype=mdtype, device=model.device
|
|
)
|
|
packed_ref_logprobs = torch.randn_like(packed_logprobs)
|
|
prompt_mask = torch.zeros(
|
|
(sum(seqlens),), dtype=torch.bool, device=model.device
|
|
)
|
|
packed_input_ids = torch.randint(
|
|
0,
|
|
mconfig.vocab_size,
|
|
(sum(seqlens),),
|
|
dtype=torch.long,
|
|
device=model.device,
|
|
)
|
|
rewards = torch.randn(bs, dtype=mdtype, device=model.device)
|
|
seq_no_eos_mask = torch.randint(
|
|
0, 2, (bs,), dtype=torch.bool, device=model.device
|
|
)
|
|
values = torch.randn(
|
|
(sum(seqlens),),
|
|
dtype=mdtype,
|
|
device=model.device,
|
|
)
|
|
|
|
return SequenceSample.from_default(
|
|
seqlens=seqlens,
|
|
ids=dataset_input.ids,
|
|
data=dict(
|
|
packed_logprobs=packed_logprobs,
|
|
packed_ref_logprobs=packed_ref_logprobs,
|
|
prompt_mask=prompt_mask,
|
|
packed_input_ids=packed_input_ids,
|
|
rewards=rewards,
|
|
seq_no_eos_mask=seq_no_eos_mask,
|
|
values=values,
|
|
),
|
|
)
|
|
|
|
|
|
def _ppo_critic_loss_from_model_outputs(
|
|
new_values: torch.FloatTensor,
|
|
input_: SequenceSample,
|
|
value_eps_clip: float,
|
|
kl_adapter: ppo_functional.KLController,
|
|
rms=None,
|
|
) -> torch.Tensor:
|
|
|
|
cu_seqlens = (
|
|
torch.nn.functional.pad(
|
|
torch.tensor(flat2d(input_.seqlens["packed_input_ids"])).cumsum(0),
|
|
(1, 0),
|
|
)
|
|
.int()
|
|
.to(new_values.device)
|
|
)
|
|
ppo_loss_mask = input_.data["ppo_loss_mask"]
|
|
returns = input_.data["returns"]
|
|
values = input_.data["values"]
|
|
kl_rewards = input_.data["kl_rewards"]
|
|
|
|
leave_one_indices = torch.cat(
|
|
[
|
|
torch.arange(
|
|
cu_seqlens[i],
|
|
cu_seqlens[i + 1] - 1,
|
|
dtype=torch.long,
|
|
device=cu_seqlens.device,
|
|
)
|
|
for i in range(cu_seqlens.shape[0] - 1)
|
|
]
|
|
)
|
|
new_values = new_values[leave_one_indices].view(-1).float()
|
|
values = values[leave_one_indices].view(-1).float()
|
|
|
|
loss, loss_stat = ppo_functional.critic_loss_fn(
|
|
value=new_values,
|
|
old_value=values,
|
|
target_value=returns,
|
|
value_eps_clip=value_eps_clip,
|
|
loss_mask=ppo_loss_mask,
|
|
)
|
|
|
|
if rms is not None:
|
|
denormalized_values = rms.denormalize(new_values)
|
|
else:
|
|
denormalized_values = new_values
|
|
|
|
# Logging.
|
|
stats_tracker.denominator(n_valid_tokens=ppo_loss_mask.bool())
|
|
stats_tracker.stat(
|
|
value_loss=loss_stat["loss"],
|
|
clip_ratio=loss_stat["clip_mask"].float(),
|
|
denormalized_values=denormalized_values.detach().float(),
|
|
denominator="n_valid_tokens",
|
|
)
|
|
|
|
# Update KL coefficient to be consistent with actor.
|
|
mean_ref_kl = (kl_rewards.detach().float() * ppo_loss_mask).sum()
|
|
dist.all_reduce(mean_ref_kl, group=constants.data_parallel_group())
|
|
_n_valid_tokens = ppo_loss_mask.count_nonzero().clone()
|
|
dist.all_reduce(_n_valid_tokens, group=constants.data_parallel_group())
|
|
mean_ref_kl /= _n_valid_tokens
|
|
kl_adapter.update(mean_ref_kl, n_steps=cu_seqlens.shape[0] - 1)
|
|
|
|
return loss
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class PPOCriticInterface(model_api.ModelInterface):
|
|
n_minibatches: int = 4
|
|
enable_save: bool = True
|
|
kl_ctl: float = 0.1
|
|
discount: float = 1.0
|
|
gae_lambda: float = 0.95
|
|
value_eps_clip: float = 0.2
|
|
max_reward_clip: float = 5.0
|
|
adaptive_kl_ctl: bool = False
|
|
adaptive_kl_target: Optional[float] = 6
|
|
adaptive_kl_horizon: Optional[float] = 10000
|
|
value_norm: bool = False
|
|
value_norm_type: str = dataclasses.field(
|
|
metadata={"choices": ["exp", "ma"]}, default="exp"
|
|
)
|
|
value_norm_beta: float = 0.99995
|
|
value_norm_eps: float = 1e-5
|
|
disable_value: bool = False
|
|
|
|
group_size: int = 1
|
|
mask_no_eos_with_zero: bool = False
|
|
mask_too_long: bool = False
|
|
use_dense_reward: bool = False
|
|
reward_delta: bool = True
|
|
token_normalize_scope: Literal["global", "dp"] = "global"
|
|
|
|
sample_reuse: int = 1
|
|
|
|
def __post_init__(self):
|
|
if self.adaptive_kl_ctl:
|
|
assert self.adaptive_kl_target is not None
|
|
assert self.adaptive_kl_horizon is not None
|
|
self.kl_adapter = ppo_functional.AdaptiveKLController(
|
|
self.kl_ctl, self.adaptive_kl_target, self.adaptive_kl_horizon
|
|
)
|
|
else:
|
|
self.kl_adapter = ppo_functional.FixedKLController(self.kl_ctl)
|
|
if self.value_norm:
|
|
from realhf.impl.model.modules import (
|
|
ExponentialRunningMeanStd,
|
|
MovingAverageRunningMeanStd,
|
|
)
|
|
|
|
if self.value_norm_type == "exp":
|
|
self.rms = ExponentialRunningMeanStd(
|
|
beta=self.value_norm_beta, epsilon=self.value_norm_eps
|
|
)
|
|
elif self.value_norm_type == "ma":
|
|
self.rms = MovingAverageRunningMeanStd()
|
|
else:
|
|
raise ValueError(f"Unknown value_norm_type {self.value_norm_type}")
|
|
self.kl_ctl = None
|
|
|
|
def save(self, model: model_api.Model, save_dir: str):
|
|
if not self.enable_save:
|
|
return
|
|
module = model.module
|
|
if not isinstance(module, ReaLModel):
|
|
module = module.module
|
|
module.save_to_hf(
|
|
tokenizer=model.tokenizer,
|
|
save_dir=save_dir,
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def inference(
|
|
self,
|
|
model: model_api.Model,
|
|
input_: SequenceSample,
|
|
mb_spec: MicroBatchSpec,
|
|
) -> SequenceSample:
|
|
assert model.module.module.config.is_critic
|
|
module = model.module
|
|
module.eval()
|
|
|
|
input_flattend = SequenceSample.from_default(
|
|
ids=list(range(input_.bs * self.group_size)),
|
|
seqlens=flat2d(input_.seqlens["packed_input_ids"]),
|
|
data=dict(packed_input_ids=input_.data["packed_input_ids"]),
|
|
)
|
|
if self.disable_value:
|
|
scores = input_.data["packed_input_ids"].new_zeros(dtype=module.dtype)
|
|
else:
|
|
scores = module.forward(input_=input_flattend, mb_spec=mb_spec)
|
|
|
|
if scores is None:
|
|
return None
|
|
scores = scores.view(-1)
|
|
# res = SequenceSample.from_default(
|
|
# ids=input_.ids,
|
|
# data=dict(values=scores),
|
|
# seqlens=input_.seqlens["packed_input_ids"],
|
|
# )
|
|
res = SequenceSample(
|
|
keys=["values"],
|
|
ids=input_.ids,
|
|
dtypes=dict(values=module.dtype),
|
|
trailing_shapes=dict(values=()),
|
|
data=dict(values=scores),
|
|
seqlens=dict(values=input_.seqlens["packed_input_ids"]),
|
|
)
|
|
|
|
return res
|
|
|
|
def train_step(
|
|
self,
|
|
model: model_api.Model,
|
|
input_: SequenceSample,
|
|
mb_spec: MicroBatchSpec,
|
|
) -> Dict | List[Dict]:
|
|
assert model.module.module.config.is_critic
|
|
|
|
if self.disable_value:
|
|
return dict()
|
|
|
|
module = model.module
|
|
tokenizer = model.tokenizer
|
|
# We call module.eval() because dropout causes the computation of incorrect of log probs.
|
|
module.eval()
|
|
|
|
prompt_mask = input_.data["prompt_mask"]
|
|
input_lens = torch.tensor(
|
|
flat2d(input_.seqlens["packed_input_ids"]), device=model.device
|
|
)
|
|
cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
|
|
reward_score = input_.data["rewards"].float()
|
|
if "dense_rewards" in input_.data:
|
|
dense_reward_score = input_.data["dense_rewards"].float()
|
|
values = input_.data["values"].float()
|
|
seq_no_eos_mask = input_.data["seq_no_eos_mask"]
|
|
if self.kl_adapter.value == 0:
|
|
ref_logp: torch.FloatTensor = reward_score.new_zeros(
|
|
int(input_lens.sum()) - len(input_lens)
|
|
)
|
|
else:
|
|
ref_logp: torch.FloatTensor = input_.data["packed_ref_logprobs"].float()
|
|
old_logp: torch.FloatTensor = input_.data["packed_logprobs"].float()
|
|
|
|
if self.value_norm:
|
|
denormalized_values = self.rms.denormalize(values)
|
|
else:
|
|
denormalized_values = values
|
|
|
|
for i in range(seq_no_eos_mask.shape[0]):
|
|
if not seq_no_eos_mask[i]:
|
|
# Set value at the EOS token to be zero.
|
|
denormalized_values[cu_seqlens[i + 1] - 1] = 0.0
|
|
values[cu_seqlens[i + 1] - 1] = 0.0
|
|
|
|
# Shift the loss mask by one token for each packed sequences.
|
|
input_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
short1cu_seqlens = cu_seqlens.clone()
|
|
short1cu_seqlens[1:] -= torch.ones_like(cu_seqlens[1:]).cumsum(0)
|
|
loss_mask = prompt_mask.logical_not()
|
|
|
|
if self.mask_too_long:
|
|
for i in range(seq_no_eos_mask.shape[0]):
|
|
if seq_no_eos_mask[i]:
|
|
loss_mask[cu_seqlens[i] : cu_seqlens[i + 1]] = False
|
|
|
|
shift_one_indices = torch.cat(
|
|
[
|
|
torch.arange(
|
|
cu_seqlens[i] + 1,
|
|
cu_seqlens[i + 1],
|
|
dtype=torch.long,
|
|
device=cu_seqlens.device,
|
|
)
|
|
for i in range(cu_seqlens.shape[0] - 1)
|
|
]
|
|
)
|
|
loss_mask = loss_mask[shift_one_indices]
|
|
|
|
# Apply the mask to log probabilities.
|
|
ref_logp *= loss_mask
|
|
old_logp *= loss_mask
|
|
|
|
# Compute rewards and GAEs.
|
|
if self.use_dense_reward:
|
|
kl_rewards, rewards = ppo_functional.get_packed_reward_dense(
|
|
kl_ctl=self.kl_adapter.value,
|
|
clip_reward_value=self.max_reward_clip,
|
|
log_probs=old_logp,
|
|
ref_log_probs=ref_logp,
|
|
dense_reward_score=dense_reward_score,
|
|
short1cu_seqlens=short1cu_seqlens,
|
|
seq_no_eos_mask=seq_no_eos_mask,
|
|
reward_delta=self.reward_delta,
|
|
)
|
|
else:
|
|
kl_rewards, rewards = ppo_functional.get_packed_rewards(
|
|
kl_ctl=self.kl_adapter.value,
|
|
clip_reward_value=self.max_reward_clip,
|
|
log_probs=old_logp,
|
|
ref_log_probs=ref_logp,
|
|
reward_score=(reward_score),
|
|
short1cu_seqlens=short1cu_seqlens,
|
|
seq_no_eos_mask=seq_no_eos_mask,
|
|
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
|
|
)
|
|
_, returns = ppo_functional.get_packed_advantages_and_returns(
|
|
gamma=self.discount,
|
|
lam=self.gae_lambda,
|
|
values=denormalized_values,
|
|
rewards=rewards,
|
|
short1cu_seqlens=short1cu_seqlens,
|
|
seq_no_eos_mask=seq_no_eos_mask,
|
|
)
|
|
|
|
# Optionally perform normalization.
|
|
if self.value_norm:
|
|
self.rms.update(returns, mask=loss_mask)
|
|
normalized_returns = self.rms.normalize(returns)
|
|
else:
|
|
normalized_returns = returns
|
|
|
|
# Prepare data to be splitted into mini-batches.
|
|
flat_input = SequenceSample.from_default(
|
|
ids=list(range(input_.bs * self.group_size)),
|
|
data=dict(
|
|
returns=normalized_returns,
|
|
values=values,
|
|
ppo_loss_mask=loss_mask,
|
|
packed_input_ids=input_.data["packed_input_ids"],
|
|
kl_rewards=kl_rewards,
|
|
),
|
|
seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()],
|
|
)
|
|
|
|
# Logging.
|
|
with stats_tracker.scope("ppo_critic"):
|
|
stats_tracker.denominator(n_valid_tokens=loss_mask)
|
|
stats_tracker.stat(returns=returns, denominator="n_valid_tokens")
|
|
|
|
def _loss_fn(out, inp):
|
|
return _ppo_critic_loss_from_model_outputs(
|
|
out,
|
|
inp,
|
|
value_eps_clip=self.value_eps_clip,
|
|
kl_adapter=self.kl_adapter,
|
|
rms=None if not self.value_norm else self.rms,
|
|
)
|
|
|
|
# Run mini-batched PPO training!
|
|
for reuse in range(self.sample_reuse):
|
|
with stats_tracker.scope(f"reuse{reuse}"):
|
|
# NOTE: We split PPO minibatches in terms of #seqs instead of #tokens.
|
|
flat_input = SequenceSample.shuffled(flat_input)
|
|
bs = flat_input.bs
|
|
sizes = [0 for _ in range(self.n_minibatches)]
|
|
for idx in range(bs):
|
|
sizes[idx % self.n_minibatches] += 1
|
|
spec = SequenceSplitSpec(sizes=sizes)
|
|
datas = flat_input.split_with_spec(spec)
|
|
logger.info(
|
|
f"PPO minibatch split (size {self.n_minibatches}): "
|
|
f"#seqs: {[s.bs for s in datas]}, "
|
|
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}"
|
|
)
|
|
for mb_i, data in enumerate(datas):
|
|
with stats_tracker.scope(f"mb{mb_i}"):
|
|
stats = module.train_batch(
|
|
input_=data,
|
|
mb_spec=mb_spec,
|
|
version_steps=model.version.global_step,
|
|
loss_fn=_loss_fn,
|
|
loss_weight_fn=lambda x: x.data[
|
|
"ppo_loss_mask"
|
|
].count_nonzero(),
|
|
token_normalize_scope=self.token_normalize_scope,
|
|
)
|
|
stats_tracker.scalar(**stats)
|
|
|
|
model.inc_version()
|
|
|
|
return stats_tracker.export()
|
|
|
|
# Mock methods for profiling only.
|
|
def _mock_inference(
|
|
self,
|
|
model: model_api.Model,
|
|
dataset_input: SequenceSample,
|
|
) -> SequenceSample:
|
|
seqlens = flat2d(dataset_input.seqlens["packed_prompts"])
|
|
module = model.module
|
|
if not isinstance(module, ReaLModel):
|
|
module = module.module
|
|
mconfig = module.config
|
|
packed_input_ids = torch.randint(
|
|
0,
|
|
mconfig.vocab_size,
|
|
(sum(seqlens),),
|
|
dtype=torch.long,
|
|
device=model.device,
|
|
)
|
|
|
|
return SequenceSample.from_default(
|
|
seqlens=seqlens,
|
|
ids=dataset_input.ids,
|
|
data=dict(packed_input_ids=packed_input_ids),
|
|
)
|
|
|
|
# Mock methods for profiling only.
|
|
def _mock_train_step(
|
|
self,
|
|
model: model_api.Model,
|
|
dataset_input: SequenceSample,
|
|
) -> Dict:
|
|
seqlens = flat2d(dataset_input.seqlens["packed_prompts"])
|
|
bs = len(seqlens)
|
|
module = model.module
|
|
if not isinstance(module, ReaLModel):
|
|
module = module.module
|
|
mconfig = module.config
|
|
mdtype = module.dtype
|
|
short1_seqlens = [x - 1 for x in seqlens]
|
|
|
|
packed_logprobs = torch.randn(
|
|
(sum(short1_seqlens),), dtype=mdtype, device=model.device
|
|
)
|
|
packed_ref_logprobs = torch.randn_like(packed_logprobs)
|
|
prompt_mask = torch.zeros(
|
|
(sum(seqlens),), dtype=torch.bool, device=model.device
|
|
)
|
|
packed_input_ids = torch.randint(
|
|
0,
|
|
mconfig.vocab_size,
|
|
(sum(seqlens),),
|
|
dtype=torch.long,
|
|
device=model.device,
|
|
)
|
|
rewards = torch.randn(bs, dtype=mdtype, device=model.device)
|
|
seq_no_eos_mask = torch.randint(
|
|
0, 2, (bs,), dtype=torch.bool, device=model.device
|
|
)
|
|
values = torch.randn(
|
|
(sum(seqlens),),
|
|
dtype=mdtype,
|
|
device=model.device,
|
|
)
|
|
|
|
return SequenceSample.from_default(
|
|
seqlens=seqlens,
|
|
ids=dataset_input.ids,
|
|
data=dict(
|
|
packed_logprobs=packed_logprobs,
|
|
packed_ref_logprobs=packed_ref_logprobs,
|
|
prompt_mask=prompt_mask,
|
|
packed_input_ids=packed_input_ids,
|
|
rewards=rewards,
|
|
seq_no_eos_mask=seq_no_eos_mask,
|
|
values=values,
|
|
),
|
|
)
|
|
|
|
|
|
model_api.register_interface("ppo_actor", PPOActorInterface)
|
|
model_api.register_interface("ppo_critic", PPOCriticInterface)
|