mirror of https://github.com/inclusionAI/AReaL
1517 lines
48 KiB
Python
1517 lines
48 KiB
Python
import getpass
|
|
import os
|
|
from dataclasses import asdict, dataclass, field, fields, is_dataclass
|
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
|
|
|
from omegaconf import MISSING
|
|
|
|
from realhf.base import pkg_version
|
|
|
|
## Data and datasets. ##
|
|
|
|
|
|
@dataclass
|
|
class MicroBatchSpec:
|
|
"""Specification for splitting micro-batches during training."""
|
|
|
|
n_mbs: int = field(
|
|
default=1,
|
|
metadata={
|
|
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
|
|
},
|
|
)
|
|
max_tokens_per_mb: int = field(
|
|
default=int(1e12),
|
|
metadata={
|
|
"help": "Maximum tokens per micro-batch. When set, n_mbs becomes the minimum number of micro-batches",
|
|
},
|
|
)
|
|
|
|
@classmethod
|
|
def new(cls, mb_spec: "MicroBatchSpec", **kwargs):
|
|
"""Create new spec with updated fields while maintaining Omegaconf compatibility."""
|
|
fields = dict(
|
|
n_mbs=mb_spec.n_mbs,
|
|
max_tokens_per_mb=mb_spec.max_tokens_per_mb,
|
|
)
|
|
fields.update(kwargs)
|
|
return cls(**fields)
|
|
|
|
|
|
@dataclass
|
|
class PromptAnswerDatasetConfig:
|
|
"""Configuration for Supervised Fine-Tuning (SFT) datasets.
|
|
|
|
Dataset format requirements:
|
|
- JSON/JSONL files
|
|
- Each entry: {"prompt": str, "answer": str}
|
|
"""
|
|
|
|
train_path: str = field(default="", metadata={"help": "Path to training dataset"})
|
|
valid_path: str = field(default="", metadata={"help": "Path to validation dataset"})
|
|
max_seqlen: int = field(
|
|
default=1024, metadata={"help": "Maximum sequence length (prompt + answer)"}
|
|
)
|
|
train_bs_n_seqs: int = field(
|
|
default=256, metadata={"help": "Training batch size in number of sequences"}
|
|
)
|
|
valid_bs_n_seqs: int = field(
|
|
default=256, metadata={"help": "Validation batch size in number of sequences"}
|
|
)
|
|
fill_to_max_length: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Pad sequences to max length. For testing only - left-fills with non-pad tokens",
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class PromptOnlyDatasetConfig:
|
|
"""Configuration for PPO RLHF datasets.
|
|
|
|
Dataset format requirements:
|
|
- JSON/JSONL files
|
|
- Each entry: {"prompt": str}
|
|
"""
|
|
|
|
path: str = field(default="", metadata={"help": "Path to dataset"})
|
|
max_prompt_len: int = field(
|
|
default=256, metadata={"help": "Maximum prompt length (truncated if longer)"}
|
|
)
|
|
train_bs_n_seqs: int = field(
|
|
default=256, metadata={"help": "Batch size in number of prompts"}
|
|
)
|
|
fill_to_max_length: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Pad sequences to max length. For testing only - left-fills with non-pad tokens",
|
|
},
|
|
)
|
|
|
|
|
|
## Model, optimizer, and backends. ##
|
|
|
|
|
|
@dataclass(unsafe_hash=True)
|
|
class ModelFamily:
|
|
"""Identifier for HuggingFace model types (e.g., llama, gpt2).
|
|
|
|
Used for model registration and allocation.
|
|
"""
|
|
|
|
_class: str = field(
|
|
metadata={
|
|
"help": "Model class name (e.g., 'llama'). Must be registered in `register_hf_family`. See "
|
|
"`realhf/api/from_hf` for supported models.",
|
|
}
|
|
)
|
|
is_critic: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Whether this is a critic/reward model. False indicates a standard LLM",
|
|
},
|
|
)
|
|
|
|
def __repr__(self):
|
|
"""Returns formatted string representation: '{class}[-critic]'."""
|
|
s = f"{self._class}"
|
|
if self.is_critic:
|
|
s += "-critic"
|
|
return s
|
|
|
|
|
|
@dataclass(unsafe_hash=True)
|
|
class ParallelismConfig:
|
|
"""Configuration for 3D parallelism (tensor, pipeline, and data parallelism).
|
|
|
|
Note:
|
|
Sequence parallelism is only used in combination with tensor-model parallelism.
|
|
"""
|
|
|
|
tensor_parallel_size: int = field(
|
|
default=1, metadata={"help": "Size of tensor-model parallelism"}
|
|
)
|
|
pipeline_parallel_size: int = field(
|
|
default=1, metadata={"help": "Number of pipeline parallel stages"}
|
|
)
|
|
data_parallel_size: int = field(
|
|
default=1, metadata={"help": "Data parallelism size for ZeRO optimization"}
|
|
)
|
|
use_sequence_parallel: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Enable sequence parallelism. Only used with tensor-model parallelism in Megatron",
|
|
},
|
|
)
|
|
|
|
def __str__(self):
|
|
"""Returns compact string representation: 'Parallel(mp=X,pp=Y,dp=Z)'."""
|
|
return (
|
|
f"Parallel(mp={self.tensor_parallel_size},"
|
|
f"pp={self.pipeline_parallel_size},"
|
|
f"dp={self.data_parallel_size})"
|
|
)
|
|
|
|
@staticmethod
|
|
def parallelism_eq(this, other):
|
|
"""Compare parallelism configurations (excluding sequence parallelism).
|
|
|
|
Note:
|
|
Implemented as static method to avoid OmegaConf compatibility issues.
|
|
"""
|
|
return (
|
|
(this.tensor_parallel_size == other.tensor_parallel_size)
|
|
and (this.pipeline_parallel_size == other.pipeline_parallel_size)
|
|
and (this.data_parallel_size == other.data_parallel_size)
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class OptimizerConfig:
|
|
"""Configuration for model optimization during training.
|
|
|
|
Note:
|
|
Set type to "empty" for models that won't be trained.
|
|
"""
|
|
|
|
type: str = field(
|
|
default="adam",
|
|
metadata={"help": "Optimizer type", "choices": ["adam", "empty"]},
|
|
)
|
|
lr: float = field(default=2e-5, metadata={"help": "Learning rate"})
|
|
weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"})
|
|
beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"})
|
|
beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"})
|
|
eps: float = field(default=1e-5, metadata={"help": "Adam epsilon parameter"})
|
|
min_lr_ratio: float = field(
|
|
default=0.0,
|
|
metadata={
|
|
"help": "Minimum learning rate ratio after annealing",
|
|
},
|
|
)
|
|
lr_scheduler_type: str = field(
|
|
default="constant",
|
|
metadata={
|
|
"help": "Learning rate scheduler type",
|
|
"choices": ["linear", "cosine", "constant"],
|
|
},
|
|
)
|
|
warmup_steps_proportion: float = field(
|
|
default=0.001,
|
|
metadata={
|
|
"help": "Proportion of training steps for warmup",
|
|
},
|
|
)
|
|
offload: bool = field(
|
|
default=False, metadata={"help": "Enable optimizer state offloading"}
|
|
)
|
|
initial_loss_scale: float = field(
|
|
default=2**32, metadata={"help": "Initial loss scaling factor"}
|
|
)
|
|
min_loss_scale: float = field(
|
|
default=1.0, metadata={"help": "Minimum loss scaling factor"}
|
|
)
|
|
loss_scale_window: float = field(
|
|
default=5, metadata={"help": "Window size for loss scaling adjustment"}
|
|
)
|
|
hysteresis: int = field(
|
|
default=2, metadata={"help": "Hysteresis (scaling factor) for loss scaling"}
|
|
)
|
|
gradient_clipping: float = field(
|
|
default=1.0, metadata={"help": "Gradient clipping threshold"}
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class vLLMConfig:
|
|
"""Configuration for vLLM inference engine. Refer to:
|
|
https://github.com/vllm-project/vllm for detailed documentation.
|
|
"""
|
|
|
|
max_num_seqs: int = 256
|
|
dtype: str = "float16"
|
|
kv_cache_type: str = "auto"
|
|
num_scheduler_steps: int = 1
|
|
multi_step_stream_outputs: bool = True
|
|
block_size: int = 16
|
|
swap_space: int = 4
|
|
cpu_offload_gb: float = 0
|
|
max_seq_len_to_capture: int = 32768
|
|
|
|
disable_sliding_window: bool = True
|
|
|
|
# NOTE: Defaults max_model_len to 32k because a larger value
|
|
# will enable chunked prefill in vLLM, which will cause
|
|
# evalution performance degeneration.
|
|
max_model_len: Optional[int] = 32768
|
|
enable_chunked_prefill: bool = False
|
|
|
|
# NOTE: Setting enable_prefix_caching to False
|
|
# because it will reuse the block after
|
|
# model weights are updated. Using v0.7.2 reset_prefix_cache
|
|
# will fix this issue.
|
|
enable_prefix_caching: bool = False
|
|
|
|
gpu_memory_utilization: float = 0.9
|
|
|
|
enforce_eager: bool = False
|
|
hybrid_train: bool = False
|
|
additional_engine_args: Dict = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class SGLangConfig:
|
|
"""Configuration for SGLang runtime. Refer to:
|
|
https://github.com/sgl-project/sglang for detailed documentation.
|
|
"""
|
|
|
|
disable_cuda_graph: bool = False
|
|
disable_radix_cache: bool = False
|
|
disable_cuda_graph_padding: bool = False
|
|
enable_nccl_nvls: bool = False
|
|
disable_outlines_disk_cache: bool = False
|
|
disable_custom_all_reduce: bool = False
|
|
disable_overlap_schedule: bool = False
|
|
enable_mixed_chunk: bool = False
|
|
enable_dp_attention: bool = False
|
|
enable_ep_moe: bool = False
|
|
enable_torch_compile: bool = False
|
|
torch_compile_max_bs: int = 32
|
|
cuda_graph_max_bs: Optional[int] = None
|
|
cuda_graph_bs: Optional[List[int]] = None
|
|
torchao_config: str = ""
|
|
enable_nan_detection: bool = False
|
|
enable_p2p_check: bool = False
|
|
triton_attention_reduce_in_fp32: bool = False
|
|
triton_attention_num_kv_splits: int = 8
|
|
num_continuous_decode_steps: int = 1
|
|
enable_memory_saver: bool = False
|
|
allow_auto_truncate: bool = False
|
|
# NOTE: to avoid the illegal memory access error
|
|
attention_backend: Optional[str] = "flashinfer"
|
|
sampling_backend: Optional[str] = None
|
|
context_length: Optional[int] = 32768
|
|
mem_fraction_static: Optional[float] = 0.9
|
|
max_running_requests: Optional[int] = None
|
|
# NOTE: chunked_prefill_size is by default 8192 on GPUs with 80GB mem in SGLang,
|
|
# but we disable it to avoid precision issues
|
|
chunked_prefill_size: Optional[int] = -1
|
|
max_prefill_tokens: int = 32768
|
|
schedule_policy: str = "lpm"
|
|
schedule_conservativeness: float = 1.0
|
|
cpu_offload_gb: int = 0
|
|
hybrid_train: bool = False
|
|
dtype: str = "float16"
|
|
kv_cache_dtype: str = "auto"
|
|
|
|
# logging
|
|
log_level: str = "warning"
|
|
log_level_http: Optional[str] = "warning"
|
|
log_requests: bool = False
|
|
log_requests_level: int = 0
|
|
show_time_cost: bool = False
|
|
enable_metrics: bool = True # Exports Prometheus-like metrics
|
|
# The interval (in decoding iterations) to log throughput
|
|
# and update prometheus metrics
|
|
decode_log_interval: int = 1
|
|
|
|
# Use staticmethod to make OmegaConf happy.
|
|
@staticmethod
|
|
def build_cmd(
|
|
sglang_config: "SGLangConfig",
|
|
model_path,
|
|
tp_size,
|
|
server_index,
|
|
base_gpu_id,
|
|
dist_init_addr: Optional[str] = None,
|
|
):
|
|
from realhf.base import constants, network, pkg_version, seeding
|
|
from realhf.experiments.common.utils import asdict as conf_as_dict
|
|
|
|
args: Dict = conf_as_dict(sglang_config)
|
|
args.pop("hybrid_train")
|
|
args["random_seed"] = seeding.get_seed()
|
|
|
|
host_ip = network.gethostip()
|
|
host = "localhost" if not sglang_config.enable_metrics else host_ip
|
|
args = dict(
|
|
host=host,
|
|
model_path=model_path,
|
|
# Model and tokenizer
|
|
tokenizer_path=model_path,
|
|
tokenizer_mode="auto",
|
|
load_format="auto",
|
|
trust_remote_code=True,
|
|
device="cuda",
|
|
served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{model_path}",
|
|
is_embedding=False,
|
|
skip_tokenizer_init=True,
|
|
# Other runtime options
|
|
tp_size=tp_size,
|
|
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
|
|
base_gpu_id=base_gpu_id,
|
|
file_storage_path=os.path.join(
|
|
constants.SGLANG_CACHE_PATH,
|
|
f"sglang_storage{server_index}",
|
|
),
|
|
# Data parallelism
|
|
dp_size=1, # TODO: check whether we require SGLang dp
|
|
load_balance_method="round_robin",
|
|
# Expert parallelism
|
|
ep_size=1, # TODO: check
|
|
nnodes=1,
|
|
node_rank=0,
|
|
dist_init_addr=dist_init_addr,
|
|
**args,
|
|
)
|
|
|
|
if pkg_version.is_version_less("sglang", "0.4.4"):
|
|
args.pop("log_requests_level")
|
|
if pkg_version.is_version_less("sglang", "0.4.3"):
|
|
args.pop("enable_nccl_nvls")
|
|
args.pop("triton_attention_num_kv_splits")
|
|
args.pop("cuda_graph_bs")
|
|
args.pop("enable_memory_saver")
|
|
args.pop("allow_auto_truncate")
|
|
args.pop("file_storage_path")
|
|
|
|
flags = []
|
|
for k, v in args.items():
|
|
if v is None or v is False or v == "":
|
|
continue
|
|
if v is True:
|
|
flags.append(f"--{k.replace('_','-')} ")
|
|
continue
|
|
if isinstance(v, list):
|
|
values = " ".join(map(str, v))
|
|
flags.append(f"--{k.replace('_','-')} {values}")
|
|
continue
|
|
flags.append(f"--{k.replace('_','-')} {v}")
|
|
flags = " ".join(flags)
|
|
return f"python3 -m sglang.launch_server {flags}"
|
|
|
|
|
|
@dataclass
|
|
class DistributedDataParallelConfig:
|
|
"""Configuration for Megatron's DistributedDataParallel.
|
|
Refer to Megatron-LM documentation for details.
|
|
"""
|
|
|
|
grad_reduce_in_fp32: bool = True
|
|
overlap_grad_reduce: bool = True
|
|
overlap_param_gather: bool = False
|
|
align_param_gather: bool = False
|
|
use_distributed_optimizer: bool = True
|
|
check_for_nan_in_grad: bool = False
|
|
bucket_size: Optional[int] = None
|
|
average_in_collective: bool = False
|
|
fp8_param_gather: bool = False
|
|
|
|
|
|
@dataclass
|
|
class MegatronConfig:
|
|
"""Configuration for Megatron-LM training framework.
|
|
Refer to Megatron-LM documentation for implementation details.
|
|
"""
|
|
|
|
# Distributed Training Configuration
|
|
ddp: DistributedDataParallelConfig = field(
|
|
default_factory=DistributedDataParallelConfig
|
|
)
|
|
# Don't use MegatronOptimizerConfig here because OmegaConf
|
|
# does not recognize the annotation "torch.dtype"
|
|
overlap_param_gather_with_optimizer_step: bool = False
|
|
|
|
# Precision Configuration
|
|
use_precision_aware_optimizer: bool = False
|
|
main_grads_dtype: str = "float32"
|
|
main_params_dtype: str = "float32"
|
|
exp_avg_dtype: str = "float32"
|
|
exp_avg_sq_dtype: str = "float32"
|
|
|
|
|
|
@dataclass
|
|
class ModelTrainEvalConfig:
|
|
"""Runtime configuration for LLMs in ReaL framework.
|
|
|
|
Uses a custom model implementation supporting:
|
|
- 3D and sequence parallelism
|
|
- Flash attention for training/generation
|
|
- Packed 1D tensor inputs for memory efficiency
|
|
|
|
Note: Requires manual conversion from HuggingFace models.
|
|
Implemented conversions are in `realhf/api/from_hf/`.
|
|
"""
|
|
|
|
# Model Architecture Configuration
|
|
type: ModelFamily = field(
|
|
default=ModelFamily("llama", False),
|
|
metadata={"help": "Model family specification"},
|
|
)
|
|
path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"})
|
|
init_from_scratch: bool = field(
|
|
default=False, metadata={"help": "Initialize model weights randomly"}
|
|
)
|
|
init_critic_from_actor: bool = field(
|
|
default=False,
|
|
metadata={"help": "Initialize critic/reward model from LM checkpoint"},
|
|
)
|
|
|
|
# Training Backend Configuration
|
|
backend: str = field(
|
|
default="megatron",
|
|
metadata={"help": "Training backend", "choices": ["megatron"]},
|
|
)
|
|
gradient_checkpointing: bool = field(
|
|
default=True, metadata={"help": "Enable memory-saving gradient checkpointing"}
|
|
)
|
|
bf16: bool = field(
|
|
default=False, metadata={"help": "Use bf16 precision (otherwise fp16)"}
|
|
)
|
|
|
|
# Backend-Specific Configurations
|
|
optimizer: Optional[OptimizerConfig] = field(
|
|
default_factory=OptimizerConfig, metadata={"help": "Optimizer configuration"}
|
|
)
|
|
megatron: MegatronConfig = field(
|
|
default_factory=MegatronConfig,
|
|
metadata={
|
|
"help": "Megatron-specific configuration. Can be ignored if this model is not trained."
|
|
},
|
|
)
|
|
vllm: vLLMConfig = field(
|
|
default_factory=vLLMConfig,
|
|
metadata={
|
|
"help": "vLLM inference configuration. Can be ignored if this model doesn't use vLLM."
|
|
},
|
|
)
|
|
sglang: SGLangConfig = field(
|
|
default_factory=SGLangConfig,
|
|
metadata={
|
|
"help": "SGLang runtime configuration. Can be ignored if this model doesn't use SGLang."
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class MFCConfig:
|
|
"""Configuration for a single Micro-Function Chain (MFC).
|
|
|
|
Contains specifications for micro-batch splitting and parallel execution.
|
|
|
|
device_mesh format depends on scope:
|
|
- Multi-node: SLURM nodelist (e.g., 'node[01-02]' or 'node01,node02')
|
|
- Single-node: MPI-style hostfile format (e.g., 'node01:0,1,2,3' for first 4 GPUs)
|
|
Must use 1, 2, 4, or 8 contiguous GPUs on single node
|
|
"""
|
|
|
|
mb_spec: MicroBatchSpec = field(
|
|
default_factory=MicroBatchSpec,
|
|
metadata={
|
|
"help": "Micro-batch splitting specification",
|
|
},
|
|
)
|
|
parallel: ParallelismConfig = field(
|
|
default_factory=ParallelismConfig,
|
|
metadata={
|
|
"help": "Parallelism strategy.",
|
|
},
|
|
)
|
|
device_mesh: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Device mesh specification for manual allocation",
|
|
},
|
|
)
|
|
|
|
|
|
## RL related. ##
|
|
|
|
|
|
@dataclass
|
|
class GenerationHyperparameters:
|
|
"""Controls text generation behavior for PPO training."""
|
|
|
|
n: int = field(
|
|
default=1, metadata={"help": "Number of sequences to generate per prompt."}
|
|
)
|
|
max_new_tokens: int = field(
|
|
default=16384, metadata={"help": "Maximum number of tokens to generate."}
|
|
)
|
|
min_new_tokens: int = field(
|
|
default=0, metadata={"help": "Minimum number of tokens to generate."}
|
|
)
|
|
greedy: bool = field(
|
|
default=False,
|
|
metadata={"help": "Whether to use greedy decoding (max probability)."},
|
|
)
|
|
top_p: float = field(
|
|
default=1.0,
|
|
metadata={"help": "Nucleus sampling probability threshold (0.0, 1.0]."},
|
|
)
|
|
top_k: int = field(
|
|
default=int(1e8),
|
|
metadata={"help": "Number of highest probability tokens to consider."},
|
|
)
|
|
temperature: float = field(
|
|
default=1.0,
|
|
metadata={"help": "Sampling temperature. Higher values increase diversity."},
|
|
)
|
|
|
|
# Deprecated parameters
|
|
use_cuda_graph: bool = field(
|
|
default=True,
|
|
metadata={"help": "[Deprecated] Whether to use CUDA graph optimization."},
|
|
)
|
|
force_cudagraph_recapture: bool = field(
|
|
default=True,
|
|
metadata={"help": "[Deprecated] Force CUDA graph recapture to release memory."},
|
|
)
|
|
force_no_logits_mask: bool = field(
|
|
default=True,
|
|
metadata={
|
|
"help": "[Deprecated] Disable logits masking (reduces stability but saves memory)."
|
|
},
|
|
)
|
|
|
|
def __post_init__(self):
|
|
if self.temperature == 0.0:
|
|
self.greedy = True
|
|
self.temperature = 1.0
|
|
if self.top_p <= 0.0 or self.top_p > 1:
|
|
raise ValueError("top_p must be in (0.0, 1.0].")
|
|
if self.top_k <= 0:
|
|
raise ValueError("top_k must be a positive integer.")
|
|
|
|
if self.use_cuda_graph and pkg_version.is_version_less("torch", "2.3.0"):
|
|
raise ValueError(
|
|
f"To use CUDAGraph, ReaL's PyTorch version should be at least 2.3.0."
|
|
)
|
|
|
|
def new(self, **kwargs):
|
|
args = asdict(self)
|
|
args.update(kwargs)
|
|
return GenerationHyperparameters(**args)
|
|
|
|
|
|
@dataclass
|
|
class PPOHyperparameters:
|
|
"""Configuration for Proximal Policy Optimization (PPO) training parameters."""
|
|
|
|
# Generation Configuration
|
|
gen: GenerationHyperparameters = field(
|
|
default_factory=GenerationHyperparameters,
|
|
metadata={"help": "Text generation hyperparameters"},
|
|
)
|
|
|
|
# Core PPO Parameters
|
|
ppo_n_minibatches: int = field(
|
|
default=4, metadata={"help": "Number of minibatches for each PPO update"}
|
|
)
|
|
eps_clip: float = field(
|
|
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
|
|
)
|
|
c_clip: Optional[float] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
|
|
},
|
|
)
|
|
value_eps_clip: float = field(
|
|
default=0.2, metadata={"help": "Clipping factor for value updates"}
|
|
)
|
|
early_stop_imp_ratio: float = field(
|
|
default=5.0, metadata={"help": "Early stop threshold for importance ratio"}
|
|
)
|
|
actor_sample_reuse: int = field(
|
|
default=1, metadata={"help": "The data reuse (aka PPO epoch) for actor."}
|
|
)
|
|
critic_sample_reuse: int = field(
|
|
default=1, metadata={"help": "The data reuse (aka PPO epoch) for critic."}
|
|
)
|
|
|
|
# Reward Processing
|
|
max_reward_clip: float = field(
|
|
default=20.0, metadata={"help": "Maximum absolute value for clipped rewards"}
|
|
)
|
|
reward_output_scaling: float = field(
|
|
default=1.0, metadata={"help": "Scaling factor for reward model outputs"}
|
|
)
|
|
reward_output_bias: float = field(
|
|
default=0.0, metadata={"help": "Bias term for reward model outputs"}
|
|
)
|
|
fuse_rew_ref: bool = field(
|
|
default=True,
|
|
metadata={"help": "Whether to fuse reward and reference model computations"},
|
|
)
|
|
|
|
# Advantage Estimation
|
|
discount: float = field(
|
|
default=1.0, metadata={"help": "Discount factor for future rewards"}
|
|
)
|
|
gae_lambda: float = field(
|
|
default=1.0, metadata={"help": "Lambda parameter for GAE"}
|
|
)
|
|
adv_norm: bool = field(
|
|
default=True, metadata={"help": "Enable advantage normalization"}
|
|
)
|
|
|
|
# KL Control
|
|
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
|
|
use_adaptive_kl_ctl: bool = field(
|
|
default=False, metadata={"help": "Use adaptive KL coefficient control"}
|
|
)
|
|
|
|
# Value Function Configuration
|
|
disable_value: bool = field(
|
|
default=False, metadata={"help": "Disable value/critic model"}
|
|
)
|
|
value_norm: bool = field(
|
|
default=True, metadata={"help": "Enable value normalization"}
|
|
)
|
|
value_norm_type: str = field(
|
|
default="exp",
|
|
metadata={"help": "Type of value normalization", "choices": ["exp", "ma"]},
|
|
)
|
|
value_norm_beta: float = field(
|
|
default=0.99995,
|
|
metadata={"help": "Decay factor for exponential moving average"},
|
|
)
|
|
value_norm_eps: float = field(
|
|
default=1e-5, metadata={"help": "Epsilon term for numerical stability"}
|
|
)
|
|
recompute_logprob: bool = field(
|
|
default=False,
|
|
metadata={"help": "Recompute logp and replace the logp returned by inference."},
|
|
)
|
|
use_decoupled_loss: bool = field(
|
|
default=False,
|
|
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
|
|
)
|
|
|
|
|
|
## Experiment utilities. ##
|
|
|
|
|
|
@dataclass
|
|
class ExperimentSaveEvalControl:
|
|
"""Controls the frequency of model saving and evaluation during training.
|
|
|
|
Manages independent counters for epochs, steps, and seconds. The model will be saved
|
|
or evaluated when any specified frequency condition is met.
|
|
|
|
Note:
|
|
- Epoch: Number of full passes through the training dataset
|
|
- Step: Number of individual training iterations
|
|
- Seconds: Wall-clock time duration
|
|
"""
|
|
|
|
total_train_epochs: int = field(
|
|
default=1, metadata={"help": "Total number of epochs to train the model."}
|
|
)
|
|
# Save control
|
|
save_freq_epochs: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Save frequency in epochs. None disables epoch-based saving."
|
|
},
|
|
)
|
|
save_freq_steps: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "Save frequency in steps. None disables step-based saving."},
|
|
)
|
|
save_freq_secs: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Save frequency in seconds. None disables time-based saving."
|
|
},
|
|
)
|
|
# Checkpointing control
|
|
ckpt_freq_epochs: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Checkpoint frequency in epochs. None uses save_freq_epochs. "
|
|
"Checkpointing is used for recover. Previous checkpoint is overwritten to save space."
|
|
},
|
|
)
|
|
ckpt_freq_steps: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Checkpoint frequency in steps. None disables step-based checkpointing."
|
|
},
|
|
)
|
|
ckpt_freq_secs: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Checkpoint frequency in seconds. None disables time-based checkpointing."
|
|
},
|
|
)
|
|
# Evaluation control
|
|
eval_freq_epochs: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Evaluation frequency in epochs. None disables epoch-based evaluation."
|
|
},
|
|
)
|
|
eval_freq_steps: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Evaluation frequency in steps. None disables step-based evaluation."
|
|
},
|
|
)
|
|
eval_freq_secs: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Evaluation frequency in seconds. None disables time-based evaluation."
|
|
},
|
|
)
|
|
# Benchmark control
|
|
benchmark_steps: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Terminate training after this number of steps. "
|
|
"For benchmarking purposes only. None indicates normal training."
|
|
},
|
|
)
|
|
benchmark_n_seqs: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Terminate training after consuming this number of samples. "
|
|
"For benchmarking purposes only. None indicates normal training."
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class AutomaticEvaluator:
|
|
"""Configuration for automatic model evaluation during training.
|
|
|
|
Controls how and when evaluation jobs are launched to assess model performance
|
|
on specified datasets.
|
|
"""
|
|
|
|
data_names: str = field(
|
|
default="aime24",
|
|
metadata={
|
|
"help": "Comma-separated dataset names for evaluation. "
|
|
"Supported datasets: 'aime24', 'amc23', 'math_500'."
|
|
},
|
|
)
|
|
max_gen_tokens: int = field(
|
|
default=32768,
|
|
metadata={"help": "Maximum number of tokens to generate during evaluation."},
|
|
)
|
|
max_concurrent_jobs: int = field(
|
|
default=3,
|
|
metadata={
|
|
"help": "Maximum number of concurrent evaluation jobs. "
|
|
"New jobs wait when this limit is reached."
|
|
},
|
|
)
|
|
eval_job_image: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Container image for evaluation jobs. "
|
|
"None uses the training GPU image."
|
|
},
|
|
)
|
|
initial_checkpoint_path: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Initial checkpoint to evaluate. "
|
|
"Results stored as global_step=0 if specified."
|
|
},
|
|
)
|
|
prompt_type: str = field(
|
|
default="deepscaler",
|
|
metadata={"help": "Prompt format to use during evaluation."},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class WandBConfig:
|
|
mode: str = "disabled"
|
|
entity: Optional[str] = None
|
|
project: Optional[str] = None
|
|
name: Optional[str] = None
|
|
job_type: Optional[str] = None
|
|
group: Optional[str] = None
|
|
notes: Optional[str] = None
|
|
tags: Optional[List[str]] = None
|
|
config: Optional[Dict] = None
|
|
|
|
|
|
@dataclass
|
|
class TensorBoardConfig:
|
|
path: Optional[str] = None
|
|
|
|
|
|
def get_user_tmp():
|
|
user = getpass.getuser()
|
|
user_tmp = os.path.join("/home", user, ".cache", "realhf")
|
|
os.makedirs(user_tmp, exist_ok=True)
|
|
return user_tmp
|
|
|
|
|
|
@dataclass
|
|
class ClusterSpecConfig:
|
|
config_path: str = field(
|
|
default="",
|
|
metadata={
|
|
"help": "JSON config path. If not given, use the following CLI args."
|
|
},
|
|
)
|
|
cluster_name: str = field(
|
|
default="local",
|
|
metadata={"help": "Name of the cluster. Used to set specific environs."},
|
|
)
|
|
fileroot: str = field(
|
|
default=get_user_tmp(),
|
|
metadata={
|
|
"help": "Root for logs and checkpoints. Should be available to all nodes."
|
|
},
|
|
)
|
|
gpu_type: str = field(
|
|
default="tesla", metadata={"help": "GPU type of the cluster. Used by slurm."}
|
|
)
|
|
mount: str = field(
|
|
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
|
|
)
|
|
gpu_image: str = field(default="", metadata={"help": "slurm image for trainers."})
|
|
cpu_image: str = field(default="", metadata={"help": "slurm image for CPU jobs."})
|
|
gpu_infer_image: str = field(
|
|
default="", metadata={"help": "slurm image for LLM inference."}
|
|
)
|
|
node_name_prefix: str = field(
|
|
default="slurmd-", metadata={"help": "Node prefix for a slurm cluster."}
|
|
)
|
|
n_nodes: int = field(
|
|
default=32,
|
|
metadata={
|
|
"help": "The size of the cluster. Used to decide slurm hostname suffix."
|
|
},
|
|
)
|
|
n_gpus_per_node: int = field(
|
|
default=8,
|
|
metadata={"help": "GPUs per node (physically)."},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class BaseExperimentConfig:
|
|
"""Configuration for quickstart experiments.
|
|
|
|
All parameters can be modified via command line arguments. Supports various
|
|
recovery modes and parallelization strategies.
|
|
|
|
Note:
|
|
- Recovery modes: auto, fault, resume, disabled
|
|
- Allocation modes: manual, heuristic, or pattern-based
|
|
"""
|
|
|
|
experiment_name: str = field(
|
|
default=MISSING,
|
|
metadata={"help": "Name of the experiment (no '_' or '/'). Required."},
|
|
)
|
|
trial_name: str = field(
|
|
default=MISSING,
|
|
metadata={"help": "Name of the trial (no '-' or '/'). Required."},
|
|
)
|
|
mode: str = field(
|
|
default="slurm",
|
|
metadata={
|
|
"help": "Experiment launching mode.",
|
|
"choices": ["slurm", "local", "ray"],
|
|
},
|
|
)
|
|
debug: bool = field(
|
|
default=True,
|
|
metadata={
|
|
"help": "Debug mode. False disables assertions for better performance."
|
|
},
|
|
)
|
|
metric_discovery_port: int = field(
|
|
default=0,
|
|
metadata={"help": "Discovery port for prometheus metrics service discovery."},
|
|
)
|
|
partition: str = field(
|
|
default="dev", metadata={"help": "SLURM partition for running the experiment."}
|
|
)
|
|
schedule_strategy: str = field(
|
|
default="empty_first", metadata={"help": "Resource scheduling strategy."}
|
|
)
|
|
wandb: WandBConfig = field(
|
|
default_factory=WandBConfig,
|
|
metadata={"help": "Weights & Biases configuration."},
|
|
)
|
|
tensorboard: TensorBoardConfig = field(
|
|
default_factory=TensorBoardConfig,
|
|
metadata={"help": "TensorBoard configuration. Only 'path' field required."},
|
|
)
|
|
image_name: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Docker image name for controller (SLURM mode only)."},
|
|
)
|
|
recover_mode: str = field(
|
|
default="disabled",
|
|
metadata={
|
|
"help": "Recovery mode (auto/fault/resume/disabled). "
|
|
"Use 'disabled' if unfamiliar with recovery mechanism."
|
|
},
|
|
)
|
|
recover_retries: int = field(
|
|
default=1,
|
|
metadata={"help": "Number of recovery retries (auto/fault modes only)."},
|
|
)
|
|
recover_after: int = field(
|
|
default=10,
|
|
metadata={"help": "Recovery interval in seconds (auto/fault modes only)."},
|
|
)
|
|
ignore_worker_error: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Ignore worker runtime errors (disabled mode only). "
|
|
"Only enable if certain errors can be safely ignored."
|
|
},
|
|
)
|
|
allocation_mode: str = field(
|
|
default="",
|
|
metadata={
|
|
"help": "GPU parallel strategy allocation mode. "
|
|
"Options: manual/heuristic or pattern-based."
|
|
},
|
|
)
|
|
n_nodes: int = field(
|
|
default=1, metadata={"help": "Number of nodes for experiment."}
|
|
)
|
|
n_gpus_per_node: int = field(
|
|
default=8, metadata={"help": "Number of GPUs per node for this experiment."}
|
|
)
|
|
nodelist: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "SLURM nodelist for manual allocation. "
|
|
"Format: 'slurmd-01:0,1,2,3' or 'slurmd-[01-02,03,07],COM08'."
|
|
},
|
|
)
|
|
exclude: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "SLURM nodelist to exclude from allocation. "
|
|
"Format: 'slurmd-01:0,1,2,3' or 'slurmd-[01-02,03,07],COM08'."
|
|
},
|
|
)
|
|
seed: int = field(default=1, metadata={"help": "Random seed for reproducibility."})
|
|
cache_clear_freq: Optional[int] = field(
|
|
default=10,
|
|
metadata={
|
|
"help": "Clear data transfer cache every N steps. "
|
|
"Set lower if OOM occurs. None disables clearing."
|
|
},
|
|
)
|
|
exp_ctrl: ExperimentSaveEvalControl = field(
|
|
default_factory=ExperimentSaveEvalControl,
|
|
metadata={"help": "Experiment save/evaluation control configuration."},
|
|
)
|
|
torch_cache_mysophobia: bool = field(
|
|
default=True,
|
|
metadata={
|
|
"help": "Clear torch cache before each RPC (~0.1s overhead per RPC)."
|
|
},
|
|
)
|
|
auto_eval: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Enable automatic evaluation during training. "
|
|
"Results logged to disk and WandB (if active)."
|
|
},
|
|
)
|
|
auto_eval_config: AutomaticEvaluator = field(
|
|
default_factory=AutomaticEvaluator,
|
|
metadata={"help": "Automatic evaluation configuration."},
|
|
)
|
|
cpus_per_master_worker: int = field(
|
|
default=4, metadata={"help": "CPU cores per master worker."}
|
|
)
|
|
mem_per_master_worker: int = field(
|
|
default=20000, metadata={"help": "Memory per master worker (MB)."}
|
|
)
|
|
cpus_per_model_worker: int = field(
|
|
default=4, metadata={"help": "CPU cores per model worker."}
|
|
)
|
|
mem_per_model_worker: int = field(
|
|
default=90000, metadata={"help": "Memory per model worker (MB)."}
|
|
)
|
|
shuffle_dataset: bool = field(
|
|
default=True, metadata={"help": "Shuffle in each epoch."}
|
|
)
|
|
ray_temp_path: str = field(
|
|
default="/tmp/ray", metadata={"help": "Absolute path for Ray's log."}
|
|
)
|
|
cluster: ClusterSpecConfig = field(
|
|
default_factory=ClusterSpecConfig,
|
|
metadata={"help": "Cluster specification. Mainly used by slurm."},
|
|
)
|
|
|
|
|
|
## Configuration options of asynchronous experiments. ##
|
|
|
|
|
|
@dataclass
|
|
class AsyncRLOptions:
|
|
schedule_policy: str = field(
|
|
default="round_robin",
|
|
metadata={
|
|
"help": "The request schedule policy during generation. Available options: [round_robin]."
|
|
},
|
|
)
|
|
new_tokens_per_chunk: int = field(
|
|
default=int(1e10),
|
|
metadata={
|
|
"help": "The length of chunked generation. Only valid if inference can't be interrupted."
|
|
},
|
|
)
|
|
max_head_offpolicyness: int = field(
|
|
default=0,
|
|
metadata={"help": "Maximum off-policyness tolerance for the first token."},
|
|
)
|
|
|
|
n_rollout_workers: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Number of rollout workers. None defaults to train world size."
|
|
},
|
|
)
|
|
max_concurrent_rollouts: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Max concurrent rollouts globally. Defaults to train batch size."
|
|
},
|
|
)
|
|
flush_request_timeout: int = field(
|
|
default=300,
|
|
metadata={"help": "The timeout of flushing requests upon weight update."},
|
|
)
|
|
|
|
cpus_per_generation_server: int = field(
|
|
default=4, metadata={"help": "Generation server CPUs."}
|
|
)
|
|
mem_per_generation_server: int = field(
|
|
default=60 * 1024, metadata={"help": "Generation server CPU memory in MB."}
|
|
)
|
|
cpus_per_gserver_manager: int = field(
|
|
default=4, metadata={"help": "Generation manager CPUs."}
|
|
)
|
|
mem_per_gserver_manager: int = field(
|
|
default=10 * 1024, metadata={"help": "Generation manager CPU memory in MB."}
|
|
)
|
|
cpus_per_rollout_worker: int = field(
|
|
default=4, metadata={"help": "Rollout worker CPUs."}
|
|
)
|
|
mem_per_rollout_worker: int = field(
|
|
default=20 * 1024, metadata={"help": "Rollout worker CPU memory in MB."}
|
|
)
|
|
|
|
|
|
## Configurations for practical experiments. ##
|
|
|
|
|
|
@dataclass
|
|
class NullPPOExperimentOptions:
|
|
"""Configuration for a null PPO experiment (testing purposes only)."""
|
|
|
|
model: ModelTrainEvalConfig = field(
|
|
default_factory=ModelTrainEvalConfig,
|
|
metadata={"help": "Model configuration for testing."},
|
|
)
|
|
inf: MFCConfig = field(
|
|
default_factory=MFCConfig,
|
|
metadata={"help": "Inference model function call configuration."},
|
|
)
|
|
train: MFCConfig = field(
|
|
default_factory=MFCConfig,
|
|
metadata={"help": "Training model function call configuration."},
|
|
)
|
|
dataset: PromptOnlyDatasetConfig = field(
|
|
default_factory=PromptOnlyDatasetConfig,
|
|
metadata={"help": "Dataset configuration for testing."},
|
|
)
|
|
dataset_filter_threshold: float = field(
|
|
default=0.2,
|
|
metadata={"help": "Threshold value for dataset filtering in tests."},
|
|
)
|
|
dataset_max_filter_percentage: float = field(
|
|
default=0.1,
|
|
metadata={"help": "Maximum percentage of dataset to filter in tests."},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SFTExperimentOptions:
|
|
"""Configuration for supervised fine-tuning (SFT) experiments."""
|
|
|
|
model: ModelTrainEvalConfig = field(
|
|
default_factory=ModelTrainEvalConfig,
|
|
metadata={"help": "Model runtime configuration."},
|
|
)
|
|
allocation: MFCConfig = field(
|
|
default_factory=MFCConfig,
|
|
metadata={"help": "Device allocation and parallelism configuration."},
|
|
)
|
|
dataset: PromptAnswerDatasetConfig = field(
|
|
default_factory=PromptAnswerDatasetConfig,
|
|
metadata={"help": "Dataset configuration."},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class PPOMATHExperimentOptions:
|
|
"""Configuration for PPO (Proximal Policy Optimization) experiments.
|
|
|
|
Manages four distinct models and their interactions through model function calls.
|
|
|
|
Note:
|
|
Models:
|
|
- Actor: Primary LLM for text generation
|
|
- Critic: Value function estimator
|
|
- Ref: Reference model for KL regularization
|
|
- Rew: Reward model (or function) for reward signals
|
|
"""
|
|
|
|
# Model configurations
|
|
actor: ModelTrainEvalConfig = field(
|
|
default_factory=ModelTrainEvalConfig,
|
|
metadata={"help": "Primary LLM configuration."},
|
|
)
|
|
critic: ModelTrainEvalConfig = field(
|
|
default_factory=ModelTrainEvalConfig,
|
|
metadata={"help": "Critic model configuration."},
|
|
)
|
|
ref: ModelTrainEvalConfig = field(
|
|
default_factory=ModelTrainEvalConfig,
|
|
metadata={"help": "Reference model configuration."},
|
|
)
|
|
rew: ModelTrainEvalConfig = field(
|
|
default_factory=ModelTrainEvalConfig,
|
|
metadata={"help": "Reward model configuration."},
|
|
)
|
|
|
|
# Model function call configurations
|
|
actor_train: MFCConfig = field(
|
|
default_factory=MFCConfig, metadata={"help": "TrainActor MFC configuration."}
|
|
)
|
|
critic_train: MFCConfig = field(
|
|
default_factory=MFCConfig, metadata={"help": "TrainCritic MFC configuration."}
|
|
)
|
|
actor_gen: MFCConfig = field(
|
|
default_factory=MFCConfig, metadata={"help": "Rollout MFC configuration."}
|
|
)
|
|
critic_inf: MFCConfig = field(
|
|
default_factory=MFCConfig, metadata={"help": "InfValues MFC configuration."}
|
|
)
|
|
rew_inf: MFCConfig = field(
|
|
default_factory=MFCConfig, metadata={"help": "InfReward MFC configuration."}
|
|
)
|
|
ref_inf: MFCConfig = field(
|
|
default_factory=MFCConfig, metadata={"help": "InfRef MFC configuration."}
|
|
)
|
|
actor_inf: MFCConfig = field(
|
|
default_factory=MFCConfig,
|
|
metadata={"help": "Actor inference MFC configuration."},
|
|
)
|
|
|
|
# Dataset and algorithm configurations
|
|
dataset: PromptOnlyDatasetConfig = field(
|
|
default_factory=PromptOnlyDatasetConfig,
|
|
metadata={"help": "Dataset configuration."},
|
|
)
|
|
ppo: PPOHyperparameters = field(
|
|
default_factory=PPOHyperparameters,
|
|
metadata={"help": "PPO algorithm hyperparameters."},
|
|
)
|
|
|
|
# Sampling and reward processing
|
|
group_size: int = field(
|
|
default=1,
|
|
metadata={"help": "Number of answers retained per prompt (best-of-n)."},
|
|
)
|
|
generation_size: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Number of answers sampled per prompt. None uses group_size."
|
|
},
|
|
)
|
|
mask_no_eos_with_zero: bool = field(
|
|
default=False,
|
|
metadata={"help": "Mask reward for truncated answers (no EOS token)."},
|
|
)
|
|
mask_too_long: bool = field(
|
|
default=False, metadata={"help": "Mask PPO loss for length-truncated answers."}
|
|
)
|
|
check_verifier_status: bool = field(
|
|
default=False,
|
|
metadata={"help": "Raise error if reward is all-zero (verifier bug check)."},
|
|
)
|
|
group_adv_norm: bool = field(
|
|
default=False, metadata={"help": "Use grouped advantage normalization in GRPO."}
|
|
)
|
|
ref_ema_eta: Optional[float] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "EMA decay rate for reference model updates. 1.0 means full update."
|
|
},
|
|
)
|
|
rw_type: Optional[str] = field(
|
|
default="sparse",
|
|
metadata={
|
|
"help": "Type of reward processing. Only `sparse` is valid for now.",
|
|
"choices": ["sparse"],
|
|
},
|
|
)
|
|
check_xml_format: bool = field(
|
|
default=False, metadata={"help": "Validate XML format in generated responses."}
|
|
)
|
|
|
|
# Dataset filtering
|
|
dataset_filter_threshold: float = field(
|
|
default=100.0,
|
|
metadata={
|
|
"help": "Rewards higher than this value will be filtered out after each epoch's training."
|
|
},
|
|
)
|
|
dataset_max_filter_percentage: float = field(
|
|
default=0.0, metadata={"help": "Maximum percentage of dataset to each filter."}
|
|
)
|
|
|
|
success_rate_ub: float = field(
|
|
default=1.0,
|
|
metadata={
|
|
"help": "Success rate higher than this value will be filtered out after generation. Valid for async training."
|
|
},
|
|
)
|
|
success_rate_lb: float = field(
|
|
default=0.0,
|
|
metadata={
|
|
"help": "Success rate lower than this value will be filtered out after generation. Valid for async training."
|
|
},
|
|
)
|
|
|
|
# testing only
|
|
no_training: bool = field(
|
|
default=False,
|
|
metadata={"help": "Run without training. Test-only."},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class MathCodeEvalOptions:
|
|
gen_config: GenerationHyperparameters = field(
|
|
default_factory=GenerationHyperparameters
|
|
)
|
|
|
|
actor: ModelTrainEvalConfig = field(
|
|
default_factory=ModelTrainEvalConfig,
|
|
metadata={"help": "Primary LLM configuration."},
|
|
)
|
|
rew: ModelTrainEvalConfig = field(
|
|
default_factory=ModelTrainEvalConfig,
|
|
metadata={"help": "Reward model configuration."},
|
|
)
|
|
|
|
actor_gen: MFCConfig = field(
|
|
default_factory=MFCConfig, metadata={"help": "Rollout MFC configuration."}
|
|
)
|
|
rew_inf: MFCConfig = field(
|
|
default_factory=MFCConfig, metadata={"help": "InfReward MFC configuration."}
|
|
)
|
|
|
|
dataset: PromptOnlyDatasetConfig = field(
|
|
default_factory=PromptOnlyDatasetConfig,
|
|
metadata={"help": "Dataset configuration."},
|
|
)
|
|
|
|
group_size: int = field(
|
|
default=1,
|
|
metadata={"help": "Number of answers retained per prompt (best-of-n)."},
|
|
)
|
|
rw_type: Optional[str] = field(
|
|
default="sparse",
|
|
metadata={
|
|
"help": "Type of reward processing. Only `sparse` is valid for now.",
|
|
"choices": ["sparse"],
|
|
},
|
|
)
|
|
check_xml_format: bool = field(
|
|
default=False, metadata={"help": "Validate XML format in generated responses."}
|
|
)
|
|
|
|
check_verifier_status: bool = field(
|
|
default=False,
|
|
metadata={"help": "Raise error if reward is all-zero (verifier bug check)."},
|
|
)
|
|
|
|
|
|
## A helper function to visualize the helper messages. ##
|
|
from rich.console import Console
|
|
from rich.highlighter import RegexHighlighter
|
|
from rich.panel import Panel
|
|
from rich.rule import Rule
|
|
from rich.theme import Theme
|
|
|
|
# Custom theme for colors
|
|
help_theme = Theme(
|
|
{
|
|
"title": "bold cyan",
|
|
"header": "bold green",
|
|
"field": "bold yellow",
|
|
"type": "italic blue",
|
|
"help": "dim white",
|
|
"default": "dim green",
|
|
"example": "italic cyan",
|
|
"border": "dim blue",
|
|
}
|
|
)
|
|
|
|
console = Console(theme=help_theme)
|
|
|
|
|
|
class CliHighlighter(RegexHighlighter):
|
|
base_style = "example."
|
|
highlights = [r"(python -m .+?)(?=\s|$)"]
|
|
|
|
|
|
highlighter = CliHighlighter()
|
|
|
|
|
|
def print_config_help(
|
|
config, prefix: str = "", parent_name: str = "", indent: int = 0
|
|
) -> None:
|
|
"""Prints help for a structured config with proper indentation and smart default display"""
|
|
if not is_dataclass(config):
|
|
return
|
|
|
|
for field in fields(config):
|
|
value = getattr(config, field.name)
|
|
full_name = f"{parent_name}.{field.name}" if parent_name else field.name
|
|
indent_space = " " * indent
|
|
|
|
# Field type handling
|
|
type_name = (
|
|
field.type.__name__ if isinstance(field.type, Type) else str(field.type)
|
|
)
|
|
|
|
# Create help text components
|
|
help_parts = []
|
|
if "help" in field.metadata:
|
|
help_parts.append(field.metadata["help"])
|
|
|
|
# Only show default for leaf nodes (non-dataclass fields)
|
|
if not is_dataclass(value):
|
|
default_value = field.default if hasattr(field, "default") else MISSING
|
|
help_parts.append(f"[default]Default: {default_value}[/default]")
|
|
|
|
# Print the field info
|
|
console.print(f"{indent_space}[field]{full_name}[/field]", end=" ")
|
|
console.print(f"[type]({type_name})[/type]", end=" ")
|
|
if help_parts:
|
|
console.print("- " + " ".join(help_parts))
|
|
else:
|
|
console.print()
|
|
|
|
# Handle nested dataclasses with increased indentation
|
|
if is_dataclass(value):
|
|
print_config_help(value, prefix, full_name, indent + 1)
|
|
|
|
|
|
def print_config_values(
|
|
config,
|
|
prefix: str = "",
|
|
parent_name: str = "",
|
|
indent: int = 0,
|
|
show_types: bool = True,
|
|
) -> None:
|
|
"""Prints current values with clean indentation and subtle separation"""
|
|
console.print() # Add space before
|
|
|
|
top_rule = Rule("Current Configuration Begin", style="bold cyan", align="center")
|
|
bottom_rule = Rule("Current Configuration End", style="bold cyan", align="center")
|
|
|
|
# Title with subtle underline
|
|
console.print(top_rule)
|
|
|
|
# Print config directly to main console
|
|
_print_config_values_internal(
|
|
console, config, prefix, parent_name, indent, show_types
|
|
)
|
|
|
|
# Closing rule
|
|
console.print(bottom_rule)
|
|
console.print() # Add space after
|
|
|
|
|
|
def _print_config_values_internal(
|
|
console: Console,
|
|
config,
|
|
prefix: str,
|
|
parent_name: str,
|
|
indent: int,
|
|
show_types: bool,
|
|
) -> None:
|
|
"""Internal recursive function that does the actual printing"""
|
|
if not is_dataclass(config):
|
|
return
|
|
|
|
for field in fields(config):
|
|
value = getattr(config, field.name)
|
|
full_name = f"{parent_name}.{field.name}" if parent_name else field.name
|
|
indent_space = " " * indent
|
|
|
|
# Field type handling
|
|
type_name = (
|
|
field.type.__name__ if isinstance(field.type, Type) else str(field.type)
|
|
)
|
|
|
|
# Create help text components
|
|
help_parts = []
|
|
|
|
# Print the field info
|
|
console.print(f"{indent_space}[field]{full_name}[/field]", end=" ")
|
|
if show_types:
|
|
console.print(f"[type]({type_name})[/type]", end=" ")
|
|
|
|
# Always show current value
|
|
value_str = str(value)
|
|
if isinstance(value, (list, dict)):
|
|
value_str = f"{type(value).__name__}(len={len(value)})"
|
|
if not is_dataclass(value):
|
|
help_parts.append(f"[value]{value_str}[/value]")
|
|
|
|
if help_parts:
|
|
console.print("- " + " ".join(help_parts))
|
|
else:
|
|
console.print()
|
|
|
|
# Handle nested dataclasses
|
|
if is_dataclass(value):
|
|
_print_config_values_internal(
|
|
console, value, prefix, full_name, indent + 1, show_types
|
|
)
|
|
|
|
|
|
def print_runtime_helper(args):
|
|
"""Print comprehensive help with rich formatting"""
|
|
|
|
exp_type = args.__class__.__name__
|
|
# Main help panel
|
|
console.print(
|
|
Panel.fit(
|
|
f"[header]Setting {exp_type} with the Following Values[/header]",
|
|
border_style="border",
|
|
),
|
|
justify="center",
|
|
)
|
|
|
|
# Configuration options section
|
|
print_config_values(args)
|