PullRequest: 332 [lite] Support FSDP engines

Merge branch mzy/lite/engines of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/332

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fsdp2 engine
* fix utils
* add fsdp engine test
* .
* fsdp engine test passed
* unsqueeze immediately before model inputs and after model outputts
* add optimizer save/load, add position id calculation for input
* .
* format
* not to squeeze
* add train and eval api
* .
* .
* improve fsdp engine data preprocessing
* format
* PullRequest: 337 [lite] Add SFT trainer example.
* trainer log
* minor changes
* add update weights from disk
* fix type annotation
This commit is contained in:
博惟 2025-07-09 16:24:25 +08:00
parent 7a438c0650
commit 15dfbe837c
24 changed files with 2389 additions and 18 deletions

View File

@ -1,5 +1,15 @@
import argparse
import getpass
import os
from dataclasses import asdict, dataclass, field
from typing import Dict, List, Optional
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from hydra import compose as hydra_compose
from hydra import initialize as hydra_init
from omegaconf import MISSING, OmegaConf
from realhf.api.cli_args import OptimizerConfig
@dataclass
@ -12,8 +22,8 @@ class MicroBatchSpec:
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
},
)
max_tokens_per_mb: int = field(
default=int(1e12),
max_tokens_per_mb: Optional[int] = field(
default=None,
metadata={
"help": "Maximum tokens per micro-batch. When set, n_mbs becomes the minimum number of micro-batches",
},
@ -70,6 +80,59 @@ class GenerationHyperparameters:
return GenerationHyperparameters(**args)
# Train Engine Configs
@dataclass
class FSDPWrapPolicy:
transformer_layer_cls_to_wrap: Optional[List[str]] = field(
default=None,
metadata={"help": "A list of transformer layer names for FSDP to wrap."},
)
@dataclass
class FSDPEngineConfig:
wrap_policy: Optional[FSDPWrapPolicy] = field(
default=None,
metadata={"help": "FSDP wrap policy, specifying model layers to wrap."},
)
offload_params: bool = field(
default=False,
metadata={"help": "Whether to offload FSDP parameters to CPU."},
)
@dataclass
class TrainEngineConfig:
experiment_name: str
trial_name: str
path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"})
attn_impl: str = field(
default="flash_attention_2",
metadata={
"help": "Attention implementation for huggingface transformers model.",
"choices": ["flash_attention_2"],
},
)
init_from_scratch: bool = field(
default=False, metadata={"help": "Initialize model weights randomly"}
)
init_critic_from_actor: bool = field(
default=False,
metadata={"help": "Initialize critic/reward model from LM checkpoint"},
)
# Training Backend Configuration
gradient_checkpointing: bool = field(
default=True, metadata={"help": "Enable gradient checkpointing"}
)
bf16: bool = field(default=False, metadata={"help": "Use bf16 precision"})
optimizer: Optional[OptimizerConfig] = field(
default=None, metadata={"help": "Optimizer configuration"}
)
backend: str = ""
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
@dataclass
class SGLangConfig:
"""Configuration for SGLang runtime. Refer to:
@ -236,3 +299,332 @@ class InferenceEngineConfig:
request_retries: int = field(
default=3, metadata={"help": "Number of retries for failed requests."}
)
@dataclass
class SGLangEngineConfig:
pass
@dataclass
class ExperimentSaveEvalControl:
"""Controls the frequency of model saving and evaluation during training.
Manages independent counters for epochs, steps, and seconds. The model will be saved
or evaluated when any specified frequency condition is met.
Note:
- Epoch: Number of full passes through the training dataset
- Step: Number of individual training iterations
- Seconds: Wall-clock time duration
"""
total_train_epochs: int = field(
default=1, metadata={"help": "Total number of epochs to train the model."}
)
# Save control
save_freq_epochs: Optional[int] = field(
default=None,
metadata={
"help": "Save frequency in epochs. None disables epoch-based saving."
},
)
save_freq_steps: Optional[int] = field(
default=None,
metadata={"help": "Save frequency in steps. None disables step-based saving."},
)
save_freq_secs: Optional[int] = field(
default=None,
metadata={
"help": "Save frequency in seconds. None disables time-based saving."
},
)
# Checkpointing control
ckpt_freq_epochs: Optional[int] = field(
default=None,
metadata={
"help": "Checkpoint frequency in epochs. None uses save_freq_epochs. "
"Checkpointing is used for recover. Previous checkpoint is overwritten to save space."
},
)
ckpt_freq_steps: Optional[int] = field(
default=None,
metadata={
"help": "Checkpoint frequency in steps. None disables step-based checkpointing."
},
)
ckpt_freq_secs: Optional[int] = field(
default=None,
metadata={
"help": "Checkpoint frequency in seconds. None disables time-based checkpointing."
},
)
# Evaluation control
eval_freq_epochs: Optional[int] = field(
default=None,
metadata={
"help": "Evaluation frequency in epochs. None disables epoch-based evaluation."
},
)
eval_freq_steps: Optional[int] = field(
default=None,
metadata={
"help": "Evaluation frequency in steps. None disables step-based evaluation."
},
)
eval_freq_secs: Optional[int] = field(
default=None,
metadata={
"help": "Evaluation frequency in seconds. None disables time-based evaluation."
},
)
# Benchmark control
benchmark_steps: Optional[int] = field(
default=None,
metadata={
"help": "Terminate training after this number of steps. "
"For benchmarking purposes only. None indicates normal training."
},
)
benchmark_n_seqs: Optional[int] = field(
default=None,
metadata={
"help": "Terminate training after consuming this number of samples. "
"For benchmarking purposes only. None indicates normal training."
},
)
@dataclass
class WandBConfig:
mode: str = "disabled"
entity: Optional[str] = None
project: Optional[str] = None
name: Optional[str] = None
job_type: Optional[str] = None
group: Optional[str] = None
notes: Optional[str] = None
tags: Optional[List[str]] = None
config: Optional[Dict] = None
@dataclass
class SwanlabConfig:
project: Optional[str] = None
name: Optional[str] = None
config: Optional[Dict] = None
logdir: Optional[str] = None
mode: Optional[str] = "local"
api_key: Optional[str] = os.getenv("SWANLAB_API_KEY", None)
@dataclass
class TensorBoardConfig:
path: Optional[str] = None
def get_user_tmp():
user = getpass.getuser()
user_tmp = os.path.join("/home", user, ".cache", "realhf")
os.makedirs(user_tmp, exist_ok=True)
return user_tmp
@dataclass
class NameResolveConfig:
type: str = field(
default="nfs",
metadata={
"help": "Type of the distributed KV store for name resolving.",
"choices": ["nfs", "etcd3", "ray"],
},
)
nfs_record_root: str = field(
default="/tmp/areal/name_resolve",
metadata={
"help": "Record root for NFS name resolving. Should be available in all nodes."
},
)
etcd3_addr: str = field(
default="localhost:2379", metadata={"help": "Address of the ETCD3 server."}
)
ray_actor_name: str = field(
default="ray_kv_store",
metadata={"help": "Name of the distributed Ray KV store."},
)
@dataclass
class ClusterSpecConfig:
name_resolve: NameResolveConfig = field(
default_factory=NameResolveConfig,
metadata={"help": "Name resolving configuration."},
)
cluster_name: str = field(
default="local",
metadata={"help": "Name of the cluster. Used to set specific environs."},
)
fileroot: str = field(
default=get_user_tmp(),
metadata={
"help": "Root for logs and checkpoints. Should be available to all nodes."
},
)
gpu_type: str = field(
default="tesla", metadata={"help": "GPU type of the cluster. Used by slurm."}
)
mount: str = field(
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
)
gpu_image: str = field(default="", metadata={"help": "slurm image for trainers."})
cpu_image: str = field(default="", metadata={"help": "slurm image for CPU jobs."})
gpu_infer_image: str = field(
default="", metadata={"help": "slurm image for LLM inference."}
)
node_name_prefix: str = field(
default="slurmd-", metadata={"help": "Node prefix for a slurm cluster."}
)
n_nodes: int = field(
default=32,
metadata={
"help": "The size of the cluster. Used to decide slurm hostname suffix."
},
)
n_gpus_per_node: int = field(
default=8,
metadata={"help": "GPUs per node (physically)."},
)
@dataclass
class DatasetConfig:
type: str = field(default="", metadata={"help": "Type of implemented dataset"})
batch_size: int = field(
default=1, metadata={"help": "Batch size of the dataloader"}
)
shuffle: bool = field(
default=True, metadata={"help": "Whether to shuffle the dataset"}
)
pin_memory: bool = field(
default=False,
metadata={
"help": "Pin memory for faster data loading (set True for GPU training)"
},
)
num_workers: int = field(
default=0, metadata={"help": "Number of worker processes for data loading"}
)
drop_last: bool = field(default=True)
@dataclass
class TrainerConfig:
experiment_name: str = field(
default=MISSING,
metadata={"help": "Name of the experiment (no '_' or '/'). Required."},
)
trial_name: str = field(
default=MISSING,
metadata={"help": "Name of the trial (no '-' or '/'). Required."},
)
fileroot: str = field(
default=get_user_tmp(),
metadata={
"help": "Root for logs and checkpoints. Should be available to all nodes."
},
)
wandb: WandBConfig = field(
default_factory=WandBConfig,
metadata={"help": "Weights & Biases configuration."},
)
swanlab: SwanlabConfig = field(
default_factory=SwanlabConfig,
metadata={"help": "SwanLab configuration."},
)
tensorboard: TensorBoardConfig = field(
default_factory=TensorBoardConfig,
metadata={"help": "TensorBoard configuration. Only 'path' field required."},
)
allocation_mode: str = field(
default="",
metadata={
"help": "GPU parallel strategy allocation mode. "
"Options: manual/heuristic or pattern-based."
},
)
seed: int = field(default=1, metadata={"help": "Random seed for reproducibility."})
exp_ctrl: ExperimentSaveEvalControl = field(
default_factory=ExperimentSaveEvalControl,
metadata={"help": "Experiment save/evaluation control configuration."},
)
tokenizer_path: str = field(default="")
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
@dataclass
class BaseExperimentConfig:
# NOTE: we need this unified config class because different experiments
# have different config structures, e.g., GRPO has two engine configs,
# but SFT only has a single one. We use subclasses to represent these structures.
experiment_name: str = field(
default=MISSING,
metadata={"help": "Name of the experiment (no '_' or '/'). Required."},
)
trial_name: str = field(
default=MISSING,
metadata={"help": "Name of the trial (no '-' or '/'). Required."},
)
cluster: ClusterSpecConfig = field(
default_factory=ClusterSpecConfig,
metadata={"help": "Cluster specification. Mainly used by slurm."},
)
n_nodes: int = field(
default=1, metadata={"help": "Number of nodes for experiment."}
)
n_gpus_per_node: int = field(
default=8, metadata={"help": "Number of GPUs per node for this experiment."}
)
train_dataset: DatasetConfig = field(default_factory=DatasetConfig)
valid_dataset: DatasetConfig = field(default_factory=DatasetConfig)
@dataclass
class SFTConfig(BaseExperimentConfig):
model: TrainEngineConfig = field(default_factory=TrainEngineConfig)
trainer: TrainerConfig = field(default_factory=TrainerConfig)
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", help="The path of the main configuration file", required=True
)
args, overrides = parser.parse_known_args(argv)
# Initialize hydra config
config_file = Path(args.config).absolute()
assert config_file.exists()
# hydra only recognize relative paths
relpath = Path(
os.path.relpath(str(config_file), (Path(__file__).parent).absolute())
)
hydra_init(config_path=str(relpath.parent), job_name="app", version_base=None)
cfg = hydra_compose(
config_name=str(relpath.name).rstrip(".yaml"),
overrides=overrides,
)
# Merge with the default configuration.
# The yaml and commandline can omit some default values defined in python dataclasses.
default_cfg = OmegaConf.structured(config_cls)
cfg = OmegaConf.merge(default_cfg, cfg)
cfg = OmegaConf.to_object(cfg)
assert isinstance(cfg, BaseExperimentConfig)
# Setup environment
from realhf.base import constants, name_resolve
constants.set_experiment_trial_names(cfg.experiment_name, cfg.trial_name)
name_resolve.reconfigure(cfg.cluster.name_resolve)
return cfg, str(config_file)

View File

@ -47,7 +47,14 @@ class TrainEngine(abc.ABC):
def destroy(self):
"""Destroy the engine and release GPU memory."""
pass
def train(self, mode: bool = True):
"""Set the engine to the train mode."""
raise NotImplementedError()
def eval(self):
"""Set the engine to the eval mode."""
return self.train(False)
def upload_weights(self, meta: WeightUpdateMeta):
"""Upload weights to the inference engine."""
@ -111,7 +118,6 @@ class InferenceEngine(abc.ABC):
def destroy(self):
"""Destroy the engine and release GPU memory."""
pass
def update_weights(self, meta: WeightUpdateMeta) -> Future:
"""Update weights in the inference engine."""

View File

@ -1,5 +1,5 @@
import abc
from typing import Any, Callable, Dict, List
from typing import Any, Dict, List
class Environment(abc.ABC):
@ -11,7 +11,6 @@ class Environment(abc.ABC):
For stateful environments, this is where resources are created and
prepared (e.g., launching a browser).
"""
pass
def list_tools(self) -> List[Dict[str, Any]]:
"""Lists all available tools in the environment."""
@ -27,4 +26,3 @@ class Environment(abc.ABC):
This method is critical for stateful environments (e.g., a browser session).
"""
pass

View File

@ -6,7 +6,7 @@ import itertools
import re
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional
from transformers import PreTrainedTokenizerFast
@ -161,6 +161,7 @@ class WeightUpdateMeta:
path: str | None
alloc_mode: AllocationMode | None
comm_backend: str | None
model_version: int = 0
@dataclass

View File

@ -21,4 +21,3 @@ def reward_fn(
Any other attributes in the dataset will be passed as keyword arguments to this function.
:rtype: float
"""
pass

130
arealite/api/trainer_api.py Normal file
View File

@ -0,0 +1,130 @@
import getpass
import os
from typing import Dict
import torch.distributed as dist
import wandb
from tensorboardX import SummaryWrite
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import TrainerConfig
from arealite.api.engine_api import InferenceEngine, TrainEngine
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, timeutil
class Trainer:
def __init__(
self,
config: TrainerConfig,
train_dataloader: StatefulDataLoader,
valid_dataloader: StatefulDataLoader,
engine: TrainEngine,
inf_engine: InferenceEngine | None = None,
):
self.config = config
self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader
self.engine = engine
self.inf_engine = inf_engine
self.tokenizer = load_hf_tokenizer(config.tokenizer_path)
self.save_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.save_freq_epochs,
freq_step=config.exp_ctrl.save_freq_steps,
freq_sec=config.exp_ctrl.save_freq_secs,
)
self.eval_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.eval_freq_epochs,
freq_step=config.exp_ctrl.eval_freq_steps,
freq_sec=config.exp_ctrl.eval_freq_steps,
)
self.logger = logging.getLogger(self.__class__.__name__)
self.init_stats_logging()
def init_stats_logging(self):
"""
Initialize wandb and/or tensorboard according to config.
If torch.distributed is initialized
Return:
tensorboard SummaryWriter if self.config.tensorboard.path is not None
"""
if dist.is_initialized() and dist.get_rank() != 0:
return
# wandb init, connect to remote wandb host
if self.config.wandb.mode != "disabled":
wandb.login()
wandb.init(
mode=self.config.wandb.mode,
entity=self.config.wandb.entity,
project=self.config.wandb.project or self.config.experiment_name,
name=self.config.wandb.name or self.config.trial_name,
job_type=self.config.wandb.job_type,
group=self.config.wandb.group
or f"{self.config.experiment_name}_{self.config.trial_name}",
notes=self.config.wandb.notes,
tags=self.config.wandb.tags,
config=self.config.wandb.config,
dir=Trainer.get_log_path(self.config),
force=True,
id=f"{self.config.experiment_name}_{self.config.trial_name}_train",
resume="allow",
settings=wandb.Settings(start_method="fork"),
)
# tensorboard logging
self.summary_writer = None
if self.config.tensorboard.path is not None:
self.summary_writer = SummaryWriter(log_dir=self.config.tensorboard.path)
def log_wandb_tensorboard(self, step: int, data: Dict):
if dist.is_initialized() and dist.get_rank() != 0:
return
wandb.log(data, step=step)
if self.summary_writer is not None:
for key, val in data.items():
self.summary_writer.add_scalar(f"{key}", val, step)
def close_wandb_tensorboard(self):
if dist.is_initialized() and dist.get_rank() != 0:
return
wandb.finish()
if self.summary_writer is not None:
self.summary_writer.close()
@staticmethod
def get_save_checkpoint_path(
config: TrainerConfig,
epoch: int,
step: int,
globalstep: int,
name: str = "default",
):
path = os.path.join(
f"{config.fileroot}/checkpoints/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}",
name,
f"epoch{epoch}epochstep{step}globalstep{globalstep}",
)
os.makedirs(path, exist_ok=True)
return path
@staticmethod
def get_log_path(config: TrainerConfig):
path = f"{config.fileroot}/logs/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}"
os.makedirs(path, exist_ok=True)
return path
def log(self, msg: str, level="info"):
if dist.is_initialized() and dist.get_rank() > 0:
return
log_fn = getattr(self.logger, level, "info")
return log_fn(msg)
def train(self):
raise NotImplementedError()

View File

@ -0,0 +1,464 @@
import gc
import os
import time
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
import transformers
from tensordict import TensorDict
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
)
from transformers import (
AutoConfig,
AutoModelForCausalLM,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import (
FinetuneSpec,
MicroBatchSpec,
SaveLoadMeta,
TrainEngine,
WeightUpdateMeta,
)
from arealite.utils.data import (
MicroBatchList,
amend_position_ids,
pack_tensor_dict,
pad_and_stack_tensors_along_first_dim,
pad_mb_list,
reorder_list,
split_packed_tensor_dict_into_mb_list,
unpack_sequence,
unsqueeze_mb_list,
)
from arealite.utils.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
apply_fsdp2,
create_fsdp_device_mesh,
fsdp2_clip_grad_norm_,
fsdp2_load_full_state_dict,
get_cosine_schedule_with_warmup,
)
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, name_resolve, names, pkg_version
logger = logging.getLogger("FSDPEngine")
class FSDPEngine(TrainEngine):
def __init__(self, config: TrainEngineConfig):
self.config = config
self.optimizer_config = config.optimizer
self.model = None
self.optimizer = None
self.tokenizer = None
# huggingface model config
self.model_config = None
# FSDP options
self.mixed_precision_policy = None
self.device_mesh = None
self.cpu_offload = None
# initialization
self.initialized = False
self.weight_update_group_initialized = False
# TODO: Handle the case when WORLD_SIZE is not set in launcher
self.world_size = int(os.environ["WORLD_SIZE"])
def train(self, mode: bool = True):
assert self.model is not None
self.model.train(mode=mode)
return self
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
# Initialize distributed enviroments and load model.
assert addr is None, "FSDPEngine does not support remote initialization."
assert pkg_version.is_version_greater_or_equal(
"torch", "2.4.0"
), f"arealite only supports FSDP2, which requires torch>=2.4.0"
"""Initialize distributed communication and model."""
if not dist.is_initialized():
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
dist.init_process_group(backend="nccl")
# TODO: Handle the condition when LOCAL_RANK is not set in launcher
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
dtype = torch.bfloat16 if self.config.bf16 else torch.float16
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
)
self.tokenizer = load_hf_tokenizer(self.config.path)
with torch.device("cuda"):
# initialize scratch model from config
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
# Simple auto wrap policy
self.mixed_precision_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
cast_forward_inputs=True,
)
self.device_mesh = create_fsdp_device_mesh(self.world_size, self.world_size)
# sharding_strategy = ShardingStrategy.FULL_SHARD
self.cpu_offload = (
CPUOffloadPolicy() if self.config.fsdp.offload_params else None
)
fsdp_kwargs = {
"mesh": self.device_mesh,
"mp_policy": self.mixed_precision_policy,
"offload_policy": self.cpu_offload,
"reshard_after_forward": True,
}
# Wrap with FSDP2
apply_fsdp2(model, fsdp_kwargs, self.config.fsdp.wrap_policy)
self.model = model
if not self.config.init_from_scratch:
# Load model from a initial checkpoint path,
# which should only be a huggingface checkpoint.
load_meta = SaveLoadMeta(
path=self.config.path,
weight_format="hf",
with_optim=False,
tokenizer=None,
base_model_path=self.config.path,
)
self.load(load_meta)
# Set up optimizer
if self.optimizer_config is not None:
assert (
self.optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
lr = self.optimizer_config.lr
weight_decay = self.optimizer_config.weight_decay
beta1 = self.optimizer_config.beta1
beta2 = self.optimizer_config.beta2
eps = self.optimizer_config.eps
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(beta1, beta2),
eps=eps,
)
total_train_steps = ft_spec.total_train_steps
num_warmup_steps = int(
self.optimizer_config.warmup_steps_proportion * total_train_steps
)
if self.optimizer_config.lr_scheduler_type == "cosine":
self.lr_scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
min_lr_ratio=self.optimizer_config.min_lr_ratio,
)
elif self.optimizer_config.lr_scheduler_type == "linear":
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
)
elif self.optimizer_config.lr_scheduler_type == "constant":
self.lr_scheduler = get_constant_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
)
else:
raise ValueError(
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
)
self.initialized = True
def destroy(self):
"""Destroy the engine and release GPU memory."""
self.model = None
self.optimizer = None
gc.collect()
torch.cuda.empty_cache()
gc.collect()
self.initialized = False
def save(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
self._save_model_to_hf(meta.path, meta.tokenizer)
elif meta.weight_format == "dcp":
# TODO: implement DCP save/load for FSDP
raise NotImplementedError("DCP format saving is not implemented yet. ")
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if meta.with_optim:
self._save_optimizer_state(meta.path)
def load(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
self._load_model_from_hf(meta.path)
elif meta.weight_format == "dcp":
# TODO: implement DCP save/load for FSDP
raise NotImplementedError("DCP format loading is not implemented yet. ")
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if meta.with_optim:
self._load_optimizer_state(meta.path)
def _save_optimizer_state(self, path: str):
# Save FSDP sharded state dict on each rank
assert self.optimizer is not None
assert dist.is_initialized()
rank = dist.get_rank()
shard_path = os.path.join(
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
)
state_dict = self.optimizer.state_dict()
torch.save(state_dict, shard_path)
dist.barrier()
def _load_optimizer_state(self, path: str):
# Load FSDP sharded state dict
assert self.optimizer is not None
assert dist.is_initialized()
rank = dist.get_rank()
shard_path = os.path.join(
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
)
optimizer_state_dict = torch.load(shard_path, weights_only=False)
self.optimizer.load_state_dict(optimizer_state_dict)
dist.barrier()
def _save_model_to_hf(
self, path: str, tokenizer: Optional[transformers.PreTrainedTokenizerFast]
):
"""Save model in HuggingFace format."""
if self.model is None:
raise RuntimeError("Model not initialized")
os.makedirs(path, exist_ok=True)
# FSDP2 checkpoint saving
# Get full state dict with FSDP2
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
state_dict = get_model_state_dict(self.model, options=options)
# save huggingface model on rank 0
if dist.get_rank() == 0:
os.makedirs(path, exist_ok=True)
self.model.save_pretrained(path, state_dict=state_dict)
self.model_config.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
dist.barrier()
def _load_model_from_hf(self, path: str):
"""Load model from HuggingFace format."""
if dist.get_rank() == 0:
full_state = get_state_dict_from_repo_id_or_path(path)
else:
full_state = {}
fsdp2_load_full_state_dict(
self.model,
full_state,
self.cpu_offload,
tie_word_embeddings=self.model_config.tie_word_embeddings,
)
def upload_weights(self, meta: WeightUpdateMeta):
if meta.type == "nccl":
if not self.weight_update_group_initialized:
self._init_distributed_weight_update(meta)
self._update_weights_from_distributed()
elif meta.type == "disk":
self._save_model_to_hf(meta.path, self.tokenizer)
# dist.barrier() are called when _save_model_to_hf finished
if dist.get_rank() == 0:
update_name = names.update_weights_from_disk(
self.config.experiment_name,
self.config.trial_name,
meta.model_version,
)
name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120)
else:
raise ValueError(f"Unknown weight update type {meta.type}")
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
raise NotImplementedError(
"Distributed weight update is not implemented for FSDPEngine yet. "
)
def _update_weights_from_distributed(self):
raise NotImplementedError(
"Distributed weight update is not implemented for FSDPEngine yet. "
)
def step_lr_scheduler(self):
assert self.lr_scheduler is not None
self.lr_scheduler.step()
def _prepare_mb_list(
self, input_: TensorDict, mb_spec: MicroBatchSpec
) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
input_ = amend_position_ids(input_)
packed_input = pack_tensor_dict(input_)
mb_list = split_packed_tensor_dict_into_mb_list(
packed_input,
mb_spec,
)
mb_list = pad_mb_list(mb_list, pad_value=0.0)
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].
mb_list = unsqueeze_mb_list(mb_list)
return mb_list
def train_batch(
self,
input_: TensorDict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict[str, float]:
"""Train on a batch using gradient accumulation."""
input_ = input_.to(self.device)
assert self.optimizer is not None
assert self.optimizer_config is not None
assert self.lr_scheduler is not None
self.optimizer.zero_grad()
mb_list = self._prepare_mb_list(input_, mb_spec)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
dist.all_reduce(total_loss_weight)
# Process microbatches with gradient accumulation
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
# Scale loss for accumulation
# Revert gradient averaging across dp ranks
loss_scale *= self.world_size
loss *= loss_scale
loss.backward()
grad_norm = fsdp2_clip_grad_norm_(
self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping
)
if not torch.isfinite(grad_norm):
self.optimizer.zero_grad()
update_successful = False
else:
self.optimizer.step()
update_successful = True
current_lr = self.lr_scheduler.get_last_lr()[0]
# Optimizer step
self.optimizer.step()
return dict(
update_successful=float(update_successful),
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
lr=current_lr,
)
@torch.no_grad()
def eval_batch(
self,
input_: TensorDict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
"""Evaluate on a batch."""
mb_list = self._prepare_mb_list(input_, mb_spec)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
total_loss = 0.0
total_weight = 0.0
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
# Simple weight calculation (could be improved)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
total_loss += loss.item() * loss_scale
total_weight += loss_scale
return torch.tensor(total_loss / total_weight)
@torch.no_grad()
def forward(
self,
input_: TensorDict,
mb_spec: MicroBatchSpec,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:
"""Forward pass with optional post-processing."""
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
mb_list = self._prepare_mb_list(input_, mb_spec)
if output_seqlens is None:
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
results = []
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
if post_hook:
result = post_hook(logits, mb_input)
results.append(result)
else:
results.append(logits)
res = aggregate_fn(results)
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
reordered = reorder_list(unpacked, mb_list.backward_indices)
return pad_and_stack_tensors_along_first_dim(reordered)

View File

@ -0,0 +1,6 @@
from arealite.api.cli_args import SGLangEngineConfig
class SGLangEngine:
def __init__(self, config: SGLangEngineConfig):
pass

View File

@ -1,7 +1,6 @@
import asyncio
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
@ -18,7 +17,7 @@ from arealite.api.io_struct import (
RolloutStat,
WeightUpdateMeta,
)
from realhf.base import logging, pkg_version
from realhf.base import logging, name_resolve, names, pkg_version
if TYPE_CHECKING:
from arealite.api.workflow_api import RolloutWorkflow
@ -372,26 +371,41 @@ class RemoteSGLangEngine(InferenceEngine):
def _update_weights(self, meta: WeightUpdateMeta):
if meta.type == "disk":
# Update weights from disk
# Wait for model checkpoints of meta.version
update_name = names.update_weights_from_disk(
self.config.experiment_name, self.config.trial_name, meta.model_version
)
save_timestamp = int(name_resolve.wait(update_name, timeout=120))
load_timestamp = time.time_ns()
logger.info(
f"Begin update weights from {meta.path}, responded in {(load_timestamp - save_timestamp)/1e6:.2f} ms"
)
try:
jobs = [
self.aupdate_weights_from_disk(addr, meta.path)
for addr in self.addresses
]
loop = asyncio.new_event_loop()
# asyncio event loop should be manually set when running asyncio stuff in another thread
asyncio.set_event_loop(loop)
loop.run_until_complete(asyncio.gather(*jobs))
finally:
loop.close()
logger.info(
f"Loading weights done in {(time.time_ns() - load_timestamp)/1e6:.2f} ms"
)
self.set_version(meta.model_version)
else:
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
async def aupdate_weights_from_disk(self, addr, path: str):
response, _ = await self.arequest_with_retry(
response = await self.arequest_with_retry(
endpoint="/update_weights_from_disk",
payload=dict(model_path=path, allow_interrupt=True),
payload=dict(model_path=str(path), allow_interrupt=True),
method="POST",
max_retries=3,
timeout=self.config.request_timeout,
target_server=addr,
target_addr=addr,
)
res = await response.json()
assert res["success"]

View File

@ -0,0 +1,161 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
"""Test script for HF Engine implementation."""
import os
from typing import Dict
import pytest
import torch
from tensordict import TensorDict
from transformers import AutoTokenizer
from arealite.api.cli_args import MicroBatchSpec, OptimizerConfig, TrainEngineConfig
from arealite.api.io_struct import FinetuneSpec, SaveLoadMeta
from arealite.engine.fsdp_engine import FSDPEngine
VOCAB_SIZE = 100
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
if not os.path.exists(MODEL_PATH):
MODEL_PATH = "Qwen/Qwen2-0.5B"
@pytest.fixture(scope="module")
def mock_input(
batch_size=5,
min_seqlen=10,
max_seqlen=20,
device="cuda:0",
) -> Dict:
"""Create mock padded input data (same format for huggingface) for testing.
Returns a dict with input_ids, attention_mask, and position_ids.
"""
pad_token_id = 0
seqlens = torch.randint(
min_seqlen, max_seqlen, (batch_size,), dtype=torch.int, device=device
)
max_seqlen = int(max(seqlens))
input_ids = torch.randint(
0, VOCAB_SIZE, (batch_size, max_seqlen), dtype=torch.long, device=device
)
attn_mask = torch.zeros((batch_size, max_seqlen), dtype=torch.bool, device=device)
attn_mask[
torch.arange(0, max_seqlen, device=device).unsqueeze(0) < seqlens.unsqueeze(1)
] = 1
input_ids.masked_fill_(~attn_mask, pad_token_id)
return TensorDict(
input_ids=input_ids,
attention_mask=attn_mask,
)
def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor:
"""Mock loss function for testing."""
return torch.mean(logits)
@pytest.fixture(scope="module")
def engine():
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "7777"
engine_config = TrainEngineConfig(
experiment_name="test-fsdp-engine",
trial_name="test0",
path=MODEL_PATH,
optimizer=OptimizerConfig(),
)
engine = FSDPEngine(engine_config)
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
engine.initialize(None, ft_spec)
print("✓ Engine created successfully")
yield engine
@torch.no_grad()
def test_forward_microbatch(engine, mock_input):
engine.eval()
x2 = (
engine.forward(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
)
.squeeze(0)
.mean(-1)
)
x1 = (
engine.forward(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
)
.squeeze(0)
.mean(-1)
)
input_ids = mock_input["input_ids"]
assert x1.shape[:1] == input_ids.shape[:1]
assert x2.shape[:1] == input_ids.shape[:1]
assert torch.allclose(x1, x2, atol=1e-1, rtol=1e-2), (x1 - x2).abs().max().item()
@torch.no_grad()
def test_eval_batch(engine, mock_input):
engine.eval()
eval_result = engine.eval_batch(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
loss_fn=mock_loss_fn,
loss_weight_fn=lambda x: x["cu_seqlens"][-1],
)
assert isinstance(eval_result, torch.Tensor), "Evaluation should return a tensor"
assert eval_result.is_cuda, "Evaluation tensor should be on CUDA device"
assert eval_result is not None, "Evaluation should return a loss value"
print(f"✓ Evaluation successful, loss: {eval_result.item()}")
def test_train_batch(engine, mock_input):
engine.train()
train_result = engine.train_batch(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
loss_fn=mock_loss_fn,
loss_weight_fn=lambda x: x["cu_seqlens"][-1],
)
assert isinstance(train_result, dict), "Training should return a dictionary"
assert train_result["grad_norm"] is not None
assert train_result["lr"] is not None
print("✓ Training successful")
@torch.no_grad()
def test_hf_save_load_weights(tmp_path_factory, engine, mock_input):
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
path = tmp_path_factory.mktemp("hf_engine_test")
save_load_meta = SaveLoadMeta(
path=path,
weight_format="hf",
tokenizer=tokenizer,
with_optim=True,
base_model_path=None,
)
old = engine.forward(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
)
engine.save(save_load_meta)
for name, param in engine.model.named_parameters():
param.zero_()
engine.load(save_load_meta)
new = engine.forward(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
)
assert torch.allclose(old, new)

View File

@ -14,9 +14,9 @@ from arealite.api.cli_args import (
InferenceEngineConfig,
SGLangConfig,
)
from arealite.api.io_struct import FinetuneSpec, LLMRequest, LLMResponse
from arealite.api.io_struct import LLMRequest, LLMResponse, WeightUpdateMeta
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import name_resolve, network, seeding
from realhf.base import network
EXPR_NAME = "test_sglang_engine"
TRIAL_NAME = "trial_0"
@ -186,3 +186,50 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
# exit
engine.destroy()
def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
# setup FSDP engine
from arealite.api.cli_args import OptimizerConfig, TrainEngineConfig
from arealite.api.io_struct import FinetuneSpec
from arealite.engine.fsdp_engine import FSDPEngine
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "7777"
engine_config = TrainEngineConfig(
experiment_name=EXPR_NAME,
trial_name=TRIAL_NAME,
path=MODEL_PATH,
optimizer=OptimizerConfig(),
)
engine = FSDPEngine(engine_config)
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
engine.initialize(None, ft_spec)
# setup name resolve
import realhf.base.name_resolve as name_resolve
from realhf.api.cli_args import NameResolveConfig
nfs_record_root = tmp_path_factory.mktemp("nfs_record_path")
name_resolve_config = NameResolveConfig(type="nfs", nfs_record_root=nfs_record_root)
name_resolve.reconfigure(name_resolve_config)
# initialize SGLang remote engine
from arealite.api.cli_args import InferenceEngineConfig
from arealite.engine.sglang_remote import RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
config.server_addrs = [f"{HOST}:{PORT}"]
inf_engine = RemoteSGLangEngine(config)
# test update weights
path = tmp_path_factory.mktemp("upload_weights_from_disk")
update_weight_meta = WeightUpdateMeta(
type="disk", path=path, alloc_mode=None, comm_backend=None, model_version=100
)
future = inf_engine.update_weights(update_weight_meta)
engine.upload_weights(update_weight_meta)
future.result()
assert inf_engine.get_version() == 100

View File

@ -0,0 +1,69 @@
import pytest
import torch
from tensordict import TensorDict
from arealite.api.cli_args import MicroBatchSpec
from arealite.utils.data import (
pack_tensor_dict,
pad_and_stack_tensors_along_first_dim,
pad_sequences_to_tensors,
reorder_list,
split_packed_tensor_dict_into_mbs,
unpack_sequence,
)
BS = 16
MAX_ANSWER_LEN = 16
MAX_PROMPT_LEN = 8
VOCAB_SIZE = 100
@pytest.fixture
def mock_padded_data():
prompt_lens = torch.randint(1, MAX_PROMPT_LEN, size=(BS,))
answer_lens = torch.randint(1, MAX_ANSWER_LEN, size=(BS,))
all_data = []
for prompt_len, ans_len in zip(prompt_lens, answer_lens):
prompt_len = int(prompt_len)
ans_len = int(ans_len)
seq = dict(
input_ids=torch.randint(0, VOCAB_SIZE, size=(prompt_len + ans_len,)),
prompt_mask=torch.tensor([1] * prompt_len + [0] * ans_len),
logprobs=torch.randn(prompt_len + ans_len),
position_ids=torch.arange(prompt_len + ans_len),
)
all_data.append(TensorDict(seq))
return pad_sequences_to_tensors(all_data)
@pytest.mark.parametrize("max_tokens_per_mb", [24, 36, 48, 100])
@pytest.mark.parametrize("n_mbs", [1, 2, 4, 8])
def test_micro_batch_split(mock_padded_data, n_mbs, max_tokens_per_mb):
mb_spec = MicroBatchSpec(n_mbs, max_tokens_per_mb)
# Unpad and split to microbatches
packed_data = pack_tensor_dict(mock_padded_data)
original_lens = packed_data["cu_seqlens"][1:] - packed_data["cu_seqlens"][:-1]
assert torch.allclose(original_lens, mock_padded_data["attention_mask"].sum(1))
split_result = split_packed_tensor_dict_into_mbs(packed_data, mb_spec)
reordered_lens = [original_lens[i] for i in split_result.forward_indices]
# assert microbatch split result does not violate requirements
assert len(split_result.mbs) >= n_mbs
# test reorder back
for key in split_result.mbs[0].keys():
if key in ["cu_seqlens", "max_seqlen"]:
continue
# assert microbatch split result does not violate requirements
for mb in split_result.mbs:
assert mb[key].shape[0] <= max_tokens_per_mb
x = torch.cat([mb[key] for mb in split_result.mbs])
xs = unpack_sequence(x, lens=reordered_lens)
xs = reorder_list(xs, split_result.backward_indices)
x = torch.cat(xs)
assert torch.allclose(x, packed_data[key])
y = pad_and_stack_tensors_along_first_dim(xs)
assert torch.allclose(mock_padded_data[key], y)

156
arealite/trainer/sft.py Normal file
View File

@ -0,0 +1,156 @@
import time
from typing import Dict
import torch
import torch.distributed as dist
import torch.utils.data
from arealite.api.io_struct import SaveLoadMeta
from arealite.api.trainer_api import Trainer
from arealite.utils.functional import gather_logprobs
from arealite.utils.logging import record_timing
from realhf.api.core.data_api import tabulate_stats
from realhf.base import logging, stats_tracker
def compute_packed_sft_loss(
logits: torch.Tensor,
input_: Dict[str, torch.Tensor],
) -> torch.Tensor:
packed_input_ids: torch.Tensor = input_["input_ids"]
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
prompt_mask = input_["prompt_mask"].bool()
logits = logits.float()
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
prompt_mask = torch.roll(prompt_mask, shifts=-1, dims=-1)
logprobs = torch.where(prompt_mask, 0, logprobs)
loss = -logprobs.sum() / prompt_mask.logical_not().count_nonzero()
with torch.no_grad():
seqlogp = torch.zeros(
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
)
for i in range(cu_seqlens.shape[0] - 1):
m = prompt_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
cu_seqlens,
logprobs.shape,
)
seqlogp[i] = torch.where(m, 0.0, logp.detach()).sum() / (
m.numel() - m.count_nonzero()
)
## Loggin stats
stats_tracker.denominator(
n_seqs=torch.ones(
cu_seqlens.shape[0] - 1, dtype=torch.bool, device=logprobs.device
),
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
n_valid_tokens=prompt_mask.logical_not(),
prompt_tokens=prompt_mask,
)
stats_tracker.stat(ppl=(-seqlogp).exp().float(), denominator="n_seqs")
stats_tracker.stat(loss=-logprobs.detach(), denominator="n_valid_tokens")
vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float()
stats_tracker.stat(
vocab_min_logits=vocab_min_logits,
vocab_max_logits=vocab_max_logits,
denominator="n_tokens",
)
return loss
class SFTTrainer(Trainer):
def train(self):
total_epochs = self.config.exp_ctrl.total_train_epochs
steps_per_epoch = len(self.train_dataloader)
self.log(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
global_step = 0
start_time = time.monotonic()
for epoch in range(total_epochs):
for step, data in enumerate(self.train_dataloader):
self.engine.train()
timing_stats = {}
with record_timing("timeperf/train_step", timing_stats):
with stats_tracker.scope("sft"):
stats = self.engine.train_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x["prompt_mask"]
.logical_not()
.count_nonzero(),
mb_spec=self.config.mb_spec,
)
self.engine.step_lr_scheduler()
stats_tracker.scalar(**stats)
if self.save_ctl.check(
epochs=int(step == steps_per_epoch - 1), steps=1
):
self.log("Saving model ...")
with record_timing("timeperf/save", timing_stats):
save_path = self.get_save_checkpoint_path(
self.config, epoch, step, global_step
)
meta = SaveLoadMeta(
path=save_path,
weight_format="hf",
with_optim=False,
tokenizer=self.tokenizer,
base_model_path=self.config.tokenizer_path,
)
self.engine.save(meta)
if self.eval_ctl.check(
epochs=int(step == steps_per_epoch - 1), steps=1
):
if dist.get_rank() == 0:
self.log("Running evaluation ...")
with record_timing("timeperf/eval", timing_stats):
self.evaluate()
stats = stats_tracker.export()
stats.update(timing_stats)
self.log_wandb_tensorboard(global_step, stats)
self.log(
f"Epoch {epoch+1}/{total_epochs} "
f"Step {step+1}/{steps_per_epoch} "
f"Train step {global_step + 1}/{total_epochs * steps_per_epoch} done."
)
self.log(
f"Detailed time stats: \n{tabulate_stats(timing_stats, floatfmt='.2f')}"
)
self.log(f"SFT training stats:\n{tabulate_stats(stats)}")
global_step += 1
self.log(
f"Training completes! Total time elapsed {time.monotonic() - start_time:.2f}."
)
self.close_wandb_tensorboard()
def evaluate(self):
if self.valid_dataloader is None:
return
self.engine.eval()
for data in self.valid_dataloader:
with stats_tracker.scope("sft-eval"):
# No need to log anything. Logging will be handled outside
# via stats_tracker.export().
self.engine.eval_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x["prompt_mask"]
.logical_not()
.count_nonzero(),
mb_spec=self.config.mb_spec,
)

491
arealite/utils/data.py Normal file
View File

@ -0,0 +1,491 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
# Pad/unpad operations are modified from flash-attention under BSD-3 license.
# Copyright (c) 2023, Tri Dao.
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from tensordict import TensorDict
from arealite.api.cli_args import MicroBatchSpec
from realhf.base import datapack, logging
logger = logging.getLogger("data utils")
def reorder_list(xs: List, indices: List[int]) -> List:
assert len(set(indices)) == len(xs)
return [xs[i] for i in indices]
def dict_map(x: Dict, fn: Callable) -> Dict:
return {k: fn(v) for k, v in x.items()}
def dict_of_list2list_of_dict(
dict_of_lists: Dict[str, List[Any]],
) -> List[Dict[str, Any]]:
if not dict_of_lists:
return []
keys = list(dict_of_lists.keys())
length = len(dict_of_lists[keys[0]])
for key, value_list in dict_of_lists.items():
if len(value_list) != length:
raise ValueError(
f"All lists must have the same length. Key '{key}' has length {len(value_list)}, expected {length}"
)
return [{key: dict_of_lists[key][i] for key in keys} for i in range(length)]
def list_of_dict2dict_of_list(
list_of_dicts: List[Dict[str, Any]],
) -> Dict[str, List[Any]]:
if not list_of_dicts:
return {}
keys = list(list_of_dicts[0].keys())
for i, dict_item in enumerate(list_of_dicts):
if set(dict_item.keys()) != set(keys):
raise ValueError(
f"All dictionaries must have the same keys. Dictionary at index {i} has keys {set(dict_item.keys())}, expected {set(keys)}"
)
return {key: [dict_item[key] for dict_item in list_of_dicts] for key in keys}
def pad_sequences_to_tensors(
sequence_list: List[TensorDict], pad_value: float = 0.0
) -> TensorDict:
if not sequence_list:
return TensorDict()
max_length = max(len(seq) for item in sequence_list for seq in item.values())
result = {}
for key in sequence_list[0].keys():
padded = []
for item in sequence_list:
x = item[key]
if not torch.is_tensor(x):
x = torch.tensor(x)
padded.append(
torch.nn.functional.pad(
x, (0, max_length - len(item[key])), value=pad_value
)
)
result[key] = torch.stack(padded)
attention_mask = [
[1] * len(next(iter(item.values())))
+ [0] * (max_length - len(next(iter(item.values()))))
for item in sequence_list
]
result["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
return TensorDict(result, batch_size=[result["attention_mask"].shape[0]])
def unpad_input(
hidden_states, attention_mask
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
)
return (
rearrange(hidden_states, "b s ... -> (b s) ...")[indices],
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def pad_input(hidden_states, indices, batch, seqlen):
output = hidden_states.new_zeros(batch * seqlen)
output[indices] = hidden_states
return rearrange(output, "(b s) ... -> b s ...", b=batch)
def concat_padded_tensors(
tensor_dicts: List[Dict[str, torch.Tensor]], pad_value: float = 0.0
) -> Dict[str, torch.Tensor]:
"""Concatenate and pad tensors from multiple padded tensor dictionaries."""
if not tensor_dicts:
return {}
# Find max sequence length across all dictionaries
lens = []
for tensor_dict in tensor_dicts:
for key, tensor in tensor_dict.items():
if key != "attention_mask" and len(tensor.shape) == 2:
lens.append(tensor.shape[1])
break
max_length = max(lens)
attn_mask = torch.arange(max_length).unsqueeze(0) < torch.tensor(lens).unsqueeze(1)
result = {}
# Process each key
for key in tensor_dicts[0].keys():
tensors_to_concat = []
for tensor_dict in tensor_dicts:
tensor = tensor_dict[key]
# Skip 1D tensors like rewards
if len(tensor.shape) == 1:
tensors_to_concat.append(tensor)
continue
current_length = tensor.shape[1]
if current_length < max_length:
# Pad tensor to max_length
pad_width = max_length - current_length
if key == "attention_mask":
# Pad attention mask with 0s
padding = torch.zeros(
(tensor.shape[0], pad_width), dtype=tensor.dtype
)
else:
# Pad feature tensors with pad_value
padding = torch.full(
(tensor.shape[0], pad_width), pad_value, dtype=tensor.dtype
)
tensor = torch.cat([tensor, padding], dim=1)
tensors_to_concat.append(tensor)
result[key] = torch.cat(tensors_to_concat, dim=0)
if "attention_mask" not in result:
result["attention_mask"] = attn_mask
return result
def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]:
"""Move tensors in a dictionary to the specified device."""
return {
key: value.to(device) if torch.is_tensor(value) else value
for key, value in data.items()
}
def unpack_sequence(
x: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
lens: Optional[List[int]] = None,
dim: int = 0,
):
"""Unpack a sequence tensor into a list of tensors based on cumulative sequence lengths."""
if lens is not None:
return torch.split(x, lens, dim=dim)
if cu_seqlens is not None:
return torch.split(
x, (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist(), dim=dim
)
raise ValueError("Either cu_seqlens or input_lens must be provided.")
def allocate_balanced_mbs(mb_spec: MicroBatchSpec, lens: List[int]) -> List[List[int]]:
group_indices = datapack.ffd_allocate(
lens, mb_spec.max_tokens_per_mb, min_groups=mb_spec.n_mbs
)
group_indices = sorted([sorted(g) for g in group_indices])
return group_indices
def allocate_balanced_mbs_synced(
mb_spec: MicroBatchSpec,
lens: List[int],
group: Optional[dist.ProcessGroup] = None,
) -> List[List[int]]:
group_indices = allocate_balanced_mbs(mb_spec, lens)
if not dist.is_initialized():
return group_indices
all_n_mbs = [None for _ in range(dist.get_world_size(group))]
dist.all_gather_object(all_n_mbs, len(group_indices), group=group)
if all(mbs == len(group_indices) for mbs in all_n_mbs):
return group_indices
return allocate_balanced_mbs_synced(
MicroBatchSpec.new(mb_spec, n_mbs=max(all_n_mbs)), lens
)
def pack_tensor_dict(data: TensorDict):
"""Pack a tensordict of shape [B, S, ...] into [total_length, ...], leaving other keys unchanged.
Args:
data (Dict[str, Any]): Dictionary containing tensors to be packed. Should contain key "attention_mask" with shape [B, S].
Returns:
Dict[str, Any]: Dictionary with packed tensors. The "attention_mask" key will be replaced by "cu_seqlens" with shape [B+1].
"""
assert "attention_mask" in data, "Input data must contain 'attention_mask' key."
attention_mask = data["attention_mask"]
assert attention_mask.ndim == 2, "Attention mask must be a 2D tensor."
bs = attention_mask.shape[0]
seq_len = attention_mask.shape[1]
# Calculate cumulative sequence lengths
lens = attention_mask.sum(dim=1, dtype=torch.int32)
max_seqlen = lens.max().item()
cu_seqlens = torch.cumsum(lens, dim=0)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
total_length = cu_seqlens[-1].item()
# Pack tensors
packed_data = {}
for key, value in data.items():
if key == "attention_mask":
packed_data["cu_seqlens"] = cu_seqlens
packed_data["max_seqlen"] = max_seqlen
# tensor and of shape [B, S, ...]
elif (
torch.is_tensor(value)
and value.ndim >= 2
and value.shape[0] == bs
and value.shape[1] == seq_len
):
packed_tensor = torch.empty(
total_length, *value.shape[2:], dtype=value.dtype, device=value.device
)
# Fill the packed tensor with values from the original tensor
for i in range(bs):
start = cu_seqlens[i].item()
end = cu_seqlens[i + 1].item()
packed_tensor[start:end] = value[i][: end - start]
packed_data[key] = packed_tensor
else:
packed_data[key] = value
return TensorDict(**packed_data)
def pad_and_stack_tensors_along_first_dim(tensor_list: List[torch.Tensor]):
max_length = max(tensor.shape[0] for tensor in tensor_list)
n_dim = tensor_list[0].ndim
assert all(
tensor.ndim == n_dim for tensor in tensor_list
), "All tensors must have the same number of dimensions."
padded_tensors = []
for tensor in tensor_list:
pad_mode = (0,) * (2 * (n_dim - 1)) + (0, max_length - tensor.shape[0])
padded_tensor = F.pad(tensor, pad_mode, value=0.0)
padded_tensors.append(padded_tensor)
return torch.stack(padded_tensors, dim=0)
@dataclass
class MicroBatchList:
data: Dict[str, Any]
mb_spec: MicroBatchSpec
mbs: List[TensorDict]
forward_indices: List[int]
backward_indices: List[int]
group_lens: List[int]
padded_mbs: Optional[List[TensorDict]] = None
padding_lengths: Optional[List[int]] = None
DEFAULT_MAX_TOKENS_PER_MB = int(1e12)
def split_packed_tensor_dict_into_mb_list(
data: TensorDict, mb_spec: MicroBatchSpec, group: Optional[dist.ProcessGroup] = None
) -> MicroBatchList:
"""Split a packed tensordict into micro-batches based on the cumulative sequence lengths.
Args:
data (TensorDict): Dictionary containing packed tensors with "cu_seqlens" key.
mb_spec (MicroBatchSpec): Specification for micro-batch splitting.
group (Optional[dist.ProcessGroup]): Process group for distributed synchronization.
Returns:
MicroBatchList: A structure containing the split micro-batches and metadata.
"""
assert (
"cu_seqlens" in data
), "Input data must be packed and contain 'cu_seqlens' key."
if mb_spec.max_tokens_per_mb is None:
mb_spec = MicroBatchSpec.new(
mb_spec, max_tokens_per_mb=DEFAULT_MAX_TOKENS_PER_MB
)
cu_seqlens = data["cu_seqlens"]
bs = cu_seqlens.shape[0] - 1
total_lens = int(cu_seqlens[-1])
input_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy()
# check tensor shape, split only 1d tensors with length "total_lens"
to_split = {}
not_to_split = {}
for key, value in data.items():
if key == "cu_seqlens" or key == "max_seqlen":
continue
if not torch.is_tensor(value) or value.numel() != total_lens:
not_to_split[key] = value
else:
to_split[key] = value
# split
group_indices = allocate_balanced_mbs_synced(mb_spec, input_lens, group=group)
splitted_lens = [
[input_lens[i] for i in group_index] for group_index in group_indices
]
group_lens = [sum(x) for x in splitted_lens]
forward_indices = datapack.flat2d(group_indices)
backward_indices = np.zeros(bs, dtype=np.int64)
backward_indices[forward_indices] = np.arange(bs)
def _split(tensor):
"""Split and pad a tensor based on forward indices and lens."""
# Unpack the sequence
unpacked = unpack_sequence(tensor, cu_seqlens=cu_seqlens)
# Reorder according to forward indices
reordered = reorder_list(unpacked, forward_indices)
reordered = torch.cat(reordered)
# Unpack again according to split lens
splitted = unpack_sequence(reordered, lens=group_lens)
return splitted
to_split = dict_map(to_split, lambda x: _split(x))
mbs = dict_of_list2list_of_dict(to_split)
results = []
# organize splitted micro batches
assert len(mbs) == len(splitted_lens), (len(mbs), len(splitted_lens))
for i, (mb, lens) in enumerate(zip(mbs, splitted_lens)):
max_seqlen = max(lens)
lens = torch.tensor(lens, device="cuda")
batch_cu_seqlens = torch.nn.functional.pad(
lens.cumsum(0, dtype=torch.int), (1, 0)
)
results.append(
TensorDict(
**mb, **not_to_split, max_seqlen=max_seqlen, cu_seqlens=batch_cu_seqlens
)
)
return MicroBatchList(
data=data,
mbs=results,
mb_spec=mb_spec,
forward_indices=forward_indices,
backward_indices=backward_indices,
group_lens=group_lens,
)
def pad_packed_tensor_dict(
data: TensorDict,
pad_to_length: int,
pad_value: float = 0.0,
) -> Tuple[Dict[str, Any], int]:
"""Pad a packed tensor dict to a specified length.
This function assumes that the input data contains "cu_seqlens" and "max_seqlen" key,
and all other tensors of shape [total_length, ] will be padded to `pad_to_length`.
This function will pad a new sequence filled with `pad_value` to the end of each tensor,
and update the "cu_seqlens" and "max_seqlen" keys accordingly.
Args:
data (TensorDict): Dictionary containing tensors to be packed.
pad_to_length (int): The length to pad the tensors to. All tensors
Returns:
Dict[str, torch.Tensor]: Dictionary with padded tensors and modified "cu_seqlens" and
"max_seqlen".
int: The pad length.
"""
assert "cu_seqlens" in data, "Input data must contain 'cu_seqlens' key."
assert "max_seqlen" in data, "Input data must contain 'max_seqlen' key."
total_length = data["cu_seqlens"][-1].item()
pad_length = pad_to_length - total_length
assert (
pad_length >= 0
), f"pad_to_length {pad_to_length} must be greater than or equal to total length {total_length}."
cu_seqlens = data["cu_seqlens"]
max_seqlen = data["max_seqlen"]
new_cu_seqlens = F.pad(cu_seqlens, (0, 1), value=pad_to_length)
new_max_seqlen = max(max_seqlen, pad_length)
padded_data = {}
for key, value in data.items():
if key == "cu_seqlens":
padded_data[key] = new_cu_seqlens
elif key == "max_seqlen":
padded_data[key] = new_max_seqlen
elif torch.is_tensor(value) and value.numel() == total_length:
# Pad the tensor to the new total length
padded_tensor = torch.nn.functional.pad(
value, (0, pad_length), value=pad_value
)
padded_data[key] = padded_tensor
else:
padded_data[key] = value
return padded_data, pad_length
def pad_mb_list(
mb_list: MicroBatchList,
pad_value: float = 0.0,
) -> MicroBatchList:
padded_mb_inputs, pad_lengths = [], []
pad_to_lengths = []
for mb, l in zip(mb_list.mbs, mb_list.group_lens):
# NOTE: GPU page size is 2MB
# Take hidden size 4096 with bf16 dtype as an example,
# the batch size of a page is 256
pad_to_length = (l + 255) // 256 * 256
padded_mb, pad_len = pad_packed_tensor_dict(
mb, pad_to_length, pad_value=pad_value
)
padded_mb_inputs.append(padded_mb)
pad_lengths.append(pad_len)
pad_to_lengths.append(pad_to_length)
logger.debug(
f"Microbatch original lengths: {mb_list.group_lens}, padded to {pad_to_lengths}."
)
mb_list.padded_mbs = padded_mb_inputs
mb_list.padding_lengths = pad_lengths
return mb_list
def unsqueeze_packed_tensor_dict(data: TensorDict) -> TensorDict:
assert "cu_seqlens" in data, "Input data must contain 'cu_seqlens' key."
assert "max_seqlen" in data, "Input data must contain 'max_seqlen' key."
total_length = data["cu_seqlens"][-1].item()
for key, value in data.items():
if key == "cu_seqlens" or key == "max_seqlen":
continue
else:
if torch.is_tensor(value) and value.numel() == total_length:
data[key] = value.unsqueeze(dim=0)
return data
def unsqueeze_mb_list(
mb_list: MicroBatchList,
) -> MicroBatchList:
"""Unsqueeze the packed tensordict in the micro-batch list."""
new_mbs = []
new_padded_mbs = []
for i, mb in enumerate(mb_list.mbs):
new_mbs.append(unsqueeze_packed_tensor_dict(mb))
if mb_list.padded_mbs is not None:
new_padded_mbs.append(unsqueeze_packed_tensor_dict(mb_list.padded_mbs[i]))
mb_list.mbs = new_mbs
mb_list.padded_mbs = new_padded_mbs if mb_list.padded_mbs is not None else None
return mb_list
def amend_position_ids(data: TensorDict) -> TensorDict:
assert "attention_mask" in data, "Input data must contain 'attention_mask' key."
attn_mask = data["attention_mask"]
bs, seqlen = attn_mask.shape[:2]
position_ids = (
torch.arange(0, seqlen, dtype=torch.long, device=attn_mask.device)
.unsqueeze(0)
.expand(bs, -1)
)
position_ids.masked_fill(~attn_mask.bool(), 0)
data["position_ids"] = position_ids
return data

189
arealite/utils/fsdp.py Normal file
View File

@ -0,0 +1,189 @@
import math
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from transformers import PreTrainedModel
from realhf.base import logging, pkg_version
logger = logging.getLogger("FSDPEngine")
if pkg_version.is_version_greater_or_equal("torch", "2.6.0"):
from torch.distributed.fsdp import (
CPUOffloadPolicy,
FSDPModule,
MixedPrecisionPolicy,
fully_shard,
)
elif pkg_version.is_version_greater_or_equal("torch", "2.4.0"):
from torch.distributed._composable.fsdp import (
CPUOffloadPolicy,
FSDPModule,
MixedPrecisionPolicy,
fully_shard,
)
else:
CPUOffloadPolicy = None
FSDPModule = None
MixedPrecisionPolicy = None
fully_shard = None
logger.warning("Current PyTorch version < 2.4.0 is not supported for FSDPEngine.")
def fsdp2_clip_grad_norm_(
parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None
):
"""torch.nn.utils.clip_grad_norm_ cann't run on cpu parameter DTensor"""
from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
else:
# prevent generators from being exhausted
parameters = list(parameters)
grads = [p.grad for p in parameters if p.grad is not None]
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True)
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
return total_norm
def create_fsdp_device_mesh(shard_size, world_size):
if shard_size < 0 or shard_size >= world_size:
device_mesh = init_device_mesh(
"cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)
)
else:
device_mesh = init_device_mesh(
"cuda",
mesh_shape=(world_size // shard_size, shard_size),
mesh_dim_names=("ddp", "fsdp"),
)
return device_mesh
def apply_fsdp2(model, fsdp_kwargs, wrap_policy):
"""model: AutoModelForCausalLM"""
assert (
CPUOffloadPolicy is not None
), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", list())
fsdp_transformer_layer_cls_to_wrap = (
wrap_policy.transformer_layer_cls_to_wrap if wrap_policy is not None else list()
)
if not fsdp_transformer_layer_cls_to_wrap:
fsdp_transformer_layer_cls_to_wrap = default_transformer_cls_names_to_wrap
if isinstance(fsdp_transformer_layer_cls_to_wrap, str):
fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap]
assert (
len(fsdp_transformer_layer_cls_to_wrap) > 0
and fsdp_transformer_layer_cls_to_wrap[0] is not None
)
modules = []
for name, module in model.named_modules():
if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or (
isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings
):
modules.append(module)
for idx, module in enumerate(modules):
fully_shard(module, **fsdp_kwargs)
fully_shard(
model, **fsdp_kwargs
) # fsdp2 will not reshard_after_forward for root module
def fsdp2_load_full_state_dict(
model: PreTrainedModel,
full_state: dict,
cpu_offload=None,
tie_word_embeddings=False,
):
"""
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
parameters from rank 0 to all other ranks. This function modifies the model in-place.
Args:
model (`torch.nn.Module`): The model to load the state dict into
full_state (`dict`): The full state dict to load, can only be on rank 0
"""
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
set_model_state_dict,
)
device = torch.cuda.current_device()
model = model.to(device=device, non_blocking=True)
cpu_offload = cpu_offload is not None
options = StateDictOptions(
full_state_dict=True,
cpu_offload=cpu_offload,
broadcast_from_rank0=True,
strict=not tie_word_embeddings,
)
set_model_state_dict(model, full_state, options=options)
if tie_word_embeddings:
model.tie_weights()
# rotary_emb is not in state_dict, so we need to broadcast it manually
for name, buf in model.named_buffers():
dist.broadcast(buf, src=0)
if cpu_offload:
model.to("cpu", non_blocking=True)
for buf in model.buffers():
buf.data = buf.data.to(device)
def get_cosine_schedule_with_warmup(
optimizer: torch.optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.0,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
The minimum lr ratio w.r.t the maximum.
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0
coef = (1 - min_lr_ratio) * 0.5
intercept = (1 + min_lr_ratio) * 0.5
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return min_lr_ratio + (1.0 - min_lr_ratio) * (
float(current_step) / float(max(1, num_warmup_steps))
)
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)
return max(min_lr_ratio, x * coef + intercept)
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

View File

@ -0,0 +1,8 @@
import torch
@torch.compile
def gather_logprobs(logits: torch.Tensor, labels: torch.Tensor):
log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
return log_probs_labels

View File

@ -0,0 +1,9 @@
import time
from contextlib import contextmanager
@contextmanager
def record_timing(name, timing_stats):
start_time = time.perf_counter()
yield
timing_stats[name] = time.perf_counter() - start_time

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import List
import torch
from tensordict import TensorDict

View File

@ -0,0 +1,84 @@
import os
from typing import Dict
import torch
def get_state_dict_from_repo_id_or_path(repo_id_or_path: str) -> Dict:
"""
Obtain a state dictionary from either a Hugging Face repo ID or a local path.
Args:
repo_id_or_path (str): Either a Hugging Face repo ID (e.g., 'username/model-name')
or a local path to a directory containing model weights.
Returns:
Dict: The combined state dictionary from all .safetensors and .bin files.
"""
from safetensors.torch import load_file as safetensors_load
state_dict = {}
# Step 1: Identify if the input is a Hugging Face repo ID or local path
try:
from huggingface_hub.utils import HFValidationError, validate_repo_id
try:
validate_repo_id(repo_id_or_path)
is_hf_repo = True
except HFValidationError:
is_hf_repo = False
except ImportError:
is_hf_repo = False
if is_hf_repo:
from huggingface_hub import snapshot_download
# Step 2: Download the repo if it's a Hugging Face repo ID
local_path = snapshot_download(
repo_id=repo_id_or_path,
)
else:
# Assume it's a local path
local_path = repo_id_or_path
if not os.path.isdir(local_path):
raise ValueError(
f"Local path {local_path} does not exist or is not a directory, "
f"or {local_path} is a huggingface repo id but huggingface_hub is not installed."
)
# Step 3: Load all .safetensors and .bin files
file_paths_to_load = []
for filename in os.listdir(local_path):
filepath = os.path.join(local_path, filename)
if filename.endswith(".safetensors") or filename.endswith(".bin"):
file_paths_to_load.append(filepath)
def _load(filepath: str):
if filepath.endswith(".safetensors"):
state_dict = safetensors_load(filepath)
elif filepath.endswith(".bin"):
state_dict = torch.load(filepath, map_location="cpu")
else:
raise ValueError(f"{filepath} is not a torch bin or safetensor file.")
return state_dict
state_dict = {}
from concurrent.futures import ThreadPoolExecutor, as_completed
with ThreadPoolExecutor(
max_workers=min(4, max(1, os.cpu_count() // 8))
) as executor:
future_to_checkpoint = {
executor.submit(_load, path): path for path in file_paths_to_load
}
for future in as_completed(future_to_checkpoint):
path = future_to_checkpoint[future]
try:
sd = future.result()
state_dict.update(sd)
except Exception as e:
raise RuntimeError(f"Error loading checkpoint from {path}: {e}")
return state_dict

View File

@ -0,0 +1,53 @@
experiment_name: gsm8k-sft
trial_name: trial0
n_nodes: 1
n_gpus_per_node: 8
cluster:
name_resolve:
type: nfs
nfs_record_root: /tmp/areal/name_resolve
model:
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
init_from_scratch: false
pad_mbs_to_max_tokens: true
gradient_checkpointing: false
bf16: true
optimizer:
type: adam
lr: 2e-5
weight_decay: 0.05
beta1: 0.9
beta2: 0.95
eps: 1e-5
lr_scheduler_type: cosine
gradient_clipping: 1.0
backend: fsdp
trainer:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled
seed: 1
exp_ctrl:
total_train_epochs: 1
eval_freq_steps: 1
tokenizer_path: ${model.path}
mb_spec:
max_tokens_per_mb: 4096
train_dataset:
type: gsm8k-sft
batch_size: 128
shuffle: true
pin_memory: true
num_workers: 4
valid_dataset:
type: gsm8k-sft
batch_size: 128
shuffle: true
pin_memory: true
num_workers: 4

View File

@ -0,0 +1,88 @@
import os
import sys
import torch
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import DataCollatorWithPadding
from arealite.api.cli_args import SFTConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec
from arealite.engine.fsdp_engine import FSDPEngine
from arealite.trainer.sft import SFTTrainer
from arealite.utils.data import pad_sequences_to_tensors
from realhf.api.core.data_api import load_hf_tokenizer
def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
def process(sample):
seq_token = tokenizer.encode(
sample["question"] + sample["answer"] + tokenizer.eos_token
)
prompt_token = tokenizer.encode(sample["question"])
prompt_mask = [1] * len(prompt_token) + [0] * (
len(seq_token) - len(prompt_token)
)
return {"input_ids": seq_token, "prompt_mask": prompt_mask}
dataset = dataset.map(process).remove_columns(["question", "answer"])
return dataset
def get_gsm8k_dataset(split, tokenizer, rank, world_size):
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
return process_gsm8k_sft_dataset(dataset, tokenizer)
def main_sft():
config, _ = load_expr_config(sys.argv[1:], SFTConfig)
config: SFTConfig
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
tokenizer = load_hf_tokenizer(config.trainer.tokenizer_path)
# Create dataset and dataloaders
assert config.train_dataset == "gsm8k-sft"
train_dataloader = StatefulDataLoader(
get_gsm8k_dataset("train", tokenizer, rank, world_size),
batch_size=config.train_dataset.batch_size // world_size,
shuffle=config.train_dataset.shuffle,
num_workers=config.train_dataset.num_workers,
collate_fn=pad_sequences_to_tensors,
drop_last=config.train_dataset.drop_last,
)
assert config.valid_dataset == "gsm8k-sft"
valid_dataloader = StatefulDataLoader(
get_gsm8k_dataset("test", tokenizer, rank, world_size),
batch_size=config.valid_dataset.batch_size // world_size,
shuffle=config.valid_dataset.shuffle,
num_workers=config.valid_dataset.num_workers,
collate_fn=pad_sequences_to_tensors,
drop_last=config.valid_dataset.drop_last,
)
# Initialize engine
ft_spec = FinetuneSpec(
total_train_epochs=config.trainer.exp_ctrl.total_train_epochs,
dataset_size=len(train_dataloader),
train_batch_size=config.train_dataset.batch_size,
)
engine = FSDPEngine(config=config.model)
engine.initialize(None, ft_spec)
# Run training.
trainer = SFTTrainer(
config=config.trainer,
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
engine=engine,
inf_engine=None,
)
trainer.train()
if __name__ == "__main__":
main_sft()

View File

@ -54,6 +54,7 @@ dependencies = [
"packaging",
"tabulate",
"torchdata",
"autoflake",
"gymnasium",
"tensordict",

View File

@ -107,3 +107,7 @@ def training_samples(experiment_name, trial_name):
def experiment_status(experiment_name, trial_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/experiment_status"
def update_weights_from_disk(experiment_name, trial_name, model_version):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/update_weights_from_disk/{model_version}"

View File

@ -71,5 +71,6 @@ timeout-decorator
prettytable
swanlab[dashboard]
torchdata
autoflake
gymnasium
tensordict