AReaL/realhf/impl/model/backend/mock_train.py

242 lines
8.0 KiB
Python

# Copyright 2025 Ant Group Inc.
import collections
import dataclasses
import math
from contextlib import contextmanager
from typing import *
import torch
import torch.distributed as dist
import transformers
from realhf.api.core import model_api
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
from realhf.base import constants, logging
from realhf.base.datapack import flat2d
from realhf.impl.model.backend.inference import PipelinableInferenceEngine
from realhf.impl.model.backend.pipe_runner import PipelineRunner, PipeTrainInstrSet
from realhf.impl.model.modules.mlp import get_activation_fn
from realhf.impl.model.nn.flatten_param import ContiguousParamSpec
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.nn.real_llm_base import ReaLModelBlock
from realhf.impl.model.parallelism.pipeline_parallel.tensor_storage import TensorBuffer
logger = logging.getLogger("Mock Train Backend", "benchmark")
@dataclasses.dataclass
class MockPipeTrainInstrSet(PipeTrainInstrSet):
"""A trivial pipelined intrsuction set for training.
Used for testing only.
"""
optim: torch.optim.Optimizer
def _exec_backward_pass(
self,
module: ReaLModel,
tensor_buffer: TensorBuffer,
stage_id: int,
micro_batch_id: int,
step_id: int,
):
output_x = tensor_buffer.get("batch_output_x", micro_batch_id, remove=True)
is_last_stage = constants.is_last_pipe_stage()
if is_last_stage:
loss: torch.Tensor = tensor_buffer.get(
"losses", micro_batch_id, remove=True
)
loss.backward()
tensor_buffer.put("losses", micro_batch_id, loss.detach().clone())
return
grad = tensor_buffer.get("grad", micro_batch_id, remove=True)
output_tensor = output_x.pp_output
torch.autograd.backward(tensors=output_tensor, grad_tensors=grad)
def _exec_reduce_grads(
self,
module: ReaLModel,
tensor_buffer: TensorBuffer,
stage_id: int,
micro_batch_id: int,
step_id: int,
):
for p in module.parameters():
if not p.requires_grad:
continue
dist.all_reduce(p.grad, group=constants.data_parallel_group())
def _exec_optimizer_step(
self,
module: ReaLModel,
tensor_buffer: TensorBuffer,
stage_id: int,
micro_batch_id: int,
step_id: int,
):
self.optim.step()
class AdamWithLossScale(torch.optim.Adam):
def get_loss_scale(self) -> torch.Tensor:
return torch.tensor([1.0], device=constants.current_device())
class MockTrainEngine(model_api.PipelinableEngine):
def __init__(self, module: ReaLModel, optimizer: AdamWithLossScale):
self.module = module
self.optim = optimizer
self.inf_engine = PipelinableInferenceEngine(module)
if constants.pipe_parallel_world_size() > 1:
self.pipe_runner = self.inf_engine.pipe_runner
self.device = module.device
self.dtype = module.dtype
def train(self, mode: bool = True):
self.module.train(mode)
return self
def eval(self):
self.module.eval()
return self
def train_batch(
self,
input_: SequenceSample,
mb_spec: MicroBatchSpec,
loss_fn: Callable,
loss_weight_fn: Callable,
token_normalize_scope: str,
version_steps: int,
):
self.optim.zero_grad()
if constants.pipe_parallel_world_size() > 1:
# Fusing the minibatched forward-backward in a pipeline training schedule.
instr_set = MockPipeTrainInstrSet(self, self.optim)
# NOTE: When training with pipeline parallel, num micro batches should be
# larger than 2 x num_pipeline_stages to avoid idle time.
return self.pipe_runner.train_batch(
instr_set=instr_set,
input_=input_,
mb_spec=mb_spec,
loss_fn=loss_fn,
loss_weight_fn=loss_weight_fn,
token_normalize_scope=token_normalize_scope,
version_steps=version_steps,
)
mb_inputs = input_.synced_data_parallel_split(mb_spec)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
)
if token_normalize_scope == "global":
dist.all_reduce(total_loss_weight, group=constants.data_parallel_group())
if total_loss_weight == 0:
raise model_api.ZeroTotalLossWeightException(
"The sum of loss weights of all micro batches is zero."
)
if constants.parallelism_rank() == 0:
logger.info(
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
f"#tokens: {input_.data['packed_input_ids'].shape[0]}, "
f"pp_size={constants.pipe_parallel_world_size()}, "
f"#tokens per mbs: {[mb.data['packed_input_ids'].shape[0] for mb in mb_inputs]}"
)
stat = collections.defaultdict(int)
for i, mb_input in enumerate(mb_inputs):
input_lens = torch.tensor(
flat2d(mb_input.seqlens["packed_input_ids"]),
dtype=torch.int32,
device=self.device,
)
max_seqlen = int(max(input_lens))
cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
model_output = self.module(
packed_input_ids=mb_input.data["packed_input_ids"],
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
).logits
loss, _stat = loss_fn(model_output, mb_input)
loss_scale = loss_weight_fn(mb_inputs[i]) / total_loss_weight
if token_normalize_scope == "global":
loss_scale *= constants.data_parallel_world_size()
loss *= loss_scale
for k, v in _stat.items():
stat[k] += v
return stat
@torch.no_grad()
def forward(
self,
input_: SequenceSample,
mb_spec: MicroBatchSpec,
post_hook: Callable[[torch.Tensor, SequenceSample], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
):
return self.inf_engine.forward(
input_=input_,
mb_spec=mb_spec,
post_hook=post_hook,
aggregate_fn=aggregate_fn,
)
@torch.no_grad()
def generate(
self,
input_: SequenceSample,
mb_spec: MicroBatchSpec,
tokenizer: transformers.PreTrainedTokenizerFast,
gconfig: model_api.GenerationHyperparameters = dataclasses.field(
default_factory=model_api.GenerationHyperparameters
),
):
return self.inf_engine.generate(
input_=input_,
mb_spec=mb_spec,
tokenizer=tokenizer,
gconfig=gconfig,
)
@dataclasses.dataclass
class MockTrainBackend(model_api.ModelBackend):
optimizer_name: str = dataclasses.field(
metadata={"choices": ["adam"]},
default="adam",
)
optimizer_config: dict = dataclasses.field(
default_factory=lambda: dict(
lr=1e-5, weight_decay=0.1, betas=(0.9, 0.95), eps=1e-5
)
)
def _initialize(
self, model: model_api.Model, spec: model_api.FinetuneSpec
) -> model_api.Model:
module = model.module
if not isinstance(module, ReaLModel):
raise ValueError("MegatronTrainBackend only supports ReaLModel.")
if self.optimizer_name == "adam":
optimizer = AdamWithLossScale(module.parameters(), **self.optimizer_config)
else:
raise NotImplementedError(
f"Optimizer {self.optimizer_name} not implemented for testing."
)
model.module = MockTrainEngine(module, optimizer)
model.backend_name = "mock_train"
return model
model_api.register_backend("mock_train", MockTrainBackend)