[Feature & Doc & Bug Fix] Add docs, simplified ray-based scripts, and fix issues to stablize asynchronous experiments (#52)

* feat: one buffer for each task

* feat: support "one buffer for each task" for async

* make kv_cache_dtype configurable

Signed-off-by: Tiwei Bie <tiwei.btw@antgroup.com>

* style: use plural form

fix: use _seed_from_key to set different seeds for data loaders
fix: call load_data for one buffer each time

* PullRequest: 125 Support running async experiments in the 2407 image.

Merge branch fw/async2407 of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/125

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* fix: handle multiple datasets in recover indices
fix: `isinstance(self.__datasets, PullerStreamDataset)`
feat: use the "spec" request to obtain the number of datasets
fix: revert rollout worker

* fix: revert async_rl_exp.py

* fix flag for list (cuda_graph_bs)

* format

* [FIX] fix async task reward [sglang bf16-> fp16]

* fix: define `self.__datasets` in advance

* PullRequest: 130 [Refactor] Remove deprecated search related code

Merge branch mzy/remove-search of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/130

Signed-off-by: 博惟 <bowei.fw@antgroup.com>


* remove search related

* PullRequest: 131 [Refactor] Change terminology "model parallel" into "tensor parallel" to align with megatron.

Merge branch mzy/mp-to-tp of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/131?tab=comment

Signed-off-by: 博惟 <bowei.fw@antgroup.com>


* change mp to tp
* .
* .

* PullRequest: 142 Fix an error for megatron backend destroy

Merge branch fw/fix-meagatron-destroy of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/142

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* PullRequest: 143 Fix the port conflict issue of generation servers

Merge branch fw/fix-gen-port of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/143?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* somehow fix the port issue
* add clearance period
* .
* .

* PullRequest: 145 Add code environment

Merge branch fw/code-env of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/145?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* add code env
* somehow fix the port issue
* fix

* PullRequest: 144 Add decoupled PPO loss

Merge branch fw/decoupled-ppo-loss of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/144?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fix ppo step logging, nan in stats tracker, and add decoupled loss
* .
* somehow fix the port issue
* fix typo

* PullRequest: 146 Merge SLURM logs and save experiment configs in yaml format.

Merge branch fw/better-logging of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/146

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* merge all slurm logs into one
* write config to yaml

* PullRequest: 141 Merge changes during NeurIPS submission

Merge branch fw/async-dev of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/141

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* .
* .
* .
* update script
* .
* .
* .
* .
* [ADD] add least req scheduling
* fix test genreq
* .
* .
* fix stats tracker nan
* .
* .
* .
* .
* .
* .
* .
* uppper clip decoupled objective
* add throughput exp script
* .
* remove behav upper clip param
* .
* .
* .
* plot curve
* update thpt script
* .
* master worker raise error when exiting
* update script
* add gen throughput logging
* .
* .
* add decoupled wandb data
* .
* fix port issue and add no training option
* .
* enlarge ttl
* remove gserver manager await staled
* update weights in groups
* .
* .
* .
* add port clearance period
* .
* .
* .
* add plot script
* add sft throughput eval
* .
* log tokens in null interface
* 消融实验和interruptible generation
* 画图脚本/运行脚本/数据结果
* .
* remove scripts
* add port test
* remove force_sync_reward
* revert some changes
* .
* revert
* revert fix
* fix
* revert
* fix typo

* support qwen3 training

* PullRequest: 147 Support interruption in SGLang and fix a KeyError in gather-scatter communication

Merge branch fw/sglang046-with-abort-request of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/147?tab=diff

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fix ppo step logging, nan in stats tracker, and add decoupled loss
* .
* somehow fix the port issue
* initial commit
* add interupt request
* fix data transfer issue
* max concurrent rollouts defaults to train batch size
* merge main
* add patch
* fix patch typp
* revert sglang
* fix typo
* fix minor typo
* .
* pip show editable sglang path

* PullRequest: 149 fix: code faas max_retries

Merge branch xss/fix_code_verifier of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/149

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* fix: code faas max_retries

* PullRequest: 150 [Bug Fix] Fix key errors in `_run_scatter` in data transfer

Merge branch mzy/fix-scatter-groups of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/150

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* fix scatter groups key error

* fix test

* .

* PullRequest: 151 Fix Qwen3 import error when using transformers with a lower version

Merge branch fw/fix-qwen3 of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/151

Reviewed-by: 温差 <xushusheng.xss@antgroup.com>


* merge all slurm logs into one
* write config to yaml
* .

* PullRequest: 152 Support sglang0.4.6 and fix master_worker import error

Merge branch adopt_sglang046 of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/152

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* Support sglang0.4.6 and fix master_worker import error
* remove disable_mla option

* PullRequest: 155 [FIX] reduce port conflicts

Merge branch sxj/reduce_port_conflict of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/155

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* [FIX] reduce port conflicts

* PullRequest: 153 Fix stuck and recover issues for async experiments

Merge branch fw/stable-async of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/153

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fix sample cnt stuck
* fix recover
* code cleanup
* merge all slurm logs into one
* write config to yaml
* .
* .
* .
* revert birth time change
* .
* enlarge sock connect timeout

* PullRequest: 158 [Fix] Fix the error where "accepted" is not defined

Merge branch fw/fix-rollout-accepted of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/158

Reviewed-by: 温差 <xushusheng.xss@antgroup.com>


* .

* PullRequest: 154 Fix unit tests and simplify package installation

Merge branch fw/v0.3.0-tests of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/154?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fix some tests
* fix tests except for experiments
* fix tests
* fix tests
* .
* .

* PullRequest: 159 [fix] Enlarge the default aiohttp connection timeout and fix a recover error in model worker

Merge branch fw/stable-async of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/159

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fix sample cnt stuck
* fix recover
* code cleanup
* merge all slurm logs into one
* write config to yaml
* .
* .
* .
* revert birth time change
* .
* enlarge sock connect timeout
* .

* PullRequest: 160 set sock_connect as rollout_request_timeout in partial_rollout.py

Merge branch xss/rollout_timeout of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/160

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* set sock_connect as rollout_request_timeout in partial_rollout.py

* PullRequest: 161 Prioritize rollouts that are submitted earlier rather than arrived earlier

Merge branch fw/birth-time of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/161

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* blocking push

* PullRequest: 163 [bugfix] Fix synchronized training when birth time is absent

Merge branch fw/fix-sync-birthtime of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/163

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* PullRequest: 164 [Refactor] Move cluster spec into CLI args

Merge branch fw/refactor-cluster-spec of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/164?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* set cluster spec path in args
* .
* fix
* add default cluster spec

* PullRequest: 165 Normally exit all workers after experiment completion

Merge branch fw/exit-all-workers of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/165

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .

* PullRequest: 167 [Feature] Use chunked logits computation to alleviate SGLang OOM

Merge branch fw/patch-sglang-oom of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/167

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* PullRequest: 166 [Feature] Support single-script experiment launch with Ray

Merge branch fw/turbolaunch of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/166?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* add training script without ray name resolve
* add ray name resolve
* ray worker
* run
* run async
* local run
* set cluster spec path in args
* .
* .
* fix
* .
* .
* .
* .
* .
* update config
* .
* minor renaming

* PullRequest: 169 [Doc] Add v0.3.0 docs based on jupyter-book

Merge branch fw/doc of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/169

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* add docs
* refine doc
* refine doc

---------

Signed-off-by: Tiwei Bie <tiwei.btw@antgroup.com>
Co-authored-by: wanghuaijie.whj <wanghuaijie.whj@antgroup.com>
Co-authored-by: Tiwei Bie <tiwei.btw@antgroup.com>
Co-authored-by: kira.gw <kira.gw@antgroup.com>
Co-authored-by: shenxujie.sxj <shenxujie.sxj@antgroup.com>
Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
Co-authored-by: sam.gjx <sam.gjx@antgroup.com>
Co-authored-by: 温差 <xushusheng.xss@antgroup.com>
Co-authored-by: 履渊 <yuhong.gyh@antgroup.com>
This commit is contained in:
Wei Fu 2025-05-28 19:18:05 +08:00 committed by GitHub
parent 89cc3a7400
commit cf46993a30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
85 changed files with 3085 additions and 1894 deletions

32
docs/_config.yml Normal file
View File

@ -0,0 +1,32 @@
# Book settings
# Learn more at https://jupyterbook.org/customize/config.html
title: AReaL Documentation
author: Wei Fu
logo: figures/logo.png
# Force re-execution of notebooks on each build.
# See https://jupyterbook.org/content/execute.html
execute:
execute_notebooks: force
# Define the name of the latex output file for PDF builds
latex:
latex_documents:
targetname: book.tex
# Add a bibtex file so that we can create citations
bibtex_bibfiles:
- references.bib
# Information about where the book exists on the web
repository:
url: https://github.com/inclusionAI/AReaL # Online location of your book
path_to_book: docs # Optional path to your book, relative to the repository root
branch: main # Which branch of the repository should be used when creating links (optional)
# Add GitHub buttons to your book
# See https://jupyterbook.org/customize/config.html#add-a-link-to-your-repository
html:
use_issues_button: true
use_repository_button: true

22
docs/_toc.yml Normal file
View File

@ -0,0 +1,22 @@
# Table of contents
# Learn more at https://jupyterbook.org/customize/toc.html
format: jb-book
root: intro
parts:
- caption: Tutorial
chapters:
- file: installation
- file: training
- file: eval
- file: troubleshooting
- caption: Developer Manual
chapters:
- file: developer/exp_launch
- file: developer/master_worker
- file: developer/model_worker
- file: developer/algo_interface
- file: developer/allocation_parallel
- caption: Contributing
chapters:
- file: contrib

74
docs/contrib.md Normal file
View File

@ -0,0 +1,74 @@
# Contribution Guide
Thank you for your interest in contributing to AReaL! We welcome contributions from everyone, whether you're fixing bugs, improving documentation, or adding new system and algorithmic features.
## Setting Up Your Development Environment
New contributors do not have write permissions to the official repository. Please fork the repository and clone your fork locally. AReaL is fully Python-based, making installation straightforward.
```bash
git clone https://github.com/${your-username}/AReaL
cd AReaL
pip3 install -r requirements.txt
pip3 install -e .
```
## Issue Guidelines
### Issue Templates
Please follow the [issue template on GitHub](https://github.com/inclusionAI/AReaL/tree/main/.github/ISSUE_TEMPLATE). Issues can be:
- Bug reports
- Feature requests
- Refactor requests
The required fields in the template help reduce communication overhead when resolving issues. **Issues with arbitrary formatting may be ignored.**
## Pull Request Guidelines
There are no specific PR templates, but **pull requests should be related to a well-templated issue**. Your PR should:
- Explain how the issue is resolved
- Describe the benefits this change will provide
- Reference the related issue number
## Code Quality
### Code Formatting
Please format your code before opening a PR:
```bash
isort . && black .
```
### Running Tests
AReaL's unit tests are based on the `pytest` framework:
```bash
# Run all tests (excluding GPU tests)
pytest -m "not gpu"
# Run a specific test case
pytest tests/test_something.py
```
**Note**: Running all tests may take several hours to complete.
## Documentation
Writing documentation is an excellent starting point for new contributors. The documentation is located in the `docs` folder and built using [Jupyter Book](https://jupyterbook.org/en/stable/intro.html).
### Adding New Documentation
1. Create your documentation files in the `docs` folder
2. Add the file path to `docs/_toc.yaml`
3. Build the documentation:
```bash
jb build docs
```
4. Preview your changes by opening the HTML files in `docs/_build/html`
This process allows you to see how your documentation will appear before submitting your contribution.

View File

@ -0,0 +1,103 @@
# Algorithm, Interface & Backends
## Overview
![](algo_interface.png)
Model Interfaces define the computations that can be performed, such as training, inference, and generation. They provide abstract classes and implementations that decouple specific algorithms (e.g., PPO, SFT) from model backends (Megatron, SGLang, vLLM). Algorithm developpers may be more interested in adding customized model interfaces.
Model backends integrate external libraries to wrap over the model as a `PipelinableEngine`, such that they can provide efficient distributed training and inference capabilities.
## Registeration
Backends and interfaces have similar registeration protocols:
```python
# Registration (at the end of each interface implementation):
model_api.register_interface("ppo", PPOActorInterface)
# Configuration (in experiment config file):
interface_config = ModelInterfaceAbstraction(
type_="ppo",
args=dict(eps_clip=0.2)
)
# Instantiation (in model worker):
interface = make_interface(interface_config)
```
## Customization
### Interfaces
An interface implementation essentially processes the data and loss function (e.g., reward clipping, computing GAEs) required by a `PipelinableEngine`, calls the actual execution method such as `PipelinableEngine.train_step`, and then runs some post-processing according to the data protocol.
Custom interfaces can be created by subclassing the `ModelInterface` class and implementing the required methods for the desired training paradigm.
Example:
```python
@dataclass
class CustomInterface(model_api.ModelInterface):
# Custom parameters
custom_param: float = 1.0
def train_step(self, model, data, mb_spec):
module = model.module
module.train()
# Custom training logic
stats = module.train_batch(
input_=data,
loss_fn=custom_loss_function,
loss_weight_fn=lambda x: x.data["mask"].count_nonzero(),
token_normalize_scope="global",
mb_spec=mb_spec,
version_steps=model.version.global_step,
)
model.inc_version()
return stats
def save(self, model, save_dir):
module = model.module
module.save_to_hf(tokenizer=model.tokenizer, save_dir=save_dir)
# Register the interface
model_api.register_interface("custom", CustomInterface)
```
Required methods vary based on the interface purpose:
+ For training interfaces: `train_step()` and `save()`
+ For inference-only interfaces: `inference()`
+ For generation interfaces: `generate()`
The interface can be configured in the experiment configuration file, e.g., `ppo_math_exp.py`. Please refer to xxx how to run unittests on your implementation.
### Backends
Backend requires implementing the `_initialize`method. Example:
```python
class FSDPEngine(PipelinableEngine):
def train_step(self, ...):
...
class FSDPBackend(ModelBackend):
def _initialize(self, model):
module = model.module
model.module: PipelinableEngine = FSDPEngine(module)
return model
register_backend("fsdp", FSDPBackend)
```
## Existing Implementations
### Interfaces
+ `ppo_interface.py`: Implemetation of PPO actor and critic.
+ `sft_interface.py`: Implementation of SFT.
### Backends
+ `megatron.py`: Training wrapper based on Megatron Core's `DistributedDataParallel`
+ `sglang.py`: A wrapper over a SGLang HTTP server for batched generation.
+ `vllm.py`: Deprecated SPMD vLLM backend.

Binary file not shown.

After

Width:  |  Height:  |  Size: 338 KiB

View File

@ -0,0 +1,67 @@
# Allocation & Parallelism
## GPU Allocation
GPU allocation is controlled by the `allocation_mode` CLI parameter. The most common pattern looks like `"sglang.d2t2p1+d1t4p1"`, which means:
+ The first 4 GPUs are allocated to SGLang for inference with:
- 2-way tensor parallelism
- 2-way data parallelism
+ The remaining GPUs are allocated for training with 4-way tensor parallelism
## Parallelism Strategies
### Training
AReaL supports three parallelism strategies for dense models, similar to Megatron:
+ Data Parallelism: Uses Megatron's DistributedDataParallel with AReaL's balanced DP partitioning algorithm (`SequenceSample.split`)
+ Tensor Parallelism: Fully replicates Megatron's `ColumnParallelLinear` and `RowParallelLinear`
+ Pipeline Parallelism: Developed in-house with 1F1B scheduling (planned to be replaced with an open-source implementation due to maintenance challenges)
### Inference
AReaL supports SGLang inference with intra-node tensor parallelism and customized data parallelism.
### Parameter Partitioning
Each model worker holds multiple model shards based on the allocation configuration.
Example: With 4 GPUs configured as:
+ Actor model: First half GPUs with tensor parallelism
+ Critic model: Second half GPUs with pipeline parallelism
+ Reference model: All GPUs with tensor and pipeline parallelism
The parameter distribution would be:
![](param_shard.png)
## Torch NCCL Communication Groups
During experiments, the following NCCL communication groups are created:
1. Global group: Includes all experiment GPUs (created in `global_comm.py`)
2. Parallelism Group: 3D parallel communication groups for a specific model (may match global group or be a subset, created in `topology.py`)
3. Data transfer groups: Groups between all data-parallel processes of any two models for data transfer (created in `data_manager.py`)
## Parallelism Ranks
Each model worker has a unique GPU index, but may have different parallel strategy coordinates under different model names (actor, critic, etc.).
Example: GPU 2 might have:
+ TP rank 1 for actor model
+ TP rank 0 for reference model
Parallel strategy coordinates are maintained in `realhf.base.constants` and accessed via:
```bash
with constants.model_scope(ModelName("actor", 0)):
dp_rank1 = constants.data_parallel_rank()
with constants.model_scope(ModelName("ref", 0)):
dp_rank2 = constants.data_parallel_rank()
```
Note: Interface and backend methods are automatically called within a model scope, so the context manager can be omitted in those implementations.

Binary file not shown.

After

Width:  |  Height:  |  Size: 162 KiB

View File

@ -0,0 +1,3 @@
# Launching Procedure
![Illustration of Experiment Launching](launch.png)

BIN
docs/developer/launch.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 389 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 260 KiB

View File

@ -0,0 +1,29 @@
# Master Worker
## Overview
![](master_arch.png)
The worker architecture of AReaL consists of a single master worker coordinating multiple model workers.
An RL algorithm typically contains several model function calls (MFCs) that need to be executed in a certain order. For example in PPO,
1. `actor_gen` generates responses given a batch of user prompts;
2. `ref_inf` computes the log-probabilities of the tokens under the reference policy;
3. `rew_inf` computes the rewards of the responses;
4. `actor_train` updates the policy with the PPO learning objective.
Here model function calls 2 and 3 depends on the output of 1. Model function call 4 depends on the outputs of 1, 2, and 3.
The MFCs are coordinated by a `FunctionExecutor` instance. It creates a `ModelFunctionCall` instance for each MFC. The actual computation is performed on model workers via remote procedure call.
## Buffer and MFC Execution Order
![](buffer_arch.png)
The master worker creates a `AsyncIOSequenceBuffer`, which is referenced by the `FunctionExecutor` and the `ModelFunctionCall`'s. The buffer is responsible for managing (meta)data and deciding the execution order of the MFCs.
Each datapoint can be seen as a `dict` of tensors. For example, the keys may include `packed_prompts` and `task_ids`. Recall that some MFC may rely on the output of another. For example in PPO, the MFC `ref_inf` requires `packed_input_ids`, which is not presented initially. Instead, `packed_input_ids` appears as one of the results of the MFC `actor_gen`.
The buffer keeps track of the available keys of each datapoint. Each `ModelFunctionCall`instance obtains its next batch via `self.get_batch_for_rpc`, which waits for enough datapoints with all the required keys. This means that it would not start execution until all required keys are ready. After a model function call execution, it calls `self.amend_batch` and updates the corresponding datapoints with new keys.
While some keys are the results of MFCs, some are loaded from the dataset via `FunctionExecutor.load_data`. Also note that instead of the actual data, the buffer stores only metadata (data indices, keys, etc.) to reduce the cost of data transfer.

View File

@ -0,0 +1,73 @@
# Model Worker
## Master-Model Worker Interaction
The master worker sends remote procedure calls (RPCs) to model workers to execute actual computations like `actor_gen` and `actor_train`. The figure below illustrates their interaction throughout an experiment:
![](master-model-interaction.png)
Model worker "compute" involves running a model interface with a specific backend (covered in detail later). For PPO algorithms, model workers sequentially execute:
+ `actor_gen`: `actor` model with SGlang backend + `PPOActorInterface.generate`
+ `rew_inf`: `reward` model (can be null for RLVR) + `MultiTaskRewardInterface.inference`
+ `actor_train`: `actor` model with Megatron backend + `PPOActorInterface.train_step`
## Communication Protocol
### Request-Reply Pattern
The master worker and model workers communicate through a `request_reply_stream` channel that handles requests and metadata responses (actual data like `input_ids` transfers through other channels).
Master (client) can send these requests to model workers (servers):
+ **fetch**: Worker loads local dataset data and sends metadata (e.g., sequence length) to master for buffer storage
+ **spec**: Worker returns dataset specifications for master to calculate experiment steps
+ **model_config**: Worker provides transformer model configuration
+ **clear_data_cache**: Worker clears data transfer and GPU caches
+ **initialize**: Worker initializes parameters, gradient buffers, and optimizer states
+ **generate/inference/train_step**: Worker executes corresponding computation (note: "inference" refers to single forward pass)
### Request Hooks
Computation requests ("generate"/"inference"/"train_step") support pre- and post-hooks for:
+ Data transfer (pre-hook)
+ Evaluation
+ Offloading
+ Parameter reallocation
+ Checkpointing (post-hooks)
These hooks often require NCCL communication/synchronization between workers. Implementing them as dedicated hooks prevents deadlocks that could occur if these operations interleaved with other NCCL communications.
### Request Types
+ **Blocking requests**: Long-running operations requiring NCCL synchronization. Workers can't execute immediately since concurrent blocking requests may need coordinated data transfers. Master sends a "flush" request to indicate all concurrent requests have been sent.
+ **Non-blocking requests**: Shorter operations without NCCL requirements that can execute immediately.
## Data Management
### Distributed Dataset Storage
Datasets distribute across model workers without overlap. For each model:
+ Processes with PP rank = -1 and TP rank = 0 serve as DP heads
+ Data stores on DP heads of the model used in the first MFC (e.g., actor model DP heads for PPO)
During "fetch" requests:
1. DP head worker loads data into local buffer
2. Sends metadata to master
3. Master tracks metadata and later instructs workers which data to use for each MFC via computation request hooks
### Data Transfer Process
For each MFC, the master:
1. Specifies which data to use
2. Provides data locations across workers
3. Workers redistribute data using:
- `Redistributor`: Generates NCCL broadcast/gather/scatter communication plan
- `DataManager`: Executes the plan
After redistribution, workers with same DP rank receive identical input data.
### MFC Output Handling
Only workers with PP rank=-1 and TP rank=0 produce output data. These workers:
1. Store data locally
2. Notify master of data locations
3. Master generates new redistribution plans for subsequent MFCs based on this layout information

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

83
docs/eval.md Normal file
View File

@ -0,0 +1,83 @@
# Evaluation
The evaluation code is located in the `evaluation` folder of the repository. Following the previous tutorial, trained checkpoints will be saved under `/storage/ray/experiments/checkpoints/root/`.
## Setup Evaluation Environment
Start a new container to execute the evaluation script. **Note**: Evaluation requires updates to certain Python libraries, so avoid using the training container for this task.
```bash
docker run -d --name areal-eval --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.3.0 /bin/bash -c "tail -f /dev/null"
docker exec -it areal-eval bash
```
## Install Dependencies and Run Evaluation
Execute the following commands inside the Docker container:
```bash
cd /storage/codes/AReaL/evaluation
cd latex2sympy
pip install -e .
cd ..
pip install -r requirements.txt
pip install vllm==0.8.5 --no-build-isolation
pip install transformers==4.51.1
pip install prettytable timeout_decorator
mkdir -p /storage/ray/eval_output/
nohup python eval_and_aggregate.py \
--model_path /storage/ray/experiments/checkpoints/root/my-exp/my-trial/epoch1epochstep20globalstep20/ \
--output_path /storage/ray/eval_output/ \
--data_names "math_500,aime24,amc23" \
--max_gen_tokens 32768 &> /storage/ray/eval_output/eval_and_aggregate_parallel.log &
```
### Command Line Parameters
- **`--model_path`**: Path to the saved model parameters
- **`--output_path`**: Path to store generated answers and log files during evaluation
- **`--data_names`**: Dataset(s) to evaluate. Multiple datasets can be separated by commas. Available options: `math_500`, `math`, `gsm8k`, `train_amc_aime`, `aime24`, `amc23`
- **`--max_gen_tokens`**: Maximum length of generated answers (default: 32768)
## Evaluation Results
The evaluation script will output a results table in the terminal:
```
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
| dataset | num_questions | greedy_length | sample_length | greedy_acc | sample_pass@1 | pass@8 | pass@16 |
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
| math_500 | 500 | 6757.4 | 4139.5 | 84.4 | 92.7 | 97.3 | 97.7 |
| aime24 | 30 | 19328.0 | 13663.5 | 50.0 | 50.4 | 77.3 | 80.0 |
| amc23 | 40 | 8850.0 | 6526.2 | 80.0 | 90.5 | 96.8 | 98.8 |
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
```
### Metrics Explanation
- **`{greedy|sample}_length`**: Average answer length under greedy or random sampling strategy
- **`greedy_acc`**: Average accuracy under greedy sampling
- **`sample_pass@{k}`**: Probability of generating a correct answer within `k` attempts under random sampling
## Configuration Details
### Sampling Parameters
- The evaluation script defaults to averaging 32 samples with temperature 0.6
- We observed that the `enforce_eager` parameter in vLLM significantly impacts evaluation performance
- When `enforce_eager=True`, we can reproduce the model performance reported in previous work
- Without this setting, evaluation results may fall below reported performance
- Therefore, we enforce `enforce_eager=True` during evaluation
### Runtime Expectations
Due to the sampling requirements and `enforce_eager` setting, the evaluation process typically takes considerable time.
Runtime depends on several factors:
- Maximum generation length
- Number of questions in the dataset
- Model size
**Performance benchmarks** (on 8x H100 GPUs):
- **AIME dataset**: ~80 minutes
- **MATH_500 dataset**: ~160 minutes

BIN
docs/figures/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 320 KiB

78
docs/installation.md Normal file
View File

@ -0,0 +1,78 @@
# Installation
## Prerequisites
### Hardware Requirements
The following hardware configuration has been extensively tested:
- **GPU**: 8x H800 per node
- **CPU**: 64 cores per node
- **Memory**: 1TB per node
- **Network**: NVSwitch + RoCE 3.2 Tbps
- **Storage**:
- 1TB local storage for single-node experiments
- 10TB shared storage (NAS) for distributed experiments
### Software Requirements
| Component | Version |
|---|:---:|
| Operating System | CentOS 7 / Ubuntu 22.04 or any system meeting the requirements below |
| NVIDIA Driver | 550.127.08 |
| CUDA | 12.8 |
| Git LFS | Required for downloading models, datasets, and AReaL code. See [installation guide](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage) |
| Docker | 27.5.1 |
| NVIDIA Container Toolkit | See [installation guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) |
| AReaL Image | `ghcr.io/inclusionai/areal-runtime:v0.3.0` (includes runtime dependencies and Ray components) |
**Note**: This tutorial does not cover the installation of NVIDIA Drivers, CUDA, or shared storage mounting, as these depend on your specific node configuration and system version. Please complete these installations independently.
## Runtime Environment
We recommend using Docker with our provided image. The Dockerfile is available in the top-level directory of the AReaL repository.
Pull the Docker image:
```bash
docker pull ghcr.io/inclusionai/areal-runtime:v0.3.0
```
This image includes all training requirements for AReaL.
**For multi-node training**: Ensure shared storage is mounted to the `/storage` directory on every node. All downloads and resources will be stored in this directory, and the AReaL container will mount this directory to `/storage` within the container.
## Code Setup
Clone the AReaL project code to `/storage/codes`:
```bash
mkdir -p /storage/codes
cd /storage/codes/
git clone https://github.com/inclusionAI/AReaL
pip install -r AReaL/requirements.txt
```
## Dataset
Download the provided training dataset and place it in `/storage/datasets/`:
```bash
mkdir -p /storage/datasets/
cd /storage/datasets/
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/boba_106k_0319.jsonl?download=true
```
## Model
We train using open-source models available on Hugging Face Hub. Here's an example using Qwen3 (ensure Git LFS is installed):
```bash
mkdir -p /storage/models
cd /storage/models
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Qwen/Qwen3-1.7B
cd Qwen3-1.7B
git lfs pull
```
**Alternative**: You can also use the Hugging Face CLI to download models after installing the `huggingface_hub` package. Refer to the [official documentation](https://huggingface.co/docs/huggingface_hub/guides/cli) for details.

6
docs/intro.md Normal file
View File

@ -0,0 +1,6 @@
# Overview
## Welcome to AReaLs documentation!
```{tableofcontents}
```

3
docs/requirements.txt Normal file
View File

@ -0,0 +1,3 @@
jupyter-book
matplotlib
numpy

108
docs/training.md Normal file
View File

@ -0,0 +1,108 @@
# Training
## Launch the Ray Cluster
### Start the Ray Head Node
On the first node, start the Ray Head with the following command:
```bash
docker run -d --name r1-ray-head --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.3.0 /bin/bash -c "ray start --head --port=6379 && tail -f /dev/null"
```
### Start Ray Worker Nodes
On all other nodes, start the Ray Worker with the following command (skip this step for single-node setups):
```bash
# Replace with the actual IP address of the first node
RAY_HEAD_IP=xxx.xxx.xxx.xxx
docker run -d --name r1-ray-worker --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.3.0 /bin/bash -c "ray start --address=$RAY_HEAD_IP:6379 && tail -f /dev/null"
```
### Verify Cluster Status
Once all nodes are running, check the Ray cluster status by entering the container on the first node:
```bash
docker exec -it r1-ray-head bash
ray status
```
You should see the Ray resource status displayed.
## Launch an Experiment
On the first node (where the Ray Head is located), run the following to launch an asynchronous PPO experiment:
```bash
docker exec -it r1-ray-head bash
cd /storage/codes/AReaL
pip3 install -e .
python3 training/main_async_ppo.py --config-name=async-ppo-1.7b-gpu8
```
This command will locate the YAML configuration file `async-ppo-1.7b-gpu8.yaml` in the `training/configs/async-ppo` folder. The meaning of each configuration entry can be found in `realhf/api/cli_args.py`. You can run asynchronous PPO, synchronous PPO, or SFT depending on the script you execute.
After starting, you'll see training launch information like this:
```
20250528-17:12:16.804 quickstart INFO: Running async-ppo-math experiment.
20250528-17:12:16.804 quickstart INFO: Logs will be dumped to /storage/experiments/logs/admin/async-ppo-1.7b-gpu8/my-trial
20250528-17:12:16.804 quickstart INFO: Experiment configs will be dumped to /storage/experiments/logs/admin/async-ppo-1.7b-gpu8/my-trial/config.yaml
20250528-17:12:16.804 quickstart INFO: Model checkpoints will be saved to /storage/experiments/checkpoints/admin/async-ppo-1.7b-gpu8/my-trial
20250528-17:12:19.261 quickstart INFO: Launching experiments with RAY...
```
**Note**: The saved YAML configuration at `/storage/experiments/logs/admin/async-ppo-1.7b-gpu8/my-trial/config.yaml` can be used to reproduce previous experiments.
## Command Line Options
To view all available options:
```bash
python3 -m realhf.apps.quickstart async-ppo-math --help
```
### Important Parameters
- **`mode`**: Always set to `ray`. Do not change this value when following this tutorial.
- **`{actor|critic|ref}.path`**: The path to the model files.
- **`dataset.path`**: The path to the dataset JSONL file.
- **`cluster.fileroot`**: The root path for saving training outputs.
- **`n_nodes`**: The number of nodes in the cluster.
- **`n_gpus_per_node`**: The number of GPUs per node.
- **`allocation_mode`**: The GPU allocation strategy and 3D parallelism configuration for the experiment. Format:
- `sglang.d${DP1}m${TP1}p${PP1}+d${DP2}m${TP2}p${PP2}`: Configures parallel strategies for SGLang generation and training respectively. Generation and training use separate GPU sets, and the total GPU count must equal: DP1×TP1×PP1 + DP2×TP2×PP2 = #GPUs.
### Training Control Parameters
- **`exp_ctrl.total_train_epochs`**: Number of training epochs (complete dataset iterations).
- **`exp_ctrl.save_freq_{epochs|steps|secs}`**: Frequency for saving model parameters to persistent storage. Set to null to disable saving.
- **`exp_ctrl.ckpt_freq_{epochs|steps|secs}`**: Frequency for saving temporary parameters for restart capability.
- **`dataset.train_bs_n_seqs`**: Training batch size (number of prompts sampled per training iteration).
- **`group_size`**: Number of responses sampled per prompt.
- **`{actor_train|ref_inf|actor_inf}.mb_spec.max_tokens_per_mb`**: Maximum tokens per mini-batch for forward/backward passes during reference model inference and actor model training. Reduce to avoid OOM errors.
- **`ppo.ppo_n_minibatches`**: Number of mini-batches for dividing data during each PPO update.
- **`ppo.recompute_logprob`**: Whether to compute proximal log probabilities for training.
- **`ppo.use_decoupled_loss`**: Use decoupled loss to stabilize asynchronous training.
- **`ppo.gen.max_new_tokens`**: Maximum tokens to generate per prompt (default: 16k).
- **`ppo.gen.min_new_tokens`**: Minimum tokens to generate per prompt (default: 0).
## Monitoring the Training Process
We recommend using Weights & Biases (wandb) for monitoring. Run `wandb login` or set the `WANDB_API_KEY` environment variable. Set `wandb.mode=True` in your configuration to upload training statistics.
The main log will be saved to `/storage/experiments/logs/admin/async-ppo-1.7b-gpu8/my-trial/main.log` and contains the statistics uploaded to wandb.
### Key Training Statistics
- **`Epoch 1/5`**: Indicates total epochs required and current epoch being trained.
- **`step 6/19`**: Shows current epoch has 19 steps, with the 6th step just completed.
- **`global step 6`**: Step count across all epochs.
- **`task_reward`**: Average reward value of all sampled responses in this step. Should steadily increase during training and eventually stabilize.
- **`importance_weight`**: Average importance sampling ratio across all tokens in the PPO loss. Typically close to 1.0.
- **`actor_clip_ratio`**: Ratio of clipped tokens in PPO loss to total tokens. Usually less than 0.1.
- **`actor_loss`**: PPO loss value. **Does not show clear trends during training** and should not be used as a performance indicator.
- **`avg_seq_len`**: Average length of all sequences (prompts with sampled responses) in this step.
- **`no_eos_ratio`**: Ratio of sampled responses truncated due to exceeding maximum generation length. An increase indicates longer average response lengths.

56
docs/troubleshooting.md Normal file
View File

@ -0,0 +1,56 @@
# Troubleshooting
If the following content does not address your issue, feel free to raise a GitHub Issue.
## Automatic Recovery
When setting `recover_mode=auto` and the experiment configuration remains unchanged, AReaL will attempt to discover previous checkpoints and recover the experiment from them.
### Recovery Failure Causes
If automatic recovery fails, check the following possibilities:
**Configuration Changes:**
- The `experiment_name` and `trial_name` in the training script differ from the previous run
- Changes in batch size (`dataset.train_bs_n_seqs` parameter)
- Changes in group size (`group_size` parameter)
- Changes in number of nodes (`n_nodes` parameter)
**Missing Recovery Checkpoints:**
Recovery checkpoints are generated under two conditions by default:
- After completion of the second step
- When a step completes and more than 600 seconds have passed since the last recovery checkpoint (controlled by `exp_ctrl.ckpt_freq_secs=600`)
### Verify Recovery Checkpoint Creation
You can confirm if a recovery checkpoint was generated by searching for the following message in the logs:
```bash
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:52:02.760 master worker INFO: Dumped recover info to file.
```
## Memory Issues
### torch.cuda.CudaOutOfMemoryError
The key to resolving this issue is identifying the phase where the error occurs:
#### During Initialization
- Check for idle processes on the GPU
- **Distributed scenarios**: Restart the Ray cluster
- **Single-machine scenarios**: Use `pkill` to terminate processes
#### During SGLang Generation
- Decrease the `actor.sglang.mem_fraction_static` parameter
- Increase the tensor parallelism degree
#### During `actor_inf` or `actor_train`
- **Adjust microbatch size**: Set parameters like `actor_train.mb_spec.max_tokens_per_mb=20480`. This parameter limits tokens per forward/backward pass and can be set as low as the maximum sequence length (including prompt)
- **Modify parallelism strategy**: Adjust `allocation_mode` by:
- Reducing data parallelism
- Increasing tensor or pipeline parallelism
- Preferring pipeline parallelism over tensor parallelism
### CUDA Error: Out of Memory
This issue may occur during data transfer. Try increasing `mem_per_xx_worker` in the CLI arguments.

View File

@ -1,427 +0,0 @@
# Improving LLM's Reasoning Capabilities with AReaL: A Complete Guide
# Prerequisites
## Hardware Requirements
Check if your hardware meets these minimum requirements:
|**Model Size**| **1.5B** |**1.5B**|**1.5B**| **7B** | **7B** | **32B** |
|---|:---:|:---:|:---:|:-------------------------:|:---:|:---:|
| **Nodes** | **1** | **4** | **16** | **4** | **16** | **16** |
| GPU | 8x H800 |8x H800 per node| 8x H800 per node | 8x H800 per node | 8x H800 per node | 8x H800 per node |
| CPU | 48 cores |48 cores per node|48 cores per node| 48 cores per node | 48 cores per node| 48 cores per node|
| Memory | 1 TB |1 TB per node|1 TB per node| 1 TB per node | 1 TB per node| 1 TB per node|
| Network | NVSwitch |NVSwitch + RoCE 3.2 Tbps|NVSwitch + RoCE 3.2 Tbps| NVSwitch + RoCE 3.2 Tbps | NVSwitch + RoCE 3.2 Tbps| NVSwitch + RoCE 3.2 Tbps|
| Storage | 1TB |Shared storage (NAS) 10TB|Shared storage (NAS) 10TB| Shared storage (NAS) 10TB |Shared storage (NAS) 10TB| Shared storage (NAS) 10TB|
| BatchSize x GroupSize | 512x16 | 512x16 | 512x16 | 512x16 | 512x16 | 512x32|
| **Single-step Time (seconds)** | **3461** | **997** | **391** | **2275** | **815** | **6707**|
| **#Steps Until Convergence** | **~250** |**~250** |**~250** |**~400** |**~400** | - |
| **Total Time (Hours)** | **~240** | **~69** | **~27** | **~252** | **~90** | - |
Notes:
- GPUs need to have 80GB memory. Other GPU models with similar specs are acceptable.
- Single-node training can use local storage, but multi-node training requires shared storage.
- We haven't successfully train a powerful 32B model, so we cannot estimate the required steps and time.
## Software Requirements
This tutorial provides a Docker image. Below are the tested software versions:
| | Version |
|---|:---:|
| OS | CentOS 7 / Ubuntu 22.04 or any other system that meets the software requirements below |
| NVIDIA Driver | 550.127.08 |
| CUDA | 12.5 |
| Git LFS | Refer to: https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage. Mainly used for downloading models, datasets, and AReaL project code. |
| Docker | 27.5.1 |
|NVIDIA Container Toolkit|[Installing the NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)|
| AReaL Image | `ghcr.io/inclusionai/areal-runtime:v0.2.0`. This image includes AReaL's runtime dependencies and Ray components. |
Since the installation of NVIDIA Drivers and CUDA, as well as the mounting of shared storage, depends on node configurations and system versions, please complete these installations independently. This tutorial does not cover their setup.
For multi-node training, ensure that the shared storage is mounted to the `/storage` directory on every node. All subsequent downloads and resources will be stored in this directory. The AReaL container will also mount this directory to `/storage` within the container, enabling seamless access during training.
# One-Click Environment Setup and Training Launch
This section provides a one-click setup script to automatically configure the node environment:
1. Install Docker, Git LFS, and NVIDIA Container Toolkit
2. Pull the AReaL image on each node
3. Download AReaL code, models, and datasets
4. Set up a Ray cluster
5. [Optional] Launch a training task within the Ray cluster
Please perform the following operations on any chosen node:
```bash
mkdir -p /storage/codes
cd /storage/codes/
git clone https://github.com/inclusionAI/AReaL.git
cd /storage/codes/AReaL
python ./examples/env/setup_env_and_start_train.py setup --private_key_file /path/to/ssh_key --ssh_port 22 --username root --hostnames NODE_IP_1 NODE_IP_2 NODE_IP_3 NODE_IP_4 --train_param 1.5B_n1
```
`setup_env_and_start_train.py setup` arguments
- `private_key_file`: SSH secret key. Using by connecting nodes.
- `ssh_port`: SSH port
- `username`: SSH username
- `hostnames`: IP list. Split with space. Can be 1, 4, or 16 node IPs
- `train_param`: [Optional] Training parameters used to launch a training task immediately after environment setup. Valid options are: `1.5B_n1`, `1.5B_n4`, `1.5B_n16`, `7B_n4`, `7B_n16`
If the script in this section fails to execute or encounters errors due to environmental discrepancies, you may manually configure the environment and launch training by following the instructions in the subsequent sections of this tutorial.
# Environment Setup
Since shared storage is used, downloading only needs to be done on one node.
## Code
Clone the AReaL project code to `/storage/codes`:
```bash
mkdir -p /storage/codes
cd /storage/codes/
git clone https://github.com/inclusionAI/AReaL
```
## Dataset
We provide a dataset for training. Download the dataset and place it in `/storage/datasets/`:
```bash
mkdir -p /storage/datasets/
cd /storage/datasets/
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/boba_106k_0319.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/orz-zero_56k_0319.jsonl?download=true
```
## Model
We train based on open-source models, which can be downloaded directly from HuggingFaceHub (Please ensure that Git LFS is installed):
```
mkdir -p /storage/models
cd /storage/models
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
cd DeepSeek-R1-Distill-Qwen-7B
git lfs pull
```
You can also use the HuggingFace CLI to download after installing PyPI and huggingface_hub. Refer to the [official documentation](https://huggingface.co/docs/huggingface_hub/guides/cli) for details.
## Launch the Ray Cluster
Before proceeding, pull the AReaL environment image, which already includes Ray components.
On the first node, start the Ray Head with the following command:
```bash
docker run -d --name r1-ray-head --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.2.0 /bin/bash -c "ray start --head --port=6379 && tail -f /dev/null"
```
On all other nodes, start the Ray Worker with the following command (skip this step if you only have one node):
```bash
# RAY_HEAD_IP is the IP of the first node
RAY_HEAD_IP=xxx.xxx.xxx.xxx
docker run -d --name r1-ray-worker --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.2.0 /bin/bash -c "ray start --address=$RAY_HEAD_IP:6379 && tail -f /dev/null"
```
Once all nodes are up, check the Ray cluster status by entering the container on the first node:
```bash
docker exec -it r1-ray-head bash
ray status
```
You should see the Ray resource status. The output will vary depending on your node count (e.g., a 16-node, 128-GPU cluster will show the following results).
```
======== Autoscaler status: 2025-02-22 14:08:51.061250 ========
Node status
---------------------------------------------------------------
Active:
1 node_d5634ae61bfe6732d957811bed65c8a39f13ece07e0326f941acbc4e
1 node_23b0c08045c9a39bc4c454cae298ee531d9a474215ac5e77a5b01e74
1 node_bc1016320658e92645f29cecb8aaf51c0b7e01a44e8ac9c814dfee59
1 node_4e7d15e9cee9ee0da5d65e45f1e346228c52bc0c557511c6eeab40dc
1 node_c5bcf15e28a00515be5d2a7e8e33d71f0f57cdfaf1003db9e0c74788
1 node_ec3f6ee8f6fdf3a5392bb4dac244668da75d094e084dcbb520ce2525
1 node_dc2f1eef88126ae4ac7902574714af9ab74b78ba037217e73e063639
1 node_a4728608c1fda187dc33bb24e831c42fe5c8a582ad428b6e595933bc
1 node_970379a3ba750ee3b13e31612b6a6b758d50bd4943555b2a13d1bd61
1 node_bf6b658bea9e437fcb642a2d881425662a689d668c92fe1545899b36
1 node_2c69511f410d9360f1d05893fde2c97dd32240e0315afea9b2d286a3
1 node_e4c90c17cc48ad469d123041d3302dcff1f7a82a4805279300812b19
1 node_3f772cbffb206c30b6ccedade83789d78397804bab874ee59563cb96
1 node_429bd5115b5590b612590bb455f2d3ed4f77055d746a184baf807655
1 node_75071820f2c16dc51fa271316b72cd45335ec877c06450d292ab7d54
1 node_6f4323f9038248d82b91321e2c4ca5fa99e65efa2d976c0b896a8964
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/2128.0 CPU
0.0/128.0 GPU
0B/21.08TiB memory
0B/2.91TiB object_store_memory
Demands:
(no resource demands)
```
# RL Training
Before starting distributed training, ensure the Ray cluster is up and running properly.
Then, on the first node (where the Ray Head is located), enter the container:
```
docker exec -it r1-ray-head bash
cd /storage/codes/AReaL
```
Choose a config file that matches your hardware environment and run it:
```bash
python3 -m realhf.apps.quickstart ppo-math --config ./examples/configs/7B-distill/ppo-7B-distill-gpus-128.yaml
```
After starting, check the training launch information:
```
╭─────────────────────────────────────────────────╮
│ Setting PPOMATHConfig with the Following Values │
╰─────────────────────────────────────────────────╯
───────────────────────── Current Configuration Begin ──────────────────────────
actor (ModelTrainEvalConfig)
actor.type (ModelFamily)
actor.type._class (str) - qwen2
actor.type.size (int) - 7
actor.type.is_critic (bool) - False
...
────────────────────────── Current Configuration End ───────────────────────────
20250222-10:26:34.877 quickstart INFO: Running ppo-math experiment.
20250222-10:44:15.581 quickstart INFO: Logs will be dumped to /storage/ray/experiments/logs/root/ppo-7B-distill-gpus-128/512x16
20250222-10:44:15.581 quickstart INFO: Model checkpoints will be saved to /storage/ray/experiments/checkpoints/root/ppo-7B-distill-gpus-128/512x16
20250222-10:26:36.408 quickstart INFO: Launching experiments with RAY...
```
If errors occur during execution (e.g., keywords like "Error" appear), refer to the troubleshooting section.
## Commandline Options
```bash
python3 -m realhf.apps.quickstart ppo-math --help
```
The descriptions of the important parameters are as follows:
+ `mode`: It is always `ray`, and do not change it to other values when referring to this tutorial for training.
+ `{actor|critic|ref}.path`: The path of the model.
+ `dataset.path`: The path of the dataset jsonl file
+ `external_configs.cluster_config`: Set config for cluster_config. e.g. fileroot is the root path for saving traning outputs.
+ `n_nodes`: The number of nodes
+ `n_gpus_per_node`: The number of GPUs per node
+ `allocation_mode`: The GPU allocation and 3D parallel strategy of the model in the experiment, mainly in the following form:
+ `sglang.d${DP1}m${TP1}p${PP1}+d${DP2}m${TP2}p${PP2}`: Configure the parallel strategies for SGLang generation and training respectively. The generation and training use disjoint sets of GPUs, and the sum of the number of GPUs used by the two should be equal to the total number of GPUs, i.e DP1xTP1xPP1+DP2xTP2xPP2=#GPUs.
+ `exp_ctrl.total_train_epochs`: The number of training epochs (i.e., the number of times to iterate over the entire dataset)
+ `exp_ctrl.save_freq_{epochs|steps|secs}`: The frequency of saving the model parameters in persistent storage. If it is set to null, the model will not be saved.
+ `exp_ctrl.ckpt_freq_{epochs|steps|secs}`: The frequency of saving temporary parameters for restart
+ `dataset.train_bs_n_seqs`: The training batch size, that is, the number of prompts to be sampled each time during training
+ `group_size`: The number of answers to be sampled for each prompt
+ `{actor_train|ref_inf}.mb_spec.max_tokens_per_mb`: The maximum number of tokens in the data for each forward/backward pass during the inference of the reference model and the training of the actor model. It can be reduced to avoid OOM errors. These data will accumulate gradients for a single parameter update.
+ `ppo.ppo_n_minibatches`: The number of parts into which all the data will be divided for each PPO update to calculate the loss and update the parameters.
+ `ppo.gen.max_new_tokens`: The maximum number of tokens to be generated for a single prompt, default to 16k.
+ `ppo.gen.min_new_tokens`: The minimum number of tokens to be generated for a single prompt, default to 0.
## Monitoring the Training Process
Here, we use the logs from a 16-node run (the same applies to 1-node and 4-node setups) to explain several methods for observing training progress and results.
### Training Progress
Search for the keyword `Epoch` in the logs to see the total number of Epochs and Steps:
```bash
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:11:56.997 master worker INFO: Epoch 1/1 step 1/19 (global step 1) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2124.429*s. Total time consumption: 2283.862s.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:52:02.719 master worker INFO: Epoch 1/1 step 2/19 (global step 2) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2405.716*s. Total time consumption: 4689.584s.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-12:27:25.084 master worker INFO: Epoch 1/1 step 3/19 (global step 3) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2122.318*s. Total time consumption: 6811.949s. Estimated remaining time: 33957.093s.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:05:58.246 master worker INFO: Epoch 1/1 step 4/19 (global step 4) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2313.134*s. Total time consumption: 9125.111s. Estimated remaining time: 33265.891s.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:44:14.349 master worker INFO: Epoch 1/1 step 5/19 (global step 5) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2296.076*s. Total time consumption: 11421.214s. Estimated remaining time: 31413.800s.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:22:33.864 master worker INFO: Epoch 1/1 step 6/19 (global step 6) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2299.448*s. Total time consumption: 13720.729s. Estimated remaining time: 29350.673s.
```
Six log entries are found. We explain the meaning of each field based on the last entry:
- `Epoch 1/1`: Indicates that a total of 1 Epoch is required, and the first Epoch is currently being trained. This example only trains for 1 Epoch. Normally, training should run for 10 Epochs or more.
- `step 6/19`: Indicates that the current Epoch has 19 Steps, and the 6th Step has just finished.
- `global step 6`: Represents the step count across all Epochs.
- `#End to end# execution time: *2299.448*s`: Indicates that the current Step took 2299.448 seconds to complete.
- `Total time consumption: 13720.729s`: The total time elapsed since training started is 13720.729 seconds.
- `Estimated remaining time: 29350.673s`: The estimated time remaining to complete training is 29350.673 seconds.
### Model Performance
Search for the keyword `task_reward` in the logs.
```bash
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:11:56.991 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.2640759198111482e-05, 'actor_loss': 1.1128166761409375e-06, 'actor_clip_ratio': 2.1122002635820536e-07, 'importance_weight': 1.0000014305114746, 'task_reward': -0.2996826171875, 'kl_reward': -2.27004832709099e-07, 'final_reward': -0.30145370960235596, 'advantage': 0.003593671601265669, 'avg_seq_len': 7907.8955078125, 'avg_prompt_len': 105.845703125, 'n_tokens': 127828786.0, 'n_valid_tokens': 127828786.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.122802734375, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:52:02.712 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.493159263394773e-05, 'actor_loss': -3.846728588996484e-07, 'actor_clip_ratio': 3.16789424914532e-07, 'importance_weight': 0.9999996423721313, 'task_reward': -0.6793212890625, 'kl_reward': -2.536311853873485e-07, 'final_reward': -0.6813737154006958, 'advantage': 0.004844569601118565, 'avg_seq_len': 8203.9453125, 'avg_prompt_len': 111.892578125, 'n_tokens': 132580185.0, 'n_valid_tokens': 132580185.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.13812255859375, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-12:27:25.077 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.572356243035756e-05, 'actor_loss': -5.036404786551429e-07, 'actor_clip_ratio': 1.8960582792715286e-07, 'importance_weight': 0.9999992251396179, 'task_reward': -0.6280517578125, 'kl_reward': -2.988609537624143e-07, 'final_reward': -0.6303607225418091, 'advantage': 0.004505862481892109, 'avg_seq_len': 7834.6328125, 'avg_prompt_len': 108.900390625, 'n_tokens': 126578395.0, 'n_valid_tokens': 126578395.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.11761474609375, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:05:58.239 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.4861981728463434e-05, 'actor_loss': 1.3935685672095133e-07, 'actor_clip_ratio': 3.02603467616791e-07, 'importance_weight': 0.9999998807907104, 'task_reward': -0.78857421875, 'kl_reward': -3.672174671009998e-07, 'final_reward': -0.791388750076294, 'advantage': 0.005053278990089893, 'avg_seq_len': 7773.39404296875, 'avg_prompt_len': 108.7890625, 'n_tokens': 125576883.0, 'n_valid_tokens': 125576883.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.117919921875, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:44:14.342 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.516058702894952e-05, 'actor_loss': -7.665488510610885e-07, 'actor_clip_ratio': 1.9505058901359007e-07, 'importance_weight': 0.9999997615814209, 'task_reward': -0.6158447265625, 'kl_reward': -4.6867208425283025e-07, 'final_reward': -0.6195111274719238, 'advantage': 0.004475570283830166, 'avg_seq_len': 7928.50830078125, 'avg_prompt_len': 105.517578125, 'n_tokens': 128171874.0, 'n_valid_tokens': 128171874.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.12353515625, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:22:33.857 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.4821250917739235e-05, 'actor_loss': -3.922649227661168e-07, 'actor_clip_ratio': 3.323623900541861e-07, 'importance_weight': 1.0000001192092896, 'task_reward': -0.7025146484375, 'kl_reward': -5.863367960046162e-07, 'final_reward': -0.7071446776390076, 'advantage': 0.004277692176401615, 'avg_seq_len': 8002.4873046875, 'avg_prompt_len': 105.951171875, 'n_tokens': 129376851.0, 'n_valid_tokens': 129376851.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.12286376953125, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
```
The last entry is used to explain the meaning of key fields:
- `task_reward`: The average reward value of all sampled answers in this step. This value should steadily increase during training and eventually stabilize.
- `importance_weight`: The average importance sampling ratio across all tokens in the PPO loss. This value is typically close to 1.
- `actor_clip_ratio`: The ratio of tokens clipped in the PPO loss to the total number of tokens. This is usually less than 0.1.
- `actor_loss`: The PPO loss. **It does not show a clear upward or downward trend during training** and should not be used as a reference for model performance.
- `avg_seq_len`: The average length of all sequences (i.e., prompts with sampled answers) in this step. In a full multi-stage training process, this value will first decrease and then increase.
- `no_eos_ratio`: The ratio of sampled answers truncated due to exceeding the maximum generation length. An increase in this value indicates that the average length of answers is increasing.
# Evaluation
## Evaluation Process
The evaluation code is located in the `evaluation` folder of the repository. As per the previous tutorial, the trained checkpoints will be saved under the path `/storage/ray/experiments/checkpoints/root/`, for example, `/storage/ray/experiments/checkpoints/root/ppo-zero-distill-7B-n16/1024x16-n16/actor/epoch1epochstep20globalstep20/`.
Start a new container to execute the evaluation script (note: evaluation requires updates to certain Python libraries; avoid using the training container for this task):
```
docker run -d --name r1-eval --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.2.0 /bin/bash -c "tail -f /dev/null"
docker exec -it r1-eval bash
```
Run the following script inside the Docker container to evaluate:
```bash
cd /storage/codes/AReaL/evaluation
cd latex2sympy
pip install -e .
cd ..
pip install -r requirements.txt
pip install vllm --no-build-isolation
pip install transformers==4.47.0
pip install prettytable timeout_decorator
mkdir /storage/ray/eval_output/
nohup python eval_and_aggregate.py \
--model_path /storage/ray/experiments/checkpoints/root/ppo-zero-distill-7B-n16/1024x16-n16/actor/epoch1epochstep20globalstep20/ \
--output_path /storage/ray/eval_output/ \
--data_names "math_500,aime24,amc23" \
--max_gen_tokens 32768 &> /storage/ray/eval_output/eval_and_aggregate_parallel.log &
```
+ `--model_path`: Path to the saved model parameters.
+ `--output_path`: Path to store the generated answers and log files during evaluation.
+ `--data_names`: Specify the dataset(s) to evaluate. Multiple datasets can be separated by commas. Default is `math_500, math, gsm8k, train_amc_aime, aime24, amc23`.
+ `--max_gen_tokens`: Maximum length of generated answers. Default is `32768`.
## Evaluation Results
The evaluation script will output a table in the terminal, for example:
```
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
| dataset | num_questions | greedy_length | sample_length | greedy_acc | sample_pass@1 | pass@8 | pass@16 |
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
| math_500 | 500 | 6757.4 | 4139.5 | 84.4 | 92.7 | 97.3 | 97.7 |
| aime24 | 30 | 19328.0 | 13663.5 | 50.0 | 50.4 | 77.3 | 80.0 |
| amc23 | 40 | 8850.0 | 6526.2 | 80.0 | 90.5 | 96.8 | 98.8 |
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
```
+ `{greedy|sample}_length`: Average answer length under greedy or random sampling strategy.
+ `greedy_acc`: Average accuracy under greedy sampling.
+ `sample_pass@{k}`: Probability of generating a correct answer on average per `k` attempts under random sampling.
## Additional Notes
### Key Parameters
+ The evaluation script defaults to taking the average of 32 samples with temperature 0.6.
+ We observed that the `enforce_eager` parameter in vLLM significantly impacts evaluation performance. When `enforce_eager=True`, we can reproduce the model performance reported in previous work. Otherwise, the evaluation results may fall below the reported performance. Therefore, we enforce `enforce_eager` to be enabled during evaluation.
Due to the above reasons, the evaluation process typically takes a considerable amount of time.
### Runtime
The runtime of the evaluation depends on factors such as the maximum generation length, the number of questions in the dataset, and the model size. On a machine with 8x H100 GPUs, evaluating `aime` and `math_500` takes approximately 80 minutes and 160 minutes, respectively.
# Troubleshooting
If the following content does not address your issue, feel free to raise a GitHub Issue.
## Automatic Recover
When setting `recover_mode=auto` and the experiment config remains the same, AReaL will try to discover previous checkpoints and recover the experiment from it.
If the automatic recover fails, please check the following possibilities:
* The `experiment_name` and `trial_name` in the training script differ from the previous run.
* Changes in Batch Size (`dataset.train_bs_n_seqs` in the parameters), Group Size (`group_size` in the parameters), or the number of nodes (`n_nodes` in the parameters).
* No recover checkpoint was created in the previous run. By default, recover checkpoints are generated under two conditions:
* After the completion of the second Step.
* When a Step completes and more than 600 seconds have passed since the last recover checkpoint. This parameter is in the `./examples/configs/*/*.yaml`, named `exp_ctrl.ckpt_freq_secs=600`.
You can confirm if a recover checkpoint was generated by searching in the log:
```bash
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:52:02.760 master worker INFO: Dumped recover info to file.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-12:27:25.105 master worker INFO: Dumped recover info to file.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:05:58.264 master worker INFO: Dumped recover info to file.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:44:14.411 master worker INFO: Dumped recover info to file.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:22:33.883 master worker INFO: Dumped recover info to file.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:59:44.925 master worker INFO: Dumped recover info to file.
```
## Series of OutOfMemory Errors
While our scripts are designed to minimize OOM (Out of Memory) errors, they can still occasionally occur, especially due to memory fragmentation and increasing sequence lengths. Although these issues are often resolved by automatic restarts, users may require the following targeted solutions.
### torch.cuda.CudaOutOfMemoryError
The key to resolving this issue is identifying the phase in which the error occurs.
- **If it occurs during initialization (before `actor_gen`):**
- Check if there are any idle processes on the GPU. In distributed scenarios, restart the Ray cluster. In single-machine scenarios, use `pkill`.
- **This error typically does not occur during the `actor_gen` phase.**
- **If it occurs during `ref_inf` or `actor_train`:**
- Adjust the microbatch size for the corresponding computation task. For example, set `actor_train.mb_spec.max_tokens_per_mb=20480`. This parameter limits the number of tokens per forward/backward pass and can be set as low as the maximum sequence length (including the prompt).
- Modify the parallelism strategy (`allocation_mode`) for the 7B model. Try reducing data parallelism and increasing tensor or pipeline parallelism.
### CUDA error: out of memory
This issue may occur during vLLM's initialization of the CPU KV cache, indicating insufficient memory on the machine. To resolve this, reduce the value of `actor.vllm.swap_space`.
### RuntimeError: Aborted due to the lack of CPU swap space.
This issue arises when the sequence length and KV cache demand exceed GPU memory, and the CPU swap space is insufficient. It is closely related to [Preemption errors](https://docs.vllm.ai/en/latest/performance/optimization.html). To resolve this, increase `actor.vllm.swap_space`. If the error persists, reduce `actor.vllm.max_num_seqs` and refer to the [vLLM documentation](https://docs.vllm.ai/en/latest/performance/optimization.html).
### CUDA error: an illegal memory access was encountered
This error typically occurs during the vLLM generation phase and is another symptom of insufficient GPU memory. Solutions include:
- Reduce the training batch size or the number of answers generated per prompt. Note that this may lower sample efficiency and extend training time.
- [Switch vLLM's attention backend to xformers](https://github.com/vllm-project/vllm/issues/5376).

View File

@ -1,423 +0,0 @@
# 利用AReaL提升大语言模型推理能力中文教程
# 前置要求
## 硬件要求
为了能正常完成训练流程,请参照下表确认你的硬件是否满足要求:
| **模型大小** | **1.5B** | **1.5B** |**1.5B** | **7B** |**7B** | **32B** |
|---------------------|---|---|---|---------------------------|---|---|
| 节点 | 1 | 4 | 16 | 4 | 16 | 16 |
| GPU | 8 张 H800 | 每节点 8 张 H800 |每节点 8 张 H800 | 每节点 8 张 H800 |每节点 8 张 H800 |每节点 8 张 H800 |
| CPU | 48 核 | 每节点 48 核 |每节点 48 核 | 每节点 48 核 |每节点 48 核 | 每节点 48 核 |
| 内存 | 1 TB |每节点 1 TB|每节点 1 TB | 每节点 1 TB |每节点 1 TB | 每节点 1 TB |
| 通信 | NVSwitch |NVSwitch+RoCE 带宽 3.2 Tbps|NVSwitch+RoCE 带宽 3.2 Tbps| NVSwitch+RoCE 带宽 3.2 Tbps |NVSwitch+RoCE 带宽 3.2 Tbps| NVSwitch+RoCE 带宽 3.2 Tbps|
| 存储 | 1TB |共享存储NAS10TB |共享存储NAS10TB | 共享存储NAS10TB |共享存储NAS10TB | 共享存储NAS10TB |
| BatchSize x GroupSize | 512x16 | 512x16 | 512x16 | 512x16 | 512x16 | 512x32|
| 单步训练时间(秒) | **3461** | **997** | **391** | **2275** | **815** | **6707**|
| 训练至收敛需要步数 | **~250** |**~250** |**~250** |**~400** |**~400** | - |
| 总训练时间(小时) | **~240** | **~69** | **~27** | **~252** | **~90** | - |
关于硬件要求的说明:
- GPU 需要 80GB 显存,可以选择同级别其他 GPU 型号。
- 单节点训练时可以使用本地存储,但多节点训练必须要提供共享存储,否则无法进行训练。
- 目前32B模型没有训练出有意义的结果所以无法估计训练到收敛需要的步数和时间。
## 软件要求
本教程提供 Docker镜像。以下是经过测试的软件版本可以参考如下软件版本进行配置。
||版本说明|
|---|---|
|OS|CentOS 7 / Ubuntu 22.04 或其他满足下方软件运行的系统|
|NVIDIA Driver|版本550.127.08|
|CUDA|版本12.5|
|Git LFS|参考:[Git LFS 安装指南](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage) 主要用于下载模型数据集AReaL 工程代码|
|Docker|版本27.5.1|
|NVIDIA Container Toolkit|[NVIDIA Container Toolkit 安装指南](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)|
|镜像|ghcr.io/inclusionai/areal-runtime:v0.2.0 这个镜像中包含运行依赖和 Ray 的相关组件|
由于 NVIDIA Driver 和 CUDA 的安装以及共享存储的挂载与节点和系统版本有关,请自行完成安装,本教程不进行介绍。
如果是多节点训练,请先将共享存储挂载到每个节点的 `/storage` 目录上,后续下载的内容都将放在这个目录下,并且 AReaL 容器也会将该目录挂载到容器的 `/storage`,以便训练时访问。
# 一键搭建环境并启动训练
本节提供一个一键安装脚本,自动完成节点的环境配置工作:
1. 安装 DockerGit LFSNVIDIA Container Toolkit
2. 在每个节点上拉取 AReaL 镜像
3. 下载 AReaL 代码,模型,数据集
4. 搭建 Ray 集群
5. 【可选】在 Ray 集群中启动一个训练任务
请选择任意一个节点执行如下操作:
```bash
mkdir -p /storage/codes
cd /storage/codes/
git clone https://github.com/inclusionAI/AReaL.git
cd /storage/codes/AReaL
python ./examples/env/setup_env_and_start_train.py setup --private_key_file /path/to/ssh_key --ssh_port 22 --username root --hostnames NODE_IP_1 NODE_IP_2 NODE_IP_3 NODE_IP_4 --train_param 1.5B_n1
```
`setup_env_and_start_train.py setup` 参数说明:
- `private_key_file`SSH 私钥文件,用于连接节点
- `ssh_port`SSH 端口
- `username`SSH 用户名
- `hostnames`IP 列表,用空格分割。可以是 1/4/16 个节点 IP
- `train_param`:【可选】训练参数,用于在完成环境搭建后直接启动一个训练任务。可选值为 `1.5B_n1``1.5B_n4``1.5B_n16``7B_n4``7B_n16`
如果因为环境差异,无法运行本节中的脚本或运行出现错误,也可以按照本教程后续章节的内容手动完成环境配置和启动训练。
# 环境配置
由于使用了共享存储,下载操作只需要在一个节点上完成。
## 代码
将 AReaL 项目代码克隆到 `/storage/codes` 中:
```bash
mkdir -p /storage/codes
cd /storage/codes/
git clone https://github.com/inclusionAI/AReaL.git
```
## 数据集
我们提供了用于训练的数据集,请下载数据集并放置在 /storage/datasets/
```bash
mkdir -p /storage/datasets/
cd /storage/datasets/
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/boba_106k_0319.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/orz-zero_56k_0319.jsonl?download=true
```
## 模型
我们基于开源模型进行训练,该模型可以从 HuggingFace Hub 直接下载(请确保已经安装了 Git LFS
```
mkdir -p /storage/models
cd /storage/models
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
cd DeepSeek-R1-Distill-Qwen-7B
git lfs pull
```
你也可以在安装 PyPI 和 huggingface_hub 后利用 huggingface CLI 进行下载,具体请参考[官方文档](https://huggingface.co/docs/huggingface_hub/guides/cli)
## 启动 Ray 集群
在执行这一步之前,请先拉取 AReaL 环境镜像,这个镜像中已经包含了 Ray 相关的组件。
在第一个节点上执行如下命令启动 Ray Head
```bash
docker run -d --name r1-ray-head --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.2.0 /bin/bash -c "ray start --head --port=6379 && tail -f /dev/null"
```
在除了第一个节点以外的每个节点上执行如下命令启动 Ray Worker如果只有一个节点这一步就不用执行了
```bash
# RAY_HEAD_IP 是第一个节点的 IP
RAY_HEAD_IP=xxx.xxx.xxx.xxx
docker run -d --name r1-ray-worker --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.2.0 /bin/bash -c "ray start --address=$RAY_HEAD_IP:6379 && tail -f /dev/null"
```
全部启动完成后,在第一个节点上通过 docker exec 进入容器,查看 Ray 集群的状态:
```bash
docker exec -it r1-ray-head bash
ray status
```
可以看到 Ray 的资源情况,输出如下(这是一个 16 节点 128 卡的集群,根据你的节点数量,这里的输出会有所不同):
```
======== Autoscaler status: 2025-02-22 14:08:51.061250 ========
Node status
---------------------------------------------------------------
Active:
1 node_d5634ae61bfe6732d957811bed65c8a39f13ece07e0326f941acbc4e
1 node_23b0c08045c9a39bc4c454cae298ee531d9a474215ac5e77a5b01e74
1 node_bc1016320658e92645f29cecb8aaf51c0b7e01a44e8ac9c814dfee59
1 node_4e7d15e9cee9ee0da5d65e45f1e346228c52bc0c557511c6eeab40dc
1 node_c5bcf15e28a00515be5d2a7e8e33d71f0f57cdfaf1003db9e0c74788
1 node_ec3f6ee8f6fdf3a5392bb4dac244668da75d094e084dcbb520ce2525
1 node_dc2f1eef88126ae4ac7902574714af9ab74b78ba037217e73e063639
1 node_a4728608c1fda187dc33bb24e831c42fe5c8a582ad428b6e595933bc
1 node_970379a3ba750ee3b13e31612b6a6b758d50bd4943555b2a13d1bd61
1 node_bf6b658bea9e437fcb642a2d881425662a689d668c92fe1545899b36
1 node_2c69511f410d9360f1d05893fde2c97dd32240e0315afea9b2d286a3
1 node_e4c90c17cc48ad469d123041d3302dcff1f7a82a4805279300812b19
1 node_3f772cbffb206c30b6ccedade83789d78397804bab874ee59563cb96
1 node_429bd5115b5590b612590bb455f2d3ed4f77055d746a184baf807655
1 node_75071820f2c16dc51fa271316b72cd45335ec877c06450d292ab7d54
1 node_6f4323f9038248d82b91321e2c4ca5fa99e65efa2d976c0b896a8964
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/2128.0 CPU
0.0/128.0 GPU
0B/21.08TiB memory
0B/2.91TiB object_store_memory
Demands:
(no resource demands)
```
# RL训练
在进行分布式训练之前,请确保已经启动了 Ray 集群,并且集群状态正常。
然后在第一个节点Ray Head 所在节点),进入容器:
```
docker exec -it r1-ray-head bash
cd /storage/codes/AReaL
```
选择匹配硬件环境的一个配置运行即可:
```bash
python3 -m realhf.apps.quickstart ppo-math --config ./examples/configs/7B-distill/ppo-7B-distill-gpus-128.yaml
```
启动后,在终端可以看到启动日志:
```
╭─────────────────────────────────────────────────╮
│ Setting PPOMATHConfig with the Following Values │
╰─────────────────────────────────────────────────╯
───────────────────────── Current Configuration Begin ──────────────────────────
actor (ModelTrainEvalConfig)
actor.type (ModelFamily)
actor.type._class (str) - qwen2
actor.type.size (int) - 7
actor.type.is_critic (bool) - False
...
────────────────────────── Current Configuration End ───────────────────────────
20250222-10:26:34.877 quickstart INFO: Running ppo-math experiment.
20250222-10:44:15.581 quickstart INFO: Logs will be dumped to /storage/ray/experiments/logs/root/ppo-7B-distill-gpus-128/512x16
20250222-10:44:15.581 quickstart INFO: Model checkpoints will be saved to /storage/ray/experiments/checkpoints/root/ppo-7B-distill-gpus-128/512x16
20250222-10:26:36.408 quickstart INFO: Launching experiments with RAY...
```
如果运行过程中出现错误(比如出现 Error 关键字请参考Troubleshooting解决。
## Commandline Options
```bash
python3 -m realhf.apps.quickstart ppo-math --help
```
其中重要的参数的说明如下:
+ mode总是为 ray参考本教程进行训练时不要改成其他值。
+ {actor|critic|ref}.path模型的路径
+ dataset.path数据集 jsonl 文件的路径
+ external_configs.cluster_config设置 cluster_config 的配置,比如 fileroot 是存放训练输出的根目录。
+ n_nodes节点数量
+ n_gpus_per_node每个节点的 GPU 数量
+ allocation_mode实验中模型的 GPU 分配和 3D 并行策略,推荐的策略有以下形式:
+ `sglang.d${DP1}m${TP1}p${PP1}+d${DP2}m${TP2}p${PP2}`: 分别配置 SGLang 生成和训练的并行策略,生成和训练分离,使用两部分不同的 GPU。二者所用的GPU数量相加要等于总的 GPU 数量,即 DP1xTP1xPP1+DP2xTP2xPP2=#GPUs。
+ exp_ctrl.total_train_epochs训练的 epoch 数量(即迭代整个数据集的次数)
+ exp_ctrl.save_freq_{epochs|steps|secs}:保存持久化存储模型参数的频率,如果设成 null 会不保存模型
+ exp_ctrl.ckpt_freq_{epochs|steps|secs}:保存临时参数用于重启的频率
+ dataset.train_bs_n_seqs训练的批量大小即每次训练需要采样的 prompt 数量
+ group_size每个 prompt 需要采样的答案数量
+ {actor_train|ref_inf}.mb_spec.max_tokens_per_mbreference模型推理和actor模型训练每次forward/backward数据中最大的token数量可以减小以避免OOM错误。这些数据会累积梯度进行一次参数更新。
+ ppo.ppo_n_minibatches每次PPO更新中会把所有数据划分成多少份以此进行loss计算和参数更新。
+ ppo.gen.max_new_tokens每条prompt生成的最大token数默认训练脚本中为16k。
+ ppo.gen.min_new_tokens每条prompt生成的最小token数默认为0。
## 过程观测
这里以 16 节点的运行日志为例1 节点和 4 节点也一样),说明几个观察训练进度和效果的方法。
### 查看训练进度
搜索日志中的 Epoch 关键字,查看总的 Epoch 数量和 Step 数量:
```bash
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:11:56.997 master worker INFO: Epoch 1/1 step 1/19 (global step 1) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2124.429*s. Total time consumption: 2283.862s.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:52:02.719 master worker INFO: Epoch 1/1 step 2/19 (global step 2) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2405.716*s. Total time consumption: 4689.584s.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-12:27:25.084 master worker INFO: Epoch 1/1 step 3/19 (global step 3) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2122.318*s. Total time consumption: 6811.949s. Estimated remaining time: 33957.093s.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:05:58.246 master worker INFO: Epoch 1/1 step 4/19 (global step 4) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2313.134*s. Total time consumption: 9125.111s. Estimated remaining time: 33265.891s.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:44:14.349 master worker INFO: Epoch 1/1 step 5/19 (global step 5) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2296.076*s. Total time consumption: 11421.214s. Estimated remaining time: 31413.800s.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:22:33.864 master worker INFO: Epoch 1/1 step 6/19 (global step 6) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2299.448*s. Total time consumption: 13720.729s. Estimated remaining time: 29350.673s.
```
出现了 6 条日志信息,以最后一条信息的内容说明各个字段的含义:
+ `Epoch 1/1`:表示总共需要训练 1 个 Epochs当前在训练第 1 个。这里作为例子总共只训练 1 个 Epoch正常训练应该是 10 个 Epochs 或者更多。
+ `step 6/19`:表示当前 Epoch 有 19 个 Steps当前在训练第 6 个
+ `global step 6` 表示当前 Step 在所有 Epochs 的 Steps 里的序号
+ `#End to end# execution time: *2299.448*s`:表示当前 Step 训练耗费了 2299.448 秒
+ `Total time consumption: 13720.729s`:从训练启动开始一共耗费了 13720.729 秒
+ `Estimated remaining time: 29350.673s`:预计完成训练还需要 29350.673 秒
### 查看训练的效果
搜索日志中的 `task_reward` 关键字
```bash
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:11:56.991 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.2640759198111482e-05, 'actor_loss': 1.1128166761409375e-06, 'actor_clip_ratio': 2.1122002635820536e-07, 'importance_weight': 1.0000014305114746, 'task_reward': -0.2996826171875, 'kl_reward': -2.27004832709099e-07, 'final_reward': -0.30145370960235596, 'advantage': 0.003593671601265669, 'avg_seq_len': 7907.8955078125, 'avg_prompt_len': 105.845703125, 'n_tokens': 127828786.0, 'n_valid_tokens': 127828786.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.122802734375, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:52:02.712 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.493159263394773e-05, 'actor_loss': -3.846728588996484e-07, 'actor_clip_ratio': 3.16789424914532e-07, 'importance_weight': 0.9999996423721313, 'task_reward': -0.6793212890625, 'kl_reward': -2.536311853873485e-07, 'final_reward': -0.6813737154006958, 'advantage': 0.004844569601118565, 'avg_seq_len': 8203.9453125, 'avg_prompt_len': 111.892578125, 'n_tokens': 132580185.0, 'n_valid_tokens': 132580185.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.13812255859375, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-12:27:25.077 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.572356243035756e-05, 'actor_loss': -5.036404786551429e-07, 'actor_clip_ratio': 1.8960582792715286e-07, 'importance_weight': 0.9999992251396179, 'task_reward': -0.6280517578125, 'kl_reward': -2.988609537624143e-07, 'final_reward': -0.6303607225418091, 'advantage': 0.004505862481892109, 'avg_seq_len': 7834.6328125, 'avg_prompt_len': 108.900390625, 'n_tokens': 126578395.0, 'n_valid_tokens': 126578395.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.11761474609375, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:05:58.239 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.4861981728463434e-05, 'actor_loss': 1.3935685672095133e-07, 'actor_clip_ratio': 3.02603467616791e-07, 'importance_weight': 0.9999998807907104, 'task_reward': -0.78857421875, 'kl_reward': -3.672174671009998e-07, 'final_reward': -0.791388750076294, 'advantage': 0.005053278990089893, 'avg_seq_len': 7773.39404296875, 'avg_prompt_len': 108.7890625, 'n_tokens': 125576883.0, 'n_valid_tokens': 125576883.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.117919921875, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:44:14.342 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.516058702894952e-05, 'actor_loss': -7.665488510610885e-07, 'actor_clip_ratio': 1.9505058901359007e-07, 'importance_weight': 0.9999997615814209, 'task_reward': -0.6158447265625, 'kl_reward': -4.6867208425283025e-07, 'final_reward': -0.6195111274719238, 'advantage': 0.004475570283830166, 'avg_seq_len': 7928.50830078125, 'avg_prompt_len': 105.517578125, 'n_tokens': 128171874.0, 'n_valid_tokens': 128171874.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.12353515625, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:22:33.857 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.4821250917739235e-05, 'actor_loss': -3.922649227661168e-07, 'actor_clip_ratio': 3.323623900541861e-07, 'importance_weight': 1.0000001192092896, 'task_reward': -0.7025146484375, 'kl_reward': -5.863367960046162e-07, 'final_reward': -0.7071446776390076, 'advantage': 0.004277692176401615, 'avg_seq_len': 8002.4873046875, 'avg_prompt_len': 105.951171875, 'n_tokens': 129376851.0, 'n_valid_tokens': 129376851.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.12286376953125, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
```
以最后一条说明其中几个重点字段的含义:
+ `task_reward`这个step中采样的所有答案的平均奖励值训练稳步进行的话这个值会持续上升最终维持不变
+ `importance_weight`: PPO loss中重要性采样比率在所有token上的平均值通常接近1。
+ `actor_clip_ratio`: PPO loss中被clip掉的token占所有token的比率通常小于0.1。
+ `actor_loss`: PPO loss**不会随着训练过程有明显的上升或下降趋势**,不应作为模型表现的参考。
+ `avg_seq_len`: 这一步中采样的所有序列(即提示词和答案相加)的平均长度。在完整的多阶段训练中,这个值会先下降再上升。
+ `no_eos_ratio`: 这一步中采样的所有答案因为超出最大生成长度被截断的比率。这个值上升也代表了答案的平均长度在上升。
# 评估
## 评估流程
评估代码包含在仓库的`evaluation`文件夹中。按照以上的教程训练得到的checkpoint会保存在`/storage/ray/experiments/checkpoints/root/`路径下,例如`/storage/ray/experiments/checkpoints/root/ppo-zero-distill-7B-n16/1024x16-n16/actor/epoch1epochstep20globalstep20/`。
启动一个新的容器用于运行评估脚本(评估需要更新部分 python 库,请不要在训练容器中进行):
```
docker run -d --name r1-eval --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.2.0 /bin/bash -c "tail -f /dev/null"
docker exec -it r1-eval bash
```
在docker容器内部运行以下脚本进行评估
```bash
cd /storage/codes/AReaL/evaluation
cd latex2sympy
pip install -e .
cd ..
pip install -r requirements.txt
pip install vllm --no-build-isolation
pip install transformers==4.47.0
pip install prettytable timeout_decorator
mkdir /storage/ray/eval_output/
nohup python eval_and_aggregate.py \
--model_path /storage/ray/experiments/checkpoints/root/ppo-zero-distill-7B-n16/1024x16-n16/actor/epoch1epochstep20globalstep20/ \
--output_path /storage/ray/eval_output/ \
--data_names "math_500,aime24,amc23" \
--max_gen_tokens 32768 &> /storage/ray/eval_output/eval_and_aggregate_parallel.log &
```
+ `--model_path`:模型参数的保存路径
+ `--output_path`:评估过程中生成的答案和日志文件路径
+ `--data_names`: 可以指定评测某个数据,多个数据集用逗号隔开,默认为 math_500, aime24, amc23
+ `--max_gen_tokens`:最长的答案生成长度,默认值 32768
## 评估结果
评估脚本运行完后会在 /storage/ray/eval_output/eval_and_aggregate_parallel.log 日志文件输出一个表格,例如:
```
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
| dataset | num_questions | greedy_length | sample_length | greedy_acc | sample_pass@1 | pass@8 | pass@16 |
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
| math_500 | 500 | 6757.4 | 4139.5 | 84.4 | 92.7 | 97.3 | 97.7 |
| aime24 | 30 | 19328.0 | 13663.5 | 50.0 | 50.4 | 77.3 | 80.0 |
| amc23 | 40 | 8850.0 | 6526.2 | 80.0 | 90.5 | 96.8 | 98.8 |
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
```
+ `{greedy|sample}_length`: 在greedy或随机采样策略下生成的平均答案长度
+ `greedy_acc`在greedy采样下的平均准确率
+ `sample_pass@{k}`在随机采样下平均每k个答案产生正确答案的概率
## 额外说明
### 关键参数
+ 我们提供的评估脚本默认采样32次取平均值采样温度值为0.6
+ 我们发现vLLM的`enforce_eager`参数很大程度影响评估性能,当`enforce_eager=True`时我们才能够复现先前工作汇报的模型表现,否则评估结果会低于先前工作汇报的结果,因此我们会在执行 `eval_and_aggregate_parallel.py` 时将`enforce_eager`强制开启。
由于以上原因,评估过程通常会消耗较长时间。
### 运行时间
评估的运行时间取决于最长生成长度、数据集的题目数量和模型大小等等。在1台8*H100机器上7B模型数据集为`math_500,aime24,amc23`生成长度为32768评估脚本运行时间为 5 个小时。
# Troubleshooting
如果以下内容没有解答你的问题,欢迎在 GitHub Issue 中进行提问。
## 自动恢复
当设置了 `recover_mode=auto` 并且训练配置和之前相同AReaL 会尝试找到之前生成的 checkpoints 并且从这个 checkpoints 恢复训练。
如果自动恢复失败,有这些可能性:
+ 训练配置里的 `experiment_name``trial_name` 与之前的不一样
+ Batch Size参数里的 `dataset.train_bs_n_seqs`Group Size参数里的 `group_size`),节点数(参数里的 `n_nodes`)三个值发生了变化
+ 之前的训练没有创建过 recover checkpoint 。默认的 recover checkpoint 规则有 2 个:
+ 从第 2 个 step 完成后才生成 recover checkpoint
+ 一个 step 训练完成,且距离上次 recover checkpoint 时间超过 600s则生成一个新的 recover checkpoint。这个参数在 `./examples/configs/*/*.yaml` 文件里,参数名为 `exp_ctrl.ckpt_freq_secs=600`。
可以通过搜索 `Dumped recover` 确认是否生成过 recover checkpoint
```bash
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:52:02.760 master worker INFO: Dumped recover info to file.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-12:27:25.105 master worker INFO: Dumped recover info to file.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:05:58.264 master worker INFO: Dumped recover info to file.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:44:14.411 master worker INFO: Dumped recover info to file.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:22:33.883 master worker INFO: Dumped recover info to file.
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:59:44.925 master worker INFO: Dumped recover info to file.
```
## 一系列OutOfMemory错误
我们提供的脚本已经尽最大努力避免了OOM错误的发生但是OOM问题仍然会随着训练进行在内存碎片增加和生成序列长度越来越长时偶尔发生。虽然这些问题通常可以通过自动重启解决当重启频繁时用户还可以尝试以下针对性的解决方式。
### torch.cuda.CudaOutOfMemoryError
解决这个问题的关键是定位错误发生的阶段。
- 如果发生在初始化阶段在进入到actor_gen之前:
- 检查当前GPU上是否存在残留进程。在分布式场景下可以通过重启ray cluster解决在单机场景下可以通过pkill解决。
- 该错误通常不会发生在actor_gen阶段。
- 如果发生在ref_inf或actor_train阶段
- 改变相应计算任务的microbatch大小例如`actor_train.mb_spec.max_tokens_per_mb=20480`这个参数代表每次模型forward/backward的数据最多只会包含20480个token这个值最小可以设为生成序列的最长长度包括prompt
- 改变模型的并行策略,即`allocation_mode`,可以尝试减少数据并行的大小,增加张量或流水线并行的大小。
### CUDA error: out of memory
这个问题可能会发生在vLLM初始化CPU KV cache时表示每台机器的内存不够了。可以减小`actor.vllm.swap_space`解决。
### RuntimeError: Aborted due to the lack of CPU swap space.
问题的原因是序列长、对KV cache需求大在GPU显存不够时KV cache会被卸载到内存而内存中设置的swap space不够。这个问题和[Preemption的报错](https://docs.vllm.ai/en/latest/performance/optimization.html)紧密相关。解决方案是增加`actor.vllm.swap_space`,如果同样的错误出现,请减少`actor.vllm.max_num_seqs`并参考[vLLM官方文档](https://docs.vllm.ai/en/latest/performance/optimization.html)。
### CUDA error: an illegal memory access was encountered
通常会在vLLM生成阶段出现同样是显存不足的一种表现。解决方案包括
+ 减小训练batch size或者每个prompt生成的答案数量但减小后会降低样本效率、延长训练时间
+ [将vLLM的attention backend换成xformers](https://github.com/vllm-project/vllm/issues/5376)

View File

@ -1,3 +1,501 @@
diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py
index 60091b9a..7a1c856b 100644
--- a/python/sglang/srt/layers/logits_processor.py
+++ b/python/sglang/srt/layers/logits_processor.py
@@ -15,6 +15,7 @@
import dataclasses
import logging
+import os
from typing import List, Optional, Union
import torch
@@ -44,6 +45,15 @@ from sglang.srt.model_executor.forward_batch_info import (
)
from sglang.srt.utils import dump_to_file
+# When compute the input and output tokens logprobs, if the rows of the
+# logprobs are too large, the peak memory usage will be too high. For example,
+# if the logprobs are [10000, 150000], the peak memory usage will be greater
+# than 10000 * 150000 * 4 / 1024 / 1024 = 5722.05 MB. (4 is the size of float)
+# So we split the logprobs into multiple chunks.
+LOGITS_PROCESSER_CHUNK_SIZE = int(
+ os.environ.get("SGLANG_LOGITS_PROCESSER_CHUNK_SIZE", "2048")
+)
+
logger = logging.getLogger(__name__)
@@ -286,15 +296,17 @@ class LogitsProcessor(nn.Module):
input_logprob_indices = None
else:
# Input logprobs are required.
- # Find 3 different indices.
+ # Find 4 different indices.
# 1. pruned_states: hidden states that we want logprobs from.
# 2. sample_indices: Indices that have sampled tokens.
# 3. input_logprob_indices: Indices that have input logprob tokens.
+ # 4. sequence_index_mapping: map pruned_states indices to top_logprobs_nums and token_ids_logprobs indices.
sample_index_pt = -1
sample_indices = []
input_logprob_indices_pt = 0
input_logprob_indices = []
pt, pruned_states = 0, []
+ idx, sequence_index_mapping = 0, []
for extend_logprob_start_len, extend_len in zip(
logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu,
@@ -310,7 +322,11 @@ class LogitsProcessor(nn.Module):
# by a caller.
assert extend_len > start_len
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
+ # sequence_index_mapping, repeat this for loop index
+ sequence_index_mapping.extend([idx] * (extend_len - start_len))
+ idx += 1
pt += extend_len
+
sample_index_pt += extend_len - start_len
sample_indices.append(sample_index_pt)
input_logprob_indices.extend(
@@ -321,6 +337,7 @@ class LogitsProcessor(nn.Module):
)
input_logprob_indices_pt += extend_len - start_len
+ sequence_index_mapping.append(idx - 1)
pruned_states = torch.cat(pruned_states)
sample_indices = torch.tensor(
sample_indices, device=pruned_states.device, dtype=torch.int64
@@ -329,12 +346,6 @@ class LogitsProcessor(nn.Module):
input_logprob_indices, device=pruned_states.device, dtype=torch.int64
)
- # Compute logits for both input and sampled tokens.
- logits = self._get_logits(pruned_states, lm_head, logits_metadata)
- sampled_logits = (
- logits[sample_indices] if sample_indices is not None else logits
- )
-
if self.debug_tensor_dump_output_folder:
assert (
not self.do_tensor_parallel_all_gather
@@ -370,67 +381,176 @@ class LogitsProcessor(nn.Module):
else:
assert False, "Should never reach"
+ del hidden_states
+
if not logits_metadata.extend_return_logprob:
+ # Compute logits for both input and sampled tokens.
+ logits = self._get_logits(pruned_states, lm_head, logits_metadata)
+ sampled_logits = (
+ logits[sample_indices] if sample_indices is not None else logits
+ )
+
# Decode mode or extend mode without return_logprob.
return LogitsProcessorOutput(
next_token_logits=sampled_logits,
hidden_states=hidden_states_to_store,
)
else:
- input_logprobs = logits[input_logprob_indices]
- del hidden_states, logits
+ # Compute logprobs requires lot of memory, so we split pruned_states
+ # into chunks of rows to compute input_logprobs separately, then
+ # concatenate the results.
+ return self._compute_output_by_chunk(
+ pruned_states,
+ sample_indices,
+ hidden_states_to_store,
+ input_logprob_indices,
+ sequence_index_mapping,
+ lm_head,
+ logits_metadata,
+ )
- # Normalize the logprob w/o temperature, top-p
- pruned_lens = torch.tensor(
- logits_metadata.extend_logprob_pruned_lens_cpu,
- device=input_logprobs.device,
+ def _compute_output_by_chunk(
+ self,
+ pruned_states: torch.Tensor,
+ sample_indices: torch.Tensor,
+ hidden_states_to_store: Optional[torch.Tensor],
+ input_logprob_indices: torch.Tensor,
+ index_mapping: list[int],
+ lm_head: VocabParallelEmbedding,
+ logits_metadata: LogitsMetadata,
+ ) -> LogitsProcessorOutput:
+ """
+ compute logprobs for the output token from the hidden states.
+ To avoid using too much memory, we split pruned_states into chunks of
+ rows to compute input_logprobs separately, then concatenate the results.
+
+ Returns:
+ LogitsProcessorOutput: logits processor output class
+ """
+
+ # Normalize the logprob w/o temperature, top-p
+ pruned_lens = torch.tensor(
+ logits_metadata.extend_logprob_pruned_lens_cpu,
+ device=pruned_states.device,
+ )
+ if logits_metadata.temp_scaled_logprobs:
+ logits_metadata.temperature = torch.repeat_interleave(
+ logits_metadata.temperature.view(-1),
+ pruned_lens,
+ ).view(-1, 1)
+ if logits_metadata.top_p_normalized_logprobs:
+ logits_metadata.top_p = torch.repeat_interleave(
+ logits_metadata.top_p,
+ pruned_lens,
)
- if logits_metadata.temp_scaled_logprobs:
- logits_metadata.temperature = torch.repeat_interleave(
- logits_metadata.temperature.view(-1),
- pruned_lens,
- ).view(-1, 1)
- if logits_metadata.top_p_normalized_logprobs:
- logits_metadata.top_p = torch.repeat_interleave(
- logits_metadata.top_p,
- pruned_lens,
- )
- input_logprobs = self.compute_temp_top_p_normalized_logprobs(
- input_logprobs, logits_metadata
+
+ # The peak memory usage is proportional to the chunk size.
+ chunk_size = LOGITS_PROCESSER_CHUNK_SIZE
+ num_chunks = (pruned_states.shape[0] + chunk_size - 1) // chunk_size
+
+ input_token_logprobs = []
+ if logits_metadata.extend_return_top_logprob:
+ input_top_logprobs_val = []
+ input_top_logprobs_idx = []
+ else:
+ input_top_logprobs_val = None
+ input_top_logprobs_idx = None
+
+ if logits_metadata.extend_token_ids_logprob:
+ input_token_ids_logprobs_val = []
+ input_token_ids_logprobs_idx = []
+ else:
+ input_token_ids_logprobs_val = None
+ input_token_ids_logprobs_idx = None
+
+ # It a single sequence is split into multiple chunks, we need to keep track
+ # of the pruned length of the sequences in the previous chunks.
+ split_len_topk = 0
+ split_len_token_ids = 0
+
+ for i in range(num_chunks):
+ start_idx = i * chunk_size
+ end_idx = min((i + 1) * chunk_size, pruned_states.shape[0])
+
+ # Get indices for this chunk
+ chunk_mask = (input_logprob_indices >= start_idx) & (
+ input_logprob_indices < end_idx
+ )
+
+ global_indices = input_logprob_indices[chunk_mask]
+ chunk_indices = global_indices - start_idx
+ chunk_states = pruned_states[start_idx:end_idx]
+ chunk_logits = self._get_logits(chunk_states, lm_head, logits_metadata)
+
+ if chunk_indices.numel() == 0:
+ continue
+
+ # Compute the logprobs of the chunk
+ chunk_input_logprobs = chunk_logits[chunk_indices]
+ chunk_input_logprobs = self.compute_temp_top_p_normalized_logprobs(
+ chunk_input_logprobs, global_indices, logits_metadata
)
+ # For each chunk, we need to get the slice of the sequence_index_mapping
+ chunk_slice = slice(index_mapping[start_idx], index_mapping[end_idx] + 1)
+
# Get the logprob of top-k tokens
if logits_metadata.extend_return_top_logprob:
- (
+ split_len_topk = self.get_top_logprobs(
+ chunk_input_logprobs,
+ logits_metadata,
+ chunk_slice,
input_top_logprobs_val,
input_top_logprobs_idx,
- ) = self.get_top_logprobs(input_logprobs, logits_metadata)
- else:
- input_top_logprobs_val = input_top_logprobs_idx = None
+ split_len_topk,
+ )
# Get the logprob of given token id
if logits_metadata.extend_token_ids_logprob:
- (
+ split_len_token_ids = self.get_token_ids_logprobs(
+ chunk_input_logprobs,
+ logits_metadata,
+ chunk_slice,
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
- ) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
- else:
- input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
-
- input_token_logprobs = input_logprobs[
- torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
- logits_metadata.extend_input_logprob_token_ids_gpu,
- ]
+ split_len_token_ids,
+ )
- return LogitsProcessorOutput(
- next_token_logits=sampled_logits,
- input_token_logprobs=input_token_logprobs,
- input_top_logprobs_val=input_top_logprobs_val,
- input_top_logprobs_idx=input_top_logprobs_idx,
- hidden_states=hidden_states_to_store,
- input_token_ids_logprobs_val=input_token_ids_logprobs_val,
- input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
+ # Handle sampled logits for the chunk if needed
+ chunk_sample_mask = (sample_indices >= start_idx) & (
+ sample_indices < end_idx
)
+ if i == 0: # Initialize sampled_logits on first chunk
+ sampled_logits = torch.empty(
+ (sample_indices.shape[0], chunk_logits.shape[1]),
+ dtype=chunk_logits.dtype,
+ device=chunk_logits.device,
+ )
+ if chunk_sample_mask.any():
+ chunk_sample_indices = sample_indices[chunk_sample_mask] - start_idx
+ sampled_logits[chunk_sample_mask] = chunk_logits[chunk_sample_indices]
+
+ # Get the logprob of the requested token ids
+ chunk_input_token_logprobs = chunk_input_logprobs[
+ torch.arange(
+ chunk_input_logprobs.shape[0], device=chunk_input_logprobs.device
+ ),
+ logits_metadata.extend_input_logprob_token_ids_gpu[start_idx:end_idx],
+ ]
+ input_token_logprobs.append(chunk_input_token_logprobs)
+
+ # Concatenate the results
+ input_token_logprobs = torch.cat(input_token_logprobs, dim=0)
+
+ return LogitsProcessorOutput(
+ hidden_states=hidden_states_to_store,
+ next_token_logits=sampled_logits,
+ input_token_logprobs=input_token_logprobs,
+ input_top_logprobs_val=input_top_logprobs_val,
+ input_top_logprobs_idx=input_top_logprobs_idx,
+ input_token_ids_logprobs_val=input_token_ids_logprobs_val,
+ input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
+ )
def _get_logits(
self,
@@ -498,60 +618,142 @@ class LogitsProcessor(nn.Module):
return logits
@staticmethod
- def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
+ def get_top_logprobs(
+ logprobs: torch.Tensor,
+ logits_metadata: LogitsMetadata,
+ chunk_slice: slice,
+ input_top_logprobs_val: List,
+ input_top_logprobs_idx: List,
+ split_pruned_len: int,
+ ):
+ """Get top-k logprobs for each sequence in the chunk.
+
+ Args:
+ logprobs: Log probabilities tensor of shape [seq_len, vocab_size]
+ logits_metadata: Metadata containing top-k and pruned length info
+ chunk_slice: Slice of sequences to process
+ input_top_logprobs_val: List to store top-k logprob values
+ input_top_logprobs_idx: List to store top-k token indices
+ split_pruned_len: Length of pruned tokens from previous chunk
+
+ Returns:
+ int: Number of remaining tokens to process in next chunk
+ """
+
max_k = max(logits_metadata.top_logprobs_nums)
- ret = all_logprobs.topk(max_k, dim=1)
+ ret = logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
- input_top_logprobs_val, input_top_logprobs_idx = [], []
-
pt = 0
- for k, pruned_len in zip(
- logits_metadata.top_logprobs_nums,
- logits_metadata.extend_logprob_pruned_lens_cpu,
- ):
+ next_split_pruned_len = 0
+ top_k_nums = logits_metadata.top_logprobs_nums[chunk_slice]
+ pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[chunk_slice]
+
+ for n, (k, pruned_len) in enumerate(zip(top_k_nums, pruned_lens)):
+ # Adjust pruned length for first sequence
+ if n == 0:
+ pruned_len -= split_pruned_len
+ else:
+ split_pruned_len = 0
+
if pruned_len <= 0:
- input_top_logprobs_val.append([])
- input_top_logprobs_idx.append([])
+ if n == 0:
+ input_top_logprobs_val.append([])
+ input_top_logprobs_idx.append([])
continue
- input_top_logprobs_val.append(
- [values[pt + j][:k] for j in range(pruned_len)]
- )
- input_top_logprobs_idx.append(
- [indices[pt + j][:k] for j in range(pruned_len)]
- )
- pt += pruned_len
+ val = []
+ idx = []
+ for j in range(pruned_len):
+ # Handle remaining tokens in next chunk if any
+ if pt + j >= len(values):
+ next_split_pruned_len = split_pruned_len + j
+ break
+ val.append(values[pt + j][:k])
+ idx.append(indices[pt + j][:k])
+
+ if split_pruned_len <= 0 and len(val) > 0:
+ input_top_logprobs_val.append(val)
+ input_top_logprobs_idx.append(idx)
+ else:
+ input_top_logprobs_val[-1].extend(val)
+ input_top_logprobs_idx[-1].extend(idx)
- return input_top_logprobs_val, input_top_logprobs_idx
+ pt += pruned_len
+ return next_split_pruned_len
@staticmethod
def get_token_ids_logprobs(
- all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
+ logprobs: torch.Tensor,
+ logits_metadata: LogitsMetadata,
+ chunk_slice: slice,
+ input_token_ids_logprobs_val: List,
+ input_token_ids_logprobs_idx: List,
+ split_pruned_len: int = 0,
):
- input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
+ """Get token_ids logprobs for each sequence in the chunk.
+
+ Args:
+ logprobs: Log probabilities tensor of shape [seq_len, vocab_size]
+ logits_metadata: Metadata containing token IDs and pruned length info
+ chunk_slice: Slice of sequences to process
+ input_token_ids_logprobs_val: List to store token logprob values
+ input_token_ids_logprobs_idx: List to store token indices
+ split_pruned_len: Length of pruned tokens from previous chunk
+
+ Returns:
+ int: Number of remaining tokens to process in next chunk
+ """
pt = 0
- for token_ids, pruned_len in zip(
- logits_metadata.token_ids_logprobs,
- logits_metadata.extend_logprob_pruned_lens_cpu,
+ next_split_pruned_len = 0
+ token_ids_logprobs_chunk = logits_metadata.token_ids_logprobs[chunk_slice]
+ pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[chunk_slice]
+
+ for n, (token_ids, pruned_len) in enumerate(
+ zip(
+ token_ids_logprobs_chunk,
+ pruned_lens,
+ )
):
+ # Adjust pruned length for first sequence
+ if n == 0:
+ pruned_len -= split_pruned_len
+ else:
+ split_pruned_len = 0
+
if pruned_len <= 0:
- input_token_ids_logprobs_val.append([])
- input_token_ids_logprobs_idx.append([])
+ if n == 0:
+ input_token_ids_logprobs_val.append([])
+ input_token_ids_logprobs_idx.append([])
continue
- input_token_ids_logprobs_val.append(
- [all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
- )
- input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
- pt += pruned_len
+ val = []
+ idx = []
+ for j in range(pruned_len):
+ # Handle remaining tokens in next chunk if any
+ if pt + j >= logprobs.shape[0]:
+ next_split_pruned_len = split_pruned_len + j
+ break
+ if token_ids is not None:
+ val.append(logprobs[pt + j, token_ids].tolist())
+ idx.append(token_ids)
+
+ if split_pruned_len <= 0 and len(val) > 0:
+ input_token_ids_logprobs_val.append(val)
+ input_token_ids_logprobs_idx.append(idx)
+ elif len(val) > 0:
+ input_token_ids_logprobs_val[-1].extend(val)
+ input_token_ids_logprobs_idx[-1].extend(idx)
- return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
+ pt += pruned_len
+ return next_split_pruned_len
@staticmethod
def compute_temp_top_p_normalized_logprobs(
- last_logits: torch.Tensor, logits_metadata: LogitsMetadata
+ last_logits: torch.Tensor,
+ indices: torch.Tensor,
+ logits_metadata: LogitsMetadata,
) -> torch.Tensor:
"""
compute logprobs for the output token from the given logits.
@@ -561,19 +763,20 @@ class LogitsProcessor(nn.Module):
"""
# Scale logits if temperature scaling is enabled
if logits_metadata.temp_scaled_logprobs:
- last_logits = last_logits / logits_metadata.temperature
+ last_logits = last_logits / logits_metadata.temperature[indices]
+
+ top_p = None
+ if logits_metadata.top_p is not None:
+ top_p = logits_metadata.top_p[indices]
# Normalize logprobs if top_p normalization is enabled
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
- if (
- logits_metadata.top_p_normalized_logprobs
- and (logits_metadata.top_p != 1.0).any()
- ):
+ if logits_metadata.top_p_normalized_logprobs and (top_p != 1.0).any():
from sglang.srt.layers.sampler import top_p_normalize_probs_torch
probs = torch.softmax(last_logits, dim=-1)
del last_logits
- probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
+ probs = top_p_normalize_probs_torch(probs, top_p)
return torch.log(probs)
else:
return torch.nn.functional.log_softmax(last_logits, dim=-1)
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 5390668c..db370d19 100644 index 5390668c..db370d19 100644
--- a/python/sglang/srt/managers/io_struct.py --- a/python/sglang/srt/managers/io_struct.py

View File

@ -1,5 +1,5 @@
[build-system] [build-system]
requires = ["setuptools>=61.0", "packaging", "torch", "pybind11>=2.10.0", "build>=1.2.1"] requires = ["setuptools>=61.0", "build>=1.2.1"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]

View File

@ -3,21 +3,5 @@
# Licensed under the Apache License, Version 2.0 (the "License"). # Licensed under the Apache License, Version 2.0 (the "License").
# Initialize preset config before all submodules. # Initialize preset config before all submodules.
from .base import prologue # isort: skip
from .api.cli_args import *
# Re-import these classes for clear documentation,
# otherwise the name will have a long prefix.
from .api.core.config import ModelName, ModelShardID
from .api.core.data_api import SequenceSample
from .api.core.dfg import MFCDef
from .api.core.model_api import (
FinetuneSpec,
Model,
ModelBackend,
ModelInterface,
PipelinableEngine,
ReaLModelConfig,
)
from .version import __version__ from .version import __version__

View File

@ -1,3 +1,4 @@
import getpass
import os import os
from dataclasses import asdict, dataclass, field, fields, is_dataclass from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import Dict, List, Optional, Tuple, Type, Union from typing import Dict, List, Optional, Tuple, Type, Union
@ -5,7 +6,6 @@ from typing import Dict, List, Optional, Tuple, Type, Union
from omegaconf import MISSING from omegaconf import MISSING
from realhf.base import pkg_version from realhf.base import pkg_version
from realhf.base.cluster import spec as cluster_spec
## Data and datasets. ## ## Data and datasets. ##
@ -847,6 +847,57 @@ class TensorBoardConfig:
path: Optional[str] = None path: Optional[str] = None
def get_user_tmp():
user = getpass.getuser()
user_tmp = os.path.join("/home", user, ".cache", "realhf")
os.makedirs(user_tmp, exist_ok=True)
return user_tmp
@dataclass
class ClusterSpecConfig:
config_path: str = field(
default="",
metadata={
"help": "JSON config path. If not given, use the following CLI args."
},
)
cluster_name: str = field(
default="local",
metadata={"help": "Name of the cluster. Used to set specific environs."},
)
fileroot: str = field(
default=get_user_tmp(),
metadata={
"help": "Root for logs and checkpoints. Should be available to all nodes."
},
)
gpu_type: str = field(
default="tesla", metadata={"help": "GPU type of the cluster. Used by slurm."}
)
mount: str = field(
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
)
gpu_image: str = field(default="", metadata={"help": "slurm image for trainers."})
cpu_image: str = field(default="", metadata={"help": "slurm image for CPU jobs."})
gpu_infer_image: str = field(
default="", metadata={"help": "slurm image for LLM inference."}
)
node_name_prefix: str = field(
default="slurmd-", metadata={"help": "Node prefix for a slurm cluster."}
)
n_nodes: int = field(
default=32,
metadata={
"help": "The size of the cluster. Used to decide slurm hostname suffix."
},
)
n_gpus_per_node: int = field(
default=8,
metadata={"help": "GPUs per node (physically)."},
)
@dataclass @dataclass
class BaseExperimentConfig: class BaseExperimentConfig:
"""Configuration for quickstart experiments. """Configuration for quickstart experiments.
@ -935,21 +986,20 @@ class BaseExperimentConfig:
default=1, metadata={"help": "Number of nodes for experiment."} default=1, metadata={"help": "Number of nodes for experiment."}
) )
n_gpus_per_node: int = field( n_gpus_per_node: int = field(
default=cluster_spec.n_gpus_per_node, default=8, metadata={"help": "Number of GPUs per node for this experiment."}
metadata={"help": "GPUs per node. Total GPUs = n_nodes * n_gpus_per_node."},
) )
nodelist: Optional[str] = field( nodelist: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": "SLURM nodelist for manual allocation. " "help": "SLURM nodelist for manual allocation. "
"Format: 'NODE01:0,1,2,3' or 'NODE[01-02,03,07],COM08'." "Format: 'slurmd-01:0,1,2,3' or 'slurmd-[01-02,03,07],COM08'."
}, },
) )
exclude: Optional[str] = field( exclude: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": "SLURM nodelist to exclude from allocation. " "help": "SLURM nodelist to exclude from allocation. "
"Format: 'NODE01:0,1,2,3' or 'NODE[01-02,03,07],COM08'." "Format: 'slurmd-01:0,1,2,3' or 'slurmd-[01-02,03,07],COM08'."
}, },
) )
seed: int = field(default=1, metadata={"help": "Random seed for reproducibility."}) seed: int = field(default=1, metadata={"help": "Random seed for reproducibility."})
@ -996,6 +1046,13 @@ class BaseExperimentConfig:
shuffle_dataset: bool = field( shuffle_dataset: bool = field(
default=True, metadata={"help": "Shuffle in each epoch."} default=True, metadata={"help": "Shuffle in each epoch."}
) )
ray_temp_path: str = field(
default="/tmp/ray", metadata={"help": "Absolute path for Ray's log."}
)
cluster: ClusterSpecConfig = field(
default_factory=ClusterSpecConfig,
metadata={"help": "Cluster specification. Mainly used by slurm."},
)
## Configuration options of asynchronous experiments. ## ## Configuration options of asynchronous experiments. ##
@ -1033,7 +1090,7 @@ class AsyncRLOptions:
}, },
) )
flush_request_timeout: int = field( flush_request_timeout: int = field(
default=120, default=300,
metadata={"help": "The timeout of flushing requests upon weight update."}, metadata={"help": "The timeout of flushing requests upon weight update."},
) )

View File

@ -723,6 +723,7 @@ class SequenceSample:
class DataBatchMeta: class DataBatchMeta:
dp_rank: int dp_rank: int
meta_sample: SequenceSample | None meta_sample: SequenceSample | None
birth_times: List
@dataclasses.dataclass @dataclasses.dataclass

View File

@ -4,6 +4,7 @@
import dataclasses import dataclasses
import os import os
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import realhf.api.core.dfg as dfg import realhf.api.core.dfg as dfg
@ -26,16 +27,21 @@ from realhf.base import constants, topology
from realhf.base.cluster import spec as cluster_spec from realhf.base.cluster import spec as cluster_spec
class ExpStatus(Enum):
RUNNING = "RUNNING"
ABORTED = "ABORTED"
COMPLETE = "COMPLETE"
@dataclasses.dataclass @dataclasses.dataclass
class Scheduling: class Scheduling:
# TODO: add partition
cpu: int cpu: int
gpu: int gpu: int
mem: int mem: int
gpu_type: str = "tesla"
node_type: str = None
nodelist: str = None nodelist: str = None
exclude: str = None exclude: str = None
container_image: str = cluster_spec.cpu_image container_image: str = None
env_vars: Dict[str, str] = dataclasses.field(default_factory=dict) env_vars: Dict[str, str] = dataclasses.field(default_factory=dict)
# time utils from "https://slurm.schedmd.com/sbatch.html" # time utils from "https://slurm.schedmd.com/sbatch.html"
time_limit: Optional[str] = None # see "--time" option for format time_limit: Optional[str] = None # see "--time" option for format
@ -241,7 +247,7 @@ class ExperimentScheduling:
generation_server: TasksGroup | None = None generation_server: TasksGroup | None = None
gserver_manager: TasksGroup | None = None gserver_manager: TasksGroup | None = None
rollout_worker: TasksGroup | None = None rollout_worker: TasksGroup | None = None
controller_image: str = cluster_spec.cpu_image controller_image: str = None
@dataclasses.dataclass @dataclasses.dataclass

View File

@ -184,10 +184,10 @@ def make_device_mesh_from_name(
""" """
DeviceMesh name format: <prefix><node_indices>[:<gpu_ids>] DeviceMesh name format: <prefix><node_indices>[:<gpu_ids>]
slurm_nodelist is the name of slurm nodes the mesh is on, should follow slurm convention, slurm_nodelist is the name of slurm nodes the mesh is on, should follow slurm convention,
for example "NODE[40-43]" or "NODE[01,11,13-14]" with prefix NODE, for example "slurmd-[40-43]" or "slurmd-[01,11,13-14]" with prefix slurmd-,
if n_nodes=1, gpu_ids are the gpu id list delimited by comma if n_gpus < n_gpus_per_node, if n_nodes=1, gpu_ids are the gpu id list delimited by comma if n_gpus < n_gpus_per_node,
for example "0,1,2,3" or "0,1". An example of full device mesh name for example "0,1,2,3" or "0,1". An example of full device mesh name
in this situation is "NODE40:0,1,2,3" in this situation is "slurmd-40:0,1,2,3"
Note: cluster device mesh name must occupy entire nodes. Note: cluster device mesh name must occupy entire nodes.
""" """

View File

@ -16,13 +16,14 @@ from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, OmegaConf from omegaconf import MISSING, OmegaConf
import realhf.api.core.system_api as system_api import realhf.api.core.system_api as system_api
from realhf.api.cli_args import print_runtime_helper from realhf.base.constants import init_constants
from realhf.base.constants import LOG_ROOT, MODEL_SAVE_ROOT, QUICKSTART_EXPR_CACHE_PATH
from realhf.base.ray_utils import check_ray_availability from realhf.base.ray_utils import check_ray_availability
from realhf.base.slurm_utils import check_slurm_availability from realhf.base.slurm_utils import check_slurm_availability
def kind_reminder(config_name, logger, args): def kind_reminder(config_name, logger, args):
from realhf.base.constants import LOG_ROOT, MODEL_SAVE_ROOT
logger.info(f"Running {config_name} experiment.") logger.info(f"Running {config_name} experiment.")
logger.info( logger.info(
f"Logs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name)}" f"Logs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name)}"
@ -81,9 +82,13 @@ def register_quickstart_exp(config_name: str, exp_cls: Callable):
trial_name = args.trial_name trial_name = args.trial_name
from realhf.apps.main import main_start, main_stop from realhf.apps.main import main_start, main_stop
init_constants(args)
from realhf.base.constants import LOG_ROOT, QUICKSTART_EXPR_CACHE_PATH
config_save_path = os.path.join( config_save_path = os.path.join(
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml" LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
) )
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
with open(config_save_path, "w") as f: with open(config_save_path, "w") as f:
yaml.dump( yaml.dump(
dataclasses.asdict(OmegaConf.to_object(args)), dataclasses.asdict(OmegaConf.to_object(args)),

View File

@ -11,13 +11,10 @@ import uuid
from typing import Dict, List, Optional from typing import Dict, List, Optional
import realhf.api.core.system_api as config_package import realhf.api.core.system_api as config_package
import realhf.base.constants as constants
import realhf.base.logging as logging
import realhf.base.name_resolve as name_resolve
import realhf.base.names as names
import realhf.base.recover as recover
import realhf.scheduler.client as sched_client import realhf.scheduler.client as sched_client
import realhf.system as system import realhf.system as system
from realhf.api.core.system_api import ExpStatus
from realhf.base import constants, logging, name_resolve, names, recover
from realhf.scheduler.client import JobException, JobState from realhf.scheduler.client import JobException, JobState
from realhf.scheduler.evaluator import AutomaticEvaluator from realhf.scheduler.evaluator import AutomaticEvaluator
from realhf.version import get_full_version_with_dirty_description from realhf.version import get_full_version_with_dirty_description
@ -55,7 +52,6 @@ def _submit_workers(
nodelist = sch_cfg.scheduling.nodelist nodelist = sch_cfg.scheduling.nodelist
exclude = sch_cfg.scheduling.exclude exclude = sch_cfg.scheduling.exclude
node_type = sch_cfg.scheduling.node_type
container_image = image_name or sch_cfg.scheduling.container_image container_image = image_name or sch_cfg.scheduling.container_image
scheduled_jobs.append( scheduled_jobs.append(
@ -65,10 +61,8 @@ def _submit_workers(
count=sch_cfg.count, count=sch_cfg.count,
cpu=sch_cfg.scheduling.cpu, cpu=sch_cfg.scheduling.cpu,
gpu=sch_cfg.scheduling.gpu, gpu=sch_cfg.scheduling.gpu,
gpu_type=sch_cfg.scheduling.gpu_type,
mem=sch_cfg.scheduling.mem, mem=sch_cfg.scheduling.mem,
container_image=container_image, container_image=container_image,
node_type=node_type,
nodelist=nodelist, nodelist=nodelist,
exclude=exclude, exclude=exclude,
env_vars=job_environs, env_vars=job_environs,
@ -147,18 +141,15 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
cluster_spec_path = os.environ.get("CLUSTER_SPEC_PATH", "") cluster_spec_path = os.environ.get("CLUSTER_SPEC_PATH", "")
if not cluster_spec_path: if not cluster_spec_path:
if args.mode == "slurm": logger.info(
raise ValueError(
"Environment variable CLUSTER_SPEC_PATH must be set for slurm mode! "
"See example/cluster_config.json for a template."
)
logger.warning(
"Environment variable CLUSTER_SPEC_PATH is not set. " "Environment variable CLUSTER_SPEC_PATH is not set. "
"Files of the experiment (logs, checkpoints, cache ...) " "Will use the fileroot specified in CLI args. "
"will be saved to temporary directory of the system. " )
"To change the fileroot, set the fileroot option of your choice in your CLUSTER_SPEC_PATH." else:
logger.warning(
"Environment variable CLUSTER_SPEC_PATH is set. "
"Will overwrite the cluster spec in CLI args. "
) )
# set env vars # set env vars
BASE_ENVIRONS = constants.get_env_vars( BASE_ENVIRONS = constants.get_env_vars(
REAL_MODE=args.mode.upper(), REAL_MODE=args.mode.upper(),
@ -273,6 +264,29 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
) )
recover_this = recover_this and reason in recover_states recover_this = recover_this and reason in recover_states
# Check whether this exception is caused by experiment finish.
name = names.experiment_status(
constants.experiment_name(), constants.trial_name()
)
try:
exp_status = name_resolve.get(name)
recover_this = recover_this and exp_status != str(ExpStatus.COMPLETE)
if exp_status == str(ExpStatus.COMPLETE):
logger.warning("*" * 100)
logger.warning(
"*"
+ f"Will not recover because the experiment has completed! Congrats!".center(
98, " "
)
+ "*"
)
logger.warning("*" * 100)
except name_resolve.NameEntryNotFoundError:
raise name_resolve.NameEntryNotFoundError(
f"Experiment status not found during recover. "
"This indicates that the master worker is not running. Exit the recover loop."
)
kill_signal = ( kill_signal = (
"SIGKILL" if args.mode == "slurm" else "SIGTERM" "SIGKILL" if args.mode == "slurm" else "SIGTERM"
) # use sigkill to terminate slurm jobs ) # use sigkill to terminate slurm jobs

View File

@ -161,20 +161,6 @@ def launch_hydra_task(
if not any("hydra/job_logging=disabled" in x for x in sys.argv): if not any("hydra/job_logging=disabled" in x for x in sys.argv):
sys.argv.insert(2, "hydra/job_logging=disabled") sys.argv.insert(2, "hydra/job_logging=disabled")
if (
"--multirun" in sys.argv
or "hydra.mode=MULTIRUN" in sys.argv
or "-m" in sys.argv
):
raise NotImplementedError("Hydra multi-run is not supported.")
# non-multirun mode, add hydra run dir
sys.argv.insert(
2,
f"hydra.run.dir={cluster_spec.fileroot}/logs/{getpass.getuser()}/"
f"{experiment_name}/{trial_name}/hydra-outputs/",
)
sys.argv.pop(1) sys.argv.pop(1)
func() func()

View File

@ -20,11 +20,8 @@ from omegaconf import OmegaConf
multiprocessing.set_start_method("spawn", force=True) multiprocessing.set_start_method("spawn", force=True)
from realhf.api.quickstart.entrypoint import ( from realhf.api.quickstart.entrypoint import QUICKSTART_CONFIG_CLASSES
QUICKSTART_CONFIG_CLASSES, from realhf.base import gpu_utils, importing, logging
QUICKSTART_EXPR_CACHE_PATH,
)
from realhf.base import cluster, gpu_utils, importing, logging, name_resolve, names
from realhf.version import get_full_version_with_dirty_description from realhf.version import get_full_version_with_dirty_description
logger = logging.getLogger("Main-Workers") logger = logging.getLogger("Main-Workers")
@ -32,6 +29,7 @@ logger = logging.getLogger("Main-Workers")
def _patch_external_impl(exp_name, trial_name): def _patch_external_impl(exp_name, trial_name):
import realhf.api.core.system_api as system_api import realhf.api.core.system_api as system_api
from realhf.base.constants import QUICKSTART_EXPR_CACHE_PATH
if os.path.exists(QUICKSTART_EXPR_CACHE_PATH): if os.path.exists(QUICKSTART_EXPR_CACHE_PATH):
for exp_cache in os.listdir(QUICKSTART_EXPR_CACHE_PATH): for exp_cache in os.listdir(QUICKSTART_EXPR_CACHE_PATH):
@ -59,6 +57,12 @@ def main_worker(args):
constants.set_experiment_trial_names(args.experiment_name, args.trial_name) constants.set_experiment_trial_names(args.experiment_name, args.trial_name)
_patch_external_impl(args.experiment_name, args.trial_name) _patch_external_impl(args.experiment_name, args.trial_name)
# Initialize cluster infor from ENV or CLI args.
import realhf.api.core.system_api as system_api
experiment = system_api.make_experiment(name=args.experiment_name)
constants.init_constants(experiment)
worker_index_start = args.jobstep_id * args.wprocs_per_jobstep + args.wproc_offset worker_index_start = args.jobstep_id * args.wprocs_per_jobstep + args.wproc_offset
worker_index_end = min( worker_index_end = min(
worker_index_start + args.wprocs_per_jobstep, worker_index_start + args.wprocs_per_jobstep,
@ -174,6 +178,10 @@ def main_controller(args):
trial_name=args.trial_name, trial_name=args.trial_name,
) )
experiment = system_api.make_experiment(args.experiment_name) experiment = system_api.make_experiment(args.experiment_name)
# Initialize cluster infor from ENV or CLI args.
constants.init_constants(experiment)
controller.start( controller.start(
experiment=experiment, experiment=experiment,
ignore_worker_error=args.ignore_worker_error, ignore_worker_error=args.ignore_worker_error,

View File

@ -1,3 +1,4 @@
# Copyright 2025 Ant Group Inc. # Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei # Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License"). # Licensed under the Apache License, Version 2.0 (the "License").
from .prologue import * # isort: skip

View File

@ -2,61 +2,62 @@
# Copyright 2024 Wei Fu & Zhiyu Mei # Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License"). # Licensed under the Apache License, Version 2.0 (the "License").
import getpass
import json import json
import os import os
import re from typing import TYPE_CHECKING, Dict
import subprocess
from typing import Dict, List, Optional, Union
CLUSTER_SPEC_PATH = os.environ.get("CLUSTER_SPEC_PATH", "") if TYPE_CHECKING:
from realhf.api.cli_args import BaseExperimentConfig
def get_user_tmp():
user = getpass.getuser()
user_tmp = os.path.join("/home", user, ".cache", "realhf")
os.makedirs(user_tmp, exist_ok=True)
return user_tmp
class ClusterSpec: class ClusterSpec:
def __init__(self): def __init__(self):
# Set default values to comfort ray
from realhf.api.cli_args import BaseExperimentConfig
self.load_spec_from_args(BaseExperimentConfig())
self.__loaded = False self.__loaded = False
def load_spec_from_file(self, file_path: str): def load_spec_from_file(self, file_path: str):
try: if not os.path.exists(file_path):
raise FileNotFoundError(f"Cluster spec file not found: {file_path}")
with open(file_path, "r") as f: with open(file_path, "r") as f:
spec: Dict = json.load(f) spec: Dict = json.load(f)
except FileNotFoundError:
if file_path == "":
spec = dict(
cluster_type="local",
cluster_name="local",
fileroot=get_user_tmp(),
)
else:
raise FileNotFoundError(f"Cluster spec file not found: {file_path}")
self.__cluster_type = spec["cluster_type"] self.__cluster_type = spec["cluster_type"]
self.__cluster_name = spec["cluster_name"] self.__cluster_name = spec["cluster_name"]
self.__fileroot = spec["fileroot"] self.__fileroot = spec["fileroot"]
self.__node_type_from_node_name_re = spec.get("node_type_from_node_name", None)
self.__gpu_type_from_node_name_re = spec.get("gpu_type_from_node_name", None)
self.__gpu_type = spec.get("gpu_type", None) self.__gpu_type = spec.get("gpu_type", None)
self.__default_mount = spec.get("default_mount", None) self.__mount = spec.get("default_mount", None)
self.__gpu_image = spec.get("gpu_image", None) self.__gpu_image = spec.get("gpu_image", None)
self.__gpu_infer_image = spec.get("gpu_infer_image", self.__gpu_image) self.__gpu_infer_image = spec.get("gpu_infer_image", self.__gpu_image)
self.__cpu_image = spec.get("cpu_image", None) self.__cpu_image = spec.get("cpu_image", None)
self.__node_name_prefix = spec.get("node_name_prefix", "NODE") self.__node_name_prefix = spec.get("node_name_prefix", "slurmd-")
# self.__n_nodes decides number of digits in slurm hostnames # self.__n_nodes decides number of digits in slurm hostnames
# e.g. if __n_nodes = 32, then the hostnames will be NODE{:02d} # e.g. if __n_nodes = 32, then the hostnames will be slurmd-{:02d}
# if __n_nodes = 128, then the hostnames will be NODE{:03d} # if __n_nodes = 128, then the hostnames will be slurmd-{:03d}
self.__n_nodes = int(spec.get("n_nodes", 32)) self.__n_nodes = int(spec.get("n_nodes", 32))
self.__n_gpus_per_node = int(spec.get("n_gpus_per_node", 8)) self.__n_gpus_per_node = int(spec.get("n_gpus_per_node", 8))
assert isinstance(self.__n_nodes, int) assert isinstance(self.__n_nodes, int)
self.__loaded = True self.__loaded = True
def load_spec_from_args(self, args: "BaseExperimentConfig"):
self.__cluster_type = args.mode
self.__cluster_name = args.cluster.cluster_name
self.__fileroot = args.cluster.fileroot
self.__gpu_type = args.cluster.gpu_type
self.__mount = args.cluster.mount
self.__gpu_image = args.cluster.gpu_image
self.__gpu_infer_image = args.cluster.gpu_infer_image
self.__cpu_image = args.cluster.cpu_image
self.__node_name_prefix = args.cluster.node_name_prefix
self.__n_nodes = args.cluster.n_nodes
self.__n_gpus_per_node = args.cluster.n_gpus_per_node
self.__loaded = True
@property @property
def name(self): def name(self):
assert self.__loaded assert self.__loaded
@ -67,32 +68,6 @@ class ClusterSpec:
assert self.__loaded assert self.__loaded
return self.__gpu_type return self.__gpu_type
def node_type_from_node_name(self, node_name: str) -> str:
"""Mapping nodename to slurm node type, including "g1", "g2", "g8",
"a100"."""
if self.__cluster_type != "slurm":
raise NotImplementedError(
"Only slurm cluster uses node_type_from_node_name."
)
assert self.__loaded
for regex, node_type in self.__node_type_from_node_name_re.items():
if re.match(regex, node_name):
return node_type
raise NotImplementedError(node_name)
def gpu_type_from_node_name(self, node_name: str) -> str:
"""Mapping nodename to slurm GPU type, including "geforce" and
"tesla"."""
if self.__cluster_type != "slurm":
raise NotImplementedError(
"Only slurm cluster uses gpu_type_from_node_name."
)
assert self.__loaded
for regex, gpu_type in self.__gpu_type_from_node_name_re.items():
if re.match(regex, node_name):
return gpu_type
raise NotImplementedError(node_name)
@property @property
def fileroot(self) -> str: def fileroot(self) -> str:
"""Return the root directory of the file system in the cluster. """Return the root directory of the file system in the cluster.
@ -109,11 +84,11 @@ class ClusterSpec:
self.__fileroot = root self.__fileroot = root
@property @property
def default_mount(self) -> str: def mount(self) -> str:
"""Directories that should be mounted to container that runs """Directories that should be mounted to container that runs
workers.""" workers."""
assert self.__loaded assert self.__loaded
return self.__default_mount return self.__mount
@property @property
def gpu_image(self) -> str: def gpu_image(self) -> str:
@ -156,23 +131,15 @@ class ClusterSpec:
return self.__cluster_type return self.__cluster_type
def node_name_is_node_type(
node_name: str, node_type: Optional[Union[List[str], str]] = None
) -> bool:
assert spec is not None
if node_type is None:
return True
if not isinstance(node_type, list):
node_type = [node_type]
nt_condition = []
for nt in node_type:
if nt not in ["g1", "g2", "g8", "a100"]:
raise ValueError(f"Unknown node type {nt}.")
else:
cond = spec.node_type_from_node_name(node_name) == nt
nt_condition.append(cond)
return any(nt_condition)
spec = ClusterSpec() spec = ClusterSpec()
def init_cluster_spec(args: "BaseExperimentConfig"):
global spec
CLUSTER_SPEC_PATH = os.environ.get("CLUSTER_SPEC_PATH", "")
if args.cluster.config_path:
spec.load_spec_from_file(args.cluster.config_path)
elif CLUSTER_SPEC_PATH:
spec.load_spec_from_file(CLUSTER_SPEC_PATH) spec.load_spec_from_file(CLUSTER_SPEC_PATH)
else:
spec.load_spec_from_args(args)

View File

@ -11,17 +11,18 @@ import os
import pathlib import pathlib
import subprocess import subprocess
from collections import defaultdict from collections import defaultdict
from pathlib import Path
from typing import * from typing import *
import numpy as np import numpy as np
import torch import torch
import realhf.base.logging as logging import realhf.base.logging as logging
from realhf.base.cluster import spec as cluster_spec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from realhf.api.cli_args import BaseExperimentConfig
from realhf.api.core.config import ModelName from realhf.api.core.config import ModelName
from realhf.api.core.system_api import ModelShardID from realhf.api.core.system_api import ModelShardID
from realhf.base.topology import ParallelGrid, ProcessTopology from realhf.base.topology import ParallelGrid, ProcessTopology
@ -68,23 +69,51 @@ TORCH_FORCE_CPU = False
# constants in experiment instance scope # constants in experiment instance scope
LOCAL_CACHE_DIR = "/tmp/realhf" LOCAL_CACHE_DIR = "/tmp/realhf"
MODEL_SAVE_ROOT = f"{cluster_spec.fileroot}/checkpoints/{getpass.getuser()}"
LOG_ROOT = f"{cluster_spec.fileroot}/logs/{getpass.getuser()}"
RECOVER_ROOT = f"{cluster_spec.fileroot}/recover/{getpass.getuser()}"
SLURM_LOCK_FILE_NAME = f"{cluster_spec.fileroot}/logs/slurm_scheduler.lock"
PORT_LOCK_FILE_ROOT = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/ports"
PYTORCH_KERNEL_CACHE_PATH = ( PYTORCH_KERNEL_CACHE_PATH = (
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels" f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
) )
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton" TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
DATASET_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/datasets" QUICKSTART_EXPR_CACHE_PATH = str(Path(__file__).parent.parent.parent / ".cache")
PROFILER_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/profiler" os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
PARAM_REALLOC_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/param_realloc" os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
SGLANG_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/sglang" os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True)
# Global constants that should be initialized after cluster initialization.
MODEL_SAVE_ROOT = None
LOG_ROOT = None
RECOVER_ROOT = None
SLURM_LOCK_FILE_NAME = None
PORT_LOCK_FILE_ROOT = None
DATASET_CACHE_PATH = None
PROFILER_CACHE_PATH = None
PARAM_REALLOC_PATH = None
SGLANG_CACHE_PATH = None
TORCH_EXTENSIONS_DIR = None
BASE_ENVIRONS = None
def init_constants(args: "BaseExperimentConfig"):
from realhf.base.cluster import init_cluster_spec
from realhf.base.cluster import spec as cluster_spec
init_cluster_spec(args)
globals_dict = globals() # Get module's global variables
kwargs = dict(
MODEL_SAVE_ROOT=f"{cluster_spec.fileroot}/checkpoints/{getpass.getuser()}",
LOG_ROOT=f"{cluster_spec.fileroot}/logs/{getpass.getuser()}",
RECOVER_ROOT=f"{cluster_spec.fileroot}/recover/{getpass.getuser()}",
SLURM_LOCK_FILE_NAME=f"{cluster_spec.fileroot}/logs/slurm_scheduler.lock",
PORT_LOCK_FILE_ROOT=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/ports",
DATASET_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/datasets",
PROFILER_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/profiler",
PARAM_REALLOC_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/param_realloc",
SGLANG_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/sglang",
TORCH_EXTENSIONS_DIR=( TORCH_EXTENSIONS_DIR=(
f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/torch/extensions" f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/torch/extensions"
),
) )
QUICKSTART_EXPR_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/"
BASE_ENVIRONS = { BASE_ENVIRONS = {
# "PYTHONPATH": "/realhf", # "PYTHONPATH": "/realhf",
"REAL_IS_REMOTE": "1", "REAL_IS_REMOTE": "1",
@ -94,7 +123,7 @@ BASE_ENVIRONS = {
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH, "PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
"TRITON_CACHE_DIR": TRITON_CACHE_PATH, "TRITON_CACHE_DIR": TRITON_CACHE_PATH,
"TOKENIZERS_PARALLELISM": "true", "TOKENIZERS_PARALLELISM": "true",
"TORCH_EXTENSIONS_DIR": TORCH_EXTENSIONS_DIR, "TORCH_EXTENSIONS_DIR": kwargs["TORCH_EXTENSIONS_DIR"],
# "TORCH_DISTRIBUTED_DEBUG": "DETAIL", # "TORCH_DISTRIBUTED_DEBUG": "DETAIL",
# "NCCL_SOCKET_IFNAME": "ibp71s0", # "NCCL_SOCKET_IFNAME": "ibp71s0",
# "GLOO_SOCKET_IFNAME": "ibp71s0", # "GLOO_SOCKET_IFNAME": "ibp71s0",
@ -125,7 +154,7 @@ BASE_ENVIRONS = {
"LANG": "C", "LANG": "C",
"NCCL_DEBUG": "WARN", "NCCL_DEBUG": "WARN",
} }
kwargs["BASE_ENVIRONS"] = BASE_ENVIRONS
# Set PPU-specific environment variables for stable training. # Set PPU-specific environment variables for stable training.
if cluster_spec.name == "wa180": if cluster_spec.name == "wa180":
logger.warning("Detected PPU. Amending PPU-related environment variables.") logger.warning("Detected PPU. Amending PPU-related environment variables.")
@ -138,7 +167,7 @@ if cluster_spec.name == "wa180":
"NCCL_SOCKET_IFNAME": "bond0", "NCCL_SOCKET_IFNAME": "bond0",
"PCCL_STATE_MONITOR_DISABLE": "1", "PCCL_STATE_MONITOR_DISABLE": "1",
} }
BASE_ENVIRONS.update(PPU_ENVIRONS) kwargs["BASE_ENVIRONS"].update(PPU_ENVIRONS)
elif cluster_spec.name == "na132": elif cluster_spec.name == "na132":
# Specific environment variable for h800 cluster na132 # Specific environment variable for h800 cluster na132
NV_ENVIRONS = { NV_ENVIRONS = {
@ -154,22 +183,26 @@ elif cluster_spec.name == "na132":
"NCCL_SET_THREAD_NAME": "1", "NCCL_SET_THREAD_NAME": "1",
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH", "NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
} }
BASE_ENVIRONS.update(NV_ENVIRONS) kwargs["BASE_ENVIRONS"].update(NV_ENVIRONS)
for key, value in kwargs.items():
if key not in globals_dict:
raise ValueError(f"Invalid constant name: {key}")
if globals_dict[key] is not None and globals_dict[key] != value:
raise RuntimeError(f"Constant '{key}' already initialized!")
globals_dict[key] = value
# make directories if does not exist # make directories if does not exist
os.makedirs(PARAM_REALLOC_PATH, exist_ok=True) os.makedirs(globals_dict["PARAM_REALLOC_PATH"], exist_ok=True)
os.makedirs(MODEL_SAVE_ROOT, exist_ok=True) os.makedirs(globals_dict["MODEL_SAVE_ROOT"], exist_ok=True)
os.makedirs(LOG_ROOT, exist_ok=True) os.makedirs(globals_dict["LOG_ROOT"], exist_ok=True)
os.makedirs(RECOVER_ROOT, exist_ok=True) os.makedirs(globals_dict["RECOVER_ROOT"], exist_ok=True)
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True) os.makedirs(globals_dict["DATASET_CACHE_PATH"], exist_ok=True)
os.makedirs(TRITON_CACHE_PATH, exist_ok=True) os.makedirs(globals_dict["PROFILER_CACHE_PATH"], exist_ok=True)
os.makedirs(DATASET_CACHE_PATH, exist_ok=True) os.makedirs(globals_dict["TORCH_EXTENSIONS_DIR"], exist_ok=True)
os.makedirs(PROFILER_CACHE_PATH, exist_ok=True) os.makedirs(globals_dict["PORT_LOCK_FILE_ROOT"], exist_ok=True)
os.makedirs(TORCH_EXTENSIONS_DIR, exist_ok=True) os.makedirs(globals_dict["SGLANG_CACHE_PATH"], exist_ok=True)
os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True)
os.makedirs(PORT_LOCK_FILE_ROOT, exist_ok=True)
os.makedirs(SGLANG_CACHE_PATH, exist_ok=True)
# _model_name will be changed in the model_scope context manager # _model_name will be changed in the model_scope context manager
_model_name: "ModelName" = None _model_name: "ModelName" = None

View File

@ -33,6 +33,13 @@ logger = logging.getLogger("benchmark")
IF_MARK = False IF_MARK = False
@dataclasses.dataclass
class RolloutStat:
submitted: int = 0
accepted: int = 0
running: int = 0
def mock_time_mark(name, identifier, t, step): def mock_time_mark(name, identifier, t, step):
if IF_MARK: if IF_MARK:
logger.debug(f"*{name}* #{identifier}# ${t}$ ns step &{step}&") logger.debug(f"*{name}* #{identifier}# ${t}$ ns step &{step}&")

View File

@ -13,13 +13,14 @@ import time
import uuid import uuid
from typing import Callable, List, Optional from typing import Callable, List, Optional
import ray
try: try:
import etcd3 import etcd3
except Exception: except Exception:
etcd3 = None etcd3 = None
from realhf.base import cluster, logging, security, timeutil from realhf.base import logging, security, timeutil
from realhf.base.cluster import spec as cluster_spec
logger = logging.getLogger("name-resolve") logger = logging.getLogger("name-resolve")
@ -286,14 +287,19 @@ class MemoryNameRecordRepository(NameRecordRepository):
class NfsNameRecordRepository(NameRecordRepository): class NfsNameRecordRepository(NameRecordRepository):
RECORD_ROOT = f"{cluster_spec.fileroot}/name_resolve/" RECORD_ROOT = ""
os.makedirs(RECORD_ROOT, exist_ok=True)
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.__to_delete = set() self.__to_delete = set()
@staticmethod @staticmethod
def __dir_path(name): def __dir_path(name):
if not NfsNameRecordRepository.RECORD_ROOT:
from realhf.base.cluster import spec as cluster_spec
RECORD_ROOT = f"{cluster_spec.fileroot}/name_resolve/"
os.makedirs(RECORD_ROOT, exist_ok=True)
NfsNameRecordRepository.RECORD_ROOT = RECORD_ROOT
return os.path.join(NfsNameRecordRepository.RECORD_ROOT, name) return os.path.join(NfsNameRecordRepository.RECORD_ROOT, name)
@staticmethod @staticmethod
@ -930,6 +936,446 @@ class Etcd3NameRecordRepository(NameRecordRepository):
logger.debug(f"Testonly: dropped key: {name}") logger.debug(f"Testonly: dropped key: {name}")
@ray.remote
class DistributedKVStore:
"""Ray actor implementing a distributed key-value store with TTL support."""
def __init__(self):
self.store = {}
self.ttl_store = {} # key -> expiry_time
self.lease_store = {} # key -> lease_id
self.lease_counter = 0
def put(self, key: str, value: str, lease_id: Optional[int] = None):
"""Store a key-value pair with optional lease."""
self.store[key] = value
if lease_id is not None:
self.lease_store[key] = lease_id
return True
def get(self, key: str):
"""Get value for a key, checking TTL expiry."""
self._cleanup_expired()
if key not in self.store:
return None
return self.store[key]
def delete(self, key: str):
"""Delete a key and its associated metadata."""
deleted = key in self.store
self.store.pop(key, None)
self.ttl_store.pop(key, None)
self.lease_store.pop(key, None)
return deleted
def get_prefix(self, prefix: str):
"""Get all key-value pairs with keys matching the prefix."""
self._cleanup_expired()
result = []
normalized_prefix = os.path.normpath(prefix)
for key, value in self.store.items():
normalized_key = os.path.normpath(key)
# Check if key matches prefix (exact match or starts with prefix/)
if normalized_key == normalized_prefix or normalized_key.startswith(
normalized_prefix.rstrip("/") + "/"
):
result.append((key, value))
return result
def delete_prefix(self, prefix: str):
"""Delete all keys matching the prefix."""
self._cleanup_expired()
normalized_prefix = os.path.normpath(prefix)
keys_to_delete = []
for key in self.store.keys():
normalized_key = os.path.normpath(key)
if normalized_key == normalized_prefix or normalized_key.startswith(
normalized_prefix.rstrip("/") + "/"
):
keys_to_delete.append(key)
for key in keys_to_delete:
self.delete(key)
return len(keys_to_delete)
def create_lease(self, ttl_seconds: int):
"""Create a lease with TTL."""
self.lease_counter += 1
lease_id = self.lease_counter
expiry_time = time.time() + ttl_seconds
return lease_id, expiry_time
def put_with_lease(self, key: str, value: str, ttl_seconds: int):
"""Store key-value with TTL lease."""
lease_id, expiry_time = self.create_lease(ttl_seconds)
self.store[key] = value
self.ttl_store[key] = expiry_time
self.lease_store[key] = lease_id
return lease_id
def refresh_lease(self, key: str, ttl_seconds: int):
"""Refresh the lease for a key."""
if key in self.store and key in self.lease_store:
self.ttl_store[key] = time.time() + ttl_seconds
return True
return False
def _cleanup_expired(self):
"""Remove expired keys."""
current_time = time.time()
expired_keys = []
for key, expiry_time in self.ttl_store.items():
if current_time > expiry_time:
expired_keys.append(key)
for key in expired_keys:
self.delete(key)
def get_all_keys(self):
"""Get all keys in the store."""
self._cleanup_expired()
return list(self.store.keys())
class RayNameResolveRepository:
"""Ray-based implementation of NameRecordRepository using distributed actors."""
KEEPALIVE_POLL_FREQUENCY = 1
@dataclasses.dataclass
class _Entry:
value: str
lease_id: Optional[int] = None
keepalive_ttl: Optional[int] = None
keeper: Optional[timeutil.FrequencyControl] = None
def __init__(self, actor_name: str = "distributed_kv_store", **kwargs):
"""Initialize Ray-based name record repository.
Args:
actor_name: Name for the Ray actor (for sharing across processes)
**kwargs: Additional configuration parameters
"""
super().__init__()
self._lock = threading.Lock()
self._actor_name = actor_name
# Initialize Ray if not already done
if not ray.is_initialized():
ray.init(ignore_reinit_error=True)
# Try to get existing actor or create new one
try:
self._kv_store = ray.get_actor(self._actor_name)
logger.debug(
f"Connected to existing Ray KV store actor: {self._actor_name}"
)
except ValueError:
# Actor doesn't exist, create it
self._kv_store = DistributedKVStore.options(
name=self._actor_name, lifetime="detached"
).remote()
logger.debug(f"Created new Ray KV store actor: {self._actor_name}")
# Track entries for cleanup and keepalive
self._entries = {}
self._keepalive_running = True
self._keepalive_thread = threading.Thread(
target=self._keepalive_thread_run, daemon=True
)
self._keepalive_thread.start()
self._to_delete = set()
def __del__(self):
"""Clean up resources when the object is deleted."""
try:
self.reset()
except Exception as e:
logger.info(
f"Exception ignored when deleting RayNameResolveRepository: {e}"
)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.reset()
def add(
self,
name: str,
value: str,
delete_on_exit: bool = True,
keepalive_ttl: Optional[int] = None,
replace: bool = False,
):
"""Add a key-value pair to the distributed store.
Args:
name: Key name
value: Value to store
delete_on_exit: Whether to delete the key when this object is destroyed
keepalive_ttl: TTL in seconds for the key
replace: Whether to replace an existing key
Raises:
NameEntryExistsError: If the key already exists and replace is False
"""
if not name:
raise ValueError(f"Invalid name: {name}")
name = os.path.normpath(name)
value = str(value)
with self._lock:
# Check if key exists when replace=False
if not replace:
existing_value = ray.get(self._kv_store.get.remote(name))
if existing_value is not None:
raise NameEntryExistsError(
f"Key already exists: K={name} V={existing_value}"
)
# Store with or without TTL
lease_id = None
if keepalive_ttl is not None and keepalive_ttl > 0:
lease_id = ray.get(
self._kv_store.put_with_lease.remote(name, value, keepalive_ttl)
)
self._to_delete.add(name)
else:
ray.get(self._kv_store.put.remote(name, value))
if delete_on_exit:
self._to_delete.add(name)
# Store entry information for keepalive management
self._entries[name] = self._Entry(
value=value,
lease_id=lease_id,
keepalive_ttl=keepalive_ttl,
keeper=(
timeutil.FrequencyControl(frequency_seconds=keepalive_ttl / 3)
if keepalive_ttl
else None
),
)
def add_subentry(self, name: str, value: str, **kwargs):
"""Add a sub-entry to the key-root `name`."""
sub_name = os.path.join(os.path.normpath(name), str(uuid.uuid4())[:8])
self.add(sub_name, value, **kwargs)
return sub_name
def delete(self, name: str):
"""Delete a key from the distributed store.
Args:
name: Key to delete
Raises:
NameEntryNotFoundError: If the key doesn't exist
"""
with self._lock:
self._delete_locked(name)
if name in self._to_delete:
self._to_delete.remove(name)
def _delete_locked(self, name: str):
"""Delete a key with lock already acquired."""
# Check if key exists
existing_value = ray.get(self._kv_store.get.remote(name))
if existing_value is None:
raise NameEntryNotFoundError(f"No such entry to delete: {name}")
# Clean up entry tracking
if name in self._entries:
del self._entries[name]
# Delete from store
ray.get(self._kv_store.delete.remote(name))
def clear_subtree(self, name_root: str):
"""Delete all keys with the given prefix."""
with self._lock:
name_root = os.path.normpath(name_root)
count = ray.get(self._kv_store.delete_prefix.remote(name_root))
# Clean up local tracking for deleted keys
keys_to_remove = []
for key in self._entries.keys():
normalized_key = os.path.normpath(key)
if normalized_key == name_root or normalized_key.startswith(
name_root.rstrip("/") + "/"
):
keys_to_remove.append(key)
for key in keys_to_remove:
del self._entries[key]
logger.debug(f"Deleted {count} entries under {name_root}")
def get(self, name: str):
"""Get the value for a key.
Args:
name: Key to retrieve
Returns:
The value as a string
Raises:
NameEntryNotFoundError: If the key doesn't exist
"""
name = os.path.normpath(name)
with self._lock:
return self._get_locked(name)
def _get_locked(self, name: str):
"""Get a value with lock already acquired."""
value = ray.get(self._kv_store.get.remote(name))
if value is None:
raise NameEntryNotFoundError(f"No such entry: {name}")
return value
def get_subtree(self, name_root: str):
"""Get all values with keys having the given prefix."""
with self._lock:
name_root = os.path.normpath(name_root)
pairs = ray.get(self._kv_store.get_prefix.remote(name_root))
values = [value for key, value in pairs]
return sorted(values)
def find_subtree(self, name_root: str):
"""Find all keys with the given prefix."""
with self._lock:
name_root = os.path.normpath(name_root)
pairs = ray.get(self._kv_store.get_prefix.remote(name_root))
keys = [key for key, value in pairs]
return sorted(keys)
def wait(
self, name: str, timeout: Optional[float] = None, poll_frequency: float = 1
):
"""Wait until a name appears.
Raises:
TimeoutError: if timeout exceeds.
"""
start = time.monotonic()
while True:
try:
return self.get(name)
except NameEntryNotFoundError:
pass
if timeout is None or timeout > 0:
time.sleep(
poll_frequency + random.random() * 0.1
) # To reduce concurrency.
if timeout is not None and time.monotonic() - start > timeout:
raise TimeoutError(
f"Timeout waiting for key '{name}' ({self.__class__.__name__})"
)
def reset(self):
"""Delete all keys added via this repository instance."""
self._keepalive_running = False
if hasattr(self, "_keepalive_thread"):
self._keepalive_thread.join(timeout=5)
with self._lock:
count = 0
for name in list(self._to_delete):
try:
self._delete_locked(name)
count += 1
except NameEntryNotFoundError:
pass
self._to_delete = set()
self._entries = {}
logger.debug(f"Reset {count} saved entries")
def watch_names(
self,
names: List[str],
call_back: Callable,
poll_frequency: float = 15,
wait_timeout: float = 300,
):
"""Watch keys and call back when they are deleted.
Args:
names: Keys to watch
call_back: Function to call when any key is deleted
poll_frequency: How often to check in seconds
wait_timeout: Maximum time to wait for keys to exist
"""
if isinstance(names, str):
names = [names]
q = queue.Queue(maxsize=len(names))
for _ in range(len(names) - 1):
q.put(0)
def wrap_call_back():
try:
q.get_nowait()
except queue.Empty:
logger.info(f"Key {names} is gone. Executing callback {call_back}")
call_back()
for name in names:
t = threading.Thread(
target=self._watch_thread_run,
args=(name, wrap_call_back, poll_frequency, wait_timeout),
daemon=True,
)
t.start()
def _watch_thread_run(
self, name: str, call_back: Callable, poll_frequency: float, wait_timeout: float
):
"""Background thread to watch a key for deletion."""
self.wait(name, timeout=wait_timeout, poll_frequency=poll_frequency)
while True:
try:
self.get(name)
time.sleep(poll_frequency + random.random())
except NameEntryNotFoundError:
call_back()
break
def _keepalive_thread_run(self):
"""Background thread to keep leases alive."""
while self._keepalive_running:
time.sleep(self.KEEPALIVE_POLL_FREQUENCY)
with self._lock:
for name, entry in list(self._entries.items()):
if (
entry.keeper is not None
and entry.keepalive_ttl is not None
and entry.lease_id is not None
and entry.keeper.check()
):
try:
# Refresh the lease
success = ray.get(
self._kv_store.refresh_lease.remote(
name, entry.keepalive_ttl
)
)
if not success:
logger.warning(
f"Failed to refresh lease for key: {name}"
)
except Exception as e:
logger.error(
f"Failed to refresh lease for key: K={name} V={entry.value}. Error: {e}"
)
def make_repository(type_="nfs", **kwargs): def make_repository(type_="nfs", **kwargs):
if type_ == "memory": if type_ == "memory":
return MemoryNameRecordRepository(**kwargs) return MemoryNameRecordRepository(**kwargs)
@ -939,6 +1385,8 @@ def make_repository(type_="nfs", **kwargs):
return RedisNameRecordRepository(**kwargs) return RedisNameRecordRepository(**kwargs)
elif type_ == "etcd3": elif type_ == "etcd3":
return Etcd3NameRecordRepository(**kwargs) return Etcd3NameRecordRepository(**kwargs)
elif type_ == "ray":
return RayNameResolveRepository(**kwargs)
else: else:
raise NotImplementedError(f"No such name resolver: {type_}") raise NotImplementedError(f"No such name resolver: {type_}")

View File

@ -99,3 +99,11 @@ def used_ports(experiment_name, trial_name, host_name):
def gen_server_manager(experiment_name, trial_name): def gen_server_manager(experiment_name, trial_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server_manager" return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server_manager"
def training_samples(experiment_name, trial_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/training_samples"
def experiment_status(experiment_name, trial_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/experiment_status"

View File

@ -28,11 +28,6 @@ def find_free_port(
"""Find a free port within the specified range, excluding certain ports.""" """Find a free port within the specified range, excluding certain ports."""
ports_name = names.used_ports(experiment_name, trial_name, gethostip()) ports_name = names.used_ports(experiment_name, trial_name, gethostip())
used_ports = list(map(int, name_resolve.get_subtree(ports_name)))
if exclude_ports is None:
exclude_ports = set(used_ports)
else:
exclude_ports = exclude_ports.union(set(used_ports))
free_port = None free_port = None
lockfile = os.path.join(constants.PORT_LOCK_FILE_ROOT, gethostip()) lockfile = os.path.join(constants.PORT_LOCK_FILE_ROOT, gethostip())
@ -40,10 +35,16 @@ def find_free_port(
with open(lockfile, "w") as fd: with open(lockfile, "w") as fd:
# This will block until lock is acquired # This will block until lock is acquired
fcntl.flock(fd, fcntl.LOCK_EX) fcntl.flock(fd, fcntl.LOCK_EX)
used_ports = list(map(int, name_resolve.get_subtree(ports_name)))
assert len(used_ports) == len(set(used_ports))
if exclude_ports is None:
exclude_ports = set(used_ports)
else:
exclude_ports = exclude_ports.union(set(used_ports))
try: try:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("", 0))
port = s.getsockname()[1] port = s.getsockname()[1]
if low <= port <= high and port not in exclude_ports: if low <= port <= high and port not in exclude_ports:
name_resolve.add_subentry(ports_name, str(port)) name_resolve.add_subentry(ports_name, str(port))

View File

@ -8,7 +8,7 @@ import json
import os import os
import sys import sys
from omegaconf import DictConfig, OmegaConf from omegaconf import OmegaConf
PROLOGUE_FLAG_NAME = "--config" PROLOGUE_FLAG_NAME = "--config"
PROLOGUE_FLAG_VAR_NAME = "config" PROLOGUE_FLAG_VAR_NAME = "config"

View File

@ -60,6 +60,7 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
"version_start", "version_start",
"version_end", "version_end",
"rewards", "rewards",
"birth_time",
) )
rpcs["actor_train"].input_keys = ( rpcs["actor_train"].input_keys = (
*rpcs["actor_train"].input_keys, *rpcs["actor_train"].input_keys,

View File

@ -104,7 +104,6 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
scheduling=Scheduling.model_worker_default( scheduling=Scheduling.model_worker_default(
cpu=self.cpus_per_model_worker, cpu=self.cpus_per_model_worker,
gpu=1, gpu=1,
gpu_type=cluster_spec.gpu_type,
mem=self.mem_per_model_worker, mem=self.mem_per_model_worker,
nodelist=self.nodelist, nodelist=self.nodelist,
exclude=self.exclude, exclude=self.exclude,
@ -115,7 +114,6 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
scheduling=Scheduling.generation_server_default( scheduling=Scheduling.generation_server_default(
cpu=self.cpus_per_generation_server, cpu=self.cpus_per_generation_server,
gpu=gen_tp_size, gpu=gen_tp_size,
gpu_type=cluster_spec.gpu_type,
mem=self.mem_per_generation_server, mem=self.mem_per_generation_server,
nodelist=self.nodelist, nodelist=self.nodelist,
exclude=self.exclude, exclude=self.exclude,
@ -125,7 +123,6 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
count=1, count=1,
scheduling=Scheduling.gserver_manager_default( scheduling=Scheduling.gserver_manager_default(
cpu=self.cpus_per_gserver_manager, cpu=self.cpus_per_gserver_manager,
gpu_type=cluster_spec.gpu_type,
mem=self.mem_per_gserver_manager, mem=self.mem_per_gserver_manager,
nodelist=self.nodelist, nodelist=self.nodelist,
exclude=self.exclude, exclude=self.exclude,
@ -135,7 +132,6 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
count=self.n_rollout_workers or train_world_size, count=self.n_rollout_workers or train_world_size,
scheduling=Scheduling.rollout_worker_default( scheduling=Scheduling.rollout_worker_default(
cpu=self.cpus_per_rollout_worker, cpu=self.cpus_per_rollout_worker,
gpu_type=cluster_spec.gpu_type,
mem=self.mem_per_rollout_worker, mem=self.mem_per_rollout_worker,
nodelist=self.nodelist, nodelist=self.nodelist,
exclude=self.exclude, exclude=self.exclude,

View File

@ -176,7 +176,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
scheduling=Scheduling.model_worker_default( scheduling=Scheduling.model_worker_default(
cpu=self.cpus_per_model_worker, cpu=self.cpus_per_model_worker,
gpu=1, gpu=1,
gpu_type=cluster_spec.gpu_type,
mem=self.mem_per_model_worker, mem=self.mem_per_model_worker,
nodelist=self.nodelist, nodelist=self.nodelist,
exclude=self.exclude, exclude=self.exclude,
@ -573,6 +572,18 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
) )
def _check_legal_allocation_options(self): def _check_legal_allocation_options(self):
if self.n_nodes > self.cluster.n_nodes:
raise ValueError(
f"Number of used nodes {self.n_nodes} should not be larger than the cluster size {self.cluster.n_nodes}"
)
if self.n_gpus_per_node > self.cluster.n_gpus_per_node:
raise ValueError(
f"Number of 7used GPUs per node {self.n_gpus_per_node} should not be larger than the cluster limit {self.cluster.n_gpus_per_node}"
)
if self.n_nodes > 1 and self.n_gpus_per_node != self.cluster.n_gpus_per_node:
raise ValueError(
f"For distributed experiments, only using all GPUs on each node is allowed."
)
if self.n_nodes > 1 and self.mode == "local": if self.n_nodes > 1 and self.mode == "local":
raise ValueError( raise ValueError(
"Cannot run multi-node experiment in local mode, " "Cannot run multi-node experiment in local mode, "

View File

@ -4,6 +4,7 @@
import asyncio import asyncio
import json import json
import os import os
from datetime import datetime
from typing import List from typing import List
import colorama import colorama
@ -51,6 +52,7 @@ class MathSingleStepAgent(Agent):
assert prompt.bs == 1 assert prompt.bs == 1
prompt_token_ids = prompt.data["packed_prompts"].cpu().numpy().tolist() prompt_token_ids = prompt.data["packed_prompts"].cpu().numpy().tolist()
qid = prompt.ids[0] qid = prompt.ids[0]
birth_time = int(datetime.now().timestamp() * 1000)
await obs_queue.put((qid, prompt_token_ids, self.gconfig)) await obs_queue.put((qid, prompt_token_ids, self.gconfig))
act: BundledGenerationOutputs = await act_queue.get() act: BundledGenerationOutputs = await act_queue.get()
@ -107,7 +109,7 @@ class MathSingleStepAgent(Agent):
"version_start", "version_start",
"version_end", "version_end",
"rewards", "rewards",
"task_ids", "birth_time",
], ],
ids=[qid], ids=[qid],
dtypes=dict( dtypes=dict(
@ -119,7 +121,7 @@ class MathSingleStepAgent(Agent):
version_end=torch.int, version_end=torch.int,
packed_logprobs=torch.float32, packed_logprobs=torch.float32,
rewards=torch.float32, rewards=torch.float32,
task_ids=torch.long, birth_time=torch.long,
), ),
trailing_shapes=dict( trailing_shapes=dict(
packed_input_ids=(), packed_input_ids=(),
@ -130,7 +132,7 @@ class MathSingleStepAgent(Agent):
version_start=(), version_start=(),
packed_logprobs=(), packed_logprobs=(),
rewards=(), rewards=(),
task_ids=(), birth_time=(),
), ),
seqlens=dict( seqlens=dict(
packed_input_ids=[act.seqlens], packed_input_ids=[act.seqlens],
@ -141,7 +143,7 @@ class MathSingleStepAgent(Agent):
rewards=[[1 for _ in range(self.gconfig.n)]], rewards=[[1 for _ in range(self.gconfig.n)]],
version_start=[[1 for _ in range(self.gconfig.n)]], version_start=[[1 for _ in range(self.gconfig.n)]],
version_end=[[1 for _ in range(self.gconfig.n)]], version_end=[[1 for _ in range(self.gconfig.n)]],
task_ids=[[1]], birth_time=[[1]],
), ),
data=dict( data=dict(
packed_prompts=torch.tensor(act.prompt_ids, dtype=torch.long), packed_prompts=torch.tensor(act.prompt_ids, dtype=torch.long),
@ -153,6 +155,7 @@ class MathSingleStepAgent(Agent):
rewards=torch.tensor(rewards, dtype=torch.float32), rewards=torch.tensor(rewards, dtype=torch.float32),
version_start=torch.tensor(act.version_start, dtype=torch.int), version_start=torch.tensor(act.version_start, dtype=torch.int),
version_end=torch.tensor(act.version_end, dtype=torch.int), version_end=torch.tensor(act.version_end, dtype=torch.int),
birth_time=torch.tensor([birth_time], dtype=torch.long),
prompt_mask=torch.tensor( prompt_mask=torch.tensor(
sum( sum(
[ [
@ -163,9 +166,18 @@ class MathSingleStepAgent(Agent):
), ),
dtype=torch.bool, dtype=torch.bool,
), ),
task_ids=prompt.data["task_ids"],
), ),
) )
if "task_ids" in prompt.keys:
y = SequenceSample(
keys=["task_ids"],
ids=[qid],
dtypes=dict(task_ids=torch.long),
trailing_shapes=dict(task_ids=()),
seqlens=dict(task_ids=[[1]]),
data=dict(task_ids=prompt.data["task_ids"]),
)
x.update_(y)
return [x] return [x]

View File

@ -9,7 +9,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import transformers import transformers
from realhf.base import cluster, constants, logging from realhf.base import constants, logging
from realhf.impl.model.utils.padding import pad_input, unpad_input from realhf.impl.model.utils.padding import pad_input, unpad_input
logger = logging.getLogger("Modeling Functional Utils") logger = logging.getLogger("Modeling Functional Utils")
@ -166,7 +166,7 @@ def build_leave_one_indices(
) )
def gather_logprobs( def _gather_logprobs(
logits: torch.Tensor, logits: torch.Tensor,
labels: torch.Tensor, labels: torch.Tensor,
): ):
@ -186,8 +186,22 @@ def gather_logprobs(
return log_probs_labels return log_probs_labels
if cluster.spec.name != "wa180": _gather_logprobs_compiled = None
gather_logprobs = torch.compile(gather_logprobs)
def gather_logprobs(
logits: torch.Tensor,
labels: torch.Tensor,
):
from realhf.base import cluster
if cluster.spec.name == "wa180":
# torch.compile doesn't work on PPU
return _gather_logprobs(logits, labels)
global _gather_logprobs_compiled
if _gather_logprobs_compiled is None:
_gather_logprobs_compiled = torch.compile(_gather_logprobs)
return _gather_logprobs_compiled(logits, labels)
def gather_packed_shifted_log_probs( def gather_packed_shifted_log_probs(

View File

@ -111,13 +111,11 @@ class SlurmSchedulerClient(SchedulerClient):
cmd: str, # XXX: should be None for workers cmd: str, # XXX: should be None for workers
count: int, count: int,
cpu: int = 1, cpu: int = 1,
gpu_type: str = "geforce",
gpu: int = 0, gpu: int = 0,
mem: int = 1024, # MB mem: int = 1024, # MB
env_vars: Optional[Dict] = None, env_vars: Optional[Dict] = None,
container_image: str = cluster_spec.gpu_image, container_image: Optional[str] = None,
container_mounts: str = cluster_spec.default_mount, container_mounts: Optional[str] = None,
node_type: Optional[str] = None,
nodelist: Optional[str] = None, nodelist: Optional[str] = None,
exclude: Optional[str] = None, exclude: Optional[str] = None,
hostfile: bool = True, hostfile: bool = True,
@ -126,14 +124,14 @@ class SlurmSchedulerClient(SchedulerClient):
deadline: str = None, deadline: str = None,
time_limit: str = None, time_limit: str = None,
): ):
container_image = container_image or cluster_spec.cpu_image
container_mounts = container_mounts or cluster_spec.mount
# record launch information, do not submit to slurm until `wait()` is called # record launch information, do not submit to slurm until `wait()` is called
# NOTE: fractional GPU requirement will be resolved automatically in `__post_init__` of SlurnLaunchInfo # NOTE: fractional GPU requirement will be resolved automatically in `__post_init__` of SlurnLaunchInfo
launch_info = SlurmLaunchInfo( launch_info = SlurmLaunchInfo(
worker_type=worker_type, worker_type=worker_type,
wprocs_in_job=count, wprocs_in_job=count,
resource_requirement=SlurmResource( resource_requirement=SlurmResource(mem=mem, cpu=cpu, gpu=gpu),
mem=mem, cpu=cpu, gpu=gpu, gpu_type=gpu_type
),
cmd=cmd, cmd=cmd,
run_name=self.run_name, run_name=self.run_name,
exper_name=self.expr_name, exper_name=self.expr_name,
@ -141,7 +139,6 @@ class SlurmSchedulerClient(SchedulerClient):
container_image=container_image, container_image=container_image,
container_mounts=container_mounts, container_mounts=container_mounts,
env_vars=env_vars, env_vars=env_vars,
node_type=node_type,
nodelist=nodelist, nodelist=nodelist,
exclude=exclude, exclude=exclude,
hostfile=hostfile, hostfile=hostfile,

View File

@ -67,21 +67,8 @@ class SlurmResource:
# a data class that represents a slurm resource quota # a data class that represents a slurm resource quota
mem: int = 0 mem: int = 0
cpu: int = 0 cpu: int = 0
gpu_type: Optional[Literal["tesla", "geforce", "ppu"]] = None
gpu: Union[float, int] = 0 gpu: Union[float, int] = 0
def __check_gpu_type(self, other: SlurmResource) -> str:
self_gpu_type = None if self.gpu == 0 else self.gpu_type
other_gpu_type = None if other.gpu == 0 else other.gpu_type
valid_gpu_type = self_gpu_type == other_gpu_type or (
self_gpu_type or other_gpu_type
)
if not valid_gpu_type:
raise InvalidGPUTypeException(
f"Cannot add two different gpu types {self_gpu_type}, {other_gpu_type}."
)
return self_gpu_type if self_gpu_type else other_gpu_type
def __str__(self): def __str__(self):
return ( return (
"SlurmResource: \n" "SlurmResource: \n"
@ -93,9 +80,6 @@ class SlurmResource:
+ " \n" + " \n"
+ "gpu: " + "gpu: "
+ str(self.gpu) + str(self.gpu)
+ " \n"
+ "gpu_type: "
+ str(self.gpu_type)
) )
def __mul__(self, other: int) -> SlurmResource: def __mul__(self, other: int) -> SlurmResource:
@ -106,7 +90,6 @@ class SlurmResource:
mem=self.mem * other, mem=self.mem * other,
cpu=self.cpu * other, cpu=self.cpu * other,
gpu=self.gpu * other, gpu=self.gpu * other,
gpu_type=self.gpu_type,
) )
def __rmul__(self, other: int) -> SlurmResource: def __rmul__(self, other: int) -> SlurmResource:
@ -120,7 +103,6 @@ class SlurmResource:
mem=self.mem + other.mem, mem=self.mem + other.mem,
cpu=self.cpu + other.cpu, cpu=self.cpu + other.cpu,
gpu=self.gpu + other.gpu, gpu=self.gpu + other.gpu,
gpu_type=self.__check_gpu_type(other),
) )
def __sub__(self, other: SlurmResource) -> SlurmResource: def __sub__(self, other: SlurmResource) -> SlurmResource:
@ -131,28 +113,19 @@ class SlurmResource:
mem=self.mem - other.mem, mem=self.mem - other.mem,
cpu=self.cpu - other.cpu, cpu=self.cpu - other.cpu,
gpu=self.gpu - other.gpu, gpu=self.gpu - other.gpu,
gpu_type=self.__check_gpu_type(other),
) )
def __neg__(self) -> SlurmResource: def __neg__(self) -> SlurmResource:
return SlurmResource( return SlurmResource(
mem=-self.mem, cpu=-self.cpu, gpu=-self.gpu, gpu_type=self.gpu_type mem=-self.mem,
cpu=-self.cpu,
gpu=-self.gpu,
) )
def __eq__(self, other: SlurmResource) -> bool: def __eq__(self, other: SlurmResource) -> bool:
return ( return self.mem == other.mem and self.cpu == other.cpu and self.gpu == other.gpu
self.mem == other.mem
and self.cpu == other.cpu
and self.gpu == other.gpu
and self.gpu_type == other.gpu_type
)
def __lt__(self, other: SlurmResource) -> bool: def __lt__(self, other: SlurmResource) -> bool:
if self.gpu_type != other.gpu_type:
if self.gpu_type is None:
return True
if self.gpu_type == "geforce":
return self.gpu_type < other.gpu_type
if self.gpu != other.gpu: if self.gpu != other.gpu:
return self.gpu < other.gpu return self.gpu < other.gpu
if self.cpu != other.cpu: if self.cpu != other.cpu:
@ -162,8 +135,6 @@ class SlurmResource:
def valid(self) -> bool: def valid(self) -> bool:
# check if it is a valid resource requirement # check if it is a valid resource requirement
if self.gpu_type not in ["geforce", "tesla", "ppu", None]:
return False
if self.mem < 0 or self.cpu < 0 or self.gpu < 0: if self.mem < 0 or self.cpu < 0 or self.gpu < 0:
return False return False
return True return True
@ -207,7 +178,6 @@ class SlurmLaunchInfo:
this string should be of format 'docker://<image>'. this string should be of format 'docker://<image>'.
container_mounts (str): . container_mounts (str): .
env_vars (dict): . env_vars (dict): .
node_type (str): .
nodelist (str): . nodelist (str): .
exclude (str): . exclude (str): .
partition (str, optional): default to "all". partition (str, optional): default to "all".
@ -234,7 +204,6 @@ class SlurmLaunchInfo:
container_image: str container_image: str
container_mounts: str container_mounts: str
env_vars: dict env_vars: dict
node_type: str
nodelist: str nodelist: str
exclude: str exclude: str
partition: Optional[str] = "all" partition: Optional[str] = "all"
@ -377,7 +346,6 @@ class SlurmLaunchInfo:
cmd = self.cmd cmd = self.cmd
# assert gpu == 1 or gpu == 0, "Slurm job GPU requirement should be resolved to a integer." # assert gpu == 1 or gpu == 0, "Slurm job GPU requirement should be resolved to a integer."
gpu_type = self.resource_requirement.gpu_type
if self.multiprog: if self.multiprog:
with open(self.multiprog_path, "w") as f: with open(self.multiprog_path, "w") as f:
@ -400,8 +368,8 @@ class SlurmLaunchInfo:
# In current slurm cluster setup, we can only use "--gres" to # In current slurm cluster setup, we can only use "--gres" to
# allocate PPUs per node. There are no options to allocate customized # allocate PPUs per node. There are no options to allocate customized
# gres per tasks. # gres per tasks.
if gpu_type == "ppu": if cluster.spec.gpu_type == "ppu":
gres_line = f"--gres={gpu_type}:{cluster.spec.n_gpus_per_node}" gres_line = f"--gres=ppu:{cluster.spec.n_gpus_per_node}"
else: else:
gres_line = f"--gres=gpu:{cluster.spec.n_gpus_per_node}" gres_line = f"--gres=gpu:{cluster.spec.n_gpus_per_node}"
@ -596,12 +564,9 @@ def _parse_output_tres_line(tres):
res.cpu = int(t.split("=")[1]) res.cpu = int(t.split("=")[1])
elif t.startswith("gres/gpu"): elif t.startswith("gres/gpu"):
prefix, sgpu = t.split("=") prefix, sgpu = t.split("=")
if ":" in prefix:
res.gpu_type = prefix.split(":")[1]
res.gpu = int(sgpu) res.gpu = int(sgpu)
elif t.startswith("gres/ppu"): elif t.startswith("gres/ppu"):
prefix, sgpu = t.split("=") prefix, sgpu = t.split("=")
res.gpu_type = "ppu"
res.gpu = int(sgpu) res.gpu = int(sgpu)
elif t.startswith("billing"): elif t.startswith("billing"):
# slurm default resource to limit number of # slurm default resource to limit number of
@ -613,7 +578,6 @@ def _parse_output_tres_line(tres):
def available_hostnames( def available_hostnames(
node_type: Optional[List[str]] = None,
nodelist: Optional[str] = None, nodelist: Optional[str] = None,
exclude: Optional[str] = None, exclude: Optional[str] = None,
partition: Optional[str] = None, partition: Optional[str] = None,
@ -684,12 +648,7 @@ def available_hostnames(
for hn in invalid_hostnames: for hn in invalid_hostnames:
valid_hostnames.remove(hn) valid_hostnames.remove(hn)
return list( return valid_hostnames
filter(
lambda x: cluster.node_name_is_node_type(x, node_type),
valid_hostnames,
)
)
def get_all_node_resources() -> Dict[str, SlurmResource]: def get_all_node_resources() -> Dict[str, SlurmResource]:
@ -721,15 +680,11 @@ def get_all_node_resources() -> Dict[str, SlurmResource]:
ctres = _parse_output_tres_line(l) ctres = _parse_output_tres_line(l)
if l.startswith("AllocTRES"): if l.startswith("AllocTRES"):
atres = _parse_output_tres_line(l) atres = _parse_output_tres_line(l)
if ctres.gpu_type is None:
ctres.gpu_type = cluster.spec.gpu_type_from_node_name(node_name)
if atres.gpu_type is None:
atres.gpu_type = ctres.gpu_type
rres = ctres - atres rres = ctres - atres
if rres.valid(): if rres.valid():
all_rres[node_name] = rres all_rres[node_name] = rres
else: else:
all_rres[node_name] = SlurmResource(gpu_type=ctres.gpu_type) all_rres[node_name] = SlurmResource()
return all_rres return all_rres
@ -769,7 +724,6 @@ def allocate_resources(
prioritized_hosts = set() prioritized_hosts = set()
for info_idx, info in enumerate(infos): for info_idx, info in enumerate(infos):
valid_hostnames = available_hostnames( valid_hostnames = available_hostnames(
node_type=info.node_type,
nodelist=info.nodelist, nodelist=info.nodelist,
exclude=info.exclude, exclude=info.exclude,
partition=info.partition, partition=info.partition,
@ -833,10 +787,7 @@ def allocate_resources(
allocated[hostname] = tmp - task_left allocated[hostname] = tmp - task_left
all_resources[hostname] = resource all_resources[hostname] = resource
if task_left > 0: if task_left > 0:
if ( if cluster.spec.gpu_type == "ppu" and info.resource_requirement.gpu > 0:
info.resource_requirement.gpu_type == "ppu"
and info.resource_requirement.gpu > 0
):
logger.warning( logger.warning(
"For PPU resources, we can only allocate tasks in the " "For PPU resources, we can only allocate tasks in the "
f"granularity of nodes ({cluster.spec.n_gpus_per_node} PPUs)" f"granularity of nodes ({cluster.spec.n_gpus_per_node} PPUs)"
@ -845,7 +796,7 @@ def allocate_resources(
f'Unable to allocate {info.n_jobsteps} Jobs with name "{info.slurm_name}". ' f'Unable to allocate {info.n_jobsteps} Jobs with name "{info.slurm_name}". '
f"Resource Requirement of this job is: {dataclasses.asdict(info.resource_requirement)}. " f"Resource Requirement of this job is: {dataclasses.asdict(info.resource_requirement)}. "
f"Valid resources for this job is " f"Valid resources for this job is "
f"(according to NodeType={info.node_type}, NodeList={info.nodelist}, " f"(according to NodeList={info.nodelist}, "
f"and Exclude={info.exclude}):\n {resource_to_string({k: v for k, v in get_all_node_resources().items() if k in valid_hostnames})}" f"and Exclude={info.exclude}):\n {resource_to_string({k: v for k, v in get_all_node_resources().items() if k in valid_hostnames})}"
) )
for pinfo in infos[:info_idx]: for pinfo in infos[:info_idx]:
@ -878,7 +829,7 @@ def allocate_resources(
def show_tesla(): def show_tesla():
all_rres = get_all_node_resources() all_rres = get_all_node_resources()
hostname = socket.gethostname() hostname = socket.gethostname()
for k in available_hostnames(node_type=["a100"]): for k in available_hostnames():
print(k, all_rres[k]) print(k, all_rres[k])

View File

@ -61,7 +61,7 @@ def run_worker(
) )
worker = worker_class(server=server) worker = worker_class(server=server)
try: try:
if worker_type in ["rollout_worker"]: if worker_type in ["rollout_worker", "master_worker", "gserver_manager"]:
asyncio.run(worker.run_async()) asyncio.run(worker.run_async())
else: else:
worker.run() worker.run()

View File

@ -244,7 +244,9 @@ class AsyncIOSequenceBuffer:
) )
return indices return indices
async def put_batch(self, samples: List[SequenceSample]): async def put_batch(
self, samples: List[SequenceSample], birth_times: List[int] | None = None
):
n = len(samples) n = len(samples)
if n == 0: if n == 0:
@ -269,9 +271,12 @@ class AsyncIOSequenceBuffer:
# Set a slight difference in birth time to let the order # Set a slight difference in birth time to let the order
# be deterministic. # be deterministic.
if birth_times is None:
self._birth_time[indices] = time.monotonic_ns() + np.arange( self._birth_time[indices] = time.monotonic_ns() + np.arange(
len(indices), dtype=np.int64 len(indices), dtype=np.int64
) )
else:
self._birth_time[indices] = birth_times
async with self._lock: async with self._lock:
self.__buffer._update_has_keys(indices) self.__buffer._update_has_keys(indices)

View File

@ -436,11 +436,9 @@ def run_ray_worker(
constants.set_experiment_trial_names(experiment_name, trial_name) constants.set_experiment_trial_names(experiment_name, trial_name)
import realhf.api.core.system_api as system_api import realhf.api.core.system_api as system_api
from realhf.api.quickstart.entrypoint import ( from realhf.api.quickstart.entrypoint import QUICKSTART_CONFIG_CLASSES
QUICKSTART_CONFIG_CLASSES,
QUICKSTART_EXPR_CACHE_PATH,
)
from realhf.base import importing from realhf.base import importing
from realhf.base.constants import QUICKSTART_EXPR_CACHE_PATH
if os.path.exists(QUICKSTART_EXPR_CACHE_PATH): if os.path.exists(QUICKSTART_EXPR_CACHE_PATH):
for exp_cache in os.listdir(QUICKSTART_EXPR_CACHE_PATH): for exp_cache in os.listdir(QUICKSTART_EXPR_CACHE_PATH):

View File

@ -10,8 +10,8 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from realhf import SequenceSample
from realhf.api.core.config import ModelName, ModelShardID from realhf.api.core.config import ModelName, ModelShardID
from realhf.api.core.data_api import SequenceSample
from realhf.base import constants, logging from realhf.base import constants, logging
from realhf.base.topology import ProcessTopology, new_or_get_group from realhf.base.topology import ProcessTopology, new_or_get_group
from realhf.impl.model.comm.global_comm import filter_match_mwids from realhf.impl.model.comm.global_comm import filter_match_mwids

View File

@ -6,11 +6,11 @@ from typing import *
import networkx as nx import networkx as nx
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from realhf.api.core.config import ModelName, ModelShardID from realhf.api.core.config import ModelShardID
from realhf.api.core.data_api import DataBatchMeta, SequenceSample from realhf.api.core.data_api import DataBatchMeta, get_shuffle_indices
from realhf.api.core.dfg import MFCDef from realhf.api.core.dfg import MFCDef
from realhf.api.core.model_api import ReaLModelConfig from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import logging from realhf.base import constants, logging, name_resolve, names, seeding
from realhf.base.topology import ProcessTopology from realhf.base.topology import ProcessTopology
from realhf.system.buffer import AsyncIOSequenceBuffer from realhf.system.buffer import AsyncIOSequenceBuffer
from realhf.system.model_function_call import ModelFunctionCall, RPCCorountineControl from realhf.system.model_function_call import ModelFunctionCall, RPCCorountineControl
@ -118,7 +118,10 @@ class FunctionExecutor:
received_ids = set() received_ids = set()
load_data_iter = 0
while buffer.size < max(rpc.n_seqs for rpc in self.rpcs): while buffer.size < max(rpc.n_seqs for rpc in self.rpcs):
load_data_iter += 1
resps = await self.stream.call_async( resps = await self.stream.call_async(
handlers=[f"__data{dp_idx}__" for dp_idx in range(self.src_dp_size)], handlers=[f"__data{dp_idx}__" for dp_idx in range(self.src_dp_size)],
handle_type="fetch", handle_type="fetch",
@ -127,6 +130,7 @@ class FunctionExecutor:
) )
all_data = [] all_data = []
all_birth_time = []
data_cnt = [] data_cnt = []
gpu_id_data = {} gpu_id_data = {}
for dp_rank, x in enumerate(resps): for dp_rank, x in enumerate(resps):
@ -147,13 +151,21 @@ class FunctionExecutor:
gpu_id = self.stream.route_to(f"__data{dp_rank}__") gpu_id = self.stream.route_to(f"__data{dp_rank}__")
all_data += x.meta_sample.unpack() all_data += x.meta_sample.unpack()
all_birth_time += x.birth_times
gpu_id_data[gpu_id] = x.meta_sample.unpack() gpu_id_data[gpu_id] = x.meta_sample.unpack()
data_cnt.append(x.meta_sample.bs) data_cnt.append(x.meta_sample.bs)
if self.shuffle_dataset: if self.shuffle_dataset:
# We load data in a round-robin manner across different DP ranks, # We load data in a round-robin manner across different DP ranks,
# so we also need to shuffle the data to fuse different dataset splits. # so we also need to shuffle the data to fuse different dataset splits.
random.shuffle(all_data) shuffle_indices = get_shuffle_indices(
seeding.get_seed()
+ 47 * self.ctrl.step_info.global_step
+ 97 * load_data_iter,
len(all_data),
)
all_data = [all_data[i] for i in shuffle_indices]
all_birth_time = [all_birth_time[i] for i in shuffle_indices]
if len(all_data) > 0: if len(all_data) > 0:
# Update resource tracker for planning data redistribution. # Update resource tracker for planning data redistribution.
@ -167,9 +179,21 @@ class FunctionExecutor:
) )
# Store into buffer! # Store into buffer!
buffer_indices = await buffer.put_batch(all_data) assert len(all_data) == len(all_birth_time)
buffer_indices = await buffer.put_batch(all_data, all_birth_time)
assert len(buffer_indices) == len(all_data) assert len(buffer_indices) == len(all_data)
training_sample_name = names.training_samples(
constants.experiment_name(), constants.trial_name()
)
try:
n_samples = int(name_resolve.get(training_sample_name))
except name_resolve.NameEntryNotFoundError:
n_samples = 0
name_resolve.add(
training_sample_name, str(n_samples + len(all_data)), replace=True
)
blogger.info( blogger.info(
f"Master worker loaded {len(all_data)} pieces of data from all dp ranks: " f"Master worker loaded {len(all_data)} pieces of data from all dp ranks: "
f"{data_cnt} from each rank. " f"{data_cnt} from each rank. "
@ -178,7 +202,7 @@ class FunctionExecutor:
else: else:
await asyncio.sleep(1) await asyncio.sleep(1)
def execute_step(self): async def execute_step(self):
logger.info("Waiting for the finish of the execution graph.") logger.info("Waiting for the finish of the execution graph.")
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -190,5 +214,5 @@ class FunctionExecutor:
loop.create_task(self.finish_traverse()), loop.create_task(self.finish_traverse()),
] ]
loop.run_until_complete(asyncio.gather(*tasks)) await asyncio.gather(*tasks)
self.buffer_id = (self.buffer_id + 1) % len(self.buffers) self.buffer_id = (self.buffer_id + 1) % len(self.buffers)

View File

@ -7,8 +7,10 @@ from pathlib import Path
import requests import requests
from realhf.api.cli_args import SGLangConfig from realhf.api.cli_args import SGLangConfig
from realhf.api.core.system_api import ExpStatus
from realhf.api.core.system_api import GenerationServer as GenerationServerConfig from realhf.api.core.system_api import GenerationServer as GenerationServerConfig
from realhf.base import ( from realhf.base import (
constants,
gpu_utils, gpu_utils,
logging, logging,
name_resolve, name_resolve,
@ -30,7 +32,12 @@ def execute_shell_command(command: str) -> subprocess.Popen:
# Replace newline continuations and split the command string. # Replace newline continuations and split the command string.
command = command.replace("\\\n", " ").replace("\\", " ") command = command.replace("\\\n", " ").replace("\\", " ")
parts = command.split() parts = command.split()
return subprocess.Popen(parts, text=True, stderr=subprocess.STDOUT) return subprocess.Popen(
parts,
text=True,
stdout=sys.stdout,
stderr=subprocess.STDOUT,
)
def launch_server_cmd(command: str, port: int = 30000): def launch_server_cmd(command: str, port: int = 30000):
@ -187,8 +194,22 @@ class GenerationServer(Worker):
if self.server_process is None: if self.server_process is None:
self.launch_server_subprocess() self.launch_server_subprocess()
# TODO: we may want to collect some metrics from the server # Check experiment finish.
time.sleep(0.05) name = names.experiment_status(
constants.experiment_name(), constants.trial_name()
)
try:
exp_status = name_resolve.wait(name, timeout=300)
if exp_status != str(ExpStatus.RUNNING):
self.exit()
return PollResult(0, 0)
except TimeoutError:
raise TimeoutError(
f"Waiting for experiment status timeout. "
"This indicates that the master worker is not running. Exit the worker."
)
time.sleep(5)
return PollResult(0, 0) return PollResult(0, 0)

View File

@ -13,40 +13,23 @@ import aiohttp
import numpy as np import numpy as np
from realhf.api.core.model_api import GenReqMeta, GenRespMeta, ModelVersionReq from realhf.api.core.model_api import GenReqMeta, GenRespMeta, ModelVersionReq
from realhf.api.core.system_api import ExpStatus
from realhf.api.core.system_api import GserverManager as GserverManagerConfig from realhf.api.core.system_api import GserverManager as GserverManagerConfig
from realhf.base import constants, logging, name_resolve, names, network, recover from realhf.base import constants, logging, name_resolve, names, network, recover
from realhf.system.worker_base import PollResult, Worker from realhf.base.monitor import RolloutStat
from realhf.system.worker_base import AsyncWorker, PollResult, Worker
logger = logging.getLogger("Generation Manager", "system") logger = logging.getLogger("Generation Manager", "system")
STALENESS_WARNED = defaultdict(lambda: False) STALENESS_WARNED = defaultdict(lambda: False)
@dataclass
class RolloutStat:
submit: int = 0
accepted: int = 0
running: int = 0
def inc(self):
self.submit += 1
self.accepted += 1
self.running += 1
def accept(self):
self.running -= 1
def reject(self):
self.running -= 1
self.accepted -= 1
@dataclass @dataclass
class AllocateRolloutInput: class AllocateRolloutInput:
qid: str qid: str
class GserverManager(Worker): class GserverManager(AsyncWorker):
"""This worker has the following functionalities: """This worker has the following functionalities:
1. As a router, it schedules generation requests and returns the 1. As a router, it schedules generation requests and returns the
best server urls to clients for submitting generation requests. best server urls to clients for submitting generation requests.
@ -104,7 +87,7 @@ class GserverManager(Worker):
self.config.train_batch_size self.config.train_batch_size
* self.__recover_info.last_step_info.global_step * self.__recover_info.last_step_info.global_step
) )
self.rollout_stat.submit = hist_rollouts self.rollout_stat.submitted = hist_rollouts
self.rollout_stat.accepted = hist_rollouts self.rollout_stat.accepted = hist_rollouts
return config.worker_info return config.worker_info
@ -187,7 +170,8 @@ class GserverManager(Worker):
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
server_url, server_url,
timeout=aiohttp.ClientTimeout( timeout=aiohttp.ClientTimeout(
total=self.config.flush_request_timeout, sock_connect=30 total=self.config.flush_request_timeout,
sock_connect=self.config.flush_request_timeout,
), ),
) as session: ) as session:
async with session.post( async with session.post(
@ -198,6 +182,7 @@ class GserverManager(Worker):
res = await resp.json() res = await resp.json()
success = res["success"] success = res["success"]
if success: if success:
if "num_paused_requests" in res:
logger.info( logger.info(
f"{res['num_paused_requests']} requests are interrupted " f"{res['num_paused_requests']} requests are interrupted "
f"during updateing weights for server {server_index}: {server_url}" f"during updateing weights for server {server_index}: {server_url}"
@ -228,7 +213,7 @@ class GserverManager(Worker):
url = min(self.server_urls, key=lambda k: self._server_token_usage[k]) url = min(self.server_urls, key=lambda k: self._server_token_usage[k])
return self.server_urls.index(url) return self.server_urls.index(url)
def _poll(self): async def _poll_async(self):
if not self.thread: if not self.thread:
# Find addresses of generation servers # Find addresses of generation servers
self.server_urls = self._discover_servers(self.config.n_servers) self.server_urls = self._discover_servers(self.config.n_servers)
@ -244,6 +229,21 @@ class GserverManager(Worker):
f"GserverManager HTTP service started in background thread at {self.manager_addr}" f"GserverManager HTTP service started in background thread at {self.manager_addr}"
) )
# Check experiment finish.
name = names.experiment_status(
constants.experiment_name(), constants.trial_name()
)
try:
exp_status = name_resolve.wait(name, timeout=300)
if exp_status != str(ExpStatus.RUNNING):
self.exit()
return PollResult(0, 0)
except TimeoutError:
raise TimeoutError(
f"Waiting for experiment status timeout. "
"This indicates that the master worker is not running. Exit the worker."
)
# Check weights. # Check weights.
with self.threading_lock: with self.threading_lock:
# FIXME: we create a sync point across servers to update weights, # FIXME: we create a sync point across servers to update weights,
@ -254,12 +254,13 @@ class GserverManager(Worker):
self.flush_requests_and_update_weights(base_url, new_param_path) self.flush_requests_and_update_weights(base_url, new_param_path)
for base_url in self.server_urls for base_url in self.server_urls
] ]
loop = asyncio.get_event_loop() await asyncio.gather(*tasks)
loop.run_until_complete(asyncio.gather(*tasks))
logger.info(f"Generaion server updated weights from: {new_param_path}") logger.info(f"Generaion server updated weights from: {new_param_path}")
if self.schedule_policy == "least_token_usage":
tasks = [ tasks = [
self._get_server_token_usage(server_url) for server_url in self.server_urls self._get_server_token_usage(server_url)
for server_url in self.server_urls
] ]
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
token_usages = loop.run_until_complete(asyncio.gather(*tasks)) token_usages = loop.run_until_complete(asyncio.gather(*tasks))
@ -304,7 +305,8 @@ class GserverManager(Worker):
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
server_url, server_url,
timeout=aiohttp.ClientTimeout( timeout=aiohttp.ClientTimeout(
total=self.config.flush_request_timeout, sock_connect=30 total=self.config.flush_request_timeout,
sock_connect=self.config.flush_request_timeout,
), ),
) as session: ) as session:
async with session.get("/metrics") as resp: async with session.get("/metrics") as resp:
@ -319,7 +321,8 @@ class GserverManager(Worker):
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
server_url, server_url,
timeout=aiohttp.ClientTimeout( timeout=aiohttp.ClientTimeout(
total=self.config.flush_request_timeout, sock_connect=30 total=self.config.flush_request_timeout,
sock_connect=self.config.flush_request_timeout,
), ),
) as session: ) as session:
async with session.get(f"/metrics") as resp: async with session.get(f"/metrics") as resp:
@ -332,8 +335,16 @@ class GserverManager(Worker):
f"Failed to get num running requests metrics from {server_url}" f"Failed to get num running requests metrics from {server_url}"
) )
def get_training_sample_cnt(self):
name = names.training_samples(self.experiment_name, self.trial_name)
try:
return int(name_resolve.get(name))
except name_resolve.NameEntryNotFoundError:
return 0
def is_staled(self): def is_staled(self):
global_sample_cnt = self.rollout_stat.accepted # Use counter written by the trainer, local counter is inaccurate
global_sample_cnt = self.get_training_sample_cnt() + self.rollout_stat.running
expected_version = global_sample_cnt // self.config.train_batch_size expected_version = global_sample_cnt // self.config.train_batch_size
version = self._last_param_realloc_step version = self._last_param_realloc_step
staled = expected_version > self.config.max_head_offpolicyness + version staled = expected_version > self.config.max_head_offpolicyness + version
@ -406,13 +417,22 @@ class GserverManager(Worker):
is_staled = self.is_staled() is_staled = self.is_staled()
reason = "" reason = ""
if has_capacity and not is_staled: if has_capacity and not is_staled:
self.rollout_stat.inc() self.rollout_stat.submitted += 1
self.rollout_stat.running += 1
logger.info(
f"Allocate rollout for qid {req.qid}. "
f"Submitted: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
return dict(success=True, reason=reason) return dict(success=True, reason=reason)
else: else:
if not has_capacity: if not has_capacity:
reason += f"capacity: {self.rollout_stat.running} >= {self.config.max_concurrent_rollouts}" reason += f"capacity: {self.rollout_stat.running} >= {self.config.max_concurrent_rollouts}"
if is_staled: if is_staled:
global_sample_cnt = self.rollout_stat.accepted global_sample_cnt = (
self.get_training_sample_cnt() + self.rollout_stat.running
)
expected_version = ( expected_version = (
global_sample_cnt // self.config.train_batch_size global_sample_cnt // self.config.train_batch_size
) )
@ -435,10 +455,15 @@ class GserverManager(Worker):
), "server request count < 0" ), "server request count < 0"
self._qid_to_server_url.pop(resp_meta.qid) self._qid_to_server_url.pop(resp_meta.qid)
self._gen_tokens += resp_meta.n_tokens self._gen_tokens += resp_meta.n_tokens
self.rollout_stat.running -= 1
if resp_meta.accepted: if resp_meta.accepted:
self.rollout_stat.accept() self.rollout_stat.accepted += 1
else: logger.info(
self.rollout_stat.reject() f"Finish rollout for qid {resp_meta.qid}. "
f"Running: {self.rollout_stat.running}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}"
)
return dict(success=True) return dict(success=True)
port = network.find_free_port( port = network.find_free_port(

View File

@ -23,6 +23,7 @@ import realhf.system.request_reply_stream as request_reply_stream
import realhf.system.worker_base as worker_base import realhf.system.worker_base as worker_base
from realhf.api.core.config import ModelName from realhf.api.core.config import ModelName
from realhf.api.core.model_api import ReaLModelConfig from realhf.api.core.model_api import ReaLModelConfig
from realhf.api.core.system_api import ExpStatus
from realhf.base import ( from realhf.base import (
constants, constants,
logging, logging,
@ -40,7 +41,7 @@ logger = logging.getLogger("master worker", "system")
blogger = logging.getLogger("benchmark") blogger = logging.getLogger("benchmark")
class MasterWorker(worker_base.Worker): class MasterWorker(worker_base.AsyncWorker):
global_exp_tik = time.perf_counter() global_exp_tik = time.perf_counter()
def _configure(self, config: config_pkg.MasterWorker): def _configure(self, config: config_pkg.MasterWorker):
@ -142,6 +143,22 @@ class MasterWorker(worker_base.Worker):
f"Global Step: {self.__rpc_ctrl.step_info.global_step + 1}." f"Global Step: {self.__rpc_ctrl.step_info.global_step + 1}."
) )
# Recover the previous number of training samples
train_rpcs = list(
filter(
lambda rpc: rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
self.__model_rpcs,
)
)
train_batch_size = train_rpcs[0].n_seqs
hist_samples = (
train_batch_size * self.__recover_info.last_step_info.global_step
)
training_sample_name = names.training_samples(
constants.experiment_name(), constants.trial_name()
)
name_resolve.add(training_sample_name, str(hist_samples), replace=True)
# for benchmark # for benchmark
self.e2e_time_history = [] self.e2e_time_history = []
self.__benchmark_steps = config.exp_ctrl.benchmark_steps self.__benchmark_steps = config.exp_ctrl.benchmark_steps
@ -274,6 +291,7 @@ class MasterWorker(worker_base.Worker):
] ]
# wandb init, connect to remote wandb host # wandb init, connect to remote wandb host
if self.wandb_config.mode != "disabled":
wandb.login() wandb.login()
wandb.init( wandb.init(
mode=self.wandb_config.mode, mode=self.wandb_config.mode,
@ -327,7 +345,7 @@ class MasterWorker(worker_base.Worker):
global_step=-1, global_step=-1,
) )
def _poll(self): async def __poll_async(self):
is_new_epoch = False is_new_epoch = False
if not self.__initialized: if not self.__initialized:
@ -369,7 +387,7 @@ class MasterWorker(worker_base.Worker):
self.logger.info(s) self.logger.info(s)
# Traverse over the dataflow graph for once. # Traverse over the dataflow graph for once.
self.func_executor.execute_step() await self.func_executor.execute_step()
# Post-process. # Post-process.
if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt: if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt:
@ -431,6 +449,18 @@ class MasterWorker(worker_base.Worker):
return worker_base.PollResult(sample_count=1, batch_count=1) return worker_base.PollResult(sample_count=1, batch_count=1)
async def _poll_async(self):
name = names.experiment_status(
constants.experiment_name(), constants.trial_name()
)
name_resolve.add(name, ExpStatus.RUNNING, replace=True)
try:
r = await self.__poll_async()
except Exception as e:
name_resolve.add(name, ExpStatus.ABORTED, replace=True)
raise e
return r
def _log_training_stats(self, e2e_time: float, time_since_configure: float): def _log_training_stats(self, e2e_time: float, time_since_configure: float):
# calculate flops # calculate flops
######################################### #########################################
@ -482,8 +512,15 @@ class MasterWorker(worker_base.Worker):
+ colorama.Style.RESET_ALL + colorama.Style.RESET_ALL
) )
# Update experiment status to inform other workers
name = names.experiment_status(
constants.experiment_name(), constants.trial_name()
)
name_resolve.add(name, ExpStatus.COMPLETE, replace=True)
# Send requests to pause model workers. # Send requests to pause model workers.
# Model workers will not respond to this message. # Model workers will not respond to this message.
# FIXME: request to model workers is unnecessary
self.__stream.request( self.__stream.request(
handlers=list(range(self.config.n_model_workers)), handlers=list(range(self.config.n_model_workers)),
handle_type="reset", handle_type="reset",

View File

@ -19,8 +19,7 @@ import realhf.api.core.dfg as dfg
import realhf.api.core.system_api as config_pkg import realhf.api.core.system_api as config_pkg
import realhf.base.recover as recover import realhf.base.recover as recover
import realhf.system.request_reply_stream as request_reply_stream import realhf.system.request_reply_stream as request_reply_stream
from realhf import ModelShardID from realhf.api.core.config import ModelName, ModelShardID
from realhf.api.core.config import ModelName
from realhf.api.core.model_api import ReaLModelConfig from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import constants, logging, stats_tracker, topology from realhf.base import constants, logging, stats_tracker, topology
from realhf.system.buffer import AsyncIOSequenceBuffer from realhf.system.buffer import AsyncIOSequenceBuffer

View File

@ -465,6 +465,7 @@ class ModelWorker(worker_base.Worker):
# because we may want to copy huggingface configurations from it, and # because we may want to copy huggingface configurations from it, and
# th next recover save will remove this symlink. # th next recover save will remove this symlink.
dst_path = Path(model_path).parent / "_tmp_ckpt" dst_path = Path(model_path).parent / "_tmp_ckpt"
shutil.rmtree(dst_path, ignore_errors=True)
shutil.copytree(model_path, dst_path) shutil.copytree(model_path, dst_path)
os.unlink(model_path) os.unlink(model_path)
os.system(f"mv {str(dst_path)} {model_path}") os.system(f"mv {str(dst_path)} {model_path}")
@ -669,14 +670,26 @@ class ModelWorker(worker_base.Worker):
self.data_manager.store(x) self.data_manager.store(x)
assert len(set([x.ids[0] for x in data_loaded])) == len(data_loaded) assert len(set([x.ids[0] for x in data_loaded])) == len(data_loaded)
if len(data_loaded) > 0:
meta_sample = data_api.SequenceSample.gather(data_loaded).meta()
else:
meta_sample = None meta_sample = None
birth_times = []
if len(data_loaded) > 0:
sample = data_api.SequenceSample.gather(data_loaded)
meta_sample = sample.meta()
if "birth_time" in sample.keys:
birth_times = (
sample.data["birth_time"].flatten().cpu().numpy().tolist()
)
assert len(birth_times) == meta_sample.bs
else:
birth_times = (
time.monotonic_ns()
+ np.arange(len(data_loaded), dtype=np.int64)
).tolist()
res = data_api.DataBatchMeta( res = data_api.DataBatchMeta(
dp_rank=dp_rank, dp_rank=dp_rank,
meta_sample=meta_sample, meta_sample=meta_sample,
birth_times=birth_times,
) )
elif request.handle_name == "spec": elif request.handle_name == "spec":
# Raw dataset without filtering. # Raw dataset without filtering.

View File

@ -80,7 +80,7 @@ class PartialRolloutManager:
async with session.post( async with session.post(
f"http://{self.gserver_manager_addr}/schedule_request", f"http://{self.gserver_manager_addr}/schedule_request",
json=asdict(req_meta), json=asdict(req_meta),
timeout=ClientTimeout(total=self.timeout, sock_connect=30), timeout=ClientTimeout(total=self.timeout, sock_connect=self.timeout),
) as response: ) as response:
response.raise_for_status() response.raise_for_status()
res = await response.json() res = await response.json()

View File

@ -44,16 +44,9 @@ class ZMQJsonPusher:
TypeError: If data is not JSON-serializable TypeError: If data is not JSON-serializable
zmq.ZMQError: If ZeroMQ operation fails zmq.ZMQError: If ZeroMQ operation fails
""" """
try:
# Directly encode to bytes without intermediate string # Directly encode to bytes without intermediate string
json_bytes = asbytes(orjson.dumps(data)) json_bytes = asbytes(orjson.dumps(data))
self.socket.send(json_bytes, flags=zmq.NOBLOCK, copy=False) self.socket.send(json_bytes, copy=False)
except (TypeError, ValueError) as e:
raise TypeError(f"Data not JSON-serializable: {e}")
except zmq.ZMQError as e:
if e.errno == zmq.EAGAIN:
logger.warning("Push operation would block (queue full)")
raise
def close(self) -> None: def close(self) -> None:
"""Clean up resources.""" """Clean up resources."""

View File

@ -14,6 +14,7 @@ from aiohttp.client import ClientTimeout
from realhf.api.core.agent_api import make_agent from realhf.api.core.agent_api import make_agent
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer, make_dataset from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer, make_dataset
from realhf.api.core.env_api import make_env from realhf.api.core.env_api import make_env
from realhf.api.core.system_api import ExpStatus
from realhf.api.core.system_api import RolloutWorker as RolloutWorkerConfig from realhf.api.core.system_api import RolloutWorker as RolloutWorkerConfig
from realhf.base import ( from realhf.base import (
constants, constants,
@ -24,6 +25,7 @@ from realhf.base import (
recover, recover,
seeding, seeding,
) )
from realhf.base.monitor import RolloutStat
from realhf.system.partial_rollout import PartialRolloutManager from realhf.system.partial_rollout import PartialRolloutManager
from realhf.system.push_pull_stream import NameResolvingZmqPusher from realhf.system.push_pull_stream import NameResolvingZmqPusher
from realhf.system.worker_base import AsyncWorker, PollResult from realhf.system.worker_base import AsyncWorker, PollResult
@ -80,8 +82,9 @@ class RolloutWorker(AsyncWorker):
self.gserver_manager_addr = None self.gserver_manager_addr = None
self.rollout_tasks: Dict[Hashable, asyncio.Task] = {} self.rollout_tasks: Dict[Hashable, asyncio.Task] = {}
# recover info # Since the rollout worker doesn't compute staleness,
self.__recover_run, self.__recover_info = recover.load_recover_info() # we don't need to recover rollout_stat here.
self.rollout_stat = RolloutStat()
return config.worker_info return config.worker_info
@ -182,10 +185,8 @@ class RolloutWorker(AsyncWorker):
self.data_generator = enumerate(self.dataloader) self.data_generator = enumerate(self.dataloader)
return None return None
# NOTE: no need to ignore ids during recover, because model workers will do so
data_id = cur_sample.ids[0] data_id = cur_sample.ids[0]
if self.__recover_run and data_id in self.__recover_info.hash_vals_to_ignore:
self.__recover_info.hash_vals_to_ignore.remove(data_id)
return None
assert data_id not in self.rollout_tasks assert data_id not in self.rollout_tasks
return cur_sample return cur_sample
@ -195,11 +196,14 @@ class RolloutWorker(AsyncWorker):
f"http://{self.gserver_manager_addr}/allocate_rollout", f"http://{self.gserver_manager_addr}/allocate_rollout",
json=dict(qid=qid), json=dict(qid=qid),
timeout=ClientTimeout( timeout=ClientTimeout(
total=self.config.rollout_request_timeout, sock_connect=30 total=self.config.rollout_request_timeout,
sock_connect=self.config.rollout_request_timeout,
), ),
) as resp: ) as resp:
resp.raise_for_status() resp.raise_for_status()
res = await resp.json() res = await resp.json()
if not res["success"]:
logger.info(f"Cannot allocate new rollout because: {res['reason']}")
return res["success"] return res["success"]
async def _poll_async(self): async def _poll_async(self):
@ -213,6 +217,21 @@ class RolloutWorker(AsyncWorker):
f"Time consumed: {time.perf_counter() - tik}s" f"Time consumed: {time.perf_counter() - tik}s"
) )
# Check experiment finish.
name = names.experiment_status(
constants.experiment_name(), constants.trial_name()
)
try:
exp_status = name_resolve.wait(name, timeout=300)
if exp_status != str(ExpStatus.RUNNING):
self.exit()
return PollResult(0, 0)
except TimeoutError:
raise TimeoutError(
f"Waiting for experiment status timeout. "
"This indicates that the master worker is not running. Exit the worker."
)
if self.push_stream is None: if self.push_stream is None:
# Initialize stream after configure to ensure that puller names have been written. # Initialize stream after configure to ensure that puller names have been written.
self.push_stream = NameResolvingZmqPusher( self.push_stream = NameResolvingZmqPusher(
@ -236,13 +255,24 @@ class RolloutWorker(AsyncWorker):
qid = data.ids[0] qid = data.ids[0]
can_rollout = await self.allocate_new_rollout(qid) can_rollout = await self.allocate_new_rollout(qid)
if can_rollout: if can_rollout:
assert qid not in self.act_queues
self.act_queues[qid] = asyncio.Queue(1024) self.act_queues[qid] = asyncio.Queue(1024)
task = asyncio.create_task(self.rollout_task(qid, data)) task = asyncio.create_task(self.rollout_task(qid, data))
assert qid not in self.rollout_tasks
self.rollout_tasks[qid] = task self.rollout_tasks[qid] = task
self._cur_data = None self._cur_data = None
self.rollout_stat.submitted += 1
self.rollout_stat.running += 1
logger.info(
f"Submit a new rollout for qid {qid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
# Run rollouts and wait # Run rollouts and wait
done, *_ = await asyncio.gather( done, *_ = await asyncio.gather(
self.poll_rollout_task(), self.poll_rollout_task(),
@ -261,10 +291,13 @@ class RolloutWorker(AsyncWorker):
self.rollout_tasks.pop(qid) self.rollout_tasks.pop(qid)
self.act_queues.pop(qid) self.act_queues.pop(qid)
self.rollout_stat.running -= 1
accepted = False accepted = False
if len(trajs) > 0: if len(trajs) > 0:
accepted = True accepted = True
self.push_stream.push([traj.as_json_compatible() for traj in trajs]) self.push_stream.push([traj.as_json_compatible() for traj in trajs])
self.rollout_stat.accepted += 1
n_tokens = 0 n_tokens = 0
for traj in trajs: for traj in trajs:
@ -278,11 +311,18 @@ class RolloutWorker(AsyncWorker):
"/finish_rollout", "/finish_rollout",
json=info, json=info,
timeout=ClientTimeout( timeout=ClientTimeout(
total=self.config.rollout_request_timeout, sock_connect=30 total=self.config.rollout_request_timeout,
sock_connect=self.config.rollout_request_timeout,
), ),
) as resp: ) as resp:
resp.raise_for_status() resp.raise_for_status()
assert (await resp.json())["success"] assert (await resp.json())["success"]
logger.info(
f"Finish rollout for qid {qid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
for traj in trajs: for traj in trajs:
batch_count += traj.bs batch_count += traj.bs

View File

@ -1,4 +1,5 @@
import queue import queue
import sys
import threading import threading
import time import time
from typing import Any, List, Optional from typing import Any, List, Optional
@ -12,9 +13,11 @@ from realhf.api.core.data_api import (
make_dataset, make_dataset,
register_dataset, register_dataset,
) )
from realhf.base import constants from realhf.base import constants, logging
from realhf.system.push_pull_stream import NameResolvingZmqPuller from realhf.system.push_pull_stream import NameResolvingZmqPuller
logger = logging.getLogger("StreamDataset")
class PullerStreamDataset(Dataset): class PullerStreamDataset(Dataset):
def __init__( def __init__(
@ -45,15 +48,12 @@ class PullerStreamDataset(Dataset):
del dataset, datasets del dataset, datasets
self.pull_timeout_ms = pull_timeout_ms self.pull_timeout_ms = pull_timeout_ms
self.data_queue = queue.Queue(maxsize=self.dataset_size) self.data_queue = queue.Queue(maxsize=self.dataset_size * util.world_size)
self._stop_event = threading.Event() self._stop_event = threading.Event()
# Pass ZMQ context (thread-safe) and let worker create the socket # Pass ZMQ context (thread-safe) and let worker create the socket
self.util = util self.util = util
self.worker_thread = threading.Thread( self.worker_thread = threading.Thread(target=self._pull_data_worker)
target=self._pull_data_worker,
daemon=True,
)
self.worker_thread.start() self.worker_thread.start()
def _pull_data_worker(self): def _pull_data_worker(self):
@ -71,13 +71,19 @@ class PullerStreamDataset(Dataset):
processed_data = [ processed_data = [
SequenceSample.from_json_compatible(x) for x in data SequenceSample.from_json_compatible(x) for x in data
] ]
self.data_queue.put_nowait(processed_data) logger.debug(
f"Get data {[x.ids[0] for x in processed_data]} from puller stream."
)
self.data_queue.put(processed_data)
except queue.Empty: except queue.Empty:
logger.debug(f"No data from puller stream.")
time.sleep(0.1) time.sleep(0.1)
continue continue
finally: finally:
# Ensure socket is closed in the same thread # Ensure socket is closed in the same thread
del stream del stream
# Exit if this thread has an error
sys.exit(1)
def __getitem__(self, idx: int) -> Optional[Any]: def __getitem__(self, idx: int) -> Optional[Any]:
samples = [] samples = []

View File

@ -60,3 +60,7 @@ orjson>=3.10.16
flask flask
setuptools>=62.3.0,<75.9 setuptools>=62.3.0,<75.9
func_timeout func_timeout
jupyter-book
uvloop>=0.21.0
uvicorn>=0.34.2
fastapi>=0.115.12

270
setup.py
View File

@ -2,280 +2,10 @@
# Copyright 2024 Wei Fu & Zhiyu Mei # Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License"). # Licensed under the Apache License, Version 2.0 (the "License").
import contextlib
import io
import os
import re
import subprocess
import warnings
from pathlib import Path
from typing import List, Set
import setuptools import setuptools
import torch
import torch.utils.cpp_extension as torch_cpp_ext
from packaging.version import Version, parse
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
ROOT_DIR = os.path.abspath(os.path.dirname(__file__))
# Supported NVIDIA GPU architectures.
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
def _is_cuda() -> bool:
return os.getenv("REAL_CUDA", "0") == "1"
# Compiler flags.
CXX_FLAGS = ["-g", "-O3", "-std=c++17"]
NVCC_FLAGS = ["-O3", "-std=c++17"]
if _is_cuda() and CUDA_HOME is None:
raise RuntimeError(
"Cannot find CUDA_HOME. In GPU environment, CUDA must be available to build the package."
)
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
def glob(pattern: str):
root = Path(__name__).parent
return [str(p) for p in root.glob(pattern)]
def get_pybind11_include_path() -> str:
pybind11_meta = subprocess.check_output(
"python3 -m pip show pybind11", shell=True
).decode("ascii")
for line in pybind11_meta.split("\n"):
line = line.strip()
if line.startswith("Location: "):
return os.path.join(line.split(": ")[1], "pybind11", "include")
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
"""Get the CUDA version from nvcc.
Adapted from
https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
"""
nvcc_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
output = nvcc_output.split()
release_idx = output.index("release") + 1
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
return nvcc_cuda_version
def get_torch_arch_list() -> Set[str]:
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
# e.g."8.0" or "7.5,8.0,8.6+PTX".Here, the "8.6+PTX" option asks the
# compiler to additionally include PTX code that can be runtime - compiled
# and executed on the 8.6 or newer architectures.While the PTX code will
# not give the best performance on the newer architectures, it provides
# forward compatibility.
env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
if env_arch_list is None:
return set()
# List are separated by; or space.
torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
if not torch_arch_list:
return set()
# Filter out the invalid architectures and print a warning.
valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
{s + "+PTX" for s in NVIDIA_SUPPORTED_ARCHS}
)
arch_list = torch_arch_list.intersection(valid_archs)
# If none of the specified architectures are valid, raise an error.
if not arch_list:
raise RuntimeError(
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
f"variable ({env_arch_list}) is supported. "
f"Supported CUDA architectures are: {valid_archs}."
)
invalid_arch_list = torch_arch_list - valid_archs
if invalid_arch_list:
warnings.warn(
f"Unsupported CUDA architectures ({invalid_arch_list}) are "
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
f"({env_arch_list}). Supported CUDA architectures are: "
f"{valid_archs}.",
stacklevel=2,
)
return arch_list
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
if _is_cuda() and not compute_capabilities:
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# GPUs on the current machine.
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 7:
raise RuntimeError(
"GPUs with compute capability below 7.0 are not supported."
)
compute_capabilities.add(f"{major}.{minor}")
ext_modules = []
if _is_cuda():
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if not compute_capabilities:
# If no GPU is specified nor available, add all supported architectures
# based on the NVCC CUDA version.
compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy()
if nvcc_cuda_version < Version("11.1"):
compute_capabilities.remove("8.6")
if nvcc_cuda_version < Version("11.8"):
compute_capabilities.remove("8.9")
compute_capabilities.remove("9.0")
# Validate the NVCC CUDA version.
if nvcc_cuda_version < Version("11.0"):
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
if nvcc_cuda_version < Version("11.1") and any(
cc.startswith("8.6") for cc in compute_capabilities
):
raise RuntimeError(
"CUDA 11.1 or higher is required for compute capability 8.6."
)
if nvcc_cuda_version < Version("11.8"):
if any(cc.startswith("8.9") for cc in compute_capabilities):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
# the previous versions of CUDA 11 and targeting compute capability 8.0.
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
# instead of 8.9.
warnings.warn(
"CUDA 11.8 or higher is required for compute capability 8.9. "
"Targeting compute capability 8.0 instead.",
stacklevel=2,
)
compute_capabilities = set(
cc for cc in compute_capabilities if not cc.startswith("8.9")
)
compute_capabilities.add("8.0+PTX")
if any(cc.startswith("9.0") for cc in compute_capabilities):
raise RuntimeError(
"CUDA 11.8 or higher is required for compute capability 9.0."
)
NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy()
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
num = capability[0] + capability[2]
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
if capability.endswith("+PTX"):
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
if int(capability[0]) >= 8:
NVCC_FLAGS_PUNICA += [
"-gencode",
f"arch=compute_{num},code=sm_{num}",
]
if capability.endswith("+PTX"):
NVCC_FLAGS_PUNICA += [
"-gencode",
f"arch=compute_{num},code=compute_{num}",
]
# Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"):
nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
num_threads = min(os.cpu_count(), nvcc_threads)
NVCC_FLAGS += ["--threads", str(num_threads)]
if nvcc_cuda_version >= Version("11.8"):
NVCC_FLAGS += ["-DENABLE_FP8_E5M2"]
# changes for punica kernels
NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS
REMOVE_NVCC_FLAGS = [
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
for flag in REMOVE_NVCC_FLAGS:
with contextlib.suppress(ValueError):
torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag)
os.makedirs(os.path.join(ROOT_DIR, "realhf", "_C"), exist_ok=True)
no_ext = os.getenv("REAL_NO_EXT", "0") == "1"
if not no_ext and _is_cuda():
gae_extension = CUDAExtension(
name="realhf._C.cugae",
sources=[
"csrc/cugae/gae.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS
+ [
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
],
},
libraries=["cuda"],
)
ext_modules.append(gae_extension)
interval_op_cuda = CUDAExtension(
name="realhf._C.interval_op_cuda",
sources=[
"csrc/interval_op/interval_op.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
libraries=["cuda"],
)
ext_modules.append(interval_op_cuda)
if not no_ext:
interval_extension = setuptools.Extension(
name="realhf._C.interval_op",
sources=[
"csrc/interval_op/interval_op.cpp",
],
language="c++",
extra_compile_args=[
"-O3",
"-Wall",
"-std=c++17",
],
include_dirs=[
get_pybind11_include_path(),
],
)
ext_modules.append(interval_extension)
setuptools.setup( setuptools.setup(
name="realhf", name="realhf",
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
include_package_data=True, include_package_data=True,
package_data={
"": [
"csrc/**/*.cu",
"csrc/**/*.cuh",
"csrc/**/*.hpp",
"csrc/**/*.cpp",
],
},
) )

View File

@ -85,7 +85,9 @@ async def test_collect_trajectory_happy_path(agent, mock_env, mock_prompt, mock_
sample = result[0] sample = result[0]
assert sample.ids == [str(123)] assert sample.ids == [str(123)]
assert torch.equal(sample.data["packed_prompts"], torch.tensor([1, 2, 3])) assert torch.equal(sample.data["packed_prompts"], torch.tensor([1, 2, 3]))
assert torch.equal(sample.data["rewards"], torch.tensor([0.8, 1.2])) # r = [0.5, 0.7]
# ((r - 0.5) * 2 - bias) * scaling, bias=0.1, scaling=2.0
assert torch.equal(sample.data["rewards"], torch.tensor([-0.2, 0.6]))
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -184,7 +184,7 @@ def _test_data_transfer(
] ]
storage_tracker.add_data_synced( storage_tracker.add_data_synced(
gpu_id, gpu_id,
ids=[i + dp_rank * world_size for i in range(world_size)], ids=[str(i + dp_rank * world_size) for i in range(world_size)],
key=key, key=key,
is_owner=True, is_owner=True,
) )
@ -199,7 +199,7 @@ def _test_data_transfer(
dist.all_reduce(input_ids) dist.all_reduce(input_ids)
s = SequenceSample.from_default( s = SequenceSample.from_default(
ids=[i + dp_rank * world_size for i in range(world_size)], ids=[str(i + dp_rank * world_size) for i in range(world_size)],
seqlens=seqlens.numpy().tolist(), seqlens=seqlens.numpy().tolist(),
data=dict(input_ids=input_ids), data=dict(input_ids=input_ids),
) )
@ -216,7 +216,7 @@ def _test_data_transfer(
dist.barrier() dist.barrier()
all_ids = list(range(world_size * from_topo.get_dim("data"))) all_ids = list(map(str, range(world_size * from_topo.get_dim("data"))))
np.random.shuffle(all_ids) np.random.shuffle(all_ids)
_all_ids = [all_ids] _all_ids = [all_ids]
dist.broadcast_object_list(_all_ids, src=0) dist.broadcast_object_list(_all_ids, src=0)
@ -236,7 +236,7 @@ def _test_data_transfer(
) )
] ]
size_per_dp = len(all_ids) // dp_size size_per_dp = len(all_ids) // dp_size
dests[gpu_id] = [coord.data * size_per_dp + i for i in range(size_per_dp)] dests[gpu_id] = [str(coord.data * size_per_dp + i) for i in range(size_per_dp)]
for gpu_id in range(world_size): for gpu_id in range(world_size):
if gpu_id not in dests: if gpu_id not in dests:
@ -253,29 +253,21 @@ def _test_data_transfer(
print("success") print("success")
parallelism = [(1, 4, 2), (1, 8, 1)] parallelism = [(4, 1, 1), (2, 2, 2), (1, 8, 1), (3, 2, 1), (2, 1, 2), (1, 2, 2)]
@pytest.mark.skipif( @pytest.mark.parametrize("from_pp_dp_tp", parallelism)
os.cpu_count() < 32 or testing.get_free_mem_gb() < 50, @pytest.mark.parametrize("to_pp_dp_tp", parallelism)
reason="The parameter reallocation test requires at least 32 CPUs and 50GB memory.",
)
@pytest.mark.parametrize("from_pp_dp_tp", [(1, 4, 2)])
@pytest.mark.parametrize("to_pp_dp_tp", [(1, 8, 1)])
@pytest.mark.distributed @pytest.mark.distributed
def test_data_transfer( def test_data_transfer(
tmp_path, tmp_path,
from_pp_dp_tp: Tuple, from_pp_dp_tp: Tuple,
to_pp_dp_tp: Tuple, to_pp_dp_tp: Tuple,
): ):
expr_name = uuid.uuid4()
trial_name = uuid.uuid4()
constants.set_force_cpu(True) constants.set_force_cpu(True)
test_impl = LocalMultiProcessTest( test_impl = LocalMultiProcessTest(
world_size=16, world_size=16,
func=_test_data_transfer, func=_test_data_transfer,
expr_name=expr_name,
trial_name=trial_name,
timeout_secs=300, timeout_secs=300,
tmp_path=tmp_path, tmp_path=tmp_path,
from_pp_dp_tp=from_pp_dp_tp, from_pp_dp_tp=from_pp_dp_tp,

View File

@ -526,10 +526,7 @@ def _test_para_realloc(
parallelism = [(4, 1, 1), (2, 2, 2), (1, 8, 1), (3, 2, 1), (2, 1, 2), (1, 2, 2)] parallelism = [(4, 1, 1), (2, 2, 2), (1, 8, 1), (3, 2, 1), (2, 1, 2), (1, 2, 2)]
@pytest.mark.skipif( @pytest.mark.skip("NCCL-based parameter reallocation is not used currently.")
os.cpu_count() < 32 or testing.get_free_mem_gb() < 50,
reason="The parameter reallocation test requires at least 32 CPUs and 50GB memory.",
)
@pytest.mark.parametrize("model_family_name", ["gpt2", "llama"]) @pytest.mark.parametrize("model_family_name", ["gpt2", "llama"])
@pytest.mark.parametrize("is_critic", [False, True]) @pytest.mark.parametrize("is_critic", [False, True])
@pytest.mark.parametrize("from_pp_dp_tp", parallelism) @pytest.mark.parametrize("from_pp_dp_tp", parallelism)

View File

@ -2,10 +2,7 @@
# Copyright 2024 Wei Fu & Zhiyu Mei # Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License"). # Licensed under the Apache License, Version 2.0 (the "License").
import os
import random
import time import time
import uuid
from typing import * from typing import *
import numpy as np import numpy as np
@ -44,6 +41,7 @@ def maybe_synchronize_cuda():
"n_intervals", list(reversed([1, 100, 500, 1000, 2000, 4000, 10000, 100000])) "n_intervals", list(reversed([1, 100, 500, 1000, 2000, 4000, 10000, 100000]))
) )
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32, torch.float16]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32, torch.float16])
@pytest.mark.gpu
def test_get(n_intervals: int, dtype: torch.dtype): def test_get(n_intervals: int, dtype: torch.dtype):
device = torch.device("cuda") device = torch.device("cuda")

View File

@ -100,9 +100,9 @@ def test_denominator_edge_cases(tracker):
tracker.denominator(mask=zero_mask) tracker.denominator(mask=zero_mask)
tracker.stat(denominator="mask", value=torch.FloatTensor([1.0, 2.0])) tracker.stat(denominator="mask", value=torch.FloatTensor([1.0, 2.0]))
results = tracker.export() results = tracker.export()
assert torch.isnan(torch.tensor(results["value/min"])) # Should be inf assert "value/min" not in results
assert torch.isnan(torch.tensor(results["value/max"])) # Should be -inf assert "value/max" not in results
assert results["value/avg"] == 0.0 assert "value/avg" not in results
def test_key_specific_export(tracker): def test_key_specific_export(tracker):

View File

@ -19,6 +19,7 @@ from realhf.base.name_resolve import (
BACKENDS = [ BACKENDS = [
("memory", {}), ("memory", {}),
("nfs", {}), ("nfs", {}),
("ray", {}),
] ]
if os.environ.get("REAL_ETCD_ADDR"): if os.environ.get("REAL_ETCD_ADDR"):
BACKENDS.append( BACKENDS.append(
@ -61,6 +62,12 @@ def name_resolve(request):
repo = Etcd3NameRecordRepository(**kwargs) repo = Etcd3NameRecordRepository(**kwargs)
yield repo yield repo
repo.reset() repo.reset()
elif backend_type == "ray":
from realhf.base.name_resolve import RayNameResolveRepository
repo = RayNameResolveRepository(**kwargs)
yield repo
repo.reset()
def test_basic_add_get(name_resolve): def test_basic_add_get(name_resolve):
@ -381,18 +388,19 @@ def test_wait_with_concurrent_delete(name_resolve):
def add_then_delete(): def add_then_delete():
time.sleep(0.1) time.sleep(0.1)
name_resolve.add("test_wait_key", "test_value") name_resolve.add("test_wait_key", "test_value")
time.sleep(0.1) time.sleep(1.0)
name_resolve.delete("test_wait_key") name_resolve.delete("test_wait_key")
thread = threading.Thread(target=add_then_delete, daemon=True) thread = threading.Thread(target=add_then_delete, daemon=True)
thread.start() thread.start()
# Wait with a timeout long enough to capture the key # Wait with a timeout long enough to capture the key
value = name_resolve.wait("test_wait_key", timeout=2.0, poll_frequency=0.05) value = name_resolve.wait("test_wait_key", timeout=3.0, poll_frequency=0.05)
assert value == "test_value" assert value == "test_value"
# Wait for the thread to complete # Wait for the thread to complete
thread.join() thread.join()
time.sleep(0.5)
# Verify the key was deleted # Verify the key was deleted
with pytest.raises(NameEntryNotFoundError): with pytest.raises(NameEntryNotFoundError):

View File

@ -76,9 +76,10 @@ def test_buffer_recover(
trial_name=trial_name, trial_name=trial_name,
mode="local", mode="local",
# allocation_mode=f"m1d{dp}p1", # allocation_mode=f"m1d{dp}p1",
nodelist="slurmd-01",
allocation_mode="manual", allocation_mode="manual",
inf=MFCConfig( inf=MFCConfig(
device_mesh="NODE01:0,1,2,3,4,5,6,7", device_mesh="slurmd-01:0,1,2,3,4,5,6,7",
parallel=ParallelismConfig( parallel=ParallelismConfig(
tensor_parallel_size=2, tensor_parallel_size=2,
pipeline_parallel_size=2, pipeline_parallel_size=2,
@ -86,7 +87,7 @@ def test_buffer_recover(
), ),
), ),
train=MFCConfig( train=MFCConfig(
device_mesh="NODE01:8,9,10,11,12,13,14,15", device_mesh="slurmd-01:8,9,10,11,12,13,14,15",
parallel=ParallelismConfig( parallel=ParallelismConfig(
tensor_parallel_size=2, tensor_parallel_size=2,
pipeline_parallel_size=2, pipeline_parallel_size=2,

View File

@ -51,6 +51,7 @@ def math_code_dataset(request, save_path):
return dataset return dataset
@pytest.mark.skip("symmetric allocation is not used")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dp,pp,mp", "dp,pp,mp",
[ [
@ -121,9 +122,138 @@ def test_ppo_symm(
run_test_exp(exp_cfg) run_test_exp(exp_cfg)
@pytest.mark.parametrize(
"gdp,gpp,gmp",
[
(2, 1, 1),
(1, 1, 2),
],
)
@pytest.mark.parametrize(
"dp,pp,mp",
[
(2, 1, 1),
(1, 2, 1),
(1, 1, 2),
],
)
def test_ppo_decoupled(
tmp_path_factory,
tokenizer,
math_code_dataset,
save_path,
cpu_hf_model,
mconfig,
dp,
pp,
mp,
gdp,
gpp,
gmp,
):
# Setup experiment env. Should be done before any other operations.
log_root = tmp_path_factory.mktemp("ppo")
cluster.spec.fileroot = str(log_root)
constants.set_experiment_trial_names(
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
)
minbs = 32
exp_cfg = PPOMATHConfig(
experiment_name=testing._DEFAULT_EXPR_NAME,
trial_name=testing._DEFAULT_TRIAL_NAME,
mode="local",
allocation_mode=f"manual",
nodelist="slurmd-01",
n_nodes=1,
n_gpus_per_node=mp * dp * pp + gmp * gdp * gpp,
actor=ModelTrainEvalConfig(
path=str(save_path),
init_from_scratch=True,
backend="mock_train",
),
ref=ModelTrainEvalConfig(
path=str(save_path),
init_from_scratch=True,
),
critic=ModelTrainEvalConfig(
path=str(save_path),
init_from_scratch=True,
init_critic_from_actor=True,
backend="mock_train",
),
actor_gen=MFCConfig(
device_mesh="slurmd-01:0,1",
parallel=ParallelismConfig(
tensor_parallel_size=gmp,
pipeline_parallel_size=gpp,
data_parallel_size=gdp,
),
),
actor_train=MFCConfig(
device_mesh="slurmd-01:2,3",
parallel=ParallelismConfig(
tensor_parallel_size=mp,
pipeline_parallel_size=pp,
data_parallel_size=dp,
),
),
critic_train=MFCConfig(
device_mesh="slurmd-01:2,3",
parallel=ParallelismConfig(
tensor_parallel_size=mp,
pipeline_parallel_size=pp,
data_parallel_size=dp,
),
),
critic_inf=MFCConfig(
device_mesh="slurmd-01:2,3",
parallel=ParallelismConfig(
tensor_parallel_size=mp,
pipeline_parallel_size=pp,
data_parallel_size=dp,
),
),
ref_inf=MFCConfig(
device_mesh="slurmd-01:2,3",
parallel=ParallelismConfig(
tensor_parallel_size=mp,
pipeline_parallel_size=pp,
data_parallel_size=dp,
),
),
rew_inf=MFCConfig(
device_mesh="slurmd-01:2,3",
parallel=ParallelismConfig(
tensor_parallel_size=mp,
pipeline_parallel_size=pp,
data_parallel_size=dp,
),
),
dataset=PromptOnlyDatasetConfig(
path=str(save_path / "math_code_dataset.jsonl"),
max_prompt_len=mconfig.n_positions // 2,
train_bs_n_seqs=minbs,
fill_to_max_length=False,
),
ppo=PPOHyperparameters(
gen=GenerationHyperparameters(
max_new_tokens=4,
min_new_tokens=4,
greedy=True,
use_cuda_graph=False,
),
),
group_size=2,
)
run_test_exp(exp_cfg)
# The global resharding strategy, where all MFCs # The global resharding strategy, where all MFCs
# occupy the same device mesh but with different # occupy the same device mesh but with different
# parallelization strategies. # parallelization strategies.
@pytest.mark.skip("Global resharding is not used.")
@pytest.mark.parametrize("actor_gen", [(1, 2, 1)]) @pytest.mark.parametrize("actor_gen", [(1, 2, 1)])
@pytest.mark.parametrize("actor_train", [(1, 1, 2)]) @pytest.mark.parametrize("actor_train", [(1, 1, 2)])
@pytest.mark.parametrize("critic_inf", [(1, 1, 2)]) @pytest.mark.parametrize("critic_inf", [(1, 1, 2)])
@ -244,6 +374,7 @@ def test_ppo_global_reshard(
# Actor/critic train and ref_inf/rew_inf are on disjoint # Actor/critic train and ref_inf/rew_inf are on disjoint
# device meshes and executed concurrently. # device meshes and executed concurrently.
@pytest.mark.skip("Critic is not used.")
@pytest.mark.parametrize("actor_gen", [(2, 2, 1)]) @pytest.mark.parametrize("actor_gen", [(2, 2, 1)])
@pytest.mark.parametrize("critic_inf", [(2, 1, 2)]) @pytest.mark.parametrize("critic_inf", [(2, 1, 2)])
def test_ppo_param_realloc_sub_device_mesh( def test_ppo_param_realloc_sub_device_mesh(
@ -306,7 +437,7 @@ def test_ppo_param_realloc_sub_device_mesh(
), ),
), ),
actor_gen=MFCConfig( actor_gen=MFCConfig(
device_mesh="NODE01:0,1,2,3", device_mesh="slurmd-01:0,1,2,3",
parallel=ParallelismConfig( parallel=ParallelismConfig(
data_parallel_size=actor_gen[0], data_parallel_size=actor_gen[0],
tensor_parallel_size=actor_gen[1], tensor_parallel_size=actor_gen[1],
@ -314,7 +445,7 @@ def test_ppo_param_realloc_sub_device_mesh(
), ),
), ),
actor_train=MFCConfig( actor_train=MFCConfig(
device_mesh="NODE01:4,5,6,7", device_mesh="slurmd-01:4,5,6,7",
parallel=ParallelismConfig( parallel=ParallelismConfig(
data_parallel_size=4, data_parallel_size=4,
tensor_parallel_size=1, tensor_parallel_size=1,
@ -322,7 +453,7 @@ def test_ppo_param_realloc_sub_device_mesh(
), ),
), ),
critic_inf=MFCConfig( critic_inf=MFCConfig(
device_mesh="NODE01:4,5,6,7", device_mesh="slurmd-01:4,5,6,7",
parallel=ParallelismConfig( parallel=ParallelismConfig(
data_parallel_size=critic_inf[0], data_parallel_size=critic_inf[0],
tensor_parallel_size=critic_inf[1], tensor_parallel_size=critic_inf[1],
@ -330,7 +461,7 @@ def test_ppo_param_realloc_sub_device_mesh(
), ),
), ),
rew_inf=MFCConfig( rew_inf=MFCConfig(
device_mesh="NODE01:4,5,6,7", device_mesh="slurmd-01:4,5,6,7",
parallel=ParallelismConfig( parallel=ParallelismConfig(
data_parallel_size=4, data_parallel_size=4,
tensor_parallel_size=1, tensor_parallel_size=1,
@ -338,7 +469,7 @@ def test_ppo_param_realloc_sub_device_mesh(
), ),
), ),
ref_inf=MFCConfig( ref_inf=MFCConfig(
device_mesh="NODE01:4,5,6,7", device_mesh="slurmd-01:4,5,6,7",
parallel=ParallelismConfig( parallel=ParallelismConfig(
data_parallel_size=1, data_parallel_size=1,
tensor_parallel_size=2, tensor_parallel_size=2,
@ -346,7 +477,7 @@ def test_ppo_param_realloc_sub_device_mesh(
), ),
), ),
critic_train=MFCConfig( critic_train=MFCConfig(
device_mesh="NODE01:4,5,6,7", device_mesh="slurmd-01:4,5,6,7",
parallel=ParallelismConfig( parallel=ParallelismConfig(
data_parallel_size=2, data_parallel_size=2,
tensor_parallel_size=1, tensor_parallel_size=1,
@ -389,6 +520,7 @@ def test_ppo_save(
allocation_mode="manual", allocation_mode="manual",
n_nodes=1, n_nodes=1,
n_gpus_per_node=2, n_gpus_per_node=2,
nodelist="slurmd-01",
actor=ModelTrainEvalConfig( actor=ModelTrainEvalConfig(
path=str(save_path), path=str(save_path),
init_from_scratch=True, init_from_scratch=True,
@ -436,7 +568,7 @@ def test_ppo_save(
) )
), ),
actor_train=MFCConfig( actor_train=MFCConfig(
device_mesh="NODE01:0", device_mesh="slurmd-01:0",
parallel=ParallelismConfig( parallel=ParallelismConfig(
data_parallel_size=1, data_parallel_size=1,
tensor_parallel_size=1, tensor_parallel_size=1,
@ -465,7 +597,7 @@ def test_ppo_save(
) )
), ),
critic_train=MFCConfig( critic_train=MFCConfig(
device_mesh="NODE01:1", device_mesh="slurmd-01:1",
parallel=ParallelismConfig( parallel=ParallelismConfig(
data_parallel_size=1, data_parallel_size=1,
tensor_parallel_size=1, tensor_parallel_size=1,

View File

@ -28,6 +28,9 @@ def run_model_worker(cfg, mw, barrier, expr_name=None):
system_api.ALL_EXPERIMENT_CLASSES = {} system_api.ALL_EXPERIMENT_CLASSES = {}
register_experiment(expr_name or testing._DEFAULT_EXPR_NAME, lambda: cfg) register_experiment(expr_name or testing._DEFAULT_EXPR_NAME, lambda: cfg)
constants.set_experiment_trial_names(
mw.worker_info.experiment_name, mw.worker_info.trial_name
)
worker = ModelWorker() worker = ModelWorker()
logger.info("Configuring model worker...") logger.info("Configuring model worker...")

View File

@ -1,100 +0,0 @@
import asyncio
import random
import uuid
import pytest
import torch
from realhf.api.cli_args import GenerationHyperparameters
from realhf.api.core.model_api import (
APIGenerateInput,
APIGenerateOutput,
BundledGenerationOutputs,
)
@pytest.fixture
def sglang_client(request):
from sglang.test.test_utils import is_in_ci
if is_in_ci():
from patch import launch_server_cmd
else:
from sglang.utils import launch_server_cmd
from sglang.utils import terminate_process, wait_for_server
server_process, port = launch_server_cmd(
f"python -m sglang.launch_server --model-path {request.param} --host 0.0.0.0 --skip-tokenizer-init "
)
wait_for_server(f"http://localhost:{port}")
from realhf.impl.model.backend.sglang import SGLangAPIClient
client = SGLangAPIClient(
generate_url=f"http://localhost:{port}/generate",
update_weights_url=f"http://localhost:{port}/update_weights_from_disk",
)
yield client
terminate_process(server_process)
@pytest.mark.parametrize(
"sglang_client",
["/storage/openpsi/models/Qwen__Qwen2.5-7B-Instruct/"],
indirect=True,
)
@pytest.mark.parametrize("group_size", [16])
@pytest.mark.asyncio
async def test_batch_generate(sglang_client, group_size):
bs = 8
# genlen = 16384
genlen = 10
prompt_len = 100
async with sglang_client:
tasks = []
qids = []
for i in range(bs):
qid = str(uuid.uuid4())
prompt_ids = [random.randint(10, 100) for _ in range(prompt_len)]
gconfig = GenerationHyperparameters(
n=group_size,
max_new_tokens=genlen,
)
req = APIGenerateInput(
qid=qid,
prompt_ids=prompt_ids,
input_ids=prompt_ids,
gconfig=gconfig,
return_logprob=True,
)
tasks.append(sglang_client.async_add_generate_request(req, stream=False))
qids.append(qid)
outputs = {}
for r in asyncio.as_completed(tasks):
out = await r
outputs[out.qid] = out
results = [outputs[key] for key in qids]
assert all([isinstance(r, APIGenerateOutput) for r in results])
batch_token_ids = []
batch_logprobs = []
max_seqlen = -1
for x in results:
max_seqlen = max(max_seqlen, max(x.output_lens))
batch_token_ids += x.output_ids
batch_logprobs += x.output_logprobs
pad_token_id = 0
# To be consistent with our internal implementation,
# we should pad generated tokens and logprobs
batch_token_ids = [
t + [pad_token_id] * (max_seqlen - len(t)) for t in batch_token_ids
]
batch_logprobs = [p + [0.0] * (max_seqlen - len(p)) for p in batch_logprobs]
tokens = torch.tensor(batch_token_ids, dtype=torch.long, device="cpu")
assert tokens.shape == (bs * group_size, genlen)
logprobs = torch.tensor(batch_logprobs, dtype=torch.float32, device="cpu")
assert logprobs.shape == (bs * group_size, genlen)

View File

@ -86,7 +86,7 @@ def test_multi_task_reward_interface(save_path, tokenizer_path, math_code_datase
batch_size=4, batch_size=4,
shuffle=True, shuffle=True,
) )
from realhf.impl.model.interface.rw_interface import MultiTaskRewardInterface from realhf.impl.model.interface.math_rw_interface import MultiTaskRewardInterface
with constants.model_scope(testing.MODEL_NAME): with constants.model_scope(testing.MODEL_NAME):
interface = MultiTaskRewardInterface( interface = MultiTaskRewardInterface(

View File

@ -55,7 +55,7 @@ def check_sequences_consistency(
) )
def test_fn( def _fn(
rank: int, rank: int,
world_size: int, world_size: int,
path: str, path: str,
@ -194,12 +194,12 @@ def test_fn(
dist.destroy_process_group() dist.destroy_process_group()
def test_sglang_consistency(tp: int, dp: int, path: str, model_family_name: str): def check_sglang_consistency(tp: int, dp: int, path: str, model_family_name: str):
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
world_size = dp * tp world_size = dp * tp
procs = [ procs = [
mp.Process( mp.Process(
target=test_fn, target=_fn,
args=( args=(
i, i,
world_size, world_size,
@ -236,7 +236,7 @@ if __name__ == "__main__":
# pp=1, # pp=1,
# tp=1, # tp=1,
# ) # )
test_sglang_consistency( check_sglang_consistency(
tp=2, tp=2,
dp=2, dp=2,
path=path, path=path,

View File

@ -59,7 +59,7 @@ def check_sequences_consistency(
) )
def test_fn( def _fn(
rank: int, rank: int,
world_size: int, world_size: int,
path: str, path: str,
@ -203,12 +203,12 @@ def test_fn(
print("success") print("success")
def test_vllm_tp_consistency(tp: int, dp: int, path: str, model_family_name: str): def check_vllm_tp_consistency(tp: int, dp: int, path: str, model_family_name: str):
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
world_size = dp * tp world_size = dp * tp
procs = [ procs = [
mp.Process( mp.Process(
target=test_fn, target=_fn,
args=( args=(
i, i,
world_size, world_size,
@ -236,7 +236,7 @@ def test_vllm_tp_consistency(tp: int, dp: int, path: str, model_family_name: str
if __name__ == "__main__": if __name__ == "__main__":
# for model_family_name in _available_model_classes: # for model_family_name in _available_model_classes:
# path = MODEL_CLASS_TO_PATH[model_family_name] # path = MODEL_CLASS_TO_PATH[model_family_name]
test_vllm_tp_consistency( check_vllm_tp_consistency(
tp=2, tp=2,
dp=2, dp=2,
path="/storage/models/Qwen__Qwen2-1.5B-Instruct/", path="/storage/models/Qwen__Qwen2-1.5B-Instruct/",

View File

@ -1,12 +1,16 @@
# Copyright 2025 Ant Group Inc. # Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0 (the "License"). # Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import dataclasses import dataclasses
import queue
import random import random
import threading
import time import time
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
import aiohttp
import pytest import pytest
from realhf.api.core.config import ModelName from realhf.api.core.config import ModelName
@ -118,7 +122,8 @@ def mock_servers():
@pytest.fixture @pytest.fixture
def gserver_manager(mock_servers): def gserver_manager(request, mock_servers):
train_batch_size, offpolicyness = request.param
testing.clear_name_resolve() testing.clear_name_resolve()
constants.set_experiment_trial_names( constants.set_experiment_trial_names(
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
@ -135,6 +140,10 @@ def gserver_manager(mock_servers):
config = GserverManagerConfig( config = GserverManagerConfig(
model_name=ModelName("default", 0), model_name=ModelName("default", 0),
n_servers=N_SERVERS, n_servers=N_SERVERS,
train_batch_size=train_batch_size,
max_head_offpolicyness=offpolicyness,
flush_request_timeout=300,
max_concurrent_rollouts=128,
schedule_policy="round_robin", schedule_policy="round_robin",
worker_info=WorkerInformation( worker_info=WorkerInformation(
experiment_name=testing._DEFAULT_EXPR_NAME, experiment_name=testing._DEFAULT_EXPR_NAME,
@ -151,6 +160,7 @@ def gserver_manager(mock_servers):
m.exit() m.exit()
@pytest.mark.parametrize("gserver_manager", [(4, 1000)], indirect=True)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_schedule_policy(gserver_manager): async def test_schedule_policy(gserver_manager):
# Test round-robin scheduling # Test round-robin scheduling
@ -171,14 +181,8 @@ async def test_schedule_policy(gserver_manager):
assert idx3 == 0 assert idx3 == 0
@pytest.mark.asyncio @pytest.mark.parametrize("gserver_manager", [(4, 1000)], indirect=True)
async def test_weight_update(gserver_manager): def test_weight_update(gserver_manager):
from fastapi.testclient import TestClient
from realhf.api.core.model_api import GenReqMeta
client = TestClient(gserver_manager.app)
# Set up a new parameter version # Set up a new parameter version
name = names.model_version( name = names.model_version(
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME, "default" testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME, "default"
@ -187,22 +191,14 @@ async def test_weight_update(gserver_manager):
global UPDATE_WEIGHTS_CALL_COUNT global UPDATE_WEIGHTS_CALL_COUNT
UPDATE_WEIGHTS_CALL_COUNT.clear() UPDATE_WEIGHTS_CALL_COUNT.clear()
req_meta = GenReqMeta( gserver_manager._poll()
"2",
prompt_len=100,
group_size=2,
new_token_budget=1024,
predicted_new_tokens=None,
)
client.post("/schedule_request", json=dataclasses.asdict(req_meta))
assert gserver_manager._last_param_realloc_step == 1 assert gserver_manager._last_param_realloc_step == 1
assert len(UPDATE_WEIGHTS_CALL_COUNT) == N_SERVERS assert len(UPDATE_WEIGHTS_CALL_COUNT) == N_SERVERS
for v in UPDATE_WEIGHTS_CALL_COUNT.values(): for v in UPDATE_WEIGHTS_CALL_COUNT.values():
assert v == 1 assert v == 1
# weights updated, no more weights update # weights updated, no more weights update
client.post("/schedule_request", json=dataclasses.asdict(req_meta)) gserver_manager._poll()
assert gserver_manager._last_param_realloc_step == 1 assert gserver_manager._last_param_realloc_step == 1
assert len(UPDATE_WEIGHTS_CALL_COUNT) == N_SERVERS assert len(UPDATE_WEIGHTS_CALL_COUNT) == N_SERVERS
for v in UPDATE_WEIGHTS_CALL_COUNT.values(): for v in UPDATE_WEIGHTS_CALL_COUNT.values():
@ -213,6 +209,7 @@ async def test_weight_update(gserver_manager):
name_resolve.delete(name) name_resolve.delete(name)
@pytest.mark.parametrize("gserver_manager", [(4, 1000)], indirect=True)
def test_server_lifecycle(gserver_manager): def test_server_lifecycle(gserver_manager):
# Test that the server starts and stops properly # Test that the server starts and stops properly
assert gserver_manager.thread is not None assert gserver_manager.thread is not None
@ -224,6 +221,7 @@ def test_server_lifecycle(gserver_manager):
assert not gserver_manager.thread.is_alive() assert not gserver_manager.thread.is_alive()
@pytest.mark.parametrize("gserver_manager", [(4, 1000)], indirect=True)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_http_server_endpoints(gserver_manager): async def test_http_server_endpoints(gserver_manager):
# Test the FastAPI endpoints # Test the FastAPI endpoints
@ -253,6 +251,74 @@ async def test_http_server_endpoints(gserver_manager):
assert responses == set(gserver_manager.server_urls) assert responses == set(gserver_manager.server_urls)
@pytest.mark.parametrize("gserver_manager", [(4, 1000)], indirect=True)
def test_unique_server_urls(gserver_manager): def test_unique_server_urls(gserver_manager):
# Ensure server URLs are unique # Ensure server URLs are unique
assert len(set(gserver_manager.server_urls)) == len(gserver_manager.server_urls) assert len(set(gserver_manager.server_urls)) == len(gserver_manager.server_urls)
@pytest.mark.parametrize("gserver_manager", [(4, 0), (4, 1), (4, 4)], indirect=True)
@pytest.mark.parametrize("n_clients", [1, 2, 3])
def test_offpolicyness_control(n_clients, gserver_manager):
train_batch_size = gserver_manager.config.train_batch_size
offpolicyness = gserver_manager.config.max_head_offpolicyness
addr = gserver_manager.manager_addr
res_queue = queue.Queue(n_clients)
async def _client_thread(res_queue):
cnt = 0
for _ in range(train_batch_size):
async with aiohttp.ClientSession() as session:
async with session.post(
f"http://{addr}/allocate_rollout",
json=dict(qid="a"),
) as resp:
resp.raise_for_status()
res = await resp.json()
cnt += int(res["success"])
res_queue.put(cnt)
def run_client(res_queue):
asyncio.run(_client_thread(res_queue))
total_cnt = 0
jobs = [
threading.Thread(target=run_client, args=(res_queue,)) for _ in range(n_clients)
]
for job in jobs:
job.start()
for job in jobs:
job.join()
total_cnt += res_queue.get()
assert total_cnt == min(
train_batch_size * n_clients, (1 + offpolicyness) * train_batch_size
)
# Increase the model version by 1
version_name = names.model_version(
constants.experiment_name(),
constants.trial_name(),
"default",
)
name_resolve.add(version_name, "1")
gserver_manager._poll()
# Run the rollout worker again
jobs = [
threading.Thread(target=run_client, args=(res_queue,)) for _ in range(n_clients)
]
for job in jobs:
job.start()
for job in jobs:
job.join()
total_cnt += res_queue.get()
# The rollout worker should produce new samples
assert total_cnt == min(
train_batch_size * n_clients * 2, (2 + offpolicyness) * train_batch_size
)
# Final clean up
name_resolve.delete(version_name)

View File

@ -84,6 +84,7 @@ def partial_rollout_manager():
reply_queue=reply_queue, reply_queue=reply_queue,
new_tokens_per_chunk=new_tokens_per_chunk, new_tokens_per_chunk=new_tokens_per_chunk,
tokenizer=mock_tokenizer, tokenizer=mock_tokenizer,
timeout=300,
) )
yield manager yield manager
# Cleanup if needed # Cleanup if needed

View File

@ -1,174 +0,0 @@
import asyncio
import copy
from asyncio.queues import QueueEmpty
from unittest.mock import patch
import pytest
from realhf.api.core.config import (
AgentAbstraction,
DatasetAbstraction,
EnvServiceAbstraction,
ModelName,
)
from realhf.api.core.model_api import (
BundledGenerationOutputs,
GenerationHyperparameters,
)
from realhf.api.core.system_api import RolloutWorker as RolloutWorkerConfig
from realhf.api.core.system_api import WorkerInformation
from realhf.base import constants, name_resolve, names, network, testing
from realhf.system.push_pull_stream import NameResolvingZmqPusher
from realhf.system.rollout_worker import RolloutWorker
from tests.fixtures import *
N_PULLERS = 3
class MockPartialRolloutManager:
def __init__(self, request_queue, reply_queue, **kwargs):
self.request_queue = request_queue
self.reply_queue = reply_queue
self.internal_queue = []
def get_num_gen_requests(self):
return len(self.internal_queue)
async def run_step(self):
async def poll_fresh_requests():
for _ in range(8):
try:
qid, prompt_token_ids, gconfig = self.request_queue.get_nowait()
assert isinstance(qid, str)
assert isinstance(prompt_token_ids, list)
assert all(isinstance(x, int) for x in prompt_token_ids)
assert isinstance(gconfig, GenerationHyperparameters)
self.internal_queue.append(qid)
except QueueEmpty:
await asyncio.sleep(0.01)
async def poll_old_requests():
for _ in range(8):
if random.random() < 0.5 and len(self.internal_queue) > 0:
# responses may not return in order
idx = random.randint(0, len(self.internal_queue) - 1)
qid = self.internal_queue.pop(idx)
out = BundledGenerationOutputs(
qid=qid,
prompt_ids=[1],
output_ids=[[2], [3]],
seqs=[[1, 2], [1, 3]],
logprobs=[[0.0, 0.1], [0.0, 2.0]],
no_eos=[True, True],
version_start=[0, 1],
version_end=[1, 2],
)
await self.reply_queue.put(out)
else:
await asyncio.sleep(0.01)
await asyncio.gather(poll_fresh_requests(), poll_old_requests())
@pytest.fixture
def rollout_workers(request):
testing.clear_name_resolve()
constants.set_experiment_trial_names(
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
)
# add name resolve to make zmq pusher happy
puller_ports = network.find_multiple_free_ports(N_PULLERS)
for puller_index in range(N_PULLERS):
name = names.stream_pullers(
testing._DEFAULT_EXPR_NAME,
testing._DEFAULT_TRIAL_NAME,
)
name_resolve.add_subentry(name, str(puller_index))
name = names.push_pull_stream(
testing._DEFAULT_EXPR_NAME,
testing._DEFAULT_TRIAL_NAME,
f"puller{puller_index}",
)
name_resolve.add(name, f"localhost:{puller_ports[puller_index]}")
with (
patch.object(NameResolvingZmqPusher, "push", return_value=None) as mock_push,
patch(
"realhf.system.rollout_worker.PartialRolloutManager",
new=MockPartialRolloutManager,
),
):
ms = [RolloutWorker() for _ in range(request.param)]
yield ms
@pytest.mark.parametrize("rollout_workers", [1, 2, 3], indirect=True)
@pytest.mark.parametrize("offpolicyness", [0, 1, 4])
@pytest.mark.asyncio
async def test_offpolicyness_control(
rollout_workers, save_path, dataset, offpolicyness
):
train_batch_size = 8
config = RolloutWorkerConfig(
base_seed=0,
model_name=ModelName("default", 0),
max_head_offpolicyness=offpolicyness,
train_batch_size=train_batch_size,
tokenizer_path="/storage/openpsi/models/Qwen__Qwen2.5-1.5B/",
new_tokens_per_chunk=1024,
max_concurrent_rollouts=128,
env=EnvServiceAbstraction("null"),
agent=AgentAbstraction("null", dict(episode_length=5, traj_size=5)),
datasets=[
DatasetAbstraction(
"prompt", args=dict(dataset_path=str(save_path / "dataset.jsonl"))
)
],
worker_info=WorkerInformation(
experiment_name=testing._DEFAULT_EXPR_NAME,
trial_name=testing._DEFAULT_TRIAL_NAME,
worker_type="rollout_worker",
worker_count=N_PULLERS * 2,
worker_index=0,
),
)
for i, m in enumerate(rollout_workers):
config = copy.deepcopy(config)
config.worker_info.worker_index = i
m._configure(config)
for i in range(10 * (offpolicyness + 1)):
for m in rollout_workers:
await m._poll_async()
# Ensure that data is not overly produced
for m in rollout_workers:
assert m.agent.ACT_GET_CNT > 0
assert (
(offpolicyness + 1) * train_batch_size >= m.push_stream.push.call_count > 0
)
# Increase the model version by 1
version_name = names.model_version(
constants.experiment_name(),
constants.trial_name(),
config.model_name.role,
)
name_resolve.add(version_name, "1")
# Run the rollout worker again
for i in range(10 * (offpolicyness + 1)):
for m in rollout_workers:
await m._poll_async()
# The rollout worker should produce new samples
for m in rollout_workers:
assert (offpolicyness + 2) * train_batch_size >= m.push_stream.push.call_count
assert (offpolicyness + 1) * train_batch_size < m.push_stream.push.call_count
# Final clean up
name_resolve.delete(version_name)
for m in rollout_workers:
await m._exit_async_tasks()

View File

@ -0,0 +1,65 @@
import dataclasses
import datetime
import os
from typing import Dict
import hydra
import yaml
from omegaconf import MISSING, OmegaConf
from realhf.api.quickstart.entrypoint import kind_reminder
from realhf.base.constants import init_constants
from realhf.experiments.async_exp.async_ppo_math_exp import AsyncPPOMATHConfig
from training.utils import run_experiment
@hydra.main(version_base=None, config_path="configs/async-ppo")
def main_ppo_math(args):
# NOTE: we import logging here to avoid hydra logging overwrite
import realhf.base.logging as logging
logger = logging.getLogger("quickstart", "colored")
# Overwrite the python dataclass configuration with yaml
default_args = OmegaConf.structured(AsyncPPOMATHConfig)
args = OmegaConf.merge(default_args, args)
args: AsyncPPOMATHConfig = OmegaConf.to_object(args)
# Set experiment trial name.
exp_name = args.experiment_name
if args.trial_name == MISSING:
args.trial_name = trial_name = (
f"run{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
)
else:
trial_name = args.trial_name
if args.mode != "ray":
raise RuntimeError("This script only supports the `ray` mode.")
init_constants(args)
from realhf.base.constants import LOG_ROOT
# Save overwritten configuration to yaml
config_save_path = os.path.join(
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
)
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
with open(config_save_path, "w") as f:
config_dict: Dict = dataclasses.asdict(args)
yaml.dump(
config_dict,
f,
default_flow_style=False,
sort_keys=False,
)
kind_reminder("async-ppo-math", logger, args)
run_experiment(args, exp_name, trial_name)
if __name__ == "__main__":
# Command: python3 training/main_async_ppo.py --config-name async-ppo-1.7b-gpu8
main_ppo_math()

65
training/main_ppo.py Normal file
View File

@ -0,0 +1,65 @@
import dataclasses
import datetime
import os
from typing import Dict
import hydra
import yaml
from omegaconf import MISSING, OmegaConf
from realhf.api.quickstart.entrypoint import kind_reminder
from realhf.base.constants import init_constants
from realhf.experiments.common.ppo_math_exp import PPOMATHConfig
from training.utils import run_experiment
@hydra.main(version_base=None, config_path="configs/ppo")
def main(args):
# NOTE: we import logging here to avoid hydra logging overwrite
import realhf.base.logging as logging
logger = logging.getLogger("quickstart", "colored")
# Overwrite the python dataclass configuration with yaml
default_args = OmegaConf.structured(PPOMATHConfig)
args = OmegaConf.merge(default_args, args)
args: PPOMATHConfig = OmegaConf.to_object(args)
# Set experiment trial name.
exp_name = args.experiment_name
if args.trial_name == MISSING:
args.trial_name = trial_name = (
f"run{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
)
else:
trial_name = args.trial_name
if args.mode != "ray":
raise RuntimeError("This script only supports the `ray` mode.")
init_constants(args)
from realhf.base.constants import LOG_ROOT
# Save overwritten configuration to yaml
config_save_path = os.path.join(
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
)
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
with open(config_save_path, "w") as f:
config_dict: Dict = dataclasses.asdict(args)
yaml.dump(
config_dict,
f,
default_flow_style=False,
sort_keys=False,
)
kind_reminder("ppo-math", logger, args)
run_experiment(args, exp_name, trial_name)
if __name__ == "__main__":
# Command: python3 training/main_ppo.py --config-name ppo-1.5b-gpu32
main()

65
training/main_sft.py Normal file
View File

@ -0,0 +1,65 @@
import dataclasses
import datetime
import os
from typing import Dict
import hydra
import yaml
from omegaconf import MISSING, OmegaConf
from realhf.api.quickstart.entrypoint import kind_reminder
from realhf.base.constants import init_constants
from realhf.experiments.common.sft_exp import SFTConfig
from training.utils import run_experiment
@hydra.main(version_base=None, config_path="configs/sft")
def main(args):
# NOTE: we import logging here to avoid hydra logging overwrite
import realhf.base.logging as logging
logger = logging.getLogger("quickstart", "colored")
# Overwrite the python dataclass configuration with yaml
default_args = OmegaConf.structured(SFTConfig)
args = OmegaConf.merge(default_args, args)
args: SFTConfig = OmegaConf.to_object(args)
# Set experiment trial name.
exp_name = args.experiment_name
if args.trial_name == MISSING:
args.trial_name = trial_name = (
f"run{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
)
else:
trial_name = args.trial_name
if args.mode != "ray":
raise RuntimeError("This script only supports the `ray` mode.")
init_constants(args)
from realhf.base.constants import LOG_ROOT
# Save overwritten configuration to yaml
config_save_path = os.path.join(
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
)
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
with open(config_save_path, "w") as f:
config_dict: Dict = dataclasses.asdict(args)
yaml.dump(
config_dict,
f,
default_flow_style=False,
sort_keys=False,
)
kind_reminder("sft", logger, args)
run_experiment(args, exp_name, trial_name)
if __name__ == "__main__":
# Command: python3 training/main_sft.py --config-name sft-7b-gpu8
main()

282
training/utils.py Normal file
View File

@ -0,0 +1,282 @@
# Copyright 2025 Ant Group Inc.
import copy
import os
import re
import signal
import sys
import threading
from contextlib import redirect_stderr, redirect_stdout
from typing import Any, List
import psutil
import ray
from realhf.api.core.system_api import Experiment, ExperimentScheduling, TasksGroup
from realhf.base import constants, logging, name_resolve, names
from realhf.system import WORKER_TYPES, load_worker
from realhf.system.worker_base import AsyncWorker, Worker, WorkerServerStatus
# Copied from SGLang
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
"""Kill the process and all its child processes."""
# Remove sigchld handler to avoid spammy logs.
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
if parent_pid is None:
parent_pid = os.getpid()
include_parent = False
try:
itself = psutil.Process(parent_pid)
except psutil.NoSuchProcess:
return
children = itself.children(recursive=True)
for child in children:
if child.pid == skip_pid:
continue
try:
child.kill()
except psutil.NoSuchProcess:
pass
if include_parent:
try:
if parent_pid == os.getpid():
itself.kill()
sys.exit(0)
itself.kill()
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
# so we send an additional signal to kill them.
itself.send_signal(signal.SIGQUIT)
except psutil.NoSuchProcess:
pass
@ray.remote
class RayWorker:
"""A thin wraper over realhf.system.worker_base.Worker."""
def __init__(
self,
worker_type: str,
worker_cls,
kv_store_name,
):
# Register all datasets and models
import realhf.impl.dataset # isort: skip
import realhf.impl.model # isort: skip
os.environ["REAL_MODE"] = "RAY"
name_resolve.reconfigure("ray", actor_name=kv_store_name)
self.worker: Worker | AsyncWorker = worker_cls()
self.worker_type = worker_type
def __repr__(self):
return "".join([c.capitalize() for c in self.worker_type.split("_")])
def configure(self, cfg: Any, expr_config: Any):
constants.init_constants(expr_config)
worker_info = cfg.worker_info
idx = worker_info.worker_index
constants.set_experiment_trial_names(
worker_info.experiment_name, worker_info.trial_name
)
self.worker.wandb_config = expr_config.wandb
self.worker.tensorboard_config = expr_config.tensorboard
self.logger = logging.getLogger(f"{self.worker_type} {idx}", "benchmark")
self.logger.info(f"Configuring {self.worker_type}...")
self.worker._configure(cfg)
self.logger.info(f"Configuring {self.worker_type}... Done.")
def run_sync(self):
self.logger.info(f"Running {self.worker_type} lazy initialization...")
self.worker._poll()
self.logger.info(f"Running {self.worker_type} lazy initialization... Done.")
while self.worker.status != WorkerServerStatus.PAUSED:
self.worker._poll()
async def run_async(self):
self.logger.info(f"Running {self.worker_type} lazy initialization...")
await self.worker._poll_async()
self.logger.info(f"Running {self.worker_type} lazy initialization... Done.")
while self.worker.status != WorkerServerStatus.PAUSED:
await self.worker._poll_async()
def _run_experiment(exp_cfg, expr_name, trial_name):
# Register all datasets and models
import realhf.impl.dataset # isort: skip
import realhf.impl.model # isort: skip
from realhf.api.core.system_api import ALL_EXPERIMENT_CLASSES
from realhf.system.master_worker import MasterWorker
constants.set_experiment_trial_names(expr_name, trial_name)
logger = logging.getLogger(f"RayMasterWorker", "benchmark")
# Initialize ray in the Ray cluster
env_vars = constants.get_env_vars(
WADNB_MODE=exp_cfg.wandb.mode,
REAL_MODE=os.environ.get("REAL_MODE", ""),
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""),
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),
REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"),
REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", "localhost:2379"),
)
runtime_env = {
"env_vars": env_vars,
"working_dir": os.getcwd(),
}
logger.info(f"Ray workers runtime env: {runtime_env}")
ray_log_path = exp_cfg.ray_temp_path
os.makedirs(ray_log_path, exist_ok=True)
ray.init(runtime_env=runtime_env, _temp_dir=ray_log_path)
logger.info(f"Ray log root: {ray_log_path}")
logger.info("Ray initialized! Ready to run workers.")
ray_kv_store_name = f"{expr_name}/{trial_name}/ray_kv_store"
name_resolve.reconfigure("ray", actor_name=ray_kv_store_name)
name_resolve.clear_subtree(
names.trial_root(experiment_name=expr_name, trial_name=trial_name)
)
# Convert CLI args into worker configurations
exp_setup = exp_cfg.initial_setup()
exp_setup.set_worker_information(expr_name, trial_name)
exp_setup.lazy_init()
# Initialize all workers
all_workers = {}
scheduling: ExperimentScheduling = exp_cfg.scheduling_setup()
for worker_type in WORKER_TYPES:
sch = getattr(scheduling, worker_type)
if sch is None:
continue
available_resources = ray.available_resources()
cpu = sch.scheduling.cpu * sch.count
gpu = sch.scheduling.gpu * sch.count
mem = sch.scheduling.mem * sch.count / 1024 # in GB
acpu = available_resources.get("CPU", 0)
agpu = available_resources.get("GPU", 0)
amem = available_resources.get("memory", 0) / 1024**3
if acpu < cpu or agpu < gpu or amem < mem:
logger.critical(
f"Ray does not have enough resources to launch workers. "
f"Required: {cpu} CPU, {gpu} GPU, {mem:.2f} GB memory. "
f"Available: {acpu} CPU, {agpu} GPU, {amem:.2f} GB memory. "
f"Please launch more Ray nodes otherwise the experiment will get stuck."
)
# Use a customized packed scheduling method
# that sequentially allocates nodes.
available_nodes = [
k
for k in available_resources
if re.match(r"node:(\b(?:\d{1,3}\.){3}\d{1,3}\b)", k)
]
total_gpus = available_resources["GPU"]
n_gpus_per_node = int(total_gpus // len(available_nodes))
count = sch.count
all_schedules: List[TasksGroup] = []
for _ in range(sch.count):
s_ = copy.deepcopy(sch)
s_.count = 1
all_schedules.append(s_)
workers = []
for node_idx, i in enumerate(range(0, count, n_gpus_per_node)):
_schedules = all_schedules[i : i + n_gpus_per_node]
for _idx, sch in enumerate(_schedules):
# Schedule jobs one-by-one to maintain the order on remote nodes.
worker = RayWorker.options(
name=f"{worker_type}/{_idx + i}",
num_cpus=sch.scheduling.cpu,
num_gpus=sch.scheduling.gpu,
memory=sch.scheduling.mem * 1024**2,
).remote(
worker_type=worker_type,
worker_cls=load_worker(worker_type),
kv_store_name=ray_kv_store_name,
)
workers.append(worker)
all_workers[worker_type] = workers
try:
# Configure workers
configure_jobs = []
for worker_type in all_workers:
worker_configs = getattr(exp_setup, worker_type)
workers = all_workers[worker_type]
assert len(workers) == len(worker_configs), (
len(workers),
len(worker_configs),
)
jobs = [
w.configure.remote(c, exp_cfg) for w, c in zip(workers, worker_configs)
]
configure_jobs += jobs
ray.get(configure_jobs)
# Run workers
run_jobs = []
for worker_type in all_workers:
workers = all_workers[worker_type]
if worker_type in ["master_worker", "rollout_worker", "gserver_manager"]:
# Only the rollout worker is asynchronous
jobs = [w.run_async.remote() for w in workers]
else:
jobs = [w.run_sync.remote() for w in workers]
run_jobs += jobs
ray.get(run_jobs)
finally:
ray.shutdown()
class DualOutput:
def __init__(self, file, terminal):
self.file = file
self.terminal = terminal
def write(self, message):
self.terminal.write(message)
self.file.write(message)
def flush(self):
self.terminal.flush()
self.file.flush()
def fileno(self):
# Return the terminal's fileno to maintain expected behavior
return self.terminal.fileno()
def run_experiment(exp_cfg, expr_name, trial_name):
log_path = os.path.join(constants.LOG_ROOT, expr_name, trial_name, "main.log")
with open(log_path, "a") as f:
# Create dual output handler
dual_out = DualOutput(f, sys.stdout)
dual_err = DualOutput(f, sys.stderr)
# Redirect stdout and stderr
with redirect_stdout(dual_out), redirect_stderr(dual_err):
_run_experiment(exp_cfg, expr_name, trial_name)