mirror of https://github.com/inclusionAI/AReaL
183 lines
5.6 KiB
Python
183 lines
5.6 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
from typing import Dict
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.utils.data
|
|
import tqdm
|
|
|
|
import realhf.api.core.model_api as model_api
|
|
import realhf.base.constants as constants
|
|
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
|
|
from realhf.base.datapack import flat2d
|
|
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
|
from realhf.impl.model.utils.functional import (
|
|
build_shift_one_indices,
|
|
gather_packed_shifted_log_probs,
|
|
)
|
|
|
|
|
|
def compute_packed_sft_loss(
|
|
logits: torch.Tensor,
|
|
input_: SequenceSample,
|
|
) -> torch.Tensor:
|
|
packed_input_ids: torch.Tensor = input_.data["packed_input_ids"]
|
|
input_lens = torch.tensor(flat2d(input_.seqlens["packed_input_ids"]))
|
|
cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
|
|
prompt_mask = input_.data["prompt_mask"]
|
|
|
|
shift_one_indices = build_shift_one_indices(logits, cu_seqlens)
|
|
logprobs = gather_packed_shifted_log_probs(
|
|
logits, cu_seqlens, packed_input_ids
|
|
).float()
|
|
prompt_mask = prompt_mask[shift_one_indices]
|
|
logprobs = torch.where(prompt_mask, 0, logprobs)
|
|
|
|
loss_sum = -logprobs.sum()
|
|
|
|
with torch.no_grad():
|
|
seqlogp = torch.zeros(
|
|
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
|
|
)
|
|
for i in range(cu_seqlens.shape[0] - 1):
|
|
m = prompt_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
|
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
|
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
|
|
cu_seqlens,
|
|
logprobs.shape,
|
|
)
|
|
seqlogp[i] = torch.where(m, 0.0, logp).sum() / (
|
|
m.numel() - m.count_nonzero()
|
|
)
|
|
|
|
logging_ppl = (-seqlogp).exp().sum()
|
|
token_denorm = prompt_mask.numel() - prompt_mask.count_nonzero()
|
|
seq_denorm = torch.tensor(
|
|
[cu_seqlens.shape[0] - 1], dtype=torch.float32, device=logits.device
|
|
)
|
|
|
|
# Logging loss and perplexity.
|
|
logging_loss = loss_sum.detach().clone()
|
|
logging_token_denorm = token_denorm.detach().clone().float()
|
|
dist.all_reduce(
|
|
logging_ppl, op=dist.ReduceOp.SUM, group=constants.data_parallel_group()
|
|
)
|
|
dist.all_reduce(
|
|
logging_loss,
|
|
op=dist.ReduceOp.SUM,
|
|
group=constants.data_parallel_group(),
|
|
)
|
|
dist.all_reduce(
|
|
seq_denorm, op=dist.ReduceOp.SUM, group=constants.data_parallel_group()
|
|
)
|
|
dist.all_reduce(
|
|
logging_token_denorm,
|
|
op=dist.ReduceOp.SUM,
|
|
group=constants.data_parallel_group(),
|
|
)
|
|
|
|
loss = loss_sum / token_denorm
|
|
return loss, {
|
|
"loss": logging_loss,
|
|
"ppl": logging_ppl,
|
|
"n_tokens": logging_token_denorm,
|
|
"n_seqs": seq_denorm,
|
|
}
|
|
|
|
|
|
class SFTInterface(model_api.ModelInterface):
|
|
token_normalize_scope: Literal["global", "dp"] = "global"
|
|
|
|
def train_step(
|
|
self, model: model_api.Model, data: SequenceSample, mb_spec: MicroBatchSpec
|
|
) -> Dict:
|
|
module = model.module
|
|
|
|
module.train()
|
|
|
|
stat = module.train_batch(
|
|
input_=data,
|
|
loss_fn=compute_packed_sft_loss,
|
|
loss_weight_fn=lambda x: x.data["prompt_mask"].count_nonzero()
|
|
.logical_not()
|
|
.count_nonzero(),
|
|
token_normalize_scope=self.token_normalize_scope,
|
|
mb_spec=mb_spec,
|
|
version_steps=model.version.global_step,
|
|
)
|
|
|
|
model.inc_version()
|
|
|
|
res = dict()
|
|
global_stats = constants.log_global_stats_tracker(
|
|
return_dict=True, clear_stats_after_logging=True
|
|
)
|
|
if stat:
|
|
res = dict(
|
|
loss=float(stat["loss"]) / int(stat["n_tokens"]),
|
|
ppl=float(stat["ppl"]) / int(stat["n_seqs"]),
|
|
n_tokens=int(stat["n_tokens"]),
|
|
n_seqs=int(stat["n_seqs"]),
|
|
**global_stats,
|
|
)
|
|
return res
|
|
|
|
def save(self, model: model_api.Model, save_dir: str):
|
|
module = model.module
|
|
if not isinstance(module, ReaLModel):
|
|
module = module.module
|
|
module.save_to_hf(
|
|
tokenizer=model.tokenizer,
|
|
save_dir=save_dir,
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def evaluate(
|
|
self,
|
|
model_: model_api.Model,
|
|
eval_dataloader: torch.utils.data.DataLoader,
|
|
) -> Dict:
|
|
device = model_.device
|
|
module = model_.module
|
|
|
|
module.eval()
|
|
losses = n_seqs = ppl = n_tokens = 0
|
|
|
|
for step, x in enumerate(tqdm.tqdm(eval_dataloader)):
|
|
x: SequenceSample
|
|
|
|
res = module.eval_batch(
|
|
input_=x.to_device(device),
|
|
loss_fn=compute_packed_sft_loss,
|
|
mb_spec=MicroBatchSpec(
|
|
n_mbs=constants.pipe_parallel_world_size(),
|
|
balanced_seqs=True,
|
|
),
|
|
)
|
|
|
|
if res is not None:
|
|
_, stat = res
|
|
losses += stat["loss"]
|
|
n_tokens += stat["n_tokens"]
|
|
n_seqs += stat["n_seqs"]
|
|
ppl += stat["ppl"]
|
|
|
|
global_stats = constants.log_global_stats_tracker(
|
|
return_dict=True, clear_stats_after_logging=True
|
|
)
|
|
if res is not None:
|
|
return dict(
|
|
loss=float(losses / n_tokens),
|
|
ppl=float(ppl / n_seqs),
|
|
n_tokens=int(n_tokens),
|
|
n_seqs=int(n_seqs),
|
|
**global_stats,
|
|
)
|
|
return dict()
|
|
|
|
|
|
model_api.register_interface("sft", SFTInterface)
|