PullRequest: 62 [Patch v0.2.0] Move all CLI arguments into a single file and add pretty helper messages.

Merge branch fw/cli-args of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/62

Signed-off-by: 郭唯 <kira.gw@antgroup.com>


* .
* format and test
* .
* .
* .
* .
* run
* .
* .
* add runtime helper message
* .
* .
This commit is contained in:
博惟 2025-03-28 12:00:37 +08:00
parent 71429c9655
commit 0bd9969ec4
40 changed files with 1485 additions and 1286 deletions

View File

@ -281,7 +281,7 @@ python3 -m realhf.apps.quickstart ppo-math option1=arg1 option2=arg2 ...
The command-line arguments like `option1=arg1` are parsed by [hydra](https://hydra.cc/), and each configuration item is a `dataclasses.dataclass` in the Python code. You can use the following command to view all the command-line arguments that can be passed in the experiment:
```bash
python3 -m realhf.apps.quickstart ppo-math --show-args
python3 -m realhf.apps.quickstart ppo-math --help
```
The descriptions of the important parameters are as follows:

View File

@ -288,7 +288,7 @@ python3 -m realhf.apps.quickstart ppo-math option1=arg1 option2=arg2 ...
其中`option1=arg1`这些命令行参数是通过[hydra](https://hydra.cc/)进行解析的其中每一条配置项都是python代码中的`dataclasses.dataclass`。用以下命令可以查看实验中所有可以传递的命令行参数:
```bash
python3 -m realhf.apps.quickstart ppo-math --show-args
python3 -m realhf.apps.quickstart ppo-math --help
```
其中重要的参数的说明如下:

View File

@ -5,12 +5,12 @@
"""
Dataset Toolkit - Process and validate code/math datasets with flexible input support
"""
import json
import argparse
import json
import logging
import random
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
# Configure console logging
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)

View File

@ -4,7 +4,7 @@ import json
import random
import sys
from pathlib import Path
from typing import List, Dict, Optional
from typing import Dict, List, Optional
def load_jsonl(file_path: Path) -> List[Dict]:

View File

@ -2,37 +2,23 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
# Initialize preset config before all submodules.
from .base import prologue # isort: skip
from .api.cli_args import *
# Re-import these classes for clear documentation,
# otherwise the name will have a long prefix like
# realhf.api.quickstart.model.ModelTrainEvalConfig.
from .api.core.config import ModelFamily, ModelName, ModelShardID
# otherwise the name will have a long prefix.
from .api.core.config import ModelName, ModelShardID
from .api.core.data_api import SequenceSample
from .api.core.dfg import MFCDef
from .api.core.model_api import (
FinetuneSpec,
GenerationHyperparameters,
Model,
ModelBackend,
ModelInterface,
PipelinableEngine,
ReaLModelConfig,
)
from .api.quickstart.dataset import (
PairedComparisonDatasetConfig,
PromptAnswerDatasetConfig,
PromptOnlyDatasetConfig,
)
from .api.quickstart.device_mesh import MFCConfig
from .api.quickstart.model import (
ModelTrainEvalConfig,
OptimizerConfig,
ParallelismConfig,
)
# Initialize preset config before all submodules.
from .base import prologue
from .experiments.common.common import CommonExperimentConfig, ExperimentSaveEvalControl
from .experiments.common.ppo_math_exp import PPOHyperparameters, PPOMATHConfig
from .experiments.common.sft_exp import SFTConfig
__version__ = "0.3.0"

1221
realhf/api/cli_args.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,7 @@
import dataclasses
import enum
from typing import *
from typing import Any, Dict, List, Optional
import realhf.base.cluster as cluster
import realhf.base.topology as topology
@ -70,33 +70,6 @@ class ModelName:
return str(self)
@dataclasses.dataclass(unsafe_hash=True)
class ModelFamily:
"""An identifier for the HF model type, such as llama, gpt2, etc.
:param _class: The class of the model, e.g., "llama". This is the registered
name in the ``register_hf_family`` function. Please refer to the files
in ``realhf/api/from_hf`` for a list of all supported models.
:type _class: str
:param size: The size of the model. This parameter is only used by the ``search``
allocation mode and will be ignored otherwise.
:type size: int
:param is_critic: Indicates whether the model is a critic or reward model,
as opposed to a standard LLM.
:type is_critic: bool
"""
_class: str
size: int = 0
is_critic: bool = False
def __repr__(self):
s = f"{self._class}-{self.size}"
if self.is_critic:
s += "-critic"
return s
@dataclasses.dataclass
class ModelShardID:
"""The ID of a model shard in a specific model worker.

View File

@ -37,6 +37,7 @@ from pydantic import Field
from pydantic import dataclasses as pdclasses
from pydantic import field_validator, model_validator
from realhf.api.cli_args import MicroBatchSpec
from realhf.api.core import config as config_api
from realhf.base import constants, datapack, logging
from realhf.base.cluster import spec as cluster_spec
@ -100,33 +101,6 @@ class SequenceSplitSpec:
return self
@dataclasses.dataclass
class MicroBatchSpec:
"""The specification for splitting micro-batches.
:param n_mbs: The number of micro-batches, if max_tokens_per_mb is
None. The *minimum* number of micro-batches, if
max_tokens_per_mb is an integer. Defaults to 1.
:type n_mbs: int
:param max_tokens_per_mb: The maximum number of tokens per micro-
batch.
:type max_tokens_per_mb: Optional[int]
"""
n_mbs: int = 1
max_tokens_per_mb: int = int(1e12)
@classmethod
def new(cls, mb_spec: "MicroBatchSpec", **kwargs):
# NOTE: Use classmethod to make the Omegaconf duck object happy.
fields = dict(
n_mbs=mb_spec.n_mbs,
max_tokens_per_mb=mb_spec.max_tokens_per_mb,
)
fields.update(kwargs)
return cls(**fields)
@pdclasses.dataclass(config=dict(arbitrary_types_allowed=True))
class SequenceSample:
"""The data structure used to represent sequence data.

View File

@ -4,14 +4,14 @@
import collections
import dataclasses
from typing import *
from typing import Any, Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import networkx as nx
import realhf.base.logging as logging
from realhf.api.cli_args import ModelFamily
from realhf.api.core.config import (
ModelFamily,
ModelInterfaceAbstraction,
ModelInterfaceType,
ModelName,

View File

@ -6,16 +6,16 @@ import abc
import asyncio
import dataclasses
import keyword
from typing import *
from typing import Any, Callable, Dict, Hashable, List, Literal, Optional, Tuple, Union
import aiohttp
import numpy as np
import torch
import torch.utils.data
import transformers
from packaging.version import Version
import realhf.base.logging as logging
from realhf.api.cli_args import GenerationHyperparameters
from realhf.api.core.config import (
ModelAbstraction,
ModelBackendAbstraction,
@ -33,87 +33,6 @@ class ZeroTotalLossWeightException(Exception):
pass
@dataclasses.dataclass
class GenerationHyperparameters:
"""Generation hyperparameters.
We implement a customized generation function instead of using
HuggingFace's to support pipelined generation. As a result, advanced
generation techniques like diversity-promoting sampling or
repetition penalty are not supported during PPO training. However,
we do not find this to be a problem in practice. Increasing the
sampling temperature and enabling top-k/top-p sampling can produce
effective models.
:param n: The number of sequences to generate for this prompt.
:type n: int
:param max_new_tokens: The maximum number of new tokens to generate.
:type max_new_tokens: int
:param min_new_tokens: The minimum number of new tokens to generate.
:type min_new_tokens: int
:param greedy: Whether to use greedy decoding.
:type greedy: bool
:param top_k: The number of highest probability tokens to keep.
:type top_k: int
:param top_p: The cumulative probability of the highest probability
tokens to keep.
:type top_p: float
:param temperature: The temperature of the sampling process.
:type temperature: float
:param use_cuda_graph: Whether to use CUDA graph to reduce kernel
launch overhead during generation.
:type use_cuda_graph: bool
:param force_cudagraph_recapture: Whether to capture the CUDA graph
every time `generate` is called, even if the graph has been captured
before. This will introduce minor overhead but will release the
kvcache when not running generation.
:type force_cudagraph_recapture: bool
:param force_no_logits_mask: Whether to omit the logits mask. The logits
mask is produced when using top-k or top-p sampling, marking tokens
that are filtered out. This mask is used by the reference model and
the actor model during training to align inferred logits with those
during generation and produce accurate KLs. Using the logits mask with
top-k/top-p sampling greatly improves the stability of PPO training
by narrowing the action space. However, this benefit comes at the cost
of additional GPU memory usage. If this option is set to True, the
logits mask will be omitted to save GPU memory, which may lead to a
decrease in learning performance.
:type force_no_logits_mask: bool
"""
n: int = 1
max_new_tokens: int = 256
min_new_tokens: int = 256
greedy: bool = False
top_p: float = 1.0
top_k: int = int(1e8)
temperature: float = 1.0
use_cuda_graph: bool = True
force_cudagraph_recapture: bool = True
force_no_logits_mask: bool = True
def __post_init__(self):
if self.temperature == 0.0:
self.greedy = True
self.temperature = 1.0
if self.top_p <= 0.0 or self.top_p > 1:
raise ValueError("top_p must be in (0.0, 1.0].")
if self.top_k <= 0:
raise ValueError("top_k must be a positive integer.")
if self.use_cuda_graph and Version(
Version(torch.__version__).base_version
) < Version("2.3.0"):
raise ValueError(
f"To use CUDAGraph, ReaL's PyTorch version should be at least 2.3.0."
)
def new(self, **kwargs):
args = dataclasses.asdict(self)
args.update(kwargs)
return GenerationHyperparameters(**args)
@dataclasses.dataclass
class APIGenerateInput:
qid: Hashable

View File

@ -7,6 +7,12 @@ import os
from typing import Dict, List, Optional, Tuple, Union
import realhf.api.core.dfg as dfg
from realhf.api.cli_args import (
AutomaticEvaluator,
ExperimentSaveEvalControl,
TensorBoardConfig,
WandBConfig,
)
from realhf.api.core.config import (
DatasetAbstraction,
ModelAbstraction,
@ -134,68 +140,6 @@ class ModelWorker:
)
@dataclasses.dataclass
class ExperimentSaveEvalControl:
"""Utility object for controlling the frequency of saving and evaluation
during training.
``Epoch`` refers to the number of times the training loop iterates over the entire dataset.
``Step`` refers to the number of iterations running the algorithm dataflow.
This object manages independent counters for epochs, steps, and seconds. The model will
be saved or evaluated when any of the following conditions are met.
:param total_train_epochs: The total number of epochs to train the model.
:type total_train_epochs: int
:param save_freq_epochs: Frequency in epochs at which to save the model. If None,
the model will not be saved based on epoch changes during training.
:type save_freq_epochs: Optional[int]
:param save_freq_steps: Frequency in steps at which to save the model. If None,
the model will not be saved based on step changes during training.
:type save_freq_steps: Optional[int]
:param save_freq_secs: Frequency in seconds at which to save the model. If None,
the model will not be saved based on time changes during training.
:type save_freq_secs: Optional[int]
:param ckpt_freq_epochs: Frequency in epochs at which to save the model for recover.
The preivous checkpoint will be overwritten to reduce disk usage. If None, use save_freq_epochs.
:type ckpt_freq_epochs: Optional[int]
:param ckpt_freq_steps: Frequency in steps at which to save the model for recover. If None,
the model will not be saved based on step changes during training.
:type ckpt_freq_steps: Optional[int]
:param ckpt_freq_secs: Frequency in seconds at which to save the model for recover. If None,
the model will not be saved based on time changes during training.
:type ckpt_freq_secs: Optional[int]
:param eval_freq_epochs: Frequency in epochs at which to evaluate the model. If None,
the model will not be evaluated based on epoch changes during training.
:type eval_freq_epochs: Optional[int]
:param eval_freq_steps: Frequency in steps at which to evaluate the model. If None,
the model will not be evaluated based on step changes during training.
:type eval_freq_steps: Optional[int]
:param eval_freq_secs: Frequency in seconds at which to evaluate the model. If None,
the model will not be evaluated based on time changes during training.
:type eval_freq_secs: Optional[int]
:param benchmark_steps: Terminate training after this number of steps. Used for system
benchmarking only. Set to None for normal training.
:type benchmark_steps: Optional[int]
"""
total_train_epochs: int = 1
# save control
save_freq_epochs: Optional[int] = None
save_freq_steps: Optional[int] = None
save_freq_secs: Optional[int] = None
# checkpointing control, only used for recover
ckpt_freq_epochs: Optional[int] = None
ckpt_freq_steps: Optional[int] = None
ckpt_freq_secs: Optional[int] = None
# eval control
eval_freq_epochs: Optional[int] = None
eval_freq_steps: Optional[int] = None
eval_freq_secs: Optional[int] = None
# benchmark
benchmark_steps: Optional[int] = None
@dataclasses.dataclass
class MasterWorker:
base_seed: int
@ -227,54 +171,6 @@ class ExperimentScheduling:
controller_image: str = _LLM_CPU_IMAGE
@dataclasses.dataclass
class AutomaticEvaluator:
"""Configuration for automatic evaluation.
:param data_names: Dataset for evaluation seperated by comma. Currently support datasets stored under ./evaluation/data,
including "aime24", "amc23" and "math_500". For example, if "aime24" and "amc23" are required for evaluation,
this field should be set to "aime24,amc23".
:type data_names: str
:param max_gen_tokens: Maximum number of tokens to be generated in evaluation.
:type max_gen_tokens: int
:param max_concurrent_jobs: Maximum number of concurrent evaluation jobs to submit. If number of existing jobs is equal to
`max_concurrent_jobs` and a new checkpoint is saved, the evaluation job will wait until former jobs complete.
:type max_concurrent_jobs: int
:param eval_job_image: Container image used to launch evaluation job. If set to None, evaluation jobs will use
GPU image for training.
:type eval_job_image: Optional[str]
:param initial_checkpoint_path: Initial checkpoint path to evaluate. If specified, this initial checkpoint will be evaluated,
results will be stored as global_step = 0.
:type initial_checkpoint_path: Optional[str]
:param prompt_type: Prompt format used in evaluation.
:type prompt_type: str
"""
data_names: str = "aime24"
max_gen_tokens: int = 32768
max_concurrent_jobs: int = 3
eval_job_image: Optional[str] = None
initial_checkpoint_path: Optional[str] = None
prompt_type: str = "deepscaler"
@dataclasses.dataclass
class WandBConfig:
mode: str = "disabled"
entity: Optional[str] = None
project: Optional[str] = None
name: Optional[str] = None
job_type: Optional[str] = None
group: Optional[str] = None
notes: Optional[str] = None
tags: Optional[List[str]] = None
config: Optional[Dict] = None
@dataclasses.dataclass
class TensorBoardConfig:
path: Optional[str] = None
@dataclasses.dataclass
class ExperimentConfig:
exp_ctrl: ExperimentSaveEvalControl

View File

@ -1,98 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
@dataclasses.dataclass
class PromptAnswerDatasetConfig:
"""Configuration for datasets used in Supervised Fine-Tuning (SFT).
The raw data must be in a JSON or JSONL file format, where each entry is a dictionary
with the keys `prompt` and `answer`. Both `prompt` and `answer` must be strings.
:param train_path: Path to the training dataset.
:type train_path: str
:param valid_path: Path to the validation dataset.
:type valid_path: str
:param max_seqlen: Maximum sequence length (prompt + answer). Sequences longer than
this will be truncated.
:type max_seqlen: int
:param train_bs_n_seqs: Number of sequences in each batch during training.
:type train_bs_n_seqs: int
:param valid_bs_n_seqs: Number of sequences in each batch during validation.
:type valid_bs_n_seqs: int
:param fill_to_max_length: Whether to fill sequences to the maximum length. If True,
prompts will be left-filled with non-pad tokens. Only used for testing.
:type fill_to_max_length: bool
"""
train_path: str = ""
valid_path: str = ""
max_seqlen: int = 1024
train_bs_n_seqs: int = 256
valid_bs_n_seqs: int = 256
fill_to_max_length: bool = False
@dataclasses.dataclass
class PairedComparisonDatasetConfig:
"""Configuration for datasets used in paired-comparison reward modeling,
DPO, and SimPO.
The raw data must be in a JSON or JSONL file format, where each entry is a dictionary
with the keys `prompt`, `pos_answers`, and `neg_answers`. `prompt` is a string, while
`pos_answers` and `neg_answers` are lists of strings. The lists must have the same length.
The raw dataset may contain multiple answer pairs for each prompt. In each epoch, we will
randomly sample `max_pairs_per_prompt` answer pairs for each prompt, so the maximum batch
size (in terms of the number of sequences) per step is `train_bs_n_seqs` multiplied by
`max_pairs_per_prompt`.
:param train_path: Path to the training dataset.
:type train_path: str
:param valid_path: Path to the evaluation dataset.
:type valid_path: str
:param max_pairs_per_prompt: Maximum number of answer pairs per prompt.
:type max_pairs_per_prompt: int
:param max_seqlen: Maximum sequence length (prompt + answers). Sequences longer than
this will be truncated.
:type max_seqlen: int
:param train_bs_n_seqs: Number of sequences in each batch during training.
:type train_bs_n_seqs: int
:param valid_bs_n_seqs: Number of sequences in each batch during validation.
:type valid_bs_n_seqs: int
"""
train_path: str = ""
valid_path: str = ""
max_pairs_per_prompt: int = 2
max_seqlen: int = 1024
train_bs_n_seqs: int = 256
valid_bs_n_seqs: int = 256
@dataclasses.dataclass
class PromptOnlyDatasetConfig:
"""Configuration for datasets used in PPO RLHF.
The raw data must be in a JSON or JSONL file format, where each entry is a dictionary
with a single key called `prompt`, which is a string.
:param path: Path to the dataset.
:type path: str
:param max_prompt_len: Maximum length of the prompt. Prompts longer than this will
be truncated.
:type max_prompt_len: int
:param train_bs_n_seqs: Number of prompts in each batch.
:type train_bs_n_seqs: int
:param fill_to_max_length: Whether to fill prompts to the maximum length. If True,
prompts will be left-filled with non-pad tokens. Only used for testing.
:type fill_to_max_length: bool
"""
path: str = ""
max_prompt_len: int = 256
train_bs_n_seqs: int = 256
fill_to_max_length: bool = False

View File

@ -9,8 +9,8 @@ from typing import List, Optional, Tuple, Union
import numpy as np
from realhf.api.core.dfg import MFCDef, MicroBatchSpec
from realhf.api.quickstart.model import ParallelismConfig
from realhf.api.cli_args import ParallelismConfig
from realhf.api.core.dfg import MFCDef
from realhf.base.cluster import spec as cluster_spec
from realhf.base.slurm_utils import (
are_ones_contiguous,
@ -344,29 +344,3 @@ class RPCAllocation:
device_mesh=DeviceMesh.from_dict(d["device_mesh"]),
parallel=ParallelismConfig(**d["parallel"]),
)
@dataclasses.dataclass
class MFCConfig:
"""Configuration for a single MFC.
:param mb_spec: Specifying how to spliting micro-batches when
executing this MFC. Refer to MicroBatchSpec for details.
:type mb_spec: MicroBatchSpec
:param parallel: Configuration for the parallelism strategy. This is
used only for manual allocation.
:type parallel: ParallelismConfig
:param device_mesh: String representation of the device mesh. If it
consists of multiple nodes, it should be formatted as a SLURM
nodelist, e.g., node[01-02] or node01,node02. If it represents a
slice on a single node, it should occupy 1, 2, 4, or 8
contiguous GPUs on the node. In this case, the string
representation is similar to an MPI hostfile, e.g.,
"node01:0,1,2,3" for the first 4 GPUs on node01. This is used
only for manual allocation.
:type device_mesh: Optional[str]
"""
mb_spec: MicroBatchSpec = dataclasses.field(default_factory=MicroBatchSpec)
parallel: ParallelismConfig = dataclasses.field(default_factory=ParallelismConfig)
device_mesh: Optional[str] = None

View File

@ -18,6 +18,7 @@ from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, OmegaConf
import realhf.api.core.system_api as system_api
from realhf.api.cli_args import print_runtime_helper
from realhf.base.constants import LOG_ROOT, MODEL_SAVE_ROOT, QUICKSTART_EXPR_CACHE_PATH
from realhf.base.ray_utils import check_ray_availability
from realhf.base.slurm_utils import check_slurm_availability
@ -68,6 +69,8 @@ def register_quickstart_exp(config_name: str, exp_cls: Callable):
logger = logging.getLogger("quickstart", "colored")
print_runtime_helper(OmegaConf.to_object(args))
exp_name = args.experiment_name
if args.trial_name == MISSING:
args.trial_name = trial_name = (

View File

@ -1,336 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
from logging import disable
from typing import *
import realhf.base.logging as logging
from realhf.api.core.config import (
ModelAbstraction,
ModelFamily,
ModelWrapperAbstraction,
)
logger = logging.getLogger("Quickstart Model Config")
@dataclasses.dataclass(unsafe_hash=True)
class ParallelismConfig:
"""Configuration for 3D parallelism.
:param model_parallel_size: Size of tensor-model parallelism.
:type model_parallel_size: int
:param pipeline_parallel_size: Number of pipeline parallelism
stages.
:type pipeline_parallel_size: int
:param data_parallel_size: Data parallelism size for ZeRO
optimization.
:type data_parallel_size: int
:param use_sequence_parallel: Whether to use sequence parallelism in
Megatron in combination with tensor-model parallelism.
:type use_sequence_parallel: bool
"""
model_parallel_size: int = 1
pipeline_parallel_size: int = 1
data_parallel_size: int = 1
use_sequence_parallel: bool = False
def __str__(self):
return (
f"Parallel(mp={self.model_parallel_size},"
f"pp={self.pipeline_parallel_size},"
f"dp={self.data_parallel_size})"
)
def parallelism_eq(this, other):
# NOTE: We write this function because
# 1) we don't want to compare sequence_parallelism (it's irrelevant to parameter reallocation)
# 2) implementing this function as a method of ParallelismConfig would cause a OmegaConf bug
return (
(this.model_parallel_size == other.model_parallel_size)
and (this.pipeline_parallel_size == other.pipeline_parallel_size)
and (this.data_parallel_size == other.data_parallel_size)
)
@dataclasses.dataclass
class OptimizerConfig:
"""Configuration for the optimizer.
For models that will not be trained, the optimizer type should be
set to "empty".
:param type: Type of optimizer. Currently, only "adam" and "empty"
optimizers are supported.
:type type: str
:param lr: Learning rate.
:type lr: float
:param weight_decay: Weight decay.
:type weight_decay: float
:param beta1: Adam beta1 parameter.
:type beta1: float
:param beta2: Adam beta2 parameter.
:type beta2: float
:param eps: Adam epsilon parameter in the denominator.
:type eps: float
:param min_lr_ratio: Minimum learning rate ratio after learning rate
annealing. Should be in the interval [0.0, 1.0].
:type min_lr_ratio: float
:param lr_scheduler_type: Type of learning rate scheduler. One of
"linear", "cosine", or "constant".
:type lr_scheduler_type: str
:param warmup_steps_proportion: Proportion of total training steps
allocated for warming up. Should be in the interval [0.0, 1.0].
:type warmup_steps_proportion: float
"""
type: str = dataclasses.field(
metadata={"choices": ["adam", "empty"]},
default="adam",
)
lr: float = 1e-5
weight_decay: float = 0.05
beta1: float = 0.9
beta2: float = 0.95
eps: float = 1e-5
min_lr_ratio: float = 0.0
lr_scheduler_type: str = dataclasses.field(
metadata={"choices": ["linear", "cosine", "constant"]},
default="cosine",
)
warmup_steps_proportion: float = 0.02
offload: bool = False
initial_loss_scale: float = 2**32
min_loss_scale: float = 1.0
loss_scale_window: float = 5
hysteresis: int = 2
gradient_clipping: float = 1.0
@dataclasses.dataclass
class vLLMConfig:
max_num_seqs: int = 256
kv_cache_type: str = "auto"
num_scheduler_steps: int = 1
multi_step_stream_outputs: bool = True
block_size: int = 16
swap_space: int = 4
cpu_offload_gb: float = 0
max_seq_len_to_capture: int = 32768
disable_sliding_window: bool = True
# NOTE: Defaults max_model_len to 32k because a larger value
# will enable chunked prefill in vLLM, which will cause
# evalution performance degeneration.
max_model_len: Optional[int] = 32768
enable_chunked_prefill: bool = False
# NOTE: Setting enable_prefix_caching to False
# because it will reuse the block after
# model weights are updated. Using v0.7.2 reset_prefix_cache
# will fix this issue.
enable_prefix_caching: bool = False
gpu_memory_utilization: float = 0.9
enforce_eager: bool = False
hybrid_train: bool = False
additional_engine_args: Dict = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class SGLangConfig:
disable_cuda_graph: bool = False
disable_radix_cache: bool = False
disable_cuda_graph_padding: bool = False
enable_nccl_nvls: bool = False
disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False
disable_mla: bool = False
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_ep_moe: bool = False
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
cuda_graph_max_bs: Optional[int] = None
cuda_graph_bs: Optional[List[int]] = None
torchao_config: str = ""
enable_nan_detection: bool = False
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
triton_attention_num_kv_splits: int = 8
num_continuous_decode_steps: int = 1
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
# NOTE: to avoid the illegal memory access error
attention_backend: Optional[str] = "triton"
sampling_backend: Optional[str] = None
context_length: Optional[int] = None
mem_fraction_static: Optional[float] = None
max_running_requests: Optional[int] = None
max_total_tokens: Optional[int] = None
chunked_prefill_size: Optional[int] = None
max_prefill_tokens: int = 16384
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
hybrid_train: bool = False
@dataclasses.dataclass
class DistributedDataParallelConfig:
"""Configuration for Megatron DistributedDataParallel.
Some default options have been overwritten.
"""
grad_reduce_in_fp32: bool = False
overlap_grad_reduce: bool = True
overlap_param_gather: bool = False
align_param_gather: bool = False
use_distributed_optimizer: bool = True
check_for_nan_in_grad: bool = False
bucket_size: Optional[int] = None
average_in_collective: bool = False
fp8_param_gather: bool = False
@dataclasses.dataclass
class MegatronConfig:
"""When using the DistributedOptimizer of Megatron, parameters and
gradients will not be splitted across DP ranks, but optimizer states will
be. In other words, Megatron only supports ZeRO-1.
Megatron DDP will split the whole flattend parameter into buckets.
Buckets do not respect parameter boundaries and are dispatched to different DP ranks.
The optimizer on a specific DP rank will only manage its own bucket,
but parameters and gradients are held by all ranks and will not be further splitted.
(That's why only optimizer states are partitioned.) During backward, bucket gradients
will be scatter-reduced (controlled by the `use_distributed_optimizer` option
in Megatron DDP, otherwise all-reduce will be issued), and parameters will then
be updated locally. At this point, the parameters are not synced across DP ranks.
The DistributedOptimizer will then call all-gather on parameters.
Since Megatron allocates static tensors for scatter-reducing parameter gradients,
it does not decrease memory usage just as DeepSpeed ZeRO-2. To be more specific,
with dynamic allocation, we can allocate gradient memory layer-by-layer. When the
backward finishes at layer N, we can scatter-reduce gradients and release the memory
after scattering. As a result, given DP size K, layer number L, and parameter size P
for each layer, dynamic allocation requires P * (1 + L/K) memory for gradients,
but Megatron requires P * L. Memory is not freed after scattering in Megatron.
'use_distributed_optimizer' enables bucketing and scatter-reduce gradients.
When setting to False, optimizer states will not be partitioned.
'overlap_grad_reduce' enables issuing all-reduce/scatter-reduce on the fly
during bacwkard once the gradient is ready, which should usually be enabled.
'overlap_param_gather' overlaps param all-gather with the next forward pass.
It creates a forward hook that waits for the previous parameter all-gather
after the optimizer step. While this sounds good, it can be problematic with
parameter reallocation, because the reallocated parameters do not have the hook.
Can be enabled for SFT, but should be disabled for PPO.
As a final note, Megatron is in an awkward place for PPO with param-realloc.
First, it does not minimize the memory usage of gradients (i.e., ZeRO-2).
Second, for functional correctness, we can't enable `overlap_param_gather`,
and a parameter update will be scatter-reduce grad + all-gather param, instead
of an all-reduce (running all-reduce requires setting `use_distributed_optimizer`
to False, but that will not partition optimizer states!), so it is not that
efficient, either. We use Megatron because it is the only backend that we can
make it functionally correct. The DeepSpeed code is too hard to read and modify.
"""
ddp: DistributedDataParallelConfig = dataclasses.field(
default_factory=DistributedDataParallelConfig
)
# Don't use MegatronOptimizerConfig here because OmegaConf
# does not recognize the annotation "torch.dtype"
overlap_param_gather_with_optimizer_step: bool = False
use_precision_aware_optimizer: bool = False
main_grads_dtype: str = "float32"
main_params_dtype: str = "float32"
exp_avg_dtype: str = "float32"
exp_avg_sq_dtype: str = "float32"
@dataclasses.dataclass
class ModelTrainEvalConfig:
"""Runtime configuration for models (or LLMs) in ReaL.
We use a customized model class instead of HuggingFace's. This customized model has
the following highlights:
1. Support for 3D parallelism and sequence parallelism.
2. Support for flash attention during both training and generation.
3. Input sequences are packed into a single 1D tensor to save GPU memory and improve efficiency.
Consequently, each HuggingFace model of interest needs to be manually converted to this
customized model. Implemented models can be found in the ``realhf/api/from_hf/`` directory.
:param type: Model family type, e.g., llama, qwen2, etc.
:type type: ModelFamily
:param backend: Backend for training. Currently, only "megatron" and "deepspeed" are supported.
Use "deepspeed" for offloading parameters or optimizer states, and "megatron" for
parameter reallocation.
:type backend: str
:param path: Path of the HuggingFace checkpoint.
:type path: str
:param gradient_checkpointing: Whether to use gradient checkpointing to save memory.
:type gradient_checkpointing: bool
:param bf16: Whether to use bf16 precision. Otherwise use fp16.
:type bf16: bool
:param parallel: Configuration for parallelism.
:type parallel: ParallelismConfig
:param optimizer: Configuration for the optimizer.
:type optimizer: Optional[OptimizerConfig]
:param init_critic_from_actor: Whether to initialize a critic/reward model from a saved LM checkpoint.
:type init_critic_from_actor: bool
"""
type: ModelFamily = dataclasses.field(default=ModelFamily("llama", 7, False))
backend: str = dataclasses.field(
default="megatron", metadata={"choices": ["megatron", "deepspeed"]}
)
path: str = ""
gradient_checkpointing: bool = True
bf16: bool = False
optimizer: Optional[OptimizerConfig] = dataclasses.field(
default_factory=OptimizerConfig
)
megatron: MegatronConfig = dataclasses.field(default_factory=MegatronConfig)
vllm: vLLMConfig = dataclasses.field(default_factory=vLLMConfig)
sglang: SGLangConfig = dataclasses.field(default_factory=SGLangConfig)
init_from_scratch: bool = False
init_critic_from_actor: bool = False
def get_real_model_config(
model_path: str,
hf_model_family: str,
is_critic: bool,
init_from_scratch: bool,
init_critic_from_actor: bool,
dtype: Optional[str] = None,
) -> ModelAbstraction:
"""Make a configuration to build model."""
model = ModelAbstraction(
"real_model",
args=dict(
model_path=model_path,
is_critic=is_critic,
init_critic_from_actor=init_critic_from_actor,
dtype=dtype,
hf_model_family=hf_model_family,
init_from_scratch=init_from_scratch,
),
)
return model

View File

@ -5,9 +5,9 @@
import dataclasses
from typing import List, Optional
from realhf.api.cli_args import ParallelismConfig
from realhf.api.core.dfg import MFCDef
from realhf.api.quickstart.device_mesh import DeviceMesh
from realhf.api.quickstart.model import ParallelismConfig
@dataclasses.dataclass

View File

@ -311,7 +311,7 @@ def main_find_config(args):
def main_profile_layers(args):
from realhf.api.core.model_api import ModelFamily
from realhf.api.cli_args import ModelFamily
_main_profile_layers(
ModelFamily(args.model_class, args.model_size, args.is_critic),
@ -320,7 +320,7 @@ def main_profile_layers(args):
def _main_profile_layers(model_family, model_path):
from realhf.api.core.model_api import ModelFamily
from realhf.api.cli_args import ModelFamily
from realhf.base.slurm_utils import check_slurm_availability
from realhf.base.testing import clear_name_resolve

View File

@ -12,8 +12,10 @@ import sys
import hydra
from omegaconf import DictConfig, OmegaConf
from rich.panel import Panel
from realhf.api.quickstart.entrypoint import QUICKSTART_FN
from realhf.api.cli_args import console, highlighter, print_config_help
from realhf.api.quickstart.entrypoint import QUICKSTART_CONFIG_CLASSES, QUICKSTART_FN
from realhf.base.cluster import spec as cluster_spec
from realhf.base.importing import import_module
from realhf.base.prologue import (
@ -32,29 +34,79 @@ import_module(
import realhf.experiments.benchmark.profile_exp
def print_help(exp_type):
"""Print comprehensive help with rich formatting"""
config_class = QUICKSTART_CONFIG_CLASSES[exp_type]()
# Main help panel
console.print(
Panel.fit(
f"[header]Configuration Help for {exp_type}[/header]", border_style="border"
)
)
# Configuration options section
console.print("\n[title]CONFIGURATION OPTIONS[/title]")
print_config_help(config_class)
# Usage section
console.print("\n[title]USAGE[/title]")
usage_code = f"python -m realhf.apps.quickstart {exp_type} --config ./your/config.yaml [OPTIONS]"
console.print(highlighter(usage_code))
# Examples section
console.print("\n[title]EXAMPLE OVERRIDES[/title]")
example_code = f"python -m realhf.apps.quickstart {exp_type} --config ./your/config.yaml dataset.path=/my/dataset.jsonl actor.optimizer.lr=2e-5"
console.print(highlighter(example_code))
# Footer
console.print("\n[dim]Use [bold]--help[/bold] to show this message again[/dim]")
def main():
parser = argparse.ArgumentParser(prog="ReaL Quickstart")
# Create parser with add_help=False to disable automatic --help
parser = argparse.ArgumentParser(prog="ReaL Quickstart", add_help=False)
# Add custom help argument that won't conflict
parser.add_argument(
"--help", action="store_true", help="Show this help message and exit"
)
subparsers = parser.add_subparsers(dest="cmd", help="sub-command help")
subparsers.required = True
for k, v in QUICKSTART_FN.items():
subparser = subparsers.add_parser(k)
# Create subparser with add_help=False
subparser = subparsers.add_parser(k, add_help=False)
# Add custom help to subparser
subparser.add_argument(
"--show-args",
action="store_true",
help="Show all legal CLI arguments for this experiment.",
"--help", action="store_true", help="Show help for this command"
)
subparser.add_argument(
PROLOGUE_FLAG_NAME,
type=str,
help="Set config (*.yaml) for this experiment.",
)
subparser.set_defaults(func=v)
# Parse known args first to check for help
args = vars(parser.parse_known_args()[0])
if args["show_args"]:
sys.argv = [sys.argv[0], "--help"]
QUICKSTART_FN[args["cmd"]]()
# Handle help at both main and subcommand levels
if args["help"]:
if args["cmd"]:
# Subcommand help
print_help(args["cmd"])
else:
# Main help
parser.print_help()
return
# Continue with normal execution
if not args["cmd"]:
parser.print_help()
experiment_name = ""
trial_name = ""

View File

@ -11,6 +11,12 @@ from typing import *
from omegaconf import OmegaConf
from realhf.api.cli_args import (
MFCConfig,
ModelTrainEvalConfig,
ParallelismConfig,
PromptOnlyDatasetConfig,
)
from realhf.api.core.config import (
DatasetAbstraction,
ModelInterfaceAbstraction,
@ -18,10 +24,7 @@ from realhf.api.core.config import (
)
from realhf.api.core.dfg import MFCDef
from realhf.api.core.system_api import ExperimentConfig
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
from realhf.api.quickstart.device_mesh import MFCConfig
from realhf.api.quickstart.entrypoint import register_quickstart_exp
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
from realhf.base import constants, logging
from realhf.base.topology import decompose_to_three_factors
from realhf.experiments.common.common import CommonExperimentConfig

View File

@ -2,13 +2,10 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import os
from importlib.metadata import version
from typing import List
from packaging.version import Version
from realhf.api.cli_args import ModelTrainEvalConfig, SGLangConfig, vLLMConfig
from realhf.api.quickstart.device_mesh import RPCAllocation
from realhf.api.quickstart.model import ModelTrainEvalConfig, SGLangConfig, vLLMConfig
from realhf.base import logging
logger = logging.getLogger(__name__)

View File

@ -13,6 +13,12 @@ import transformers
from omegaconf import MISSING, OmegaConf
import realhf.base.logging as logging
from realhf.api.cli_args import (
BaseExperimentConfig,
MFCConfig,
ModelTrainEvalConfig,
ParallelismConfig,
)
from realhf.api.core.config import (
DatasetAbstraction,
ModelAbstraction,
@ -24,28 +30,18 @@ from realhf.api.core.config import (
from realhf.api.core.dfg import MFCDef, ModelInterfaceType
from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY
from realhf.api.core.system_api import (
AutomaticEvaluator,
Experiment,
ExperimentConfig,
ExperimentSaveEvalControl,
ExperimentScheduling,
ModelWorker,
Scheduling,
TasksGroup,
TensorBoardConfig,
WandBConfig,
)
from realhf.api.quickstart.device_mesh import (
DeviceMesh,
MFCConfig,
RPCAllocation,
make_device_mesh_from_name,
)
from realhf.api.quickstart.model import (
ModelTrainEvalConfig,
ParallelismConfig,
get_real_model_config,
)
from realhf.base.cluster import spec as cluster_spec
from realhf.experiments.common.check import (
check_is_realhf_native_model_interface,
@ -58,6 +54,7 @@ from realhf.experiments.common.check import (
from realhf.experiments.common.utils import (
AllocationMode,
asdict,
get_real_model_config,
get_topo,
make_inf_backend_config,
make_train_backend_config,
@ -75,165 +72,7 @@ GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = False
@dataclasses.dataclass
class CommonExperimentConfig(Experiment):
"""Configuration for quickstart experiments.
All members can be modified via the command line. For example,
.. code-block:: shell
$ python3 -m realhf.apps.quickstart sft trial_name=my_trial seed=42 exp_ctrl.save_freq_steps=10 ...
This command changes the ``trial_name``, ``seed``, and the ``save_freq_steps`` attribute
of the ``exp_ctrl`` attribute in this class.
``recover_mode`` can be one of the following\:
- ``auto``\: Automatically recover the last failed run. If the checkpoint does not exist, run from scratch with fault tolerance.
- ``fault``\: Run from scratch with fault tolerance.
- ``resume``\: Resume from saved recovery states and then run it once without fault tolerance.
- ``disabled``\: Do nothing but raise an error if one occurs.
If you are not familiar with ReaL's recovery mechanism, set this to ``disabled``.
Normal checkpointing is usually sufficient in most cases.
``allocation_mode`` can be one of the following\:
- ``manual``\: Manually allocate resources using the specified command-line options.
- ``search``\: Allocate resources and configure parallel strategies using the search engine.
- ``heuristic``\: Allocate resources and configure parallel strategies using heuristic strategies obtained from a search.
- A regex pattern like ``d${DP}p${PP}m${TP}``\: Identical parallelization for all MFCs with ${DP}-way data parallelism, ${PP}-way pipeline parallelism, and ${TP}-way model parallelism.
- A regex pattern like ``{vllm|sglang}.{IdentPara}+{IdentPara}``\: Decoupled generation and training allocations with correspnding identical parallelization strategies.
- Key-value pairs with MFC names and their parallel strategies in the whole cluster, e.g., ``actor_gen:d4m2p1,*:d2p2m2`` specifies a ``d4m2p1`` strategy for actor geneartion and ``d2p2m2`` for other MFCs in a world of 8 GPUs.
:param experiment_name: The name of the experiment.
An arbitrary string without "_" and "/", e.g., ``ultra-chat-llama``.
This parameter is required.
:type experiment_name: str
:param trial_name: The name of the trial.
An arbitrary string without "-" and "/", e.g., ``lr1e-3wd0.05``.
This parameter is required.
:type trial_name: str
:param mode: The experiment launching mode. Supported values are "local", "ray", or "slurm".
"ray" mode requires launching the Ray cluster via CLI.
"slurm" mode requires the Pyxis plugin with the Enroot container enabled.
"local" mode implies ``n_nodes=1``.
:type mode: str
:param debug: Whether to run in debug mode.
Setting this to `False` will disable all assertions, which will be faster but less safe.
:type debug: bool
:param partition: The SLURM partition for running the experiment.
:type partition: str
:param wandb: The WandB initialization config.
See https://docs.wandb.ai/ref/python/init/ for details.
:type wandb: WandbConfig
:param tensorboard: The tensorboard initialization config.
Only the field of `path` is needed to specify the directory of saving the tensorboard events.
:type tensorboard: TensorBoardConfig
:param image_name: The name of the Docker image used by the controller.
This parameter is only used in SLURM mode.
:type image_name: str or None
:param recover_mode: The recovery mode. See above for details.
:type recover_mode: str
:param recover_retries: The number of retries for recovery.
Effective only when ``recover_mode`` is set to "auto" or "fault".
:type recover_retries: int
:param recover_after: The time interval (seconds) for recovery.
Effective only when ``recover_mode`` is set to "auto" or "fault".
:type recover_after: int
:param ignore_worker_error: Whether to ignore errors raised by
workers during runtime. Only set this to `True` if you are certain that the error can be ignored.
Effective only when ``recover_mode`` is set to "disabled".
:type ignore_worker_error: bool
:param allocation_mode: The mode for GPU parallel strategy allocation. See above for details.
:type allocation_mode: str
:param allocation_use_cache: Whether to use cache in allocation search.
Effective only when ``allocation_mode`` is set to "search" and a cache is available in the log directory of the current experiment
name and trial.
:type allocation_use_cache: bool
:param n_nodes: The number of nodes to run the experiment.
:type n_nodes: int
:param n_gpus_per_node: The number of GPUs per node.
Thus, the total number of GPUs will be ``n_nodes * n_gpus_per_node``.
ReaL supports a world size of 1, 2, 4, 8, ... within a single node,
or multiple nodes with the same number of GPUs.
:type n_gpus_per_node: int
:param nodelist: Nodelist for the distributed setting in SLURM nodelist format.
Required for the ``manual`` allocation mode.
For multiple GPUs on a single node, it should be formatted as "NODE01:0,1,2,3",
indicating the use of the first 4 GPUs on ``NODE01``.
For multiple complete nodes, it should be formatted as "NODE[01-02,03,07],COM08",
indicating the use of all GPUs on these nodes: [NODE01, NODE02, NODE03, NODE07, COM08].
:type nodelist: str or None
:param seed: The random seed.
:type seed: int
:param cache_clear_freq: The cache of data transfer will be cleared after each ``cache_clear_freq`` steps.
If None, will not clear the cache. Set to a small number, e.g., 1, if OOM or CUDA OOM occurs.
:type cache_clear_freq: int or None
:param exp_ctrl: The control for saving and evaluating the experiment.
:type exp_ctrl: ExperimentSaveEvalControl
:param torch_cache_mysophobia: Whether to clean torch-allocated cache blocks with
torch.cuda.empty_cache() before each RPC in model worker
If enabled, there will be a ~0.1s overhead per RPC.
:type torch_cache_mysophobia: bool
:param auto_eval: Whether to automatic evaluation in training. When enabled, an evaluation
job is submitted whenever a checkpoint is saved, and the result will be logged on disk and
on wandb if wandb is active.
:type auto_eval: bool
:param auto_eval_config: Configuration for automatic evaluation.
:type auto_eval_config: AutomaticEvaluator
:param cpus_per_master_worker: The number of CPUs for each master worker.
:param mem_per_master_worker: The size of memory for each master worker, measured in MB.
:param cpus_per_model_worker: The number of CPUs for each model worker.
:param mem_per_model_worker: The size of memory for each model worker, measured in MB.
"""
experiment_name: str = MISSING
trial_name: str = MISSING
mode: str = dataclasses.field(
metadata={"choices": ["slurm", "local", "ray"]}, default="slurm"
)
debug: bool = True
partition: str = "dev"
schedule_strategy: str = "empty_first"
wandb: WandBConfig = dataclasses.field(default_factory=WandBConfig)
tensorboard: TensorBoardConfig = dataclasses.field(
default_factory=TensorBoardConfig
)
image_name: Optional[str] = None
recover_mode: str = "disabled"
recover_retries: int = 1
recover_after: int = 10
ignore_worker_error: bool = False
allocation_mode: str = ""
allocation_use_cache: bool = False
n_nodes: int = 1
n_gpus_per_node: int = cluster_spec.n_gpus_per_node
nodelist: Optional[str] = None
seed: int = 1
cache_clear_freq: Optional[int] = 10
exp_ctrl: ExperimentSaveEvalControl = dataclasses.field(
default_factory=ExperimentSaveEvalControl
)
torch_cache_mysophobia: bool = True
# Options for automatic evaluation
auto_eval: bool = False
auto_eval_config: AutomaticEvaluator = dataclasses.field(
default_factory=AutomaticEvaluator
)
# Options for worker resources
cpus_per_master_worker: int = 4
mem_per_master_worker: int = 20000
cpus_per_model_worker: int = 4
mem_per_model_worker: int = 90000
class CommonExperimentConfig(BaseExperimentConfig, Experiment):
@property
def models(self) -> Dict[str, ModelTrainEvalConfig]:

View File

@ -2,40 +2,25 @@
import dataclasses
from realhf.api.cli_args import (
MFCConfig,
ModelTrainEvalConfig,
NullPPOExperimentOptions,
PromptOnlyDatasetConfig,
SFTExperimentOptions,
)
from realhf.api.core.config import (
DatasetAbstraction,
ModelInterfaceAbstraction,
ModelInterfaceType,
)
from realhf.api.core.dfg import MFCDef
from realhf.api.quickstart.dataset import (
PromptAnswerDatasetConfig,
PromptOnlyDatasetConfig,
)
from realhf.api.quickstart.device_mesh import MFCConfig
from realhf.api.quickstart.entrypoint import register_quickstart_exp
from realhf.api.quickstart.model import ModelTrainEvalConfig
from realhf.experiments.common.common import CommonExperimentConfig
@dataclasses.dataclass
class NullSFTConfig(CommonExperimentConfig):
"""Configuration for a null SFT experiment. Used for testing purposes.
:param allocation: Configuration for device allocation and
parallelism.
:type allocation: MFCConfig
:param dataset: Configuration for the dataset.
:type dataset: PromptAnswerDatasetConfig
"""
model: ModelTrainEvalConfig = dataclasses.field(
default_factory=ModelTrainEvalConfig
)
allocation: MFCConfig = dataclasses.field(default_factory=MFCConfig)
dataset: PromptAnswerDatasetConfig = dataclasses.field(
default_factory=PromptAnswerDatasetConfig
)
class NullSFTConfig(CommonExperimentConfig, SFTExperimentOptions):
@property
def models(self):
@ -52,7 +37,7 @@ class NullSFTConfig(CommonExperimentConfig):
interface_type=ModelInterfaceType.TRAIN_STEP,
interface_impl=ModelInterfaceAbstraction("null"),
model_name="default",
input_keys=["packed_input_ids", "prompt_mask"],
input_keys=("packed_input_ids", "prompt_mask"),
log_return_value=True,
model_type=self.model.type,
model_path=self.model.path,
@ -71,7 +56,6 @@ class NullSFTConfig(CommonExperimentConfig):
args=dict(
max_length=self.dataset.max_seqlen,
dataset_path=self.dataset.train_path,
fill_to_max_length=self.dataset.fill_to_max_length,
),
)
]
@ -85,22 +69,7 @@ register_quickstart_exp("null-sft", NullSFTConfig)
@dataclasses.dataclass
class NullPPOConfig(CommonExperimentConfig):
"""Configuration for a null PPO experiment.
Used for testing purposes.
"""
model: ModelTrainEvalConfig = dataclasses.field(
default_factory=ModelTrainEvalConfig
)
inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
dataset: PromptOnlyDatasetConfig = dataclasses.field(
default_factory=PromptOnlyDatasetConfig
)
dataset_filter_threshold: float = 0.2
dataset_max_filter_percentage: float = 0.1
class NullPPOConfig(CommonExperimentConfig, NullPPOExperimentOptions):
@property
def models(self):
@ -117,8 +86,8 @@ class NullPPOConfig(CommonExperimentConfig):
interface_type=ModelInterfaceType.INFERENCE,
interface_impl=ModelInterfaceAbstraction("null"),
model_name="default",
input_keys=["packed_prompts"],
output_keys=["rewards"],
input_keys=("packed_prompts",),
output_keys=("rewards",),
model_type=self.model.type,
model_path=self.model.path,
)
@ -129,7 +98,7 @@ class NullPPOConfig(CommonExperimentConfig):
interface_type=ModelInterfaceType.TRAIN_STEP,
interface_impl=ModelInterfaceAbstraction("null"),
model_name="default",
input_keys=["packed_prompts", "rewards"],
input_keys=("packed_prompts", "rewards"),
log_return_value=True,
model_type=self.model.type,
model_path=self.model.path,
@ -144,11 +113,10 @@ class NullPPOConfig(CommonExperimentConfig):
def datasets(self):
return [
DatasetAbstraction(
"math_prompt",
"math_code_prompt",
args=dict(
max_length=self.dataset.max_prompt_len,
dataset_path=self.dataset.path,
fill_to_max_length=self.dataset.fill_to_max_length,
filter_threshold=self.dataset_filter_threshold,
max_filter_percentage=self.dataset_max_filter_percentage,
),

View File

@ -5,21 +5,18 @@ import copy
import dataclasses
import os
import pprint
from typing import *
from typing import Dict
import realhf.base.logging as logging
from realhf.api.cli_args import ModelTrainEvalConfig, PPOMATHExperimentOptions
from realhf.api.core.config import (
DatasetAbstraction,
ModelInterfaceAbstraction,
ModelInterfaceType,
)
from realhf.api.core.dfg import MFCDef, ParamReallocHook
from realhf.api.core.model_api import GenerationHyperparameters
from realhf.api.core.system_api import ExperimentConfig
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
from realhf.api.quickstart.device_mesh import MFCConfig
from realhf.api.quickstart.entrypoint import register_quickstart_exp
from realhf.api.quickstart.model import ModelTrainEvalConfig
from realhf.experiments.common.common import CommonExperimentConfig
from realhf.experiments.common.utils import (
asdict,
@ -31,198 +28,11 @@ logger = logging.getLogger("PPO Math exp", "colored")
@dataclasses.dataclass
class PPOHyperparameters:
"""Configuration of PPO hyperparameters.
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
:param gen: Generation hyperparameters.
:type gen: GenerationHyperparameters
:param ppo_n_minibatches: Number of minibatches in each PPO update.
:type ppo_n_minibatches: int
:param kl_ctl: Coefficient of KL divergence rewards.
:type kl_ctl: float
:param discount: Discount factor.
:type discount: float
:param gae_lambda: Lambda factor in GAE.
:type gae_lambda: float
:param eps_clip: PPO actor probability ratio clipping factor.
:type eps_clip: float
:param value_eps_clip: PPO value clipping factor.
:type value_eps_clip: float
:param max_reward_clip: Maximum reward value.
:type max_reward_clip: float
:param reward_output_scaling: Scaling factor of the reward model output.
:type reward_output_scaling: float
:param reward_output_bias: Bias of the reward model output.
The number outputed by the reward model will be
CLIP((x - bias) * scaling, -max_reward_clip, max_reward_clip).
:type reward_output_bias: float
:param early_stop_imp_ratio: PPO update will be early stopped if importance ratio
exceeds this maximum value.
:type early_stop_imp_ratio: float
:param use_adaptive_kl_ctl: Whether to use adaptive KL divergence coefficient.
:type use_adaptive_kl_ctl: bool
:param adv_norm: Whether to use advantage normalization.
:type adv_norm: bool
:param value_norm: Whether to denormalize valued and normalize return predictions.
:type value_norm: bool
:param value_norm_type: Type of value normalization.
Either exponential moving average ("exp") or moving average ("ma").
:type value_norm_type: str
:param value_norm_beta: Exponential decay factor
in exponential moving average.
:type value_norm_beta: float
:param value_norm_eps: Epsilon factor in the
denominator of exponential moving average.
:type value_norm_eps: float
:param disable_value: A shortcut option to disable the critic model.
:type disable_value: bool
"""
gen: GenerationHyperparameters = dataclasses.field(
default_factory=GenerationHyperparameters
)
ppo_n_minibatches: int = 4
kl_ctl: float = 0.1
discount: float = 1.0
gae_lambda: float = 1.0
eps_clip: float = 0.2
value_eps_clip: float = 0.2
max_reward_clip: float = 20.0
reward_output_scaling: float = 1.0
reward_output_bias: float = 0.0
early_stop_imp_ratio: float = 5.0
use_adaptive_kl_ctl: bool = False
adv_norm: bool = True
value_norm: bool = True
value_norm_type: str = dataclasses.field(
metadata={"choices": ["exp", "ma"]}, default="exp"
)
value_norm_beta: float = 0.99995
value_norm_eps: float = 1e-5
disable_value: bool = False
recompute_logprob: bool = False
fuse_rew_ref: bool = True
@dataclasses.dataclass
class PPOMATHConfig(CommonExperimentConfig):
"""PPO experiment configuration.
It is a subclass of :class:`CommonExperimentConfig`,
so all CLI options in the base class are available.
We don't implement runtime evaluation for PPO.
We identify that the RLHF process is composed of four
distinct models with independent parameters and six
*model function calls* upon these models.
The four models are\:
- Actor\: The primary LLM that generates text.
- Critic\: The value function that estimates the value of a state.
- Ref\: The reference LLM that provides KL regularization.
- Rew\: The reward model that provides reward signals.
The four model function calls and their dependencies are\:
- Rollout\: Generate text from the actor model.
- InfReward\: Infer rewards from the reward model given generated text.
- InfRef\: Infer log probabilities from the reference model given generated text.
- InfValues\: Infer values from the critic model given generated text.
- TrainActor\: Train the actor model given generated text, rewards, values, and reference log probabilities.
- TrainCritic\: Train the critic model given generated text, rewards, values, and reference log probabilities.
This class resolves these dependencies under the hood.
What the users should specify are the runtime configurations
of models and allocations of *each model function call*.
:param actor: Runtime configuration of the primary LLM.
:type actor: ModelTrainEvalConfig
:param critic: Runtime configuration of the critic model of PPO.
:type critic: ModelTrainEvalConfig
:param ref: Runtime configuration of the reference LLM.
:type ref: ModelTrainEvalConfig
:param rew: Runtime configuration of the reward LLM.
:type rew: ModelTrainEvalConfig
:param actor_train: :class:`MFCConfig` for TrainActor.
:type actor_train: MFCConfig
:param critic_train: :class:`MFCConfig` for TrainCritic.
:type critic_train: MFCConfig
:param actor_gen: :class:`MFCConfig` for Rollout.
:type actor_gen: MFCConfig
:param critic_inf: :class:`MFCConfig` for InfValues.
:type critic_inf: MFCConfig
:param rew_inf: :class:`MFCConfig` for InfReward.
:type rew_inf: MFCConfig
:param ref_inf: :class:`MFCConfig` for InfRef.
:type ref_inf: MFCConfig
:param dataset: Dataset configuration.
:type dataset: PromptOnlyDatasetConfig
:param ppo: Configuration for the PPO algorithm.
:type ppo: PPOHyperparameters
:param group_size: The number of answers remained for each prompt.
:type group_size: int
:param generation_size: The number of answers sampled for each prompt.
Among them, only `group_size` samples are remained according to
the reward score, aka best-of-n sampling. If None, use `group_size`.
:type generation_size: Optional[int]
:param mask_no_eos_with_zero: Whether to mask out the reward if an
answer is truncated due to exceeding the length limit.
:type mask_no_eos_with_zero: bool
:param mask_too_long: Whether to mask out the PPO loss if an
answer is truncated due to exceeding the length limit.
:type mask_too_long: bool
:param check_verifier_status: If True, raise an error
when the reward is all-zero. This usually indicates a bug
of the verifier.
:type check_verifier_status: bool
:param group_adv_norm: Whther to use grouped advantage
normaliztion in GRPO.
:type group_adv_norm: bool
"""
actor: ModelTrainEvalConfig = dataclasses.field(
default_factory=ModelTrainEvalConfig
)
critic: ModelTrainEvalConfig = dataclasses.field(
default_factory=ModelTrainEvalConfig
)
ref: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
# for manual allocation only
actor_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
critic_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
actor_gen: MFCConfig = dataclasses.field(default_factory=MFCConfig)
critic_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
rew_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
actor_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
dataset: PromptOnlyDatasetConfig = dataclasses.field(
default_factory=PromptOnlyDatasetConfig
)
ppo: PPOHyperparameters = dataclasses.field(default_factory=PPOHyperparameters)
group_size: int = 1
generation_size: Optional[int] = None
mask_no_eos_with_zero: bool = False
ref_ema_eta: Optional[float] = None
group_adv_norm: bool = False
mask_too_long: bool = False
rw_type: Optional[str] = "sparse"
check_xml_format: bool = False
check_verifier_status: bool = False
dataset_filter_threshold: float = 100.0
dataset_max_filter_percentage: float = 0.0
def __post_init__(self):
self.ppo_kwargs = dict(
@property
def ppo_kwargs(self):
return dict(
n_minibatches=self.ppo.ppo_n_minibatches,
kl_ctl=self.ppo.kl_ctl,
discount=self.ppo.discount,
@ -341,8 +151,8 @@ class PPOMATHConfig(CommonExperimentConfig):
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=["packed_prompts", "task_ids"],
output_keys=rollout_output_keys,
input_keys=("packed_prompts", "task_ids"),
output_keys=tuple(rollout_output_keys),
n_seqs=self.dataset.train_bs_n_seqs,
)
@ -354,8 +164,8 @@ class PPOMATHConfig(CommonExperimentConfig):
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=["packed_input_ids"],
output_keys=["packed_logprobs"],
input_keys=("packed_input_ids",),
output_keys=("packed_logprobs",),
output_key_remap=dict(logprobs="packed_logprobs"),
n_seqs=self.dataset.train_bs_n_seqs,
)
@ -366,8 +176,8 @@ class PPOMATHConfig(CommonExperimentConfig):
interface_type=ModelInterfaceType.INFERENCE,
interface_impl=rw_interface,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=["packed_input_ids", "packed_prompts", "task_ids"],
output_keys=["rewards"],
input_keys=("packed_input_ids", "packed_prompts", "task_ids"),
output_keys=("rewards",),
n_seqs=self.dataset.train_bs_n_seqs,
)
@ -387,8 +197,8 @@ class PPOMATHConfig(CommonExperimentConfig):
model_path=self.ref.path,
interface_impl=ref_interface,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=inf_ref_inputs,
output_keys=inf_ref_outputs,
input_keys=tuple(inf_ref_inputs),
output_keys=tuple(inf_ref_outputs),
output_key_remap=dict(logprobs="packed_ref_logprobs"),
n_seqs=self.dataset.train_bs_n_seqs,
)
@ -402,8 +212,8 @@ class PPOMATHConfig(CommonExperimentConfig):
model_type=self.critic.type,
model_path=self.critic.path,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=["packed_input_ids", "seq_no_eos_mask"],
output_keys=["values"],
input_keys=("packed_input_ids", "seq_no_eos_mask"),
output_keys=("values",),
n_seqs=self.dataset.train_bs_n_seqs,
)
@ -426,7 +236,7 @@ class PPOMATHConfig(CommonExperimentConfig):
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=train_actor_inputs,
input_keys=tuple(train_actor_inputs),
log_return_value=True,
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
n_seqs=self.dataset.train_bs_n_seqs,
@ -440,7 +250,7 @@ class PPOMATHConfig(CommonExperimentConfig):
interface_impl=critic_interface,
model_type=self.critic.type,
model_path=self.critic.path,
input_keys=[
input_keys=(
"packed_input_ids",
"packed_logprobs",
"packed_ref_logprobs",
@ -448,7 +258,7 @@ class PPOMATHConfig(CommonExperimentConfig):
"values",
"prompt_mask",
"seq_no_eos_mask",
],
),
log_return_value=True,
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
n_seqs=self.dataset.train_bs_n_seqs,

View File

@ -4,6 +4,7 @@
import dataclasses
from realhf.api.cli_args import SFTExperimentOptions
from realhf.api.core.config import (
DatasetAbstraction,
ModelInterfaceAbstraction,
@ -11,35 +12,12 @@ from realhf.api.core.config import (
ModelName,
)
from realhf.api.core.dfg import MFCDef
from realhf.api.quickstart.dataset import PromptAnswerDatasetConfig
from realhf.api.quickstart.device_mesh import MFCConfig
from realhf.api.quickstart.entrypoint import register_quickstart_exp
from realhf.api.quickstart.model import ModelTrainEvalConfig
from realhf.experiments.common.common import CommonExperimentConfig
@dataclasses.dataclass
class SFTConfig(CommonExperimentConfig):
"""Configuration for SFT experiments.
This class is a subclass of :class:`CommonExperimentConfig`,
so all CLI options from the base class are available.
:param model: Configuration for model runtime.
:type model: ModelTrainEvalConfig
:param allocation: Configuration for device allocation and parallelism.
:type allocation: MFCConfig
:param dataset: Configuration for the dataset.
:type dataset: PromptAnswerDatasetConfig
"""
model: ModelTrainEvalConfig = dataclasses.field(
default_factory=ModelTrainEvalConfig
)
allocation: MFCConfig = dataclasses.field(default_factory=MFCConfig)
dataset: PromptAnswerDatasetConfig = dataclasses.field(
default_factory=PromptAnswerDatasetConfig
)
class SFTConfig(CommonExperimentConfig, SFTExperimentOptions):
@property
def models(self):
@ -56,7 +34,7 @@ class SFTConfig(CommonExperimentConfig):
interface_type=ModelInterfaceType.TRAIN_STEP,
interface_impl=ModelInterfaceAbstraction("sft"),
model_name="default",
input_keys=["packed_input_ids", "prompt_mask"],
input_keys=("packed_input_ids", "prompt_mask"),
log_return_value=True,
model_type=self.model.type,
model_path=self.model.path,

View File

@ -7,23 +7,31 @@ import dataclasses
import enum
import itertools
import re
from typing import *
from typing import (
Any,
Callable,
Dict,
Hashable,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
import numpy as np
from omegaconf import DictConfig, OmegaConf
from realhf.api.cli_args import ModelTrainEvalConfig, ParallelismConfig
from realhf.api.core.config import (
ModelAbstraction,
ModelBackendAbstraction,
ModelInterfaceType,
ModelName,
)
from realhf.api.core.dfg import OffloadHook, ParamReallocHook
from realhf.api.quickstart.device_mesh import RPCAllocation
from realhf.api.quickstart.model import (
ModelTrainEvalConfig,
ParallelismConfig,
parallelism_eq,
)
from realhf.base import logging
from realhf.base.topology import (
DataPipeModelParallelTopology,
@ -34,6 +42,29 @@ from realhf.base.topology import (
logger = logging.getLogger("Experiment Common Utils", "benchmark")
def get_real_model_config(
model_path: str,
hf_model_family: str,
is_critic: bool,
init_from_scratch: bool,
init_critic_from_actor: bool,
dtype: Optional[str] = None,
) -> ModelAbstraction:
"""Make a configuration to build model."""
model = ModelAbstraction(
"real_model",
args=dict(
model_path=model_path,
is_critic=is_critic,
init_critic_from_actor=init_critic_from_actor,
dtype=dtype,
hf_model_family=hf_model_family,
init_from_scratch=init_from_scratch,
),
)
return model
def get_topo(
parallel: ParallelismConfig,
gradient_checkpointing: bool,
@ -72,7 +103,7 @@ def make_train_backend_config(
model_cfg: ModelTrainEvalConfig, parallel_cfg: ParallelismConfig
):
if model_cfg.backend == "megatron":
megatron_args: Dict[str, Any] = OmegaConf.to_container(model_cfg.megatron)
megatron_args: Dict[str, Any] = asdict(model_cfg.megatron)
return ModelBackendAbstraction(
"megatron",
args=dict(
@ -132,8 +163,11 @@ def resolve_replica_ids(
for alloc in allocs:
if alloc.rpc.name == main_alloc.rpc.name:
continue
same_alloc = alloc.device_mesh == main_alloc.device_mesh and parallelism_eq(
alloc.parallel, main_alloc.parallel
same_alloc = (
alloc.device_mesh == main_alloc.device_mesh
and ParallelismConfig.parallelism_eq(
alloc.parallel, main_alloc.parallel
)
)
if not same_alloc or (
alloc.rpc.is_generate()
@ -165,7 +199,7 @@ def resolve_rpc_hooks(
if rpc.role != other.rpc.role:
continue
if (
parallelism_eq(parallel, other.parallel)
ParallelismConfig.parallelism_eq(parallel, other.parallel)
and device_mesh == other.device_mesh
and not (
model_configs[rpc.role].vllm.hybrid_train

View File

@ -58,7 +58,9 @@ def load_metadata(path):
try:
if "task" not in d:
d["task"] = "math"
logger.warning(f'Key "task" not found in the dataset. Use math as default task type.')
logger.warning(
f'Key "task" not found in the dataset. Use math as default task type.'
)
if d["task"] == "math":
d = check_math_metadata_entries(d)
elif d["task"] == "code":
@ -206,7 +208,7 @@ else:
),
),
max_length=512,
dataset_path='/storage/datasets/full_prompts_for_r1_distilled.jsonl'
dataset_path="/storage/datasets/full_prompts_for_r1_distilled.jsonl",
)
dataloader = torch.utils.data.DataLoader(

View File

@ -16,9 +16,9 @@ import torch.distributed as dist
import transformers
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from realhf.api.cli_args import MegatronConfig, MicroBatchSpec, OptimizerConfig
from realhf.api.core import model_api
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
from realhf.api.quickstart.model import MegatronConfig, OptimizerConfig
from realhf.api.core.data_api import SequenceSample
from realhf.base import constants, logging, pkg_version
from realhf.base.datapack import flat2d
from realhf.base.monitor import CUDATimeMarkType, cuda_tmarked
@ -552,7 +552,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
DistributedDataParallelConfig,
)
else:
from realhf.api.quickstart.model import DistributedDataParallelConfig
from realhf.api.cli_args import DistributedDataParallelConfig
self.ddp = DistributedDataParallelConfig(**self.ddp)
with megatron_ctx():
if pkg_version.is_version_less("megatron.core", "0.7.0"):

View File

@ -17,6 +17,7 @@ import torch.multiprocessing as mp
import transformers
from tqdm.asyncio import tqdm
from realhf.api.cli_args import SGLangConfig
from realhf.api.core import data_api
from realhf.api.core.model_api import (
APIGenerateInput,
@ -29,7 +30,6 @@ from realhf.api.core.model_api import (
PipelinableEngine,
register_backend,
)
from realhf.api.quickstart.model import SGLangConfig
from realhf.base import (
cluster,
constants,

View File

@ -29,8 +29,8 @@ except ModuleNotFoundError:
pass
from realhf.api.cli_args import vLLMConfig
from realhf.api.core import data_api, model_api
from realhf.api.quickstart.model import vLLMConfig
from realhf.base import constants, logging, seeding
logger = logging.getLogger("vLLM backend")

View File

@ -18,9 +18,9 @@ import pandas as pd
import realhf.base.cluster
import realhf.base.constants as constants
import realhf.base.logging as logging
from realhf.api.cli_args import ParallelismConfig
from realhf.api.core.dfg import MFCDef, ModelFamily, ModelInterfaceType
from realhf.api.core.model_api import ReaLModelConfig
from realhf.api.quickstart.model import ParallelismConfig
from realhf.search_engine.param_realloc import estimate_param_realloc_time_cost
from realhf.search_engine.utils import load_model_config

View File

@ -19,10 +19,10 @@ except ModuleNotFoundError:
mdm_search = None
import realhf.base.constants as constants
from realhf.api.cli_args import ModelTrainEvalConfig, ParallelismConfig
from realhf.api.core.config import ModelInterfaceType
from realhf.api.core.dfg import MFCDef
from realhf.api.quickstart.device_mesh import DeviceMesh, RPCAllocation
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
from realhf.api.quickstart.search import RPCExecution

View File

@ -218,9 +218,10 @@ class RedistribPlanner:
return list(gather_plan.values()) + list(scatter_plan.values())
def derive_plan_bcast(
self, dests: Dict[int, List[Hashable]], keys: List[str]
self, dests: Dict[int, List[Hashable]], keys: List[str] | Tuple[str]
) -> List[RedistribStep]:
assert isinstance(keys, list), type(keys)
assert isinstance(keys, (list, tuple)), type(keys)
keys = list(keys)
self.dests = dests
# Get all requried data IDs.

View File

@ -44,15 +44,16 @@ tabulate
aiofiles
pydantic
isort==5.13.2
clang-format
clang-format==19.1.7
ninja
paramiko
# To eliminate security risks
torch>2.0.0
black>=25.1.0
black==25.1.0
cookiecutter>2.1.1
asyncio
aiohttp
httpx>=0.28.1
etcd3
protobuf<3.21
protobuf<3.21
rich

View File

@ -15,7 +15,7 @@ import pytest
import torch
import torch.distributed as dist
from realhf.api.core.config import ModelFamily, ModelName, ModelShardID
from realhf.api.core.config import ModelName, ModelShardID
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY, ReaLModelConfig
from realhf.base import constants, logging, testing, topology

View File

@ -11,10 +11,8 @@ from typing import *
import matplotlib.pyplot as plt
import networkx as nx
import pytest
import ray
from ray.util.queue import Queue as RayQueue
from realhf.api.core.config import ModelFamily, ModelInterfaceAbstraction, ModelName
from realhf.api.core.config import ModelInterfaceAbstraction, ModelName
from realhf.api.core.dfg import MFCDef, ModelInterfaceType, build_graph
from realhf.base import logging
@ -125,8 +123,6 @@ def _get_reinforce_rpcs():
@pytest.mark.parametrize("rpcs", [_get_ppo_rpcs(), _get_reinforce_rpcs()])
def test_build_graph(tmp_path: pathlib.Path, rpcs: List[MFCDef]):
if not ray.is_initialized():
ray.init()
G = build_graph(rpcs, verbose=True, graph_path=str(tmp_path / "dfg.png"))
assert nx.is_directed_acyclic_graph(G)
for node in rpcs:
@ -152,14 +148,3 @@ def test_build_graph(tmp_path: pathlib.Path, rpcs: List[MFCDef]):
if k.startswith("_"):
continue
assert v == dataclasses.asdict(node)[k]
# Ensure node can be passed into ray queue
queue = RayQueue(maxsize=8)
queue.put(node)
node_ = queue.get()
for k, v in dataclasses.asdict(node_).items():
if k.startswith("_"):
continue
assert v == dataclasses.asdict(node)[k]
if ray.is_initialized():
ray.shutdown()

View File

@ -7,6 +7,7 @@ import uuid
import pytest
import torch
from torch.utils.data import DataLoader
from realhf.api.core import config as config_api
from realhf.api.core import data_api
@ -20,16 +21,22 @@ def _validate_dataset(cfg: config_api.DatasetAbstraction, tokenizer):
dp_rank=0,
world_size=1,
tokenizer_or_tokenizer_name=tokenizer,
experiment_name=uuid.uuid4(),
trial_name=uuid.uuid4(),
experiment_name=str(uuid.uuid4()),
trial_name=str(uuid.uuid4()),
)
dataloader = DataLoader(
dataset,
collate_fn=data_api.SequenceSample.gather,
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
batch_size=10240,
shuffle=True,
)
dataloader = data_api.PackedDataLoader(dataset)
for x in dataloader:
assert isinstance(x, data_api.SequenceSample)
assert x.data is not None
for k, v in x.data.items():
assert v.device == torch.device("cpu")
bs = len(x.ids)
for k, vs in x.seqlens.items():
assert all(isinstance(v, list) for v in vs)
assert all(all(isinstance(vv, int) for vv in v) for v in vs)
@ -37,7 +44,7 @@ def _validate_dataset(cfg: config_api.DatasetAbstraction, tokenizer):
if x.metadata:
for k, v in x.metadata.items():
assert isinstance(v, list), k
xs = x.split(bs)
xs = x.unpack()
for xx in xs:
if xx.metadata:
for k, v in xx.metadata.items():

View File

@ -7,13 +7,14 @@ from typing import *
import pytest
from realhf.api.core.system_api import ExperimentSaveEvalControl
from realhf.api.quickstart.dataset import (
from realhf.api.cli_args import (
ExperimentSaveEvalControl,
MFCConfig,
ModelTrainEvalConfig,
ParallelismConfig,
PromptAnswerDatasetConfig,
PromptOnlyDatasetConfig,
)
from realhf.api.quickstart.device_mesh import MFCConfig, ParallelismConfig
from realhf.api.quickstart.model import ModelTrainEvalConfig
from realhf.base import cluster, logging, name_resolve, testing
from realhf.experiments.common.null_exp import NullPPOConfig, NullSFTConfig
from tests.experiments.utils import run_test_exp
@ -28,40 +29,40 @@ def model_class(request):
@pytest.fixture(params=[300])
def dataset_with_size(request, save_path):
def math_code_dataset_with_size(request, save_path):
size = request.param
max_prompt_len = 8
max_resp_len = 8
dataset = []
for i in range(size):
prompt_len = random.randint(1, max_prompt_len)
n_pairs = random.randint(1, 5)
d = dict(
id=i,
prompt=generate_random_sentence(prompt_len),
answer=generate_random_sentence(random.randint(1, max_resp_len)),
pos_answers=[
generate_random_sentence(random.randint(1, max_resp_len))
for _ in range(n_pairs)
],
neg_answers=[
generate_random_sentence(random.randint(1, max_resp_len))
for _ in range(n_pairs)
],
query_id=str(uuid.uuid4()),
prompt=generate_random_sentence(prompt_len),
task=random.choice(["math", "code"]),
)
if d["task"] == "math":
d["solutions"] = [generate_random_sentence(max_resp_len)]
elif d["task"] == "code":
d["input_output"] = json.dumps(dict(inputs=["the\n"], outputs=["the\n"]))
dataset.append(d)
with open(str(save_path / "_dataset.json"), "w") as f:
json.dump(dataset, f)
return dataset, size
with open(str(save_path / "math_code_dataset.jsonl"), "a") as f:
f.write(json.dumps(d) + "\n")
return dataset, len(dataset)
@pytest.mark.parametrize("dp", [4])
@pytest.mark.parametrize("bs", [63])
def test_buffer_recover(
bs, tmp_path_factory, dataset_with_size, tokenizer, save_path, cpu_hf_model, dp
bs,
tmp_path_factory,
math_code_dataset_with_size,
tokenizer,
save_path,
cpu_hf_model,
dp,
):
_, dataset_size = dataset_with_size
_, dataset_size = math_code_dataset_with_size
# Setup experiment env. Should be done before any other operations.
log_root = tmp_path_factory.mktemp("buffer-recover")
cluster.spec.fileroot = str(log_root)
@ -100,7 +101,7 @@ def test_buffer_recover(
backend="mock_train",
),
dataset=PromptOnlyDatasetConfig(
path=str(save_path / "_dataset.json"),
path=str(save_path / "math_code_dataset.jsonl"),
max_prompt_len=128,
train_bs_n_seqs=bs,
fill_to_max_length=False,

View File

@ -7,17 +7,18 @@ from typing import *
import pytest
from realhf.api.core.data_api import MicroBatchSpec
from realhf.api.core.system_api import ExperimentSaveEvalControl
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
from realhf.api.quickstart.device_mesh import MFCConfig
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
from realhf.base import cluster, testing
from realhf.experiments.common.ppo_math_exp import (
from realhf.api.cli_args import (
ExperimentSaveEvalControl,
GenerationHyperparameters,
MFCConfig,
MicroBatchSpec,
ModelTrainEvalConfig,
ParallelismConfig,
PPOHyperparameters,
PPOMATHConfig,
PromptOnlyDatasetConfig,
)
from realhf.base import cluster, testing
from realhf.experiments.common.ppo_math_exp import PPOMATHConfig
from tests.experiments.utils import run_test_exp
from tests.fixtures import *

View File

@ -5,8 +5,16 @@ from typing import *
import pytest
from realhf.api.quickstart.dataset import PromptAnswerDatasetConfig
from realhf.api.quickstart.model import ModelTrainEvalConfig
from realhf.api.cli_args import (
ExperimentSaveEvalControl,
GenerationHyperparameters,
MFCConfig,
MicroBatchSpec,
ModelTrainEvalConfig,
ParallelismConfig,
PPOHyperparameters,
PromptOnlyDatasetConfig,
)
from realhf.base import cluster, testing
from realhf.experiments.common.sft_exp import SFTConfig
from tests.experiments.utils import run_test_exp

View File

@ -14,7 +14,7 @@ import torch
import torch.distributed as dist
import transformers
from realhf.api.core.config import ModelFamily
from realhf.api.cli_args import ModelFamily
from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY, ReaLModelConfig
from realhf.base import constants, logging
from realhf.base.testing import (