PullRequest: 340 [lite] Refactor trainer API into utilities and remove mb_spec in engine methods

Merge branch fw/lite-dev of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/340

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* support fsdp engine and sglang remote engine
* minor fix
* .
* refactor trainer
* add close
* rm mb_spec
* fix
This commit is contained in:
博惟 2025-07-10 11:10:10 +08:00 committed by 晓雷
parent 7be4ab0d18
commit c38cffc023
18 changed files with 541 additions and 538 deletions

View File

@ -187,7 +187,6 @@ class TrainEngine(abc.ABC):
def train_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict[str, float]:
@ -197,7 +196,6 @@ class TrainEngine(abc.ABC):
def eval_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
@ -207,7 +205,6 @@ class TrainEngine(abc.ABC):
def forward(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
output_seqlens: List[List[int]] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
@ -323,7 +320,7 @@ Extended engines (such as Actor in PPO) provide convenient organization and call
class Actor(Engine):
@torch.no_grad()
def compute_logps(self, input_: Dict[str, Tensor], mb_spec: MicroBatchSpec) -> torch.Tensor:
def compute_logps(self, input_: Dict[str, Tensor]) -> torch.Tensor:
... # unpad
logps = self.forward(xxx)
... # pad back
@ -332,8 +329,7 @@ class Actor(Engine):
def compute_advantages_and_returns(self, input_: Dict) -> Dict:
pass
def ppo_update(self, input_: Dict,
mb_spec: MicroBatchSpec) -> List[Dict[str, float]]:
def ppo_update(self, input_: Dict) -> List[Dict[str, float]]:
...
all_stats = []
for _ in range(self.ppo_n_minibatches):
@ -344,11 +340,10 @@ class Actor(Engine):
class Critic(Engine):
@torch.no_grad()
def compute_values(self, input_: Dict, mb_spec: MicroBatchSpec) -> torch.Tensor:
def compute_values(self, input_: Dict) -> torch.Tensor:
pass
def ppo_update(self, input_: Dict,
mb_spec: MicroBatchSpec) -> List[Dict[str, float]]:
def ppo_update(self, input_: Dict) -> List[Dict[str, float]]:
...
all_stats = []
for _ in range(self.ppo_n_minibatches):

View File

@ -1,5 +1,4 @@
import argparse
import getpass
import os
from dataclasses import asdict, dataclass, field
from pathlib import Path
@ -9,6 +8,7 @@ from hydra import compose as hydra_compose
from hydra import initialize as hydra_init
from omegaconf import MISSING, OmegaConf
from arealite.utils.fs import get_user_tmp
from realhf.api.cli_args import OptimizerConfig
@ -103,8 +103,8 @@ class FSDPEngineConfig:
@dataclass
class TrainEngineConfig:
experiment_name: str
trial_name: str
experiment_name: str = MISSING
trial_name: str = MISSING
path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"})
attn_impl: str = field(
default="flash_attention_2",
@ -120,6 +120,8 @@ class TrainEngineConfig:
default=False,
metadata={"help": "Initialize critic/reward model from LM checkpoint"},
)
# Runtime microbatch limit
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
# Training Backend Configuration
gradient_checkpointing: bool = field(
@ -307,92 +309,38 @@ class SGLangEngineConfig:
@dataclass
class ExperimentSaveEvalControl:
"""Controls the frequency of model saving and evaluation during training.
class _Timer:
experiment_name: str = MISSING
trial_name: str = MISSING
fileroot: str = MISSING
freq_epochs: Optional[int] = field(
default=None,
metadata={
"help": "Trigger frequency in epochs. None disables epoch-based saving."
},
)
freq_steps: Optional[int] = field(
default=None,
metadata={
"help": "Trigger frequency in steps. None disables step-based saving."
},
)
freq_secs: Optional[int] = field(
default=None,
metadata={
"help": "Trigger frequency in seconds. None disables time-based saving."
},
)
Manages independent counters for epochs, steps, and seconds. The model will be saved
or evaluated when any specified frequency condition is met.
Note:
- Epoch: Number of full passes through the training dataset
- Step: Number of individual training iterations
- Seconds: Wall-clock time duration
"""
@dataclass
class EvaluatorConfig(_Timer):
pass
total_train_epochs: int = field(
default=1, metadata={"help": "Total number of epochs to train the model."}
)
# Save control
save_freq_epochs: Optional[int] = field(
default=None,
metadata={
"help": "Save frequency in epochs. None disables epoch-based saving."
},
)
save_freq_steps: Optional[int] = field(
default=None,
metadata={"help": "Save frequency in steps. None disables step-based saving."},
)
save_freq_secs: Optional[int] = field(
default=None,
metadata={
"help": "Save frequency in seconds. None disables time-based saving."
},
)
# Checkpointing control
ckpt_freq_epochs: Optional[int] = field(
default=None,
metadata={
"help": "Checkpoint frequency in epochs. None uses save_freq_epochs. "
"Checkpointing is used for recover. Previous checkpoint is overwritten to save space."
},
)
ckpt_freq_steps: Optional[int] = field(
default=None,
metadata={
"help": "Checkpoint frequency in steps. None disables step-based checkpointing."
},
)
ckpt_freq_secs: Optional[int] = field(
default=None,
metadata={
"help": "Checkpoint frequency in seconds. None disables time-based checkpointing."
},
)
# Evaluation control
eval_freq_epochs: Optional[int] = field(
default=None,
metadata={
"help": "Evaluation frequency in epochs. None disables epoch-based evaluation."
},
)
eval_freq_steps: Optional[int] = field(
default=None,
metadata={
"help": "Evaluation frequency in steps. None disables step-based evaluation."
},
)
eval_freq_secs: Optional[int] = field(
default=None,
metadata={
"help": "Evaluation frequency in seconds. None disables time-based evaluation."
},
)
# Benchmark control
benchmark_steps: Optional[int] = field(
default=None,
metadata={
"help": "Terminate training after this number of steps. "
"For benchmarking purposes only. None indicates normal training."
},
)
benchmark_n_seqs: Optional[int] = field(
default=None,
metadata={
"help": "Terminate training after consuming this number of samples. "
"For benchmarking purposes only. None indicates normal training."
},
)
@dataclass
class SaverConfig(_Timer):
pass
@dataclass
@ -423,11 +371,23 @@ class TensorBoardConfig:
path: Optional[str] = None
def get_user_tmp():
user = getpass.getuser()
user_tmp = os.path.join("/home", user, ".cache", "realhf")
os.makedirs(user_tmp, exist_ok=True)
return user_tmp
@dataclass
class StatsLoggerConfig:
experiment_name: str = MISSING
trial_name: str = MISSING
fileroot: str = MISSING
wandb: WandBConfig = field(
default_factory=WandBConfig,
metadata={"help": "Weights & Biases configuration."},
)
swanlab: SwanlabConfig = field(
default_factory=SwanlabConfig,
metadata={"help": "SwanLab configuration."},
)
tensorboard: TensorBoardConfig = field(
default_factory=TensorBoardConfig,
metadata={"help": "TensorBoard configuration. Only 'path' field required."},
)
@dataclass
@ -498,7 +458,9 @@ class ClusterSpecConfig:
@dataclass
class DatasetConfig:
type: str = field(default="", metadata={"help": "Type of implemented dataset"})
type: Optional[str] = field(
default=None, metadata={"help": "Type of implemented dataset"}
)
batch_size: int = field(
default=1, metadata={"help": "Batch size of the dataloader"}
)
@ -517,51 +479,6 @@ class DatasetConfig:
drop_last: bool = field(default=True)
@dataclass
class TrainerConfig:
experiment_name: str = field(
default=MISSING,
metadata={"help": "Name of the experiment (no '_' or '/'). Required."},
)
trial_name: str = field(
default=MISSING,
metadata={"help": "Name of the trial (no '-' or '/'). Required."},
)
fileroot: str = field(
default=get_user_tmp(),
metadata={
"help": "Root for logs and checkpoints. Should be available to all nodes."
},
)
wandb: WandBConfig = field(
default_factory=WandBConfig,
metadata={"help": "Weights & Biases configuration."},
)
swanlab: SwanlabConfig = field(
default_factory=SwanlabConfig,
metadata={"help": "SwanLab configuration."},
)
tensorboard: TensorBoardConfig = field(
default_factory=TensorBoardConfig,
metadata={"help": "TensorBoard configuration. Only 'path' field required."},
)
allocation_mode: str = field(
default="",
metadata={
"help": "GPU parallel strategy allocation mode. "
"Options: manual/heuristic or pattern-based."
},
)
seed: int = field(default=1, metadata={"help": "Random seed for reproducibility."})
exp_ctrl: ExperimentSaveEvalControl = field(
default_factory=ExperimentSaveEvalControl,
metadata={"help": "Experiment save/evaluation control configuration."},
)
tokenizer_path: str = field(default="")
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
@dataclass
class BaseExperimentConfig:
# NOTE: we need this unified config class because different experiments
@ -585,14 +502,45 @@ class BaseExperimentConfig:
n_gpus_per_node: int = field(
default=8, metadata={"help": "Number of GPUs per node for this experiment."}
)
allocation_mode: str = field(
default="",
metadata={
"help": "GPU parallel strategy allocation mode. "
"Options: manual/heuristic or pattern-based."
},
)
seed: int = field(default=1, metadata={"help": "Random seed for reproducibility."})
total_train_epochs: int = field(
default=1, metadata={"help": "Total number of epochs to train the model."}
)
total_train_steps: Optional[int] = field(
default=None,
metadata={
"help": "Terminate training after this number of steps. "
"For benchmarking purposes only. None indicates normal training."
},
)
total_train_n_seqs: Optional[int] = field(
default=None,
metadata={
"help": "Terminate training after consuming this number of samples. "
"For benchmarking purposes only. None indicates normal training."
},
)
tokenizer_path: str = field(default="")
train_dataset: DatasetConfig = field(default_factory=DatasetConfig)
valid_dataset: DatasetConfig = field(default_factory=DatasetConfig)
saver: SaverConfig = field(default_factory=SaverConfig)
checkpointer: SaverConfig = field(default_factory=SaverConfig)
evaluator: EvaluatorConfig = field(default_factory=EvaluatorConfig)
stats_logger: StatsLoggerConfig = field(default_factory=StatsLoggerConfig)
@dataclass
class SFTConfig(BaseExperimentConfig):
model: TrainEngineConfig = field(default_factory=TrainEngineConfig)
trainer: TrainerConfig = field(default_factory=TrainerConfig)
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:

View File

@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch
from tensordict import TensorDict
from arealite.api.cli_args import MicroBatchSpec
from arealite.api.io_struct import (
FinetuneSpec,
LLMRequest,
@ -79,7 +78,6 @@ class TrainEngine(abc.ABC):
def train_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict[str, float]:
@ -90,7 +88,6 @@ class TrainEngine(abc.ABC):
def eval_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
@ -101,7 +98,6 @@ class TrainEngine(abc.ABC):
def forward(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
output_seqlens: List[List[int]] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,

View File

@ -1,130 +0,0 @@
import getpass
import os
from typing import Dict
import torch.distributed as dist
import wandb
from tensorboardX import SummaryWriter
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import TrainerConfig
from arealite.api.engine_api import InferenceEngine, TrainEngine
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, timeutil
class Trainer:
def __init__(
self,
config: TrainerConfig,
train_dataloader: StatefulDataLoader,
valid_dataloader: StatefulDataLoader,
engine: TrainEngine,
inf_engine: InferenceEngine | None = None,
):
self.config = config
self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader
self.engine = engine
self.inf_engine = inf_engine
self.tokenizer = load_hf_tokenizer(config.tokenizer_path)
self.save_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.save_freq_epochs,
freq_step=config.exp_ctrl.save_freq_steps,
freq_sec=config.exp_ctrl.save_freq_secs,
)
self.eval_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.eval_freq_epochs,
freq_step=config.exp_ctrl.eval_freq_steps,
freq_sec=config.exp_ctrl.eval_freq_steps,
)
self.logger = logging.getLogger(self.__class__.__name__)
self.init_stats_logging()
def init_stats_logging(self):
"""
Initialize wandb and/or tensorboard according to config.
If torch.distributed is initialized
Return:
tensorboard SummaryWriter if self.config.tensorboard.path is not None
"""
if dist.is_initialized() and dist.get_rank() != 0:
return
# wandb init, connect to remote wandb host
if self.config.wandb.mode != "disabled":
wandb.login()
wandb.init(
mode=self.config.wandb.mode,
entity=self.config.wandb.entity,
project=self.config.wandb.project or self.config.experiment_name,
name=self.config.wandb.name or self.config.trial_name,
job_type=self.config.wandb.job_type,
group=self.config.wandb.group
or f"{self.config.experiment_name}_{self.config.trial_name}",
notes=self.config.wandb.notes,
tags=self.config.wandb.tags,
config=self.config.wandb.config,
dir=Trainer.get_log_path(self.config),
force=True,
id=f"{self.config.experiment_name}_{self.config.trial_name}_train",
resume="allow",
settings=wandb.Settings(start_method="fork"),
)
# tensorboard logging
self.summary_writer = None
if self.config.tensorboard.path is not None:
self.summary_writer = SummaryWriter(log_dir=self.config.tensorboard.path)
def log_wandb_tensorboard(self, step: int, data: Dict):
if dist.is_initialized() and dist.get_rank() != 0:
return
wandb.log(data, step=step)
if self.summary_writer is not None:
for key, val in data.items():
self.summary_writer.add_scalar(f"{key}", val, step)
def close_wandb_tensorboard(self):
if dist.is_initialized() and dist.get_rank() != 0:
return
wandb.finish()
if self.summary_writer is not None:
self.summary_writer.close()
@staticmethod
def get_save_checkpoint_path(
config: TrainerConfig,
epoch: int,
step: int,
globalstep: int,
name: str = "default",
):
path = os.path.join(
f"{config.fileroot}/checkpoints/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}",
name,
f"epoch{epoch}epochstep{step}globalstep{globalstep}",
)
os.makedirs(path, exist_ok=True)
return path
@staticmethod
def get_log_path(config: TrainerConfig):
path = f"{config.fileroot}/logs/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}"
os.makedirs(path, exist_ok=True)
return path
def log(self, msg: str, level="info"):
if dist.is_initialized() and dist.get_rank() > 0:
return
log_fn = getattr(self.logger, level, "info")
return log_fn(msg)
def train(self):
raise NotImplementedError()

View File

@ -21,7 +21,6 @@ from transformers import (
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import (
FinetuneSpec,
MicroBatchSpec,
SaveLoadMeta,
TrainEngine,
WeightUpdateMeta,
@ -319,15 +318,15 @@ class FSDPEngine(TrainEngine):
assert self.lr_scheduler is not None
self.lr_scheduler.step()
def _prepare_mb_list(
self, input_: TensorDict, mb_spec: MicroBatchSpec
) -> MicroBatchList:
def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)
packed_input = pack_tensor_dict(input_)
mb_list = split_packed_tensor_dict_into_mb_list(
packed_input,
mb_spec,
self.config.mb_spec,
)
mb_list = pad_mb_list(mb_list, pad_value=0.0)
# NOTE: We unsqueeze here because huggingface transformer models requires
@ -338,7 +337,6 @@ class FSDPEngine(TrainEngine):
def train_batch(
self,
input_: TensorDict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict[str, float]:
@ -349,7 +347,7 @@ class FSDPEngine(TrainEngine):
assert self.lr_scheduler is not None
self.optimizer.zero_grad()
mb_list = self._prepare_mb_list(input_, mb_spec)
mb_list = self._prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
@ -398,12 +396,12 @@ class FSDPEngine(TrainEngine):
def eval_batch(
self,
input_: TensorDict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
"""Evaluate on a batch."""
mb_list = self._prepare_mb_list(input_, mb_spec)
input_ = input_.to(self.device)
mb_list = self._prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
@ -431,14 +429,14 @@ class FSDPEngine(TrainEngine):
def forward(
self,
input_: TensorDict,
mb_spec: MicroBatchSpec,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:
"""Forward pass with optional post-processing."""
input_ = input_.to(self.device)
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
mb_list = self._prepare_mb_list(input_, mb_spec)
mb_list = self._prepare_mb_list(input_)
if output_seqlens is None:
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()

View File

@ -9,12 +9,7 @@ import torch.distributed as dist
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
from arealite.api.cli_args import (
EngineConfig,
MicroBatchSpec,
ParallelismConfig,
TrainingArgs,
)
from arealite.api.cli_args import EngineConfig, ParallelismConfig, TrainingArgs
from arealite.api.engine_api import TrainEngine
from arealite.api.io_struct import FinetuneSpec
from arealite.api.llm_client_api import LLMClient
@ -150,7 +145,6 @@ class HFEngine(TrainEngine):
def train_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict:
@ -192,7 +186,6 @@ class HFEngine(TrainEngine):
def eval_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
@ -221,7 +214,6 @@ class HFEngine(TrainEngine):
def forward(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1),

View File

@ -0,0 +1,94 @@
from typing import Dict
import torch
import torch.utils.data
from tensordict import TensorDict
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import TrainEngine
from arealite.engine.fsdp_engine import FSDPEngine
from arealite.utils.functional import gather_logprobs
from realhf.base import stats_tracker
class LMEngine:
def __init__(self, engine: TrainEngine):
self.engine = engine
def train_lm(self, data: TensorDict):
self.engine.train()
return self.engine.train_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x["prompt_mask"].logical_not().count_nonzero(),
)
def evaluate_lm(self, data):
self.engine.eval()
self.engine.eval_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x["prompt_mask"].logical_not().count_nonzero(),
)
class FSDPLMEngine(FSDPEngine):
def __init__(self, config: TrainEngineConfig):
super().__init__(config)
self.lm_engine = LMEngine(self)
def train_lm(self, data):
return self.lm_engine.train_lm(data)
def evaluate_lm(self, data):
return self.lm_engine.evaluate_lm(data)
def compute_packed_sft_loss(
logits: torch.Tensor, input_: Dict[str, torch.Tensor]
) -> torch.Tensor:
packed_input_ids: torch.Tensor = input_["input_ids"]
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
prompt_mask = input_["prompt_mask"].bool()
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
prompt_mask = torch.roll(prompt_mask, shifts=-1, dims=-1)
logprobs = torch.where(prompt_mask, 0, logprobs)
loss = -logprobs.sum() / prompt_mask.logical_not().count_nonzero()
with torch.no_grad():
seqlogp = torch.zeros(
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
)
for i in range(cu_seqlens.shape[0] - 1):
m = prompt_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
cu_seqlens,
logprobs.shape,
)
seqlogp[i] = torch.where(m, 0.0, logp.detach()).sum() / (
m.numel() - m.count_nonzero()
)
## Loggin stats
stats_tracker.denominator(
n_seqs=torch.ones(
cu_seqlens.shape[0] - 1, dtype=torch.bool, device=logprobs.device
),
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
n_valid_tokens=prompt_mask.logical_not(),
prompt_tokens=prompt_mask,
)
stats_tracker.stat(ppl=(-seqlogp).exp().float(), denominator="n_seqs")
stats_tracker.stat(loss=-logprobs.detach(), denominator="n_valid_tokens")
vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float()
stats_tracker.stat(
vocab_min_logits=vocab_min_logits,
vocab_max_logits=vocab_max_logits,
denominator="n_tokens",
)
return loss

View File

@ -81,22 +81,10 @@ def engine():
@torch.no_grad()
def test_forward_microbatch(engine, mock_input):
engine.eval()
x2 = (
engine.forward(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
)
.squeeze(0)
.mean(-1)
)
x1 = (
engine.forward(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
)
.squeeze(0)
.mean(-1)
)
engine.config.mb_spec = MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100)
x2 = engine.forward(input_=mock_input).squeeze(0).mean(-1)
engine.config.mb_spec = MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100)
x1 = engine.forward(input_=mock_input).squeeze(0).mean(-1)
input_ids = mock_input["input_ids"]
assert x1.shape[:1] == input_ids.shape[:1]
assert x2.shape[:1] == input_ids.shape[:1]
@ -106,9 +94,9 @@ def test_forward_microbatch(engine, mock_input):
@torch.no_grad()
def test_eval_batch(engine, mock_input):
engine.eval()
engine.config.mb_spec = MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100)
eval_result = engine.eval_batch(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
loss_fn=mock_loss_fn,
loss_weight_fn=lambda x: x["cu_seqlens"][-1],
)
@ -120,9 +108,9 @@ def test_eval_batch(engine, mock_input):
def test_train_batch(engine, mock_input):
engine.train()
engine.config.mb_spec = MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100)
train_result = engine.train_batch(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
loss_fn=mock_loss_fn,
loss_weight_fn=lambda x: x["cu_seqlens"][-1],
)
@ -144,18 +132,13 @@ def test_hf_save_load_weights(tmp_path_factory, engine, mock_input):
base_model_path=None,
)
old = engine.forward(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
)
engine.config.mb_spec = MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100)
old = engine.forward(input_=mock_input)
engine.save(save_load_meta)
for name, param in engine.model.named_parameters():
param.zero_()
engine.load(save_load_meta)
new = engine.forward(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
)
new = engine.forward(input_=mock_input)
assert torch.allclose(old, new)

View File

@ -1,156 +0,0 @@
import time
from typing import Dict
import torch
import torch.distributed as dist
import torch.utils.data
from arealite.api.io_struct import SaveLoadMeta
from arealite.api.trainer_api import Trainer
from arealite.utils.functional import gather_logprobs
from arealite.utils.logging import record_timing
from realhf.api.core.data_api import tabulate_stats
from realhf.base import logging, stats_tracker
def compute_packed_sft_loss(
logits: torch.Tensor,
input_: Dict[str, torch.Tensor],
) -> torch.Tensor:
packed_input_ids: torch.Tensor = input_["input_ids"]
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
prompt_mask = input_["prompt_mask"].bool()
logits = logits.float()
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
prompt_mask = torch.roll(prompt_mask, shifts=-1, dims=-1)
logprobs = torch.where(prompt_mask, 0, logprobs)
loss = -logprobs.sum() / prompt_mask.logical_not().count_nonzero()
with torch.no_grad():
seqlogp = torch.zeros(
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
)
for i in range(cu_seqlens.shape[0] - 1):
m = prompt_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
cu_seqlens,
logprobs.shape,
)
seqlogp[i] = torch.where(m, 0.0, logp.detach()).sum() / (
m.numel() - m.count_nonzero()
)
## Loggin stats
stats_tracker.denominator(
n_seqs=torch.ones(
cu_seqlens.shape[0] - 1, dtype=torch.bool, device=logprobs.device
),
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
n_valid_tokens=prompt_mask.logical_not(),
prompt_tokens=prompt_mask,
)
stats_tracker.stat(ppl=(-seqlogp).exp().float(), denominator="n_seqs")
stats_tracker.stat(loss=-logprobs.detach(), denominator="n_valid_tokens")
vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float()
stats_tracker.stat(
vocab_min_logits=vocab_min_logits,
vocab_max_logits=vocab_max_logits,
denominator="n_tokens",
)
return loss
class SFTTrainer(Trainer):
def train(self):
total_epochs = self.config.exp_ctrl.total_train_epochs
steps_per_epoch = len(self.train_dataloader)
self.log(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
global_step = 0
start_time = time.monotonic()
for epoch in range(total_epochs):
for step, data in enumerate(self.train_dataloader):
self.engine.train()
timing_stats = {}
with record_timing("timeperf/train_step", timing_stats):
with stats_tracker.scope("sft"):
stats = self.engine.train_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x["prompt_mask"]
.logical_not()
.count_nonzero(),
mb_spec=self.config.mb_spec,
)
self.engine.step_lr_scheduler()
stats_tracker.scalar(**stats)
if self.save_ctl.check(
epochs=int(step == steps_per_epoch - 1), steps=1
):
self.log("Saving model ...")
with record_timing("timeperf/save", timing_stats):
save_path = self.get_save_checkpoint_path(
self.config, epoch, step, global_step
)
meta = SaveLoadMeta(
path=save_path,
weight_format="hf",
with_optim=False,
tokenizer=self.tokenizer,
base_model_path=self.config.tokenizer_path,
)
self.engine.save(meta)
if self.eval_ctl.check(
epochs=int(step == steps_per_epoch - 1), steps=1
):
if dist.get_rank() == 0:
self.log("Running evaluation ...")
with record_timing("timeperf/eval", timing_stats):
self.evaluate()
stats = stats_tracker.export()
stats.update(timing_stats)
self.log_wandb_tensorboard(global_step, stats)
self.log(
f"Epoch {epoch+1}/{total_epochs} "
f"Step {step+1}/{steps_per_epoch} "
f"Train step {global_step + 1}/{total_epochs * steps_per_epoch} done."
)
self.log(
f"Detailed time stats: \n{tabulate_stats(timing_stats, floatfmt='.2f')}"
)
self.log(f"SFT training stats:\n{tabulate_stats(stats)}")
global_step += 1
self.log(
f"Training completes! Total time elapsed {time.monotonic() - start_time:.2f}."
)
self.close_wandb_tensorboard()
def evaluate(self):
if self.valid_dataloader is None:
return
self.engine.eval()
for data in self.valid_dataloader:
with stats_tracker.scope("sft-eval"):
# No need to log anything. Logging will be handled outside
# via stats_tracker.export().
self.engine.eval_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x["prompt_mask"]
.logical_not()
.count_nonzero(),
mb_spec=self.config.mb_spec,
)

View File

@ -277,7 +277,7 @@ def pad_and_stack_tensors_along_first_dim(tensor_list: List[torch.Tensor]):
@dataclass
class MicroBatchList:
data: Dict[str, Any]
data: TensorDict
mb_spec: MicroBatchSpec
mbs: List[TensorDict]
forward_indices: List[int]
@ -379,7 +379,7 @@ def pad_packed_tensor_dict(
data: TensorDict,
pad_to_length: int,
pad_value: float = 0.0,
) -> Tuple[Dict[str, Any], int]:
) -> Tuple[TensorDict, int]:
"""Pad a packed tensor dict to a specified length.
This function assumes that the input data contains "cu_seqlens" and "max_seqlen" key,
and all other tensors of shape [total_length, ] will be padded to `pad_to_length`.
@ -391,7 +391,7 @@ def pad_packed_tensor_dict(
pad_to_length (int): The length to pad the tensors to. All tensors
Returns:
Dict[str, torch.Tensor]: Dictionary with padded tensors and modified "cu_seqlens" and
TensorDict: Dictionary with padded tensors and modified "cu_seqlens" and
"max_seqlen".
int: The pad length.
"""
@ -420,7 +420,7 @@ def pad_packed_tensor_dict(
padded_data[key] = padded_tensor
else:
padded_data[key] = value
return padded_data, pad_length
return TensorDict(padded_data, batch_size=data.batch_size), pad_length
def pad_mb_list(
@ -453,13 +453,17 @@ def unsqueeze_packed_tensor_dict(data: TensorDict) -> TensorDict:
assert "max_seqlen" in data, "Input data must contain 'max_seqlen' key."
total_length = data["cu_seqlens"][-1].item()
new_data = {}
for key, value in data.items():
if key == "cu_seqlens" or key == "max_seqlen":
continue
if (
key not in ["cu_seqlens", "max_seqlen"]
and torch.is_tensor(value)
and value.numel() == total_length
):
new_data[key] = value.unsqueeze(dim=0)
else:
if torch.is_tensor(value) and value.numel() == total_length:
data[key] = value.unsqueeze(dim=0)
return data
new_data[key] = value
return TensorDict(new_data, batch_size=data.batch_size)
def unsqueeze_mb_list(
@ -472,7 +476,6 @@ def unsqueeze_mb_list(
new_mbs.append(unsqueeze_packed_tensor_dict(mb))
if mb_list.padded_mbs is not None:
new_padded_mbs.append(unsqueeze_packed_tensor_dict(mb_list.padded_mbs[i]))
mb_list.mbs = new_mbs
mb_list.padded_mbs = new_padded_mbs if mb_list.padded_mbs is not None else None
return mb_list

View File

@ -0,0 +1,36 @@
from typing import TYPE_CHECKING, Any, Callable
from arealite.api.cli_args import EvaluatorConfig
from arealite.api.io_struct import FinetuneSpec
from realhf.base import timeutil
if TYPE_CHECKING:
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
class Evaluator:
def __init__(self, config: EvaluatorConfig, ft_spec: FinetuneSpec):
self.config = config
self.ft_sepc = ft_spec
self.freq_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.freq_epochs,
freq_step=config.freq_steps,
freq_sec=config.freq_secs,
)
def evaluate(
self,
valid_dataloader: "StatefulDataLoader",
evaluate_fn: Callable[["TensorDict"], Any],
epoch: int,
step: int,
global_step: int,
):
if not self.freq_ctl.check(
epochs=int(step == self.ft_sepc.steps_per_epoch - 1), steps=1
):
return
for data in valid_dataloader:
evaluate_fn(data)

9
arealite/utils/fs.py Normal file
View File

@ -0,0 +1,9 @@
import getpass
import os
def get_user_tmp():
user = getpass.getuser()
user_tmp = os.path.join("/home", user, ".cache", "realhf")
os.makedirs(user_tmp, exist_ok=True)
return user_tmp

View File

@ -1,9 +0,0 @@
import time
from contextlib import contextmanager
@contextmanager
def record_timing(name, timing_stats):
start_time = time.perf_counter()
yield
timing_stats[name] = time.perf_counter() - start_time

68
arealite/utils/saver.py Normal file
View File

@ -0,0 +1,68 @@
import getpass
import os
from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import SaverConfig
from arealite.api.engine_api import TrainEngine
from arealite.api.io_struct import FinetuneSpec, SaveLoadMeta
from realhf.base import timeutil
class Saver:
def __init__(self, config: SaverConfig, ft_spec: FinetuneSpec, for_recover: bool):
self.config = config
self.ft_sepc = ft_spec
self.for_recover = for_recover
self.freq_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.freq_epochs,
freq_step=config.freq_steps,
freq_sec=config.freq_secs,
)
@staticmethod
def get_save_checkpoint_path(
config: SaverConfig,
epoch: int,
step: int,
globalstep: int,
name: str = "default",
):
path = os.path.join(
f"{config.fileroot}/checkpoints/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}",
name,
f"epoch{epoch}epochstep{step}globalstep{globalstep}",
)
os.makedirs(path, exist_ok=True)
return path
def save(
self,
engine: TrainEngine,
epoch: int,
step: int,
global_step: int,
name: str = "default",
tokenizer: PreTrainedTokenizerFast | None = None,
base_model_path: str | None = None,
):
if not self.freq_ctl.check(
epochs=int(step == self.ft_sepc.steps_per_epoch - 1), steps=1
):
return
path = self.get_save_checkpoint_path(epoch, step, global_step, name)
weight_format = "hf"
with_optim = False
if self.for_recover:
weight_format = "dcp"
with_optim = True
meta = SaveLoadMeta(
path=path,
weight_format=weight_format,
with_optim=with_optim,
tokenizer=tokenizer,
base_model_path=base_model_path,
)
engine.save(meta)

View File

@ -0,0 +1,111 @@
import getpass
import os
import time
from typing import Dict
import torch.distributed as dist
import wandb
from tensorboardX import SummaryWriter
from arealite.api.cli_args import StatsLoggerConfig
from arealite.api.io_struct import FinetuneSpec
from realhf.api.core.data_api import tabulate_stats
from realhf.base import logging
class StatsLogger:
def __init__(self, config: StatsLoggerConfig, ft_spec: FinetuneSpec):
self.logger = logging.getLogger("StatsLogger", "system")
self.config = config
self.ft_spec = ft_spec
self.init()
def init(self):
if dist.is_initialized() and dist.get_rank() != 0:
return
self.start_time = time.perf_counter()
# wandb init, connect to remote wandb host
if self.config.wandb.mode != "disabled":
wandb.login()
wandb.init(
mode=self.config.wandb.mode,
entity=self.config.wandb.entity,
project=self.config.wandb.project or self.config.experiment_name,
name=self.config.wandb.name or self.config.trial_name,
job_type=self.config.wandb.job_type,
group=self.config.wandb.group
or f"{self.config.experiment_name}_{self.config.trial_name}",
notes=self.config.wandb.notes,
tags=self.config.wandb.tags,
config=self.config.wandb.config,
dir=self.get_log_path(self.config),
force=True,
id=f"{self.config.experiment_name}_{self.config.trial_name}_train",
resume="allow",
settings=wandb.Settings(start_method="fork"),
)
# tensorboard logging
self.summary_writer = None
if self.config.tensorboard.path is not None:
self.summary_writer = SummaryWriter(log_dir=self.config.tensorboard.path)
def close(self):
if dist.is_initialized() and dist.get_rank() != 0:
return
self.info(
f"Training completes! Total time elapsed {time.monotonic() - self.start_time:.2f}."
)
wandb.finish()
if self.summary_writer is not None:
self.summary_writer.close()
def commit(self, epoch: int, step: int, global_step: int, data: Dict):
if dist.is_initialized() and dist.get_rank() != 0:
return
self.info(
f"Epoch {epoch+1}/{self.ft_spec.total_train_epochs} "
f"Step {step+1}/{self.ft_spec.steps_per_epoch} "
f"Train step {global_step + 1}/{self.ft_spec.total_train_steps} done."
)
self.info("Stats:")
self.print_stats(data)
wandb.log(data, step=global_step)
if self.summary_writer is not None:
for key, val in data.items():
self.summary_writer.add_scalar(f"{key}", val, global_step)
def print_stats(self, stats: Dict[str, float]):
self.info("\n" + tabulate_stats(stats))
@staticmethod
def get_log_path(config: StatsLoggerConfig):
path = f"{config.fileroot}/logs/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}"
os.makedirs(path, exist_ok=True)
return path
def info(self, msg: str, *args, **kwargs):
if dist.is_initialized() and dist.get_rank() > 0:
return
self.logger.info(msg, *args, **kwargs)
def debug(self, msg: str, *args, **kwargs):
if dist.is_initialized() and dist.get_rank() > 0:
return
self.logger.debug(msg, *args, **kwargs)
def critical(self, msg: str, *args, **kwargs):
if dist.is_initialized() and dist.get_rank() > 0:
return
self.logger.critical(msg, *args, **kwargs)
def warning(self, msg: str, *args, **kwargs):
if dist.is_initialized() and dist.get_rank() > 0:
return
self.logger.warning(msg, *args, **kwargs)
def error(self, msg: str, *args, **kwargs):
if dist.is_initialized() and dist.get_rank() > 0:
return
self.logger.error(msg, *args, **kwargs)

View File

@ -6,13 +6,19 @@ cluster:
name_resolve:
type: nfs
nfs_record_root: /tmp/areal/name_resolve
seed: 1
total_train_epochs: 1
tokenizer_path: ${model.path}
model:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
init_from_scratch: false
pad_mbs_to_max_tokens: true
gradient_checkpointing: false
bf16: true
mb_spec:
max_tokens_per_mb: 4096
optimizer:
type: adam
lr: 2e-5
@ -24,30 +30,46 @@ model:
gradient_clipping: 1.0
backend: fsdp
trainer:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled
seed: 1
exp_ctrl:
total_train_epochs: 1
eval_freq_steps: 1
tokenizer_path: ${model.path}
mb_spec:
max_tokens_per_mb: 4096
train_dataset:
type: gsm8k-sft
batch_size: 128
shuffle: true
pin_memory: true
num_workers: 4
valid_dataset:
type: gsm8k-sft
batch_size: 128
shuffle: true
pin_memory: true
num_workers: 4
num_workers: 4
# Utilities
saver:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null
checkpointer:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: 3600
evaluator:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: null
freq_steps: 1
freq_secs: null
stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled

View File

@ -1,18 +1,19 @@
import os
import sys
import torch
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import DataCollatorWithPadding
from arealite.api.cli_args import SFTConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec
from arealite.engine.fsdp_engine import FSDPEngine
from arealite.trainer.sft import SFTTrainer
from arealite.engine.sft.lm_engine import FSDPLMEngine
from arealite.utils.data import pad_sequences_to_tensors
from arealite.utils.evaluator import Evaluator
from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import stats_tracker
def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
@ -42,10 +43,9 @@ def main_sft():
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
tokenizer = load_hf_tokenizer(config.trainer.tokenizer_path)
tokenizer = load_hf_tokenizer(config.tokenizer_path)
# Create dataset and dataloaders
assert config.train_dataset == "gsm8k-sft"
train_dataloader = StatefulDataLoader(
get_gsm8k_dataset("train", tokenizer, rank, world_size),
batch_size=config.train_dataset.batch_size // world_size,
@ -54,7 +54,6 @@ def main_sft():
collate_fn=pad_sequences_to_tensors,
drop_last=config.train_dataset.drop_last,
)
assert config.valid_dataset == "gsm8k-sft"
valid_dataloader = StatefulDataLoader(
get_gsm8k_dataset("test", tokenizer, rank, world_size),
batch_size=config.valid_dataset.batch_size // world_size,
@ -66,22 +65,52 @@ def main_sft():
# Initialize engine
ft_spec = FinetuneSpec(
total_train_epochs=config.trainer.exp_ctrl.total_train_epochs,
dataset_size=len(train_dataloader),
total_train_epochs=config.total_train_epochs,
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
train_batch_size=config.train_dataset.batch_size,
)
engine = FSDPEngine(config=config.model)
engine = FSDPLMEngine(config=config.model)
engine.initialize(None, ft_spec)
# Run training.
trainer = SFTTrainer(
config=config.trainer,
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
engine=engine,
inf_engine=None,
)
trainer.train()
saver = Saver(config.saver, ft_spec, for_recover=False)
logger = StatsLogger(config.stats_logger, ft_spec)
evaluator = Evaluator(config.evaluator, ft_spec)
total_epochs = config.total_train_epochs
steps_per_epoch = len(train_dataloader)
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
global_step = 0
for epoch in range(total_epochs):
for step, data in enumerate(train_dataloader):
with (
stats_tracker.record_timing("train_step"),
stats_tracker.scope("sft"),
):
stats = engine.train_lm(data)
engine.step_lr_scheduler()
stats_tracker.scalar(**stats)
with stats_tracker.record_timing("save"):
saver.save(engine, epoch, step, global_step)
with stats_tracker.record_timing("eval"), stats_tracker.scope("sft-eval"):
# No need to log anything. Logging will be handled outside
# via stats_tracker.export().
evaluator.evaluate(
valid_dataloader,
engine.evaluate_lm,
epoch,
step,
global_step,
)
logger.commit(epoch, step, global_step, stats_tracker.export())
global_step += 1
engine.destroy()
logger.close()
if __name__ == "__main__":

View File

@ -1,4 +1,6 @@
import time
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
from typing import Dict
@ -49,6 +51,17 @@ class DistributedStatsTracker:
return key
return "/".join(self.scope_stack + [key])
@contextmanager
def record_timing(self, key):
start_time = time.perf_counter()
try:
yield
finally:
# NOTE: timing records are fixed under the "timeperf" scope
full_key = f"timeperf/{key}"
self._set_reduce_type(full_key, ReduceType.SCALAR)
self.stats[full_key].append(time.perf_counter() - start_time)
def denominator(self, **kwargs):
for key, value in kwargs.items():
if not isinstance(value, torch.Tensor) or value.dtype != torch.bool:
@ -252,3 +265,4 @@ denominator = DEFAULT_TRACKER.denominator
export = DEFAULT_TRACKER.export
scope = DEFAULT_TRACKER.scope
scalar = DEFAULT_TRACKER.scalar
record_timing = DEFAULT_TRACKER.record_timing