mirror of https://github.com/inclusionAI/AReaL
PullRequest: 62 [Patch v0.2.0] Move all CLI arguments into a single file and add pretty helper messages.
Merge branch fw/cli-args of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/62 Signed-off-by: 郭唯 <kira.gw@antgroup.com> * . * format and test * . * . * . * . * run * . * . * add runtime helper message * . * .
This commit is contained in:
parent
71429c9655
commit
0bd9969ec4
|
@ -281,7 +281,7 @@ python3 -m realhf.apps.quickstart ppo-math option1=arg1 option2=arg2 ...
|
|||
The command-line arguments like `option1=arg1` are parsed by [hydra](https://hydra.cc/), and each configuration item is a `dataclasses.dataclass` in the Python code. You can use the following command to view all the command-line arguments that can be passed in the experiment:
|
||||
|
||||
```bash
|
||||
python3 -m realhf.apps.quickstart ppo-math --show-args
|
||||
python3 -m realhf.apps.quickstart ppo-math --help
|
||||
```
|
||||
|
||||
The descriptions of the important parameters are as follows:
|
||||
|
|
|
@ -288,7 +288,7 @@ python3 -m realhf.apps.quickstart ppo-math option1=arg1 option2=arg2 ...
|
|||
其中`option1=arg1`这些命令行参数是通过[hydra](https://hydra.cc/)进行解析的,其中每一条配置项都是python代码中的`dataclasses.dataclass`。用以下命令可以查看实验中所有可以传递的命令行参数:
|
||||
|
||||
```bash
|
||||
python3 -m realhf.apps.quickstart ppo-math --show-args
|
||||
python3 -m realhf.apps.quickstart ppo-math --help
|
||||
```
|
||||
|
||||
其中重要的参数的说明如下:
|
||||
|
|
|
@ -5,12 +5,12 @@
|
|||
"""
|
||||
Dataset Toolkit - Process and validate code/math datasets with flexible input support
|
||||
"""
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
# Configure console logging
|
||||
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
|
||||
|
|
|
@ -4,7 +4,7 @@ import json
|
|||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
def load_jsonl(file_path: Path) -> List[Dict]:
|
||||
|
|
|
@ -2,37 +2,23 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
# Initialize preset config before all submodules.
|
||||
from .base import prologue # isort: skip
|
||||
|
||||
from .api.cli_args import *
|
||||
|
||||
# Re-import these classes for clear documentation,
|
||||
# otherwise the name will have a long prefix like
|
||||
# realhf.api.quickstart.model.ModelTrainEvalConfig.
|
||||
from .api.core.config import ModelFamily, ModelName, ModelShardID
|
||||
# otherwise the name will have a long prefix.
|
||||
from .api.core.config import ModelName, ModelShardID
|
||||
from .api.core.data_api import SequenceSample
|
||||
from .api.core.dfg import MFCDef
|
||||
from .api.core.model_api import (
|
||||
FinetuneSpec,
|
||||
GenerationHyperparameters,
|
||||
Model,
|
||||
ModelBackend,
|
||||
ModelInterface,
|
||||
PipelinableEngine,
|
||||
ReaLModelConfig,
|
||||
)
|
||||
from .api.quickstart.dataset import (
|
||||
PairedComparisonDatasetConfig,
|
||||
PromptAnswerDatasetConfig,
|
||||
PromptOnlyDatasetConfig,
|
||||
)
|
||||
from .api.quickstart.device_mesh import MFCConfig
|
||||
from .api.quickstart.model import (
|
||||
ModelTrainEvalConfig,
|
||||
OptimizerConfig,
|
||||
ParallelismConfig,
|
||||
)
|
||||
|
||||
# Initialize preset config before all submodules.
|
||||
from .base import prologue
|
||||
from .experiments.common.common import CommonExperimentConfig, ExperimentSaveEvalControl
|
||||
from .experiments.common.ppo_math_exp import PPOHyperparameters, PPOMATHConfig
|
||||
from .experiments.common.sft_exp import SFTConfig
|
||||
|
||||
__version__ = "0.3.0"
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -4,7 +4,7 @@
|
|||
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import *
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import realhf.base.cluster as cluster
|
||||
import realhf.base.topology as topology
|
||||
|
@ -70,33 +70,6 @@ class ModelName:
|
|||
return str(self)
|
||||
|
||||
|
||||
@dataclasses.dataclass(unsafe_hash=True)
|
||||
class ModelFamily:
|
||||
"""An identifier for the HF model type, such as llama, gpt2, etc.
|
||||
|
||||
:param _class: The class of the model, e.g., "llama". This is the registered
|
||||
name in the ``register_hf_family`` function. Please refer to the files
|
||||
in ``realhf/api/from_hf`` for a list of all supported models.
|
||||
:type _class: str
|
||||
:param size: The size of the model. This parameter is only used by the ``search``
|
||||
allocation mode and will be ignored otherwise.
|
||||
:type size: int
|
||||
:param is_critic: Indicates whether the model is a critic or reward model,
|
||||
as opposed to a standard LLM.
|
||||
:type is_critic: bool
|
||||
"""
|
||||
|
||||
_class: str
|
||||
size: int = 0
|
||||
is_critic: bool = False
|
||||
|
||||
def __repr__(self):
|
||||
s = f"{self._class}-{self.size}"
|
||||
if self.is_critic:
|
||||
s += "-critic"
|
||||
return s
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelShardID:
|
||||
"""The ID of a model shard in a specific model worker.
|
||||
|
|
|
@ -37,6 +37,7 @@ from pydantic import Field
|
|||
from pydantic import dataclasses as pdclasses
|
||||
from pydantic import field_validator, model_validator
|
||||
|
||||
from realhf.api.cli_args import MicroBatchSpec
|
||||
from realhf.api.core import config as config_api
|
||||
from realhf.base import constants, datapack, logging
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
|
@ -100,33 +101,6 @@ class SequenceSplitSpec:
|
|||
return self
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MicroBatchSpec:
|
||||
"""The specification for splitting micro-batches.
|
||||
|
||||
:param n_mbs: The number of micro-batches, if max_tokens_per_mb is
|
||||
None. The *minimum* number of micro-batches, if
|
||||
max_tokens_per_mb is an integer. Defaults to 1.
|
||||
:type n_mbs: int
|
||||
:param max_tokens_per_mb: The maximum number of tokens per micro-
|
||||
batch.
|
||||
:type max_tokens_per_mb: Optional[int]
|
||||
"""
|
||||
|
||||
n_mbs: int = 1
|
||||
max_tokens_per_mb: int = int(1e12)
|
||||
|
||||
@classmethod
|
||||
def new(cls, mb_spec: "MicroBatchSpec", **kwargs):
|
||||
# NOTE: Use classmethod to make the Omegaconf duck object happy.
|
||||
fields = dict(
|
||||
n_mbs=mb_spec.n_mbs,
|
||||
max_tokens_per_mb=mb_spec.max_tokens_per_mb,
|
||||
)
|
||||
fields.update(kwargs)
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
@pdclasses.dataclass(config=dict(arbitrary_types_allowed=True))
|
||||
class SequenceSample:
|
||||
"""The data structure used to represent sequence data.
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
|
||||
import collections
|
||||
import dataclasses
|
||||
from typing import *
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.cli_args import ModelFamily
|
||||
from realhf.api.core.config import (
|
||||
ModelFamily,
|
||||
ModelInterfaceAbstraction,
|
||||
ModelInterfaceType,
|
||||
ModelName,
|
||||
|
|
|
@ -6,16 +6,16 @@ import abc
|
|||
import asyncio
|
||||
import dataclasses
|
||||
import keyword
|
||||
from typing import *
|
||||
from typing import Any, Callable, Dict, Hashable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.cli_args import GenerationHyperparameters
|
||||
from realhf.api.core.config import (
|
||||
ModelAbstraction,
|
||||
ModelBackendAbstraction,
|
||||
|
@ -33,87 +33,6 @@ class ZeroTotalLossWeightException(Exception):
|
|||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GenerationHyperparameters:
|
||||
"""Generation hyperparameters.
|
||||
|
||||
We implement a customized generation function instead of using
|
||||
HuggingFace's to support pipelined generation. As a result, advanced
|
||||
generation techniques like diversity-promoting sampling or
|
||||
repetition penalty are not supported during PPO training. However,
|
||||
we do not find this to be a problem in practice. Increasing the
|
||||
sampling temperature and enabling top-k/top-p sampling can produce
|
||||
effective models.
|
||||
|
||||
:param n: The number of sequences to generate for this prompt.
|
||||
:type n: int
|
||||
:param max_new_tokens: The maximum number of new tokens to generate.
|
||||
:type max_new_tokens: int
|
||||
:param min_new_tokens: The minimum number of new tokens to generate.
|
||||
:type min_new_tokens: int
|
||||
:param greedy: Whether to use greedy decoding.
|
||||
:type greedy: bool
|
||||
:param top_k: The number of highest probability tokens to keep.
|
||||
:type top_k: int
|
||||
:param top_p: The cumulative probability of the highest probability
|
||||
tokens to keep.
|
||||
:type top_p: float
|
||||
:param temperature: The temperature of the sampling process.
|
||||
:type temperature: float
|
||||
:param use_cuda_graph: Whether to use CUDA graph to reduce kernel
|
||||
launch overhead during generation.
|
||||
:type use_cuda_graph: bool
|
||||
:param force_cudagraph_recapture: Whether to capture the CUDA graph
|
||||
every time `generate` is called, even if the graph has been captured
|
||||
before. This will introduce minor overhead but will release the
|
||||
kvcache when not running generation.
|
||||
:type force_cudagraph_recapture: bool
|
||||
:param force_no_logits_mask: Whether to omit the logits mask. The logits
|
||||
mask is produced when using top-k or top-p sampling, marking tokens
|
||||
that are filtered out. This mask is used by the reference model and
|
||||
the actor model during training to align inferred logits with those
|
||||
during generation and produce accurate KLs. Using the logits mask with
|
||||
top-k/top-p sampling greatly improves the stability of PPO training
|
||||
by narrowing the action space. However, this benefit comes at the cost
|
||||
of additional GPU memory usage. If this option is set to True, the
|
||||
logits mask will be omitted to save GPU memory, which may lead to a
|
||||
decrease in learning performance.
|
||||
:type force_no_logits_mask: bool
|
||||
"""
|
||||
|
||||
n: int = 1
|
||||
max_new_tokens: int = 256
|
||||
min_new_tokens: int = 256
|
||||
greedy: bool = False
|
||||
top_p: float = 1.0
|
||||
top_k: int = int(1e8)
|
||||
temperature: float = 1.0
|
||||
use_cuda_graph: bool = True
|
||||
force_cudagraph_recapture: bool = True
|
||||
force_no_logits_mask: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.temperature == 0.0:
|
||||
self.greedy = True
|
||||
self.temperature = 1.0
|
||||
if self.top_p <= 0.0 or self.top_p > 1:
|
||||
raise ValueError("top_p must be in (0.0, 1.0].")
|
||||
if self.top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer.")
|
||||
|
||||
if self.use_cuda_graph and Version(
|
||||
Version(torch.__version__).base_version
|
||||
) < Version("2.3.0"):
|
||||
raise ValueError(
|
||||
f"To use CUDAGraph, ReaL's PyTorch version should be at least 2.3.0."
|
||||
)
|
||||
|
||||
def new(self, **kwargs):
|
||||
args = dataclasses.asdict(self)
|
||||
args.update(kwargs)
|
||||
return GenerationHyperparameters(**args)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class APIGenerateInput:
|
||||
qid: Hashable
|
||||
|
|
|
@ -7,6 +7,12 @@ import os
|
|||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import realhf.api.core.dfg as dfg
|
||||
from realhf.api.cli_args import (
|
||||
AutomaticEvaluator,
|
||||
ExperimentSaveEvalControl,
|
||||
TensorBoardConfig,
|
||||
WandBConfig,
|
||||
)
|
||||
from realhf.api.core.config import (
|
||||
DatasetAbstraction,
|
||||
ModelAbstraction,
|
||||
|
@ -134,68 +140,6 @@ class ModelWorker:
|
|||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExperimentSaveEvalControl:
|
||||
"""Utility object for controlling the frequency of saving and evaluation
|
||||
during training.
|
||||
|
||||
``Epoch`` refers to the number of times the training loop iterates over the entire dataset.
|
||||
``Step`` refers to the number of iterations running the algorithm dataflow.
|
||||
|
||||
This object manages independent counters for epochs, steps, and seconds. The model will
|
||||
be saved or evaluated when any of the following conditions are met.
|
||||
|
||||
:param total_train_epochs: The total number of epochs to train the model.
|
||||
:type total_train_epochs: int
|
||||
:param save_freq_epochs: Frequency in epochs at which to save the model. If None,
|
||||
the model will not be saved based on epoch changes during training.
|
||||
:type save_freq_epochs: Optional[int]
|
||||
:param save_freq_steps: Frequency in steps at which to save the model. If None,
|
||||
the model will not be saved based on step changes during training.
|
||||
:type save_freq_steps: Optional[int]
|
||||
:param save_freq_secs: Frequency in seconds at which to save the model. If None,
|
||||
the model will not be saved based on time changes during training.
|
||||
:type save_freq_secs: Optional[int]
|
||||
:param ckpt_freq_epochs: Frequency in epochs at which to save the model for recover.
|
||||
The preivous checkpoint will be overwritten to reduce disk usage. If None, use save_freq_epochs.
|
||||
:type ckpt_freq_epochs: Optional[int]
|
||||
:param ckpt_freq_steps: Frequency in steps at which to save the model for recover. If None,
|
||||
the model will not be saved based on step changes during training.
|
||||
:type ckpt_freq_steps: Optional[int]
|
||||
:param ckpt_freq_secs: Frequency in seconds at which to save the model for recover. If None,
|
||||
the model will not be saved based on time changes during training.
|
||||
:type ckpt_freq_secs: Optional[int]
|
||||
:param eval_freq_epochs: Frequency in epochs at which to evaluate the model. If None,
|
||||
the model will not be evaluated based on epoch changes during training.
|
||||
:type eval_freq_epochs: Optional[int]
|
||||
:param eval_freq_steps: Frequency in steps at which to evaluate the model. If None,
|
||||
the model will not be evaluated based on step changes during training.
|
||||
:type eval_freq_steps: Optional[int]
|
||||
:param eval_freq_secs: Frequency in seconds at which to evaluate the model. If None,
|
||||
the model will not be evaluated based on time changes during training.
|
||||
:type eval_freq_secs: Optional[int]
|
||||
:param benchmark_steps: Terminate training after this number of steps. Used for system
|
||||
benchmarking only. Set to None for normal training.
|
||||
:type benchmark_steps: Optional[int]
|
||||
"""
|
||||
|
||||
total_train_epochs: int = 1
|
||||
# save control
|
||||
save_freq_epochs: Optional[int] = None
|
||||
save_freq_steps: Optional[int] = None
|
||||
save_freq_secs: Optional[int] = None
|
||||
# checkpointing control, only used for recover
|
||||
ckpt_freq_epochs: Optional[int] = None
|
||||
ckpt_freq_steps: Optional[int] = None
|
||||
ckpt_freq_secs: Optional[int] = None
|
||||
# eval control
|
||||
eval_freq_epochs: Optional[int] = None
|
||||
eval_freq_steps: Optional[int] = None
|
||||
eval_freq_secs: Optional[int] = None
|
||||
# benchmark
|
||||
benchmark_steps: Optional[int] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MasterWorker:
|
||||
base_seed: int
|
||||
|
@ -227,54 +171,6 @@ class ExperimentScheduling:
|
|||
controller_image: str = _LLM_CPU_IMAGE
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AutomaticEvaluator:
|
||||
"""Configuration for automatic evaluation.
|
||||
:param data_names: Dataset for evaluation seperated by comma. Currently support datasets stored under ./evaluation/data,
|
||||
including "aime24", "amc23" and "math_500". For example, if "aime24" and "amc23" are required for evaluation,
|
||||
this field should be set to "aime24,amc23".
|
||||
:type data_names: str
|
||||
:param max_gen_tokens: Maximum number of tokens to be generated in evaluation.
|
||||
:type max_gen_tokens: int
|
||||
:param max_concurrent_jobs: Maximum number of concurrent evaluation jobs to submit. If number of existing jobs is equal to
|
||||
`max_concurrent_jobs` and a new checkpoint is saved, the evaluation job will wait until former jobs complete.
|
||||
:type max_concurrent_jobs: int
|
||||
:param eval_job_image: Container image used to launch evaluation job. If set to None, evaluation jobs will use
|
||||
GPU image for training.
|
||||
:type eval_job_image: Optional[str]
|
||||
:param initial_checkpoint_path: Initial checkpoint path to evaluate. If specified, this initial checkpoint will be evaluated,
|
||||
results will be stored as global_step = 0.
|
||||
:type initial_checkpoint_path: Optional[str]
|
||||
:param prompt_type: Prompt format used in evaluation.
|
||||
:type prompt_type: str
|
||||
"""
|
||||
|
||||
data_names: str = "aime24"
|
||||
max_gen_tokens: int = 32768
|
||||
max_concurrent_jobs: int = 3
|
||||
eval_job_image: Optional[str] = None
|
||||
initial_checkpoint_path: Optional[str] = None
|
||||
prompt_type: str = "deepscaler"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class WandBConfig:
|
||||
mode: str = "disabled"
|
||||
entity: Optional[str] = None
|
||||
project: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
job_type: Optional[str] = None
|
||||
group: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
config: Optional[Dict] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TensorBoardConfig:
|
||||
path: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExperimentConfig:
|
||||
exp_ctrl: ExperimentSaveEvalControl
|
||||
|
|
|
@ -1,98 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import dataclasses
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PromptAnswerDatasetConfig:
|
||||
"""Configuration for datasets used in Supervised Fine-Tuning (SFT).
|
||||
|
||||
The raw data must be in a JSON or JSONL file format, where each entry is a dictionary
|
||||
with the keys `prompt` and `answer`. Both `prompt` and `answer` must be strings.
|
||||
|
||||
:param train_path: Path to the training dataset.
|
||||
:type train_path: str
|
||||
:param valid_path: Path to the validation dataset.
|
||||
:type valid_path: str
|
||||
:param max_seqlen: Maximum sequence length (prompt + answer). Sequences longer than
|
||||
this will be truncated.
|
||||
:type max_seqlen: int
|
||||
:param train_bs_n_seqs: Number of sequences in each batch during training.
|
||||
:type train_bs_n_seqs: int
|
||||
:param valid_bs_n_seqs: Number of sequences in each batch during validation.
|
||||
:type valid_bs_n_seqs: int
|
||||
:param fill_to_max_length: Whether to fill sequences to the maximum length. If True,
|
||||
prompts will be left-filled with non-pad tokens. Only used for testing.
|
||||
:type fill_to_max_length: bool
|
||||
"""
|
||||
|
||||
train_path: str = ""
|
||||
valid_path: str = ""
|
||||
max_seqlen: int = 1024
|
||||
train_bs_n_seqs: int = 256
|
||||
valid_bs_n_seqs: int = 256
|
||||
fill_to_max_length: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PairedComparisonDatasetConfig:
|
||||
"""Configuration for datasets used in paired-comparison reward modeling,
|
||||
DPO, and SimPO.
|
||||
|
||||
The raw data must be in a JSON or JSONL file format, where each entry is a dictionary
|
||||
with the keys `prompt`, `pos_answers`, and `neg_answers`. `prompt` is a string, while
|
||||
`pos_answers` and `neg_answers` are lists of strings. The lists must have the same length.
|
||||
|
||||
The raw dataset may contain multiple answer pairs for each prompt. In each epoch, we will
|
||||
randomly sample `max_pairs_per_prompt` answer pairs for each prompt, so the maximum batch
|
||||
size (in terms of the number of sequences) per step is `train_bs_n_seqs` multiplied by
|
||||
`max_pairs_per_prompt`.
|
||||
|
||||
:param train_path: Path to the training dataset.
|
||||
:type train_path: str
|
||||
:param valid_path: Path to the evaluation dataset.
|
||||
:type valid_path: str
|
||||
:param max_pairs_per_prompt: Maximum number of answer pairs per prompt.
|
||||
:type max_pairs_per_prompt: int
|
||||
:param max_seqlen: Maximum sequence length (prompt + answers). Sequences longer than
|
||||
this will be truncated.
|
||||
:type max_seqlen: int
|
||||
:param train_bs_n_seqs: Number of sequences in each batch during training.
|
||||
:type train_bs_n_seqs: int
|
||||
:param valid_bs_n_seqs: Number of sequences in each batch during validation.
|
||||
:type valid_bs_n_seqs: int
|
||||
"""
|
||||
|
||||
train_path: str = ""
|
||||
valid_path: str = ""
|
||||
max_pairs_per_prompt: int = 2
|
||||
max_seqlen: int = 1024
|
||||
train_bs_n_seqs: int = 256
|
||||
valid_bs_n_seqs: int = 256
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PromptOnlyDatasetConfig:
|
||||
"""Configuration for datasets used in PPO RLHF.
|
||||
|
||||
The raw data must be in a JSON or JSONL file format, where each entry is a dictionary
|
||||
with a single key called `prompt`, which is a string.
|
||||
|
||||
:param path: Path to the dataset.
|
||||
:type path: str
|
||||
:param max_prompt_len: Maximum length of the prompt. Prompts longer than this will
|
||||
be truncated.
|
||||
:type max_prompt_len: int
|
||||
:param train_bs_n_seqs: Number of prompts in each batch.
|
||||
:type train_bs_n_seqs: int
|
||||
:param fill_to_max_length: Whether to fill prompts to the maximum length. If True,
|
||||
prompts will be left-filled with non-pad tokens. Only used for testing.
|
||||
:type fill_to_max_length: bool
|
||||
"""
|
||||
|
||||
path: str = ""
|
||||
max_prompt_len: int = 256
|
||||
train_bs_n_seqs: int = 256
|
||||
fill_to_max_length: bool = False
|
|
@ -9,8 +9,8 @@ from typing import List, Optional, Tuple, Union
|
|||
|
||||
import numpy as np
|
||||
|
||||
from realhf.api.core.dfg import MFCDef, MicroBatchSpec
|
||||
from realhf.api.quickstart.model import ParallelismConfig
|
||||
from realhf.api.cli_args import ParallelismConfig
|
||||
from realhf.api.core.dfg import MFCDef
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
from realhf.base.slurm_utils import (
|
||||
are_ones_contiguous,
|
||||
|
@ -344,29 +344,3 @@ class RPCAllocation:
|
|||
device_mesh=DeviceMesh.from_dict(d["device_mesh"]),
|
||||
parallel=ParallelismConfig(**d["parallel"]),
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MFCConfig:
|
||||
"""Configuration for a single MFC.
|
||||
|
||||
:param mb_spec: Specifying how to spliting micro-batches when
|
||||
executing this MFC. Refer to MicroBatchSpec for details.
|
||||
:type mb_spec: MicroBatchSpec
|
||||
:param parallel: Configuration for the parallelism strategy. This is
|
||||
used only for manual allocation.
|
||||
:type parallel: ParallelismConfig
|
||||
:param device_mesh: String representation of the device mesh. If it
|
||||
consists of multiple nodes, it should be formatted as a SLURM
|
||||
nodelist, e.g., node[01-02] or node01,node02. If it represents a
|
||||
slice on a single node, it should occupy 1, 2, 4, or 8
|
||||
contiguous GPUs on the node. In this case, the string
|
||||
representation is similar to an MPI hostfile, e.g.,
|
||||
"node01:0,1,2,3" for the first 4 GPUs on node01. This is used
|
||||
only for manual allocation.
|
||||
:type device_mesh: Optional[str]
|
||||
"""
|
||||
|
||||
mb_spec: MicroBatchSpec = dataclasses.field(default_factory=MicroBatchSpec)
|
||||
parallel: ParallelismConfig = dataclasses.field(default_factory=ParallelismConfig)
|
||||
device_mesh: Optional[str] = None
|
||||
|
|
|
@ -18,6 +18,7 @@ from hydra.core.config_store import ConfigStore
|
|||
from omegaconf import MISSING, OmegaConf
|
||||
|
||||
import realhf.api.core.system_api as system_api
|
||||
from realhf.api.cli_args import print_runtime_helper
|
||||
from realhf.base.constants import LOG_ROOT, MODEL_SAVE_ROOT, QUICKSTART_EXPR_CACHE_PATH
|
||||
from realhf.base.ray_utils import check_ray_availability
|
||||
from realhf.base.slurm_utils import check_slurm_availability
|
||||
|
@ -68,6 +69,8 @@ def register_quickstart_exp(config_name: str, exp_cls: Callable):
|
|||
|
||||
logger = logging.getLogger("quickstart", "colored")
|
||||
|
||||
print_runtime_helper(OmegaConf.to_object(args))
|
||||
|
||||
exp_name = args.experiment_name
|
||||
if args.trial_name == MISSING:
|
||||
args.trial_name = trial_name = (
|
||||
|
|
|
@ -1,336 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import dataclasses
|
||||
from logging import disable
|
||||
from typing import *
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.core.config import (
|
||||
ModelAbstraction,
|
||||
ModelFamily,
|
||||
ModelWrapperAbstraction,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("Quickstart Model Config")
|
||||
|
||||
|
||||
@dataclasses.dataclass(unsafe_hash=True)
|
||||
class ParallelismConfig:
|
||||
"""Configuration for 3D parallelism.
|
||||
|
||||
:param model_parallel_size: Size of tensor-model parallelism.
|
||||
:type model_parallel_size: int
|
||||
:param pipeline_parallel_size: Number of pipeline parallelism
|
||||
stages.
|
||||
:type pipeline_parallel_size: int
|
||||
:param data_parallel_size: Data parallelism size for ZeRO
|
||||
optimization.
|
||||
:type data_parallel_size: int
|
||||
:param use_sequence_parallel: Whether to use sequence parallelism in
|
||||
Megatron in combination with tensor-model parallelism.
|
||||
:type use_sequence_parallel: bool
|
||||
"""
|
||||
|
||||
model_parallel_size: int = 1
|
||||
pipeline_parallel_size: int = 1
|
||||
data_parallel_size: int = 1
|
||||
use_sequence_parallel: bool = False
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"Parallel(mp={self.model_parallel_size},"
|
||||
f"pp={self.pipeline_parallel_size},"
|
||||
f"dp={self.data_parallel_size})"
|
||||
)
|
||||
|
||||
|
||||
def parallelism_eq(this, other):
|
||||
# NOTE: We write this function because
|
||||
# 1) we don't want to compare sequence_parallelism (it's irrelevant to parameter reallocation)
|
||||
# 2) implementing this function as a method of ParallelismConfig would cause a OmegaConf bug
|
||||
return (
|
||||
(this.model_parallel_size == other.model_parallel_size)
|
||||
and (this.pipeline_parallel_size == other.pipeline_parallel_size)
|
||||
and (this.data_parallel_size == other.data_parallel_size)
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class OptimizerConfig:
|
||||
"""Configuration for the optimizer.
|
||||
|
||||
For models that will not be trained, the optimizer type should be
|
||||
set to "empty".
|
||||
|
||||
:param type: Type of optimizer. Currently, only "adam" and "empty"
|
||||
optimizers are supported.
|
||||
:type type: str
|
||||
:param lr: Learning rate.
|
||||
:type lr: float
|
||||
:param weight_decay: Weight decay.
|
||||
:type weight_decay: float
|
||||
:param beta1: Adam beta1 parameter.
|
||||
:type beta1: float
|
||||
:param beta2: Adam beta2 parameter.
|
||||
:type beta2: float
|
||||
:param eps: Adam epsilon parameter in the denominator.
|
||||
:type eps: float
|
||||
:param min_lr_ratio: Minimum learning rate ratio after learning rate
|
||||
annealing. Should be in the interval [0.0, 1.0].
|
||||
:type min_lr_ratio: float
|
||||
:param lr_scheduler_type: Type of learning rate scheduler. One of
|
||||
"linear", "cosine", or "constant".
|
||||
:type lr_scheduler_type: str
|
||||
:param warmup_steps_proportion: Proportion of total training steps
|
||||
allocated for warming up. Should be in the interval [0.0, 1.0].
|
||||
:type warmup_steps_proportion: float
|
||||
"""
|
||||
|
||||
type: str = dataclasses.field(
|
||||
metadata={"choices": ["adam", "empty"]},
|
||||
default="adam",
|
||||
)
|
||||
lr: float = 1e-5
|
||||
weight_decay: float = 0.05
|
||||
beta1: float = 0.9
|
||||
beta2: float = 0.95
|
||||
eps: float = 1e-5
|
||||
min_lr_ratio: float = 0.0
|
||||
lr_scheduler_type: str = dataclasses.field(
|
||||
metadata={"choices": ["linear", "cosine", "constant"]},
|
||||
default="cosine",
|
||||
)
|
||||
warmup_steps_proportion: float = 0.02
|
||||
offload: bool = False
|
||||
initial_loss_scale: float = 2**32
|
||||
min_loss_scale: float = 1.0
|
||||
loss_scale_window: float = 5
|
||||
hysteresis: int = 2
|
||||
gradient_clipping: float = 1.0
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class vLLMConfig:
|
||||
max_num_seqs: int = 256
|
||||
kv_cache_type: str = "auto"
|
||||
num_scheduler_steps: int = 1
|
||||
multi_step_stream_outputs: bool = True
|
||||
block_size: int = 16
|
||||
swap_space: int = 4
|
||||
cpu_offload_gb: float = 0
|
||||
max_seq_len_to_capture: int = 32768
|
||||
|
||||
disable_sliding_window: bool = True
|
||||
|
||||
# NOTE: Defaults max_model_len to 32k because a larger value
|
||||
# will enable chunked prefill in vLLM, which will cause
|
||||
# evalution performance degeneration.
|
||||
max_model_len: Optional[int] = 32768
|
||||
enable_chunked_prefill: bool = False
|
||||
|
||||
# NOTE: Setting enable_prefix_caching to False
|
||||
# because it will reuse the block after
|
||||
# model weights are updated. Using v0.7.2 reset_prefix_cache
|
||||
# will fix this issue.
|
||||
enable_prefix_caching: bool = False
|
||||
|
||||
gpu_memory_utilization: float = 0.9
|
||||
|
||||
enforce_eager: bool = False
|
||||
hybrid_train: bool = False
|
||||
additional_engine_args: Dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SGLangConfig:
|
||||
disable_cuda_graph: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
disable_cuda_graph_padding: bool = False
|
||||
enable_nccl_nvls: bool = False
|
||||
disable_outlines_disk_cache: bool = False
|
||||
disable_custom_all_reduce: bool = False
|
||||
disable_mla: bool = False
|
||||
disable_overlap_schedule: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
enable_ep_moe: bool = False
|
||||
enable_torch_compile: bool = False
|
||||
torch_compile_max_bs: int = 32
|
||||
cuda_graph_max_bs: Optional[int] = None
|
||||
cuda_graph_bs: Optional[List[int]] = None
|
||||
torchao_config: str = ""
|
||||
enable_nan_detection: bool = False
|
||||
enable_p2p_check: bool = False
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
triton_attention_num_kv_splits: int = 8
|
||||
num_continuous_decode_steps: int = 1
|
||||
enable_memory_saver: bool = False
|
||||
allow_auto_truncate: bool = False
|
||||
# NOTE: to avoid the illegal memory access error
|
||||
attention_backend: Optional[str] = "triton"
|
||||
sampling_backend: Optional[str] = None
|
||||
context_length: Optional[int] = None
|
||||
mem_fraction_static: Optional[float] = None
|
||||
max_running_requests: Optional[int] = None
|
||||
max_total_tokens: Optional[int] = None
|
||||
chunked_prefill_size: Optional[int] = None
|
||||
max_prefill_tokens: int = 16384
|
||||
schedule_policy: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
hybrid_train: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DistributedDataParallelConfig:
|
||||
"""Configuration for Megatron DistributedDataParallel.
|
||||
Some default options have been overwritten.
|
||||
"""
|
||||
|
||||
grad_reduce_in_fp32: bool = False
|
||||
overlap_grad_reduce: bool = True
|
||||
overlap_param_gather: bool = False
|
||||
align_param_gather: bool = False
|
||||
use_distributed_optimizer: bool = True
|
||||
check_for_nan_in_grad: bool = False
|
||||
bucket_size: Optional[int] = None
|
||||
average_in_collective: bool = False
|
||||
fp8_param_gather: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MegatronConfig:
|
||||
"""When using the DistributedOptimizer of Megatron, parameters and
|
||||
gradients will not be splitted across DP ranks, but optimizer states will
|
||||
be. In other words, Megatron only supports ZeRO-1.
|
||||
|
||||
Megatron DDP will split the whole flattend parameter into buckets.
|
||||
Buckets do not respect parameter boundaries and are dispatched to different DP ranks.
|
||||
The optimizer on a specific DP rank will only manage its own bucket,
|
||||
but parameters and gradients are held by all ranks and will not be further splitted.
|
||||
(That's why only optimizer states are partitioned.) During backward, bucket gradients
|
||||
will be scatter-reduced (controlled by the `use_distributed_optimizer` option
|
||||
in Megatron DDP, otherwise all-reduce will be issued), and parameters will then
|
||||
be updated locally. At this point, the parameters are not synced across DP ranks.
|
||||
The DistributedOptimizer will then call all-gather on parameters.
|
||||
|
||||
Since Megatron allocates static tensors for scatter-reducing parameter gradients,
|
||||
it does not decrease memory usage just as DeepSpeed ZeRO-2. To be more specific,
|
||||
with dynamic allocation, we can allocate gradient memory layer-by-layer. When the
|
||||
backward finishes at layer N, we can scatter-reduce gradients and release the memory
|
||||
after scattering. As a result, given DP size K, layer number L, and parameter size P
|
||||
for each layer, dynamic allocation requires P * (1 + L/K) memory for gradients,
|
||||
but Megatron requires P * L. Memory is not freed after scattering in Megatron.
|
||||
|
||||
'use_distributed_optimizer' enables bucketing and scatter-reduce gradients.
|
||||
When setting to False, optimizer states will not be partitioned.
|
||||
|
||||
'overlap_grad_reduce' enables issuing all-reduce/scatter-reduce on the fly
|
||||
during bacwkard once the gradient is ready, which should usually be enabled.
|
||||
|
||||
'overlap_param_gather' overlaps param all-gather with the next forward pass.
|
||||
It creates a forward hook that waits for the previous parameter all-gather
|
||||
after the optimizer step. While this sounds good, it can be problematic with
|
||||
parameter reallocation, because the reallocated parameters do not have the hook.
|
||||
Can be enabled for SFT, but should be disabled for PPO.
|
||||
|
||||
As a final note, Megatron is in an awkward place for PPO with param-realloc.
|
||||
First, it does not minimize the memory usage of gradients (i.e., ZeRO-2).
|
||||
Second, for functional correctness, we can't enable `overlap_param_gather`,
|
||||
and a parameter update will be scatter-reduce grad + all-gather param, instead
|
||||
of an all-reduce (running all-reduce requires setting `use_distributed_optimizer`
|
||||
to False, but that will not partition optimizer states!), so it is not that
|
||||
efficient, either. We use Megatron because it is the only backend that we can
|
||||
make it functionally correct. The DeepSpeed code is too hard to read and modify.
|
||||
"""
|
||||
|
||||
ddp: DistributedDataParallelConfig = dataclasses.field(
|
||||
default_factory=DistributedDataParallelConfig
|
||||
)
|
||||
# Don't use MegatronOptimizerConfig here because OmegaConf
|
||||
# does not recognize the annotation "torch.dtype"
|
||||
overlap_param_gather_with_optimizer_step: bool = False
|
||||
|
||||
use_precision_aware_optimizer: bool = False
|
||||
main_grads_dtype: str = "float32"
|
||||
main_params_dtype: str = "float32"
|
||||
exp_avg_dtype: str = "float32"
|
||||
exp_avg_sq_dtype: str = "float32"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelTrainEvalConfig:
|
||||
"""Runtime configuration for models (or LLMs) in ReaL.
|
||||
|
||||
We use a customized model class instead of HuggingFace's. This customized model has
|
||||
the following highlights:
|
||||
|
||||
1. Support for 3D parallelism and sequence parallelism.
|
||||
|
||||
2. Support for flash attention during both training and generation.
|
||||
|
||||
3. Input sequences are packed into a single 1D tensor to save GPU memory and improve efficiency.
|
||||
|
||||
Consequently, each HuggingFace model of interest needs to be manually converted to this
|
||||
customized model. Implemented models can be found in the ``realhf/api/from_hf/`` directory.
|
||||
|
||||
:param type: Model family type, e.g., llama, qwen2, etc.
|
||||
:type type: ModelFamily
|
||||
:param backend: Backend for training. Currently, only "megatron" and "deepspeed" are supported.
|
||||
Use "deepspeed" for offloading parameters or optimizer states, and "megatron" for
|
||||
parameter reallocation.
|
||||
:type backend: str
|
||||
:param path: Path of the HuggingFace checkpoint.
|
||||
:type path: str
|
||||
:param gradient_checkpointing: Whether to use gradient checkpointing to save memory.
|
||||
:type gradient_checkpointing: bool
|
||||
:param bf16: Whether to use bf16 precision. Otherwise use fp16.
|
||||
:type bf16: bool
|
||||
:param parallel: Configuration for parallelism.
|
||||
:type parallel: ParallelismConfig
|
||||
:param optimizer: Configuration for the optimizer.
|
||||
:type optimizer: Optional[OptimizerConfig]
|
||||
:param init_critic_from_actor: Whether to initialize a critic/reward model from a saved LM checkpoint.
|
||||
:type init_critic_from_actor: bool
|
||||
"""
|
||||
|
||||
type: ModelFamily = dataclasses.field(default=ModelFamily("llama", 7, False))
|
||||
backend: str = dataclasses.field(
|
||||
default="megatron", metadata={"choices": ["megatron", "deepspeed"]}
|
||||
)
|
||||
path: str = ""
|
||||
gradient_checkpointing: bool = True
|
||||
bf16: bool = False
|
||||
optimizer: Optional[OptimizerConfig] = dataclasses.field(
|
||||
default_factory=OptimizerConfig
|
||||
)
|
||||
megatron: MegatronConfig = dataclasses.field(default_factory=MegatronConfig)
|
||||
vllm: vLLMConfig = dataclasses.field(default_factory=vLLMConfig)
|
||||
sglang: SGLangConfig = dataclasses.field(default_factory=SGLangConfig)
|
||||
init_from_scratch: bool = False
|
||||
init_critic_from_actor: bool = False
|
||||
|
||||
|
||||
def get_real_model_config(
|
||||
model_path: str,
|
||||
hf_model_family: str,
|
||||
is_critic: bool,
|
||||
init_from_scratch: bool,
|
||||
init_critic_from_actor: bool,
|
||||
dtype: Optional[str] = None,
|
||||
) -> ModelAbstraction:
|
||||
"""Make a configuration to build model."""
|
||||
model = ModelAbstraction(
|
||||
"real_model",
|
||||
args=dict(
|
||||
model_path=model_path,
|
||||
is_critic=is_critic,
|
||||
init_critic_from_actor=init_critic_from_actor,
|
||||
dtype=dtype,
|
||||
hf_model_family=hf_model_family,
|
||||
init_from_scratch=init_from_scratch,
|
||||
),
|
||||
)
|
||||
return model
|
|
@ -5,9 +5,9 @@
|
|||
import dataclasses
|
||||
from typing import List, Optional
|
||||
|
||||
from realhf.api.cli_args import ParallelismConfig
|
||||
from realhf.api.core.dfg import MFCDef
|
||||
from realhf.api.quickstart.device_mesh import DeviceMesh
|
||||
from realhf.api.quickstart.model import ParallelismConfig
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
|
@ -311,7 +311,7 @@ def main_find_config(args):
|
|||
|
||||
|
||||
def main_profile_layers(args):
|
||||
from realhf.api.core.model_api import ModelFamily
|
||||
from realhf.api.cli_args import ModelFamily
|
||||
|
||||
_main_profile_layers(
|
||||
ModelFamily(args.model_class, args.model_size, args.is_critic),
|
||||
|
@ -320,7 +320,7 @@ def main_profile_layers(args):
|
|||
|
||||
|
||||
def _main_profile_layers(model_family, model_path):
|
||||
from realhf.api.core.model_api import ModelFamily
|
||||
from realhf.api.cli_args import ModelFamily
|
||||
from realhf.base.slurm_utils import check_slurm_availability
|
||||
from realhf.base.testing import clear_name_resolve
|
||||
|
||||
|
|
|
@ -12,8 +12,10 @@ import sys
|
|||
|
||||
import hydra
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from rich.panel import Panel
|
||||
|
||||
from realhf.api.quickstart.entrypoint import QUICKSTART_FN
|
||||
from realhf.api.cli_args import console, highlighter, print_config_help
|
||||
from realhf.api.quickstart.entrypoint import QUICKSTART_CONFIG_CLASSES, QUICKSTART_FN
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
from realhf.base.importing import import_module
|
||||
from realhf.base.prologue import (
|
||||
|
@ -32,29 +34,79 @@ import_module(
|
|||
import realhf.experiments.benchmark.profile_exp
|
||||
|
||||
|
||||
def print_help(exp_type):
|
||||
"""Print comprehensive help with rich formatting"""
|
||||
config_class = QUICKSTART_CONFIG_CLASSES[exp_type]()
|
||||
|
||||
# Main help panel
|
||||
console.print(
|
||||
Panel.fit(
|
||||
f"[header]Configuration Help for {exp_type}[/header]", border_style="border"
|
||||
)
|
||||
)
|
||||
|
||||
# Configuration options section
|
||||
console.print("\n[title]CONFIGURATION OPTIONS[/title]")
|
||||
print_config_help(config_class)
|
||||
|
||||
# Usage section
|
||||
console.print("\n[title]USAGE[/title]")
|
||||
usage_code = f"python -m realhf.apps.quickstart {exp_type} --config ./your/config.yaml [OPTIONS]"
|
||||
console.print(highlighter(usage_code))
|
||||
|
||||
# Examples section
|
||||
console.print("\n[title]EXAMPLE OVERRIDES[/title]")
|
||||
example_code = f"python -m realhf.apps.quickstart {exp_type} --config ./your/config.yaml dataset.path=/my/dataset.jsonl actor.optimizer.lr=2e-5"
|
||||
console.print(highlighter(example_code))
|
||||
|
||||
# Footer
|
||||
console.print("\n[dim]Use [bold]--help[/bold] to show this message again[/dim]")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog="ReaL Quickstart")
|
||||
# Create parser with add_help=False to disable automatic --help
|
||||
parser = argparse.ArgumentParser(prog="ReaL Quickstart", add_help=False)
|
||||
|
||||
# Add custom help argument that won't conflict
|
||||
parser.add_argument(
|
||||
"--help", action="store_true", help="Show this help message and exit"
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="cmd", help="sub-command help")
|
||||
subparsers.required = True
|
||||
|
||||
for k, v in QUICKSTART_FN.items():
|
||||
subparser = subparsers.add_parser(k)
|
||||
# Create subparser with add_help=False
|
||||
subparser = subparsers.add_parser(k, add_help=False)
|
||||
|
||||
# Add custom help to subparser
|
||||
subparser.add_argument(
|
||||
"--show-args",
|
||||
action="store_true",
|
||||
help="Show all legal CLI arguments for this experiment.",
|
||||
"--help", action="store_true", help="Show help for this command"
|
||||
)
|
||||
subparser.add_argument(
|
||||
PROLOGUE_FLAG_NAME,
|
||||
type=str,
|
||||
help="Set config (*.yaml) for this experiment.",
|
||||
)
|
||||
|
||||
subparser.set_defaults(func=v)
|
||||
|
||||
# Parse known args first to check for help
|
||||
args = vars(parser.parse_known_args()[0])
|
||||
if args["show_args"]:
|
||||
sys.argv = [sys.argv[0], "--help"]
|
||||
QUICKSTART_FN[args["cmd"]]()
|
||||
|
||||
# Handle help at both main and subcommand levels
|
||||
if args["help"]:
|
||||
if args["cmd"]:
|
||||
# Subcommand help
|
||||
print_help(args["cmd"])
|
||||
else:
|
||||
# Main help
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
# Continue with normal execution
|
||||
if not args["cmd"]:
|
||||
parser.print_help()
|
||||
experiment_name = ""
|
||||
trial_name = ""
|
||||
|
||||
|
|
|
@ -11,6 +11,12 @@ from typing import *
|
|||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from realhf.api.cli_args import (
|
||||
MFCConfig,
|
||||
ModelTrainEvalConfig,
|
||||
ParallelismConfig,
|
||||
PromptOnlyDatasetConfig,
|
||||
)
|
||||
from realhf.api.core.config import (
|
||||
DatasetAbstraction,
|
||||
ModelInterfaceAbstraction,
|
||||
|
@ -18,10 +24,7 @@ from realhf.api.core.config import (
|
|||
)
|
||||
from realhf.api.core.dfg import MFCDef
|
||||
from realhf.api.core.system_api import ExperimentConfig
|
||||
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
|
||||
from realhf.api.quickstart.device_mesh import MFCConfig
|
||||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
|
||||
from realhf.base import constants, logging
|
||||
from realhf.base.topology import decompose_to_three_factors
|
||||
from realhf.experiments.common.common import CommonExperimentConfig
|
||||
|
|
|
@ -2,13 +2,10 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
import os
|
||||
from importlib.metadata import version
|
||||
from typing import List
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
from realhf.api.cli_args import ModelTrainEvalConfig, SGLangConfig, vLLMConfig
|
||||
from realhf.api.quickstart.device_mesh import RPCAllocation
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig, SGLangConfig, vLLMConfig
|
||||
from realhf.base import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -13,6 +13,12 @@ import transformers
|
|||
from omegaconf import MISSING, OmegaConf
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.cli_args import (
|
||||
BaseExperimentConfig,
|
||||
MFCConfig,
|
||||
ModelTrainEvalConfig,
|
||||
ParallelismConfig,
|
||||
)
|
||||
from realhf.api.core.config import (
|
||||
DatasetAbstraction,
|
||||
ModelAbstraction,
|
||||
|
@ -24,28 +30,18 @@ from realhf.api.core.config import (
|
|||
from realhf.api.core.dfg import MFCDef, ModelInterfaceType
|
||||
from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY
|
||||
from realhf.api.core.system_api import (
|
||||
AutomaticEvaluator,
|
||||
Experiment,
|
||||
ExperimentConfig,
|
||||
ExperimentSaveEvalControl,
|
||||
ExperimentScheduling,
|
||||
ModelWorker,
|
||||
Scheduling,
|
||||
TasksGroup,
|
||||
TensorBoardConfig,
|
||||
WandBConfig,
|
||||
)
|
||||
from realhf.api.quickstart.device_mesh import (
|
||||
DeviceMesh,
|
||||
MFCConfig,
|
||||
RPCAllocation,
|
||||
make_device_mesh_from_name,
|
||||
)
|
||||
from realhf.api.quickstart.model import (
|
||||
ModelTrainEvalConfig,
|
||||
ParallelismConfig,
|
||||
get_real_model_config,
|
||||
)
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
from realhf.experiments.common.check import (
|
||||
check_is_realhf_native_model_interface,
|
||||
|
@ -58,6 +54,7 @@ from realhf.experiments.common.check import (
|
|||
from realhf.experiments.common.utils import (
|
||||
AllocationMode,
|
||||
asdict,
|
||||
get_real_model_config,
|
||||
get_topo,
|
||||
make_inf_backend_config,
|
||||
make_train_backend_config,
|
||||
|
@ -75,165 +72,7 @@ GEN_HYBRID_TRAIN_DECOUPLE_ALLOC_WARN = False
|
|||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CommonExperimentConfig(Experiment):
|
||||
"""Configuration for quickstart experiments.
|
||||
|
||||
All members can be modified via the command line. For example,
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
$ python3 -m realhf.apps.quickstart sft trial_name=my_trial seed=42 exp_ctrl.save_freq_steps=10 ...
|
||||
|
||||
This command changes the ``trial_name``, ``seed``, and the ``save_freq_steps`` attribute
|
||||
of the ``exp_ctrl`` attribute in this class.
|
||||
|
||||
``recover_mode`` can be one of the following\:
|
||||
|
||||
- ``auto``\: Automatically recover the last failed run. If the checkpoint does not exist, run from scratch with fault tolerance.
|
||||
|
||||
- ``fault``\: Run from scratch with fault tolerance.
|
||||
|
||||
- ``resume``\: Resume from saved recovery states and then run it once without fault tolerance.
|
||||
|
||||
- ``disabled``\: Do nothing but raise an error if one occurs.
|
||||
|
||||
If you are not familiar with ReaL's recovery mechanism, set this to ``disabled``.
|
||||
Normal checkpointing is usually sufficient in most cases.
|
||||
|
||||
``allocation_mode`` can be one of the following\:
|
||||
|
||||
- ``manual``\: Manually allocate resources using the specified command-line options.
|
||||
|
||||
- ``search``\: Allocate resources and configure parallel strategies using the search engine.
|
||||
|
||||
- ``heuristic``\: Allocate resources and configure parallel strategies using heuristic strategies obtained from a search.
|
||||
|
||||
- A regex pattern like ``d${DP}p${PP}m${TP}``\: Identical parallelization for all MFCs with ${DP}-way data parallelism, ${PP}-way pipeline parallelism, and ${TP}-way model parallelism.
|
||||
|
||||
- A regex pattern like ``{vllm|sglang}.{IdentPara}+{IdentPara}``\: Decoupled generation and training allocations with correspnding identical parallelization strategies.
|
||||
|
||||
- Key-value pairs with MFC names and their parallel strategies in the whole cluster, e.g., ``actor_gen:d4m2p1,*:d2p2m2`` specifies a ``d4m2p1`` strategy for actor geneartion and ``d2p2m2`` for other MFCs in a world of 8 GPUs.
|
||||
|
||||
:param experiment_name: The name of the experiment.
|
||||
An arbitrary string without "_" and "/", e.g., ``ultra-chat-llama``.
|
||||
This parameter is required.
|
||||
:type experiment_name: str
|
||||
:param trial_name: The name of the trial.
|
||||
An arbitrary string without "-" and "/", e.g., ``lr1e-3wd0.05``.
|
||||
This parameter is required.
|
||||
:type trial_name: str
|
||||
:param mode: The experiment launching mode. Supported values are "local", "ray", or "slurm".
|
||||
"ray" mode requires launching the Ray cluster via CLI.
|
||||
"slurm" mode requires the Pyxis plugin with the Enroot container enabled.
|
||||
"local" mode implies ``n_nodes=1``.
|
||||
:type mode: str
|
||||
:param debug: Whether to run in debug mode.
|
||||
Setting this to `False` will disable all assertions, which will be faster but less safe.
|
||||
:type debug: bool
|
||||
:param partition: The SLURM partition for running the experiment.
|
||||
:type partition: str
|
||||
:param wandb: The WandB initialization config.
|
||||
See https://docs.wandb.ai/ref/python/init/ for details.
|
||||
:type wandb: WandbConfig
|
||||
:param tensorboard: The tensorboard initialization config.
|
||||
Only the field of `path` is needed to specify the directory of saving the tensorboard events.
|
||||
:type tensorboard: TensorBoardConfig
|
||||
:param image_name: The name of the Docker image used by the controller.
|
||||
This parameter is only used in SLURM mode.
|
||||
:type image_name: str or None
|
||||
:param recover_mode: The recovery mode. See above for details.
|
||||
:type recover_mode: str
|
||||
:param recover_retries: The number of retries for recovery.
|
||||
Effective only when ``recover_mode`` is set to "auto" or "fault".
|
||||
:type recover_retries: int
|
||||
:param recover_after: The time interval (seconds) for recovery.
|
||||
Effective only when ``recover_mode`` is set to "auto" or "fault".
|
||||
:type recover_after: int
|
||||
:param ignore_worker_error: Whether to ignore errors raised by
|
||||
workers during runtime. Only set this to `True` if you are certain that the error can be ignored.
|
||||
Effective only when ``recover_mode`` is set to "disabled".
|
||||
:type ignore_worker_error: bool
|
||||
:param allocation_mode: The mode for GPU parallel strategy allocation. See above for details.
|
||||
:type allocation_mode: str
|
||||
:param allocation_use_cache: Whether to use cache in allocation search.
|
||||
Effective only when ``allocation_mode`` is set to "search" and a cache is available in the log directory of the current experiment
|
||||
name and trial.
|
||||
:type allocation_use_cache: bool
|
||||
:param n_nodes: The number of nodes to run the experiment.
|
||||
:type n_nodes: int
|
||||
:param n_gpus_per_node: The number of GPUs per node.
|
||||
Thus, the total number of GPUs will be ``n_nodes * n_gpus_per_node``.
|
||||
ReaL supports a world size of 1, 2, 4, 8, ... within a single node,
|
||||
or multiple nodes with the same number of GPUs.
|
||||
:type n_gpus_per_node: int
|
||||
:param nodelist: Nodelist for the distributed setting in SLURM nodelist format.
|
||||
Required for the ``manual`` allocation mode.
|
||||
For multiple GPUs on a single node, it should be formatted as "NODE01:0,1,2,3",
|
||||
indicating the use of the first 4 GPUs on ``NODE01``.
|
||||
For multiple complete nodes, it should be formatted as "NODE[01-02,03,07],COM08",
|
||||
indicating the use of all GPUs on these nodes: [NODE01, NODE02, NODE03, NODE07, COM08].
|
||||
:type nodelist: str or None
|
||||
:param seed: The random seed.
|
||||
:type seed: int
|
||||
:param cache_clear_freq: The cache of data transfer will be cleared after each ``cache_clear_freq`` steps.
|
||||
If None, will not clear the cache. Set to a small number, e.g., 1, if OOM or CUDA OOM occurs.
|
||||
:type cache_clear_freq: int or None
|
||||
:param exp_ctrl: The control for saving and evaluating the experiment.
|
||||
:type exp_ctrl: ExperimentSaveEvalControl
|
||||
:param torch_cache_mysophobia: Whether to clean torch-allocated cache blocks with
|
||||
torch.cuda.empty_cache() before each RPC in model worker
|
||||
If enabled, there will be a ~0.1s overhead per RPC.
|
||||
:type torch_cache_mysophobia: bool
|
||||
:param auto_eval: Whether to automatic evaluation in training. When enabled, an evaluation
|
||||
job is submitted whenever a checkpoint is saved, and the result will be logged on disk and
|
||||
on wandb if wandb is active.
|
||||
:type auto_eval: bool
|
||||
:param auto_eval_config: Configuration for automatic evaluation.
|
||||
:type auto_eval_config: AutomaticEvaluator
|
||||
:param cpus_per_master_worker: The number of CPUs for each master worker.
|
||||
:param mem_per_master_worker: The size of memory for each master worker, measured in MB.
|
||||
:param cpus_per_model_worker: The number of CPUs for each model worker.
|
||||
:param mem_per_model_worker: The size of memory for each model worker, measured in MB.
|
||||
"""
|
||||
|
||||
experiment_name: str = MISSING
|
||||
trial_name: str = MISSING
|
||||
mode: str = dataclasses.field(
|
||||
metadata={"choices": ["slurm", "local", "ray"]}, default="slurm"
|
||||
)
|
||||
debug: bool = True
|
||||
partition: str = "dev"
|
||||
schedule_strategy: str = "empty_first"
|
||||
wandb: WandBConfig = dataclasses.field(default_factory=WandBConfig)
|
||||
tensorboard: TensorBoardConfig = dataclasses.field(
|
||||
default_factory=TensorBoardConfig
|
||||
)
|
||||
image_name: Optional[str] = None
|
||||
recover_mode: str = "disabled"
|
||||
recover_retries: int = 1
|
||||
recover_after: int = 10
|
||||
ignore_worker_error: bool = False
|
||||
allocation_mode: str = ""
|
||||
allocation_use_cache: bool = False
|
||||
n_nodes: int = 1
|
||||
n_gpus_per_node: int = cluster_spec.n_gpus_per_node
|
||||
nodelist: Optional[str] = None
|
||||
seed: int = 1
|
||||
cache_clear_freq: Optional[int] = 10
|
||||
exp_ctrl: ExperimentSaveEvalControl = dataclasses.field(
|
||||
default_factory=ExperimentSaveEvalControl
|
||||
)
|
||||
torch_cache_mysophobia: bool = True
|
||||
# Options for automatic evaluation
|
||||
auto_eval: bool = False
|
||||
auto_eval_config: AutomaticEvaluator = dataclasses.field(
|
||||
default_factory=AutomaticEvaluator
|
||||
)
|
||||
# Options for worker resources
|
||||
cpus_per_master_worker: int = 4
|
||||
mem_per_master_worker: int = 20000
|
||||
cpus_per_model_worker: int = 4
|
||||
mem_per_model_worker: int = 90000
|
||||
class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
||||
|
||||
@property
|
||||
def models(self) -> Dict[str, ModelTrainEvalConfig]:
|
||||
|
|
|
@ -2,40 +2,25 @@
|
|||
|
||||
import dataclasses
|
||||
|
||||
from realhf.api.cli_args import (
|
||||
MFCConfig,
|
||||
ModelTrainEvalConfig,
|
||||
NullPPOExperimentOptions,
|
||||
PromptOnlyDatasetConfig,
|
||||
SFTExperimentOptions,
|
||||
)
|
||||
from realhf.api.core.config import (
|
||||
DatasetAbstraction,
|
||||
ModelInterfaceAbstraction,
|
||||
ModelInterfaceType,
|
||||
)
|
||||
from realhf.api.core.dfg import MFCDef
|
||||
from realhf.api.quickstart.dataset import (
|
||||
PromptAnswerDatasetConfig,
|
||||
PromptOnlyDatasetConfig,
|
||||
)
|
||||
from realhf.api.quickstart.device_mesh import MFCConfig
|
||||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig
|
||||
from realhf.experiments.common.common import CommonExperimentConfig
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NullSFTConfig(CommonExperimentConfig):
|
||||
"""Configuration for a null SFT experiment. Used for testing purposes.
|
||||
|
||||
:param allocation: Configuration for device allocation and
|
||||
parallelism.
|
||||
:type allocation: MFCConfig
|
||||
:param dataset: Configuration for the dataset.
|
||||
:type dataset: PromptAnswerDatasetConfig
|
||||
"""
|
||||
|
||||
model: ModelTrainEvalConfig = dataclasses.field(
|
||||
default_factory=ModelTrainEvalConfig
|
||||
)
|
||||
allocation: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
dataset: PromptAnswerDatasetConfig = dataclasses.field(
|
||||
default_factory=PromptAnswerDatasetConfig
|
||||
)
|
||||
class NullSFTConfig(CommonExperimentConfig, SFTExperimentOptions):
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
|
@ -52,7 +37,7 @@ class NullSFTConfig(CommonExperimentConfig):
|
|||
interface_type=ModelInterfaceType.TRAIN_STEP,
|
||||
interface_impl=ModelInterfaceAbstraction("null"),
|
||||
model_name="default",
|
||||
input_keys=["packed_input_ids", "prompt_mask"],
|
||||
input_keys=("packed_input_ids", "prompt_mask"),
|
||||
log_return_value=True,
|
||||
model_type=self.model.type,
|
||||
model_path=self.model.path,
|
||||
|
@ -71,7 +56,6 @@ class NullSFTConfig(CommonExperimentConfig):
|
|||
args=dict(
|
||||
max_length=self.dataset.max_seqlen,
|
||||
dataset_path=self.dataset.train_path,
|
||||
fill_to_max_length=self.dataset.fill_to_max_length,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
@ -85,22 +69,7 @@ register_quickstart_exp("null-sft", NullSFTConfig)
|
|||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NullPPOConfig(CommonExperimentConfig):
|
||||
"""Configuration for a null PPO experiment.
|
||||
|
||||
Used for testing purposes.
|
||||
"""
|
||||
|
||||
model: ModelTrainEvalConfig = dataclasses.field(
|
||||
default_factory=ModelTrainEvalConfig
|
||||
)
|
||||
inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
dataset: PromptOnlyDatasetConfig = dataclasses.field(
|
||||
default_factory=PromptOnlyDatasetConfig
|
||||
)
|
||||
dataset_filter_threshold: float = 0.2
|
||||
dataset_max_filter_percentage: float = 0.1
|
||||
class NullPPOConfig(CommonExperimentConfig, NullPPOExperimentOptions):
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
|
@ -117,8 +86,8 @@ class NullPPOConfig(CommonExperimentConfig):
|
|||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
interface_impl=ModelInterfaceAbstraction("null"),
|
||||
model_name="default",
|
||||
input_keys=["packed_prompts"],
|
||||
output_keys=["rewards"],
|
||||
input_keys=("packed_prompts",),
|
||||
output_keys=("rewards",),
|
||||
model_type=self.model.type,
|
||||
model_path=self.model.path,
|
||||
)
|
||||
|
@ -129,7 +98,7 @@ class NullPPOConfig(CommonExperimentConfig):
|
|||
interface_type=ModelInterfaceType.TRAIN_STEP,
|
||||
interface_impl=ModelInterfaceAbstraction("null"),
|
||||
model_name="default",
|
||||
input_keys=["packed_prompts", "rewards"],
|
||||
input_keys=("packed_prompts", "rewards"),
|
||||
log_return_value=True,
|
||||
model_type=self.model.type,
|
||||
model_path=self.model.path,
|
||||
|
@ -144,11 +113,10 @@ class NullPPOConfig(CommonExperimentConfig):
|
|||
def datasets(self):
|
||||
return [
|
||||
DatasetAbstraction(
|
||||
"math_prompt",
|
||||
"math_code_prompt",
|
||||
args=dict(
|
||||
max_length=self.dataset.max_prompt_len,
|
||||
dataset_path=self.dataset.path,
|
||||
fill_to_max_length=self.dataset.fill_to_max_length,
|
||||
filter_threshold=self.dataset_filter_threshold,
|
||||
max_filter_percentage=self.dataset_max_filter_percentage,
|
||||
),
|
||||
|
|
|
@ -5,21 +5,18 @@ import copy
|
|||
import dataclasses
|
||||
import os
|
||||
import pprint
|
||||
from typing import *
|
||||
from typing import Dict
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.cli_args import ModelTrainEvalConfig, PPOMATHExperimentOptions
|
||||
from realhf.api.core.config import (
|
||||
DatasetAbstraction,
|
||||
ModelInterfaceAbstraction,
|
||||
ModelInterfaceType,
|
||||
)
|
||||
from realhf.api.core.dfg import MFCDef, ParamReallocHook
|
||||
from realhf.api.core.model_api import GenerationHyperparameters
|
||||
from realhf.api.core.system_api import ExperimentConfig
|
||||
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
|
||||
from realhf.api.quickstart.device_mesh import MFCConfig
|
||||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig
|
||||
from realhf.experiments.common.common import CommonExperimentConfig
|
||||
from realhf.experiments.common.utils import (
|
||||
asdict,
|
||||
|
@ -31,198 +28,11 @@ logger = logging.getLogger("PPO Math exp", "colored")
|
|||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PPOHyperparameters:
|
||||
"""Configuration of PPO hyperparameters.
|
||||
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
||||
|
||||
:param gen: Generation hyperparameters.
|
||||
:type gen: GenerationHyperparameters
|
||||
:param ppo_n_minibatches: Number of minibatches in each PPO update.
|
||||
:type ppo_n_minibatches: int
|
||||
:param kl_ctl: Coefficient of KL divergence rewards.
|
||||
:type kl_ctl: float
|
||||
:param discount: Discount factor.
|
||||
:type discount: float
|
||||
:param gae_lambda: Lambda factor in GAE.
|
||||
:type gae_lambda: float
|
||||
:param eps_clip: PPO actor probability ratio clipping factor.
|
||||
:type eps_clip: float
|
||||
:param value_eps_clip: PPO value clipping factor.
|
||||
:type value_eps_clip: float
|
||||
:param max_reward_clip: Maximum reward value.
|
||||
:type max_reward_clip: float
|
||||
:param reward_output_scaling: Scaling factor of the reward model output.
|
||||
:type reward_output_scaling: float
|
||||
:param reward_output_bias: Bias of the reward model output.
|
||||
The number outputed by the reward model will be
|
||||
CLIP((x - bias) * scaling, -max_reward_clip, max_reward_clip).
|
||||
:type reward_output_bias: float
|
||||
:param early_stop_imp_ratio: PPO update will be early stopped if importance ratio
|
||||
exceeds this maximum value.
|
||||
:type early_stop_imp_ratio: float
|
||||
:param use_adaptive_kl_ctl: Whether to use adaptive KL divergence coefficient.
|
||||
:type use_adaptive_kl_ctl: bool
|
||||
:param adv_norm: Whether to use advantage normalization.
|
||||
:type adv_norm: bool
|
||||
:param value_norm: Whether to denormalize valued and normalize return predictions.
|
||||
:type value_norm: bool
|
||||
:param value_norm_type: Type of value normalization.
|
||||
Either exponential moving average ("exp") or moving average ("ma").
|
||||
:type value_norm_type: str
|
||||
:param value_norm_beta: Exponential decay factor
|
||||
in exponential moving average.
|
||||
:type value_norm_beta: float
|
||||
:param value_norm_eps: Epsilon factor in the
|
||||
denominator of exponential moving average.
|
||||
:type value_norm_eps: float
|
||||
:param disable_value: A shortcut option to disable the critic model.
|
||||
:type disable_value: bool
|
||||
"""
|
||||
|
||||
gen: GenerationHyperparameters = dataclasses.field(
|
||||
default_factory=GenerationHyperparameters
|
||||
)
|
||||
ppo_n_minibatches: int = 4
|
||||
kl_ctl: float = 0.1
|
||||
discount: float = 1.0
|
||||
gae_lambda: float = 1.0
|
||||
eps_clip: float = 0.2
|
||||
value_eps_clip: float = 0.2
|
||||
max_reward_clip: float = 20.0
|
||||
reward_output_scaling: float = 1.0
|
||||
reward_output_bias: float = 0.0
|
||||
early_stop_imp_ratio: float = 5.0
|
||||
use_adaptive_kl_ctl: bool = False
|
||||
adv_norm: bool = True
|
||||
value_norm: bool = True
|
||||
value_norm_type: str = dataclasses.field(
|
||||
metadata={"choices": ["exp", "ma"]}, default="exp"
|
||||
)
|
||||
value_norm_beta: float = 0.99995
|
||||
value_norm_eps: float = 1e-5
|
||||
|
||||
disable_value: bool = False
|
||||
recompute_logprob: bool = False
|
||||
fuse_rew_ref: bool = True
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PPOMATHConfig(CommonExperimentConfig):
|
||||
"""PPO experiment configuration.
|
||||
|
||||
It is a subclass of :class:`CommonExperimentConfig`,
|
||||
so all CLI options in the base class are available.
|
||||
|
||||
We don't implement runtime evaluation for PPO.
|
||||
|
||||
We identify that the RLHF process is composed of four
|
||||
distinct models with independent parameters and six
|
||||
*model function calls* upon these models.
|
||||
|
||||
The four models are\:
|
||||
|
||||
- Actor\: The primary LLM that generates text.
|
||||
- Critic\: The value function that estimates the value of a state.
|
||||
- Ref\: The reference LLM that provides KL regularization.
|
||||
- Rew\: The reward model that provides reward signals.
|
||||
|
||||
The four model function calls and their dependencies are\:
|
||||
|
||||
- Rollout\: Generate text from the actor model.
|
||||
- InfReward\: Infer rewards from the reward model given generated text.
|
||||
- InfRef\: Infer log probabilities from the reference model given generated text.
|
||||
- InfValues\: Infer values from the critic model given generated text.
|
||||
- TrainActor\: Train the actor model given generated text, rewards, values, and reference log probabilities.
|
||||
- TrainCritic\: Train the critic model given generated text, rewards, values, and reference log probabilities.
|
||||
|
||||
This class resolves these dependencies under the hood.
|
||||
What the users should specify are the runtime configurations
|
||||
of models and allocations of *each model function call*.
|
||||
|
||||
:param actor: Runtime configuration of the primary LLM.
|
||||
:type actor: ModelTrainEvalConfig
|
||||
:param critic: Runtime configuration of the critic model of PPO.
|
||||
:type critic: ModelTrainEvalConfig
|
||||
:param ref: Runtime configuration of the reference LLM.
|
||||
:type ref: ModelTrainEvalConfig
|
||||
:param rew: Runtime configuration of the reward LLM.
|
||||
:type rew: ModelTrainEvalConfig
|
||||
:param actor_train: :class:`MFCConfig` for TrainActor.
|
||||
:type actor_train: MFCConfig
|
||||
:param critic_train: :class:`MFCConfig` for TrainCritic.
|
||||
:type critic_train: MFCConfig
|
||||
:param actor_gen: :class:`MFCConfig` for Rollout.
|
||||
:type actor_gen: MFCConfig
|
||||
:param critic_inf: :class:`MFCConfig` for InfValues.
|
||||
:type critic_inf: MFCConfig
|
||||
:param rew_inf: :class:`MFCConfig` for InfReward.
|
||||
:type rew_inf: MFCConfig
|
||||
:param ref_inf: :class:`MFCConfig` for InfRef.
|
||||
:type ref_inf: MFCConfig
|
||||
:param dataset: Dataset configuration.
|
||||
:type dataset: PromptOnlyDatasetConfig
|
||||
:param ppo: Configuration for the PPO algorithm.
|
||||
:type ppo: PPOHyperparameters
|
||||
:param group_size: The number of answers remained for each prompt.
|
||||
:type group_size: int
|
||||
:param generation_size: The number of answers sampled for each prompt.
|
||||
Among them, only `group_size` samples are remained according to
|
||||
the reward score, aka best-of-n sampling. If None, use `group_size`.
|
||||
:type generation_size: Optional[int]
|
||||
:param mask_no_eos_with_zero: Whether to mask out the reward if an
|
||||
answer is truncated due to exceeding the length limit.
|
||||
:type mask_no_eos_with_zero: bool
|
||||
:param mask_too_long: Whether to mask out the PPO loss if an
|
||||
answer is truncated due to exceeding the length limit.
|
||||
:type mask_too_long: bool
|
||||
:param check_verifier_status: If True, raise an error
|
||||
when the reward is all-zero. This usually indicates a bug
|
||||
of the verifier.
|
||||
:type check_verifier_status: bool
|
||||
:param group_adv_norm: Whther to use grouped advantage
|
||||
normaliztion in GRPO.
|
||||
:type group_adv_norm: bool
|
||||
"""
|
||||
|
||||
actor: ModelTrainEvalConfig = dataclasses.field(
|
||||
default_factory=ModelTrainEvalConfig
|
||||
)
|
||||
critic: ModelTrainEvalConfig = dataclasses.field(
|
||||
default_factory=ModelTrainEvalConfig
|
||||
)
|
||||
ref: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
|
||||
|
||||
# for manual allocation only
|
||||
actor_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
critic_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
actor_gen: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
critic_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
rew_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
actor_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
|
||||
dataset: PromptOnlyDatasetConfig = dataclasses.field(
|
||||
default_factory=PromptOnlyDatasetConfig
|
||||
)
|
||||
|
||||
ppo: PPOHyperparameters = dataclasses.field(default_factory=PPOHyperparameters)
|
||||
|
||||
group_size: int = 1
|
||||
generation_size: Optional[int] = None
|
||||
mask_no_eos_with_zero: bool = False
|
||||
ref_ema_eta: Optional[float] = None
|
||||
group_adv_norm: bool = False
|
||||
mask_too_long: bool = False
|
||||
rw_type: Optional[str] = "sparse"
|
||||
check_xml_format: bool = False
|
||||
|
||||
check_verifier_status: bool = False
|
||||
|
||||
dataset_filter_threshold: float = 100.0
|
||||
dataset_max_filter_percentage: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
self.ppo_kwargs = dict(
|
||||
@property
|
||||
def ppo_kwargs(self):
|
||||
return dict(
|
||||
n_minibatches=self.ppo.ppo_n_minibatches,
|
||||
kl_ctl=self.ppo.kl_ctl,
|
||||
discount=self.ppo.discount,
|
||||
|
@ -341,8 +151,8 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
model_type=self.actor.type,
|
||||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=["packed_prompts", "task_ids"],
|
||||
output_keys=rollout_output_keys,
|
||||
input_keys=("packed_prompts", "task_ids"),
|
||||
output_keys=tuple(rollout_output_keys),
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
|
@ -354,8 +164,8 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
model_type=self.actor.type,
|
||||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=["packed_input_ids"],
|
||||
output_keys=["packed_logprobs"],
|
||||
input_keys=("packed_input_ids",),
|
||||
output_keys=("packed_logprobs",),
|
||||
output_key_remap=dict(logprobs="packed_logprobs"),
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
@ -366,8 +176,8 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
interface_impl=rw_interface,
|
||||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=["packed_input_ids", "packed_prompts", "task_ids"],
|
||||
output_keys=["rewards"],
|
||||
input_keys=("packed_input_ids", "packed_prompts", "task_ids"),
|
||||
output_keys=("rewards",),
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
|
@ -387,8 +197,8 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
model_path=self.ref.path,
|
||||
interface_impl=ref_interface,
|
||||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=inf_ref_inputs,
|
||||
output_keys=inf_ref_outputs,
|
||||
input_keys=tuple(inf_ref_inputs),
|
||||
output_keys=tuple(inf_ref_outputs),
|
||||
output_key_remap=dict(logprobs="packed_ref_logprobs"),
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
@ -402,8 +212,8 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
model_type=self.critic.type,
|
||||
model_path=self.critic.path,
|
||||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=["packed_input_ids", "seq_no_eos_mask"],
|
||||
output_keys=["values"],
|
||||
input_keys=("packed_input_ids", "seq_no_eos_mask"),
|
||||
output_keys=("values",),
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
|
@ -426,7 +236,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
model_type=self.actor.type,
|
||||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=train_actor_inputs,
|
||||
input_keys=tuple(train_actor_inputs),
|
||||
log_return_value=True,
|
||||
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
|
@ -440,7 +250,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
interface_impl=critic_interface,
|
||||
model_type=self.critic.type,
|
||||
model_path=self.critic.path,
|
||||
input_keys=[
|
||||
input_keys=(
|
||||
"packed_input_ids",
|
||||
"packed_logprobs",
|
||||
"packed_ref_logprobs",
|
||||
|
@ -448,7 +258,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
"values",
|
||||
"prompt_mask",
|
||||
"seq_no_eos_mask",
|
||||
],
|
||||
),
|
||||
log_return_value=True,
|
||||
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
import dataclasses
|
||||
|
||||
from realhf.api.cli_args import SFTExperimentOptions
|
||||
from realhf.api.core.config import (
|
||||
DatasetAbstraction,
|
||||
ModelInterfaceAbstraction,
|
||||
|
@ -11,35 +12,12 @@ from realhf.api.core.config import (
|
|||
ModelName,
|
||||
)
|
||||
from realhf.api.core.dfg import MFCDef
|
||||
from realhf.api.quickstart.dataset import PromptAnswerDatasetConfig
|
||||
from realhf.api.quickstart.device_mesh import MFCConfig
|
||||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig
|
||||
from realhf.experiments.common.common import CommonExperimentConfig
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SFTConfig(CommonExperimentConfig):
|
||||
"""Configuration for SFT experiments.
|
||||
|
||||
This class is a subclass of :class:`CommonExperimentConfig`,
|
||||
so all CLI options from the base class are available.
|
||||
|
||||
:param model: Configuration for model runtime.
|
||||
:type model: ModelTrainEvalConfig
|
||||
:param allocation: Configuration for device allocation and parallelism.
|
||||
:type allocation: MFCConfig
|
||||
:param dataset: Configuration for the dataset.
|
||||
:type dataset: PromptAnswerDatasetConfig
|
||||
"""
|
||||
|
||||
model: ModelTrainEvalConfig = dataclasses.field(
|
||||
default_factory=ModelTrainEvalConfig
|
||||
)
|
||||
allocation: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
dataset: PromptAnswerDatasetConfig = dataclasses.field(
|
||||
default_factory=PromptAnswerDatasetConfig
|
||||
)
|
||||
class SFTConfig(CommonExperimentConfig, SFTExperimentOptions):
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
|
@ -56,7 +34,7 @@ class SFTConfig(CommonExperimentConfig):
|
|||
interface_type=ModelInterfaceType.TRAIN_STEP,
|
||||
interface_impl=ModelInterfaceAbstraction("sft"),
|
||||
model_name="default",
|
||||
input_keys=["packed_input_ids", "prompt_mask"],
|
||||
input_keys=("packed_input_ids", "prompt_mask"),
|
||||
log_return_value=True,
|
||||
model_type=self.model.type,
|
||||
model_path=self.model.path,
|
||||
|
|
|
@ -7,23 +7,31 @@ import dataclasses
|
|||
import enum
|
||||
import itertools
|
||||
import re
|
||||
from typing import *
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Hashable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from realhf.api.cli_args import ModelTrainEvalConfig, ParallelismConfig
|
||||
from realhf.api.core.config import (
|
||||
ModelAbstraction,
|
||||
ModelBackendAbstraction,
|
||||
ModelInterfaceType,
|
||||
ModelName,
|
||||
)
|
||||
from realhf.api.core.dfg import OffloadHook, ParamReallocHook
|
||||
from realhf.api.quickstart.device_mesh import RPCAllocation
|
||||
from realhf.api.quickstart.model import (
|
||||
ModelTrainEvalConfig,
|
||||
ParallelismConfig,
|
||||
parallelism_eq,
|
||||
)
|
||||
from realhf.base import logging
|
||||
from realhf.base.topology import (
|
||||
DataPipeModelParallelTopology,
|
||||
|
@ -34,6 +42,29 @@ from realhf.base.topology import (
|
|||
logger = logging.getLogger("Experiment Common Utils", "benchmark")
|
||||
|
||||
|
||||
def get_real_model_config(
|
||||
model_path: str,
|
||||
hf_model_family: str,
|
||||
is_critic: bool,
|
||||
init_from_scratch: bool,
|
||||
init_critic_from_actor: bool,
|
||||
dtype: Optional[str] = None,
|
||||
) -> ModelAbstraction:
|
||||
"""Make a configuration to build model."""
|
||||
model = ModelAbstraction(
|
||||
"real_model",
|
||||
args=dict(
|
||||
model_path=model_path,
|
||||
is_critic=is_critic,
|
||||
init_critic_from_actor=init_critic_from_actor,
|
||||
dtype=dtype,
|
||||
hf_model_family=hf_model_family,
|
||||
init_from_scratch=init_from_scratch,
|
||||
),
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def get_topo(
|
||||
parallel: ParallelismConfig,
|
||||
gradient_checkpointing: bool,
|
||||
|
@ -72,7 +103,7 @@ def make_train_backend_config(
|
|||
model_cfg: ModelTrainEvalConfig, parallel_cfg: ParallelismConfig
|
||||
):
|
||||
if model_cfg.backend == "megatron":
|
||||
megatron_args: Dict[str, Any] = OmegaConf.to_container(model_cfg.megatron)
|
||||
megatron_args: Dict[str, Any] = asdict(model_cfg.megatron)
|
||||
return ModelBackendAbstraction(
|
||||
"megatron",
|
||||
args=dict(
|
||||
|
@ -132,8 +163,11 @@ def resolve_replica_ids(
|
|||
for alloc in allocs:
|
||||
if alloc.rpc.name == main_alloc.rpc.name:
|
||||
continue
|
||||
same_alloc = alloc.device_mesh == main_alloc.device_mesh and parallelism_eq(
|
||||
alloc.parallel, main_alloc.parallel
|
||||
same_alloc = (
|
||||
alloc.device_mesh == main_alloc.device_mesh
|
||||
and ParallelismConfig.parallelism_eq(
|
||||
alloc.parallel, main_alloc.parallel
|
||||
)
|
||||
)
|
||||
if not same_alloc or (
|
||||
alloc.rpc.is_generate()
|
||||
|
@ -165,7 +199,7 @@ def resolve_rpc_hooks(
|
|||
if rpc.role != other.rpc.role:
|
||||
continue
|
||||
if (
|
||||
parallelism_eq(parallel, other.parallel)
|
||||
ParallelismConfig.parallelism_eq(parallel, other.parallel)
|
||||
and device_mesh == other.device_mesh
|
||||
and not (
|
||||
model_configs[rpc.role].vllm.hybrid_train
|
||||
|
|
|
@ -58,7 +58,9 @@ def load_metadata(path):
|
|||
try:
|
||||
if "task" not in d:
|
||||
d["task"] = "math"
|
||||
logger.warning(f'Key "task" not found in the dataset. Use math as default task type.')
|
||||
logger.warning(
|
||||
f'Key "task" not found in the dataset. Use math as default task type.'
|
||||
)
|
||||
if d["task"] == "math":
|
||||
d = check_math_metadata_entries(d)
|
||||
elif d["task"] == "code":
|
||||
|
@ -206,7 +208,7 @@ else:
|
|||
),
|
||||
),
|
||||
max_length=512,
|
||||
dataset_path='/storage/datasets/full_prompts_for_r1_distilled.jsonl'
|
||||
dataset_path="/storage/datasets/full_prompts_for_r1_distilled.jsonl",
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
|
|
|
@ -16,9 +16,9 @@ import torch.distributed as dist
|
|||
import transformers
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from realhf.api.cli_args import MegatronConfig, MicroBatchSpec, OptimizerConfig
|
||||
from realhf.api.core import model_api
|
||||
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
|
||||
from realhf.api.quickstart.model import MegatronConfig, OptimizerConfig
|
||||
from realhf.api.core.data_api import SequenceSample
|
||||
from realhf.base import constants, logging, pkg_version
|
||||
from realhf.base.datapack import flat2d
|
||||
from realhf.base.monitor import CUDATimeMarkType, cuda_tmarked
|
||||
|
@ -552,7 +552,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
DistributedDataParallelConfig,
|
||||
)
|
||||
else:
|
||||
from realhf.api.quickstart.model import DistributedDataParallelConfig
|
||||
from realhf.api.cli_args import DistributedDataParallelConfig
|
||||
self.ddp = DistributedDataParallelConfig(**self.ddp)
|
||||
with megatron_ctx():
|
||||
if pkg_version.is_version_less("megatron.core", "0.7.0"):
|
||||
|
|
|
@ -17,6 +17,7 @@ import torch.multiprocessing as mp
|
|||
import transformers
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from realhf.api.cli_args import SGLangConfig
|
||||
from realhf.api.core import data_api
|
||||
from realhf.api.core.model_api import (
|
||||
APIGenerateInput,
|
||||
|
@ -29,7 +30,6 @@ from realhf.api.core.model_api import (
|
|||
PipelinableEngine,
|
||||
register_backend,
|
||||
)
|
||||
from realhf.api.quickstart.model import SGLangConfig
|
||||
from realhf.base import (
|
||||
cluster,
|
||||
constants,
|
||||
|
|
|
@ -29,8 +29,8 @@ except ModuleNotFoundError:
|
|||
pass
|
||||
|
||||
|
||||
from realhf.api.cli_args import vLLMConfig
|
||||
from realhf.api.core import data_api, model_api
|
||||
from realhf.api.quickstart.model import vLLMConfig
|
||||
from realhf.base import constants, logging, seeding
|
||||
|
||||
logger = logging.getLogger("vLLM backend")
|
||||
|
|
|
@ -18,9 +18,9 @@ import pandas as pd
|
|||
import realhf.base.cluster
|
||||
import realhf.base.constants as constants
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.cli_args import ParallelismConfig
|
||||
from realhf.api.core.dfg import MFCDef, ModelFamily, ModelInterfaceType
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.api.quickstart.model import ParallelismConfig
|
||||
from realhf.search_engine.param_realloc import estimate_param_realloc_time_cost
|
||||
from realhf.search_engine.utils import load_model_config
|
||||
|
||||
|
|
|
@ -19,10 +19,10 @@ except ModuleNotFoundError:
|
|||
mdm_search = None
|
||||
|
||||
import realhf.base.constants as constants
|
||||
from realhf.api.cli_args import ModelTrainEvalConfig, ParallelismConfig
|
||||
from realhf.api.core.config import ModelInterfaceType
|
||||
from realhf.api.core.dfg import MFCDef
|
||||
from realhf.api.quickstart.device_mesh import DeviceMesh, RPCAllocation
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
|
||||
from realhf.api.quickstart.search import RPCExecution
|
||||
|
||||
|
||||
|
|
|
@ -218,9 +218,10 @@ class RedistribPlanner:
|
|||
return list(gather_plan.values()) + list(scatter_plan.values())
|
||||
|
||||
def derive_plan_bcast(
|
||||
self, dests: Dict[int, List[Hashable]], keys: List[str]
|
||||
self, dests: Dict[int, List[Hashable]], keys: List[str] | Tuple[str]
|
||||
) -> List[RedistribStep]:
|
||||
assert isinstance(keys, list), type(keys)
|
||||
assert isinstance(keys, (list, tuple)), type(keys)
|
||||
keys = list(keys)
|
||||
self.dests = dests
|
||||
|
||||
# Get all requried data IDs.
|
||||
|
|
|
@ -44,15 +44,16 @@ tabulate
|
|||
aiofiles
|
||||
pydantic
|
||||
isort==5.13.2
|
||||
clang-format
|
||||
clang-format==19.1.7
|
||||
ninja
|
||||
paramiko
|
||||
# To eliminate security risks
|
||||
torch>2.0.0
|
||||
black>=25.1.0
|
||||
black==25.1.0
|
||||
cookiecutter>2.1.1
|
||||
asyncio
|
||||
aiohttp
|
||||
httpx>=0.28.1
|
||||
etcd3
|
||||
protobuf<3.21
|
||||
protobuf<3.21
|
||||
rich
|
|
@ -15,7 +15,7 @@ import pytest
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from realhf.api.core.config import ModelFamily, ModelName, ModelShardID
|
||||
from realhf.api.core.config import ModelName, ModelShardID
|
||||
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
|
||||
from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY, ReaLModelConfig
|
||||
from realhf.base import constants, logging, testing, topology
|
||||
|
|
|
@ -11,10 +11,8 @@ from typing import *
|
|||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import pytest
|
||||
import ray
|
||||
from ray.util.queue import Queue as RayQueue
|
||||
|
||||
from realhf.api.core.config import ModelFamily, ModelInterfaceAbstraction, ModelName
|
||||
from realhf.api.core.config import ModelInterfaceAbstraction, ModelName
|
||||
from realhf.api.core.dfg import MFCDef, ModelInterfaceType, build_graph
|
||||
from realhf.base import logging
|
||||
|
||||
|
@ -125,8 +123,6 @@ def _get_reinforce_rpcs():
|
|||
|
||||
@pytest.mark.parametrize("rpcs", [_get_ppo_rpcs(), _get_reinforce_rpcs()])
|
||||
def test_build_graph(tmp_path: pathlib.Path, rpcs: List[MFCDef]):
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
G = build_graph(rpcs, verbose=True, graph_path=str(tmp_path / "dfg.png"))
|
||||
assert nx.is_directed_acyclic_graph(G)
|
||||
for node in rpcs:
|
||||
|
@ -152,14 +148,3 @@ def test_build_graph(tmp_path: pathlib.Path, rpcs: List[MFCDef]):
|
|||
if k.startswith("_"):
|
||||
continue
|
||||
assert v == dataclasses.asdict(node)[k]
|
||||
|
||||
# Ensure node can be passed into ray queue
|
||||
queue = RayQueue(maxsize=8)
|
||||
queue.put(node)
|
||||
node_ = queue.get()
|
||||
for k, v in dataclasses.asdict(node_).items():
|
||||
if k.startswith("_"):
|
||||
continue
|
||||
assert v == dataclasses.asdict(node)[k]
|
||||
if ray.is_initialized():
|
||||
ray.shutdown()
|
||||
|
|
|
@ -7,6 +7,7 @@ import uuid
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from realhf.api.core import config as config_api
|
||||
from realhf.api.core import data_api
|
||||
|
@ -20,16 +21,22 @@ def _validate_dataset(cfg: config_api.DatasetAbstraction, tokenizer):
|
|||
dp_rank=0,
|
||||
world_size=1,
|
||||
tokenizer_or_tokenizer_name=tokenizer,
|
||||
experiment_name=uuid.uuid4(),
|
||||
trial_name=uuid.uuid4(),
|
||||
experiment_name=str(uuid.uuid4()),
|
||||
trial_name=str(uuid.uuid4()),
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
collate_fn=data_api.SequenceSample.gather,
|
||||
# NOTE: This is *NOT* the actual batch size for training.
|
||||
# It is just a proper size to load data to workers.
|
||||
batch_size=10240,
|
||||
shuffle=True,
|
||||
)
|
||||
dataloader = data_api.PackedDataLoader(dataset)
|
||||
for x in dataloader:
|
||||
assert isinstance(x, data_api.SequenceSample)
|
||||
assert x.data is not None
|
||||
for k, v in x.data.items():
|
||||
assert v.device == torch.device("cpu")
|
||||
bs = len(x.ids)
|
||||
for k, vs in x.seqlens.items():
|
||||
assert all(isinstance(v, list) for v in vs)
|
||||
assert all(all(isinstance(vv, int) for vv in v) for v in vs)
|
||||
|
@ -37,7 +44,7 @@ def _validate_dataset(cfg: config_api.DatasetAbstraction, tokenizer):
|
|||
if x.metadata:
|
||||
for k, v in x.metadata.items():
|
||||
assert isinstance(v, list), k
|
||||
xs = x.split(bs)
|
||||
xs = x.unpack()
|
||||
for xx in xs:
|
||||
if xx.metadata:
|
||||
for k, v in xx.metadata.items():
|
||||
|
|
|
@ -7,13 +7,14 @@ from typing import *
|
|||
|
||||
import pytest
|
||||
|
||||
from realhf.api.core.system_api import ExperimentSaveEvalControl
|
||||
from realhf.api.quickstart.dataset import (
|
||||
from realhf.api.cli_args import (
|
||||
ExperimentSaveEvalControl,
|
||||
MFCConfig,
|
||||
ModelTrainEvalConfig,
|
||||
ParallelismConfig,
|
||||
PromptAnswerDatasetConfig,
|
||||
PromptOnlyDatasetConfig,
|
||||
)
|
||||
from realhf.api.quickstart.device_mesh import MFCConfig, ParallelismConfig
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig
|
||||
from realhf.base import cluster, logging, name_resolve, testing
|
||||
from realhf.experiments.common.null_exp import NullPPOConfig, NullSFTConfig
|
||||
from tests.experiments.utils import run_test_exp
|
||||
|
@ -28,40 +29,40 @@ def model_class(request):
|
|||
|
||||
|
||||
@pytest.fixture(params=[300])
|
||||
def dataset_with_size(request, save_path):
|
||||
def math_code_dataset_with_size(request, save_path):
|
||||
size = request.param
|
||||
max_prompt_len = 8
|
||||
max_resp_len = 8
|
||||
dataset = []
|
||||
for i in range(size):
|
||||
prompt_len = random.randint(1, max_prompt_len)
|
||||
n_pairs = random.randint(1, 5)
|
||||
d = dict(
|
||||
id=i,
|
||||
prompt=generate_random_sentence(prompt_len),
|
||||
answer=generate_random_sentence(random.randint(1, max_resp_len)),
|
||||
pos_answers=[
|
||||
generate_random_sentence(random.randint(1, max_resp_len))
|
||||
for _ in range(n_pairs)
|
||||
],
|
||||
neg_answers=[
|
||||
generate_random_sentence(random.randint(1, max_resp_len))
|
||||
for _ in range(n_pairs)
|
||||
],
|
||||
query_id=str(uuid.uuid4()),
|
||||
prompt=generate_random_sentence(prompt_len),
|
||||
task=random.choice(["math", "code"]),
|
||||
)
|
||||
if d["task"] == "math":
|
||||
d["solutions"] = [generate_random_sentence(max_resp_len)]
|
||||
elif d["task"] == "code":
|
||||
d["input_output"] = json.dumps(dict(inputs=["the\n"], outputs=["the\n"]))
|
||||
dataset.append(d)
|
||||
with open(str(save_path / "_dataset.json"), "w") as f:
|
||||
json.dump(dataset, f)
|
||||
return dataset, size
|
||||
with open(str(save_path / "math_code_dataset.jsonl"), "a") as f:
|
||||
f.write(json.dumps(d) + "\n")
|
||||
return dataset, len(dataset)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dp", [4])
|
||||
@pytest.mark.parametrize("bs", [63])
|
||||
def test_buffer_recover(
|
||||
bs, tmp_path_factory, dataset_with_size, tokenizer, save_path, cpu_hf_model, dp
|
||||
bs,
|
||||
tmp_path_factory,
|
||||
math_code_dataset_with_size,
|
||||
tokenizer,
|
||||
save_path,
|
||||
cpu_hf_model,
|
||||
dp,
|
||||
):
|
||||
_, dataset_size = dataset_with_size
|
||||
_, dataset_size = math_code_dataset_with_size
|
||||
# Setup experiment env. Should be done before any other operations.
|
||||
log_root = tmp_path_factory.mktemp("buffer-recover")
|
||||
cluster.spec.fileroot = str(log_root)
|
||||
|
@ -100,7 +101,7 @@ def test_buffer_recover(
|
|||
backend="mock_train",
|
||||
),
|
||||
dataset=PromptOnlyDatasetConfig(
|
||||
path=str(save_path / "_dataset.json"),
|
||||
path=str(save_path / "math_code_dataset.jsonl"),
|
||||
max_prompt_len=128,
|
||||
train_bs_n_seqs=bs,
|
||||
fill_to_max_length=False,
|
||||
|
|
|
@ -7,17 +7,18 @@ from typing import *
|
|||
|
||||
import pytest
|
||||
|
||||
from realhf.api.core.data_api import MicroBatchSpec
|
||||
from realhf.api.core.system_api import ExperimentSaveEvalControl
|
||||
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
|
||||
from realhf.api.quickstart.device_mesh import MFCConfig
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
|
||||
from realhf.base import cluster, testing
|
||||
from realhf.experiments.common.ppo_math_exp import (
|
||||
from realhf.api.cli_args import (
|
||||
ExperimentSaveEvalControl,
|
||||
GenerationHyperparameters,
|
||||
MFCConfig,
|
||||
MicroBatchSpec,
|
||||
ModelTrainEvalConfig,
|
||||
ParallelismConfig,
|
||||
PPOHyperparameters,
|
||||
PPOMATHConfig,
|
||||
PromptOnlyDatasetConfig,
|
||||
)
|
||||
from realhf.base import cluster, testing
|
||||
from realhf.experiments.common.ppo_math_exp import PPOMATHConfig
|
||||
from tests.experiments.utils import run_test_exp
|
||||
from tests.fixtures import *
|
||||
|
||||
|
|
|
@ -5,8 +5,16 @@ from typing import *
|
|||
|
||||
import pytest
|
||||
|
||||
from realhf.api.quickstart.dataset import PromptAnswerDatasetConfig
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig
|
||||
from realhf.api.cli_args import (
|
||||
ExperimentSaveEvalControl,
|
||||
GenerationHyperparameters,
|
||||
MFCConfig,
|
||||
MicroBatchSpec,
|
||||
ModelTrainEvalConfig,
|
||||
ParallelismConfig,
|
||||
PPOHyperparameters,
|
||||
PromptOnlyDatasetConfig,
|
||||
)
|
||||
from realhf.base import cluster, testing
|
||||
from realhf.experiments.common.sft_exp import SFTConfig
|
||||
from tests.experiments.utils import run_test_exp
|
||||
|
|
|
@ -14,7 +14,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import transformers
|
||||
|
||||
from realhf.api.core.config import ModelFamily
|
||||
from realhf.api.cli_args import ModelFamily
|
||||
from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY, ReaLModelConfig
|
||||
from realhf.base import constants, logging
|
||||
from realhf.base.testing import (
|
||||
|
|
Loading…
Reference in New Issue