AReaL/arealite/api/cli_args.py

779 lines
26 KiB
Python

import argparse
import os
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import uvloop
uvloop.install()
from hydra import compose as hydra_compose
from hydra import initialize as hydra_init
from omegaconf import MISSING, OmegaConf
from arealite.utils.fs import get_user_tmp
@dataclass
class MicroBatchSpec:
"""Specification for splitting micro-batches during training."""
n_mbs: int = field(
default=1,
metadata={
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
},
)
max_tokens_per_mb: Optional[int] = field(
default=None,
metadata={
"help": "Maximum tokens per micro-batch. When set, n_mbs becomes the minimum number of micro-batches",
},
)
@classmethod
def new(cls, mb_spec: "MicroBatchSpec", **kwargs):
"""Create new spec with updated fields while maintaining Omegaconf compatibility."""
fields = dict(
n_mbs=mb_spec.n_mbs,
max_tokens_per_mb=mb_spec.max_tokens_per_mb,
)
fields.update(kwargs)
return cls(**fields)
@dataclass
class GenerationHyperparameters:
"""Controls text generation behavior for RL training."""
n_samples: int = field(
default=1, metadata={"help": "Number of sequences to generate per prompt."}
)
max_new_tokens: int = field(
default=16384, metadata={"help": "Maximum number of tokens to generate."}
)
min_new_tokens: int = field(
default=0, metadata={"help": "Minimum number of tokens to generate."}
)
greedy: bool = field(
default=False,
metadata={"help": "Whether to use greedy decoding (max probability)."},
)
top_p: float = field(
default=1.0,
metadata={"help": "Nucleus sampling probability threshold (0.0, 1.0]."},
)
top_k: int = field(
default=int(1e8),
metadata={"help": "Number of highest probability tokens to consider."},
)
temperature: float = field(
default=1.0,
metadata={"help": "Sampling temperature. Higher values increase diversity."},
)
stop_token_ids: List[int] = field(
default_factory=list,
metadata={"help": "Stop generation when encoutering these token ids."},
)
def new(self, **kwargs):
args = asdict(self)
args.update(kwargs)
return GenerationHyperparameters(**args)
# Train Engine Configs
@dataclass
class OptimizerConfig:
"""Configuration for model optimization during training.
Note:
Set type to "empty" for models that won't be trained.
"""
type: str = field(
default="adam",
metadata={"help": "Optimizer type", "choices": ["adam", "empty"]},
)
lr: float = field(default=2e-5, metadata={"help": "Learning rate"})
weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"})
beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"})
beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"})
eps: float = field(default=1e-5, metadata={"help": "Adam epsilon parameter"})
min_lr_ratio: float = field(
default=0.0,
metadata={
"help": "Minimum learning rate ratio after annealing",
},
)
lr_scheduler_type: str = field(
default="constant",
metadata={
"help": "Learning rate scheduler type",
"choices": ["linear", "cosine", "constant"],
},
)
warmup_steps_proportion: float = field(
default=0.001,
metadata={
"help": "Proportion of training steps for warmup",
},
)
offload: bool = field(
default=False, metadata={"help": "Enable optimizer state offloading"}
)
initial_loss_scale: float = field(
default=2**32, metadata={"help": "Initial loss scaling factor"}
)
min_loss_scale: float = field(
default=1.0, metadata={"help": "Minimum loss scaling factor"}
)
loss_scale_window: float = field(
default=5, metadata={"help": "Window size for loss scaling adjustment"}
)
hysteresis: int = field(
default=2, metadata={"help": "Hysteresis (scaling factor) for loss scaling"}
)
gradient_clipping: float = field(
default=1.0, metadata={"help": "Gradient clipping threshold"}
)
@dataclass
class FSDPWrapPolicy:
transformer_layer_cls_to_wrap: Optional[List[str]] = field(
default=None,
metadata={"help": "A list of transformer layer names for FSDP to wrap."},
)
@dataclass
class FSDPEngineConfig:
wrap_policy: Optional[FSDPWrapPolicy] = field(
default=None,
metadata={"help": "FSDP wrap policy, specifying model layers to wrap."},
)
offload_params: bool = field(
default=False,
metadata={"help": "Whether to offload FSDP parameters to CPU."},
)
@dataclass
class TrainEngineConfig:
experiment_name: str = MISSING
trial_name: str = MISSING
path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"})
attn_impl: str = field(
default="flash_attention_2",
metadata={
"help": "Attention implementation for huggingface transformers model.",
"choices": ["flash_attention_2"],
},
)
init_from_scratch: bool = field(
default=False, metadata={"help": "Initialize model weights randomly"}
)
init_critic_from_actor: bool = field(
default=False,
metadata={"help": "Initialize critic/reward model from LM checkpoint"},
)
# Runtime microbatch limit
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
# Training Backend Configuration
disable_dropout: bool = field(default=False)
gradient_checkpointing: bool = field(
default=True, metadata={"help": "Enable gradient checkpointing"}
)
dtype: str = field(default="float16", metadata={"help": "Parameter dtype."})
optimizer: Optional[OptimizerConfig] = field(
default=None, metadata={"help": "Optimizer configuration"}
)
backend: str = ""
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
@dataclass
class PPOActorConfig(TrainEngineConfig):
# Core PPO/GRPO Parameters
group_size: int = field(
default=1, metadata={"help": "Number of sequences in each group"}
)
group_adv_norm: bool = field(
default=False,
metadata={
"help": "Normalize advantages within each prompt group rather than globally"
},
)
ppo_n_minibatches: int = field(
default=4, metadata={"help": "Number of minibatches for each PPO update"}
)
eps_clip: float = field(
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
)
c_clip: Optional[float] = field(
default=None,
metadata={
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
},
)
temperature: float = field(
default=1.0, metadata={"help": "Temperature during generation."}
)
# Reward
group_reward_norm: bool = field(
default=False,
metadata={
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
},
)
reward_scaling: float = field(
default=1.0, metadata={"help": "Reward scaling factor"}
)
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
reward_clip: float = field(
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
)
mask_no_eos_with_zero: bool = field(
default=False,
metadata={
"help": "Mask truncated generations (no EOS token) and exclude from training"
},
)
# Advantage Estimation
discount: float = field(
default=1.0, metadata={"help": "Discount factor for future rewards"}
)
gae_lambda: float = field(
default=1.0, metadata={"help": "Lambda parameter for GAE"}
)
adv_norm: bool = field(
default=True, metadata={"help": "Enable advantage normalization"}
)
# KL Control
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
# Asynchronous RL
recompute_logprob: bool = field(
default=False,
metadata={"help": "Recompute logp and replace the logp returned by inference."},
)
use_decoupled_loss: bool = field(
default=False,
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
)
behav_imp_weight_cap: Optional[float] = field(
default=None,
metadata={
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
},
)
@dataclass
class SGLangConfig:
"""Configuration for SGLang runtime. Refer to:
https://github.com/sgl-project/sglang for detailed documentation.
"""
model_path: str = ""
random_seed: int = 1
skip_tokenizer_init: bool = False
disable_cuda_graph: bool = False
disable_radix_cache: bool = False
disable_cuda_graph_padding: bool = False
enable_nccl_nvls: bool = False
disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_ep_moe: bool = False
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
cuda_graph_max_bs: Optional[int] = None
cuda_graph_bs: Optional[List[int]] = None
torchao_config: str = ""
enable_nan_detection: bool = False
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
triton_attention_num_kv_splits: int = 8
num_continuous_decode_steps: int = 1
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
# NOTE: to avoid the illegal memory access error
attention_backend: Optional[str] = "flashinfer"
sampling_backend: Optional[str] = None
context_length: Optional[int] = 32768
mem_fraction_static: Optional[float] = 0.9
max_running_requests: Optional[int] = None
# NOTE: chunked_prefill_size is by default 8192 on GPUs with 80GB mem in SGLang,
# but we disable it to avoid precision issues
chunked_prefill_size: Optional[int] = -1
max_prefill_tokens: int = 32768
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
dtype: str = "float16"
kv_cache_dtype: str = "auto"
# logging
log_level: str = "warning"
log_level_http: Optional[str] = "warning"
log_requests: bool = False
log_requests_level: int = 0
show_time_cost: bool = False
enable_metrics: bool = True # Exports Prometheus-like metrics
# The interval (in decoding iterations) to log throughput
# and update prometheus metrics
decode_log_interval: int = 1
# Use staticmethod to make OmegaConf happy.
@staticmethod
def build_cmd(
sglang_config: "SGLangConfig",
tp_size,
base_gpu_id,
host,
port,
dist_init_addr: Optional[str] = None,
sglang_version: Optional[str] = None,
):
from realhf.base import pkg_version
from realhf.experiments.common.utils import asdict as conf_as_dict
args: Dict = conf_as_dict(sglang_config)
args = dict(
host=host,
port=port,
# Model and tokenizer
tokenizer_path=sglang_config.model_path,
tokenizer_mode="auto",
load_format="auto",
trust_remote_code=True,
device="cuda",
is_embedding=False,
# Other runtime options
tp_size=tp_size,
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
base_gpu_id=base_gpu_id,
nnodes=1,
node_rank=0,
# initialization addresses and ports
dist_init_addr=dist_init_addr,
**args,
)
if sglang_version:
version_less_than_0_4_4 = (
pkg_version.compare_versions(sglang_version, "0.4.4") < 0
)
version_less_than_0_4_3 = (
pkg_version.compare_versions(sglang_version, "0.4.3") < 0
)
elif pkg_version.is_available("sglang"):
version_less_than_0_4_4 = pkg_version.is_version_less("sglang", "0.4.4")
version_less_than_0_4_3 = pkg_version.is_version_less("sglang", "0.4.3")
else:
raise ValueError(
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
)
if version_less_than_0_4_4:
args.pop("log_requests_level")
if version_less_than_0_4_3:
args.pop("enable_nccl_nvls")
args.pop("triton_attention_num_kv_splits")
args.pop("cuda_graph_bs")
args.pop("enable_memory_saver")
args.pop("allow_auto_truncate")
args.pop("file_storage_path")
flags = []
for k, v in args.items():
if v is None or v is False or v == "":
continue
if v is True:
flags.append(f"--{k.replace('_','-')} ")
continue
if isinstance(v, list):
values = " ".join(map(str, v))
flags.append(f"--{k.replace('_','-')} {values}")
continue
flags.append(f"--{k.replace('_','-')} {v}")
flags = " ".join(flags)
return f"python3 -m sglang.launch_server {flags}"
@dataclass
class InferenceEngineConfig:
experiment_name: str = MISSING
trial_name: str = MISSING
max_concurrent_rollouts: None | int = field(
default=None,
metadata={
"help": "Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size."
},
)
queue_size: None | int = field(
default=None,
metadata={"help": "Input/Output queue size for async rollout."},
)
consumer_batch_size: int = field(
default=1,
metadata={"help": "Batch size for consuming rollouts from the queue."},
)
max_head_offpolicyness: int = field(
default=0,
metadata={
"help": "Maximum off-policyness for the head. "
"If the current version is more than this many versions behind, "
"the request will not be accepted.",
},
)
enable_rollout_tracing: bool = field(default=False)
schedule_policy: str = field(
default="round_robin",
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
)
setup_timeout: float = field(default=90.0)
request_timeout: float = field(
default=3600, metadata={"help": "Timeout for HTTP requests."}
)
request_retries: int = field(
default=3, metadata={"help": "Number of retries for failed requests."}
)
@dataclass
class _Timer:
experiment_name: str = MISSING
trial_name: str = MISSING
fileroot: str = MISSING
freq_epochs: Optional[int] = field(
default=None,
metadata={
"help": "Trigger frequency in epochs. None disables epoch-based saving."
},
)
freq_steps: Optional[int] = field(
default=None,
metadata={
"help": "Trigger frequency in steps. None disables step-based saving."
},
)
freq_secs: Optional[int] = field(
default=None,
metadata={
"help": "Trigger frequency in seconds. None disables time-based saving."
},
)
@dataclass
class EvaluatorConfig(_Timer):
pass
@dataclass
class SaverConfig(_Timer):
pass
@dataclass
class WandBConfig:
mode: str = "disabled"
entity: Optional[str] = None
project: Optional[str] = None
name: Optional[str] = None
job_type: Optional[str] = None
group: Optional[str] = None
notes: Optional[str] = None
tags: Optional[List[str]] = None
config: Optional[Dict] = None
@dataclass
class SwanlabConfig:
project: Optional[str] = None
name: Optional[str] = None
config: Optional[Dict] = None
logdir: Optional[str] = None
mode: Optional[str] = "local"
api_key: Optional[str] = os.getenv("SWANLAB_API_KEY", None)
@dataclass
class TensorBoardConfig:
path: Optional[str] = None
@dataclass
class StatsLoggerConfig:
experiment_name: str = MISSING
trial_name: str = MISSING
fileroot: str = MISSING
wandb: WandBConfig = field(
default_factory=WandBConfig,
metadata={"help": "Weights & Biases configuration."},
)
swanlab: SwanlabConfig = field(
default_factory=SwanlabConfig,
metadata={"help": "SwanLab configuration."},
)
tensorboard: TensorBoardConfig = field(
default_factory=TensorBoardConfig,
metadata={"help": "TensorBoard configuration. Only 'path' field required."},
)
@dataclass
class NameResolveConfig:
type: str = field(
default="nfs",
metadata={
"help": "Type of the distributed KV store for name resolving.",
"choices": ["nfs", "etcd3", "ray"],
},
)
nfs_record_root: str = field(
default="/tmp/areal/name_resolve",
metadata={
"help": "Record root for NFS name resolving. Should be available in all nodes."
},
)
etcd3_addr: str = field(
default="localhost:2379", metadata={"help": "Address of the ETCD3 server."}
)
ray_actor_name: str = field(
default="ray_kv_store",
metadata={"help": "Name of the distributed Ray KV store."},
)
@dataclass
class ClusterSpecConfig:
name_resolve: NameResolveConfig = field(
default_factory=NameResolveConfig,
metadata={"help": "Name resolving configuration."},
)
cluster_name: str = field(
default="local",
metadata={"help": "Name of the cluster. Used to set specific environs."},
)
fileroot: str = field(
default=get_user_tmp(),
metadata={
"help": "Root for logs and checkpoints. Should be available to all nodes."
},
)
gpu_type: str = field(
default="tesla", metadata={"help": "GPU type of the cluster. Used by slurm."}
)
mount: str = field(
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
)
gpu_image: str = field(default="", metadata={"help": "slurm image for trainers."})
cpu_image: str = field(default="", metadata={"help": "slurm image for CPU jobs."})
gpu_infer_image: str = field(
default="", metadata={"help": "slurm image for LLM inference."}
)
node_name_prefix: str = field(
default="slurmd-", metadata={"help": "Node prefix for a slurm cluster."}
)
n_nodes: int = field(
default=32,
metadata={
"help": "The size of the cluster. Used to decide slurm hostname suffix."
},
)
n_gpus_per_node: int = field(
default=8,
metadata={"help": "GPUs per node (physically)."},
)
@dataclass
class DatasetConfig:
type: Optional[str] = field(
default=None, metadata={"help": "Type of implemented dataset"}
)
batch_size: int = field(
default=1, metadata={"help": "Batch size of the dataloader"}
)
shuffle: bool = field(
default=True, metadata={"help": "Whether to shuffle the dataset"}
)
pin_memory: bool = field(
default=False,
metadata={
"help": "Pin memory for faster data loading (set True for GPU training)"
},
)
num_workers: int = field(
default=0, metadata={"help": "Number of worker processes for data loading"}
)
drop_last: bool = field(default=True)
@dataclass
class LauncherConfig:
"""Configuration for launching the SGLang server."""
inference_server_cpus_per_gpu: int = field(
default=4,
metadata={"help": "Number of CPUs allocated per GPU for inference server. "},
)
inference_server_mem_per_gpu: int = field(
default=32 * 1024,
metadata={"help": "Memory allocated per GPU for inference server in MB. "},
)
trainer_cpus_per_gpu: int = field(
default=4,
metadata={"help": "Number of CPUs allocated per GPU for training. "},
)
trainer_mem_per_gpu: int = field(
default=32 * 1024,
metadata={"help": "Memory allocated per GPU for training in MB. "},
)
inference_server_env_vars: str = field(
default="",
metadata={
"help": "Environment variables for inference server, seperated by commas. "
"Example: 'ENV1=val1,ENV2=val2'. "
},
)
trainer_env_vars: str = field(
default="",
metadata={
"help": "Environment variables for training, seperated by commas. "
"Example: 'ENV1=val1,ENV2=val2'. "
},
)
trainer_port: int = field(
default=27015,
metadata={"help": "Trainer port used for torch.distributed initialization."},
)
@dataclass
class BaseExperimentConfig:
# NOTE: we need this unified config class because different experiments
# have different config structures, e.g., GRPO has two engine configs,
# but SFT only has a single one. We use subclasses to represent these structures.
experiment_name: str = field(
default=MISSING,
metadata={"help": "Name of the experiment (no '_' or '/'). Required."},
)
trial_name: str = field(
default=MISSING,
metadata={"help": "Name of the trial (no '-' or '/'). Required."},
)
cluster: ClusterSpecConfig = field(
default_factory=ClusterSpecConfig,
metadata={"help": "Cluster specification. Mainly used by slurm."},
)
n_nodes: int = field(
default=1, metadata={"help": "Number of nodes for experiment."}
)
n_gpus_per_node: int = field(
default=8, metadata={"help": "Number of GPUs per node for this experiment."}
)
allocation_mode: str = field(
default="",
metadata={
"help": "GPU parallel strategy allocation mode. "
"Options: manual/heuristic or pattern-based."
},
)
seed: int = field(default=1, metadata={"help": "Random seed for reproducibility."})
total_train_epochs: int = field(
default=1, metadata={"help": "Total number of epochs to train the model."}
)
total_train_steps: Optional[int] = field(
default=None,
metadata={
"help": "Terminate training after this number of steps. "
"For benchmarking purposes only. None indicates normal training."
},
)
total_train_n_seqs: Optional[int] = field(
default=None,
metadata={
"help": "Terminate training after consuming this number of samples. "
"For benchmarking purposes only. None indicates normal training."
},
)
tokenizer_path: str = field(default="")
train_dataset: DatasetConfig = field(default_factory=DatasetConfig)
valid_dataset: DatasetConfig = field(default_factory=DatasetConfig)
saver: SaverConfig = field(default_factory=SaverConfig)
checkpointer: SaverConfig = field(default_factory=SaverConfig)
evaluator: EvaluatorConfig = field(default_factory=EvaluatorConfig)
stats_logger: StatsLoggerConfig = field(default_factory=StatsLoggerConfig)
server_only: bool = False
sglang: SGLangConfig = field(default_factory=SGLangConfig)
launcher: LauncherConfig = field(default_factory=LauncherConfig)
@dataclass
class SFTConfig(BaseExperimentConfig):
model: TrainEngineConfig = field(default_factory=TrainEngineConfig)
@dataclass
class GRPOConfig(BaseExperimentConfig):
async_training: bool = field(default=True)
gconfig: GenerationHyperparameters = field(
default_factory=GenerationHyperparameters
)
rollout: InferenceEngineConfig = field(default_factory=InferenceEngineConfig)
actor: PPOActorConfig = field(default_factory=PPOActorConfig)
ref: PPOActorConfig = field(default_factory=PPOActorConfig)
def parse_cli_args(argv: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", help="The path of the main configuration file", required=True
)
args, overrides = parser.parse_known_args(argv)
# Initialize hydra config
config_file = Path(args.config).absolute()
assert config_file.exists()
# hydra only recognize relative paths
relpath = Path(os.path.relpath(str(config_file), Path(__file__).parent.absolute()))
hydra_init(config_path=str(relpath.parent), job_name="app", version_base=None)
cfg = hydra_compose(
config_name=str(relpath.name).split(".yaml")[0],
overrides=overrides,
)
return cfg, config_file
def to_structured_cfg(cfg, config_cls):
# Merge with the default configuration.
# The yaml and commandline can omit some default values defined in python dataclasses.
default_cfg = OmegaConf.structured(config_cls)
cfg = OmegaConf.merge(default_cfg, cfg)
return cfg
def load_expr_config(argv: List[str], config_cls):
cfg, config_file = parse_cli_args(argv)
cfg = to_structured_cfg(cfg, config_cls=config_cls)
cfg = OmegaConf.to_object(cfg)
assert isinstance(cfg, BaseExperimentConfig)
# Setup environment
from realhf.base import constants, name_resolve, names
constants.set_experiment_trial_names(cfg.experiment_name, cfg.trial_name)
name_resolve.reconfigure(cfg.cluster.name_resolve)
name_resolve.clear_subtree(
names.trial_root(experiment_name=cfg.experiment_name, trial_name=cfg.trial_name)
)
return cfg, str(config_file)