mirror of https://github.com/inclusionAI/AReaL
PullRequest: 332 [lite] Support FSDP engines
Merge branch mzy/lite/engines of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/332 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * fsdp2 engine * fix utils * add fsdp engine test * . * fsdp engine test passed * unsqueeze immediately before model inputs and after model outputts * add optimizer save/load, add position id calculation for input * . * format * not to squeeze * add train and eval api * . * . * improve fsdp engine data preprocessing * format * PullRequest: 337 [lite] Add SFT trainer example. * trainer log * minor changes * add update weights from disk * fix type annotation
This commit is contained in:
parent
7a438c0650
commit
15dfbe837c
|
@ -1,5 +1,15 @@
|
|||
import argparse
|
||||
import getpass
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from hydra import compose as hydra_compose
|
||||
from hydra import initialize as hydra_init
|
||||
from omegaconf import MISSING, OmegaConf
|
||||
|
||||
from realhf.api.cli_args import OptimizerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -12,8 +22,8 @@ class MicroBatchSpec:
|
|||
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
|
||||
},
|
||||
)
|
||||
max_tokens_per_mb: int = field(
|
||||
default=int(1e12),
|
||||
max_tokens_per_mb: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Maximum tokens per micro-batch. When set, n_mbs becomes the minimum number of micro-batches",
|
||||
},
|
||||
|
@ -70,6 +80,59 @@ class GenerationHyperparameters:
|
|||
return GenerationHyperparameters(**args)
|
||||
|
||||
|
||||
# Train Engine Configs
|
||||
@dataclass
|
||||
class FSDPWrapPolicy:
|
||||
transformer_layer_cls_to_wrap: Optional[List[str]] = field(
|
||||
default=None,
|
||||
metadata={"help": "A list of transformer layer names for FSDP to wrap."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FSDPEngineConfig:
|
||||
wrap_policy: Optional[FSDPWrapPolicy] = field(
|
||||
default=None,
|
||||
metadata={"help": "FSDP wrap policy, specifying model layers to wrap."},
|
||||
)
|
||||
offload_params: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to offload FSDP parameters to CPU."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainEngineConfig:
|
||||
experiment_name: str
|
||||
trial_name: str
|
||||
path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"})
|
||||
attn_impl: str = field(
|
||||
default="flash_attention_2",
|
||||
metadata={
|
||||
"help": "Attention implementation for huggingface transformers model.",
|
||||
"choices": ["flash_attention_2"],
|
||||
},
|
||||
)
|
||||
init_from_scratch: bool = field(
|
||||
default=False, metadata={"help": "Initialize model weights randomly"}
|
||||
)
|
||||
init_critic_from_actor: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Initialize critic/reward model from LM checkpoint"},
|
||||
)
|
||||
|
||||
# Training Backend Configuration
|
||||
gradient_checkpointing: bool = field(
|
||||
default=True, metadata={"help": "Enable gradient checkpointing"}
|
||||
)
|
||||
bf16: bool = field(default=False, metadata={"help": "Use bf16 precision"})
|
||||
optimizer: Optional[OptimizerConfig] = field(
|
||||
default=None, metadata={"help": "Optimizer configuration"}
|
||||
)
|
||||
backend: str = ""
|
||||
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SGLangConfig:
|
||||
"""Configuration for SGLang runtime. Refer to:
|
||||
|
@ -236,3 +299,332 @@ class InferenceEngineConfig:
|
|||
request_retries: int = field(
|
||||
default=3, metadata={"help": "Number of retries for failed requests."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SGLangEngineConfig:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExperimentSaveEvalControl:
|
||||
"""Controls the frequency of model saving and evaluation during training.
|
||||
|
||||
Manages independent counters for epochs, steps, and seconds. The model will be saved
|
||||
or evaluated when any specified frequency condition is met.
|
||||
|
||||
Note:
|
||||
- Epoch: Number of full passes through the training dataset
|
||||
- Step: Number of individual training iterations
|
||||
- Seconds: Wall-clock time duration
|
||||
"""
|
||||
|
||||
total_train_epochs: int = field(
|
||||
default=1, metadata={"help": "Total number of epochs to train the model."}
|
||||
)
|
||||
# Save control
|
||||
save_freq_epochs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Save frequency in epochs. None disables epoch-based saving."
|
||||
},
|
||||
)
|
||||
save_freq_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Save frequency in steps. None disables step-based saving."},
|
||||
)
|
||||
save_freq_secs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Save frequency in seconds. None disables time-based saving."
|
||||
},
|
||||
)
|
||||
# Checkpointing control
|
||||
ckpt_freq_epochs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Checkpoint frequency in epochs. None uses save_freq_epochs. "
|
||||
"Checkpointing is used for recover. Previous checkpoint is overwritten to save space."
|
||||
},
|
||||
)
|
||||
ckpt_freq_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Checkpoint frequency in steps. None disables step-based checkpointing."
|
||||
},
|
||||
)
|
||||
ckpt_freq_secs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Checkpoint frequency in seconds. None disables time-based checkpointing."
|
||||
},
|
||||
)
|
||||
# Evaluation control
|
||||
eval_freq_epochs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Evaluation frequency in epochs. None disables epoch-based evaluation."
|
||||
},
|
||||
)
|
||||
eval_freq_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Evaluation frequency in steps. None disables step-based evaluation."
|
||||
},
|
||||
)
|
||||
eval_freq_secs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Evaluation frequency in seconds. None disables time-based evaluation."
|
||||
},
|
||||
)
|
||||
# Benchmark control
|
||||
benchmark_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Terminate training after this number of steps. "
|
||||
"For benchmarking purposes only. None indicates normal training."
|
||||
},
|
||||
)
|
||||
benchmark_n_seqs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Terminate training after consuming this number of samples. "
|
||||
"For benchmarking purposes only. None indicates normal training."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WandBConfig:
|
||||
mode: str = "disabled"
|
||||
entity: Optional[str] = None
|
||||
project: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
job_type: Optional[str] = None
|
||||
group: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
config: Optional[Dict] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwanlabConfig:
|
||||
project: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
config: Optional[Dict] = None
|
||||
logdir: Optional[str] = None
|
||||
mode: Optional[str] = "local"
|
||||
api_key: Optional[str] = os.getenv("SWANLAB_API_KEY", None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorBoardConfig:
|
||||
path: Optional[str] = None
|
||||
|
||||
|
||||
def get_user_tmp():
|
||||
user = getpass.getuser()
|
||||
user_tmp = os.path.join("/home", user, ".cache", "realhf")
|
||||
os.makedirs(user_tmp, exist_ok=True)
|
||||
return user_tmp
|
||||
|
||||
|
||||
@dataclass
|
||||
class NameResolveConfig:
|
||||
type: str = field(
|
||||
default="nfs",
|
||||
metadata={
|
||||
"help": "Type of the distributed KV store for name resolving.",
|
||||
"choices": ["nfs", "etcd3", "ray"],
|
||||
},
|
||||
)
|
||||
nfs_record_root: str = field(
|
||||
default="/tmp/areal/name_resolve",
|
||||
metadata={
|
||||
"help": "Record root for NFS name resolving. Should be available in all nodes."
|
||||
},
|
||||
)
|
||||
etcd3_addr: str = field(
|
||||
default="localhost:2379", metadata={"help": "Address of the ETCD3 server."}
|
||||
)
|
||||
ray_actor_name: str = field(
|
||||
default="ray_kv_store",
|
||||
metadata={"help": "Name of the distributed Ray KV store."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClusterSpecConfig:
|
||||
name_resolve: NameResolveConfig = field(
|
||||
default_factory=NameResolveConfig,
|
||||
metadata={"help": "Name resolving configuration."},
|
||||
)
|
||||
cluster_name: str = field(
|
||||
default="local",
|
||||
metadata={"help": "Name of the cluster. Used to set specific environs."},
|
||||
)
|
||||
fileroot: str = field(
|
||||
default=get_user_tmp(),
|
||||
metadata={
|
||||
"help": "Root for logs and checkpoints. Should be available to all nodes."
|
||||
},
|
||||
)
|
||||
gpu_type: str = field(
|
||||
default="tesla", metadata={"help": "GPU type of the cluster. Used by slurm."}
|
||||
)
|
||||
mount: str = field(
|
||||
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
|
||||
)
|
||||
gpu_image: str = field(default="", metadata={"help": "slurm image for trainers."})
|
||||
cpu_image: str = field(default="", metadata={"help": "slurm image for CPU jobs."})
|
||||
gpu_infer_image: str = field(
|
||||
default="", metadata={"help": "slurm image for LLM inference."}
|
||||
)
|
||||
node_name_prefix: str = field(
|
||||
default="slurmd-", metadata={"help": "Node prefix for a slurm cluster."}
|
||||
)
|
||||
n_nodes: int = field(
|
||||
default=32,
|
||||
metadata={
|
||||
"help": "The size of the cluster. Used to decide slurm hostname suffix."
|
||||
},
|
||||
)
|
||||
n_gpus_per_node: int = field(
|
||||
default=8,
|
||||
metadata={"help": "GPUs per node (physically)."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
type: str = field(default="", metadata={"help": "Type of implemented dataset"})
|
||||
batch_size: int = field(
|
||||
default=1, metadata={"help": "Batch size of the dataloader"}
|
||||
)
|
||||
shuffle: bool = field(
|
||||
default=True, metadata={"help": "Whether to shuffle the dataset"}
|
||||
)
|
||||
pin_memory: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Pin memory for faster data loading (set True for GPU training)"
|
||||
},
|
||||
)
|
||||
num_workers: int = field(
|
||||
default=0, metadata={"help": "Number of worker processes for data loading"}
|
||||
)
|
||||
drop_last: bool = field(default=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerConfig:
|
||||
experiment_name: str = field(
|
||||
default=MISSING,
|
||||
metadata={"help": "Name of the experiment (no '_' or '/'). Required."},
|
||||
)
|
||||
trial_name: str = field(
|
||||
default=MISSING,
|
||||
metadata={"help": "Name of the trial (no '-' or '/'). Required."},
|
||||
)
|
||||
fileroot: str = field(
|
||||
default=get_user_tmp(),
|
||||
metadata={
|
||||
"help": "Root for logs and checkpoints. Should be available to all nodes."
|
||||
},
|
||||
)
|
||||
wandb: WandBConfig = field(
|
||||
default_factory=WandBConfig,
|
||||
metadata={"help": "Weights & Biases configuration."},
|
||||
)
|
||||
swanlab: SwanlabConfig = field(
|
||||
default_factory=SwanlabConfig,
|
||||
metadata={"help": "SwanLab configuration."},
|
||||
)
|
||||
tensorboard: TensorBoardConfig = field(
|
||||
default_factory=TensorBoardConfig,
|
||||
metadata={"help": "TensorBoard configuration. Only 'path' field required."},
|
||||
)
|
||||
allocation_mode: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "GPU parallel strategy allocation mode. "
|
||||
"Options: manual/heuristic or pattern-based."
|
||||
},
|
||||
)
|
||||
seed: int = field(default=1, metadata={"help": "Random seed for reproducibility."})
|
||||
exp_ctrl: ExperimentSaveEvalControl = field(
|
||||
default_factory=ExperimentSaveEvalControl,
|
||||
metadata={"help": "Experiment save/evaluation control configuration."},
|
||||
)
|
||||
|
||||
tokenizer_path: str = field(default="")
|
||||
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseExperimentConfig:
|
||||
# NOTE: we need this unified config class because different experiments
|
||||
# have different config structures, e.g., GRPO has two engine configs,
|
||||
# but SFT only has a single one. We use subclasses to represent these structures.
|
||||
experiment_name: str = field(
|
||||
default=MISSING,
|
||||
metadata={"help": "Name of the experiment (no '_' or '/'). Required."},
|
||||
)
|
||||
trial_name: str = field(
|
||||
default=MISSING,
|
||||
metadata={"help": "Name of the trial (no '-' or '/'). Required."},
|
||||
)
|
||||
cluster: ClusterSpecConfig = field(
|
||||
default_factory=ClusterSpecConfig,
|
||||
metadata={"help": "Cluster specification. Mainly used by slurm."},
|
||||
)
|
||||
n_nodes: int = field(
|
||||
default=1, metadata={"help": "Number of nodes for experiment."}
|
||||
)
|
||||
n_gpus_per_node: int = field(
|
||||
default=8, metadata={"help": "Number of GPUs per node for this experiment."}
|
||||
)
|
||||
train_dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
||||
valid_dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTConfig(BaseExperimentConfig):
|
||||
model: TrainEngineConfig = field(default_factory=TrainEngineConfig)
|
||||
trainer: TrainerConfig = field(default_factory=TrainerConfig)
|
||||
|
||||
|
||||
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config", help="The path of the main configuration file", required=True
|
||||
)
|
||||
args, overrides = parser.parse_known_args(argv)
|
||||
|
||||
# Initialize hydra config
|
||||
config_file = Path(args.config).absolute()
|
||||
assert config_file.exists()
|
||||
# hydra only recognize relative paths
|
||||
relpath = Path(
|
||||
os.path.relpath(str(config_file), (Path(__file__).parent).absolute())
|
||||
)
|
||||
hydra_init(config_path=str(relpath.parent), job_name="app", version_base=None)
|
||||
cfg = hydra_compose(
|
||||
config_name=str(relpath.name).rstrip(".yaml"),
|
||||
overrides=overrides,
|
||||
)
|
||||
|
||||
# Merge with the default configuration.
|
||||
# The yaml and commandline can omit some default values defined in python dataclasses.
|
||||
default_cfg = OmegaConf.structured(config_cls)
|
||||
cfg = OmegaConf.merge(default_cfg, cfg)
|
||||
cfg = OmegaConf.to_object(cfg)
|
||||
assert isinstance(cfg, BaseExperimentConfig)
|
||||
|
||||
# Setup environment
|
||||
from realhf.base import constants, name_resolve
|
||||
|
||||
constants.set_experiment_trial_names(cfg.experiment_name, cfg.trial_name)
|
||||
name_resolve.reconfigure(cfg.cluster.name_resolve)
|
||||
return cfg, str(config_file)
|
||||
|
|
|
@ -47,7 +47,14 @@ class TrainEngine(abc.ABC):
|
|||
|
||||
def destroy(self):
|
||||
"""Destroy the engine and release GPU memory."""
|
||||
pass
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set the engine to the train mode."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def eval(self):
|
||||
"""Set the engine to the eval mode."""
|
||||
return self.train(False)
|
||||
|
||||
def upload_weights(self, meta: WeightUpdateMeta):
|
||||
"""Upload weights to the inference engine."""
|
||||
|
@ -111,7 +118,6 @@ class InferenceEngine(abc.ABC):
|
|||
|
||||
def destroy(self):
|
||||
"""Destroy the engine and release GPU memory."""
|
||||
pass
|
||||
|
||||
def update_weights(self, meta: WeightUpdateMeta) -> Future:
|
||||
"""Update weights in the inference engine."""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import abc
|
||||
from typing import Any, Callable, Dict, List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class Environment(abc.ABC):
|
||||
|
@ -11,7 +11,6 @@ class Environment(abc.ABC):
|
|||
For stateful environments, this is where resources are created and
|
||||
prepared (e.g., launching a browser).
|
||||
"""
|
||||
pass
|
||||
|
||||
def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""Lists all available tools in the environment."""
|
||||
|
@ -27,4 +26,3 @@ class Environment(abc.ABC):
|
|||
|
||||
This method is critical for stateful environments (e.g., a browser session).
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -6,7 +6,7 @@ import itertools
|
|||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
|
@ -161,6 +161,7 @@ class WeightUpdateMeta:
|
|||
path: str | None
|
||||
alloc_mode: AllocationMode | None
|
||||
comm_backend: str | None
|
||||
model_version: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -21,4 +21,3 @@ def reward_fn(
|
|||
Any other attributes in the dataset will be passed as keyword arguments to this function.
|
||||
:rtype: float
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
import getpass
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from tensorboardX import SummaryWrite
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import TrainerConfig
|
||||
from arealite.api.engine_api import InferenceEngine, TrainEngine
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import logging, timeutil
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
config: TrainerConfig,
|
||||
train_dataloader: StatefulDataLoader,
|
||||
valid_dataloader: StatefulDataLoader,
|
||||
engine: TrainEngine,
|
||||
inf_engine: InferenceEngine | None = None,
|
||||
):
|
||||
self.config = config
|
||||
|
||||
self.train_dataloader = train_dataloader
|
||||
self.valid_dataloader = valid_dataloader
|
||||
|
||||
self.engine = engine
|
||||
self.inf_engine = inf_engine
|
||||
|
||||
self.tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||
|
||||
self.save_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=config.exp_ctrl.save_freq_epochs,
|
||||
freq_step=config.exp_ctrl.save_freq_steps,
|
||||
freq_sec=config.exp_ctrl.save_freq_secs,
|
||||
)
|
||||
self.eval_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=config.exp_ctrl.eval_freq_epochs,
|
||||
freq_step=config.exp_ctrl.eval_freq_steps,
|
||||
freq_sec=config.exp_ctrl.eval_freq_steps,
|
||||
)
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.init_stats_logging()
|
||||
|
||||
def init_stats_logging(self):
|
||||
"""
|
||||
Initialize wandb and/or tensorboard according to config.
|
||||
If torch.distributed is initialized
|
||||
|
||||
Return:
|
||||
tensorboard SummaryWriter if self.config.tensorboard.path is not None
|
||||
"""
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
|
||||
# wandb init, connect to remote wandb host
|
||||
if self.config.wandb.mode != "disabled":
|
||||
wandb.login()
|
||||
wandb.init(
|
||||
mode=self.config.wandb.mode,
|
||||
entity=self.config.wandb.entity,
|
||||
project=self.config.wandb.project or self.config.experiment_name,
|
||||
name=self.config.wandb.name or self.config.trial_name,
|
||||
job_type=self.config.wandb.job_type,
|
||||
group=self.config.wandb.group
|
||||
or f"{self.config.experiment_name}_{self.config.trial_name}",
|
||||
notes=self.config.wandb.notes,
|
||||
tags=self.config.wandb.tags,
|
||||
config=self.config.wandb.config,
|
||||
dir=Trainer.get_log_path(self.config),
|
||||
force=True,
|
||||
id=f"{self.config.experiment_name}_{self.config.trial_name}_train",
|
||||
resume="allow",
|
||||
settings=wandb.Settings(start_method="fork"),
|
||||
)
|
||||
# tensorboard logging
|
||||
self.summary_writer = None
|
||||
if self.config.tensorboard.path is not None:
|
||||
self.summary_writer = SummaryWriter(log_dir=self.config.tensorboard.path)
|
||||
|
||||
def log_wandb_tensorboard(self, step: int, data: Dict):
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
|
||||
wandb.log(data, step=step)
|
||||
if self.summary_writer is not None:
|
||||
for key, val in data.items():
|
||||
self.summary_writer.add_scalar(f"{key}", val, step)
|
||||
|
||||
def close_wandb_tensorboard(self):
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
|
||||
wandb.finish()
|
||||
if self.summary_writer is not None:
|
||||
self.summary_writer.close()
|
||||
|
||||
@staticmethod
|
||||
def get_save_checkpoint_path(
|
||||
config: TrainerConfig,
|
||||
epoch: int,
|
||||
step: int,
|
||||
globalstep: int,
|
||||
name: str = "default",
|
||||
):
|
||||
path = os.path.join(
|
||||
f"{config.fileroot}/checkpoints/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}",
|
||||
name,
|
||||
f"epoch{epoch}epochstep{step}globalstep{globalstep}",
|
||||
)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
@staticmethod
|
||||
def get_log_path(config: TrainerConfig):
|
||||
path = f"{config.fileroot}/logs/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}"
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
def log(self, msg: str, level="info"):
|
||||
if dist.is_initialized() and dist.get_rank() > 0:
|
||||
return
|
||||
log_fn = getattr(self.logger, level, "info")
|
||||
return log_fn(msg)
|
||||
|
||||
def train(self):
|
||||
raise NotImplementedError()
|
|
@ -0,0 +1,464 @@
|
|||
import gc
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from tensordict import TensorDict
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
StateDictOptions,
|
||||
get_model_state_dict,
|
||||
)
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import (
|
||||
FinetuneSpec,
|
||||
MicroBatchSpec,
|
||||
SaveLoadMeta,
|
||||
TrainEngine,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
from arealite.utils.data import (
|
||||
MicroBatchList,
|
||||
amend_position_ids,
|
||||
pack_tensor_dict,
|
||||
pad_and_stack_tensors_along_first_dim,
|
||||
pad_mb_list,
|
||||
reorder_list,
|
||||
split_packed_tensor_dict_into_mb_list,
|
||||
unpack_sequence,
|
||||
unsqueeze_mb_list,
|
||||
)
|
||||
from arealite.utils.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
MixedPrecisionPolicy,
|
||||
apply_fsdp2,
|
||||
create_fsdp_device_mesh,
|
||||
fsdp2_clip_grad_norm_,
|
||||
fsdp2_load_full_state_dict,
|
||||
get_cosine_schedule_with_warmup,
|
||||
)
|
||||
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
|
||||
logger = logging.getLogger("FSDPEngine")
|
||||
|
||||
|
||||
class FSDPEngine(TrainEngine):
|
||||
def __init__(self, config: TrainEngineConfig):
|
||||
self.config = config
|
||||
self.optimizer_config = config.optimizer
|
||||
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
self.tokenizer = None
|
||||
# huggingface model config
|
||||
self.model_config = None
|
||||
# FSDP options
|
||||
self.mixed_precision_policy = None
|
||||
self.device_mesh = None
|
||||
self.cpu_offload = None
|
||||
# initialization
|
||||
self.initialized = False
|
||||
self.weight_update_group_initialized = False
|
||||
|
||||
# TODO: Handle the case when WORLD_SIZE is not set in launcher
|
||||
self.world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
assert self.model is not None
|
||||
self.model.train(mode=mode)
|
||||
return self
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
|
||||
# Initialize distributed enviroments and load model.
|
||||
assert addr is None, "FSDPEngine does not support remote initialization."
|
||||
|
||||
assert pkg_version.is_version_greater_or_equal(
|
||||
"torch", "2.4.0"
|
||||
), f"arealite only supports FSDP2, which requires torch>=2.4.0"
|
||||
|
||||
"""Initialize distributed communication and model."""
|
||||
if not dist.is_initialized():
|
||||
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
# TODO: Handle the condition when LOCAL_RANK is not set in launcher
|
||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
dtype = torch.bfloat16 if self.config.bf16 else torch.float16
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
self.tokenizer = load_hf_tokenizer(self.config.path)
|
||||
with torch.device("cuda"):
|
||||
# initialize scratch model from config
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
|
||||
# Simple auto wrap policy
|
||||
self.mixed_precision_policy = MixedPrecisionPolicy(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.float32,
|
||||
cast_forward_inputs=True,
|
||||
)
|
||||
self.device_mesh = create_fsdp_device_mesh(self.world_size, self.world_size)
|
||||
# sharding_strategy = ShardingStrategy.FULL_SHARD
|
||||
self.cpu_offload = (
|
||||
CPUOffloadPolicy() if self.config.fsdp.offload_params else None
|
||||
)
|
||||
|
||||
fsdp_kwargs = {
|
||||
"mesh": self.device_mesh,
|
||||
"mp_policy": self.mixed_precision_policy,
|
||||
"offload_policy": self.cpu_offload,
|
||||
"reshard_after_forward": True,
|
||||
}
|
||||
|
||||
# Wrap with FSDP2
|
||||
apply_fsdp2(model, fsdp_kwargs, self.config.fsdp.wrap_policy)
|
||||
self.model = model
|
||||
|
||||
if not self.config.init_from_scratch:
|
||||
# Load model from a initial checkpoint path,
|
||||
# which should only be a huggingface checkpoint.
|
||||
load_meta = SaveLoadMeta(
|
||||
path=self.config.path,
|
||||
weight_format="hf",
|
||||
with_optim=False,
|
||||
tokenizer=None,
|
||||
base_model_path=self.config.path,
|
||||
)
|
||||
self.load(load_meta)
|
||||
|
||||
# Set up optimizer
|
||||
if self.optimizer_config is not None:
|
||||
assert (
|
||||
self.optimizer_config.type == "adam"
|
||||
), "Only AdamW optimizer is supported in this engine."
|
||||
lr = self.optimizer_config.lr
|
||||
weight_decay = self.optimizer_config.weight_decay
|
||||
beta1 = self.optimizer_config.beta1
|
||||
beta2 = self.optimizer_config.beta2
|
||||
eps = self.optimizer_config.eps
|
||||
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
betas=(beta1, beta2),
|
||||
eps=eps,
|
||||
)
|
||||
total_train_steps = ft_spec.total_train_steps
|
||||
num_warmup_steps = int(
|
||||
self.optimizer_config.warmup_steps_proportion * total_train_steps
|
||||
)
|
||||
|
||||
if self.optimizer_config.lr_scheduler_type == "cosine":
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
total_train_steps,
|
||||
min_lr_ratio=self.optimizer_config.min_lr_ratio,
|
||||
)
|
||||
elif self.optimizer_config.lr_scheduler_type == "linear":
|
||||
self.lr_scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
total_train_steps,
|
||||
)
|
||||
elif self.optimizer_config.lr_scheduler_type == "constant":
|
||||
self.lr_scheduler = get_constant_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
|
||||
)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def destroy(self):
|
||||
"""Destroy the engine and release GPU memory."""
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
self.initialized = False
|
||||
|
||||
def save(self, meta: SaveLoadMeta):
|
||||
if meta.weight_format == "hf":
|
||||
self._save_model_to_hf(meta.path, meta.tokenizer)
|
||||
elif meta.weight_format == "dcp":
|
||||
# TODO: implement DCP save/load for FSDP
|
||||
raise NotImplementedError("DCP format saving is not implemented yet. ")
|
||||
else:
|
||||
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
||||
|
||||
if meta.with_optim:
|
||||
self._save_optimizer_state(meta.path)
|
||||
|
||||
def load(self, meta: SaveLoadMeta):
|
||||
if meta.weight_format == "hf":
|
||||
self._load_model_from_hf(meta.path)
|
||||
elif meta.weight_format == "dcp":
|
||||
# TODO: implement DCP save/load for FSDP
|
||||
raise NotImplementedError("DCP format loading is not implemented yet. ")
|
||||
else:
|
||||
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
||||
|
||||
if meta.with_optim:
|
||||
self._load_optimizer_state(meta.path)
|
||||
|
||||
def _save_optimizer_state(self, path: str):
|
||||
# Save FSDP sharded state dict on each rank
|
||||
assert self.optimizer is not None
|
||||
assert dist.is_initialized()
|
||||
rank = dist.get_rank()
|
||||
shard_path = os.path.join(
|
||||
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
|
||||
)
|
||||
state_dict = self.optimizer.state_dict()
|
||||
torch.save(state_dict, shard_path)
|
||||
dist.barrier()
|
||||
|
||||
def _load_optimizer_state(self, path: str):
|
||||
# Load FSDP sharded state dict
|
||||
assert self.optimizer is not None
|
||||
assert dist.is_initialized()
|
||||
rank = dist.get_rank()
|
||||
shard_path = os.path.join(
|
||||
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
|
||||
)
|
||||
optimizer_state_dict = torch.load(shard_path, weights_only=False)
|
||||
self.optimizer.load_state_dict(optimizer_state_dict)
|
||||
dist.barrier()
|
||||
|
||||
def _save_model_to_hf(
|
||||
self, path: str, tokenizer: Optional[transformers.PreTrainedTokenizerFast]
|
||||
):
|
||||
"""Save model in HuggingFace format."""
|
||||
if self.model is None:
|
||||
raise RuntimeError("Model not initialized")
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
# FSDP2 checkpoint saving
|
||||
# Get full state dict with FSDP2
|
||||
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
|
||||
state_dict = get_model_state_dict(self.model, options=options)
|
||||
|
||||
# save huggingface model on rank 0
|
||||
if dist.get_rank() == 0:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
self.model.save_pretrained(path, state_dict=state_dict)
|
||||
self.model_config.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
def _load_model_from_hf(self, path: str):
|
||||
"""Load model from HuggingFace format."""
|
||||
if dist.get_rank() == 0:
|
||||
full_state = get_state_dict_from_repo_id_or_path(path)
|
||||
else:
|
||||
full_state = {}
|
||||
|
||||
fsdp2_load_full_state_dict(
|
||||
self.model,
|
||||
full_state,
|
||||
self.cpu_offload,
|
||||
tie_word_embeddings=self.model_config.tie_word_embeddings,
|
||||
)
|
||||
|
||||
def upload_weights(self, meta: WeightUpdateMeta):
|
||||
if meta.type == "nccl":
|
||||
if not self.weight_update_group_initialized:
|
||||
self._init_distributed_weight_update(meta)
|
||||
self._update_weights_from_distributed()
|
||||
elif meta.type == "disk":
|
||||
self._save_model_to_hf(meta.path, self.tokenizer)
|
||||
# dist.barrier() are called when _save_model_to_hf finished
|
||||
if dist.get_rank() == 0:
|
||||
update_name = names.update_weights_from_disk(
|
||||
self.config.experiment_name,
|
||||
self.config.trial_name,
|
||||
meta.model_version,
|
||||
)
|
||||
name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120)
|
||||
else:
|
||||
raise ValueError(f"Unknown weight update type {meta.type}")
|
||||
|
||||
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
||||
raise NotImplementedError(
|
||||
"Distributed weight update is not implemented for FSDPEngine yet. "
|
||||
)
|
||||
|
||||
def _update_weights_from_distributed(self):
|
||||
raise NotImplementedError(
|
||||
"Distributed weight update is not implemented for FSDPEngine yet. "
|
||||
)
|
||||
|
||||
def step_lr_scheduler(self):
|
||||
assert self.lr_scheduler is not None
|
||||
self.lr_scheduler.step()
|
||||
|
||||
def _prepare_mb_list(
|
||||
self, input_: TensorDict, mb_spec: MicroBatchSpec
|
||||
) -> MicroBatchList:
|
||||
assert "attention_mask" in input_ and "input_ids" in input_
|
||||
input_ = amend_position_ids(input_)
|
||||
packed_input = pack_tensor_dict(input_)
|
||||
mb_list = split_packed_tensor_dict_into_mb_list(
|
||||
packed_input,
|
||||
mb_spec,
|
||||
)
|
||||
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
||||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||
# packed input to be of shape [1, total_seqlen].
|
||||
mb_list = unsqueeze_mb_list(mb_list)
|
||||
return mb_list
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict[str, float]:
|
||||
"""Train on a batch using gradient accumulation."""
|
||||
input_ = input_.to(self.device)
|
||||
assert self.optimizer is not None
|
||||
assert self.optimizer_config is not None
|
||||
assert self.lr_scheduler is not None
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
mb_list = self._prepare_mb_list(input_, mb_spec)
|
||||
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||
)
|
||||
assert total_loss_weight != 0
|
||||
dist.all_reduce(total_loss_weight)
|
||||
|
||||
# Process microbatches with gradient accumulation
|
||||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
):
|
||||
outputs = self.model(**padded_mb_input)
|
||||
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
loss = loss_fn(logits, mb_input)
|
||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||
|
||||
# Scale loss for accumulation
|
||||
# Revert gradient averaging across dp ranks
|
||||
loss_scale *= self.world_size
|
||||
|
||||
loss *= loss_scale
|
||||
loss.backward()
|
||||
|
||||
grad_norm = fsdp2_clip_grad_norm_(
|
||||
self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping
|
||||
)
|
||||
if not torch.isfinite(grad_norm):
|
||||
self.optimizer.zero_grad()
|
||||
update_successful = False
|
||||
else:
|
||||
self.optimizer.step()
|
||||
update_successful = True
|
||||
|
||||
current_lr = self.lr_scheduler.get_last_lr()[0]
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
return dict(
|
||||
update_successful=float(update_successful),
|
||||
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
||||
lr=current_lr,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_batch(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
"""Evaluate on a batch."""
|
||||
mb_list = self._prepare_mb_list(input_, mb_spec)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||
)
|
||||
assert total_loss_weight != 0
|
||||
|
||||
total_loss = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
):
|
||||
outputs = self.model(**padded_mb_input)
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
loss = loss_fn(logits, mb_input)
|
||||
|
||||
# Simple weight calculation (could be improved)
|
||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||
total_loss += loss.item() * loss_scale
|
||||
total_weight += loss_scale
|
||||
|
||||
return torch.tensor(total_loss / total_weight)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[int] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
) -> Any | None:
|
||||
"""Forward pass with optional post-processing."""
|
||||
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
|
||||
mb_list = self._prepare_mb_list(input_, mb_spec)
|
||||
|
||||
if output_seqlens is None:
|
||||
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
|
||||
|
||||
results = []
|
||||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
):
|
||||
outputs = self.model(**padded_mb_input)
|
||||
logits = outputs.logits.squeeze(0)
|
||||
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||
|
||||
if post_hook:
|
||||
result = post_hook(logits, mb_input)
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(logits)
|
||||
|
||||
res = aggregate_fn(results)
|
||||
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
|
||||
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
|
||||
reordered = reorder_list(unpacked, mb_list.backward_indices)
|
||||
return pad_and_stack_tensors_along_first_dim(reordered)
|
|
@ -0,0 +1,6 @@
|
|||
from arealite.api.cli_args import SGLangEngineConfig
|
||||
|
||||
|
||||
class SGLangEngine:
|
||||
def __init__(self, config: SGLangEngineConfig):
|
||||
pass
|
|
@ -1,7 +1,6 @@
|
|||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Empty, Full, Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
@ -18,7 +17,7 @@ from arealite.api.io_struct import (
|
|||
RolloutStat,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
from realhf.base import logging, pkg_version
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
|
@ -372,26 +371,41 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
def _update_weights(self, meta: WeightUpdateMeta):
|
||||
if meta.type == "disk":
|
||||
# Update weights from disk
|
||||
# Wait for model checkpoints of meta.version
|
||||
update_name = names.update_weights_from_disk(
|
||||
self.config.experiment_name, self.config.trial_name, meta.model_version
|
||||
)
|
||||
save_timestamp = int(name_resolve.wait(update_name, timeout=120))
|
||||
load_timestamp = time.time_ns()
|
||||
logger.info(
|
||||
f"Begin update weights from {meta.path}, responded in {(load_timestamp - save_timestamp)/1e6:.2f} ms"
|
||||
)
|
||||
try:
|
||||
jobs = [
|
||||
self.aupdate_weights_from_disk(addr, meta.path)
|
||||
for addr in self.addresses
|
||||
]
|
||||
loop = asyncio.new_event_loop()
|
||||
# asyncio event loop should be manually set when running asyncio stuff in another thread
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(asyncio.gather(*jobs))
|
||||
finally:
|
||||
loop.close()
|
||||
logger.info(
|
||||
f"Loading weights done in {(time.time_ns() - load_timestamp)/1e6:.2f} ms"
|
||||
)
|
||||
self.set_version(meta.model_version)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
|
||||
|
||||
async def aupdate_weights_from_disk(self, addr, path: str):
|
||||
response, _ = await self.arequest_with_retry(
|
||||
response = await self.arequest_with_retry(
|
||||
endpoint="/update_weights_from_disk",
|
||||
payload=dict(model_path=path, allow_interrupt=True),
|
||||
payload=dict(model_path=str(path), allow_interrupt=True),
|
||||
method="POST",
|
||||
max_retries=3,
|
||||
timeout=self.config.request_timeout,
|
||||
target_server=addr,
|
||||
target_addr=addr,
|
||||
)
|
||||
res = await response.json()
|
||||
assert res["success"]
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Test script for HF Engine implementation."""
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from arealite.api.cli_args import MicroBatchSpec, OptimizerConfig, TrainEngineConfig
|
||||
from arealite.api.io_struct import FinetuneSpec, SaveLoadMeta
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
|
||||
VOCAB_SIZE = 100
|
||||
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
MODEL_PATH = "Qwen/Qwen2-0.5B"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_input(
|
||||
batch_size=5,
|
||||
min_seqlen=10,
|
||||
max_seqlen=20,
|
||||
device="cuda:0",
|
||||
) -> Dict:
|
||||
"""Create mock padded input data (same format for huggingface) for testing.
|
||||
Returns a dict with input_ids, attention_mask, and position_ids.
|
||||
"""
|
||||
pad_token_id = 0
|
||||
seqlens = torch.randint(
|
||||
min_seqlen, max_seqlen, (batch_size,), dtype=torch.int, device=device
|
||||
)
|
||||
max_seqlen = int(max(seqlens))
|
||||
input_ids = torch.randint(
|
||||
0, VOCAB_SIZE, (batch_size, max_seqlen), dtype=torch.long, device=device
|
||||
)
|
||||
attn_mask = torch.zeros((batch_size, max_seqlen), dtype=torch.bool, device=device)
|
||||
|
||||
attn_mask[
|
||||
torch.arange(0, max_seqlen, device=device).unsqueeze(0) < seqlens.unsqueeze(1)
|
||||
] = 1
|
||||
input_ids.masked_fill_(~attn_mask, pad_token_id)
|
||||
|
||||
return TensorDict(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attn_mask,
|
||||
)
|
||||
|
||||
|
||||
def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor:
|
||||
"""Mock loss function for testing."""
|
||||
return torch.mean(logits)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def engine():
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "7777"
|
||||
|
||||
engine_config = TrainEngineConfig(
|
||||
experiment_name="test-fsdp-engine",
|
||||
trial_name="test0",
|
||||
path=MODEL_PATH,
|
||||
optimizer=OptimizerConfig(),
|
||||
)
|
||||
engine = FSDPEngine(engine_config)
|
||||
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
|
||||
engine.initialize(None, ft_spec)
|
||||
print("✓ Engine created successfully")
|
||||
yield engine
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_forward_microbatch(engine, mock_input):
|
||||
engine.eval()
|
||||
x2 = (
|
||||
engine.forward(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
|
||||
)
|
||||
.squeeze(0)
|
||||
.mean(-1)
|
||||
)
|
||||
x1 = (
|
||||
engine.forward(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
|
||||
)
|
||||
.squeeze(0)
|
||||
.mean(-1)
|
||||
)
|
||||
input_ids = mock_input["input_ids"]
|
||||
assert x1.shape[:1] == input_ids.shape[:1]
|
||||
assert x2.shape[:1] == input_ids.shape[:1]
|
||||
assert torch.allclose(x1, x2, atol=1e-1, rtol=1e-2), (x1 - x2).abs().max().item()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_eval_batch(engine, mock_input):
|
||||
engine.eval()
|
||||
eval_result = engine.eval_batch(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
|
||||
loss_fn=mock_loss_fn,
|
||||
loss_weight_fn=lambda x: x["cu_seqlens"][-1],
|
||||
)
|
||||
assert isinstance(eval_result, torch.Tensor), "Evaluation should return a tensor"
|
||||
assert eval_result.is_cuda, "Evaluation tensor should be on CUDA device"
|
||||
assert eval_result is not None, "Evaluation should return a loss value"
|
||||
print(f"✓ Evaluation successful, loss: {eval_result.item()}")
|
||||
|
||||
|
||||
def test_train_batch(engine, mock_input):
|
||||
engine.train()
|
||||
train_result = engine.train_batch(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
|
||||
loss_fn=mock_loss_fn,
|
||||
loss_weight_fn=lambda x: x["cu_seqlens"][-1],
|
||||
)
|
||||
assert isinstance(train_result, dict), "Training should return a dictionary"
|
||||
assert train_result["grad_norm"] is not None
|
||||
assert train_result["lr"] is not None
|
||||
print("✓ Training successful")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_hf_save_load_weights(tmp_path_factory, engine, mock_input):
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
path = tmp_path_factory.mktemp("hf_engine_test")
|
||||
save_load_meta = SaveLoadMeta(
|
||||
path=path,
|
||||
weight_format="hf",
|
||||
tokenizer=tokenizer,
|
||||
with_optim=True,
|
||||
base_model_path=None,
|
||||
)
|
||||
|
||||
old = engine.forward(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
|
||||
)
|
||||
engine.save(save_load_meta)
|
||||
|
||||
for name, param in engine.model.named_parameters():
|
||||
param.zero_()
|
||||
|
||||
engine.load(save_load_meta)
|
||||
new = engine.forward(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
|
||||
)
|
||||
assert torch.allclose(old, new)
|
|
@ -14,9 +14,9 @@ from arealite.api.cli_args import (
|
|||
InferenceEngineConfig,
|
||||
SGLangConfig,
|
||||
)
|
||||
from arealite.api.io_struct import FinetuneSpec, LLMRequest, LLMResponse
|
||||
from arealite.api.io_struct import LLMRequest, LLMResponse, WeightUpdateMeta
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import name_resolve, network, seeding
|
||||
from realhf.base import network
|
||||
|
||||
EXPR_NAME = "test_sglang_engine"
|
||||
TRIAL_NAME = "trial_0"
|
||||
|
@ -186,3 +186,50 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
|
|||
|
||||
# exit
|
||||
engine.destroy()
|
||||
|
||||
|
||||
def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
|
||||
# setup FSDP engine
|
||||
from arealite.api.cli_args import OptimizerConfig, TrainEngineConfig
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "7777"
|
||||
|
||||
engine_config = TrainEngineConfig(
|
||||
experiment_name=EXPR_NAME,
|
||||
trial_name=TRIAL_NAME,
|
||||
path=MODEL_PATH,
|
||||
optimizer=OptimizerConfig(),
|
||||
)
|
||||
engine = FSDPEngine(engine_config)
|
||||
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
|
||||
engine.initialize(None, ft_spec)
|
||||
|
||||
# setup name resolve
|
||||
import realhf.base.name_resolve as name_resolve
|
||||
from realhf.api.cli_args import NameResolveConfig
|
||||
|
||||
nfs_record_root = tmp_path_factory.mktemp("nfs_record_path")
|
||||
name_resolve_config = NameResolveConfig(type="nfs", nfs_record_root=nfs_record_root)
|
||||
name_resolve.reconfigure(name_resolve_config)
|
||||
# initialize SGLang remote engine
|
||||
from arealite.api.cli_args import InferenceEngineConfig
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
inf_engine = RemoteSGLangEngine(config)
|
||||
# test update weights
|
||||
path = tmp_path_factory.mktemp("upload_weights_from_disk")
|
||||
update_weight_meta = WeightUpdateMeta(
|
||||
type="disk", path=path, alloc_mode=None, comm_backend=None, model_version=100
|
||||
)
|
||||
future = inf_engine.update_weights(update_weight_meta)
|
||||
engine.upload_weights(update_weight_meta)
|
||||
future.result()
|
||||
assert inf_engine.get_version() == 100
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
import pytest
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
from arealite.api.cli_args import MicroBatchSpec
|
||||
from arealite.utils.data import (
|
||||
pack_tensor_dict,
|
||||
pad_and_stack_tensors_along_first_dim,
|
||||
pad_sequences_to_tensors,
|
||||
reorder_list,
|
||||
split_packed_tensor_dict_into_mbs,
|
||||
unpack_sequence,
|
||||
)
|
||||
|
||||
BS = 16
|
||||
MAX_ANSWER_LEN = 16
|
||||
MAX_PROMPT_LEN = 8
|
||||
VOCAB_SIZE = 100
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_padded_data():
|
||||
prompt_lens = torch.randint(1, MAX_PROMPT_LEN, size=(BS,))
|
||||
answer_lens = torch.randint(1, MAX_ANSWER_LEN, size=(BS,))
|
||||
all_data = []
|
||||
for prompt_len, ans_len in zip(prompt_lens, answer_lens):
|
||||
prompt_len = int(prompt_len)
|
||||
ans_len = int(ans_len)
|
||||
seq = dict(
|
||||
input_ids=torch.randint(0, VOCAB_SIZE, size=(prompt_len + ans_len,)),
|
||||
prompt_mask=torch.tensor([1] * prompt_len + [0] * ans_len),
|
||||
logprobs=torch.randn(prompt_len + ans_len),
|
||||
position_ids=torch.arange(prompt_len + ans_len),
|
||||
)
|
||||
all_data.append(TensorDict(seq))
|
||||
return pad_sequences_to_tensors(all_data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_tokens_per_mb", [24, 36, 48, 100])
|
||||
@pytest.mark.parametrize("n_mbs", [1, 2, 4, 8])
|
||||
def test_micro_batch_split(mock_padded_data, n_mbs, max_tokens_per_mb):
|
||||
mb_spec = MicroBatchSpec(n_mbs, max_tokens_per_mb)
|
||||
|
||||
# Unpad and split to microbatches
|
||||
packed_data = pack_tensor_dict(mock_padded_data)
|
||||
original_lens = packed_data["cu_seqlens"][1:] - packed_data["cu_seqlens"][:-1]
|
||||
assert torch.allclose(original_lens, mock_padded_data["attention_mask"].sum(1))
|
||||
split_result = split_packed_tensor_dict_into_mbs(packed_data, mb_spec)
|
||||
reordered_lens = [original_lens[i] for i in split_result.forward_indices]
|
||||
|
||||
# assert microbatch split result does not violate requirements
|
||||
assert len(split_result.mbs) >= n_mbs
|
||||
|
||||
# test reorder back
|
||||
for key in split_result.mbs[0].keys():
|
||||
if key in ["cu_seqlens", "max_seqlen"]:
|
||||
continue
|
||||
|
||||
# assert microbatch split result does not violate requirements
|
||||
for mb in split_result.mbs:
|
||||
assert mb[key].shape[0] <= max_tokens_per_mb
|
||||
|
||||
x = torch.cat([mb[key] for mb in split_result.mbs])
|
||||
xs = unpack_sequence(x, lens=reordered_lens)
|
||||
xs = reorder_list(xs, split_result.backward_indices)
|
||||
x = torch.cat(xs)
|
||||
assert torch.allclose(x, packed_data[key])
|
||||
y = pad_and_stack_tensors_along_first_dim(xs)
|
||||
assert torch.allclose(mock_padded_data[key], y)
|
|
@ -0,0 +1,156 @@
|
|||
import time
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
|
||||
from arealite.api.io_struct import SaveLoadMeta
|
||||
from arealite.api.trainer_api import Trainer
|
||||
from arealite.utils.functional import gather_logprobs
|
||||
from arealite.utils.logging import record_timing
|
||||
from realhf.api.core.data_api import tabulate_stats
|
||||
from realhf.base import logging, stats_tracker
|
||||
|
||||
|
||||
def compute_packed_sft_loss(
|
||||
logits: torch.Tensor,
|
||||
input_: Dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
packed_input_ids: torch.Tensor = input_["input_ids"]
|
||||
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
|
||||
prompt_mask = input_["prompt_mask"].bool()
|
||||
logits = logits.float()
|
||||
|
||||
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
|
||||
prompt_mask = torch.roll(prompt_mask, shifts=-1, dims=-1)
|
||||
logprobs = torch.where(prompt_mask, 0, logprobs)
|
||||
|
||||
loss = -logprobs.sum() / prompt_mask.logical_not().count_nonzero()
|
||||
|
||||
with torch.no_grad():
|
||||
seqlogp = torch.zeros(
|
||||
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
|
||||
)
|
||||
for i in range(cu_seqlens.shape[0] - 1):
|
||||
m = prompt_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
||||
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
||||
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
|
||||
cu_seqlens,
|
||||
logprobs.shape,
|
||||
)
|
||||
seqlogp[i] = torch.where(m, 0.0, logp.detach()).sum() / (
|
||||
m.numel() - m.count_nonzero()
|
||||
)
|
||||
|
||||
## Loggin stats
|
||||
stats_tracker.denominator(
|
||||
n_seqs=torch.ones(
|
||||
cu_seqlens.shape[0] - 1, dtype=torch.bool, device=logprobs.device
|
||||
),
|
||||
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
|
||||
n_valid_tokens=prompt_mask.logical_not(),
|
||||
prompt_tokens=prompt_mask,
|
||||
)
|
||||
stats_tracker.stat(ppl=(-seqlogp).exp().float(), denominator="n_seqs")
|
||||
stats_tracker.stat(loss=-logprobs.detach(), denominator="n_valid_tokens")
|
||||
vocab_min_logits = logits.detach().min(-1).values.float()
|
||||
vocab_max_logits = logits.detach().max(-1).values.float()
|
||||
stats_tracker.stat(
|
||||
vocab_min_logits=vocab_min_logits,
|
||||
vocab_max_logits=vocab_max_logits,
|
||||
denominator="n_tokens",
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class SFTTrainer(Trainer):
|
||||
|
||||
def train(self):
|
||||
total_epochs = self.config.exp_ctrl.total_train_epochs
|
||||
steps_per_epoch = len(self.train_dataloader)
|
||||
|
||||
self.log(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
||||
global_step = 0
|
||||
start_time = time.monotonic()
|
||||
for epoch in range(total_epochs):
|
||||
for step, data in enumerate(self.train_dataloader):
|
||||
self.engine.train()
|
||||
timing_stats = {}
|
||||
with record_timing("timeperf/train_step", timing_stats):
|
||||
with stats_tracker.scope("sft"):
|
||||
stats = self.engine.train_batch(
|
||||
input_=data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda x: x["prompt_mask"]
|
||||
.logical_not()
|
||||
.count_nonzero(),
|
||||
mb_spec=self.config.mb_spec,
|
||||
)
|
||||
self.engine.step_lr_scheduler()
|
||||
stats_tracker.scalar(**stats)
|
||||
|
||||
if self.save_ctl.check(
|
||||
epochs=int(step == steps_per_epoch - 1), steps=1
|
||||
):
|
||||
self.log("Saving model ...")
|
||||
|
||||
with record_timing("timeperf/save", timing_stats):
|
||||
save_path = self.get_save_checkpoint_path(
|
||||
self.config, epoch, step, global_step
|
||||
)
|
||||
meta = SaveLoadMeta(
|
||||
path=save_path,
|
||||
weight_format="hf",
|
||||
with_optim=False,
|
||||
tokenizer=self.tokenizer,
|
||||
base_model_path=self.config.tokenizer_path,
|
||||
)
|
||||
self.engine.save(meta)
|
||||
|
||||
if self.eval_ctl.check(
|
||||
epochs=int(step == steps_per_epoch - 1), steps=1
|
||||
):
|
||||
if dist.get_rank() == 0:
|
||||
self.log("Running evaluation ...")
|
||||
with record_timing("timeperf/eval", timing_stats):
|
||||
self.evaluate()
|
||||
|
||||
stats = stats_tracker.export()
|
||||
stats.update(timing_stats)
|
||||
self.log_wandb_tensorboard(global_step, stats)
|
||||
|
||||
self.log(
|
||||
f"Epoch {epoch+1}/{total_epochs} "
|
||||
f"Step {step+1}/{steps_per_epoch} "
|
||||
f"Train step {global_step + 1}/{total_epochs * steps_per_epoch} done."
|
||||
)
|
||||
self.log(
|
||||
f"Detailed time stats: \n{tabulate_stats(timing_stats, floatfmt='.2f')}"
|
||||
)
|
||||
self.log(f"SFT training stats:\n{tabulate_stats(stats)}")
|
||||
global_step += 1
|
||||
|
||||
self.log(
|
||||
f"Training completes! Total time elapsed {time.monotonic() - start_time:.2f}."
|
||||
)
|
||||
|
||||
self.close_wandb_tensorboard()
|
||||
|
||||
def evaluate(self):
|
||||
if self.valid_dataloader is None:
|
||||
return
|
||||
self.engine.eval()
|
||||
for data in self.valid_dataloader:
|
||||
with stats_tracker.scope("sft-eval"):
|
||||
# No need to log anything. Logging will be handled outside
|
||||
# via stats_tracker.export().
|
||||
self.engine.eval_batch(
|
||||
input_=data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda x: x["prompt_mask"]
|
||||
.logical_not()
|
||||
.count_nonzero(),
|
||||
mb_spec=self.config.mb_spec,
|
||||
)
|
|
@ -0,0 +1,491 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
# Pad/unpad operations are modified from flash-attention under BSD-3 license.
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from tensordict import TensorDict
|
||||
|
||||
from arealite.api.cli_args import MicroBatchSpec
|
||||
from realhf.base import datapack, logging
|
||||
|
||||
logger = logging.getLogger("data utils")
|
||||
|
||||
|
||||
def reorder_list(xs: List, indices: List[int]) -> List:
|
||||
assert len(set(indices)) == len(xs)
|
||||
return [xs[i] for i in indices]
|
||||
|
||||
|
||||
def dict_map(x: Dict, fn: Callable) -> Dict:
|
||||
return {k: fn(v) for k, v in x.items()}
|
||||
|
||||
|
||||
def dict_of_list2list_of_dict(
|
||||
dict_of_lists: Dict[str, List[Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not dict_of_lists:
|
||||
return []
|
||||
keys = list(dict_of_lists.keys())
|
||||
length = len(dict_of_lists[keys[0]])
|
||||
for key, value_list in dict_of_lists.items():
|
||||
if len(value_list) != length:
|
||||
raise ValueError(
|
||||
f"All lists must have the same length. Key '{key}' has length {len(value_list)}, expected {length}"
|
||||
)
|
||||
return [{key: dict_of_lists[key][i] for key in keys} for i in range(length)]
|
||||
|
||||
|
||||
def list_of_dict2dict_of_list(
|
||||
list_of_dicts: List[Dict[str, Any]],
|
||||
) -> Dict[str, List[Any]]:
|
||||
if not list_of_dicts:
|
||||
return {}
|
||||
keys = list(list_of_dicts[0].keys())
|
||||
for i, dict_item in enumerate(list_of_dicts):
|
||||
if set(dict_item.keys()) != set(keys):
|
||||
raise ValueError(
|
||||
f"All dictionaries must have the same keys. Dictionary at index {i} has keys {set(dict_item.keys())}, expected {set(keys)}"
|
||||
)
|
||||
return {key: [dict_item[key] for dict_item in list_of_dicts] for key in keys}
|
||||
|
||||
|
||||
def pad_sequences_to_tensors(
|
||||
sequence_list: List[TensorDict], pad_value: float = 0.0
|
||||
) -> TensorDict:
|
||||
if not sequence_list:
|
||||
return TensorDict()
|
||||
max_length = max(len(seq) for item in sequence_list for seq in item.values())
|
||||
result = {}
|
||||
for key in sequence_list[0].keys():
|
||||
padded = []
|
||||
for item in sequence_list:
|
||||
x = item[key]
|
||||
if not torch.is_tensor(x):
|
||||
x = torch.tensor(x)
|
||||
padded.append(
|
||||
torch.nn.functional.pad(
|
||||
x, (0, max_length - len(item[key])), value=pad_value
|
||||
)
|
||||
)
|
||||
result[key] = torch.stack(padded)
|
||||
attention_mask = [
|
||||
[1] * len(next(iter(item.values())))
|
||||
+ [0] * (max_length - len(next(iter(item.values()))))
|
||||
for item in sequence_list
|
||||
]
|
||||
result["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
|
||||
return TensorDict(result, batch_size=[result["attention_mask"].shape[0]])
|
||||
|
||||
|
||||
def unpad_input(
|
||||
hidden_states, attention_mask
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
||||
)
|
||||
return (
|
||||
rearrange(hidden_states, "b s ... -> (b s) ...")[indices],
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
)
|
||||
|
||||
|
||||
def pad_input(hidden_states, indices, batch, seqlen):
|
||||
output = hidden_states.new_zeros(batch * seqlen)
|
||||
output[indices] = hidden_states
|
||||
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
||||
|
||||
|
||||
def concat_padded_tensors(
|
||||
tensor_dicts: List[Dict[str, torch.Tensor]], pad_value: float = 0.0
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Concatenate and pad tensors from multiple padded tensor dictionaries."""
|
||||
if not tensor_dicts:
|
||||
return {}
|
||||
|
||||
# Find max sequence length across all dictionaries
|
||||
lens = []
|
||||
for tensor_dict in tensor_dicts:
|
||||
for key, tensor in tensor_dict.items():
|
||||
if key != "attention_mask" and len(tensor.shape) == 2:
|
||||
lens.append(tensor.shape[1])
|
||||
break
|
||||
max_length = max(lens)
|
||||
attn_mask = torch.arange(max_length).unsqueeze(0) < torch.tensor(lens).unsqueeze(1)
|
||||
|
||||
result = {}
|
||||
# Process each key
|
||||
for key in tensor_dicts[0].keys():
|
||||
tensors_to_concat = []
|
||||
for tensor_dict in tensor_dicts:
|
||||
tensor = tensor_dict[key]
|
||||
# Skip 1D tensors like rewards
|
||||
if len(tensor.shape) == 1:
|
||||
tensors_to_concat.append(tensor)
|
||||
continue
|
||||
current_length = tensor.shape[1]
|
||||
if current_length < max_length:
|
||||
# Pad tensor to max_length
|
||||
pad_width = max_length - current_length
|
||||
if key == "attention_mask":
|
||||
# Pad attention mask with 0s
|
||||
padding = torch.zeros(
|
||||
(tensor.shape[0], pad_width), dtype=tensor.dtype
|
||||
)
|
||||
else:
|
||||
# Pad feature tensors with pad_value
|
||||
padding = torch.full(
|
||||
(tensor.shape[0], pad_width), pad_value, dtype=tensor.dtype
|
||||
)
|
||||
tensor = torch.cat([tensor, padding], dim=1)
|
||||
tensors_to_concat.append(tensor)
|
||||
|
||||
result[key] = torch.cat(tensors_to_concat, dim=0)
|
||||
if "attention_mask" not in result:
|
||||
result["attention_mask"] = attn_mask
|
||||
return result
|
||||
|
||||
|
||||
def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]:
|
||||
"""Move tensors in a dictionary to the specified device."""
|
||||
return {
|
||||
key: value.to(device) if torch.is_tensor(value) else value
|
||||
for key, value in data.items()
|
||||
}
|
||||
|
||||
|
||||
def unpack_sequence(
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
lens: Optional[List[int]] = None,
|
||||
dim: int = 0,
|
||||
):
|
||||
"""Unpack a sequence tensor into a list of tensors based on cumulative sequence lengths."""
|
||||
if lens is not None:
|
||||
return torch.split(x, lens, dim=dim)
|
||||
if cu_seqlens is not None:
|
||||
return torch.split(
|
||||
x, (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist(), dim=dim
|
||||
)
|
||||
raise ValueError("Either cu_seqlens or input_lens must be provided.")
|
||||
|
||||
|
||||
def allocate_balanced_mbs(mb_spec: MicroBatchSpec, lens: List[int]) -> List[List[int]]:
|
||||
group_indices = datapack.ffd_allocate(
|
||||
lens, mb_spec.max_tokens_per_mb, min_groups=mb_spec.n_mbs
|
||||
)
|
||||
group_indices = sorted([sorted(g) for g in group_indices])
|
||||
return group_indices
|
||||
|
||||
|
||||
def allocate_balanced_mbs_synced(
|
||||
mb_spec: MicroBatchSpec,
|
||||
lens: List[int],
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> List[List[int]]:
|
||||
group_indices = allocate_balanced_mbs(mb_spec, lens)
|
||||
if not dist.is_initialized():
|
||||
return group_indices
|
||||
|
||||
all_n_mbs = [None for _ in range(dist.get_world_size(group))]
|
||||
dist.all_gather_object(all_n_mbs, len(group_indices), group=group)
|
||||
if all(mbs == len(group_indices) for mbs in all_n_mbs):
|
||||
return group_indices
|
||||
return allocate_balanced_mbs_synced(
|
||||
MicroBatchSpec.new(mb_spec, n_mbs=max(all_n_mbs)), lens
|
||||
)
|
||||
|
||||
|
||||
def pack_tensor_dict(data: TensorDict):
|
||||
"""Pack a tensordict of shape [B, S, ...] into [total_length, ...], leaving other keys unchanged.
|
||||
|
||||
Args:
|
||||
data (Dict[str, Any]): Dictionary containing tensors to be packed. Should contain key "attention_mask" with shape [B, S].
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary with packed tensors. The "attention_mask" key will be replaced by "cu_seqlens" with shape [B+1].
|
||||
"""
|
||||
|
||||
assert "attention_mask" in data, "Input data must contain 'attention_mask' key."
|
||||
attention_mask = data["attention_mask"]
|
||||
assert attention_mask.ndim == 2, "Attention mask must be a 2D tensor."
|
||||
bs = attention_mask.shape[0]
|
||||
seq_len = attention_mask.shape[1]
|
||||
|
||||
# Calculate cumulative sequence lengths
|
||||
lens = attention_mask.sum(dim=1, dtype=torch.int32)
|
||||
max_seqlen = lens.max().item()
|
||||
cu_seqlens = torch.cumsum(lens, dim=0)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
total_length = cu_seqlens[-1].item()
|
||||
# Pack tensors
|
||||
packed_data = {}
|
||||
for key, value in data.items():
|
||||
if key == "attention_mask":
|
||||
packed_data["cu_seqlens"] = cu_seqlens
|
||||
packed_data["max_seqlen"] = max_seqlen
|
||||
# tensor and of shape [B, S, ...]
|
||||
elif (
|
||||
torch.is_tensor(value)
|
||||
and value.ndim >= 2
|
||||
and value.shape[0] == bs
|
||||
and value.shape[1] == seq_len
|
||||
):
|
||||
packed_tensor = torch.empty(
|
||||
total_length, *value.shape[2:], dtype=value.dtype, device=value.device
|
||||
)
|
||||
# Fill the packed tensor with values from the original tensor
|
||||
for i in range(bs):
|
||||
start = cu_seqlens[i].item()
|
||||
end = cu_seqlens[i + 1].item()
|
||||
packed_tensor[start:end] = value[i][: end - start]
|
||||
packed_data[key] = packed_tensor
|
||||
else:
|
||||
packed_data[key] = value
|
||||
|
||||
return TensorDict(**packed_data)
|
||||
|
||||
|
||||
def pad_and_stack_tensors_along_first_dim(tensor_list: List[torch.Tensor]):
|
||||
max_length = max(tensor.shape[0] for tensor in tensor_list)
|
||||
n_dim = tensor_list[0].ndim
|
||||
assert all(
|
||||
tensor.ndim == n_dim for tensor in tensor_list
|
||||
), "All tensors must have the same number of dimensions."
|
||||
|
||||
padded_tensors = []
|
||||
for tensor in tensor_list:
|
||||
pad_mode = (0,) * (2 * (n_dim - 1)) + (0, max_length - tensor.shape[0])
|
||||
padded_tensor = F.pad(tensor, pad_mode, value=0.0)
|
||||
padded_tensors.append(padded_tensor)
|
||||
return torch.stack(padded_tensors, dim=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MicroBatchList:
|
||||
data: Dict[str, Any]
|
||||
mb_spec: MicroBatchSpec
|
||||
mbs: List[TensorDict]
|
||||
forward_indices: List[int]
|
||||
backward_indices: List[int]
|
||||
group_lens: List[int]
|
||||
padded_mbs: Optional[List[TensorDict]] = None
|
||||
padding_lengths: Optional[List[int]] = None
|
||||
|
||||
|
||||
DEFAULT_MAX_TOKENS_PER_MB = int(1e12)
|
||||
|
||||
|
||||
def split_packed_tensor_dict_into_mb_list(
|
||||
data: TensorDict, mb_spec: MicroBatchSpec, group: Optional[dist.ProcessGroup] = None
|
||||
) -> MicroBatchList:
|
||||
"""Split a packed tensordict into micro-batches based on the cumulative sequence lengths.
|
||||
|
||||
Args:
|
||||
data (TensorDict): Dictionary containing packed tensors with "cu_seqlens" key.
|
||||
mb_spec (MicroBatchSpec): Specification for micro-batch splitting.
|
||||
group (Optional[dist.ProcessGroup]): Process group for distributed synchronization.
|
||||
|
||||
Returns:
|
||||
MicroBatchList: A structure containing the split micro-batches and metadata.
|
||||
"""
|
||||
assert (
|
||||
"cu_seqlens" in data
|
||||
), "Input data must be packed and contain 'cu_seqlens' key."
|
||||
if mb_spec.max_tokens_per_mb is None:
|
||||
mb_spec = MicroBatchSpec.new(
|
||||
mb_spec, max_tokens_per_mb=DEFAULT_MAX_TOKENS_PER_MB
|
||||
)
|
||||
cu_seqlens = data["cu_seqlens"]
|
||||
bs = cu_seqlens.shape[0] - 1
|
||||
total_lens = int(cu_seqlens[-1])
|
||||
input_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy()
|
||||
|
||||
# check tensor shape, split only 1d tensors with length "total_lens"
|
||||
to_split = {}
|
||||
not_to_split = {}
|
||||
for key, value in data.items():
|
||||
if key == "cu_seqlens" or key == "max_seqlen":
|
||||
continue
|
||||
if not torch.is_tensor(value) or value.numel() != total_lens:
|
||||
not_to_split[key] = value
|
||||
else:
|
||||
to_split[key] = value
|
||||
|
||||
# split
|
||||
group_indices = allocate_balanced_mbs_synced(mb_spec, input_lens, group=group)
|
||||
splitted_lens = [
|
||||
[input_lens[i] for i in group_index] for group_index in group_indices
|
||||
]
|
||||
group_lens = [sum(x) for x in splitted_lens]
|
||||
|
||||
forward_indices = datapack.flat2d(group_indices)
|
||||
backward_indices = np.zeros(bs, dtype=np.int64)
|
||||
backward_indices[forward_indices] = np.arange(bs)
|
||||
|
||||
def _split(tensor):
|
||||
"""Split and pad a tensor based on forward indices and lens."""
|
||||
# Unpack the sequence
|
||||
unpacked = unpack_sequence(tensor, cu_seqlens=cu_seqlens)
|
||||
# Reorder according to forward indices
|
||||
reordered = reorder_list(unpacked, forward_indices)
|
||||
reordered = torch.cat(reordered)
|
||||
# Unpack again according to split lens
|
||||
splitted = unpack_sequence(reordered, lens=group_lens)
|
||||
return splitted
|
||||
|
||||
to_split = dict_map(to_split, lambda x: _split(x))
|
||||
mbs = dict_of_list2list_of_dict(to_split)
|
||||
|
||||
results = []
|
||||
# organize splitted micro batches
|
||||
assert len(mbs) == len(splitted_lens), (len(mbs), len(splitted_lens))
|
||||
for i, (mb, lens) in enumerate(zip(mbs, splitted_lens)):
|
||||
max_seqlen = max(lens)
|
||||
lens = torch.tensor(lens, device="cuda")
|
||||
batch_cu_seqlens = torch.nn.functional.pad(
|
||||
lens.cumsum(0, dtype=torch.int), (1, 0)
|
||||
)
|
||||
results.append(
|
||||
TensorDict(
|
||||
**mb, **not_to_split, max_seqlen=max_seqlen, cu_seqlens=batch_cu_seqlens
|
||||
)
|
||||
)
|
||||
return MicroBatchList(
|
||||
data=data,
|
||||
mbs=results,
|
||||
mb_spec=mb_spec,
|
||||
forward_indices=forward_indices,
|
||||
backward_indices=backward_indices,
|
||||
group_lens=group_lens,
|
||||
)
|
||||
|
||||
|
||||
def pad_packed_tensor_dict(
|
||||
data: TensorDict,
|
||||
pad_to_length: int,
|
||||
pad_value: float = 0.0,
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
"""Pad a packed tensor dict to a specified length.
|
||||
This function assumes that the input data contains "cu_seqlens" and "max_seqlen" key,
|
||||
and all other tensors of shape [total_length, ] will be padded to `pad_to_length`.
|
||||
This function will pad a new sequence filled with `pad_value` to the end of each tensor,
|
||||
and update the "cu_seqlens" and "max_seqlen" keys accordingly.
|
||||
|
||||
Args:
|
||||
data (TensorDict): Dictionary containing tensors to be packed.
|
||||
pad_to_length (int): The length to pad the tensors to. All tensors
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: Dictionary with padded tensors and modified "cu_seqlens" and
|
||||
"max_seqlen".
|
||||
int: The pad length.
|
||||
"""
|
||||
assert "cu_seqlens" in data, "Input data must contain 'cu_seqlens' key."
|
||||
assert "max_seqlen" in data, "Input data must contain 'max_seqlen' key."
|
||||
total_length = data["cu_seqlens"][-1].item()
|
||||
pad_length = pad_to_length - total_length
|
||||
assert (
|
||||
pad_length >= 0
|
||||
), f"pad_to_length {pad_to_length} must be greater than or equal to total length {total_length}."
|
||||
cu_seqlens = data["cu_seqlens"]
|
||||
max_seqlen = data["max_seqlen"]
|
||||
new_cu_seqlens = F.pad(cu_seqlens, (0, 1), value=pad_to_length)
|
||||
new_max_seqlen = max(max_seqlen, pad_length)
|
||||
padded_data = {}
|
||||
for key, value in data.items():
|
||||
if key == "cu_seqlens":
|
||||
padded_data[key] = new_cu_seqlens
|
||||
elif key == "max_seqlen":
|
||||
padded_data[key] = new_max_seqlen
|
||||
elif torch.is_tensor(value) and value.numel() == total_length:
|
||||
# Pad the tensor to the new total length
|
||||
padded_tensor = torch.nn.functional.pad(
|
||||
value, (0, pad_length), value=pad_value
|
||||
)
|
||||
padded_data[key] = padded_tensor
|
||||
else:
|
||||
padded_data[key] = value
|
||||
return padded_data, pad_length
|
||||
|
||||
|
||||
def pad_mb_list(
|
||||
mb_list: MicroBatchList,
|
||||
pad_value: float = 0.0,
|
||||
) -> MicroBatchList:
|
||||
padded_mb_inputs, pad_lengths = [], []
|
||||
pad_to_lengths = []
|
||||
for mb, l in zip(mb_list.mbs, mb_list.group_lens):
|
||||
# NOTE: GPU page size is 2MB
|
||||
# Take hidden size 4096 with bf16 dtype as an example,
|
||||
# the batch size of a page is 256
|
||||
pad_to_length = (l + 255) // 256 * 256
|
||||
padded_mb, pad_len = pad_packed_tensor_dict(
|
||||
mb, pad_to_length, pad_value=pad_value
|
||||
)
|
||||
padded_mb_inputs.append(padded_mb)
|
||||
pad_lengths.append(pad_len)
|
||||
pad_to_lengths.append(pad_to_length)
|
||||
logger.debug(
|
||||
f"Microbatch original lengths: {mb_list.group_lens}, padded to {pad_to_lengths}."
|
||||
)
|
||||
mb_list.padded_mbs = padded_mb_inputs
|
||||
mb_list.padding_lengths = pad_lengths
|
||||
return mb_list
|
||||
|
||||
|
||||
def unsqueeze_packed_tensor_dict(data: TensorDict) -> TensorDict:
|
||||
assert "cu_seqlens" in data, "Input data must contain 'cu_seqlens' key."
|
||||
assert "max_seqlen" in data, "Input data must contain 'max_seqlen' key."
|
||||
|
||||
total_length = data["cu_seqlens"][-1].item()
|
||||
for key, value in data.items():
|
||||
if key == "cu_seqlens" or key == "max_seqlen":
|
||||
continue
|
||||
else:
|
||||
if torch.is_tensor(value) and value.numel() == total_length:
|
||||
data[key] = value.unsqueeze(dim=0)
|
||||
return data
|
||||
|
||||
|
||||
def unsqueeze_mb_list(
|
||||
mb_list: MicroBatchList,
|
||||
) -> MicroBatchList:
|
||||
"""Unsqueeze the packed tensordict in the micro-batch list."""
|
||||
new_mbs = []
|
||||
new_padded_mbs = []
|
||||
for i, mb in enumerate(mb_list.mbs):
|
||||
new_mbs.append(unsqueeze_packed_tensor_dict(mb))
|
||||
if mb_list.padded_mbs is not None:
|
||||
new_padded_mbs.append(unsqueeze_packed_tensor_dict(mb_list.padded_mbs[i]))
|
||||
mb_list.mbs = new_mbs
|
||||
mb_list.padded_mbs = new_padded_mbs if mb_list.padded_mbs is not None else None
|
||||
return mb_list
|
||||
|
||||
|
||||
def amend_position_ids(data: TensorDict) -> TensorDict:
|
||||
assert "attention_mask" in data, "Input data must contain 'attention_mask' key."
|
||||
attn_mask = data["attention_mask"]
|
||||
bs, seqlen = attn_mask.shape[:2]
|
||||
position_ids = (
|
||||
torch.arange(0, seqlen, dtype=torch.long, device=attn_mask.device)
|
||||
.unsqueeze(0)
|
||||
.expand(bs, -1)
|
||||
)
|
||||
position_ids.masked_fill(~attn_mask.bool(), 0)
|
||||
data["position_ids"] = position_ids
|
||||
return data
|
|
@ -0,0 +1,189 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from realhf.base import logging, pkg_version
|
||||
|
||||
logger = logging.getLogger("FSDPEngine")
|
||||
|
||||
if pkg_version.is_version_greater_or_equal("torch", "2.6.0"):
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
FSDPModule,
|
||||
MixedPrecisionPolicy,
|
||||
fully_shard,
|
||||
)
|
||||
elif pkg_version.is_version_greater_or_equal("torch", "2.4.0"):
|
||||
from torch.distributed._composable.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
FSDPModule,
|
||||
MixedPrecisionPolicy,
|
||||
fully_shard,
|
||||
)
|
||||
else:
|
||||
CPUOffloadPolicy = None
|
||||
FSDPModule = None
|
||||
MixedPrecisionPolicy = None
|
||||
fully_shard = None
|
||||
logger.warning("Current PyTorch version < 2.4.0 is not supported for FSDPEngine.")
|
||||
|
||||
|
||||
def fsdp2_clip_grad_norm_(
|
||||
parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None
|
||||
):
|
||||
"""torch.nn.utils.clip_grad_norm_ cann't run on cpu parameter DTensor"""
|
||||
from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
else:
|
||||
# prevent generators from being exhausted
|
||||
parameters = list(parameters)
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
|
||||
total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True)
|
||||
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
|
||||
return total_norm
|
||||
|
||||
|
||||
def create_fsdp_device_mesh(shard_size, world_size):
|
||||
if shard_size < 0 or shard_size >= world_size:
|
||||
device_mesh = init_device_mesh(
|
||||
"cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)
|
||||
)
|
||||
else:
|
||||
device_mesh = init_device_mesh(
|
||||
"cuda",
|
||||
mesh_shape=(world_size // shard_size, shard_size),
|
||||
mesh_dim_names=("ddp", "fsdp"),
|
||||
)
|
||||
return device_mesh
|
||||
|
||||
|
||||
def apply_fsdp2(model, fsdp_kwargs, wrap_policy):
|
||||
"""model: AutoModelForCausalLM"""
|
||||
assert (
|
||||
CPUOffloadPolicy is not None
|
||||
), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
|
||||
|
||||
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", list())
|
||||
fsdp_transformer_layer_cls_to_wrap = (
|
||||
wrap_policy.transformer_layer_cls_to_wrap if wrap_policy is not None else list()
|
||||
)
|
||||
if not fsdp_transformer_layer_cls_to_wrap:
|
||||
fsdp_transformer_layer_cls_to_wrap = default_transformer_cls_names_to_wrap
|
||||
|
||||
if isinstance(fsdp_transformer_layer_cls_to_wrap, str):
|
||||
fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap]
|
||||
|
||||
assert (
|
||||
len(fsdp_transformer_layer_cls_to_wrap) > 0
|
||||
and fsdp_transformer_layer_cls_to_wrap[0] is not None
|
||||
)
|
||||
|
||||
modules = []
|
||||
for name, module in model.named_modules():
|
||||
if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or (
|
||||
isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings
|
||||
):
|
||||
modules.append(module)
|
||||
|
||||
for idx, module in enumerate(modules):
|
||||
fully_shard(module, **fsdp_kwargs)
|
||||
fully_shard(
|
||||
model, **fsdp_kwargs
|
||||
) # fsdp2 will not reshard_after_forward for root module
|
||||
|
||||
|
||||
def fsdp2_load_full_state_dict(
|
||||
model: PreTrainedModel,
|
||||
full_state: dict,
|
||||
cpu_offload=None,
|
||||
tie_word_embeddings=False,
|
||||
):
|
||||
"""
|
||||
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
|
||||
parameters from rank 0 to all other ranks. This function modifies the model in-place.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model to load the state dict into
|
||||
full_state (`dict`): The full state dict to load, can only be on rank 0
|
||||
"""
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
StateDictOptions,
|
||||
set_model_state_dict,
|
||||
)
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
model = model.to(device=device, non_blocking=True)
|
||||
cpu_offload = cpu_offload is not None
|
||||
options = StateDictOptions(
|
||||
full_state_dict=True,
|
||||
cpu_offload=cpu_offload,
|
||||
broadcast_from_rank0=True,
|
||||
strict=not tie_word_embeddings,
|
||||
)
|
||||
set_model_state_dict(model, full_state, options=options)
|
||||
|
||||
if tie_word_embeddings:
|
||||
model.tie_weights()
|
||||
|
||||
# rotary_emb is not in state_dict, so we need to broadcast it manually
|
||||
for name, buf in model.named_buffers():
|
||||
dist.broadcast(buf, src=0)
|
||||
|
||||
if cpu_offload:
|
||||
model.to("cpu", non_blocking=True)
|
||||
for buf in model.buffers():
|
||||
buf.data = buf.data.to(device)
|
||||
|
||||
|
||||
def get_cosine_schedule_with_warmup(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
min_lr_ratio: float = 0.0,
|
||||
num_cycles: float = 0.5,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
||||
initial lr set in the optimizer.
|
||||
Args:
|
||||
optimizer (:class:`~torch.optim.Optimizer`):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
num_warmup_steps (:obj:`int`):
|
||||
The number of steps for the warmup phase.
|
||||
num_training_steps (:obj:`int`):
|
||||
The total number of training steps.
|
||||
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The minimum lr ratio w.r.t the maximum.
|
||||
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
|
||||
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
||||
following a half-cosine).
|
||||
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
Return:
|
||||
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||
"""
|
||||
assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0
|
||||
coef = (1 - min_lr_ratio) * 0.5
|
||||
intercept = (1 + min_lr_ratio) * 0.5
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return min_lr_ratio + (1.0 - min_lr_ratio) * (
|
||||
float(current_step) / float(max(1, num_warmup_steps))
|
||||
)
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps)
|
||||
)
|
||||
x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)
|
||||
return max(min_lr_ratio, x * coef + intercept)
|
||||
|
||||
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
|
|
@ -0,0 +1,8 @@
|
|||
import torch
|
||||
|
||||
|
||||
@torch.compile
|
||||
def gather_logprobs(logits: torch.Tensor, labels: torch.Tensor):
|
||||
log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
||||
return log_probs_labels
|
|
@ -0,0 +1,9 @@
|
|||
import time
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def record_timing(name, timing_stats):
|
||||
start_time = time.perf_counter()
|
||||
yield
|
||||
timing_stats[name] = time.perf_counter() - start_time
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, List
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
import os
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_state_dict_from_repo_id_or_path(repo_id_or_path: str) -> Dict:
|
||||
"""
|
||||
Obtain a state dictionary from either a Hugging Face repo ID or a local path.
|
||||
|
||||
Args:
|
||||
repo_id_or_path (str): Either a Hugging Face repo ID (e.g., 'username/model-name')
|
||||
or a local path to a directory containing model weights.
|
||||
|
||||
Returns:
|
||||
Dict: The combined state dictionary from all .safetensors and .bin files.
|
||||
"""
|
||||
from safetensors.torch import load_file as safetensors_load
|
||||
|
||||
state_dict = {}
|
||||
|
||||
# Step 1: Identify if the input is a Hugging Face repo ID or local path
|
||||
try:
|
||||
from huggingface_hub.utils import HFValidationError, validate_repo_id
|
||||
|
||||
try:
|
||||
validate_repo_id(repo_id_or_path)
|
||||
is_hf_repo = True
|
||||
except HFValidationError:
|
||||
is_hf_repo = False
|
||||
except ImportError:
|
||||
is_hf_repo = False
|
||||
|
||||
if is_hf_repo:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Step 2: Download the repo if it's a Hugging Face repo ID
|
||||
local_path = snapshot_download(
|
||||
repo_id=repo_id_or_path,
|
||||
)
|
||||
else:
|
||||
# Assume it's a local path
|
||||
local_path = repo_id_or_path
|
||||
if not os.path.isdir(local_path):
|
||||
raise ValueError(
|
||||
f"Local path {local_path} does not exist or is not a directory, "
|
||||
f"or {local_path} is a huggingface repo id but huggingface_hub is not installed."
|
||||
)
|
||||
|
||||
# Step 3: Load all .safetensors and .bin files
|
||||
file_paths_to_load = []
|
||||
for filename in os.listdir(local_path):
|
||||
filepath = os.path.join(local_path, filename)
|
||||
if filename.endswith(".safetensors") or filename.endswith(".bin"):
|
||||
file_paths_to_load.append(filepath)
|
||||
|
||||
def _load(filepath: str):
|
||||
if filepath.endswith(".safetensors"):
|
||||
state_dict = safetensors_load(filepath)
|
||||
elif filepath.endswith(".bin"):
|
||||
state_dict = torch.load(filepath, map_location="cpu")
|
||||
else:
|
||||
raise ValueError(f"{filepath} is not a torch bin or safetensor file.")
|
||||
return state_dict
|
||||
|
||||
state_dict = {}
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=min(4, max(1, os.cpu_count() // 8))
|
||||
) as executor:
|
||||
future_to_checkpoint = {
|
||||
executor.submit(_load, path): path for path in file_paths_to_load
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_checkpoint):
|
||||
path = future_to_checkpoint[future]
|
||||
try:
|
||||
sd = future.result()
|
||||
state_dict.update(sd)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error loading checkpoint from {path}: {e}")
|
||||
return state_dict
|
|
@ -0,0 +1,53 @@
|
|||
experiment_name: gsm8k-sft
|
||||
trial_name: trial0
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
cluster:
|
||||
name_resolve:
|
||||
type: nfs
|
||||
nfs_record_root: /tmp/areal/name_resolve
|
||||
|
||||
model:
|
||||
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
|
||||
init_from_scratch: false
|
||||
pad_mbs_to_max_tokens: true
|
||||
gradient_checkpointing: false
|
||||
bf16: true
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2e-5
|
||||
weight_decay: 0.05
|
||||
beta1: 0.9
|
||||
beta2: 0.95
|
||||
eps: 1e-5
|
||||
lr_scheduler_type: cosine
|
||||
gradient_clipping: 1.0
|
||||
backend: fsdp
|
||||
|
||||
trainer:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
wandb:
|
||||
mode: disabled
|
||||
seed: 1
|
||||
exp_ctrl:
|
||||
total_train_epochs: 1
|
||||
eval_freq_steps: 1
|
||||
tokenizer_path: ${model.path}
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 4096
|
||||
|
||||
train_dataset:
|
||||
type: gsm8k-sft
|
||||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
|
||||
valid_dataset:
|
||||
type: gsm8k-sft
|
||||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
|
@ -0,0 +1,88 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
from arealite.api.cli_args import SFTConfig, load_expr_config
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
from arealite.trainer.sft import SFTTrainer
|
||||
from arealite.utils.data import pad_sequences_to_tensors
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
|
||||
|
||||
def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
|
||||
def process(sample):
|
||||
seq_token = tokenizer.encode(
|
||||
sample["question"] + sample["answer"] + tokenizer.eos_token
|
||||
)
|
||||
prompt_token = tokenizer.encode(sample["question"])
|
||||
prompt_mask = [1] * len(prompt_token) + [0] * (
|
||||
len(seq_token) - len(prompt_token)
|
||||
)
|
||||
return {"input_ids": seq_token, "prompt_mask": prompt_mask}
|
||||
|
||||
dataset = dataset.map(process).remove_columns(["question", "answer"])
|
||||
return dataset
|
||||
|
||||
|
||||
def get_gsm8k_dataset(split, tokenizer, rank, world_size):
|
||||
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
return process_gsm8k_sft_dataset(dataset, tokenizer)
|
||||
|
||||
|
||||
def main_sft():
|
||||
config, _ = load_expr_config(sys.argv[1:], SFTConfig)
|
||||
config: SFTConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
tokenizer = load_hf_tokenizer(config.trainer.tokenizer_path)
|
||||
|
||||
# Create dataset and dataloaders
|
||||
assert config.train_dataset == "gsm8k-sft"
|
||||
train_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("train", tokenizer, rank, world_size),
|
||||
batch_size=config.train_dataset.batch_size // world_size,
|
||||
shuffle=config.train_dataset.shuffle,
|
||||
num_workers=config.train_dataset.num_workers,
|
||||
collate_fn=pad_sequences_to_tensors,
|
||||
drop_last=config.train_dataset.drop_last,
|
||||
)
|
||||
assert config.valid_dataset == "gsm8k-sft"
|
||||
valid_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("test", tokenizer, rank, world_size),
|
||||
batch_size=config.valid_dataset.batch_size // world_size,
|
||||
shuffle=config.valid_dataset.shuffle,
|
||||
num_workers=config.valid_dataset.num_workers,
|
||||
collate_fn=pad_sequences_to_tensors,
|
||||
drop_last=config.valid_dataset.drop_last,
|
||||
)
|
||||
|
||||
# Initialize engine
|
||||
ft_spec = FinetuneSpec(
|
||||
total_train_epochs=config.trainer.exp_ctrl.total_train_epochs,
|
||||
dataset_size=len(train_dataloader),
|
||||
train_batch_size=config.train_dataset.batch_size,
|
||||
)
|
||||
engine = FSDPEngine(config=config.model)
|
||||
engine.initialize(None, ft_spec)
|
||||
|
||||
# Run training.
|
||||
trainer = SFTTrainer(
|
||||
config=config.trainer,
|
||||
train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
engine=engine,
|
||||
inf_engine=None,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_sft()
|
|
@ -54,6 +54,7 @@ dependencies = [
|
|||
"packaging",
|
||||
"tabulate",
|
||||
"torchdata",
|
||||
"autoflake",
|
||||
"gymnasium",
|
||||
"tensordict",
|
||||
|
||||
|
|
|
@ -107,3 +107,7 @@ def training_samples(experiment_name, trial_name):
|
|||
|
||||
def experiment_status(experiment_name, trial_name):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/experiment_status"
|
||||
|
||||
|
||||
def update_weights_from_disk(experiment_name, trial_name, model_version):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/update_weights_from_disk/{model_version}"
|
||||
|
|
|
@ -71,5 +71,6 @@ timeout-decorator
|
|||
prettytable
|
||||
swanlab[dashboard]
|
||||
torchdata
|
||||
autoflake
|
||||
gymnasium
|
||||
tensordict
|
Loading…
Reference in New Issue