[Feature & Doc & Bug Fix] Add docs, simplified ray-based scripts, and fix issues to stablize asynchronous experiments (#52)
* feat: one buffer for each task * feat: support "one buffer for each task" for async * make kv_cache_dtype configurable Signed-off-by: Tiwei Bie <tiwei.btw@antgroup.com> * style: use plural form fix: use _seed_from_key to set different seeds for data loaders fix: call load_data for one buffer each time * PullRequest: 125 Support running async experiments in the 2407 image. Merge branch fw/async2407 of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/125 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * fix: handle multiple datasets in recover indices fix: `isinstance(self.__datasets, PullerStreamDataset)` feat: use the "spec" request to obtain the number of datasets fix: revert rollout worker * fix: revert async_rl_exp.py * fix flag for list (cuda_graph_bs) * format * [FIX] fix async task reward [sglang bf16-> fp16] * fix: define `self.__datasets` in advance * PullRequest: 130 [Refactor] Remove deprecated search related code Merge branch mzy/remove-search of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/130 Signed-off-by: 博惟 <bowei.fw@antgroup.com> * remove search related * PullRequest: 131 [Refactor] Change terminology "model parallel" into "tensor parallel" to align with megatron. Merge branch mzy/mp-to-tp of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/131?tab=comment Signed-off-by: 博惟 <bowei.fw@antgroup.com> * change mp to tp * . * . * PullRequest: 142 Fix an error for megatron backend destroy Merge branch fw/fix-meagatron-destroy of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/142 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * PullRequest: 143 Fix the port conflict issue of generation servers Merge branch fw/fix-gen-port of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/143?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * somehow fix the port issue * add clearance period * . * . * PullRequest: 145 Add code environment Merge branch fw/code-env of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/145?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * add code env * somehow fix the port issue * fix * PullRequest: 144 Add decoupled PPO loss Merge branch fw/decoupled-ppo-loss of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/144?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * fix ppo step logging, nan in stats tracker, and add decoupled loss * . * somehow fix the port issue * fix typo * PullRequest: 146 Merge SLURM logs and save experiment configs in yaml format. Merge branch fw/better-logging of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/146 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * merge all slurm logs into one * write config to yaml * PullRequest: 141 Merge changes during NeurIPS submission Merge branch fw/async-dev of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/141 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * . * . * . * update script * . * . * . * . * [ADD] add least req scheduling * fix test genreq * . * . * fix stats tracker nan * . * . * . * . * . * . * . * uppper clip decoupled objective * add throughput exp script * . * remove behav upper clip param * . * . * . * plot curve * update thpt script * . * master worker raise error when exiting * update script * add gen throughput logging * . * . * add decoupled wandb data * . * fix port issue and add no training option * . * enlarge ttl * remove gserver manager await staled * update weights in groups * . * . * . * add port clearance period * . * . * . * add plot script * add sft throughput eval * . * log tokens in null interface * 消融实验和interruptible generation * 画图脚本/运行脚本/数据结果 * . * remove scripts * add port test * remove force_sync_reward * revert some changes * . * revert * revert fix * fix * revert * fix typo * support qwen3 training * PullRequest: 147 Support interruption in SGLang and fix a KeyError in gather-scatter communication Merge branch fw/sglang046-with-abort-request of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/147?tab=diff Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * fix ppo step logging, nan in stats tracker, and add decoupled loss * . * somehow fix the port issue * initial commit * add interupt request * fix data transfer issue * max concurrent rollouts defaults to train batch size * merge main * add patch * fix patch typp * revert sglang * fix typo * fix minor typo * . * pip show editable sglang path * PullRequest: 149 fix: code faas max_retries Merge branch xss/fix_code_verifier of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/149 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * fix: code faas max_retries * PullRequest: 150 [Bug Fix] Fix key errors in `_run_scatter` in data transfer Merge branch mzy/fix-scatter-groups of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/150 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * fix scatter groups key error * fix test * . * PullRequest: 151 Fix Qwen3 import error when using transformers with a lower version Merge branch fw/fix-qwen3 of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/151 Reviewed-by: 温差 <xushusheng.xss@antgroup.com> * merge all slurm logs into one * write config to yaml * . * PullRequest: 152 Support sglang0.4.6 and fix master_worker import error Merge branch adopt_sglang046 of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/152 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * Support sglang0.4.6 and fix master_worker import error * remove disable_mla option * PullRequest: 155 [FIX] reduce port conflicts Merge branch sxj/reduce_port_conflict of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/155 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * [FIX] reduce port conflicts * PullRequest: 153 Fix stuck and recover issues for async experiments Merge branch fw/stable-async of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/153 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * fix sample cnt stuck * fix recover * code cleanup * merge all slurm logs into one * write config to yaml * . * . * . * revert birth time change * . * enlarge sock connect timeout * PullRequest: 158 [Fix] Fix the error where "accepted" is not defined Merge branch fw/fix-rollout-accepted of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/158 Reviewed-by: 温差 <xushusheng.xss@antgroup.com> * . * PullRequest: 154 Fix unit tests and simplify package installation Merge branch fw/v0.3.0-tests of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/154?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * fix some tests * fix tests except for experiments * fix tests * fix tests * . * . * PullRequest: 159 [fix] Enlarge the default aiohttp connection timeout and fix a recover error in model worker Merge branch fw/stable-async of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/159 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * fix sample cnt stuck * fix recover * code cleanup * merge all slurm logs into one * write config to yaml * . * . * . * revert birth time change * . * enlarge sock connect timeout * . * PullRequest: 160 set sock_connect as rollout_request_timeout in partial_rollout.py Merge branch xss/rollout_timeout of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/160 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * set sock_connect as rollout_request_timeout in partial_rollout.py * PullRequest: 161 Prioritize rollouts that are submitted earlier rather than arrived earlier Merge branch fw/birth-time of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/161 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * blocking push * PullRequest: 163 [bugfix] Fix synchronized training when birth time is absent Merge branch fw/fix-sync-birthtime of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/163 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * PullRequest: 164 [Refactor] Move cluster spec into CLI args Merge branch fw/refactor-cluster-spec of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/164?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * set cluster spec path in args * . * fix * add default cluster spec * PullRequest: 165 Normally exit all workers after experiment completion Merge branch fw/exit-all-workers of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/165 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * PullRequest: 167 [Feature] Use chunked logits computation to alleviate SGLang OOM Merge branch fw/patch-sglang-oom of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/167 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * PullRequest: 166 [Feature] Support single-script experiment launch with Ray Merge branch fw/turbolaunch of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/166?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * add training script without ray name resolve * add ray name resolve * ray worker * run * run async * local run * set cluster spec path in args * . * . * fix * . * . * . * . * . * update config * . * minor renaming * PullRequest: 169 [Doc] Add v0.3.0 docs based on jupyter-book Merge branch fw/doc of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/169 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * add docs * refine doc * refine doc --------- Signed-off-by: Tiwei Bie <tiwei.btw@antgroup.com> Co-authored-by: wanghuaijie.whj <wanghuaijie.whj@antgroup.com> Co-authored-by: Tiwei Bie <tiwei.btw@antgroup.com> Co-authored-by: kira.gw <kira.gw@antgroup.com> Co-authored-by: shenxujie.sxj <shenxujie.sxj@antgroup.com> Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com> Co-authored-by: sam.gjx <sam.gjx@antgroup.com> Co-authored-by: 温差 <xushusheng.xss@antgroup.com> Co-authored-by: 履渊 <yuhong.gyh@antgroup.com>
|
@ -0,0 +1,32 @@
|
|||
# Book settings
|
||||
# Learn more at https://jupyterbook.org/customize/config.html
|
||||
|
||||
title: AReaL Documentation
|
||||
author: Wei Fu
|
||||
logo: figures/logo.png
|
||||
|
||||
# Force re-execution of notebooks on each build.
|
||||
# See https://jupyterbook.org/content/execute.html
|
||||
execute:
|
||||
execute_notebooks: force
|
||||
|
||||
# Define the name of the latex output file for PDF builds
|
||||
latex:
|
||||
latex_documents:
|
||||
targetname: book.tex
|
||||
|
||||
# Add a bibtex file so that we can create citations
|
||||
bibtex_bibfiles:
|
||||
- references.bib
|
||||
|
||||
# Information about where the book exists on the web
|
||||
repository:
|
||||
url: https://github.com/inclusionAI/AReaL # Online location of your book
|
||||
path_to_book: docs # Optional path to your book, relative to the repository root
|
||||
branch: main # Which branch of the repository should be used when creating links (optional)
|
||||
|
||||
# Add GitHub buttons to your book
|
||||
# See https://jupyterbook.org/customize/config.html#add-a-link-to-your-repository
|
||||
html:
|
||||
use_issues_button: true
|
||||
use_repository_button: true
|
|
@ -0,0 +1,22 @@
|
|||
# Table of contents
|
||||
# Learn more at https://jupyterbook.org/customize/toc.html
|
||||
|
||||
format: jb-book
|
||||
root: intro
|
||||
parts:
|
||||
- caption: Tutorial
|
||||
chapters:
|
||||
- file: installation
|
||||
- file: training
|
||||
- file: eval
|
||||
- file: troubleshooting
|
||||
- caption: Developer Manual
|
||||
chapters:
|
||||
- file: developer/exp_launch
|
||||
- file: developer/master_worker
|
||||
- file: developer/model_worker
|
||||
- file: developer/algo_interface
|
||||
- file: developer/allocation_parallel
|
||||
- caption: Contributing
|
||||
chapters:
|
||||
- file: contrib
|
|
@ -0,0 +1,74 @@
|
|||
# Contribution Guide
|
||||
|
||||
Thank you for your interest in contributing to AReaL! We welcome contributions from everyone, whether you're fixing bugs, improving documentation, or adding new system and algorithmic features.
|
||||
|
||||
## Setting Up Your Development Environment
|
||||
|
||||
New contributors do not have write permissions to the official repository. Please fork the repository and clone your fork locally. AReaL is fully Python-based, making installation straightforward.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/${your-username}/AReaL
|
||||
cd AReaL
|
||||
pip3 install -r requirements.txt
|
||||
pip3 install -e .
|
||||
```
|
||||
|
||||
## Issue Guidelines
|
||||
|
||||
### Issue Templates
|
||||
|
||||
Please follow the [issue template on GitHub](https://github.com/inclusionAI/AReaL/tree/main/.github/ISSUE_TEMPLATE). Issues can be:
|
||||
- Bug reports
|
||||
- Feature requests
|
||||
- Refactor requests
|
||||
|
||||
The required fields in the template help reduce communication overhead when resolving issues. **Issues with arbitrary formatting may be ignored.**
|
||||
|
||||
## Pull Request Guidelines
|
||||
|
||||
There are no specific PR templates, but **pull requests should be related to a well-templated issue**. Your PR should:
|
||||
- Explain how the issue is resolved
|
||||
- Describe the benefits this change will provide
|
||||
- Reference the related issue number
|
||||
|
||||
## Code Quality
|
||||
|
||||
### Code Formatting
|
||||
|
||||
Please format your code before opening a PR:
|
||||
|
||||
```bash
|
||||
isort . && black .
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
AReaL's unit tests are based on the `pytest` framework:
|
||||
|
||||
```bash
|
||||
# Run all tests (excluding GPU tests)
|
||||
pytest -m "not gpu"
|
||||
|
||||
# Run a specific test case
|
||||
pytest tests/test_something.py
|
||||
```
|
||||
|
||||
**Note**: Running all tests may take several hours to complete.
|
||||
|
||||
## Documentation
|
||||
|
||||
Writing documentation is an excellent starting point for new contributors. The documentation is located in the `docs` folder and built using [Jupyter Book](https://jupyterbook.org/en/stable/intro.html).
|
||||
|
||||
### Adding New Documentation
|
||||
|
||||
1. Create your documentation files in the `docs` folder
|
||||
2. Add the file path to `docs/_toc.yaml`
|
||||
3. Build the documentation:
|
||||
|
||||
```bash
|
||||
jb build docs
|
||||
```
|
||||
|
||||
4. Preview your changes by opening the HTML files in `docs/_build/html`
|
||||
|
||||
This process allows you to see how your documentation will appear before submitting your contribution.
|
|
@ -0,0 +1,103 @@
|
|||
# Algorithm, Interface & Backends
|
||||
|
||||
## Overview
|
||||

|
||||
|
||||
Model Interfaces define the computations that can be performed, such as training, inference, and generation. They provide abstract classes and implementations that decouple specific algorithms (e.g., PPO, SFT) from model backends (Megatron, SGLang, vLLM). Algorithm developpers may be more interested in adding customized model interfaces.
|
||||
|
||||
Model backends integrate external libraries to wrap over the model as a `PipelinableEngine`, such that they can provide efficient distributed training and inference capabilities.
|
||||
|
||||
## Registeration
|
||||
Backends and interfaces have similar registeration protocols:
|
||||
|
||||
```python
|
||||
# Registration (at the end of each interface implementation):
|
||||
model_api.register_interface("ppo", PPOActorInterface)
|
||||
|
||||
# Configuration (in experiment config file):
|
||||
interface_config = ModelInterfaceAbstraction(
|
||||
type_="ppo",
|
||||
args=dict(eps_clip=0.2)
|
||||
)
|
||||
|
||||
# Instantiation (in model worker):
|
||||
interface = make_interface(interface_config)
|
||||
```
|
||||
|
||||
## Customization
|
||||
### Interfaces
|
||||
An interface implementation essentially processes the data and loss function (e.g., reward clipping, computing GAEs) required by a `PipelinableEngine`, calls the actual execution method such as `PipelinableEngine.train_step`, and then runs some post-processing according to the data protocol.
|
||||
|
||||
Custom interfaces can be created by subclassing the `ModelInterface` class and implementing the required methods for the desired training paradigm.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomInterface(model_api.ModelInterface):
|
||||
# Custom parameters
|
||||
custom_param: float = 1.0
|
||||
|
||||
def train_step(self, model, data, mb_spec):
|
||||
module = model.module
|
||||
module.train()
|
||||
|
||||
# Custom training logic
|
||||
stats = module.train_batch(
|
||||
input_=data,
|
||||
loss_fn=custom_loss_function,
|
||||
loss_weight_fn=lambda x: x.data["mask"].count_nonzero(),
|
||||
token_normalize_scope="global",
|
||||
mb_spec=mb_spec,
|
||||
version_steps=model.version.global_step,
|
||||
)
|
||||
|
||||
model.inc_version()
|
||||
return stats
|
||||
|
||||
def save(self, model, save_dir):
|
||||
module = model.module
|
||||
module.save_to_hf(tokenizer=model.tokenizer, save_dir=save_dir)
|
||||
|
||||
# Register the interface
|
||||
model_api.register_interface("custom", CustomInterface)
|
||||
```
|
||||
|
||||
Required methods vary based on the interface purpose:
|
||||
|
||||
+ For training interfaces: `train_step()` and `save()`
|
||||
+ For inference-only interfaces: `inference()`
|
||||
+ For generation interfaces: `generate()`
|
||||
|
||||
The interface can be configured in the experiment configuration file, e.g., `ppo_math_exp.py`. Please refer to xxx how to run unittests on your implementation.
|
||||
|
||||
### Backends
|
||||
Backend requires implementing the `_initialize`method. Example:
|
||||
|
||||
```python
|
||||
class FSDPEngine(PipelinableEngine):
|
||||
def train_step(self, ...):
|
||||
...
|
||||
|
||||
class FSDPBackend(ModelBackend):
|
||||
|
||||
def _initialize(self, model):
|
||||
module = model.module
|
||||
model.module: PipelinableEngine = FSDPEngine(module)
|
||||
return model
|
||||
|
||||
register_backend("fsdp", FSDPBackend)
|
||||
```
|
||||
|
||||
## Existing Implementations
|
||||
### Interfaces
|
||||
+ `ppo_interface.py`: Implemetation of PPO actor and critic.
|
||||
+ `sft_interface.py`: Implementation of SFT.
|
||||
|
||||
### Backends
|
||||
+ `megatron.py`: Training wrapper based on Megatron Core's `DistributedDataParallel`
|
||||
+ `sglang.py`: A wrapper over a SGLang HTTP server for batched generation.
|
||||
+ `vllm.py`: Deprecated SPMD vLLM backend.
|
||||
|
||||
|
||||
|
After Width: | Height: | Size: 338 KiB |
|
@ -0,0 +1,67 @@
|
|||
# Allocation & Parallelism
|
||||
|
||||
## GPU Allocation
|
||||
GPU allocation is controlled by the `allocation_mode` CLI parameter. The most common pattern looks like `"sglang.d2t2p1+d1t4p1"`, which means:
|
||||
|
||||
+ The first 4 GPUs are allocated to SGLang for inference with:
|
||||
- 2-way tensor parallelism
|
||||
- 2-way data parallelism
|
||||
+ The remaining GPUs are allocated for training with 4-way tensor parallelism
|
||||
|
||||
## Parallelism Strategies
|
||||
|
||||
### Training
|
||||
AReaL supports three parallelism strategies for dense models, similar to Megatron:
|
||||
|
||||
+ Data Parallelism: Uses Megatron's DistributedDataParallel with AReaL's balanced DP partitioning algorithm (`SequenceSample.split`)
|
||||
+ Tensor Parallelism: Fully replicates Megatron's `ColumnParallelLinear` and `RowParallelLinear`
|
||||
+ Pipeline Parallelism: Developed in-house with 1F1B scheduling (planned to be replaced with an open-source implementation due to maintenance challenges)
|
||||
|
||||
### Inference
|
||||
AReaL supports SGLang inference with intra-node tensor parallelism and customized data parallelism.
|
||||
|
||||
### Parameter Partitioning
|
||||
Each model worker holds multiple model shards based on the allocation configuration.
|
||||
|
||||
Example: With 4 GPUs configured as:
|
||||
|
||||
+ Actor model: First half GPUs with tensor parallelism
|
||||
+ Critic model: Second half GPUs with pipeline parallelism
|
||||
+ Reference model: All GPUs with tensor and pipeline parallelism
|
||||
|
||||
The parameter distribution would be:
|
||||
|
||||

|
||||
|
||||
## Torch NCCL Communication Groups
|
||||
During experiments, the following NCCL communication groups are created:
|
||||
|
||||
1. Global group: Includes all experiment GPUs (created in `global_comm.py`)
|
||||
2. Parallelism Group: 3D parallel communication groups for a specific model (may match global group or be a subset, created in `topology.py`)
|
||||
3. Data transfer groups: Groups between all data-parallel processes of any two models for data transfer (created in `data_manager.py`)
|
||||
|
||||
## Parallelism Ranks
|
||||
Each model worker has a unique GPU index, but may have different parallel strategy coordinates under different model names (actor, critic, etc.).
|
||||
|
||||
Example: GPU 2 might have:
|
||||
|
||||
+ TP rank 1 for actor model
|
||||
+ TP rank 0 for reference model
|
||||
|
||||
Parallel strategy coordinates are maintained in `realhf.base.constants` and accessed via:
|
||||
|
||||
```bash
|
||||
with constants.model_scope(ModelName("actor", 0)):
|
||||
dp_rank1 = constants.data_parallel_rank()
|
||||
with constants.model_scope(ModelName("ref", 0)):
|
||||
dp_rank2 = constants.data_parallel_rank()
|
||||
```
|
||||
|
||||
Note: Interface and backend methods are automatically called within a model scope, so the context manager can be omitted in those implementations.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
After Width: | Height: | Size: 162 KiB |
|
@ -0,0 +1,3 @@
|
|||
# Launching Procedure
|
||||
|
||||

|
After Width: | Height: | Size: 130 KiB |
After Width: | Height: | Size: 389 KiB |
After Width: | Height: | Size: 260 KiB |
|
@ -0,0 +1,29 @@
|
|||
# Master Worker
|
||||
|
||||
## Overview
|
||||

|
||||
|
||||
The worker architecture of AReaL consists of a single master worker coordinating multiple model workers.
|
||||
|
||||
An RL algorithm typically contains several model function calls (MFCs) that need to be executed in a certain order. For example in PPO,
|
||||
|
||||
1. `actor_gen` generates responses given a batch of user prompts;
|
||||
2. `ref_inf` computes the log-probabilities of the tokens under the reference policy;
|
||||
3. `rew_inf` computes the rewards of the responses;
|
||||
4. `actor_train` updates the policy with the PPO learning objective.
|
||||
|
||||
Here model function calls 2 and 3 depends on the output of 1. Model function call 4 depends on the outputs of 1, 2, and 3.
|
||||
|
||||
The MFCs are coordinated by a `FunctionExecutor` instance. It creates a `ModelFunctionCall` instance for each MFC. The actual computation is performed on model workers via remote procedure call.
|
||||
|
||||
## Buffer and MFC Execution Order
|
||||

|
||||
|
||||
The master worker creates a `AsyncIOSequenceBuffer`, which is referenced by the `FunctionExecutor` and the `ModelFunctionCall`'s. The buffer is responsible for managing (meta)data and deciding the execution order of the MFCs.
|
||||
|
||||
Each datapoint can be seen as a `dict` of tensors. For example, the keys may include `packed_prompts` and `task_ids`. Recall that some MFC may rely on the output of another. For example in PPO, the MFC `ref_inf` requires `packed_input_ids`, which is not presented initially. Instead, `packed_input_ids` appears as one of the results of the MFC `actor_gen`.
|
||||
|
||||
The buffer keeps track of the available keys of each datapoint. Each `ModelFunctionCall`instance obtains its next batch via `self.get_batch_for_rpc`, which waits for enough datapoints with all the required keys. This means that it would not start execution until all required keys are ready. After a model function call execution, it calls `self.amend_batch` and updates the corresponding datapoints with new keys.
|
||||
|
||||
While some keys are the results of MFCs, some are loaded from the dataset via `FunctionExecutor.load_data`. Also note that instead of the actual data, the buffer stores only metadata (data indices, keys, etc.) to reduce the cost of data transfer.
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
# Model Worker
|
||||
## Master-Model Worker Interaction
|
||||
The master worker sends remote procedure calls (RPCs) to model workers to execute actual computations like `actor_gen` and `actor_train`. The figure below illustrates their interaction throughout an experiment:
|
||||
|
||||

|
||||
|
||||
Model worker "compute" involves running a model interface with a specific backend (covered in detail later). For PPO algorithms, model workers sequentially execute:
|
||||
|
||||
+ `actor_gen`: `actor` model with SGlang backend + `PPOActorInterface.generate`
|
||||
+ `rew_inf`: `reward` model (can be null for RLVR) + `MultiTaskRewardInterface.inference`
|
||||
+ `actor_train`: `actor` model with Megatron backend + `PPOActorInterface.train_step`
|
||||
|
||||
## Communication Protocol
|
||||
### Request-Reply Pattern
|
||||
The master worker and model workers communicate through a `request_reply_stream` channel that handles requests and metadata responses (actual data like `input_ids` transfers through other channels).
|
||||
|
||||
Master (client) can send these requests to model workers (servers):
|
||||
|
||||
+ **fetch**: Worker loads local dataset data and sends metadata (e.g., sequence length) to master for buffer storage
|
||||
+ **spec**: Worker returns dataset specifications for master to calculate experiment steps
|
||||
+ **model_config**: Worker provides transformer model configuration
|
||||
+ **clear_data_cache**: Worker clears data transfer and GPU caches
|
||||
+ **initialize**: Worker initializes parameters, gradient buffers, and optimizer states
|
||||
+ **generate/inference/train_step**: Worker executes corresponding computation (note: "inference" refers to single forward pass)
|
||||
|
||||
### Request Hooks
|
||||
Computation requests ("generate"/"inference"/"train_step") support pre- and post-hooks for:
|
||||
|
||||
+ Data transfer (pre-hook)
|
||||
+ Evaluation
|
||||
+ Offloading
|
||||
+ Parameter reallocation
|
||||
+ Checkpointing (post-hooks)
|
||||
|
||||
These hooks often require NCCL communication/synchronization between workers. Implementing them as dedicated hooks prevents deadlocks that could occur if these operations interleaved with other NCCL communications.
|
||||
|
||||
### Request Types
|
||||
+ **Blocking requests**: Long-running operations requiring NCCL synchronization. Workers can't execute immediately since concurrent blocking requests may need coordinated data transfers. Master sends a "flush" request to indicate all concurrent requests have been sent.
|
||||
+ **Non-blocking requests**: Shorter operations without NCCL requirements that can execute immediately.
|
||||
|
||||
## Data Management
|
||||
### Distributed Dataset Storage
|
||||
Datasets distribute across model workers without overlap. For each model:
|
||||
|
||||
+ Processes with PP rank = -1 and TP rank = 0 serve as DP heads
|
||||
+ Data stores on DP heads of the model used in the first MFC (e.g., actor model DP heads for PPO)
|
||||
|
||||
During "fetch" requests:
|
||||
|
||||
1. DP head worker loads data into local buffer
|
||||
2. Sends metadata to master
|
||||
3. Master tracks metadata and later instructs workers which data to use for each MFC via computation request hooks
|
||||
|
||||
### Data Transfer Process
|
||||
For each MFC, the master:
|
||||
|
||||
1. Specifies which data to use
|
||||
2. Provides data locations across workers
|
||||
3. Workers redistribute data using:
|
||||
- `Redistributor`: Generates NCCL broadcast/gather/scatter communication plan
|
||||
- `DataManager`: Executes the plan
|
||||
|
||||
After redistribution, workers with same DP rank receive identical input data.
|
||||
|
||||
### MFC Output Handling
|
||||
Only workers with PP rank=-1 and TP rank=0 produce output data. These workers:
|
||||
|
||||
1. Store data locally
|
||||
2. Notify master of data locations
|
||||
3. Master generates new redistribution plans for subsequent MFCs based on this layout information
|
||||
|
||||
|
||||
|
After Width: | Height: | Size: 24 KiB |
|
@ -0,0 +1,83 @@
|
|||
# Evaluation
|
||||
|
||||
The evaluation code is located in the `evaluation` folder of the repository. Following the previous tutorial, trained checkpoints will be saved under `/storage/ray/experiments/checkpoints/root/`.
|
||||
|
||||
## Setup Evaluation Environment
|
||||
|
||||
Start a new container to execute the evaluation script. **Note**: Evaluation requires updates to certain Python libraries, so avoid using the training container for this task.
|
||||
|
||||
```bash
|
||||
docker run -d --name areal-eval --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.3.0 /bin/bash -c "tail -f /dev/null"
|
||||
docker exec -it areal-eval bash
|
||||
```
|
||||
|
||||
## Install Dependencies and Run Evaluation
|
||||
|
||||
Execute the following commands inside the Docker container:
|
||||
|
||||
```bash
|
||||
cd /storage/codes/AReaL/evaluation
|
||||
cd latex2sympy
|
||||
pip install -e .
|
||||
cd ..
|
||||
pip install -r requirements.txt
|
||||
pip install vllm==0.8.5 --no-build-isolation
|
||||
pip install transformers==4.51.1
|
||||
pip install prettytable timeout_decorator
|
||||
mkdir -p /storage/ray/eval_output/
|
||||
nohup python eval_and_aggregate.py \
|
||||
--model_path /storage/ray/experiments/checkpoints/root/my-exp/my-trial/epoch1epochstep20globalstep20/ \
|
||||
--output_path /storage/ray/eval_output/ \
|
||||
--data_names "math_500,aime24,amc23" \
|
||||
--max_gen_tokens 32768 &> /storage/ray/eval_output/eval_and_aggregate_parallel.log &
|
||||
```
|
||||
|
||||
### Command Line Parameters
|
||||
|
||||
- **`--model_path`**: Path to the saved model parameters
|
||||
- **`--output_path`**: Path to store generated answers and log files during evaluation
|
||||
- **`--data_names`**: Dataset(s) to evaluate. Multiple datasets can be separated by commas. Available options: `math_500`, `math`, `gsm8k`, `train_amc_aime`, `aime24`, `amc23`
|
||||
- **`--max_gen_tokens`**: Maximum length of generated answers (default: 32768)
|
||||
|
||||
## Evaluation Results
|
||||
|
||||
The evaluation script will output a results table in the terminal:
|
||||
|
||||
```
|
||||
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
|
||||
| dataset | num_questions | greedy_length | sample_length | greedy_acc | sample_pass@1 | pass@8 | pass@16 |
|
||||
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
|
||||
| math_500 | 500 | 6757.4 | 4139.5 | 84.4 | 92.7 | 97.3 | 97.7 |
|
||||
| aime24 | 30 | 19328.0 | 13663.5 | 50.0 | 50.4 | 77.3 | 80.0 |
|
||||
| amc23 | 40 | 8850.0 | 6526.2 | 80.0 | 90.5 | 96.8 | 98.8 |
|
||||
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
|
||||
```
|
||||
|
||||
### Metrics Explanation
|
||||
|
||||
- **`{greedy|sample}_length`**: Average answer length under greedy or random sampling strategy
|
||||
- **`greedy_acc`**: Average accuracy under greedy sampling
|
||||
- **`sample_pass@{k}`**: Probability of generating a correct answer within `k` attempts under random sampling
|
||||
|
||||
## Configuration Details
|
||||
|
||||
### Sampling Parameters
|
||||
|
||||
- The evaluation script defaults to averaging 32 samples with temperature 0.6
|
||||
- We observed that the `enforce_eager` parameter in vLLM significantly impacts evaluation performance
|
||||
- When `enforce_eager=True`, we can reproduce the model performance reported in previous work
|
||||
- Without this setting, evaluation results may fall below reported performance
|
||||
- Therefore, we enforce `enforce_eager=True` during evaluation
|
||||
|
||||
### Runtime Expectations
|
||||
|
||||
Due to the sampling requirements and `enforce_eager` setting, the evaluation process typically takes considerable time.
|
||||
|
||||
Runtime depends on several factors:
|
||||
- Maximum generation length
|
||||
- Number of questions in the dataset
|
||||
- Model size
|
||||
|
||||
**Performance benchmarks** (on 8x H100 GPUs):
|
||||
- **AIME dataset**: ~80 minutes
|
||||
- **MATH_500 dataset**: ~160 minutes
|
After Width: | Height: | Size: 320 KiB |
|
@ -0,0 +1,78 @@
|
|||
# Installation
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Hardware Requirements
|
||||
|
||||
The following hardware configuration has been extensively tested:
|
||||
|
||||
- **GPU**: 8x H800 per node
|
||||
- **CPU**: 64 cores per node
|
||||
- **Memory**: 1TB per node
|
||||
- **Network**: NVSwitch + RoCE 3.2 Tbps
|
||||
- **Storage**:
|
||||
- 1TB local storage for single-node experiments
|
||||
- 10TB shared storage (NAS) for distributed experiments
|
||||
|
||||
### Software Requirements
|
||||
|
||||
| Component | Version |
|
||||
|---|:---:|
|
||||
| Operating System | CentOS 7 / Ubuntu 22.04 or any system meeting the requirements below |
|
||||
| NVIDIA Driver | 550.127.08 |
|
||||
| CUDA | 12.8 |
|
||||
| Git LFS | Required for downloading models, datasets, and AReaL code. See [installation guide](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage) |
|
||||
| Docker | 27.5.1 |
|
||||
| NVIDIA Container Toolkit | See [installation guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) |
|
||||
| AReaL Image | `ghcr.io/inclusionai/areal-runtime:v0.3.0` (includes runtime dependencies and Ray components) |
|
||||
|
||||
**Note**: This tutorial does not cover the installation of NVIDIA Drivers, CUDA, or shared storage mounting, as these depend on your specific node configuration and system version. Please complete these installations independently.
|
||||
|
||||
## Runtime Environment
|
||||
|
||||
We recommend using Docker with our provided image. The Dockerfile is available in the top-level directory of the AReaL repository.
|
||||
|
||||
Pull the Docker image:
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/inclusionai/areal-runtime:v0.3.0
|
||||
```
|
||||
|
||||
This image includes all training requirements for AReaL.
|
||||
|
||||
**For multi-node training**: Ensure shared storage is mounted to the `/storage` directory on every node. All downloads and resources will be stored in this directory, and the AReaL container will mount this directory to `/storage` within the container.
|
||||
|
||||
## Code Setup
|
||||
|
||||
Clone the AReaL project code to `/storage/codes`:
|
||||
|
||||
```bash
|
||||
mkdir -p /storage/codes
|
||||
cd /storage/codes/
|
||||
git clone https://github.com/inclusionAI/AReaL
|
||||
pip install -r AReaL/requirements.txt
|
||||
```
|
||||
|
||||
## Dataset
|
||||
|
||||
Download the provided training dataset and place it in `/storage/datasets/`:
|
||||
|
||||
```bash
|
||||
mkdir -p /storage/datasets/
|
||||
cd /storage/datasets/
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/boba_106k_0319.jsonl?download=true
|
||||
```
|
||||
|
||||
## Model
|
||||
|
||||
We train using open-source models available on Hugging Face Hub. Here's an example using Qwen3 (ensure Git LFS is installed):
|
||||
|
||||
```bash
|
||||
mkdir -p /storage/models
|
||||
cd /storage/models
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Qwen/Qwen3-1.7B
|
||||
cd Qwen3-1.7B
|
||||
git lfs pull
|
||||
```
|
||||
|
||||
**Alternative**: You can also use the Hugging Face CLI to download models after installing the `huggingface_hub` package. Refer to the [official documentation](https://huggingface.co/docs/huggingface_hub/guides/cli) for details.
|
|
@ -0,0 +1,6 @@
|
|||
# Overview
|
||||
|
||||
## Welcome to AReaL’s documentation!
|
||||
|
||||
```{tableofcontents}
|
||||
```
|
|
@ -0,0 +1,3 @@
|
|||
jupyter-book
|
||||
matplotlib
|
||||
numpy
|
|
@ -0,0 +1,108 @@
|
|||
# Training
|
||||
|
||||
## Launch the Ray Cluster
|
||||
|
||||
### Start the Ray Head Node
|
||||
|
||||
On the first node, start the Ray Head with the following command:
|
||||
|
||||
```bash
|
||||
docker run -d --name r1-ray-head --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.3.0 /bin/bash -c "ray start --head --port=6379 && tail -f /dev/null"
|
||||
```
|
||||
|
||||
### Start Ray Worker Nodes
|
||||
|
||||
On all other nodes, start the Ray Worker with the following command (skip this step for single-node setups):
|
||||
|
||||
```bash
|
||||
# Replace with the actual IP address of the first node
|
||||
RAY_HEAD_IP=xxx.xxx.xxx.xxx
|
||||
docker run -d --name r1-ray-worker --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.3.0 /bin/bash -c "ray start --address=$RAY_HEAD_IP:6379 && tail -f /dev/null"
|
||||
```
|
||||
|
||||
### Verify Cluster Status
|
||||
|
||||
Once all nodes are running, check the Ray cluster status by entering the container on the first node:
|
||||
|
||||
```bash
|
||||
docker exec -it r1-ray-head bash
|
||||
ray status
|
||||
```
|
||||
|
||||
You should see the Ray resource status displayed.
|
||||
|
||||
## Launch an Experiment
|
||||
|
||||
On the first node (where the Ray Head is located), run the following to launch an asynchronous PPO experiment:
|
||||
|
||||
```bash
|
||||
docker exec -it r1-ray-head bash
|
||||
cd /storage/codes/AReaL
|
||||
pip3 install -e .
|
||||
python3 training/main_async_ppo.py --config-name=async-ppo-1.7b-gpu8
|
||||
```
|
||||
|
||||
This command will locate the YAML configuration file `async-ppo-1.7b-gpu8.yaml` in the `training/configs/async-ppo` folder. The meaning of each configuration entry can be found in `realhf/api/cli_args.py`. You can run asynchronous PPO, synchronous PPO, or SFT depending on the script you execute.
|
||||
|
||||
After starting, you'll see training launch information like this:
|
||||
|
||||
```
|
||||
20250528-17:12:16.804 quickstart INFO: Running async-ppo-math experiment.
|
||||
20250528-17:12:16.804 quickstart INFO: Logs will be dumped to /storage/experiments/logs/admin/async-ppo-1.7b-gpu8/my-trial
|
||||
20250528-17:12:16.804 quickstart INFO: Experiment configs will be dumped to /storage/experiments/logs/admin/async-ppo-1.7b-gpu8/my-trial/config.yaml
|
||||
20250528-17:12:16.804 quickstart INFO: Model checkpoints will be saved to /storage/experiments/checkpoints/admin/async-ppo-1.7b-gpu8/my-trial
|
||||
20250528-17:12:19.261 quickstart INFO: Launching experiments with RAY...
|
||||
```
|
||||
|
||||
**Note**: The saved YAML configuration at `/storage/experiments/logs/admin/async-ppo-1.7b-gpu8/my-trial/config.yaml` can be used to reproduce previous experiments.
|
||||
|
||||
## Command Line Options
|
||||
|
||||
To view all available options:
|
||||
|
||||
```bash
|
||||
python3 -m realhf.apps.quickstart async-ppo-math --help
|
||||
```
|
||||
|
||||
### Important Parameters
|
||||
|
||||
- **`mode`**: Always set to `ray`. Do not change this value when following this tutorial.
|
||||
- **`{actor|critic|ref}.path`**: The path to the model files.
|
||||
- **`dataset.path`**: The path to the dataset JSONL file.
|
||||
- **`cluster.fileroot`**: The root path for saving training outputs.
|
||||
- **`n_nodes`**: The number of nodes in the cluster.
|
||||
- **`n_gpus_per_node`**: The number of GPUs per node.
|
||||
- **`allocation_mode`**: The GPU allocation strategy and 3D parallelism configuration for the experiment. Format:
|
||||
- `sglang.d${DP1}m${TP1}p${PP1}+d${DP2}m${TP2}p${PP2}`: Configures parallel strategies for SGLang generation and training respectively. Generation and training use separate GPU sets, and the total GPU count must equal: DP1×TP1×PP1 + DP2×TP2×PP2 = #GPUs.
|
||||
|
||||
### Training Control Parameters
|
||||
|
||||
- **`exp_ctrl.total_train_epochs`**: Number of training epochs (complete dataset iterations).
|
||||
- **`exp_ctrl.save_freq_{epochs|steps|secs}`**: Frequency for saving model parameters to persistent storage. Set to null to disable saving.
|
||||
- **`exp_ctrl.ckpt_freq_{epochs|steps|secs}`**: Frequency for saving temporary parameters for restart capability.
|
||||
- **`dataset.train_bs_n_seqs`**: Training batch size (number of prompts sampled per training iteration).
|
||||
- **`group_size`**: Number of responses sampled per prompt.
|
||||
- **`{actor_train|ref_inf|actor_inf}.mb_spec.max_tokens_per_mb`**: Maximum tokens per mini-batch for forward/backward passes during reference model inference and actor model training. Reduce to avoid OOM errors.
|
||||
- **`ppo.ppo_n_minibatches`**: Number of mini-batches for dividing data during each PPO update.
|
||||
- **`ppo.recompute_logprob`**: Whether to compute proximal log probabilities for training.
|
||||
- **`ppo.use_decoupled_loss`**: Use decoupled loss to stabilize asynchronous training.
|
||||
- **`ppo.gen.max_new_tokens`**: Maximum tokens to generate per prompt (default: 16k).
|
||||
- **`ppo.gen.min_new_tokens`**: Minimum tokens to generate per prompt (default: 0).
|
||||
|
||||
## Monitoring the Training Process
|
||||
|
||||
We recommend using Weights & Biases (wandb) for monitoring. Run `wandb login` or set the `WANDB_API_KEY` environment variable. Set `wandb.mode=True` in your configuration to upload training statistics.
|
||||
|
||||
The main log will be saved to `/storage/experiments/logs/admin/async-ppo-1.7b-gpu8/my-trial/main.log` and contains the statistics uploaded to wandb.
|
||||
|
||||
### Key Training Statistics
|
||||
|
||||
- **`Epoch 1/5`**: Indicates total epochs required and current epoch being trained.
|
||||
- **`step 6/19`**: Shows current epoch has 19 steps, with the 6th step just completed.
|
||||
- **`global step 6`**: Step count across all epochs.
|
||||
- **`task_reward`**: Average reward value of all sampled responses in this step. Should steadily increase during training and eventually stabilize.
|
||||
- **`importance_weight`**: Average importance sampling ratio across all tokens in the PPO loss. Typically close to 1.0.
|
||||
- **`actor_clip_ratio`**: Ratio of clipped tokens in PPO loss to total tokens. Usually less than 0.1.
|
||||
- **`actor_loss`**: PPO loss value. **Does not show clear trends during training** and should not be used as a performance indicator.
|
||||
- **`avg_seq_len`**: Average length of all sequences (prompts with sampled responses) in this step.
|
||||
- **`no_eos_ratio`**: Ratio of sampled responses truncated due to exceeding maximum generation length. An increase indicates longer average response lengths.
|
|
@ -0,0 +1,56 @@
|
|||
# Troubleshooting
|
||||
|
||||
If the following content does not address your issue, feel free to raise a GitHub Issue.
|
||||
|
||||
## Automatic Recovery
|
||||
|
||||
When setting `recover_mode=auto` and the experiment configuration remains unchanged, AReaL will attempt to discover previous checkpoints and recover the experiment from them.
|
||||
|
||||
### Recovery Failure Causes
|
||||
|
||||
If automatic recovery fails, check the following possibilities:
|
||||
|
||||
**Configuration Changes:**
|
||||
- The `experiment_name` and `trial_name` in the training script differ from the previous run
|
||||
- Changes in batch size (`dataset.train_bs_n_seqs` parameter)
|
||||
- Changes in group size (`group_size` parameter)
|
||||
- Changes in number of nodes (`n_nodes` parameter)
|
||||
|
||||
**Missing Recovery Checkpoints:**
|
||||
Recovery checkpoints are generated under two conditions by default:
|
||||
- After completion of the second step
|
||||
- When a step completes and more than 600 seconds have passed since the last recovery checkpoint (controlled by `exp_ctrl.ckpt_freq_secs=600`)
|
||||
|
||||
### Verify Recovery Checkpoint Creation
|
||||
|
||||
You can confirm if a recovery checkpoint was generated by searching for the following message in the logs:
|
||||
|
||||
```bash
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:52:02.760 master worker INFO: Dumped recover info to file.
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### torch.cuda.CudaOutOfMemoryError
|
||||
|
||||
The key to resolving this issue is identifying the phase where the error occurs:
|
||||
|
||||
#### During Initialization
|
||||
- Check for idle processes on the GPU
|
||||
- **Distributed scenarios**: Restart the Ray cluster
|
||||
- **Single-machine scenarios**: Use `pkill` to terminate processes
|
||||
|
||||
#### During SGLang Generation
|
||||
- Decrease the `actor.sglang.mem_fraction_static` parameter
|
||||
- Increase the tensor parallelism degree
|
||||
|
||||
#### During `actor_inf` or `actor_train`
|
||||
- **Adjust microbatch size**: Set parameters like `actor_train.mb_spec.max_tokens_per_mb=20480`. This parameter limits tokens per forward/backward pass and can be set as low as the maximum sequence length (including prompt)
|
||||
- **Modify parallelism strategy**: Adjust `allocation_mode` by:
|
||||
- Reducing data parallelism
|
||||
- Increasing tensor or pipeline parallelism
|
||||
- Preferring pipeline parallelism over tensor parallelism
|
||||
|
||||
### CUDA Error: Out of Memory
|
||||
|
||||
This issue may occur during data transfer. Try increasing `mem_per_xx_worker` in the CLI arguments.
|
|
@ -1,427 +0,0 @@
|
|||
# Improving LLM's Reasoning Capabilities with AReaL: A Complete Guide
|
||||
|
||||
# Prerequisites
|
||||
## Hardware Requirements
|
||||
Check if your hardware meets these minimum requirements:
|
||||
|
||||
|
||||
|**Model Size**| **1.5B** |**1.5B**|**1.5B**| **7B** | **7B** | **32B** |
|
||||
|---|:---:|:---:|:---:|:-------------------------:|:---:|:---:|
|
||||
| **Nodes** | **1** | **4** | **16** | **4** | **16** | **16** |
|
||||
| GPU | 8x H800 |8x H800 per node| 8x H800 per node | 8x H800 per node | 8x H800 per node | 8x H800 per node |
|
||||
| CPU | 48 cores |48 cores per node|48 cores per node| 48 cores per node | 48 cores per node| 48 cores per node|
|
||||
| Memory | 1 TB |1 TB per node|1 TB per node| 1 TB per node | 1 TB per node| 1 TB per node|
|
||||
| Network | NVSwitch |NVSwitch + RoCE 3.2 Tbps|NVSwitch + RoCE 3.2 Tbps| NVSwitch + RoCE 3.2 Tbps | NVSwitch + RoCE 3.2 Tbps| NVSwitch + RoCE 3.2 Tbps|
|
||||
| Storage | 1TB |Shared storage (NAS) 10TB|Shared storage (NAS) 10TB| Shared storage (NAS) 10TB |Shared storage (NAS) 10TB| Shared storage (NAS) 10TB|
|
||||
| BatchSize x GroupSize | 512x16 | 512x16 | 512x16 | 512x16 | 512x16 | 512x32|
|
||||
| **Single-step Time (seconds)** | **3461** | **997** | **391** | **2275** | **815** | **6707**|
|
||||
| **#Steps Until Convergence** | **~250** |**~250** |**~250** |**~400** |**~400** | - |
|
||||
| **Total Time (Hours)** | **~240** | **~69** | **~27** | **~252** | **~90** | - |
|
||||
|
||||
Notes:
|
||||
- GPUs need to have 80GB memory. Other GPU models with similar specs are acceptable.
|
||||
- Single-node training can use local storage, but multi-node training requires shared storage.
|
||||
- We haven't successfully train a powerful 32B model, so we cannot estimate the required steps and time.
|
||||
|
||||
## Software Requirements
|
||||
This tutorial provides a Docker image. Below are the tested software versions:
|
||||
|
||||
| | Version |
|
||||
|---|:---:|
|
||||
| OS | CentOS 7 / Ubuntu 22.04 or any other system that meets the software requirements below |
|
||||
| NVIDIA Driver | 550.127.08 |
|
||||
| CUDA | 12.5 |
|
||||
| Git LFS | Refer to: https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage. Mainly used for downloading models, datasets, and AReaL project code. |
|
||||
| Docker | 27.5.1 |
|
||||
|NVIDIA Container Toolkit|[Installing the NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)|
|
||||
| AReaL Image | `ghcr.io/inclusionai/areal-runtime:v0.2.0`. This image includes AReaL's runtime dependencies and Ray components. |
|
||||
|
||||
Since the installation of NVIDIA Drivers and CUDA, as well as the mounting of shared storage, depends on node configurations and system versions, please complete these installations independently. This tutorial does not cover their setup.
|
||||
|
||||
For multi-node training, ensure that the shared storage is mounted to the `/storage` directory on every node. All subsequent downloads and resources will be stored in this directory. The AReaL container will also mount this directory to `/storage` within the container, enabling seamless access during training.
|
||||
|
||||
|
||||
# One-Click Environment Setup and Training Launch
|
||||
|
||||
This section provides a one-click setup script to automatically configure the node environment:
|
||||
|
||||
1. Install Docker, Git LFS, and NVIDIA Container Toolkit
|
||||
2. Pull the AReaL image on each node
|
||||
3. Download AReaL code, models, and datasets
|
||||
4. Set up a Ray cluster
|
||||
5. [Optional] Launch a training task within the Ray cluster
|
||||
|
||||
Please perform the following operations on any chosen node:
|
||||
|
||||
```bash
|
||||
mkdir -p /storage/codes
|
||||
cd /storage/codes/
|
||||
git clone https://github.com/inclusionAI/AReaL.git
|
||||
cd /storage/codes/AReaL
|
||||
|
||||
python ./examples/env/setup_env_and_start_train.py setup --private_key_file /path/to/ssh_key --ssh_port 22 --username root --hostnames NODE_IP_1 NODE_IP_2 NODE_IP_3 NODE_IP_4 --train_param 1.5B_n1
|
||||
```
|
||||
|
||||
`setup_env_and_start_train.py setup` arguments:
|
||||
|
||||
- `private_key_file`: SSH secret key. Using by connecting nodes.
|
||||
- `ssh_port`: SSH port
|
||||
- `username`: SSH username
|
||||
- `hostnames`: IP list. Split with space. Can be 1, 4, or 16 node IPs
|
||||
- `train_param`: [Optional] Training parameters used to launch a training task immediately after environment setup. Valid options are: `1.5B_n1`, `1.5B_n4`, `1.5B_n16`, `7B_n4`, `7B_n16`
|
||||
|
||||
If the script in this section fails to execute or encounters errors due to environmental discrepancies, you may manually configure the environment and launch training by following the instructions in the subsequent sections of this tutorial.
|
||||
|
||||
# Environment Setup
|
||||
|
||||
Since shared storage is used, downloading only needs to be done on one node.
|
||||
|
||||
## Code
|
||||
Clone the AReaL project code to `/storage/codes`:
|
||||
|
||||
|
||||
```bash
|
||||
mkdir -p /storage/codes
|
||||
cd /storage/codes/
|
||||
git clone https://github.com/inclusionAI/AReaL
|
||||
```
|
||||
|
||||
## Dataset
|
||||
|
||||
We provide a dataset for training. Download the dataset and place it in `/storage/datasets/`:
|
||||
|
||||
```bash
|
||||
mkdir -p /storage/datasets/
|
||||
cd /storage/datasets/
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/boba_106k_0319.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/orz-zero_56k_0319.jsonl?download=true
|
||||
```
|
||||
|
||||
## Model
|
||||
|
||||
We train based on open-source models, which can be downloaded directly from HuggingFaceHub (Please ensure that Git LFS is installed):
|
||||
|
||||
```
|
||||
mkdir -p /storage/models
|
||||
cd /storage/models
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
|
||||
cd DeepSeek-R1-Distill-Qwen-7B
|
||||
git lfs pull
|
||||
```
|
||||
|
||||
You can also use the HuggingFace CLI to download after installing PyPI and huggingface_hub. Refer to the [official documentation](https://huggingface.co/docs/huggingface_hub/guides/cli) for details.
|
||||
|
||||
## Launch the Ray Cluster
|
||||
|
||||
Before proceeding, pull the AReaL environment image, which already includes Ray components.
|
||||
|
||||
On the first node, start the Ray Head with the following command:
|
||||
|
||||
```bash
|
||||
docker run -d --name r1-ray-head --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.2.0 /bin/bash -c "ray start --head --port=6379 && tail -f /dev/null"
|
||||
```
|
||||
|
||||
On all other nodes, start the Ray Worker with the following command (skip this step if you only have one node):
|
||||
|
||||
```bash
|
||||
# RAY_HEAD_IP is the IP of the first node
|
||||
RAY_HEAD_IP=xxx.xxx.xxx.xxx
|
||||
docker run -d --name r1-ray-worker --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.2.0 /bin/bash -c "ray start --address=$RAY_HEAD_IP:6379 && tail -f /dev/null"
|
||||
```
|
||||
|
||||
Once all nodes are up, check the Ray cluster status by entering the container on the first node:
|
||||
|
||||
```bash
|
||||
docker exec -it r1-ray-head bash
|
||||
ray status
|
||||
```
|
||||
|
||||
You should see the Ray resource status. The output will vary depending on your node count (e.g., a 16-node, 128-GPU cluster will show the following results).
|
||||
|
||||
```
|
||||
======== Autoscaler status: 2025-02-22 14:08:51.061250 ========
|
||||
Node status
|
||||
---------------------------------------------------------------
|
||||
Active:
|
||||
1 node_d5634ae61bfe6732d957811bed65c8a39f13ece07e0326f941acbc4e
|
||||
1 node_23b0c08045c9a39bc4c454cae298ee531d9a474215ac5e77a5b01e74
|
||||
1 node_bc1016320658e92645f29cecb8aaf51c0b7e01a44e8ac9c814dfee59
|
||||
1 node_4e7d15e9cee9ee0da5d65e45f1e346228c52bc0c557511c6eeab40dc
|
||||
1 node_c5bcf15e28a00515be5d2a7e8e33d71f0f57cdfaf1003db9e0c74788
|
||||
1 node_ec3f6ee8f6fdf3a5392bb4dac244668da75d094e084dcbb520ce2525
|
||||
1 node_dc2f1eef88126ae4ac7902574714af9ab74b78ba037217e73e063639
|
||||
1 node_a4728608c1fda187dc33bb24e831c42fe5c8a582ad428b6e595933bc
|
||||
1 node_970379a3ba750ee3b13e31612b6a6b758d50bd4943555b2a13d1bd61
|
||||
1 node_bf6b658bea9e437fcb642a2d881425662a689d668c92fe1545899b36
|
||||
1 node_2c69511f410d9360f1d05893fde2c97dd32240e0315afea9b2d286a3
|
||||
1 node_e4c90c17cc48ad469d123041d3302dcff1f7a82a4805279300812b19
|
||||
1 node_3f772cbffb206c30b6ccedade83789d78397804bab874ee59563cb96
|
||||
1 node_429bd5115b5590b612590bb455f2d3ed4f77055d746a184baf807655
|
||||
1 node_75071820f2c16dc51fa271316b72cd45335ec877c06450d292ab7d54
|
||||
1 node_6f4323f9038248d82b91321e2c4ca5fa99e65efa2d976c0b896a8964
|
||||
Pending:
|
||||
(no pending nodes)
|
||||
Recent failures:
|
||||
(no failures)
|
||||
|
||||
Resources
|
||||
---------------------------------------------------------------
|
||||
Usage:
|
||||
0.0/2128.0 CPU
|
||||
0.0/128.0 GPU
|
||||
0B/21.08TiB memory
|
||||
0B/2.91TiB object_store_memory
|
||||
|
||||
Demands:
|
||||
(no resource demands)
|
||||
```
|
||||
|
||||
# RL Training
|
||||
|
||||
Before starting distributed training, ensure the Ray cluster is up and running properly.
|
||||
Then, on the first node (where the Ray Head is located), enter the container:
|
||||
|
||||
```
|
||||
docker exec -it r1-ray-head bash
|
||||
cd /storage/codes/AReaL
|
||||
```
|
||||
|
||||
Choose a config file that matches your hardware environment and run it:
|
||||
|
||||
```bash
|
||||
python3 -m realhf.apps.quickstart ppo-math --config ./examples/configs/7B-distill/ppo-7B-distill-gpus-128.yaml
|
||||
```
|
||||
|
||||
After starting, check the training launch information:
|
||||
|
||||
```
|
||||
╭─────────────────────────────────────────────────╮
|
||||
│ Setting PPOMATHConfig with the Following Values │
|
||||
╰─────────────────────────────────────────────────╯
|
||||
|
||||
───────────────────────── Current Configuration Begin ──────────────────────────
|
||||
actor (ModelTrainEvalConfig)
|
||||
actor.type (ModelFamily)
|
||||
actor.type._class (str) - qwen2
|
||||
actor.type.size (int) - 7
|
||||
actor.type.is_critic (bool) - False
|
||||
...
|
||||
────────────────────────── Current Configuration End ───────────────────────────
|
||||
|
||||
20250222-10:26:34.877 quickstart INFO: Running ppo-math experiment.
|
||||
20250222-10:44:15.581 quickstart INFO: Logs will be dumped to /storage/ray/experiments/logs/root/ppo-7B-distill-gpus-128/512x16
|
||||
20250222-10:44:15.581 quickstart INFO: Model checkpoints will be saved to /storage/ray/experiments/checkpoints/root/ppo-7B-distill-gpus-128/512x16
|
||||
20250222-10:26:36.408 quickstart INFO: Launching experiments with RAY...
|
||||
```
|
||||
|
||||
If errors occur during execution (e.g., keywords like "Error" appear), refer to the troubleshooting section.
|
||||
|
||||
## Commandline Options
|
||||
|
||||
```bash
|
||||
python3 -m realhf.apps.quickstart ppo-math --help
|
||||
```
|
||||
|
||||
The descriptions of the important parameters are as follows:
|
||||
|
||||
|
||||
+ `mode`: It is always `ray`, and do not change it to other values when referring to this tutorial for training.
|
||||
+ `{actor|critic|ref}.path`: The path of the model.
|
||||
+ `dataset.path`: The path of the dataset jsonl file
|
||||
+ `external_configs.cluster_config`: Set config for cluster_config. e.g. fileroot is the root path for saving traning outputs.
|
||||
|
||||
+ `n_nodes`: The number of nodes
|
||||
+ `n_gpus_per_node`: The number of GPUs per node
|
||||
+ `allocation_mode`: The GPU allocation and 3D parallel strategy of the model in the experiment, mainly in the following form:
|
||||
+ `sglang.d${DP1}m${TP1}p${PP1}+d${DP2}m${TP2}p${PP2}`: Configure the parallel strategies for SGLang generation and training respectively. The generation and training use disjoint sets of GPUs, and the sum of the number of GPUs used by the two should be equal to the total number of GPUs, i.e DP1xTP1xPP1+DP2xTP2xPP2=#GPUs.
|
||||
|
||||
+ `exp_ctrl.total_train_epochs`: The number of training epochs (i.e., the number of times to iterate over the entire dataset)
|
||||
+ `exp_ctrl.save_freq_{epochs|steps|secs}`: The frequency of saving the model parameters in persistent storage. If it is set to null, the model will not be saved.
|
||||
+ `exp_ctrl.ckpt_freq_{epochs|steps|secs}`: The frequency of saving temporary parameters for restart
|
||||
+ `dataset.train_bs_n_seqs`: The training batch size, that is, the number of prompts to be sampled each time during training
|
||||
+ `group_size`: The number of answers to be sampled for each prompt
|
||||
+ `{actor_train|ref_inf}.mb_spec.max_tokens_per_mb`: The maximum number of tokens in the data for each forward/backward pass during the inference of the reference model and the training of the actor model. It can be reduced to avoid OOM errors. These data will accumulate gradients for a single parameter update.
|
||||
+ `ppo.ppo_n_minibatches`: The number of parts into which all the data will be divided for each PPO update to calculate the loss and update the parameters.
|
||||
+ `ppo.gen.max_new_tokens`: The maximum number of tokens to be generated for a single prompt, default to 16k.
|
||||
+ `ppo.gen.min_new_tokens`: The minimum number of tokens to be generated for a single prompt, default to 0.
|
||||
|
||||
## Monitoring the Training Process
|
||||
|
||||
Here, we use the logs from a 16-node run (the same applies to 1-node and 4-node setups) to explain several methods for observing training progress and results.
|
||||
|
||||
### Training Progress
|
||||
|
||||
Search for the keyword `Epoch` in the logs to see the total number of Epochs and Steps:
|
||||
|
||||
```bash
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:11:56.997 master worker INFO: Epoch 1/1 step 1/19 (global step 1) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2124.429*s. Total time consumption: 2283.862s.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:52:02.719 master worker INFO: Epoch 1/1 step 2/19 (global step 2) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2405.716*s. Total time consumption: 4689.584s.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-12:27:25.084 master worker INFO: Epoch 1/1 step 3/19 (global step 3) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2122.318*s. Total time consumption: 6811.949s. Estimated remaining time: 33957.093s.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:05:58.246 master worker INFO: Epoch 1/1 step 4/19 (global step 4) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2313.134*s. Total time consumption: 9125.111s. Estimated remaining time: 33265.891s.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:44:14.349 master worker INFO: Epoch 1/1 step 5/19 (global step 5) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2296.076*s. Total time consumption: 11421.214s. Estimated remaining time: 31413.800s.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:22:33.864 master worker INFO: Epoch 1/1 step 6/19 (global step 6) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2299.448*s. Total time consumption: 13720.729s. Estimated remaining time: 29350.673s.
|
||||
```
|
||||
|
||||
Six log entries are found. We explain the meaning of each field based on the last entry:
|
||||
- `Epoch 1/1`: Indicates that a total of 1 Epoch is required, and the first Epoch is currently being trained. This example only trains for 1 Epoch. Normally, training should run for 10 Epochs or more.
|
||||
- `step 6/19`: Indicates that the current Epoch has 19 Steps, and the 6th Step has just finished.
|
||||
- `global step 6`: Represents the step count across all Epochs.
|
||||
- `#End to end# execution time: *2299.448*s`: Indicates that the current Step took 2299.448 seconds to complete.
|
||||
- `Total time consumption: 13720.729s`: The total time elapsed since training started is 13720.729 seconds.
|
||||
- `Estimated remaining time: 29350.673s`: The estimated time remaining to complete training is 29350.673 seconds.
|
||||
|
||||
|
||||
### Model Performance
|
||||
|
||||
Search for the keyword `task_reward` in the logs.
|
||||
|
||||
|
||||
```bash
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:11:56.991 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.2640759198111482e-05, 'actor_loss': 1.1128166761409375e-06, 'actor_clip_ratio': 2.1122002635820536e-07, 'importance_weight': 1.0000014305114746, 'task_reward': -0.2996826171875, 'kl_reward': -2.27004832709099e-07, 'final_reward': -0.30145370960235596, 'advantage': 0.003593671601265669, 'avg_seq_len': 7907.8955078125, 'avg_prompt_len': 105.845703125, 'n_tokens': 127828786.0, 'n_valid_tokens': 127828786.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.122802734375, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:52:02.712 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.493159263394773e-05, 'actor_loss': -3.846728588996484e-07, 'actor_clip_ratio': 3.16789424914532e-07, 'importance_weight': 0.9999996423721313, 'task_reward': -0.6793212890625, 'kl_reward': -2.536311853873485e-07, 'final_reward': -0.6813737154006958, 'advantage': 0.004844569601118565, 'avg_seq_len': 8203.9453125, 'avg_prompt_len': 111.892578125, 'n_tokens': 132580185.0, 'n_valid_tokens': 132580185.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.13812255859375, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-12:27:25.077 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.572356243035756e-05, 'actor_loss': -5.036404786551429e-07, 'actor_clip_ratio': 1.8960582792715286e-07, 'importance_weight': 0.9999992251396179, 'task_reward': -0.6280517578125, 'kl_reward': -2.988609537624143e-07, 'final_reward': -0.6303607225418091, 'advantage': 0.004505862481892109, 'avg_seq_len': 7834.6328125, 'avg_prompt_len': 108.900390625, 'n_tokens': 126578395.0, 'n_valid_tokens': 126578395.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.11761474609375, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:05:58.239 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.4861981728463434e-05, 'actor_loss': 1.3935685672095133e-07, 'actor_clip_ratio': 3.02603467616791e-07, 'importance_weight': 0.9999998807907104, 'task_reward': -0.78857421875, 'kl_reward': -3.672174671009998e-07, 'final_reward': -0.791388750076294, 'advantage': 0.005053278990089893, 'avg_seq_len': 7773.39404296875, 'avg_prompt_len': 108.7890625, 'n_tokens': 125576883.0, 'n_valid_tokens': 125576883.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.117919921875, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:44:14.342 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.516058702894952e-05, 'actor_loss': -7.665488510610885e-07, 'actor_clip_ratio': 1.9505058901359007e-07, 'importance_weight': 0.9999997615814209, 'task_reward': -0.6158447265625, 'kl_reward': -4.6867208425283025e-07, 'final_reward': -0.6195111274719238, 'advantage': 0.004475570283830166, 'avg_seq_len': 7928.50830078125, 'avg_prompt_len': 105.517578125, 'n_tokens': 128171874.0, 'n_valid_tokens': 128171874.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.12353515625, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:22:33.857 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.4821250917739235e-05, 'actor_loss': -3.922649227661168e-07, 'actor_clip_ratio': 3.323623900541861e-07, 'importance_weight': 1.0000001192092896, 'task_reward': -0.7025146484375, 'kl_reward': -5.863367960046162e-07, 'final_reward': -0.7071446776390076, 'advantage': 0.004277692176401615, 'avg_seq_len': 8002.4873046875, 'avg_prompt_len': 105.951171875, 'n_tokens': 129376851.0, 'n_valid_tokens': 129376851.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.12286376953125, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
|
||||
```
|
||||
|
||||
The last entry is used to explain the meaning of key fields:
|
||||
- `task_reward`: The average reward value of all sampled answers in this step. This value should steadily increase during training and eventually stabilize.
|
||||
- `importance_weight`: The average importance sampling ratio across all tokens in the PPO loss. This value is typically close to 1.
|
||||
- `actor_clip_ratio`: The ratio of tokens clipped in the PPO loss to the total number of tokens. This is usually less than 0.1.
|
||||
- `actor_loss`: The PPO loss. **It does not show a clear upward or downward trend during training** and should not be used as a reference for model performance.
|
||||
- `avg_seq_len`: The average length of all sequences (i.e., prompts with sampled answers) in this step. In a full multi-stage training process, this value will first decrease and then increase.
|
||||
- `no_eos_ratio`: The ratio of sampled answers truncated due to exceeding the maximum generation length. An increase in this value indicates that the average length of answers is increasing.
|
||||
|
||||
# Evaluation
|
||||
|
||||
## Evaluation Process
|
||||
|
||||
The evaluation code is located in the `evaluation` folder of the repository. As per the previous tutorial, the trained checkpoints will be saved under the path `/storage/ray/experiments/checkpoints/root/`, for example, `/storage/ray/experiments/checkpoints/root/ppo-zero-distill-7B-n16/1024x16-n16/actor/epoch1epochstep20globalstep20/`.
|
||||
|
||||
Start a new container to execute the evaluation script (note: evaluation requires updates to certain Python libraries; avoid using the training container for this task):
|
||||
```
|
||||
docker run -d --name r1-eval --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.2.0 /bin/bash -c "tail -f /dev/null"
|
||||
docker exec -it r1-eval bash
|
||||
```
|
||||
|
||||
Run the following script inside the Docker container to evaluate:
|
||||
|
||||
```bash
|
||||
cd /storage/codes/AReaL/evaluation
|
||||
cd latex2sympy
|
||||
pip install -e .
|
||||
cd ..
|
||||
pip install -r requirements.txt
|
||||
pip install vllm --no-build-isolation
|
||||
pip install transformers==4.47.0
|
||||
pip install prettytable timeout_decorator
|
||||
mkdir /storage/ray/eval_output/
|
||||
nohup python eval_and_aggregate.py \
|
||||
--model_path /storage/ray/experiments/checkpoints/root/ppo-zero-distill-7B-n16/1024x16-n16/actor/epoch1epochstep20globalstep20/ \
|
||||
--output_path /storage/ray/eval_output/ \
|
||||
--data_names "math_500,aime24,amc23" \
|
||||
--max_gen_tokens 32768 &> /storage/ray/eval_output/eval_and_aggregate_parallel.log &
|
||||
```
|
||||
|
||||
+ `--model_path`: Path to the saved model parameters.
|
||||
+ `--output_path`: Path to store the generated answers and log files during evaluation.
|
||||
+ `--data_names`: Specify the dataset(s) to evaluate. Multiple datasets can be separated by commas. Default is `math_500, math, gsm8k, train_amc_aime, aime24, amc23`.
|
||||
+ `--max_gen_tokens`: Maximum length of generated answers. Default is `32768`.
|
||||
|
||||
## Evaluation Results
|
||||
|
||||
The evaluation script will output a table in the terminal, for example:
|
||||
|
||||
```
|
||||
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
|
||||
| dataset | num_questions | greedy_length | sample_length | greedy_acc | sample_pass@1 | pass@8 | pass@16 |
|
||||
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
|
||||
| math_500 | 500 | 6757.4 | 4139.5 | 84.4 | 92.7 | 97.3 | 97.7 |
|
||||
| aime24 | 30 | 19328.0 | 13663.5 | 50.0 | 50.4 | 77.3 | 80.0 |
|
||||
| amc23 | 40 | 8850.0 | 6526.2 | 80.0 | 90.5 | 96.8 | 98.8 |
|
||||
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
|
||||
```
|
||||
|
||||
|
||||
+ `{greedy|sample}_length`: Average answer length under greedy or random sampling strategy.
|
||||
+ `greedy_acc`: Average accuracy under greedy sampling.
|
||||
+ `sample_pass@{k}`: Probability of generating a correct answer on average per `k` attempts under random sampling.
|
||||
|
||||
## Additional Notes
|
||||
|
||||
### Key Parameters
|
||||
|
||||
+ The evaluation script defaults to taking the average of 32 samples with temperature 0.6.
|
||||
+ We observed that the `enforce_eager` parameter in vLLM significantly impacts evaluation performance. When `enforce_eager=True`, we can reproduce the model performance reported in previous work. Otherwise, the evaluation results may fall below the reported performance. Therefore, we enforce `enforce_eager` to be enabled during evaluation.
|
||||
|
||||
Due to the above reasons, the evaluation process typically takes a considerable amount of time.
|
||||
|
||||
### Runtime
|
||||
|
||||
The runtime of the evaluation depends on factors such as the maximum generation length, the number of questions in the dataset, and the model size. On a machine with 8x H100 GPUs, evaluating `aime` and `math_500` takes approximately 80 minutes and 160 minutes, respectively.
|
||||
|
||||
# Troubleshooting
|
||||
|
||||
If the following content does not address your issue, feel free to raise a GitHub Issue.
|
||||
|
||||
|
||||
## Automatic Recover
|
||||
|
||||
When setting `recover_mode=auto` and the experiment config remains the same, AReaL will try to discover previous checkpoints and recover the experiment from it.
|
||||
|
||||
If the automatic recover fails, please check the following possibilities:
|
||||
|
||||
* The `experiment_name` and `trial_name` in the training script differ from the previous run.
|
||||
|
||||
* Changes in Batch Size (`dataset.train_bs_n_seqs` in the parameters), Group Size (`group_size` in the parameters), or the number of nodes (`n_nodes` in the parameters).
|
||||
|
||||
* No recover checkpoint was created in the previous run. By default, recover checkpoints are generated under two conditions:
|
||||
|
||||
* After the completion of the second Step.
|
||||
|
||||
* When a Step completes and more than 600 seconds have passed since the last recover checkpoint. This parameter is in the `./examples/configs/*/*.yaml`, named `exp_ctrl.ckpt_freq_secs=600`.
|
||||
|
||||
You can confirm if a recover checkpoint was generated by searching in the log:
|
||||
|
||||
```bash
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:52:02.760 master worker INFO: Dumped recover info to file.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-12:27:25.105 master worker INFO: Dumped recover info to file.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:05:58.264 master worker INFO: Dumped recover info to file.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:44:14.411 master worker INFO: Dumped recover info to file.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:22:33.883 master worker INFO: Dumped recover info to file.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:59:44.925 master worker INFO: Dumped recover info to file.
|
||||
```
|
||||
|
||||
## Series of OutOfMemory Errors
|
||||
|
||||
While our scripts are designed to minimize OOM (Out of Memory) errors, they can still occasionally occur, especially due to memory fragmentation and increasing sequence lengths. Although these issues are often resolved by automatic restarts, users may require the following targeted solutions.
|
||||
|
||||
### torch.cuda.CudaOutOfMemoryError
|
||||
|
||||
The key to resolving this issue is identifying the phase in which the error occurs.
|
||||
|
||||
- **If it occurs during initialization (before `actor_gen`):**
|
||||
- Check if there are any idle processes on the GPU. In distributed scenarios, restart the Ray cluster. In single-machine scenarios, use `pkill`.
|
||||
- **This error typically does not occur during the `actor_gen` phase.**
|
||||
- **If it occurs during `ref_inf` or `actor_train`:**
|
||||
- Adjust the microbatch size for the corresponding computation task. For example, set `actor_train.mb_spec.max_tokens_per_mb=20480`. This parameter limits the number of tokens per forward/backward pass and can be set as low as the maximum sequence length (including the prompt).
|
||||
- Modify the parallelism strategy (`allocation_mode`) for the 7B model. Try reducing data parallelism and increasing tensor or pipeline parallelism.
|
||||
|
||||
### CUDA error: out of memory
|
||||
|
||||
This issue may occur during vLLM's initialization of the CPU KV cache, indicating insufficient memory on the machine. To resolve this, reduce the value of `actor.vllm.swap_space`.
|
||||
|
||||
### RuntimeError: Aborted due to the lack of CPU swap space.
|
||||
|
||||
This issue arises when the sequence length and KV cache demand exceed GPU memory, and the CPU swap space is insufficient. It is closely related to [Preemption errors](https://docs.vllm.ai/en/latest/performance/optimization.html). To resolve this, increase `actor.vllm.swap_space`. If the error persists, reduce `actor.vllm.max_num_seqs` and refer to the [vLLM documentation](https://docs.vllm.ai/en/latest/performance/optimization.html).
|
||||
|
||||
### CUDA error: an illegal memory access was encountered
|
||||
|
||||
This error typically occurs during the vLLM generation phase and is another symptom of insufficient GPU memory. Solutions include:
|
||||
|
||||
- Reduce the training batch size or the number of answers generated per prompt. Note that this may lower sample efficiency and extend training time.
|
||||
- [Switch vLLM's attention backend to xformers](https://github.com/vllm-project/vllm/issues/5376).
|
||||
|
||||
|
||||
|
|
@ -1,423 +0,0 @@
|
|||
# 利用AReaL提升大语言模型推理能力:中文教程
|
||||
|
||||
# 前置要求
|
||||
|
||||
## 硬件要求
|
||||
|
||||
为了能正常完成训练流程,请参照下表确认你的硬件是否满足要求:
|
||||
|
||||
| **模型大小** | **1.5B** | **1.5B** |**1.5B** | **7B** |**7B** | **32B** |
|
||||
|---------------------|---|---|---|---------------------------|---|---|
|
||||
| 节点 | 1 | 4 | 16 | 4 | 16 | 16 |
|
||||
| GPU | 8 张 H800 | 每节点 8 张 H800 |每节点 8 张 H800 | 每节点 8 张 H800 |每节点 8 张 H800 |每节点 8 张 H800 |
|
||||
| CPU | 48 核 | 每节点 48 核 |每节点 48 核 | 每节点 48 核 |每节点 48 核 | 每节点 48 核 |
|
||||
| 内存 | 1 TB |每节点 1 TB|每节点 1 TB | 每节点 1 TB |每节点 1 TB | 每节点 1 TB |
|
||||
| 通信 | NVSwitch |NVSwitch+RoCE 带宽 3.2 Tbps|NVSwitch+RoCE 带宽 3.2 Tbps| NVSwitch+RoCE 带宽 3.2 Tbps |NVSwitch+RoCE 带宽 3.2 Tbps| NVSwitch+RoCE 带宽 3.2 Tbps|
|
||||
| 存储 | 1TB |共享存储(NAS)10TB |共享存储(NAS)10TB | 共享存储(NAS)10TB |共享存储(NAS)10TB | 共享存储(NAS)10TB |
|
||||
| BatchSize x GroupSize | 512x16 | 512x16 | 512x16 | 512x16 | 512x16 | 512x32|
|
||||
| 单步训练时间(秒) | **3461** | **997** | **391** | **2275** | **815** | **6707**|
|
||||
| 训练至收敛需要步数 | **~250** |**~250** |**~250** |**~400** |**~400** | - |
|
||||
| 总训练时间(小时) | **~240** | **~69** | **~27** | **~252** | **~90** | - |
|
||||
|
||||
关于硬件要求的说明:
|
||||
|
||||
- GPU 需要 80GB 显存,可以选择同级别其他 GPU 型号。
|
||||
|
||||
- 单节点训练时可以使用本地存储,但多节点训练必须要提供共享存储,否则无法进行训练。
|
||||
|
||||
- 目前32B模型没有训练出有意义的结果,所以无法估计训练到收敛需要的步数和时间。
|
||||
|
||||
|
||||
## 软件要求
|
||||
|
||||
本教程提供 Docker镜像。以下是经过测试的软件版本,可以参考如下软件版本进行配置。
|
||||
|
||||
||版本说明|
|
||||
|---|---|
|
||||
|OS|CentOS 7 / Ubuntu 22.04 或其他满足下方软件运行的系统|
|
||||
|NVIDIA Driver|版本:550.127.08|
|
||||
|CUDA|版本:12.5|
|
||||
|Git LFS|参考:[Git LFS 安装指南](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage) 主要用于下载模型,数据集,AReaL 工程代码|
|
||||
|Docker|版本:27.5.1|
|
||||
|NVIDIA Container Toolkit|[NVIDIA Container Toolkit 安装指南](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)|
|
||||
|镜像|ghcr.io/inclusionai/areal-runtime:v0.2.0 这个镜像中包含运行依赖和 Ray 的相关组件|
|
||||
|
||||
|
||||
由于 NVIDIA Driver 和 CUDA 的安装以及共享存储的挂载与节点和系统版本有关,请自行完成安装,本教程不进行介绍。
|
||||
|
||||
如果是多节点训练,请先将共享存储挂载到每个节点的 `/storage` 目录上,后续下载的内容都将放在这个目录下,并且 AReaL 容器也会将该目录挂载到容器的 `/storage`,以便训练时访问。
|
||||
|
||||
|
||||
# 一键搭建环境并启动训练
|
||||
|
||||
本节提供一个一键安装脚本,自动完成节点的环境配置工作:
|
||||
1. 安装 Docker,Git LFS,NVIDIA Container Toolkit
|
||||
2. 在每个节点上拉取 AReaL 镜像
|
||||
3. 下载 AReaL 代码,模型,数据集
|
||||
4. 搭建 Ray 集群
|
||||
5. 【可选】在 Ray 集群中启动一个训练任务
|
||||
|
||||
请选择任意一个节点执行如下操作:
|
||||
|
||||
```bash
|
||||
mkdir -p /storage/codes
|
||||
cd /storage/codes/
|
||||
git clone https://github.com/inclusionAI/AReaL.git
|
||||
cd /storage/codes/AReaL
|
||||
|
||||
python ./examples/env/setup_env_and_start_train.py setup --private_key_file /path/to/ssh_key --ssh_port 22 --username root --hostnames NODE_IP_1 NODE_IP_2 NODE_IP_3 NODE_IP_4 --train_param 1.5B_n1
|
||||
```
|
||||
|
||||
`setup_env_and_start_train.py setup` 参数说明:
|
||||
|
||||
- `private_key_file`:SSH 私钥文件,用于连接节点
|
||||
- `ssh_port`:SSH 端口
|
||||
- `username`:SSH 用户名
|
||||
- `hostnames`:IP 列表,用空格分割。可以是 1/4/16 个节点 IP
|
||||
- `train_param`:【可选】训练参数,用于在完成环境搭建后直接启动一个训练任务。可选值为 `1.5B_n1`,`1.5B_n4`,`1.5B_n16`,`7B_n4`,`7B_n16`
|
||||
|
||||
如果因为环境差异,无法运行本节中的脚本或运行出现错误,也可以按照本教程后续章节的内容手动完成环境配置和启动训练。
|
||||
|
||||
# 环境配置
|
||||
|
||||
由于使用了共享存储,下载操作只需要在一个节点上完成。
|
||||
|
||||
## 代码
|
||||
将 AReaL 项目代码克隆到 `/storage/codes` 中:
|
||||
|
||||
|
||||
```bash
|
||||
mkdir -p /storage/codes
|
||||
cd /storage/codes/
|
||||
git clone https://github.com/inclusionAI/AReaL.git
|
||||
```
|
||||
|
||||
## 数据集
|
||||
|
||||
我们提供了用于训练的数据集,请下载数据集并放置在 /storage/datasets/
|
||||
```bash
|
||||
mkdir -p /storage/datasets/
|
||||
cd /storage/datasets/
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/boba_106k_0319.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/orz-zero_56k_0319.jsonl?download=true
|
||||
```
|
||||
|
||||
## 模型
|
||||
我们基于开源模型进行训练,该模型可以从 HuggingFace Hub 直接下载(请确保已经安装了 Git LFS):
|
||||
|
||||
```
|
||||
mkdir -p /storage/models
|
||||
cd /storage/models
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
|
||||
cd DeepSeek-R1-Distill-Qwen-7B
|
||||
git lfs pull
|
||||
```
|
||||
|
||||
你也可以在安装 PyPI 和 huggingface_hub 后利用 huggingface CLI 进行下载,具体请参考[官方文档](https://huggingface.co/docs/huggingface_hub/guides/cli)
|
||||
|
||||
|
||||
## 启动 Ray 集群
|
||||
|
||||
在执行这一步之前,请先拉取 AReaL 环境镜像,这个镜像中已经包含了 Ray 相关的组件。
|
||||
|
||||
在第一个节点上执行如下命令启动 Ray Head:
|
||||
|
||||
```bash
|
||||
docker run -d --name r1-ray-head --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.2.0 /bin/bash -c "ray start --head --port=6379 && tail -f /dev/null"
|
||||
```
|
||||
|
||||
在除了第一个节点以外的每个节点上执行如下命令启动 Ray Worker(如果只有一个节点,这一步就不用执行了):
|
||||
|
||||
```bash
|
||||
# RAY_HEAD_IP 是第一个节点的 IP
|
||||
RAY_HEAD_IP=xxx.xxx.xxx.xxx
|
||||
docker run -d --name r1-ray-worker --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.2.0 /bin/bash -c "ray start --address=$RAY_HEAD_IP:6379 && tail -f /dev/null"
|
||||
```
|
||||
|
||||
全部启动完成后,在第一个节点上通过 docker exec 进入容器,查看 Ray 集群的状态:
|
||||
|
||||
```bash
|
||||
docker exec -it r1-ray-head bash
|
||||
ray status
|
||||
```
|
||||
|
||||
可以看到 Ray 的资源情况,输出如下(这是一个 16 节点 128 卡的集群,根据你的节点数量,这里的输出会有所不同):
|
||||
|
||||
```
|
||||
======== Autoscaler status: 2025-02-22 14:08:51.061250 ========
|
||||
Node status
|
||||
---------------------------------------------------------------
|
||||
Active:
|
||||
1 node_d5634ae61bfe6732d957811bed65c8a39f13ece07e0326f941acbc4e
|
||||
1 node_23b0c08045c9a39bc4c454cae298ee531d9a474215ac5e77a5b01e74
|
||||
1 node_bc1016320658e92645f29cecb8aaf51c0b7e01a44e8ac9c814dfee59
|
||||
1 node_4e7d15e9cee9ee0da5d65e45f1e346228c52bc0c557511c6eeab40dc
|
||||
1 node_c5bcf15e28a00515be5d2a7e8e33d71f0f57cdfaf1003db9e0c74788
|
||||
1 node_ec3f6ee8f6fdf3a5392bb4dac244668da75d094e084dcbb520ce2525
|
||||
1 node_dc2f1eef88126ae4ac7902574714af9ab74b78ba037217e73e063639
|
||||
1 node_a4728608c1fda187dc33bb24e831c42fe5c8a582ad428b6e595933bc
|
||||
1 node_970379a3ba750ee3b13e31612b6a6b758d50bd4943555b2a13d1bd61
|
||||
1 node_bf6b658bea9e437fcb642a2d881425662a689d668c92fe1545899b36
|
||||
1 node_2c69511f410d9360f1d05893fde2c97dd32240e0315afea9b2d286a3
|
||||
1 node_e4c90c17cc48ad469d123041d3302dcff1f7a82a4805279300812b19
|
||||
1 node_3f772cbffb206c30b6ccedade83789d78397804bab874ee59563cb96
|
||||
1 node_429bd5115b5590b612590bb455f2d3ed4f77055d746a184baf807655
|
||||
1 node_75071820f2c16dc51fa271316b72cd45335ec877c06450d292ab7d54
|
||||
1 node_6f4323f9038248d82b91321e2c4ca5fa99e65efa2d976c0b896a8964
|
||||
Pending:
|
||||
(no pending nodes)
|
||||
Recent failures:
|
||||
(no failures)
|
||||
|
||||
Resources
|
||||
---------------------------------------------------------------
|
||||
Usage:
|
||||
0.0/2128.0 CPU
|
||||
0.0/128.0 GPU
|
||||
0B/21.08TiB memory
|
||||
0B/2.91TiB object_store_memory
|
||||
|
||||
Demands:
|
||||
(no resource demands)
|
||||
```
|
||||
|
||||
# RL训练
|
||||
|
||||
在进行分布式训练之前,请确保已经启动了 Ray 集群,并且集群状态正常。
|
||||
然后在第一个节点(Ray Head 所在节点),进入容器:
|
||||
|
||||
```
|
||||
docker exec -it r1-ray-head bash
|
||||
cd /storage/codes/AReaL
|
||||
```
|
||||
|
||||
选择匹配硬件环境的一个配置运行即可:
|
||||
|
||||
```bash
|
||||
python3 -m realhf.apps.quickstart ppo-math --config ./examples/configs/7B-distill/ppo-7B-distill-gpus-128.yaml
|
||||
```
|
||||
|
||||
启动后,在终端可以看到启动日志:
|
||||
```
|
||||
╭─────────────────────────────────────────────────╮
|
||||
│ Setting PPOMATHConfig with the Following Values │
|
||||
╰─────────────────────────────────────────────────╯
|
||||
|
||||
───────────────────────── Current Configuration Begin ──────────────────────────
|
||||
actor (ModelTrainEvalConfig)
|
||||
actor.type (ModelFamily)
|
||||
actor.type._class (str) - qwen2
|
||||
actor.type.size (int) - 7
|
||||
actor.type.is_critic (bool) - False
|
||||
...
|
||||
────────────────────────── Current Configuration End ───────────────────────────
|
||||
|
||||
20250222-10:26:34.877 quickstart INFO: Running ppo-math experiment.
|
||||
20250222-10:44:15.581 quickstart INFO: Logs will be dumped to /storage/ray/experiments/logs/root/ppo-7B-distill-gpus-128/512x16
|
||||
20250222-10:44:15.581 quickstart INFO: Model checkpoints will be saved to /storage/ray/experiments/checkpoints/root/ppo-7B-distill-gpus-128/512x16
|
||||
20250222-10:26:36.408 quickstart INFO: Launching experiments with RAY...
|
||||
```
|
||||
|
||||
如果运行过程中出现错误(比如出现 Error 关键字),请参考Troubleshooting解决。
|
||||
|
||||
## Commandline Options
|
||||
|
||||
```bash
|
||||
python3 -m realhf.apps.quickstart ppo-math --help
|
||||
```
|
||||
|
||||
其中重要的参数的说明如下:
|
||||
|
||||
+ mode:总是为 ray,参考本教程进行训练时不要改成其他值。
|
||||
+ {actor|critic|ref}.path:模型的路径
|
||||
+ dataset.path:数据集 jsonl 文件的路径
|
||||
+ external_configs.cluster_config:设置 cluster_config 的配置,比如 fileroot 是存放训练输出的根目录。
|
||||
|
||||
+ n_nodes:节点数量
|
||||
+ n_gpus_per_node:每个节点的 GPU 数量
|
||||
+ allocation_mode:实验中模型的 GPU 分配和 3D 并行策略,推荐的策略有以下形式:
|
||||
+ `sglang.d${DP1}m${TP1}p${PP1}+d${DP2}m${TP2}p${PP2}`: 分别配置 SGLang 生成和训练的并行策略,生成和训练分离,使用两部分不同的 GPU。二者所用的GPU数量相加要等于总的 GPU 数量,即 DP1xTP1xPP1+DP2xTP2xPP2=#GPUs。
|
||||
|
||||
+ exp_ctrl.total_train_epochs:训练的 epoch 数量(即迭代整个数据集的次数)
|
||||
+ exp_ctrl.save_freq_{epochs|steps|secs}:保存持久化存储模型参数的频率,如果设成 null 会不保存模型
|
||||
+ exp_ctrl.ckpt_freq_{epochs|steps|secs}:保存临时参数用于重启的频率
|
||||
+ dataset.train_bs_n_seqs:训练的批量大小,即每次训练需要采样的 prompt 数量
|
||||
+ group_size:每个 prompt 需要采样的答案数量
|
||||
+ {actor_train|ref_inf}.mb_spec.max_tokens_per_mb:reference模型推理和actor模型训练每次forward/backward数据中最大的token数量,可以减小以避免OOM错误。这些数据会累积梯度进行一次参数更新。
|
||||
+ ppo.ppo_n_minibatches:每次PPO更新中会把所有数据划分成多少份以此进行loss计算和参数更新。
|
||||
+ ppo.gen.max_new_tokens:每条prompt生成的最大token数,默认训练脚本中为16k。
|
||||
+ ppo.gen.min_new_tokens:每条prompt生成的最小token数,默认为0。
|
||||
|
||||
|
||||
## 过程观测
|
||||
这里以 16 节点的运行日志为例(1 节点和 4 节点也一样),说明几个观察训练进度和效果的方法。
|
||||
|
||||
### 查看训练进度
|
||||
|
||||
搜索日志中的 Epoch 关键字,查看总的 Epoch 数量和 Step 数量:
|
||||
|
||||
```bash
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:11:56.997 master worker INFO: Epoch 1/1 step 1/19 (global step 1) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2124.429*s. Total time consumption: 2283.862s.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:52:02.719 master worker INFO: Epoch 1/1 step 2/19 (global step 2) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2405.716*s. Total time consumption: 4689.584s.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-12:27:25.084 master worker INFO: Epoch 1/1 step 3/19 (global step 3) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2122.318*s. Total time consumption: 6811.949s. Estimated remaining time: 33957.093s.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:05:58.246 master worker INFO: Epoch 1/1 step 4/19 (global step 4) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2313.134*s. Total time consumption: 9125.111s. Estimated remaining time: 33265.891s.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:44:14.349 master worker INFO: Epoch 1/1 step 5/19 (global step 5) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2296.076*s. Total time consumption: 11421.214s. Estimated remaining time: 31413.800s.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:22:33.864 master worker INFO: Epoch 1/1 step 6/19 (global step 6) finishes. Average #tokens per batch is 111847. #End to end# execution time: *2299.448*s. Total time consumption: 13720.729s. Estimated remaining time: 29350.673s.
|
||||
```
|
||||
|
||||
出现了 6 条日志信息,以最后一条信息的内容说明各个字段的含义:
|
||||
+ `Epoch 1/1`:表示总共需要训练 1 个 Epochs,当前在训练第 1 个。这里作为例子总共只训练 1 个 Epoch,正常训练应该是 10 个 Epochs 或者更多。
|
||||
+ `step 6/19`:表示当前 Epoch 有 19 个 Steps,当前在训练第 6 个
|
||||
+ `global step 6`: 表示当前 Step 在所有 Epochs 的 Steps 里的序号
|
||||
+ `#End to end# execution time: *2299.448*s`:表示当前 Step 训练耗费了 2299.448 秒
|
||||
+ `Total time consumption: 13720.729s`:从训练启动开始一共耗费了 13720.729 秒
|
||||
+ `Estimated remaining time: 29350.673s`:预计完成训练还需要 29350.673 秒
|
||||
|
||||
|
||||
### 查看训练的效果
|
||||
|
||||
搜索日志中的 `task_reward` 关键字
|
||||
|
||||
```bash
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:11:56.991 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.2640759198111482e-05, 'actor_loss': 1.1128166761409375e-06, 'actor_clip_ratio': 2.1122002635820536e-07, 'importance_weight': 1.0000014305114746, 'task_reward': -0.2996826171875, 'kl_reward': -2.27004832709099e-07, 'final_reward': -0.30145370960235596, 'advantage': 0.003593671601265669, 'avg_seq_len': 7907.8955078125, 'avg_prompt_len': 105.845703125, 'n_tokens': 127828786.0, 'n_valid_tokens': 127828786.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.122802734375, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:52:02.712 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.493159263394773e-05, 'actor_loss': -3.846728588996484e-07, 'actor_clip_ratio': 3.16789424914532e-07, 'importance_weight': 0.9999996423721313, 'task_reward': -0.6793212890625, 'kl_reward': -2.536311853873485e-07, 'final_reward': -0.6813737154006958, 'advantage': 0.004844569601118565, 'avg_seq_len': 8203.9453125, 'avg_prompt_len': 111.892578125, 'n_tokens': 132580185.0, 'n_valid_tokens': 132580185.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.13812255859375, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-12:27:25.077 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.572356243035756e-05, 'actor_loss': -5.036404786551429e-07, 'actor_clip_ratio': 1.8960582792715286e-07, 'importance_weight': 0.9999992251396179, 'task_reward': -0.6280517578125, 'kl_reward': -2.988609537624143e-07, 'final_reward': -0.6303607225418091, 'advantage': 0.004505862481892109, 'avg_seq_len': 7834.6328125, 'avg_prompt_len': 108.900390625, 'n_tokens': 126578395.0, 'n_valid_tokens': 126578395.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.11761474609375, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:05:58.239 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.4861981728463434e-05, 'actor_loss': 1.3935685672095133e-07, 'actor_clip_ratio': 3.02603467616791e-07, 'importance_weight': 0.9999998807907104, 'task_reward': -0.78857421875, 'kl_reward': -3.672174671009998e-07, 'final_reward': -0.791388750076294, 'advantage': 0.005053278990089893, 'avg_seq_len': 7773.39404296875, 'avg_prompt_len': 108.7890625, 'n_tokens': 125576883.0, 'n_valid_tokens': 125576883.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.117919921875, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:44:14.342 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.516058702894952e-05, 'actor_loss': -7.665488510610885e-07, 'actor_clip_ratio': 1.9505058901359007e-07, 'importance_weight': 0.9999997615814209, 'task_reward': -0.6158447265625, 'kl_reward': -4.6867208425283025e-07, 'final_reward': -0.6195111274719238, 'advantage': 0.004475570283830166, 'avg_seq_len': 7928.50830078125, 'avg_prompt_len': 105.517578125, 'n_tokens': 128171874.0, 'n_valid_tokens': 128171874.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.12353515625, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:22:33.857 master worker INFO: RPC name actor_train returns {'ppo_approx_kl': -2.4821250917739235e-05, 'actor_loss': -3.922649227661168e-07, 'actor_clip_ratio': 3.323623900541861e-07, 'importance_weight': 1.0000001192092896, 'task_reward': -0.7025146484375, 'kl_reward': -5.863367960046162e-07, 'final_reward': -0.7071446776390076, 'advantage': 0.004277692176401615, 'avg_seq_len': 8002.4873046875, 'avg_prompt_len': 105.951171875, 'n_tokens': 129376851.0, 'n_valid_tokens': 129376851.0, 'n_seqs': 16384.0, 'no_eos_ratio': 0.12286376953125, 'disable_value': 1.0, 'mask_no_eos_with_zero': 0.0}
|
||||
```
|
||||
|
||||
以最后一条说明其中几个重点字段的含义:
|
||||
+ `task_reward`:这个step中采样的所有答案的平均奖励值,训练稳步进行的话这个值会持续上升,最终维持不变
|
||||
+ `importance_weight`: PPO loss中重要性采样比率在所有token上的平均值,通常接近1。
|
||||
+ `actor_clip_ratio`: PPO loss中被clip掉的token占所有token的比率,通常小于0.1。
|
||||
+ `actor_loss`: PPO loss,**不会随着训练过程有明显的上升或下降趋势**,不应作为模型表现的参考。
|
||||
+ `avg_seq_len`: 这一步中采样的所有序列(即提示词和答案相加)的平均长度。在完整的多阶段训练中,这个值会先下降再上升。
|
||||
+ `no_eos_ratio`: 这一步中采样的所有答案因为超出最大生成长度被截断的比率。这个值上升也代表了答案的平均长度在上升。
|
||||
|
||||
# 评估
|
||||
|
||||
## 评估流程
|
||||
|
||||
评估代码包含在仓库的`evaluation`文件夹中。按照以上的教程,训练得到的checkpoint会保存在`/storage/ray/experiments/checkpoints/root/`路径下,例如`/storage/ray/experiments/checkpoints/root/ppo-zero-distill-7B-n16/1024x16-n16/actor/epoch1epochstep20globalstep20/`。
|
||||
|
||||
启动一个新的容器用于运行评估脚本(评估需要更新部分 python 库,请不要在训练容器中进行):
|
||||
```
|
||||
docker run -d --name r1-eval --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.2.0 /bin/bash -c "tail -f /dev/null"
|
||||
docker exec -it r1-eval bash
|
||||
```
|
||||
|
||||
在docker容器内部运行以下脚本进行评估:
|
||||
|
||||
```bash
|
||||
cd /storage/codes/AReaL/evaluation
|
||||
cd latex2sympy
|
||||
pip install -e .
|
||||
cd ..
|
||||
pip install -r requirements.txt
|
||||
pip install vllm --no-build-isolation
|
||||
pip install transformers==4.47.0
|
||||
pip install prettytable timeout_decorator
|
||||
mkdir /storage/ray/eval_output/
|
||||
nohup python eval_and_aggregate.py \
|
||||
--model_path /storage/ray/experiments/checkpoints/root/ppo-zero-distill-7B-n16/1024x16-n16/actor/epoch1epochstep20globalstep20/ \
|
||||
--output_path /storage/ray/eval_output/ \
|
||||
--data_names "math_500,aime24,amc23" \
|
||||
--max_gen_tokens 32768 &> /storage/ray/eval_output/eval_and_aggregate_parallel.log &
|
||||
```
|
||||
|
||||
+ `--model_path`:模型参数的保存路径
|
||||
+ `--output_path`:评估过程中生成的答案和日志文件路径
|
||||
+ `--data_names`: 可以指定评测某个数据,多个数据集用逗号隔开,默认为 math_500, aime24, amc23
|
||||
+ `--max_gen_tokens`:最长的答案生成长度,默认值 32768
|
||||
|
||||
## 评估结果
|
||||
|
||||
评估脚本运行完后会在 /storage/ray/eval_output/eval_and_aggregate_parallel.log 日志文件输出一个表格,例如:
|
||||
|
||||
```
|
||||
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
|
||||
| dataset | num_questions | greedy_length | sample_length | greedy_acc | sample_pass@1 | pass@8 | pass@16 |
|
||||
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
|
||||
| math_500 | 500 | 6757.4 | 4139.5 | 84.4 | 92.7 | 97.3 | 97.7 |
|
||||
| aime24 | 30 | 19328.0 | 13663.5 | 50.0 | 50.4 | 77.3 | 80.0 |
|
||||
| amc23 | 40 | 8850.0 | 6526.2 | 80.0 | 90.5 | 96.8 | 98.8 |
|
||||
+----------+---------------+---------------+---------------+------------+---------------+--------+---------+
|
||||
```
|
||||
|
||||
+ `{greedy|sample}_length`: 在greedy或随机采样策略下生成的平均答案长度
|
||||
+ `greedy_acc`:在greedy采样下的平均准确率
|
||||
+ `sample_pass@{k}`:在随机采样下平均每k个答案产生正确答案的概率
|
||||
|
||||
## 额外说明
|
||||
|
||||
### 关键参数
|
||||
|
||||
+ 我们提供的评估脚本默认采样32次取平均值,采样温度值为0.6
|
||||
+ 我们发现vLLM的`enforce_eager`参数很大程度影响评估性能,当`enforce_eager=True`时我们才能够复现先前工作汇报的模型表现,否则评估结果会低于先前工作汇报的结果,因此我们会在执行 `eval_and_aggregate_parallel.py` 时将`enforce_eager`强制开启。
|
||||
|
||||
由于以上原因,评估过程通常会消耗较长时间。
|
||||
|
||||
### 运行时间
|
||||
评估的运行时间取决于最长生成长度、数据集的题目数量和模型大小等等。在1台8*H100机器上,7B模型,数据集为`math_500,aime24,amc23`,生成长度为32768,评估脚本运行时间为 5 个小时。
|
||||
|
||||
|
||||
# Troubleshooting
|
||||
|
||||
如果以下内容没有解答你的问题,欢迎在 GitHub Issue 中进行提问。
|
||||
|
||||
## 自动恢复
|
||||
|
||||
当设置了 `recover_mode=auto` 并且训练配置和之前相同,AReaL 会尝试找到之前生成的 checkpoints 并且从这个 checkpoints 恢复训练。
|
||||
|
||||
如果自动恢复失败,有这些可能性:
|
||||
|
||||
+ 训练配置里的 `experiment_name` 和 `trial_name` 与之前的不一样
|
||||
+ Batch Size(参数里的 `dataset.train_bs_n_seqs`),Group Size(参数里的 `group_size`),节点数(参数里的 `n_nodes`)三个值发生了变化
|
||||
+ 之前的训练没有创建过 recover checkpoint 。默认的 recover checkpoint 规则有 2 个:
|
||||
+ 从第 2 个 step 完成后才生成 recover checkpoint
|
||||
+ 一个 step 训练完成,且距离上次 recover checkpoint 时间超过 600s,则生成一个新的 recover checkpoint。这个参数在 `./examples/configs/*/*.yaml` 文件里,参数名为 :`exp_ctrl.ckpt_freq_secs=600`。
|
||||
|
||||
|
||||
可以通过搜索 `Dumped recover` 确认是否生成过 recover checkpoint
|
||||
|
||||
```bash
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-11:52:02.760 master worker INFO: Dumped recover info to file.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-12:27:25.105 master worker INFO: Dumped recover info to file.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:05:58.264 master worker INFO: Dumped recover info to file.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-13:44:14.411 master worker INFO: Dumped recover info to file.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:22:33.883 master worker INFO: Dumped recover info to file.
|
||||
(master_worker/0 pid=96390, ip=xxx.xxx.xxx.xxx) 20250222-14:59:44.925 master worker INFO: Dumped recover info to file.
|
||||
```
|
||||
|
||||
## 一系列OutOfMemory错误
|
||||
|
||||
我们提供的脚本已经尽最大努力避免了OOM错误的发生,但是OOM问题仍然会随着训练进行,在内存碎片增加和生成序列长度越来越长时偶尔发生。虽然这些问题通常可以通过自动重启解决,当重启频繁时,用户还可以尝试以下针对性的解决方式。
|
||||
|
||||
### torch.cuda.CudaOutOfMemoryError
|
||||
|
||||
解决这个问题的关键是定位错误发生的阶段。
|
||||
|
||||
- 如果发生在初始化阶段(在进入到actor_gen之前):
|
||||
- 检查当前GPU上是否存在残留进程。在分布式场景下,可以通过重启ray cluster解决;在单机场景下,可以通过pkill解决。
|
||||
- 该错误通常不会发生在actor_gen阶段。
|
||||
- 如果发生在ref_inf或actor_train阶段
|
||||
- 改变相应计算任务的microbatch大小,例如`actor_train.mb_spec.max_tokens_per_mb=20480`,这个参数代表每次模型forward/backward的数据最多只会包含20480个token,这个值最小可以设为生成序列的最长长度(包括prompt)
|
||||
- 改变模型的并行策略,即`allocation_mode`,可以尝试减少数据并行的大小,增加张量或流水线并行的大小。
|
||||
|
||||
### CUDA error: out of memory
|
||||
|
||||
这个问题可能会发生在vLLM初始化CPU KV cache时,表示每台机器的内存不够了。可以减小`actor.vllm.swap_space`解决。
|
||||
|
||||
### RuntimeError: Aborted due to the lack of CPU swap space.
|
||||
|
||||
问题的原因是序列长、对KV cache需求大,在GPU显存不够时KV cache会被卸载到内存,而内存中设置的swap space不够。这个问题和[Preemption的报错](https://docs.vllm.ai/en/latest/performance/optimization.html)紧密相关。解决方案是增加`actor.vllm.swap_space`,如果同样的错误出现,请减少`actor.vllm.max_num_seqs`并参考[vLLM官方文档](https://docs.vllm.ai/en/latest/performance/optimization.html)。
|
||||
|
||||
### CUDA error: an illegal memory access was encountered
|
||||
|
||||
通常会在vLLM生成阶段出现,同样是显存不足的一种表现。解决方案包括:
|
||||
|
||||
|
||||
+ 减小训练batch size或者每个prompt生成的答案数量,但减小后会降低样本效率、延长训练时间
|
||||
+ [将vLLM的attention backend换成xformers](https://github.com/vllm-project/vllm/issues/5376)
|
||||
|
|
@ -1,3 +1,501 @@
|
|||
diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py
|
||||
index 60091b9a..7a1c856b 100644
|
||||
--- a/python/sglang/srt/layers/logits_processor.py
|
||||
+++ b/python/sglang/srt/layers/logits_processor.py
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
+import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -44,6 +45,15 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
)
|
||||
from sglang.srt.utils import dump_to_file
|
||||
|
||||
+# When compute the input and output tokens logprobs, if the rows of the
|
||||
+# logprobs are too large, the peak memory usage will be too high. For example,
|
||||
+# if the logprobs are [10000, 150000], the peak memory usage will be greater
|
||||
+# than 10000 * 150000 * 4 / 1024 / 1024 = 5722.05 MB. (4 is the size of float)
|
||||
+# So we split the logprobs into multiple chunks.
|
||||
+LOGITS_PROCESSER_CHUNK_SIZE = int(
|
||||
+ os.environ.get("SGLANG_LOGITS_PROCESSER_CHUNK_SIZE", "2048")
|
||||
+)
|
||||
+
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -286,15 +296,17 @@ class LogitsProcessor(nn.Module):
|
||||
input_logprob_indices = None
|
||||
else:
|
||||
# Input logprobs are required.
|
||||
- # Find 3 different indices.
|
||||
+ # Find 4 different indices.
|
||||
# 1. pruned_states: hidden states that we want logprobs from.
|
||||
# 2. sample_indices: Indices that have sampled tokens.
|
||||
# 3. input_logprob_indices: Indices that have input logprob tokens.
|
||||
+ # 4. sequence_index_mapping: map pruned_states indices to top_logprobs_nums and token_ids_logprobs indices.
|
||||
sample_index_pt = -1
|
||||
sample_indices = []
|
||||
input_logprob_indices_pt = 0
|
||||
input_logprob_indices = []
|
||||
pt, pruned_states = 0, []
|
||||
+ idx, sequence_index_mapping = 0, []
|
||||
for extend_logprob_start_len, extend_len in zip(
|
||||
logits_metadata.extend_logprob_start_lens_cpu,
|
||||
logits_metadata.extend_seq_lens_cpu,
|
||||
@@ -310,7 +322,11 @@ class LogitsProcessor(nn.Module):
|
||||
# by a caller.
|
||||
assert extend_len > start_len
|
||||
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
||||
+ # sequence_index_mapping, repeat this for loop index
|
||||
+ sequence_index_mapping.extend([idx] * (extend_len - start_len))
|
||||
+ idx += 1
|
||||
pt += extend_len
|
||||
+
|
||||
sample_index_pt += extend_len - start_len
|
||||
sample_indices.append(sample_index_pt)
|
||||
input_logprob_indices.extend(
|
||||
@@ -321,6 +337,7 @@ class LogitsProcessor(nn.Module):
|
||||
)
|
||||
input_logprob_indices_pt += extend_len - start_len
|
||||
|
||||
+ sequence_index_mapping.append(idx - 1)
|
||||
pruned_states = torch.cat(pruned_states)
|
||||
sample_indices = torch.tensor(
|
||||
sample_indices, device=pruned_states.device, dtype=torch.int64
|
||||
@@ -329,12 +346,6 @@ class LogitsProcessor(nn.Module):
|
||||
input_logprob_indices, device=pruned_states.device, dtype=torch.int64
|
||||
)
|
||||
|
||||
- # Compute logits for both input and sampled tokens.
|
||||
- logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
||||
- sampled_logits = (
|
||||
- logits[sample_indices] if sample_indices is not None else logits
|
||||
- )
|
||||
-
|
||||
if self.debug_tensor_dump_output_folder:
|
||||
assert (
|
||||
not self.do_tensor_parallel_all_gather
|
||||
@@ -370,67 +381,176 @@ class LogitsProcessor(nn.Module):
|
||||
else:
|
||||
assert False, "Should never reach"
|
||||
|
||||
+ del hidden_states
|
||||
+
|
||||
if not logits_metadata.extend_return_logprob:
|
||||
+ # Compute logits for both input and sampled tokens.
|
||||
+ logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
||||
+ sampled_logits = (
|
||||
+ logits[sample_indices] if sample_indices is not None else logits
|
||||
+ )
|
||||
+
|
||||
# Decode mode or extend mode without return_logprob.
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=sampled_logits,
|
||||
hidden_states=hidden_states_to_store,
|
||||
)
|
||||
else:
|
||||
- input_logprobs = logits[input_logprob_indices]
|
||||
- del hidden_states, logits
|
||||
+ # Compute logprobs requires lot of memory, so we split pruned_states
|
||||
+ # into chunks of rows to compute input_logprobs separately, then
|
||||
+ # concatenate the results.
|
||||
+ return self._compute_output_by_chunk(
|
||||
+ pruned_states,
|
||||
+ sample_indices,
|
||||
+ hidden_states_to_store,
|
||||
+ input_logprob_indices,
|
||||
+ sequence_index_mapping,
|
||||
+ lm_head,
|
||||
+ logits_metadata,
|
||||
+ )
|
||||
|
||||
- # Normalize the logprob w/o temperature, top-p
|
||||
- pruned_lens = torch.tensor(
|
||||
- logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||
- device=input_logprobs.device,
|
||||
+ def _compute_output_by_chunk(
|
||||
+ self,
|
||||
+ pruned_states: torch.Tensor,
|
||||
+ sample_indices: torch.Tensor,
|
||||
+ hidden_states_to_store: Optional[torch.Tensor],
|
||||
+ input_logprob_indices: torch.Tensor,
|
||||
+ index_mapping: list[int],
|
||||
+ lm_head: VocabParallelEmbedding,
|
||||
+ logits_metadata: LogitsMetadata,
|
||||
+ ) -> LogitsProcessorOutput:
|
||||
+ """
|
||||
+ compute logprobs for the output token from the hidden states.
|
||||
+ To avoid using too much memory, we split pruned_states into chunks of
|
||||
+ rows to compute input_logprobs separately, then concatenate the results.
|
||||
+
|
||||
+ Returns:
|
||||
+ LogitsProcessorOutput: logits processor output class
|
||||
+ """
|
||||
+
|
||||
+ # Normalize the logprob w/o temperature, top-p
|
||||
+ pruned_lens = torch.tensor(
|
||||
+ logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||
+ device=pruned_states.device,
|
||||
+ )
|
||||
+ if logits_metadata.temp_scaled_logprobs:
|
||||
+ logits_metadata.temperature = torch.repeat_interleave(
|
||||
+ logits_metadata.temperature.view(-1),
|
||||
+ pruned_lens,
|
||||
+ ).view(-1, 1)
|
||||
+ if logits_metadata.top_p_normalized_logprobs:
|
||||
+ logits_metadata.top_p = torch.repeat_interleave(
|
||||
+ logits_metadata.top_p,
|
||||
+ pruned_lens,
|
||||
)
|
||||
- if logits_metadata.temp_scaled_logprobs:
|
||||
- logits_metadata.temperature = torch.repeat_interleave(
|
||||
- logits_metadata.temperature.view(-1),
|
||||
- pruned_lens,
|
||||
- ).view(-1, 1)
|
||||
- if logits_metadata.top_p_normalized_logprobs:
|
||||
- logits_metadata.top_p = torch.repeat_interleave(
|
||||
- logits_metadata.top_p,
|
||||
- pruned_lens,
|
||||
- )
|
||||
- input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
||||
- input_logprobs, logits_metadata
|
||||
+
|
||||
+ # The peak memory usage is proportional to the chunk size.
|
||||
+ chunk_size = LOGITS_PROCESSER_CHUNK_SIZE
|
||||
+ num_chunks = (pruned_states.shape[0] + chunk_size - 1) // chunk_size
|
||||
+
|
||||
+ input_token_logprobs = []
|
||||
+ if logits_metadata.extend_return_top_logprob:
|
||||
+ input_top_logprobs_val = []
|
||||
+ input_top_logprobs_idx = []
|
||||
+ else:
|
||||
+ input_top_logprobs_val = None
|
||||
+ input_top_logprobs_idx = None
|
||||
+
|
||||
+ if logits_metadata.extend_token_ids_logprob:
|
||||
+ input_token_ids_logprobs_val = []
|
||||
+ input_token_ids_logprobs_idx = []
|
||||
+ else:
|
||||
+ input_token_ids_logprobs_val = None
|
||||
+ input_token_ids_logprobs_idx = None
|
||||
+
|
||||
+ # It a single sequence is split into multiple chunks, we need to keep track
|
||||
+ # of the pruned length of the sequences in the previous chunks.
|
||||
+ split_len_topk = 0
|
||||
+ split_len_token_ids = 0
|
||||
+
|
||||
+ for i in range(num_chunks):
|
||||
+ start_idx = i * chunk_size
|
||||
+ end_idx = min((i + 1) * chunk_size, pruned_states.shape[0])
|
||||
+
|
||||
+ # Get indices for this chunk
|
||||
+ chunk_mask = (input_logprob_indices >= start_idx) & (
|
||||
+ input_logprob_indices < end_idx
|
||||
+ )
|
||||
+
|
||||
+ global_indices = input_logprob_indices[chunk_mask]
|
||||
+ chunk_indices = global_indices - start_idx
|
||||
+ chunk_states = pruned_states[start_idx:end_idx]
|
||||
+ chunk_logits = self._get_logits(chunk_states, lm_head, logits_metadata)
|
||||
+
|
||||
+ if chunk_indices.numel() == 0:
|
||||
+ continue
|
||||
+
|
||||
+ # Compute the logprobs of the chunk
|
||||
+ chunk_input_logprobs = chunk_logits[chunk_indices]
|
||||
+ chunk_input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
||||
+ chunk_input_logprobs, global_indices, logits_metadata
|
||||
)
|
||||
|
||||
+ # For each chunk, we need to get the slice of the sequence_index_mapping
|
||||
+ chunk_slice = slice(index_mapping[start_idx], index_mapping[end_idx] + 1)
|
||||
+
|
||||
# Get the logprob of top-k tokens
|
||||
if logits_metadata.extend_return_top_logprob:
|
||||
- (
|
||||
+ split_len_topk = self.get_top_logprobs(
|
||||
+ chunk_input_logprobs,
|
||||
+ logits_metadata,
|
||||
+ chunk_slice,
|
||||
input_top_logprobs_val,
|
||||
input_top_logprobs_idx,
|
||||
- ) = self.get_top_logprobs(input_logprobs, logits_metadata)
|
||||
- else:
|
||||
- input_top_logprobs_val = input_top_logprobs_idx = None
|
||||
+ split_len_topk,
|
||||
+ )
|
||||
|
||||
# Get the logprob of given token id
|
||||
if logits_metadata.extend_token_ids_logprob:
|
||||
- (
|
||||
+ split_len_token_ids = self.get_token_ids_logprobs(
|
||||
+ chunk_input_logprobs,
|
||||
+ logits_metadata,
|
||||
+ chunk_slice,
|
||||
input_token_ids_logprobs_val,
|
||||
input_token_ids_logprobs_idx,
|
||||
- ) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
|
||||
- else:
|
||||
- input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
|
||||
-
|
||||
- input_token_logprobs = input_logprobs[
|
||||
- torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
|
||||
- logits_metadata.extend_input_logprob_token_ids_gpu,
|
||||
- ]
|
||||
+ split_len_token_ids,
|
||||
+ )
|
||||
|
||||
- return LogitsProcessorOutput(
|
||||
- next_token_logits=sampled_logits,
|
||||
- input_token_logprobs=input_token_logprobs,
|
||||
- input_top_logprobs_val=input_top_logprobs_val,
|
||||
- input_top_logprobs_idx=input_top_logprobs_idx,
|
||||
- hidden_states=hidden_states_to_store,
|
||||
- input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
||||
- input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
||||
+ # Handle sampled logits for the chunk if needed
|
||||
+ chunk_sample_mask = (sample_indices >= start_idx) & (
|
||||
+ sample_indices < end_idx
|
||||
)
|
||||
+ if i == 0: # Initialize sampled_logits on first chunk
|
||||
+ sampled_logits = torch.empty(
|
||||
+ (sample_indices.shape[0], chunk_logits.shape[1]),
|
||||
+ dtype=chunk_logits.dtype,
|
||||
+ device=chunk_logits.device,
|
||||
+ )
|
||||
+ if chunk_sample_mask.any():
|
||||
+ chunk_sample_indices = sample_indices[chunk_sample_mask] - start_idx
|
||||
+ sampled_logits[chunk_sample_mask] = chunk_logits[chunk_sample_indices]
|
||||
+
|
||||
+ # Get the logprob of the requested token ids
|
||||
+ chunk_input_token_logprobs = chunk_input_logprobs[
|
||||
+ torch.arange(
|
||||
+ chunk_input_logprobs.shape[0], device=chunk_input_logprobs.device
|
||||
+ ),
|
||||
+ logits_metadata.extend_input_logprob_token_ids_gpu[start_idx:end_idx],
|
||||
+ ]
|
||||
+ input_token_logprobs.append(chunk_input_token_logprobs)
|
||||
+
|
||||
+ # Concatenate the results
|
||||
+ input_token_logprobs = torch.cat(input_token_logprobs, dim=0)
|
||||
+
|
||||
+ return LogitsProcessorOutput(
|
||||
+ hidden_states=hidden_states_to_store,
|
||||
+ next_token_logits=sampled_logits,
|
||||
+ input_token_logprobs=input_token_logprobs,
|
||||
+ input_top_logprobs_val=input_top_logprobs_val,
|
||||
+ input_top_logprobs_idx=input_top_logprobs_idx,
|
||||
+ input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
||||
+ input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
||||
+ )
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
@@ -498,60 +618,142 @@ class LogitsProcessor(nn.Module):
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
- def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
||||
+ def get_top_logprobs(
|
||||
+ logprobs: torch.Tensor,
|
||||
+ logits_metadata: LogitsMetadata,
|
||||
+ chunk_slice: slice,
|
||||
+ input_top_logprobs_val: List,
|
||||
+ input_top_logprobs_idx: List,
|
||||
+ split_pruned_len: int,
|
||||
+ ):
|
||||
+ """Get top-k logprobs for each sequence in the chunk.
|
||||
+
|
||||
+ Args:
|
||||
+ logprobs: Log probabilities tensor of shape [seq_len, vocab_size]
|
||||
+ logits_metadata: Metadata containing top-k and pruned length info
|
||||
+ chunk_slice: Slice of sequences to process
|
||||
+ input_top_logprobs_val: List to store top-k logprob values
|
||||
+ input_top_logprobs_idx: List to store top-k token indices
|
||||
+ split_pruned_len: Length of pruned tokens from previous chunk
|
||||
+
|
||||
+ Returns:
|
||||
+ int: Number of remaining tokens to process in next chunk
|
||||
+ """
|
||||
+
|
||||
max_k = max(logits_metadata.top_logprobs_nums)
|
||||
- ret = all_logprobs.topk(max_k, dim=1)
|
||||
+ ret = logprobs.topk(max_k, dim=1)
|
||||
values = ret.values.tolist()
|
||||
indices = ret.indices.tolist()
|
||||
|
||||
- input_top_logprobs_val, input_top_logprobs_idx = [], []
|
||||
-
|
||||
pt = 0
|
||||
- for k, pruned_len in zip(
|
||||
- logits_metadata.top_logprobs_nums,
|
||||
- logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||
- ):
|
||||
+ next_split_pruned_len = 0
|
||||
+ top_k_nums = logits_metadata.top_logprobs_nums[chunk_slice]
|
||||
+ pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[chunk_slice]
|
||||
+
|
||||
+ for n, (k, pruned_len) in enumerate(zip(top_k_nums, pruned_lens)):
|
||||
+ # Adjust pruned length for first sequence
|
||||
+ if n == 0:
|
||||
+ pruned_len -= split_pruned_len
|
||||
+ else:
|
||||
+ split_pruned_len = 0
|
||||
+
|
||||
if pruned_len <= 0:
|
||||
- input_top_logprobs_val.append([])
|
||||
- input_top_logprobs_idx.append([])
|
||||
+ if n == 0:
|
||||
+ input_top_logprobs_val.append([])
|
||||
+ input_top_logprobs_idx.append([])
|
||||
continue
|
||||
|
||||
- input_top_logprobs_val.append(
|
||||
- [values[pt + j][:k] for j in range(pruned_len)]
|
||||
- )
|
||||
- input_top_logprobs_idx.append(
|
||||
- [indices[pt + j][:k] for j in range(pruned_len)]
|
||||
- )
|
||||
- pt += pruned_len
|
||||
+ val = []
|
||||
+ idx = []
|
||||
+ for j in range(pruned_len):
|
||||
+ # Handle remaining tokens in next chunk if any
|
||||
+ if pt + j >= len(values):
|
||||
+ next_split_pruned_len = split_pruned_len + j
|
||||
+ break
|
||||
+ val.append(values[pt + j][:k])
|
||||
+ idx.append(indices[pt + j][:k])
|
||||
+
|
||||
+ if split_pruned_len <= 0 and len(val) > 0:
|
||||
+ input_top_logprobs_val.append(val)
|
||||
+ input_top_logprobs_idx.append(idx)
|
||||
+ else:
|
||||
+ input_top_logprobs_val[-1].extend(val)
|
||||
+ input_top_logprobs_idx[-1].extend(idx)
|
||||
|
||||
- return input_top_logprobs_val, input_top_logprobs_idx
|
||||
+ pt += pruned_len
|
||||
+ return next_split_pruned_len
|
||||
|
||||
@staticmethod
|
||||
def get_token_ids_logprobs(
|
||||
- all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
|
||||
+ logprobs: torch.Tensor,
|
||||
+ logits_metadata: LogitsMetadata,
|
||||
+ chunk_slice: slice,
|
||||
+ input_token_ids_logprobs_val: List,
|
||||
+ input_token_ids_logprobs_idx: List,
|
||||
+ split_pruned_len: int = 0,
|
||||
):
|
||||
- input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
||||
+ """Get token_ids logprobs for each sequence in the chunk.
|
||||
+
|
||||
+ Args:
|
||||
+ logprobs: Log probabilities tensor of shape [seq_len, vocab_size]
|
||||
+ logits_metadata: Metadata containing token IDs and pruned length info
|
||||
+ chunk_slice: Slice of sequences to process
|
||||
+ input_token_ids_logprobs_val: List to store token logprob values
|
||||
+ input_token_ids_logprobs_idx: List to store token indices
|
||||
+ split_pruned_len: Length of pruned tokens from previous chunk
|
||||
+
|
||||
+ Returns:
|
||||
+ int: Number of remaining tokens to process in next chunk
|
||||
+ """
|
||||
pt = 0
|
||||
- for token_ids, pruned_len in zip(
|
||||
- logits_metadata.token_ids_logprobs,
|
||||
- logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||
+ next_split_pruned_len = 0
|
||||
+ token_ids_logprobs_chunk = logits_metadata.token_ids_logprobs[chunk_slice]
|
||||
+ pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[chunk_slice]
|
||||
+
|
||||
+ for n, (token_ids, pruned_len) in enumerate(
|
||||
+ zip(
|
||||
+ token_ids_logprobs_chunk,
|
||||
+ pruned_lens,
|
||||
+ )
|
||||
):
|
||||
+ # Adjust pruned length for first sequence
|
||||
+ if n == 0:
|
||||
+ pruned_len -= split_pruned_len
|
||||
+ else:
|
||||
+ split_pruned_len = 0
|
||||
+
|
||||
if pruned_len <= 0:
|
||||
- input_token_ids_logprobs_val.append([])
|
||||
- input_token_ids_logprobs_idx.append([])
|
||||
+ if n == 0:
|
||||
+ input_token_ids_logprobs_val.append([])
|
||||
+ input_token_ids_logprobs_idx.append([])
|
||||
continue
|
||||
|
||||
- input_token_ids_logprobs_val.append(
|
||||
- [all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
|
||||
- )
|
||||
- input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
||||
- pt += pruned_len
|
||||
+ val = []
|
||||
+ idx = []
|
||||
+ for j in range(pruned_len):
|
||||
+ # Handle remaining tokens in next chunk if any
|
||||
+ if pt + j >= logprobs.shape[0]:
|
||||
+ next_split_pruned_len = split_pruned_len + j
|
||||
+ break
|
||||
+ if token_ids is not None:
|
||||
+ val.append(logprobs[pt + j, token_ids].tolist())
|
||||
+ idx.append(token_ids)
|
||||
+
|
||||
+ if split_pruned_len <= 0 and len(val) > 0:
|
||||
+ input_token_ids_logprobs_val.append(val)
|
||||
+ input_token_ids_logprobs_idx.append(idx)
|
||||
+ elif len(val) > 0:
|
||||
+ input_token_ids_logprobs_val[-1].extend(val)
|
||||
+ input_token_ids_logprobs_idx[-1].extend(idx)
|
||||
|
||||
- return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
|
||||
+ pt += pruned_len
|
||||
+ return next_split_pruned_len
|
||||
|
||||
@staticmethod
|
||||
def compute_temp_top_p_normalized_logprobs(
|
||||
- last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
||||
+ last_logits: torch.Tensor,
|
||||
+ indices: torch.Tensor,
|
||||
+ logits_metadata: LogitsMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
compute logprobs for the output token from the given logits.
|
||||
@@ -561,19 +763,20 @@ class LogitsProcessor(nn.Module):
|
||||
"""
|
||||
# Scale logits if temperature scaling is enabled
|
||||
if logits_metadata.temp_scaled_logprobs:
|
||||
- last_logits = last_logits / logits_metadata.temperature
|
||||
+ last_logits = last_logits / logits_metadata.temperature[indices]
|
||||
+
|
||||
+ top_p = None
|
||||
+ if logits_metadata.top_p is not None:
|
||||
+ top_p = logits_metadata.top_p[indices]
|
||||
|
||||
# Normalize logprobs if top_p normalization is enabled
|
||||
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
|
||||
- if (
|
||||
- logits_metadata.top_p_normalized_logprobs
|
||||
- and (logits_metadata.top_p != 1.0).any()
|
||||
- ):
|
||||
+ if logits_metadata.top_p_normalized_logprobs and (top_p != 1.0).any():
|
||||
from sglang.srt.layers.sampler import top_p_normalize_probs_torch
|
||||
|
||||
probs = torch.softmax(last_logits, dim=-1)
|
||||
del last_logits
|
||||
- probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
|
||||
+ probs = top_p_normalize_probs_torch(probs, top_p)
|
||||
return torch.log(probs)
|
||||
else:
|
||||
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
|
||||
index 5390668c..db370d19 100644
|
||||
--- a/python/sglang/srt/managers/io_struct.py
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[build-system]
|
||||
requires = ["setuptools>=61.0", "packaging", "torch", "pybind11>=2.10.0", "build>=1.2.1"]
|
||||
requires = ["setuptools>=61.0", "build>=1.2.1"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
|
|
|
@ -3,21 +3,5 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
# Initialize preset config before all submodules.
|
||||
from .base import prologue # isort: skip
|
||||
|
||||
from .api.cli_args import *
|
||||
|
||||
# Re-import these classes for clear documentation,
|
||||
# otherwise the name will have a long prefix.
|
||||
from .api.core.config import ModelName, ModelShardID
|
||||
from .api.core.data_api import SequenceSample
|
||||
from .api.core.dfg import MFCDef
|
||||
from .api.core.model_api import (
|
||||
FinetuneSpec,
|
||||
Model,
|
||||
ModelBackend,
|
||||
ModelInterface,
|
||||
PipelinableEngine,
|
||||
ReaLModelConfig,
|
||||
)
|
||||
from .version import __version__
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import getpass
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field, fields, is_dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
@ -5,7 +6,6 @@ from typing import Dict, List, Optional, Tuple, Type, Union
|
|||
from omegaconf import MISSING
|
||||
|
||||
from realhf.base import pkg_version
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
|
||||
## Data and datasets. ##
|
||||
|
||||
|
@ -847,6 +847,57 @@ class TensorBoardConfig:
|
|||
path: Optional[str] = None
|
||||
|
||||
|
||||
def get_user_tmp():
|
||||
user = getpass.getuser()
|
||||
user_tmp = os.path.join("/home", user, ".cache", "realhf")
|
||||
os.makedirs(user_tmp, exist_ok=True)
|
||||
return user_tmp
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClusterSpecConfig:
|
||||
config_path: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "JSON config path. If not given, use the following CLI args."
|
||||
},
|
||||
)
|
||||
cluster_name: str = field(
|
||||
default="local",
|
||||
metadata={"help": "Name of the cluster. Used to set specific environs."},
|
||||
)
|
||||
fileroot: str = field(
|
||||
default=get_user_tmp(),
|
||||
metadata={
|
||||
"help": "Root for logs and checkpoints. Should be available to all nodes."
|
||||
},
|
||||
)
|
||||
gpu_type: str = field(
|
||||
default="tesla", metadata={"help": "GPU type of the cluster. Used by slurm."}
|
||||
)
|
||||
mount: str = field(
|
||||
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
|
||||
)
|
||||
gpu_image: str = field(default="", metadata={"help": "slurm image for trainers."})
|
||||
cpu_image: str = field(default="", metadata={"help": "slurm image for CPU jobs."})
|
||||
gpu_infer_image: str = field(
|
||||
default="", metadata={"help": "slurm image for LLM inference."}
|
||||
)
|
||||
node_name_prefix: str = field(
|
||||
default="slurmd-", metadata={"help": "Node prefix for a slurm cluster."}
|
||||
)
|
||||
n_nodes: int = field(
|
||||
default=32,
|
||||
metadata={
|
||||
"help": "The size of the cluster. Used to decide slurm hostname suffix."
|
||||
},
|
||||
)
|
||||
n_gpus_per_node: int = field(
|
||||
default=8,
|
||||
metadata={"help": "GPUs per node (physically)."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseExperimentConfig:
|
||||
"""Configuration for quickstart experiments.
|
||||
|
@ -935,21 +986,20 @@ class BaseExperimentConfig:
|
|||
default=1, metadata={"help": "Number of nodes for experiment."}
|
||||
)
|
||||
n_gpus_per_node: int = field(
|
||||
default=cluster_spec.n_gpus_per_node,
|
||||
metadata={"help": "GPUs per node. Total GPUs = n_nodes * n_gpus_per_node."},
|
||||
default=8, metadata={"help": "Number of GPUs per node for this experiment."}
|
||||
)
|
||||
nodelist: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "SLURM nodelist for manual allocation. "
|
||||
"Format: 'NODE01:0,1,2,3' or 'NODE[01-02,03,07],COM08'."
|
||||
"Format: 'slurmd-01:0,1,2,3' or 'slurmd-[01-02,03,07],COM08'."
|
||||
},
|
||||
)
|
||||
exclude: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "SLURM nodelist to exclude from allocation. "
|
||||
"Format: 'NODE01:0,1,2,3' or 'NODE[01-02,03,07],COM08'."
|
||||
"Format: 'slurmd-01:0,1,2,3' or 'slurmd-[01-02,03,07],COM08'."
|
||||
},
|
||||
)
|
||||
seed: int = field(default=1, metadata={"help": "Random seed for reproducibility."})
|
||||
|
@ -996,6 +1046,13 @@ class BaseExperimentConfig:
|
|||
shuffle_dataset: bool = field(
|
||||
default=True, metadata={"help": "Shuffle in each epoch."}
|
||||
)
|
||||
ray_temp_path: str = field(
|
||||
default="/tmp/ray", metadata={"help": "Absolute path for Ray's log."}
|
||||
)
|
||||
cluster: ClusterSpecConfig = field(
|
||||
default_factory=ClusterSpecConfig,
|
||||
metadata={"help": "Cluster specification. Mainly used by slurm."},
|
||||
)
|
||||
|
||||
|
||||
## Configuration options of asynchronous experiments. ##
|
||||
|
@ -1033,7 +1090,7 @@ class AsyncRLOptions:
|
|||
},
|
||||
)
|
||||
flush_request_timeout: int = field(
|
||||
default=120,
|
||||
default=300,
|
||||
metadata={"help": "The timeout of flushing requests upon weight update."},
|
||||
)
|
||||
|
||||
|
|
|
@ -723,6 +723,7 @@ class SequenceSample:
|
|||
class DataBatchMeta:
|
||||
dp_rank: int
|
||||
meta_sample: SequenceSample | None
|
||||
birth_times: List
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
import dataclasses
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import realhf.api.core.dfg as dfg
|
||||
|
@ -26,16 +27,21 @@ from realhf.base import constants, topology
|
|||
from realhf.base.cluster import spec as cluster_spec
|
||||
|
||||
|
||||
class ExpStatus(Enum):
|
||||
RUNNING = "RUNNING"
|
||||
ABORTED = "ABORTED"
|
||||
COMPLETE = "COMPLETE"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Scheduling:
|
||||
# TODO: add partition
|
||||
cpu: int
|
||||
gpu: int
|
||||
mem: int
|
||||
gpu_type: str = "tesla"
|
||||
node_type: str = None
|
||||
nodelist: str = None
|
||||
exclude: str = None
|
||||
container_image: str = cluster_spec.cpu_image
|
||||
container_image: str = None
|
||||
env_vars: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||
# time utils from "https://slurm.schedmd.com/sbatch.html"
|
||||
time_limit: Optional[str] = None # see "--time" option for format
|
||||
|
@ -241,7 +247,7 @@ class ExperimentScheduling:
|
|||
generation_server: TasksGroup | None = None
|
||||
gserver_manager: TasksGroup | None = None
|
||||
rollout_worker: TasksGroup | None = None
|
||||
controller_image: str = cluster_spec.cpu_image
|
||||
controller_image: str = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
|
@ -184,10 +184,10 @@ def make_device_mesh_from_name(
|
|||
"""
|
||||
DeviceMesh name format: <prefix><node_indices>[:<gpu_ids>]
|
||||
slurm_nodelist is the name of slurm nodes the mesh is on, should follow slurm convention,
|
||||
for example "NODE[40-43]" or "NODE[01,11,13-14]" with prefix NODE,
|
||||
for example "slurmd-[40-43]" or "slurmd-[01,11,13-14]" with prefix slurmd-,
|
||||
if n_nodes=1, gpu_ids are the gpu id list delimited by comma if n_gpus < n_gpus_per_node,
|
||||
for example "0,1,2,3" or "0,1". An example of full device mesh name
|
||||
in this situation is "NODE40:0,1,2,3"
|
||||
in this situation is "slurmd-40:0,1,2,3"
|
||||
|
||||
Note: cluster device mesh name must occupy entire nodes.
|
||||
"""
|
||||
|
|
|
@ -16,13 +16,14 @@ from hydra.core.config_store import ConfigStore
|
|||
from omegaconf import MISSING, OmegaConf
|
||||
|
||||
import realhf.api.core.system_api as system_api
|
||||
from realhf.api.cli_args import print_runtime_helper
|
||||
from realhf.base.constants import LOG_ROOT, MODEL_SAVE_ROOT, QUICKSTART_EXPR_CACHE_PATH
|
||||
from realhf.base.constants import init_constants
|
||||
from realhf.base.ray_utils import check_ray_availability
|
||||
from realhf.base.slurm_utils import check_slurm_availability
|
||||
|
||||
|
||||
def kind_reminder(config_name, logger, args):
|
||||
from realhf.base.constants import LOG_ROOT, MODEL_SAVE_ROOT
|
||||
|
||||
logger.info(f"Running {config_name} experiment.")
|
||||
logger.info(
|
||||
f"Logs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name)}"
|
||||
|
@ -81,9 +82,13 @@ def register_quickstart_exp(config_name: str, exp_cls: Callable):
|
|||
trial_name = args.trial_name
|
||||
from realhf.apps.main import main_start, main_stop
|
||||
|
||||
init_constants(args)
|
||||
from realhf.base.constants import LOG_ROOT, QUICKSTART_EXPR_CACHE_PATH
|
||||
|
||||
config_save_path = os.path.join(
|
||||
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
|
||||
)
|
||||
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
|
||||
with open(config_save_path, "w") as f:
|
||||
yaml.dump(
|
||||
dataclasses.asdict(OmegaConf.to_object(args)),
|
||||
|
|
|
@ -11,13 +11,10 @@ import uuid
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
import realhf.api.core.system_api as config_package
|
||||
import realhf.base.constants as constants
|
||||
import realhf.base.logging as logging
|
||||
import realhf.base.name_resolve as name_resolve
|
||||
import realhf.base.names as names
|
||||
import realhf.base.recover as recover
|
||||
import realhf.scheduler.client as sched_client
|
||||
import realhf.system as system
|
||||
from realhf.api.core.system_api import ExpStatus
|
||||
from realhf.base import constants, logging, name_resolve, names, recover
|
||||
from realhf.scheduler.client import JobException, JobState
|
||||
from realhf.scheduler.evaluator import AutomaticEvaluator
|
||||
from realhf.version import get_full_version_with_dirty_description
|
||||
|
@ -55,7 +52,6 @@ def _submit_workers(
|
|||
|
||||
nodelist = sch_cfg.scheduling.nodelist
|
||||
exclude = sch_cfg.scheduling.exclude
|
||||
node_type = sch_cfg.scheduling.node_type
|
||||
container_image = image_name or sch_cfg.scheduling.container_image
|
||||
|
||||
scheduled_jobs.append(
|
||||
|
@ -65,10 +61,8 @@ def _submit_workers(
|
|||
count=sch_cfg.count,
|
||||
cpu=sch_cfg.scheduling.cpu,
|
||||
gpu=sch_cfg.scheduling.gpu,
|
||||
gpu_type=sch_cfg.scheduling.gpu_type,
|
||||
mem=sch_cfg.scheduling.mem,
|
||||
container_image=container_image,
|
||||
node_type=node_type,
|
||||
nodelist=nodelist,
|
||||
exclude=exclude,
|
||||
env_vars=job_environs,
|
||||
|
@ -147,18 +141,15 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
|
|||
|
||||
cluster_spec_path = os.environ.get("CLUSTER_SPEC_PATH", "")
|
||||
if not cluster_spec_path:
|
||||
if args.mode == "slurm":
|
||||
raise ValueError(
|
||||
"Environment variable CLUSTER_SPEC_PATH must be set for slurm mode! "
|
||||
"See example/cluster_config.json for a template."
|
||||
)
|
||||
logger.warning(
|
||||
logger.info(
|
||||
"Environment variable CLUSTER_SPEC_PATH is not set. "
|
||||
"Files of the experiment (logs, checkpoints, cache ...) "
|
||||
"will be saved to temporary directory of the system. "
|
||||
"To change the fileroot, set the fileroot option of your choice in your CLUSTER_SPEC_PATH."
|
||||
"Will use the fileroot specified in CLI args. "
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Environment variable CLUSTER_SPEC_PATH is set. "
|
||||
"Will overwrite the cluster spec in CLI args. "
|
||||
)
|
||||
|
||||
# set env vars
|
||||
BASE_ENVIRONS = constants.get_env_vars(
|
||||
REAL_MODE=args.mode.upper(),
|
||||
|
@ -273,6 +264,29 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
|
|||
)
|
||||
recover_this = recover_this and reason in recover_states
|
||||
|
||||
# Check whether this exception is caused by experiment finish.
|
||||
name = names.experiment_status(
|
||||
constants.experiment_name(), constants.trial_name()
|
||||
)
|
||||
try:
|
||||
exp_status = name_resolve.get(name)
|
||||
recover_this = recover_this and exp_status != str(ExpStatus.COMPLETE)
|
||||
if exp_status == str(ExpStatus.COMPLETE):
|
||||
logger.warning("*" * 100)
|
||||
logger.warning(
|
||||
"*"
|
||||
+ f"Will not recover because the experiment has completed! Congrats!".center(
|
||||
98, " "
|
||||
)
|
||||
+ "*"
|
||||
)
|
||||
logger.warning("*" * 100)
|
||||
except name_resolve.NameEntryNotFoundError:
|
||||
raise name_resolve.NameEntryNotFoundError(
|
||||
f"Experiment status not found during recover. "
|
||||
"This indicates that the master worker is not running. Exit the recover loop."
|
||||
)
|
||||
|
||||
kill_signal = (
|
||||
"SIGKILL" if args.mode == "slurm" else "SIGTERM"
|
||||
) # use sigkill to terminate slurm jobs
|
||||
|
|
|
@ -161,20 +161,6 @@ def launch_hydra_task(
|
|||
if not any("hydra/job_logging=disabled" in x for x in sys.argv):
|
||||
sys.argv.insert(2, "hydra/job_logging=disabled")
|
||||
|
||||
if (
|
||||
"--multirun" in sys.argv
|
||||
or "hydra.mode=MULTIRUN" in sys.argv
|
||||
or "-m" in sys.argv
|
||||
):
|
||||
raise NotImplementedError("Hydra multi-run is not supported.")
|
||||
|
||||
# non-multirun mode, add hydra run dir
|
||||
sys.argv.insert(
|
||||
2,
|
||||
f"hydra.run.dir={cluster_spec.fileroot}/logs/{getpass.getuser()}/"
|
||||
f"{experiment_name}/{trial_name}/hydra-outputs/",
|
||||
)
|
||||
|
||||
sys.argv.pop(1)
|
||||
|
||||
func()
|
||||
|
|
|
@ -20,11 +20,8 @@ from omegaconf import OmegaConf
|
|||
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
from realhf.api.quickstart.entrypoint import (
|
||||
QUICKSTART_CONFIG_CLASSES,
|
||||
QUICKSTART_EXPR_CACHE_PATH,
|
||||
)
|
||||
from realhf.base import cluster, gpu_utils, importing, logging, name_resolve, names
|
||||
from realhf.api.quickstart.entrypoint import QUICKSTART_CONFIG_CLASSES
|
||||
from realhf.base import gpu_utils, importing, logging
|
||||
from realhf.version import get_full_version_with_dirty_description
|
||||
|
||||
logger = logging.getLogger("Main-Workers")
|
||||
|
@ -32,6 +29,7 @@ logger = logging.getLogger("Main-Workers")
|
|||
|
||||
def _patch_external_impl(exp_name, trial_name):
|
||||
import realhf.api.core.system_api as system_api
|
||||
from realhf.base.constants import QUICKSTART_EXPR_CACHE_PATH
|
||||
|
||||
if os.path.exists(QUICKSTART_EXPR_CACHE_PATH):
|
||||
for exp_cache in os.listdir(QUICKSTART_EXPR_CACHE_PATH):
|
||||
|
@ -59,6 +57,12 @@ def main_worker(args):
|
|||
constants.set_experiment_trial_names(args.experiment_name, args.trial_name)
|
||||
_patch_external_impl(args.experiment_name, args.trial_name)
|
||||
|
||||
# Initialize cluster infor from ENV or CLI args.
|
||||
import realhf.api.core.system_api as system_api
|
||||
|
||||
experiment = system_api.make_experiment(name=args.experiment_name)
|
||||
constants.init_constants(experiment)
|
||||
|
||||
worker_index_start = args.jobstep_id * args.wprocs_per_jobstep + args.wproc_offset
|
||||
worker_index_end = min(
|
||||
worker_index_start + args.wprocs_per_jobstep,
|
||||
|
@ -174,6 +178,10 @@ def main_controller(args):
|
|||
trial_name=args.trial_name,
|
||||
)
|
||||
experiment = system_api.make_experiment(args.experiment_name)
|
||||
|
||||
# Initialize cluster infor from ENV or CLI args.
|
||||
constants.init_constants(experiment)
|
||||
|
||||
controller.start(
|
||||
experiment=experiment,
|
||||
ignore_worker_error=args.ignore_worker_error,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
from .prologue import * # isort: skip
|
||||
|
|
|
@ -2,61 +2,62 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import getpass
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
CLUSTER_SPEC_PATH = os.environ.get("CLUSTER_SPEC_PATH", "")
|
||||
|
||||
|
||||
def get_user_tmp():
|
||||
user = getpass.getuser()
|
||||
user_tmp = os.path.join("/home", user, ".cache", "realhf")
|
||||
os.makedirs(user_tmp, exist_ok=True)
|
||||
return user_tmp
|
||||
if TYPE_CHECKING:
|
||||
from realhf.api.cli_args import BaseExperimentConfig
|
||||
|
||||
|
||||
class ClusterSpec:
|
||||
def __init__(self):
|
||||
# Set default values to comfort ray
|
||||
from realhf.api.cli_args import BaseExperimentConfig
|
||||
|
||||
self.load_spec_from_args(BaseExperimentConfig())
|
||||
|
||||
self.__loaded = False
|
||||
|
||||
def load_spec_from_file(self, file_path: str):
|
||||
try:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Cluster spec file not found: {file_path}")
|
||||
|
||||
with open(file_path, "r") as f:
|
||||
spec: Dict = json.load(f)
|
||||
except FileNotFoundError:
|
||||
if file_path == "":
|
||||
spec = dict(
|
||||
cluster_type="local",
|
||||
cluster_name="local",
|
||||
fileroot=get_user_tmp(),
|
||||
)
|
||||
else:
|
||||
raise FileNotFoundError(f"Cluster spec file not found: {file_path}")
|
||||
|
||||
self.__cluster_type = spec["cluster_type"]
|
||||
self.__cluster_name = spec["cluster_name"]
|
||||
self.__fileroot = spec["fileroot"]
|
||||
self.__node_type_from_node_name_re = spec.get("node_type_from_node_name", None)
|
||||
self.__gpu_type_from_node_name_re = spec.get("gpu_type_from_node_name", None)
|
||||
self.__gpu_type = spec.get("gpu_type", None)
|
||||
self.__default_mount = spec.get("default_mount", None)
|
||||
self.__mount = spec.get("default_mount", None)
|
||||
self.__gpu_image = spec.get("gpu_image", None)
|
||||
self.__gpu_infer_image = spec.get("gpu_infer_image", self.__gpu_image)
|
||||
self.__cpu_image = spec.get("cpu_image", None)
|
||||
self.__node_name_prefix = spec.get("node_name_prefix", "NODE")
|
||||
self.__node_name_prefix = spec.get("node_name_prefix", "slurmd-")
|
||||
# self.__n_nodes decides number of digits in slurm hostnames
|
||||
# e.g. if __n_nodes = 32, then the hostnames will be NODE{:02d}
|
||||
# if __n_nodes = 128, then the hostnames will be NODE{:03d}
|
||||
# e.g. if __n_nodes = 32, then the hostnames will be slurmd-{:02d}
|
||||
# if __n_nodes = 128, then the hostnames will be slurmd-{:03d}
|
||||
self.__n_nodes = int(spec.get("n_nodes", 32))
|
||||
self.__n_gpus_per_node = int(spec.get("n_gpus_per_node", 8))
|
||||
assert isinstance(self.__n_nodes, int)
|
||||
|
||||
self.__loaded = True
|
||||
|
||||
def load_spec_from_args(self, args: "BaseExperimentConfig"):
|
||||
self.__cluster_type = args.mode
|
||||
self.__cluster_name = args.cluster.cluster_name
|
||||
self.__fileroot = args.cluster.fileroot
|
||||
self.__gpu_type = args.cluster.gpu_type
|
||||
self.__mount = args.cluster.mount
|
||||
self.__gpu_image = args.cluster.gpu_image
|
||||
self.__gpu_infer_image = args.cluster.gpu_infer_image
|
||||
self.__cpu_image = args.cluster.cpu_image
|
||||
self.__node_name_prefix = args.cluster.node_name_prefix
|
||||
self.__n_nodes = args.cluster.n_nodes
|
||||
self.__n_gpus_per_node = args.cluster.n_gpus_per_node
|
||||
self.__loaded = True
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
assert self.__loaded
|
||||
|
@ -67,32 +68,6 @@ class ClusterSpec:
|
|||
assert self.__loaded
|
||||
return self.__gpu_type
|
||||
|
||||
def node_type_from_node_name(self, node_name: str) -> str:
|
||||
"""Mapping nodename to slurm node type, including "g1", "g2", "g8",
|
||||
"a100"."""
|
||||
if self.__cluster_type != "slurm":
|
||||
raise NotImplementedError(
|
||||
"Only slurm cluster uses node_type_from_node_name."
|
||||
)
|
||||
assert self.__loaded
|
||||
for regex, node_type in self.__node_type_from_node_name_re.items():
|
||||
if re.match(regex, node_name):
|
||||
return node_type
|
||||
raise NotImplementedError(node_name)
|
||||
|
||||
def gpu_type_from_node_name(self, node_name: str) -> str:
|
||||
"""Mapping nodename to slurm GPU type, including "geforce" and
|
||||
"tesla"."""
|
||||
if self.__cluster_type != "slurm":
|
||||
raise NotImplementedError(
|
||||
"Only slurm cluster uses gpu_type_from_node_name."
|
||||
)
|
||||
assert self.__loaded
|
||||
for regex, gpu_type in self.__gpu_type_from_node_name_re.items():
|
||||
if re.match(regex, node_name):
|
||||
return gpu_type
|
||||
raise NotImplementedError(node_name)
|
||||
|
||||
@property
|
||||
def fileroot(self) -> str:
|
||||
"""Return the root directory of the file system in the cluster.
|
||||
|
@ -109,11 +84,11 @@ class ClusterSpec:
|
|||
self.__fileroot = root
|
||||
|
||||
@property
|
||||
def default_mount(self) -> str:
|
||||
def mount(self) -> str:
|
||||
"""Directories that should be mounted to container that runs
|
||||
workers."""
|
||||
assert self.__loaded
|
||||
return self.__default_mount
|
||||
return self.__mount
|
||||
|
||||
@property
|
||||
def gpu_image(self) -> str:
|
||||
|
@ -156,23 +131,15 @@ class ClusterSpec:
|
|||
return self.__cluster_type
|
||||
|
||||
|
||||
def node_name_is_node_type(
|
||||
node_name: str, node_type: Optional[Union[List[str], str]] = None
|
||||
) -> bool:
|
||||
assert spec is not None
|
||||
if node_type is None:
|
||||
return True
|
||||
if not isinstance(node_type, list):
|
||||
node_type = [node_type]
|
||||
nt_condition = []
|
||||
for nt in node_type:
|
||||
if nt not in ["g1", "g2", "g8", "a100"]:
|
||||
raise ValueError(f"Unknown node type {nt}.")
|
||||
else:
|
||||
cond = spec.node_type_from_node_name(node_name) == nt
|
||||
nt_condition.append(cond)
|
||||
return any(nt_condition)
|
||||
|
||||
|
||||
spec = ClusterSpec()
|
||||
spec.load_spec_from_file(CLUSTER_SPEC_PATH)
|
||||
|
||||
|
||||
def init_cluster_spec(args: "BaseExperimentConfig"):
|
||||
global spec
|
||||
CLUSTER_SPEC_PATH = os.environ.get("CLUSTER_SPEC_PATH", "")
|
||||
if args.cluster.config_path:
|
||||
spec.load_spec_from_file(args.cluster.config_path)
|
||||
elif CLUSTER_SPEC_PATH:
|
||||
spec.load_spec_from_file(CLUSTER_SPEC_PATH)
|
||||
else:
|
||||
spec.load_spec_from_args(args)
|
||||
|
|
|
@ -11,17 +11,18 @@ import os
|
|||
import pathlib
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from realhf.api.cli_args import BaseExperimentConfig
|
||||
from realhf.api.core.config import ModelName
|
||||
from realhf.api.core.system_api import ModelShardID
|
||||
from realhf.base.topology import ParallelGrid, ProcessTopology
|
||||
|
@ -68,24 +69,52 @@ TORCH_FORCE_CPU = False
|
|||
|
||||
# constants in experiment instance scope
|
||||
LOCAL_CACHE_DIR = "/tmp/realhf"
|
||||
MODEL_SAVE_ROOT = f"{cluster_spec.fileroot}/checkpoints/{getpass.getuser()}"
|
||||
LOG_ROOT = f"{cluster_spec.fileroot}/logs/{getpass.getuser()}"
|
||||
RECOVER_ROOT = f"{cluster_spec.fileroot}/recover/{getpass.getuser()}"
|
||||
SLURM_LOCK_FILE_NAME = f"{cluster_spec.fileroot}/logs/slurm_scheduler.lock"
|
||||
PORT_LOCK_FILE_ROOT = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/ports"
|
||||
PYTORCH_KERNEL_CACHE_PATH = (
|
||||
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
|
||||
)
|
||||
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
|
||||
DATASET_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/datasets"
|
||||
PROFILER_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/profiler"
|
||||
PARAM_REALLOC_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/param_realloc"
|
||||
SGLANG_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/sglang"
|
||||
TORCH_EXTENSIONS_DIR = (
|
||||
QUICKSTART_EXPR_CACHE_PATH = str(Path(__file__).parent.parent.parent / ".cache")
|
||||
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True)
|
||||
|
||||
# Global constants that should be initialized after cluster initialization.
|
||||
MODEL_SAVE_ROOT = None
|
||||
LOG_ROOT = None
|
||||
RECOVER_ROOT = None
|
||||
SLURM_LOCK_FILE_NAME = None
|
||||
PORT_LOCK_FILE_ROOT = None
|
||||
DATASET_CACHE_PATH = None
|
||||
PROFILER_CACHE_PATH = None
|
||||
PARAM_REALLOC_PATH = None
|
||||
SGLANG_CACHE_PATH = None
|
||||
TORCH_EXTENSIONS_DIR = None
|
||||
BASE_ENVIRONS = None
|
||||
|
||||
|
||||
def init_constants(args: "BaseExperimentConfig"):
|
||||
from realhf.base.cluster import init_cluster_spec
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
|
||||
init_cluster_spec(args)
|
||||
|
||||
globals_dict = globals() # Get module's global variables
|
||||
|
||||
kwargs = dict(
|
||||
MODEL_SAVE_ROOT=f"{cluster_spec.fileroot}/checkpoints/{getpass.getuser()}",
|
||||
LOG_ROOT=f"{cluster_spec.fileroot}/logs/{getpass.getuser()}",
|
||||
RECOVER_ROOT=f"{cluster_spec.fileroot}/recover/{getpass.getuser()}",
|
||||
SLURM_LOCK_FILE_NAME=f"{cluster_spec.fileroot}/logs/slurm_scheduler.lock",
|
||||
PORT_LOCK_FILE_ROOT=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/ports",
|
||||
DATASET_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/datasets",
|
||||
PROFILER_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/profiler",
|
||||
PARAM_REALLOC_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/param_realloc",
|
||||
SGLANG_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/sglang",
|
||||
TORCH_EXTENSIONS_DIR=(
|
||||
f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/torch/extensions"
|
||||
)
|
||||
QUICKSTART_EXPR_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/"
|
||||
BASE_ENVIRONS = {
|
||||
),
|
||||
)
|
||||
BASE_ENVIRONS = {
|
||||
# "PYTHONPATH": "/realhf",
|
||||
"REAL_IS_REMOTE": "1",
|
||||
# "NCCL_P2P_DISABLE": "1",
|
||||
|
@ -94,7 +123,7 @@ BASE_ENVIRONS = {
|
|||
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
|
||||
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
||||
"TOKENIZERS_PARALLELISM": "true",
|
||||
"TORCH_EXTENSIONS_DIR": TORCH_EXTENSIONS_DIR,
|
||||
"TORCH_EXTENSIONS_DIR": kwargs["TORCH_EXTENSIONS_DIR"],
|
||||
# "TORCH_DISTRIBUTED_DEBUG": "DETAIL",
|
||||
# "NCCL_SOCKET_IFNAME": "ibp71s0",
|
||||
# "GLOO_SOCKET_IFNAME": "ibp71s0",
|
||||
|
@ -124,10 +153,10 @@ BASE_ENVIRONS = {
|
|||
"LC_ALL": "C",
|
||||
"LANG": "C",
|
||||
"NCCL_DEBUG": "WARN",
|
||||
}
|
||||
|
||||
# Set PPU-specific environment variables for stable training.
|
||||
if cluster_spec.name == "wa180":
|
||||
}
|
||||
kwargs["BASE_ENVIRONS"] = BASE_ENVIRONS
|
||||
# Set PPU-specific environment variables for stable training.
|
||||
if cluster_spec.name == "wa180":
|
||||
logger.warning("Detected PPU. Amending PPU-related environment variables.")
|
||||
PPU_ENVIRONS = {
|
||||
"NCCL_DEBUG": "INFO",
|
||||
|
@ -138,8 +167,8 @@ if cluster_spec.name == "wa180":
|
|||
"NCCL_SOCKET_IFNAME": "bond0",
|
||||
"PCCL_STATE_MONITOR_DISABLE": "1",
|
||||
}
|
||||
BASE_ENVIRONS.update(PPU_ENVIRONS)
|
||||
elif cluster_spec.name == "na132":
|
||||
kwargs["BASE_ENVIRONS"].update(PPU_ENVIRONS)
|
||||
elif cluster_spec.name == "na132":
|
||||
# Specific environment variable for h800 cluster na132
|
||||
NV_ENVIRONS = {
|
||||
"NCCL_SOCKET_IFNAME": "bond0",
|
||||
|
@ -154,22 +183,26 @@ elif cluster_spec.name == "na132":
|
|||
"NCCL_SET_THREAD_NAME": "1",
|
||||
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
||||
}
|
||||
BASE_ENVIRONS.update(NV_ENVIRONS)
|
||||
kwargs["BASE_ENVIRONS"].update(NV_ENVIRONS)
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key not in globals_dict:
|
||||
raise ValueError(f"Invalid constant name: {key}")
|
||||
if globals_dict[key] is not None and globals_dict[key] != value:
|
||||
raise RuntimeError(f"Constant '{key}' already initialized!")
|
||||
globals_dict[key] = value
|
||||
|
||||
# make directories if does not exist
|
||||
os.makedirs(globals_dict["PARAM_REALLOC_PATH"], exist_ok=True)
|
||||
os.makedirs(globals_dict["MODEL_SAVE_ROOT"], exist_ok=True)
|
||||
os.makedirs(globals_dict["LOG_ROOT"], exist_ok=True)
|
||||
os.makedirs(globals_dict["RECOVER_ROOT"], exist_ok=True)
|
||||
os.makedirs(globals_dict["DATASET_CACHE_PATH"], exist_ok=True)
|
||||
os.makedirs(globals_dict["PROFILER_CACHE_PATH"], exist_ok=True)
|
||||
os.makedirs(globals_dict["TORCH_EXTENSIONS_DIR"], exist_ok=True)
|
||||
os.makedirs(globals_dict["PORT_LOCK_FILE_ROOT"], exist_ok=True)
|
||||
os.makedirs(globals_dict["SGLANG_CACHE_PATH"], exist_ok=True)
|
||||
|
||||
# make directories if does not exist
|
||||
os.makedirs(PARAM_REALLOC_PATH, exist_ok=True)
|
||||
os.makedirs(MODEL_SAVE_ROOT, exist_ok=True)
|
||||
os.makedirs(LOG_ROOT, exist_ok=True)
|
||||
os.makedirs(RECOVER_ROOT, exist_ok=True)
|
||||
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(DATASET_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(PROFILER_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(TORCH_EXTENSIONS_DIR, exist_ok=True)
|
||||
os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(PORT_LOCK_FILE_ROOT, exist_ok=True)
|
||||
os.makedirs(SGLANG_CACHE_PATH, exist_ok=True)
|
||||
|
||||
# _model_name will be changed in the model_scope context manager
|
||||
_model_name: "ModelName" = None
|
||||
|
|
|
@ -33,6 +33,13 @@ logger = logging.getLogger("benchmark")
|
|||
IF_MARK = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RolloutStat:
|
||||
submitted: int = 0
|
||||
accepted: int = 0
|
||||
running: int = 0
|
||||
|
||||
|
||||
def mock_time_mark(name, identifier, t, step):
|
||||
if IF_MARK:
|
||||
logger.debug(f"*{name}* #{identifier}# ${t}$ ns step &{step}&")
|
||||
|
|
|
@ -13,13 +13,14 @@ import time
|
|||
import uuid
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import ray
|
||||
|
||||
try:
|
||||
import etcd3
|
||||
except Exception:
|
||||
etcd3 = None
|
||||
|
||||
from realhf.base import cluster, logging, security, timeutil
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
from realhf.base import logging, security, timeutil
|
||||
|
||||
logger = logging.getLogger("name-resolve")
|
||||
|
||||
|
@ -286,14 +287,19 @@ class MemoryNameRecordRepository(NameRecordRepository):
|
|||
|
||||
|
||||
class NfsNameRecordRepository(NameRecordRepository):
|
||||
RECORD_ROOT = f"{cluster_spec.fileroot}/name_resolve/"
|
||||
os.makedirs(RECORD_ROOT, exist_ok=True)
|
||||
RECORD_ROOT = ""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__to_delete = set()
|
||||
|
||||
@staticmethod
|
||||
def __dir_path(name):
|
||||
if not NfsNameRecordRepository.RECORD_ROOT:
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
|
||||
RECORD_ROOT = f"{cluster_spec.fileroot}/name_resolve/"
|
||||
os.makedirs(RECORD_ROOT, exist_ok=True)
|
||||
NfsNameRecordRepository.RECORD_ROOT = RECORD_ROOT
|
||||
return os.path.join(NfsNameRecordRepository.RECORD_ROOT, name)
|
||||
|
||||
@staticmethod
|
||||
|
@ -930,6 +936,446 @@ class Etcd3NameRecordRepository(NameRecordRepository):
|
|||
logger.debug(f"Testonly: dropped key: {name}")
|
||||
|
||||
|
||||
@ray.remote
|
||||
class DistributedKVStore:
|
||||
"""Ray actor implementing a distributed key-value store with TTL support."""
|
||||
|
||||
def __init__(self):
|
||||
self.store = {}
|
||||
self.ttl_store = {} # key -> expiry_time
|
||||
self.lease_store = {} # key -> lease_id
|
||||
self.lease_counter = 0
|
||||
|
||||
def put(self, key: str, value: str, lease_id: Optional[int] = None):
|
||||
"""Store a key-value pair with optional lease."""
|
||||
self.store[key] = value
|
||||
if lease_id is not None:
|
||||
self.lease_store[key] = lease_id
|
||||
return True
|
||||
|
||||
def get(self, key: str):
|
||||
"""Get value for a key, checking TTL expiry."""
|
||||
self._cleanup_expired()
|
||||
if key not in self.store:
|
||||
return None
|
||||
return self.store[key]
|
||||
|
||||
def delete(self, key: str):
|
||||
"""Delete a key and its associated metadata."""
|
||||
deleted = key in self.store
|
||||
self.store.pop(key, None)
|
||||
self.ttl_store.pop(key, None)
|
||||
self.lease_store.pop(key, None)
|
||||
return deleted
|
||||
|
||||
def get_prefix(self, prefix: str):
|
||||
"""Get all key-value pairs with keys matching the prefix."""
|
||||
self._cleanup_expired()
|
||||
result = []
|
||||
normalized_prefix = os.path.normpath(prefix)
|
||||
|
||||
for key, value in self.store.items():
|
||||
normalized_key = os.path.normpath(key)
|
||||
# Check if key matches prefix (exact match or starts with prefix/)
|
||||
if normalized_key == normalized_prefix or normalized_key.startswith(
|
||||
normalized_prefix.rstrip("/") + "/"
|
||||
):
|
||||
result.append((key, value))
|
||||
return result
|
||||
|
||||
def delete_prefix(self, prefix: str):
|
||||
"""Delete all keys matching the prefix."""
|
||||
self._cleanup_expired()
|
||||
normalized_prefix = os.path.normpath(prefix)
|
||||
keys_to_delete = []
|
||||
|
||||
for key in self.store.keys():
|
||||
normalized_key = os.path.normpath(key)
|
||||
if normalized_key == normalized_prefix or normalized_key.startswith(
|
||||
normalized_prefix.rstrip("/") + "/"
|
||||
):
|
||||
keys_to_delete.append(key)
|
||||
|
||||
for key in keys_to_delete:
|
||||
self.delete(key)
|
||||
return len(keys_to_delete)
|
||||
|
||||
def create_lease(self, ttl_seconds: int):
|
||||
"""Create a lease with TTL."""
|
||||
self.lease_counter += 1
|
||||
lease_id = self.lease_counter
|
||||
expiry_time = time.time() + ttl_seconds
|
||||
return lease_id, expiry_time
|
||||
|
||||
def put_with_lease(self, key: str, value: str, ttl_seconds: int):
|
||||
"""Store key-value with TTL lease."""
|
||||
lease_id, expiry_time = self.create_lease(ttl_seconds)
|
||||
self.store[key] = value
|
||||
self.ttl_store[key] = expiry_time
|
||||
self.lease_store[key] = lease_id
|
||||
return lease_id
|
||||
|
||||
def refresh_lease(self, key: str, ttl_seconds: int):
|
||||
"""Refresh the lease for a key."""
|
||||
if key in self.store and key in self.lease_store:
|
||||
self.ttl_store[key] = time.time() + ttl_seconds
|
||||
return True
|
||||
return False
|
||||
|
||||
def _cleanup_expired(self):
|
||||
"""Remove expired keys."""
|
||||
current_time = time.time()
|
||||
expired_keys = []
|
||||
|
||||
for key, expiry_time in self.ttl_store.items():
|
||||
if current_time > expiry_time:
|
||||
expired_keys.append(key)
|
||||
|
||||
for key in expired_keys:
|
||||
self.delete(key)
|
||||
|
||||
def get_all_keys(self):
|
||||
"""Get all keys in the store."""
|
||||
self._cleanup_expired()
|
||||
return list(self.store.keys())
|
||||
|
||||
|
||||
class RayNameResolveRepository:
|
||||
"""Ray-based implementation of NameRecordRepository using distributed actors."""
|
||||
|
||||
KEEPALIVE_POLL_FREQUENCY = 1
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Entry:
|
||||
value: str
|
||||
lease_id: Optional[int] = None
|
||||
keepalive_ttl: Optional[int] = None
|
||||
keeper: Optional[timeutil.FrequencyControl] = None
|
||||
|
||||
def __init__(self, actor_name: str = "distributed_kv_store", **kwargs):
|
||||
"""Initialize Ray-based name record repository.
|
||||
|
||||
Args:
|
||||
actor_name: Name for the Ray actor (for sharing across processes)
|
||||
**kwargs: Additional configuration parameters
|
||||
"""
|
||||
super().__init__()
|
||||
self._lock = threading.Lock()
|
||||
self._actor_name = actor_name
|
||||
|
||||
# Initialize Ray if not already done
|
||||
if not ray.is_initialized():
|
||||
ray.init(ignore_reinit_error=True)
|
||||
|
||||
# Try to get existing actor or create new one
|
||||
try:
|
||||
self._kv_store = ray.get_actor(self._actor_name)
|
||||
logger.debug(
|
||||
f"Connected to existing Ray KV store actor: {self._actor_name}"
|
||||
)
|
||||
except ValueError:
|
||||
# Actor doesn't exist, create it
|
||||
self._kv_store = DistributedKVStore.options(
|
||||
name=self._actor_name, lifetime="detached"
|
||||
).remote()
|
||||
logger.debug(f"Created new Ray KV store actor: {self._actor_name}")
|
||||
|
||||
# Track entries for cleanup and keepalive
|
||||
self._entries = {}
|
||||
self._keepalive_running = True
|
||||
self._keepalive_thread = threading.Thread(
|
||||
target=self._keepalive_thread_run, daemon=True
|
||||
)
|
||||
self._keepalive_thread.start()
|
||||
|
||||
self._to_delete = set()
|
||||
|
||||
def __del__(self):
|
||||
"""Clean up resources when the object is deleted."""
|
||||
try:
|
||||
self.reset()
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Exception ignored when deleting RayNameResolveRepository: {e}"
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.reset()
|
||||
|
||||
def add(
|
||||
self,
|
||||
name: str,
|
||||
value: str,
|
||||
delete_on_exit: bool = True,
|
||||
keepalive_ttl: Optional[int] = None,
|
||||
replace: bool = False,
|
||||
):
|
||||
"""Add a key-value pair to the distributed store.
|
||||
|
||||
Args:
|
||||
name: Key name
|
||||
value: Value to store
|
||||
delete_on_exit: Whether to delete the key when this object is destroyed
|
||||
keepalive_ttl: TTL in seconds for the key
|
||||
replace: Whether to replace an existing key
|
||||
|
||||
Raises:
|
||||
NameEntryExistsError: If the key already exists and replace is False
|
||||
"""
|
||||
if not name:
|
||||
raise ValueError(f"Invalid name: {name}")
|
||||
name = os.path.normpath(name)
|
||||
value = str(value)
|
||||
|
||||
with self._lock:
|
||||
# Check if key exists when replace=False
|
||||
if not replace:
|
||||
existing_value = ray.get(self._kv_store.get.remote(name))
|
||||
if existing_value is not None:
|
||||
raise NameEntryExistsError(
|
||||
f"Key already exists: K={name} V={existing_value}"
|
||||
)
|
||||
|
||||
# Store with or without TTL
|
||||
lease_id = None
|
||||
if keepalive_ttl is not None and keepalive_ttl > 0:
|
||||
lease_id = ray.get(
|
||||
self._kv_store.put_with_lease.remote(name, value, keepalive_ttl)
|
||||
)
|
||||
self._to_delete.add(name)
|
||||
else:
|
||||
ray.get(self._kv_store.put.remote(name, value))
|
||||
if delete_on_exit:
|
||||
self._to_delete.add(name)
|
||||
|
||||
# Store entry information for keepalive management
|
||||
self._entries[name] = self._Entry(
|
||||
value=value,
|
||||
lease_id=lease_id,
|
||||
keepalive_ttl=keepalive_ttl,
|
||||
keeper=(
|
||||
timeutil.FrequencyControl(frequency_seconds=keepalive_ttl / 3)
|
||||
if keepalive_ttl
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def add_subentry(self, name: str, value: str, **kwargs):
|
||||
"""Add a sub-entry to the key-root `name`."""
|
||||
sub_name = os.path.join(os.path.normpath(name), str(uuid.uuid4())[:8])
|
||||
self.add(sub_name, value, **kwargs)
|
||||
return sub_name
|
||||
|
||||
def delete(self, name: str):
|
||||
"""Delete a key from the distributed store.
|
||||
|
||||
Args:
|
||||
name: Key to delete
|
||||
|
||||
Raises:
|
||||
NameEntryNotFoundError: If the key doesn't exist
|
||||
"""
|
||||
with self._lock:
|
||||
self._delete_locked(name)
|
||||
if name in self._to_delete:
|
||||
self._to_delete.remove(name)
|
||||
|
||||
def _delete_locked(self, name: str):
|
||||
"""Delete a key with lock already acquired."""
|
||||
# Check if key exists
|
||||
existing_value = ray.get(self._kv_store.get.remote(name))
|
||||
if existing_value is None:
|
||||
raise NameEntryNotFoundError(f"No such entry to delete: {name}")
|
||||
|
||||
# Clean up entry tracking
|
||||
if name in self._entries:
|
||||
del self._entries[name]
|
||||
|
||||
# Delete from store
|
||||
ray.get(self._kv_store.delete.remote(name))
|
||||
|
||||
def clear_subtree(self, name_root: str):
|
||||
"""Delete all keys with the given prefix."""
|
||||
with self._lock:
|
||||
name_root = os.path.normpath(name_root)
|
||||
count = ray.get(self._kv_store.delete_prefix.remote(name_root))
|
||||
|
||||
# Clean up local tracking for deleted keys
|
||||
keys_to_remove = []
|
||||
for key in self._entries.keys():
|
||||
normalized_key = os.path.normpath(key)
|
||||
if normalized_key == name_root or normalized_key.startswith(
|
||||
name_root.rstrip("/") + "/"
|
||||
):
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
del self._entries[key]
|
||||
|
||||
logger.debug(f"Deleted {count} entries under {name_root}")
|
||||
|
||||
def get(self, name: str):
|
||||
"""Get the value for a key.
|
||||
|
||||
Args:
|
||||
name: Key to retrieve
|
||||
|
||||
Returns:
|
||||
The value as a string
|
||||
|
||||
Raises:
|
||||
NameEntryNotFoundError: If the key doesn't exist
|
||||
"""
|
||||
name = os.path.normpath(name)
|
||||
with self._lock:
|
||||
return self._get_locked(name)
|
||||
|
||||
def _get_locked(self, name: str):
|
||||
"""Get a value with lock already acquired."""
|
||||
value = ray.get(self._kv_store.get.remote(name))
|
||||
if value is None:
|
||||
raise NameEntryNotFoundError(f"No such entry: {name}")
|
||||
return value
|
||||
|
||||
def get_subtree(self, name_root: str):
|
||||
"""Get all values with keys having the given prefix."""
|
||||
with self._lock:
|
||||
name_root = os.path.normpath(name_root)
|
||||
pairs = ray.get(self._kv_store.get_prefix.remote(name_root))
|
||||
values = [value for key, value in pairs]
|
||||
return sorted(values)
|
||||
|
||||
def find_subtree(self, name_root: str):
|
||||
"""Find all keys with the given prefix."""
|
||||
with self._lock:
|
||||
name_root = os.path.normpath(name_root)
|
||||
pairs = ray.get(self._kv_store.get_prefix.remote(name_root))
|
||||
keys = [key for key, value in pairs]
|
||||
return sorted(keys)
|
||||
|
||||
def wait(
|
||||
self, name: str, timeout: Optional[float] = None, poll_frequency: float = 1
|
||||
):
|
||||
"""Wait until a name appears.
|
||||
|
||||
Raises:
|
||||
TimeoutError: if timeout exceeds.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
try:
|
||||
return self.get(name)
|
||||
except NameEntryNotFoundError:
|
||||
pass
|
||||
if timeout is None or timeout > 0:
|
||||
time.sleep(
|
||||
poll_frequency + random.random() * 0.1
|
||||
) # To reduce concurrency.
|
||||
if timeout is not None and time.monotonic() - start > timeout:
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for key '{name}' ({self.__class__.__name__})"
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
"""Delete all keys added via this repository instance."""
|
||||
self._keepalive_running = False
|
||||
if hasattr(self, "_keepalive_thread"):
|
||||
self._keepalive_thread.join(timeout=5)
|
||||
|
||||
with self._lock:
|
||||
count = 0
|
||||
for name in list(self._to_delete):
|
||||
try:
|
||||
self._delete_locked(name)
|
||||
count += 1
|
||||
except NameEntryNotFoundError:
|
||||
pass
|
||||
self._to_delete = set()
|
||||
self._entries = {}
|
||||
logger.debug(f"Reset {count} saved entries")
|
||||
|
||||
def watch_names(
|
||||
self,
|
||||
names: List[str],
|
||||
call_back: Callable,
|
||||
poll_frequency: float = 15,
|
||||
wait_timeout: float = 300,
|
||||
):
|
||||
"""Watch keys and call back when they are deleted.
|
||||
|
||||
Args:
|
||||
names: Keys to watch
|
||||
call_back: Function to call when any key is deleted
|
||||
poll_frequency: How often to check in seconds
|
||||
wait_timeout: Maximum time to wait for keys to exist
|
||||
"""
|
||||
if isinstance(names, str):
|
||||
names = [names]
|
||||
|
||||
q = queue.Queue(maxsize=len(names))
|
||||
for _ in range(len(names) - 1):
|
||||
q.put(0)
|
||||
|
||||
def wrap_call_back():
|
||||
try:
|
||||
q.get_nowait()
|
||||
except queue.Empty:
|
||||
logger.info(f"Key {names} is gone. Executing callback {call_back}")
|
||||
call_back()
|
||||
|
||||
for name in names:
|
||||
t = threading.Thread(
|
||||
target=self._watch_thread_run,
|
||||
args=(name, wrap_call_back, poll_frequency, wait_timeout),
|
||||
daemon=True,
|
||||
)
|
||||
t.start()
|
||||
|
||||
def _watch_thread_run(
|
||||
self, name: str, call_back: Callable, poll_frequency: float, wait_timeout: float
|
||||
):
|
||||
"""Background thread to watch a key for deletion."""
|
||||
self.wait(name, timeout=wait_timeout, poll_frequency=poll_frequency)
|
||||
while True:
|
||||
try:
|
||||
self.get(name)
|
||||
time.sleep(poll_frequency + random.random())
|
||||
except NameEntryNotFoundError:
|
||||
call_back()
|
||||
break
|
||||
|
||||
def _keepalive_thread_run(self):
|
||||
"""Background thread to keep leases alive."""
|
||||
while self._keepalive_running:
|
||||
time.sleep(self.KEEPALIVE_POLL_FREQUENCY)
|
||||
with self._lock:
|
||||
for name, entry in list(self._entries.items()):
|
||||
if (
|
||||
entry.keeper is not None
|
||||
and entry.keepalive_ttl is not None
|
||||
and entry.lease_id is not None
|
||||
and entry.keeper.check()
|
||||
):
|
||||
try:
|
||||
# Refresh the lease
|
||||
success = ray.get(
|
||||
self._kv_store.refresh_lease.remote(
|
||||
name, entry.keepalive_ttl
|
||||
)
|
||||
)
|
||||
if not success:
|
||||
logger.warning(
|
||||
f"Failed to refresh lease for key: {name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to refresh lease for key: K={name} V={entry.value}. Error: {e}"
|
||||
)
|
||||
|
||||
|
||||
def make_repository(type_="nfs", **kwargs):
|
||||
if type_ == "memory":
|
||||
return MemoryNameRecordRepository(**kwargs)
|
||||
|
@ -939,6 +1385,8 @@ def make_repository(type_="nfs", **kwargs):
|
|||
return RedisNameRecordRepository(**kwargs)
|
||||
elif type_ == "etcd3":
|
||||
return Etcd3NameRecordRepository(**kwargs)
|
||||
elif type_ == "ray":
|
||||
return RayNameResolveRepository(**kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"No such name resolver: {type_}")
|
||||
|
||||
|
|
|
@ -99,3 +99,11 @@ def used_ports(experiment_name, trial_name, host_name):
|
|||
|
||||
def gen_server_manager(experiment_name, trial_name):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server_manager"
|
||||
|
||||
|
||||
def training_samples(experiment_name, trial_name):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/training_samples"
|
||||
|
||||
|
||||
def experiment_status(experiment_name, trial_name):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/experiment_status"
|
||||
|
|
|
@ -28,11 +28,6 @@ def find_free_port(
|
|||
"""Find a free port within the specified range, excluding certain ports."""
|
||||
|
||||
ports_name = names.used_ports(experiment_name, trial_name, gethostip())
|
||||
used_ports = list(map(int, name_resolve.get_subtree(ports_name)))
|
||||
if exclude_ports is None:
|
||||
exclude_ports = set(used_ports)
|
||||
else:
|
||||
exclude_ports = exclude_ports.union(set(used_ports))
|
||||
|
||||
free_port = None
|
||||
lockfile = os.path.join(constants.PORT_LOCK_FILE_ROOT, gethostip())
|
||||
|
@ -40,10 +35,16 @@ def find_free_port(
|
|||
with open(lockfile, "w") as fd:
|
||||
# This will block until lock is acquired
|
||||
fcntl.flock(fd, fcntl.LOCK_EX)
|
||||
used_ports = list(map(int, name_resolve.get_subtree(ports_name)))
|
||||
assert len(used_ports) == len(set(used_ports))
|
||||
if exclude_ports is None:
|
||||
exclude_ports = set(used_ports)
|
||||
else:
|
||||
exclude_ports = exclude_ports.union(set(used_ports))
|
||||
try:
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(("", 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
s.bind(("", 0))
|
||||
port = s.getsockname()[1]
|
||||
if low <= port <= high and port not in exclude_ports:
|
||||
name_resolve.add_subentry(ports_name, str(port))
|
||||
|
|
|
@ -8,7 +8,7 @@ import json
|
|||
import os
|
||||
import sys
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
PROLOGUE_FLAG_NAME = "--config"
|
||||
PROLOGUE_FLAG_VAR_NAME = "config"
|
||||
|
|
|
@ -60,6 +60,7 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
|
|||
"version_start",
|
||||
"version_end",
|
||||
"rewards",
|
||||
"birth_time",
|
||||
)
|
||||
rpcs["actor_train"].input_keys = (
|
||||
*rpcs["actor_train"].input_keys,
|
||||
|
|
|
@ -104,7 +104,6 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
|
|||
scheduling=Scheduling.model_worker_default(
|
||||
cpu=self.cpus_per_model_worker,
|
||||
gpu=1,
|
||||
gpu_type=cluster_spec.gpu_type,
|
||||
mem=self.mem_per_model_worker,
|
||||
nodelist=self.nodelist,
|
||||
exclude=self.exclude,
|
||||
|
@ -115,7 +114,6 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
|
|||
scheduling=Scheduling.generation_server_default(
|
||||
cpu=self.cpus_per_generation_server,
|
||||
gpu=gen_tp_size,
|
||||
gpu_type=cluster_spec.gpu_type,
|
||||
mem=self.mem_per_generation_server,
|
||||
nodelist=self.nodelist,
|
||||
exclude=self.exclude,
|
||||
|
@ -125,7 +123,6 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
|
|||
count=1,
|
||||
scheduling=Scheduling.gserver_manager_default(
|
||||
cpu=self.cpus_per_gserver_manager,
|
||||
gpu_type=cluster_spec.gpu_type,
|
||||
mem=self.mem_per_gserver_manager,
|
||||
nodelist=self.nodelist,
|
||||
exclude=self.exclude,
|
||||
|
@ -135,7 +132,6 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
|
|||
count=self.n_rollout_workers or train_world_size,
|
||||
scheduling=Scheduling.rollout_worker_default(
|
||||
cpu=self.cpus_per_rollout_worker,
|
||||
gpu_type=cluster_spec.gpu_type,
|
||||
mem=self.mem_per_rollout_worker,
|
||||
nodelist=self.nodelist,
|
||||
exclude=self.exclude,
|
||||
|
|
|
@ -176,7 +176,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
scheduling=Scheduling.model_worker_default(
|
||||
cpu=self.cpus_per_model_worker,
|
||||
gpu=1,
|
||||
gpu_type=cluster_spec.gpu_type,
|
||||
mem=self.mem_per_model_worker,
|
||||
nodelist=self.nodelist,
|
||||
exclude=self.exclude,
|
||||
|
@ -573,6 +572,18 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
)
|
||||
|
||||
def _check_legal_allocation_options(self):
|
||||
if self.n_nodes > self.cluster.n_nodes:
|
||||
raise ValueError(
|
||||
f"Number of used nodes {self.n_nodes} should not be larger than the cluster size {self.cluster.n_nodes}"
|
||||
)
|
||||
if self.n_gpus_per_node > self.cluster.n_gpus_per_node:
|
||||
raise ValueError(
|
||||
f"Number of 7used GPUs per node {self.n_gpus_per_node} should not be larger than the cluster limit {self.cluster.n_gpus_per_node}"
|
||||
)
|
||||
if self.n_nodes > 1 and self.n_gpus_per_node != self.cluster.n_gpus_per_node:
|
||||
raise ValueError(
|
||||
f"For distributed experiments, only using all GPUs on each node is allowed."
|
||||
)
|
||||
if self.n_nodes > 1 and self.mode == "local":
|
||||
raise ValueError(
|
||||
"Cannot run multi-node experiment in local mode, "
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
import colorama
|
||||
|
@ -51,6 +52,7 @@ class MathSingleStepAgent(Agent):
|
|||
assert prompt.bs == 1
|
||||
prompt_token_ids = prompt.data["packed_prompts"].cpu().numpy().tolist()
|
||||
qid = prompt.ids[0]
|
||||
birth_time = int(datetime.now().timestamp() * 1000)
|
||||
await obs_queue.put((qid, prompt_token_ids, self.gconfig))
|
||||
|
||||
act: BundledGenerationOutputs = await act_queue.get()
|
||||
|
@ -107,7 +109,7 @@ class MathSingleStepAgent(Agent):
|
|||
"version_start",
|
||||
"version_end",
|
||||
"rewards",
|
||||
"task_ids",
|
||||
"birth_time",
|
||||
],
|
||||
ids=[qid],
|
||||
dtypes=dict(
|
||||
|
@ -119,7 +121,7 @@ class MathSingleStepAgent(Agent):
|
|||
version_end=torch.int,
|
||||
packed_logprobs=torch.float32,
|
||||
rewards=torch.float32,
|
||||
task_ids=torch.long,
|
||||
birth_time=torch.long,
|
||||
),
|
||||
trailing_shapes=dict(
|
||||
packed_input_ids=(),
|
||||
|
@ -130,7 +132,7 @@ class MathSingleStepAgent(Agent):
|
|||
version_start=(),
|
||||
packed_logprobs=(),
|
||||
rewards=(),
|
||||
task_ids=(),
|
||||
birth_time=(),
|
||||
),
|
||||
seqlens=dict(
|
||||
packed_input_ids=[act.seqlens],
|
||||
|
@ -141,7 +143,7 @@ class MathSingleStepAgent(Agent):
|
|||
rewards=[[1 for _ in range(self.gconfig.n)]],
|
||||
version_start=[[1 for _ in range(self.gconfig.n)]],
|
||||
version_end=[[1 for _ in range(self.gconfig.n)]],
|
||||
task_ids=[[1]],
|
||||
birth_time=[[1]],
|
||||
),
|
||||
data=dict(
|
||||
packed_prompts=torch.tensor(act.prompt_ids, dtype=torch.long),
|
||||
|
@ -153,6 +155,7 @@ class MathSingleStepAgent(Agent):
|
|||
rewards=torch.tensor(rewards, dtype=torch.float32),
|
||||
version_start=torch.tensor(act.version_start, dtype=torch.int),
|
||||
version_end=torch.tensor(act.version_end, dtype=torch.int),
|
||||
birth_time=torch.tensor([birth_time], dtype=torch.long),
|
||||
prompt_mask=torch.tensor(
|
||||
sum(
|
||||
[
|
||||
|
@ -163,9 +166,18 @@ class MathSingleStepAgent(Agent):
|
|||
),
|
||||
dtype=torch.bool,
|
||||
),
|
||||
task_ids=prompt.data["task_ids"],
|
||||
),
|
||||
)
|
||||
if "task_ids" in prompt.keys:
|
||||
y = SequenceSample(
|
||||
keys=["task_ids"],
|
||||
ids=[qid],
|
||||
dtypes=dict(task_ids=torch.long),
|
||||
trailing_shapes=dict(task_ids=()),
|
||||
seqlens=dict(task_ids=[[1]]),
|
||||
data=dict(task_ids=prompt.data["task_ids"]),
|
||||
)
|
||||
x.update_(y)
|
||||
|
||||
return [x]
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import transformers
|
||||
|
||||
from realhf.base import cluster, constants, logging
|
||||
from realhf.base import constants, logging
|
||||
from realhf.impl.model.utils.padding import pad_input, unpad_input
|
||||
|
||||
logger = logging.getLogger("Modeling Functional Utils")
|
||||
|
@ -166,7 +166,7 @@ def build_leave_one_indices(
|
|||
)
|
||||
|
||||
|
||||
def gather_logprobs(
|
||||
def _gather_logprobs(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
):
|
||||
|
@ -186,8 +186,22 @@ def gather_logprobs(
|
|||
return log_probs_labels
|
||||
|
||||
|
||||
if cluster.spec.name != "wa180":
|
||||
gather_logprobs = torch.compile(gather_logprobs)
|
||||
_gather_logprobs_compiled = None
|
||||
|
||||
|
||||
def gather_logprobs(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
):
|
||||
from realhf.base import cluster
|
||||
|
||||
if cluster.spec.name == "wa180":
|
||||
# torch.compile doesn't work on PPU
|
||||
return _gather_logprobs(logits, labels)
|
||||
global _gather_logprobs_compiled
|
||||
if _gather_logprobs_compiled is None:
|
||||
_gather_logprobs_compiled = torch.compile(_gather_logprobs)
|
||||
return _gather_logprobs_compiled(logits, labels)
|
||||
|
||||
|
||||
def gather_packed_shifted_log_probs(
|
||||
|
|
|
@ -111,13 +111,11 @@ class SlurmSchedulerClient(SchedulerClient):
|
|||
cmd: str, # XXX: should be None for workers
|
||||
count: int,
|
||||
cpu: int = 1,
|
||||
gpu_type: str = "geforce",
|
||||
gpu: int = 0,
|
||||
mem: int = 1024, # MB
|
||||
env_vars: Optional[Dict] = None,
|
||||
container_image: str = cluster_spec.gpu_image,
|
||||
container_mounts: str = cluster_spec.default_mount,
|
||||
node_type: Optional[str] = None,
|
||||
container_image: Optional[str] = None,
|
||||
container_mounts: Optional[str] = None,
|
||||
nodelist: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
hostfile: bool = True,
|
||||
|
@ -126,14 +124,14 @@ class SlurmSchedulerClient(SchedulerClient):
|
|||
deadline: str = None,
|
||||
time_limit: str = None,
|
||||
):
|
||||
container_image = container_image or cluster_spec.cpu_image
|
||||
container_mounts = container_mounts or cluster_spec.mount
|
||||
# record launch information, do not submit to slurm until `wait()` is called
|
||||
# NOTE: fractional GPU requirement will be resolved automatically in `__post_init__` of SlurnLaunchInfo
|
||||
launch_info = SlurmLaunchInfo(
|
||||
worker_type=worker_type,
|
||||
wprocs_in_job=count,
|
||||
resource_requirement=SlurmResource(
|
||||
mem=mem, cpu=cpu, gpu=gpu, gpu_type=gpu_type
|
||||
),
|
||||
resource_requirement=SlurmResource(mem=mem, cpu=cpu, gpu=gpu),
|
||||
cmd=cmd,
|
||||
run_name=self.run_name,
|
||||
exper_name=self.expr_name,
|
||||
|
@ -141,7 +139,6 @@ class SlurmSchedulerClient(SchedulerClient):
|
|||
container_image=container_image,
|
||||
container_mounts=container_mounts,
|
||||
env_vars=env_vars,
|
||||
node_type=node_type,
|
||||
nodelist=nodelist,
|
||||
exclude=exclude,
|
||||
hostfile=hostfile,
|
||||
|
|
|
@ -67,21 +67,8 @@ class SlurmResource:
|
|||
# a data class that represents a slurm resource quota
|
||||
mem: int = 0
|
||||
cpu: int = 0
|
||||
gpu_type: Optional[Literal["tesla", "geforce", "ppu"]] = None
|
||||
gpu: Union[float, int] = 0
|
||||
|
||||
def __check_gpu_type(self, other: SlurmResource) -> str:
|
||||
self_gpu_type = None if self.gpu == 0 else self.gpu_type
|
||||
other_gpu_type = None if other.gpu == 0 else other.gpu_type
|
||||
valid_gpu_type = self_gpu_type == other_gpu_type or (
|
||||
self_gpu_type or other_gpu_type
|
||||
)
|
||||
if not valid_gpu_type:
|
||||
raise InvalidGPUTypeException(
|
||||
f"Cannot add two different gpu types {self_gpu_type}, {other_gpu_type}."
|
||||
)
|
||||
return self_gpu_type if self_gpu_type else other_gpu_type
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
"SlurmResource: \n"
|
||||
|
@ -93,9 +80,6 @@ class SlurmResource:
|
|||
+ " \n"
|
||||
+ "gpu: "
|
||||
+ str(self.gpu)
|
||||
+ " \n"
|
||||
+ "gpu_type: "
|
||||
+ str(self.gpu_type)
|
||||
)
|
||||
|
||||
def __mul__(self, other: int) -> SlurmResource:
|
||||
|
@ -106,7 +90,6 @@ class SlurmResource:
|
|||
mem=self.mem * other,
|
||||
cpu=self.cpu * other,
|
||||
gpu=self.gpu * other,
|
||||
gpu_type=self.gpu_type,
|
||||
)
|
||||
|
||||
def __rmul__(self, other: int) -> SlurmResource:
|
||||
|
@ -120,7 +103,6 @@ class SlurmResource:
|
|||
mem=self.mem + other.mem,
|
||||
cpu=self.cpu + other.cpu,
|
||||
gpu=self.gpu + other.gpu,
|
||||
gpu_type=self.__check_gpu_type(other),
|
||||
)
|
||||
|
||||
def __sub__(self, other: SlurmResource) -> SlurmResource:
|
||||
|
@ -131,28 +113,19 @@ class SlurmResource:
|
|||
mem=self.mem - other.mem,
|
||||
cpu=self.cpu - other.cpu,
|
||||
gpu=self.gpu - other.gpu,
|
||||
gpu_type=self.__check_gpu_type(other),
|
||||
)
|
||||
|
||||
def __neg__(self) -> SlurmResource:
|
||||
return SlurmResource(
|
||||
mem=-self.mem, cpu=-self.cpu, gpu=-self.gpu, gpu_type=self.gpu_type
|
||||
mem=-self.mem,
|
||||
cpu=-self.cpu,
|
||||
gpu=-self.gpu,
|
||||
)
|
||||
|
||||
def __eq__(self, other: SlurmResource) -> bool:
|
||||
return (
|
||||
self.mem == other.mem
|
||||
and self.cpu == other.cpu
|
||||
and self.gpu == other.gpu
|
||||
and self.gpu_type == other.gpu_type
|
||||
)
|
||||
return self.mem == other.mem and self.cpu == other.cpu and self.gpu == other.gpu
|
||||
|
||||
def __lt__(self, other: SlurmResource) -> bool:
|
||||
if self.gpu_type != other.gpu_type:
|
||||
if self.gpu_type is None:
|
||||
return True
|
||||
if self.gpu_type == "geforce":
|
||||
return self.gpu_type < other.gpu_type
|
||||
if self.gpu != other.gpu:
|
||||
return self.gpu < other.gpu
|
||||
if self.cpu != other.cpu:
|
||||
|
@ -162,8 +135,6 @@ class SlurmResource:
|
|||
|
||||
def valid(self) -> bool:
|
||||
# check if it is a valid resource requirement
|
||||
if self.gpu_type not in ["geforce", "tesla", "ppu", None]:
|
||||
return False
|
||||
if self.mem < 0 or self.cpu < 0 or self.gpu < 0:
|
||||
return False
|
||||
return True
|
||||
|
@ -207,7 +178,6 @@ class SlurmLaunchInfo:
|
|||
this string should be of format 'docker://<image>'.
|
||||
container_mounts (str): .
|
||||
env_vars (dict): .
|
||||
node_type (str): .
|
||||
nodelist (str): .
|
||||
exclude (str): .
|
||||
partition (str, optional): default to "all".
|
||||
|
@ -234,7 +204,6 @@ class SlurmLaunchInfo:
|
|||
container_image: str
|
||||
container_mounts: str
|
||||
env_vars: dict
|
||||
node_type: str
|
||||
nodelist: str
|
||||
exclude: str
|
||||
partition: Optional[str] = "all"
|
||||
|
@ -377,7 +346,6 @@ class SlurmLaunchInfo:
|
|||
cmd = self.cmd
|
||||
|
||||
# assert gpu == 1 or gpu == 0, "Slurm job GPU requirement should be resolved to a integer."
|
||||
gpu_type = self.resource_requirement.gpu_type
|
||||
|
||||
if self.multiprog:
|
||||
with open(self.multiprog_path, "w") as f:
|
||||
|
@ -400,8 +368,8 @@ class SlurmLaunchInfo:
|
|||
# In current slurm cluster setup, we can only use "--gres" to
|
||||
# allocate PPUs per node. There are no options to allocate customized
|
||||
# gres per tasks.
|
||||
if gpu_type == "ppu":
|
||||
gres_line = f"--gres={gpu_type}:{cluster.spec.n_gpus_per_node}"
|
||||
if cluster.spec.gpu_type == "ppu":
|
||||
gres_line = f"--gres=ppu:{cluster.spec.n_gpus_per_node}"
|
||||
else:
|
||||
gres_line = f"--gres=gpu:{cluster.spec.n_gpus_per_node}"
|
||||
|
||||
|
@ -596,12 +564,9 @@ def _parse_output_tres_line(tres):
|
|||
res.cpu = int(t.split("=")[1])
|
||||
elif t.startswith("gres/gpu"):
|
||||
prefix, sgpu = t.split("=")
|
||||
if ":" in prefix:
|
||||
res.gpu_type = prefix.split(":")[1]
|
||||
res.gpu = int(sgpu)
|
||||
elif t.startswith("gres/ppu"):
|
||||
prefix, sgpu = t.split("=")
|
||||
res.gpu_type = "ppu"
|
||||
res.gpu = int(sgpu)
|
||||
elif t.startswith("billing"):
|
||||
# slurm default resource to limit number of
|
||||
|
@ -613,7 +578,6 @@ def _parse_output_tres_line(tres):
|
|||
|
||||
|
||||
def available_hostnames(
|
||||
node_type: Optional[List[str]] = None,
|
||||
nodelist: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
partition: Optional[str] = None,
|
||||
|
@ -684,12 +648,7 @@ def available_hostnames(
|
|||
for hn in invalid_hostnames:
|
||||
valid_hostnames.remove(hn)
|
||||
|
||||
return list(
|
||||
filter(
|
||||
lambda x: cluster.node_name_is_node_type(x, node_type),
|
||||
valid_hostnames,
|
||||
)
|
||||
)
|
||||
return valid_hostnames
|
||||
|
||||
|
||||
def get_all_node_resources() -> Dict[str, SlurmResource]:
|
||||
|
@ -721,15 +680,11 @@ def get_all_node_resources() -> Dict[str, SlurmResource]:
|
|||
ctres = _parse_output_tres_line(l)
|
||||
if l.startswith("AllocTRES"):
|
||||
atres = _parse_output_tres_line(l)
|
||||
if ctres.gpu_type is None:
|
||||
ctres.gpu_type = cluster.spec.gpu_type_from_node_name(node_name)
|
||||
if atres.gpu_type is None:
|
||||
atres.gpu_type = ctres.gpu_type
|
||||
rres = ctres - atres
|
||||
if rres.valid():
|
||||
all_rres[node_name] = rres
|
||||
else:
|
||||
all_rres[node_name] = SlurmResource(gpu_type=ctres.gpu_type)
|
||||
all_rres[node_name] = SlurmResource()
|
||||
|
||||
return all_rres
|
||||
|
||||
|
@ -769,7 +724,6 @@ def allocate_resources(
|
|||
prioritized_hosts = set()
|
||||
for info_idx, info in enumerate(infos):
|
||||
valid_hostnames = available_hostnames(
|
||||
node_type=info.node_type,
|
||||
nodelist=info.nodelist,
|
||||
exclude=info.exclude,
|
||||
partition=info.partition,
|
||||
|
@ -833,10 +787,7 @@ def allocate_resources(
|
|||
allocated[hostname] = tmp - task_left
|
||||
all_resources[hostname] = resource
|
||||
if task_left > 0:
|
||||
if (
|
||||
info.resource_requirement.gpu_type == "ppu"
|
||||
and info.resource_requirement.gpu > 0
|
||||
):
|
||||
if cluster.spec.gpu_type == "ppu" and info.resource_requirement.gpu > 0:
|
||||
logger.warning(
|
||||
"For PPU resources, we can only allocate tasks in the "
|
||||
f"granularity of nodes ({cluster.spec.n_gpus_per_node} PPUs)"
|
||||
|
@ -845,7 +796,7 @@ def allocate_resources(
|
|||
f'Unable to allocate {info.n_jobsteps} Jobs with name "{info.slurm_name}". '
|
||||
f"Resource Requirement of this job is: {dataclasses.asdict(info.resource_requirement)}. "
|
||||
f"Valid resources for this job is "
|
||||
f"(according to NodeType={info.node_type}, NodeList={info.nodelist}, "
|
||||
f"(according to NodeList={info.nodelist}, "
|
||||
f"and Exclude={info.exclude}):\n {resource_to_string({k: v for k, v in get_all_node_resources().items() if k in valid_hostnames})}"
|
||||
)
|
||||
for pinfo in infos[:info_idx]:
|
||||
|
@ -878,7 +829,7 @@ def allocate_resources(
|
|||
def show_tesla():
|
||||
all_rres = get_all_node_resources()
|
||||
hostname = socket.gethostname()
|
||||
for k in available_hostnames(node_type=["a100"]):
|
||||
for k in available_hostnames():
|
||||
print(k, all_rres[k])
|
||||
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ def run_worker(
|
|||
)
|
||||
worker = worker_class(server=server)
|
||||
try:
|
||||
if worker_type in ["rollout_worker"]:
|
||||
if worker_type in ["rollout_worker", "master_worker", "gserver_manager"]:
|
||||
asyncio.run(worker.run_async())
|
||||
else:
|
||||
worker.run()
|
||||
|
|
|
@ -244,7 +244,9 @@ class AsyncIOSequenceBuffer:
|
|||
)
|
||||
return indices
|
||||
|
||||
async def put_batch(self, samples: List[SequenceSample]):
|
||||
async def put_batch(
|
||||
self, samples: List[SequenceSample], birth_times: List[int] | None = None
|
||||
):
|
||||
n = len(samples)
|
||||
|
||||
if n == 0:
|
||||
|
@ -269,9 +271,12 @@ class AsyncIOSequenceBuffer:
|
|||
|
||||
# Set a slight difference in birth time to let the order
|
||||
# be deterministic.
|
||||
if birth_times is None:
|
||||
self._birth_time[indices] = time.monotonic_ns() + np.arange(
|
||||
len(indices), dtype=np.int64
|
||||
)
|
||||
else:
|
||||
self._birth_time[indices] = birth_times
|
||||
|
||||
async with self._lock:
|
||||
self.__buffer._update_has_keys(indices)
|
||||
|
|
|
@ -436,11 +436,9 @@ def run_ray_worker(
|
|||
constants.set_experiment_trial_names(experiment_name, trial_name)
|
||||
|
||||
import realhf.api.core.system_api as system_api
|
||||
from realhf.api.quickstart.entrypoint import (
|
||||
QUICKSTART_CONFIG_CLASSES,
|
||||
QUICKSTART_EXPR_CACHE_PATH,
|
||||
)
|
||||
from realhf.api.quickstart.entrypoint import QUICKSTART_CONFIG_CLASSES
|
||||
from realhf.base import importing
|
||||
from realhf.base.constants import QUICKSTART_EXPR_CACHE_PATH
|
||||
|
||||
if os.path.exists(QUICKSTART_EXPR_CACHE_PATH):
|
||||
for exp_cache in os.listdir(QUICKSTART_EXPR_CACHE_PATH):
|
||||
|
|
|
@ -10,8 +10,8 @@ import numpy as np
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from realhf import SequenceSample
|
||||
from realhf.api.core.config import ModelName, ModelShardID
|
||||
from realhf.api.core.data_api import SequenceSample
|
||||
from realhf.base import constants, logging
|
||||
from realhf.base.topology import ProcessTopology, new_or_get_group
|
||||
from realhf.impl.model.comm.global_comm import filter_match_mwids
|
||||
|
|
|
@ -6,11 +6,11 @@ from typing import *
|
|||
import networkx as nx
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from realhf.api.core.config import ModelName, ModelShardID
|
||||
from realhf.api.core.data_api import DataBatchMeta, SequenceSample
|
||||
from realhf.api.core.config import ModelShardID
|
||||
from realhf.api.core.data_api import DataBatchMeta, get_shuffle_indices
|
||||
from realhf.api.core.dfg import MFCDef
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base import logging
|
||||
from realhf.base import constants, logging, name_resolve, names, seeding
|
||||
from realhf.base.topology import ProcessTopology
|
||||
from realhf.system.buffer import AsyncIOSequenceBuffer
|
||||
from realhf.system.model_function_call import ModelFunctionCall, RPCCorountineControl
|
||||
|
@ -118,7 +118,10 @@ class FunctionExecutor:
|
|||
|
||||
received_ids = set()
|
||||
|
||||
load_data_iter = 0
|
||||
|
||||
while buffer.size < max(rpc.n_seqs for rpc in self.rpcs):
|
||||
load_data_iter += 1
|
||||
resps = await self.stream.call_async(
|
||||
handlers=[f"__data{dp_idx}__" for dp_idx in range(self.src_dp_size)],
|
||||
handle_type="fetch",
|
||||
|
@ -127,6 +130,7 @@ class FunctionExecutor:
|
|||
)
|
||||
|
||||
all_data = []
|
||||
all_birth_time = []
|
||||
data_cnt = []
|
||||
gpu_id_data = {}
|
||||
for dp_rank, x in enumerate(resps):
|
||||
|
@ -147,13 +151,21 @@ class FunctionExecutor:
|
|||
|
||||
gpu_id = self.stream.route_to(f"__data{dp_rank}__")
|
||||
all_data += x.meta_sample.unpack()
|
||||
all_birth_time += x.birth_times
|
||||
gpu_id_data[gpu_id] = x.meta_sample.unpack()
|
||||
data_cnt.append(x.meta_sample.bs)
|
||||
|
||||
if self.shuffle_dataset:
|
||||
# We load data in a round-robin manner across different DP ranks,
|
||||
# so we also need to shuffle the data to fuse different dataset splits.
|
||||
random.shuffle(all_data)
|
||||
shuffle_indices = get_shuffle_indices(
|
||||
seeding.get_seed()
|
||||
+ 47 * self.ctrl.step_info.global_step
|
||||
+ 97 * load_data_iter,
|
||||
len(all_data),
|
||||
)
|
||||
all_data = [all_data[i] for i in shuffle_indices]
|
||||
all_birth_time = [all_birth_time[i] for i in shuffle_indices]
|
||||
|
||||
if len(all_data) > 0:
|
||||
# Update resource tracker for planning data redistribution.
|
||||
|
@ -167,9 +179,21 @@ class FunctionExecutor:
|
|||
)
|
||||
|
||||
# Store into buffer!
|
||||
buffer_indices = await buffer.put_batch(all_data)
|
||||
assert len(all_data) == len(all_birth_time)
|
||||
buffer_indices = await buffer.put_batch(all_data, all_birth_time)
|
||||
assert len(buffer_indices) == len(all_data)
|
||||
|
||||
training_sample_name = names.training_samples(
|
||||
constants.experiment_name(), constants.trial_name()
|
||||
)
|
||||
try:
|
||||
n_samples = int(name_resolve.get(training_sample_name))
|
||||
except name_resolve.NameEntryNotFoundError:
|
||||
n_samples = 0
|
||||
name_resolve.add(
|
||||
training_sample_name, str(n_samples + len(all_data)), replace=True
|
||||
)
|
||||
|
||||
blogger.info(
|
||||
f"Master worker loaded {len(all_data)} pieces of data from all dp ranks: "
|
||||
f"{data_cnt} from each rank. "
|
||||
|
@ -178,7 +202,7 @@ class FunctionExecutor:
|
|||
else:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def execute_step(self):
|
||||
async def execute_step(self):
|
||||
logger.info("Waiting for the finish of the execution graph.")
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
@ -190,5 +214,5 @@ class FunctionExecutor:
|
|||
loop.create_task(self.finish_traverse()),
|
||||
]
|
||||
|
||||
loop.run_until_complete(asyncio.gather(*tasks))
|
||||
await asyncio.gather(*tasks)
|
||||
self.buffer_id = (self.buffer_id + 1) % len(self.buffers)
|
||||
|
|
|
@ -7,8 +7,10 @@ from pathlib import Path
|
|||
import requests
|
||||
|
||||
from realhf.api.cli_args import SGLangConfig
|
||||
from realhf.api.core.system_api import ExpStatus
|
||||
from realhf.api.core.system_api import GenerationServer as GenerationServerConfig
|
||||
from realhf.base import (
|
||||
constants,
|
||||
gpu_utils,
|
||||
logging,
|
||||
name_resolve,
|
||||
|
@ -30,7 +32,12 @@ def execute_shell_command(command: str) -> subprocess.Popen:
|
|||
# Replace newline continuations and split the command string.
|
||||
command = command.replace("\\\n", " ").replace("\\", " ")
|
||||
parts = command.split()
|
||||
return subprocess.Popen(parts, text=True, stderr=subprocess.STDOUT)
|
||||
return subprocess.Popen(
|
||||
parts,
|
||||
text=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
|
||||
|
||||
def launch_server_cmd(command: str, port: int = 30000):
|
||||
|
@ -187,8 +194,22 @@ class GenerationServer(Worker):
|
|||
if self.server_process is None:
|
||||
self.launch_server_subprocess()
|
||||
|
||||
# TODO: we may want to collect some metrics from the server
|
||||
time.sleep(0.05)
|
||||
# Check experiment finish.
|
||||
name = names.experiment_status(
|
||||
constants.experiment_name(), constants.trial_name()
|
||||
)
|
||||
try:
|
||||
exp_status = name_resolve.wait(name, timeout=300)
|
||||
if exp_status != str(ExpStatus.RUNNING):
|
||||
self.exit()
|
||||
return PollResult(0, 0)
|
||||
except TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Waiting for experiment status timeout. "
|
||||
"This indicates that the master worker is not running. Exit the worker."
|
||||
)
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
return PollResult(0, 0)
|
||||
|
||||
|
|
|
@ -13,40 +13,23 @@ import aiohttp
|
|||
import numpy as np
|
||||
|
||||
from realhf.api.core.model_api import GenReqMeta, GenRespMeta, ModelVersionReq
|
||||
from realhf.api.core.system_api import ExpStatus
|
||||
from realhf.api.core.system_api import GserverManager as GserverManagerConfig
|
||||
from realhf.base import constants, logging, name_resolve, names, network, recover
|
||||
from realhf.system.worker_base import PollResult, Worker
|
||||
from realhf.base.monitor import RolloutStat
|
||||
from realhf.system.worker_base import AsyncWorker, PollResult, Worker
|
||||
|
||||
logger = logging.getLogger("Generation Manager", "system")
|
||||
|
||||
STALENESS_WARNED = defaultdict(lambda: False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RolloutStat:
|
||||
submit: int = 0
|
||||
accepted: int = 0
|
||||
running: int = 0
|
||||
|
||||
def inc(self):
|
||||
self.submit += 1
|
||||
self.accepted += 1
|
||||
self.running += 1
|
||||
|
||||
def accept(self):
|
||||
self.running -= 1
|
||||
|
||||
def reject(self):
|
||||
self.running -= 1
|
||||
self.accepted -= 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class AllocateRolloutInput:
|
||||
qid: str
|
||||
|
||||
|
||||
class GserverManager(Worker):
|
||||
class GserverManager(AsyncWorker):
|
||||
"""This worker has the following functionalities:
|
||||
1. As a router, it schedules generation requests and returns the
|
||||
best server urls to clients for submitting generation requests.
|
||||
|
@ -104,7 +87,7 @@ class GserverManager(Worker):
|
|||
self.config.train_batch_size
|
||||
* self.__recover_info.last_step_info.global_step
|
||||
)
|
||||
self.rollout_stat.submit = hist_rollouts
|
||||
self.rollout_stat.submitted = hist_rollouts
|
||||
self.rollout_stat.accepted = hist_rollouts
|
||||
|
||||
return config.worker_info
|
||||
|
@ -187,7 +170,8 @@ class GserverManager(Worker):
|
|||
async with aiohttp.ClientSession(
|
||||
server_url,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.config.flush_request_timeout, sock_connect=30
|
||||
total=self.config.flush_request_timeout,
|
||||
sock_connect=self.config.flush_request_timeout,
|
||||
),
|
||||
) as session:
|
||||
async with session.post(
|
||||
|
@ -198,6 +182,7 @@ class GserverManager(Worker):
|
|||
res = await resp.json()
|
||||
success = res["success"]
|
||||
if success:
|
||||
if "num_paused_requests" in res:
|
||||
logger.info(
|
||||
f"{res['num_paused_requests']} requests are interrupted "
|
||||
f"during updateing weights for server {server_index}: {server_url}"
|
||||
|
@ -228,7 +213,7 @@ class GserverManager(Worker):
|
|||
url = min(self.server_urls, key=lambda k: self._server_token_usage[k])
|
||||
return self.server_urls.index(url)
|
||||
|
||||
def _poll(self):
|
||||
async def _poll_async(self):
|
||||
if not self.thread:
|
||||
# Find addresses of generation servers
|
||||
self.server_urls = self._discover_servers(self.config.n_servers)
|
||||
|
@ -244,6 +229,21 @@ class GserverManager(Worker):
|
|||
f"GserverManager HTTP service started in background thread at {self.manager_addr}"
|
||||
)
|
||||
|
||||
# Check experiment finish.
|
||||
name = names.experiment_status(
|
||||
constants.experiment_name(), constants.trial_name()
|
||||
)
|
||||
try:
|
||||
exp_status = name_resolve.wait(name, timeout=300)
|
||||
if exp_status != str(ExpStatus.RUNNING):
|
||||
self.exit()
|
||||
return PollResult(0, 0)
|
||||
except TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Waiting for experiment status timeout. "
|
||||
"This indicates that the master worker is not running. Exit the worker."
|
||||
)
|
||||
|
||||
# Check weights.
|
||||
with self.threading_lock:
|
||||
# FIXME: we create a sync point across servers to update weights,
|
||||
|
@ -254,12 +254,13 @@ class GserverManager(Worker):
|
|||
self.flush_requests_and_update_weights(base_url, new_param_path)
|
||||
for base_url in self.server_urls
|
||||
]
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(asyncio.gather(*tasks))
|
||||
await asyncio.gather(*tasks)
|
||||
logger.info(f"Generaion server updated weights from: {new_param_path}")
|
||||
|
||||
if self.schedule_policy == "least_token_usage":
|
||||
tasks = [
|
||||
self._get_server_token_usage(server_url) for server_url in self.server_urls
|
||||
self._get_server_token_usage(server_url)
|
||||
for server_url in self.server_urls
|
||||
]
|
||||
loop = asyncio.get_event_loop()
|
||||
token_usages = loop.run_until_complete(asyncio.gather(*tasks))
|
||||
|
@ -304,7 +305,8 @@ class GserverManager(Worker):
|
|||
async with aiohttp.ClientSession(
|
||||
server_url,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.config.flush_request_timeout, sock_connect=30
|
||||
total=self.config.flush_request_timeout,
|
||||
sock_connect=self.config.flush_request_timeout,
|
||||
),
|
||||
) as session:
|
||||
async with session.get("/metrics") as resp:
|
||||
|
@ -319,7 +321,8 @@ class GserverManager(Worker):
|
|||
async with aiohttp.ClientSession(
|
||||
server_url,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.config.flush_request_timeout, sock_connect=30
|
||||
total=self.config.flush_request_timeout,
|
||||
sock_connect=self.config.flush_request_timeout,
|
||||
),
|
||||
) as session:
|
||||
async with session.get(f"/metrics") as resp:
|
||||
|
@ -332,8 +335,16 @@ class GserverManager(Worker):
|
|||
f"Failed to get num running requests metrics from {server_url}"
|
||||
)
|
||||
|
||||
def get_training_sample_cnt(self):
|
||||
name = names.training_samples(self.experiment_name, self.trial_name)
|
||||
try:
|
||||
return int(name_resolve.get(name))
|
||||
except name_resolve.NameEntryNotFoundError:
|
||||
return 0
|
||||
|
||||
def is_staled(self):
|
||||
global_sample_cnt = self.rollout_stat.accepted
|
||||
# Use counter written by the trainer, local counter is inaccurate
|
||||
global_sample_cnt = self.get_training_sample_cnt() + self.rollout_stat.running
|
||||
expected_version = global_sample_cnt // self.config.train_batch_size
|
||||
version = self._last_param_realloc_step
|
||||
staled = expected_version > self.config.max_head_offpolicyness + version
|
||||
|
@ -406,13 +417,22 @@ class GserverManager(Worker):
|
|||
is_staled = self.is_staled()
|
||||
reason = ""
|
||||
if has_capacity and not is_staled:
|
||||
self.rollout_stat.inc()
|
||||
self.rollout_stat.submitted += 1
|
||||
self.rollout_stat.running += 1
|
||||
logger.info(
|
||||
f"Allocate rollout for qid {req.qid}. "
|
||||
f"Submitted: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
return dict(success=True, reason=reason)
|
||||
else:
|
||||
if not has_capacity:
|
||||
reason += f"capacity: {self.rollout_stat.running} >= {self.config.max_concurrent_rollouts}"
|
||||
if is_staled:
|
||||
global_sample_cnt = self.rollout_stat.accepted
|
||||
global_sample_cnt = (
|
||||
self.get_training_sample_cnt() + self.rollout_stat.running
|
||||
)
|
||||
expected_version = (
|
||||
global_sample_cnt // self.config.train_batch_size
|
||||
)
|
||||
|
@ -435,10 +455,15 @@ class GserverManager(Worker):
|
|||
), "server request count < 0"
|
||||
self._qid_to_server_url.pop(resp_meta.qid)
|
||||
self._gen_tokens += resp_meta.n_tokens
|
||||
self.rollout_stat.running -= 1
|
||||
if resp_meta.accepted:
|
||||
self.rollout_stat.accept()
|
||||
else:
|
||||
self.rollout_stat.reject()
|
||||
self.rollout_stat.accepted += 1
|
||||
logger.info(
|
||||
f"Finish rollout for qid {resp_meta.qid}. "
|
||||
f"Running: {self.rollout_stat.running}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}"
|
||||
)
|
||||
return dict(success=True)
|
||||
|
||||
port = network.find_free_port(
|
||||
|
|
|
@ -23,6 +23,7 @@ import realhf.system.request_reply_stream as request_reply_stream
|
|||
import realhf.system.worker_base as worker_base
|
||||
from realhf.api.core.config import ModelName
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.api.core.system_api import ExpStatus
|
||||
from realhf.base import (
|
||||
constants,
|
||||
logging,
|
||||
|
@ -40,7 +41,7 @@ logger = logging.getLogger("master worker", "system")
|
|||
blogger = logging.getLogger("benchmark")
|
||||
|
||||
|
||||
class MasterWorker(worker_base.Worker):
|
||||
class MasterWorker(worker_base.AsyncWorker):
|
||||
global_exp_tik = time.perf_counter()
|
||||
|
||||
def _configure(self, config: config_pkg.MasterWorker):
|
||||
|
@ -142,6 +143,22 @@ class MasterWorker(worker_base.Worker):
|
|||
f"Global Step: {self.__rpc_ctrl.step_info.global_step + 1}."
|
||||
)
|
||||
|
||||
# Recover the previous number of training samples
|
||||
train_rpcs = list(
|
||||
filter(
|
||||
lambda rpc: rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
|
||||
self.__model_rpcs,
|
||||
)
|
||||
)
|
||||
train_batch_size = train_rpcs[0].n_seqs
|
||||
hist_samples = (
|
||||
train_batch_size * self.__recover_info.last_step_info.global_step
|
||||
)
|
||||
training_sample_name = names.training_samples(
|
||||
constants.experiment_name(), constants.trial_name()
|
||||
)
|
||||
name_resolve.add(training_sample_name, str(hist_samples), replace=True)
|
||||
|
||||
# for benchmark
|
||||
self.e2e_time_history = []
|
||||
self.__benchmark_steps = config.exp_ctrl.benchmark_steps
|
||||
|
@ -274,6 +291,7 @@ class MasterWorker(worker_base.Worker):
|
|||
]
|
||||
|
||||
# wandb init, connect to remote wandb host
|
||||
if self.wandb_config.mode != "disabled":
|
||||
wandb.login()
|
||||
wandb.init(
|
||||
mode=self.wandb_config.mode,
|
||||
|
@ -327,7 +345,7 @@ class MasterWorker(worker_base.Worker):
|
|||
global_step=-1,
|
||||
)
|
||||
|
||||
def _poll(self):
|
||||
async def __poll_async(self):
|
||||
is_new_epoch = False
|
||||
|
||||
if not self.__initialized:
|
||||
|
@ -369,7 +387,7 @@ class MasterWorker(worker_base.Worker):
|
|||
self.logger.info(s)
|
||||
|
||||
# Traverse over the dataflow graph for once.
|
||||
self.func_executor.execute_step()
|
||||
await self.func_executor.execute_step()
|
||||
|
||||
# Post-process.
|
||||
if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt:
|
||||
|
@ -431,6 +449,18 @@ class MasterWorker(worker_base.Worker):
|
|||
|
||||
return worker_base.PollResult(sample_count=1, batch_count=1)
|
||||
|
||||
async def _poll_async(self):
|
||||
name = names.experiment_status(
|
||||
constants.experiment_name(), constants.trial_name()
|
||||
)
|
||||
name_resolve.add(name, ExpStatus.RUNNING, replace=True)
|
||||
try:
|
||||
r = await self.__poll_async()
|
||||
except Exception as e:
|
||||
name_resolve.add(name, ExpStatus.ABORTED, replace=True)
|
||||
raise e
|
||||
return r
|
||||
|
||||
def _log_training_stats(self, e2e_time: float, time_since_configure: float):
|
||||
# calculate flops
|
||||
#########################################
|
||||
|
@ -482,8 +512,15 @@ class MasterWorker(worker_base.Worker):
|
|||
+ colorama.Style.RESET_ALL
|
||||
)
|
||||
|
||||
# Update experiment status to inform other workers
|
||||
name = names.experiment_status(
|
||||
constants.experiment_name(), constants.trial_name()
|
||||
)
|
||||
name_resolve.add(name, ExpStatus.COMPLETE, replace=True)
|
||||
|
||||
# Send requests to pause model workers.
|
||||
# Model workers will not respond to this message.
|
||||
# FIXME: request to model workers is unnecessary
|
||||
self.__stream.request(
|
||||
handlers=list(range(self.config.n_model_workers)),
|
||||
handle_type="reset",
|
||||
|
|
|
@ -19,8 +19,7 @@ import realhf.api.core.dfg as dfg
|
|||
import realhf.api.core.system_api as config_pkg
|
||||
import realhf.base.recover as recover
|
||||
import realhf.system.request_reply_stream as request_reply_stream
|
||||
from realhf import ModelShardID
|
||||
from realhf.api.core.config import ModelName
|
||||
from realhf.api.core.config import ModelName, ModelShardID
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base import constants, logging, stats_tracker, topology
|
||||
from realhf.system.buffer import AsyncIOSequenceBuffer
|
||||
|
|
|
@ -465,6 +465,7 @@ class ModelWorker(worker_base.Worker):
|
|||
# because we may want to copy huggingface configurations from it, and
|
||||
# th next recover save will remove this symlink.
|
||||
dst_path = Path(model_path).parent / "_tmp_ckpt"
|
||||
shutil.rmtree(dst_path, ignore_errors=True)
|
||||
shutil.copytree(model_path, dst_path)
|
||||
os.unlink(model_path)
|
||||
os.system(f"mv {str(dst_path)} {model_path}")
|
||||
|
@ -669,14 +670,26 @@ class ModelWorker(worker_base.Worker):
|
|||
self.data_manager.store(x)
|
||||
assert len(set([x.ids[0] for x in data_loaded])) == len(data_loaded)
|
||||
|
||||
if len(data_loaded) > 0:
|
||||
meta_sample = data_api.SequenceSample.gather(data_loaded).meta()
|
||||
else:
|
||||
meta_sample = None
|
||||
birth_times = []
|
||||
if len(data_loaded) > 0:
|
||||
sample = data_api.SequenceSample.gather(data_loaded)
|
||||
meta_sample = sample.meta()
|
||||
if "birth_time" in sample.keys:
|
||||
birth_times = (
|
||||
sample.data["birth_time"].flatten().cpu().numpy().tolist()
|
||||
)
|
||||
assert len(birth_times) == meta_sample.bs
|
||||
else:
|
||||
birth_times = (
|
||||
time.monotonic_ns()
|
||||
+ np.arange(len(data_loaded), dtype=np.int64)
|
||||
).tolist()
|
||||
|
||||
res = data_api.DataBatchMeta(
|
||||
dp_rank=dp_rank,
|
||||
meta_sample=meta_sample,
|
||||
birth_times=birth_times,
|
||||
)
|
||||
elif request.handle_name == "spec":
|
||||
# Raw dataset without filtering.
|
||||
|
|
|
@ -80,7 +80,7 @@ class PartialRolloutManager:
|
|||
async with session.post(
|
||||
f"http://{self.gserver_manager_addr}/schedule_request",
|
||||
json=asdict(req_meta),
|
||||
timeout=ClientTimeout(total=self.timeout, sock_connect=30),
|
||||
timeout=ClientTimeout(total=self.timeout, sock_connect=self.timeout),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
|
|
|
@ -44,16 +44,9 @@ class ZMQJsonPusher:
|
|||
TypeError: If data is not JSON-serializable
|
||||
zmq.ZMQError: If ZeroMQ operation fails
|
||||
"""
|
||||
try:
|
||||
# Directly encode to bytes without intermediate string
|
||||
json_bytes = asbytes(orjson.dumps(data))
|
||||
self.socket.send(json_bytes, flags=zmq.NOBLOCK, copy=False)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise TypeError(f"Data not JSON-serializable: {e}")
|
||||
except zmq.ZMQError as e:
|
||||
if e.errno == zmq.EAGAIN:
|
||||
logger.warning("Push operation would block (queue full)")
|
||||
raise
|
||||
self.socket.send(json_bytes, copy=False)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
|
|
|
@ -14,6 +14,7 @@ from aiohttp.client import ClientTimeout
|
|||
from realhf.api.core.agent_api import make_agent
|
||||
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer, make_dataset
|
||||
from realhf.api.core.env_api import make_env
|
||||
from realhf.api.core.system_api import ExpStatus
|
||||
from realhf.api.core.system_api import RolloutWorker as RolloutWorkerConfig
|
||||
from realhf.base import (
|
||||
constants,
|
||||
|
@ -24,6 +25,7 @@ from realhf.base import (
|
|||
recover,
|
||||
seeding,
|
||||
)
|
||||
from realhf.base.monitor import RolloutStat
|
||||
from realhf.system.partial_rollout import PartialRolloutManager
|
||||
from realhf.system.push_pull_stream import NameResolvingZmqPusher
|
||||
from realhf.system.worker_base import AsyncWorker, PollResult
|
||||
|
@ -80,8 +82,9 @@ class RolloutWorker(AsyncWorker):
|
|||
self.gserver_manager_addr = None
|
||||
self.rollout_tasks: Dict[Hashable, asyncio.Task] = {}
|
||||
|
||||
# recover info
|
||||
self.__recover_run, self.__recover_info = recover.load_recover_info()
|
||||
# Since the rollout worker doesn't compute staleness,
|
||||
# we don't need to recover rollout_stat here.
|
||||
self.rollout_stat = RolloutStat()
|
||||
|
||||
return config.worker_info
|
||||
|
||||
|
@ -182,10 +185,8 @@ class RolloutWorker(AsyncWorker):
|
|||
self.data_generator = enumerate(self.dataloader)
|
||||
return None
|
||||
|
||||
# NOTE: no need to ignore ids during recover, because model workers will do so
|
||||
data_id = cur_sample.ids[0]
|
||||
if self.__recover_run and data_id in self.__recover_info.hash_vals_to_ignore:
|
||||
self.__recover_info.hash_vals_to_ignore.remove(data_id)
|
||||
return None
|
||||
assert data_id not in self.rollout_tasks
|
||||
return cur_sample
|
||||
|
||||
|
@ -195,11 +196,14 @@ class RolloutWorker(AsyncWorker):
|
|||
f"http://{self.gserver_manager_addr}/allocate_rollout",
|
||||
json=dict(qid=qid),
|
||||
timeout=ClientTimeout(
|
||||
total=self.config.rollout_request_timeout, sock_connect=30
|
||||
total=self.config.rollout_request_timeout,
|
||||
sock_connect=self.config.rollout_request_timeout,
|
||||
),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
res = await resp.json()
|
||||
if not res["success"]:
|
||||
logger.info(f"Cannot allocate new rollout because: {res['reason']}")
|
||||
return res["success"]
|
||||
|
||||
async def _poll_async(self):
|
||||
|
@ -213,6 +217,21 @@ class RolloutWorker(AsyncWorker):
|
|||
f"Time consumed: {time.perf_counter() - tik}s"
|
||||
)
|
||||
|
||||
# Check experiment finish.
|
||||
name = names.experiment_status(
|
||||
constants.experiment_name(), constants.trial_name()
|
||||
)
|
||||
try:
|
||||
exp_status = name_resolve.wait(name, timeout=300)
|
||||
if exp_status != str(ExpStatus.RUNNING):
|
||||
self.exit()
|
||||
return PollResult(0, 0)
|
||||
except TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Waiting for experiment status timeout. "
|
||||
"This indicates that the master worker is not running. Exit the worker."
|
||||
)
|
||||
|
||||
if self.push_stream is None:
|
||||
# Initialize stream after configure to ensure that puller names have been written.
|
||||
self.push_stream = NameResolvingZmqPusher(
|
||||
|
@ -236,13 +255,24 @@ class RolloutWorker(AsyncWorker):
|
|||
qid = data.ids[0]
|
||||
can_rollout = await self.allocate_new_rollout(qid)
|
||||
if can_rollout:
|
||||
assert qid not in self.act_queues
|
||||
self.act_queues[qid] = asyncio.Queue(1024)
|
||||
|
||||
task = asyncio.create_task(self.rollout_task(qid, data))
|
||||
assert qid not in self.rollout_tasks
|
||||
self.rollout_tasks[qid] = task
|
||||
|
||||
self._cur_data = None
|
||||
|
||||
self.rollout_stat.submitted += 1
|
||||
self.rollout_stat.running += 1
|
||||
logger.info(
|
||||
f"Submit a new rollout for qid {qid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
|
||||
# Run rollouts and wait
|
||||
done, *_ = await asyncio.gather(
|
||||
self.poll_rollout_task(),
|
||||
|
@ -261,10 +291,13 @@ class RolloutWorker(AsyncWorker):
|
|||
self.rollout_tasks.pop(qid)
|
||||
self.act_queues.pop(qid)
|
||||
|
||||
self.rollout_stat.running -= 1
|
||||
|
||||
accepted = False
|
||||
if len(trajs) > 0:
|
||||
accepted = True
|
||||
self.push_stream.push([traj.as_json_compatible() for traj in trajs])
|
||||
self.rollout_stat.accepted += 1
|
||||
|
||||
n_tokens = 0
|
||||
for traj in trajs:
|
||||
|
@ -278,11 +311,18 @@ class RolloutWorker(AsyncWorker):
|
|||
"/finish_rollout",
|
||||
json=info,
|
||||
timeout=ClientTimeout(
|
||||
total=self.config.rollout_request_timeout, sock_connect=30
|
||||
total=self.config.rollout_request_timeout,
|
||||
sock_connect=self.config.rollout_request_timeout,
|
||||
),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
assert (await resp.json())["success"]
|
||||
logger.info(
|
||||
f"Finish rollout for qid {qid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
|
||||
for traj in trajs:
|
||||
batch_count += traj.bs
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import queue
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, List, Optional
|
||||
|
@ -12,9 +13,11 @@ from realhf.api.core.data_api import (
|
|||
make_dataset,
|
||||
register_dataset,
|
||||
)
|
||||
from realhf.base import constants
|
||||
from realhf.base import constants, logging
|
||||
from realhf.system.push_pull_stream import NameResolvingZmqPuller
|
||||
|
||||
logger = logging.getLogger("StreamDataset")
|
||||
|
||||
|
||||
class PullerStreamDataset(Dataset):
|
||||
def __init__(
|
||||
|
@ -45,15 +48,12 @@ class PullerStreamDataset(Dataset):
|
|||
del dataset, datasets
|
||||
|
||||
self.pull_timeout_ms = pull_timeout_ms
|
||||
self.data_queue = queue.Queue(maxsize=self.dataset_size)
|
||||
self.data_queue = queue.Queue(maxsize=self.dataset_size * util.world_size)
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# Pass ZMQ context (thread-safe) and let worker create the socket
|
||||
self.util = util
|
||||
self.worker_thread = threading.Thread(
|
||||
target=self._pull_data_worker,
|
||||
daemon=True,
|
||||
)
|
||||
self.worker_thread = threading.Thread(target=self._pull_data_worker)
|
||||
self.worker_thread.start()
|
||||
|
||||
def _pull_data_worker(self):
|
||||
|
@ -71,13 +71,19 @@ class PullerStreamDataset(Dataset):
|
|||
processed_data = [
|
||||
SequenceSample.from_json_compatible(x) for x in data
|
||||
]
|
||||
self.data_queue.put_nowait(processed_data)
|
||||
logger.debug(
|
||||
f"Get data {[x.ids[0] for x in processed_data]} from puller stream."
|
||||
)
|
||||
self.data_queue.put(processed_data)
|
||||
except queue.Empty:
|
||||
logger.debug(f"No data from puller stream.")
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
finally:
|
||||
# Ensure socket is closed in the same thread
|
||||
del stream
|
||||
# Exit if this thread has an error
|
||||
sys.exit(1)
|
||||
|
||||
def __getitem__(self, idx: int) -> Optional[Any]:
|
||||
samples = []
|
||||
|
|
|
@ -60,3 +60,7 @@ orjson>=3.10.16
|
|||
flask
|
||||
setuptools>=62.3.0,<75.9
|
||||
func_timeout
|
||||
jupyter-book
|
||||
uvloop>=0.21.0
|
||||
uvicorn>=0.34.2
|
||||
fastapi>=0.115.12
|
||||
|
|
270
setup.py
|
@ -2,280 +2,10 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List, Set
|
||||
|
||||
import setuptools
|
||||
import torch
|
||||
import torch.utils.cpp_extension as torch_cpp_ext
|
||||
from packaging.version import Version, parse
|
||||
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
|
||||
|
||||
ROOT_DIR = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
# Supported NVIDIA GPU architectures.
|
||||
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
|
||||
|
||||
|
||||
def _is_cuda() -> bool:
|
||||
return os.getenv("REAL_CUDA", "0") == "1"
|
||||
|
||||
|
||||
# Compiler flags.
|
||||
CXX_FLAGS = ["-g", "-O3", "-std=c++17"]
|
||||
NVCC_FLAGS = ["-O3", "-std=c++17"]
|
||||
|
||||
if _is_cuda() and CUDA_HOME is None:
|
||||
raise RuntimeError(
|
||||
"Cannot find CUDA_HOME. In GPU environment, CUDA must be available to build the package."
|
||||
)
|
||||
|
||||
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
|
||||
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||
|
||||
|
||||
def glob(pattern: str):
|
||||
root = Path(__name__).parent
|
||||
return [str(p) for p in root.glob(pattern)]
|
||||
|
||||
|
||||
def get_pybind11_include_path() -> str:
|
||||
pybind11_meta = subprocess.check_output(
|
||||
"python3 -m pip show pybind11", shell=True
|
||||
).decode("ascii")
|
||||
for line in pybind11_meta.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("Location: "):
|
||||
return os.path.join(line.split(": ")[1], "pybind11", "include")
|
||||
|
||||
|
||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||
"""Get the CUDA version from nvcc.
|
||||
|
||||
Adapted from
|
||||
https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
|
||||
"""
|
||||
nvcc_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
||||
)
|
||||
output = nvcc_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
|
||||
return nvcc_cuda_version
|
||||
|
||||
|
||||
def get_torch_arch_list() -> Set[str]:
|
||||
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
|
||||
# e.g."8.0" or "7.5,8.0,8.6+PTX".Here, the "8.6+PTX" option asks the
|
||||
# compiler to additionally include PTX code that can be runtime - compiled
|
||||
# and executed on the 8.6 or newer architectures.While the PTX code will
|
||||
# not give the best performance on the newer architectures, it provides
|
||||
# forward compatibility.
|
||||
env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
|
||||
if env_arch_list is None:
|
||||
return set()
|
||||
|
||||
# List are separated by; or space.
|
||||
torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
|
||||
if not torch_arch_list:
|
||||
return set()
|
||||
|
||||
# Filter out the invalid architectures and print a warning.
|
||||
valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
|
||||
{s + "+PTX" for s in NVIDIA_SUPPORTED_ARCHS}
|
||||
)
|
||||
arch_list = torch_arch_list.intersection(valid_archs)
|
||||
# If none of the specified architectures are valid, raise an error.
|
||||
if not arch_list:
|
||||
raise RuntimeError(
|
||||
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
|
||||
f"variable ({env_arch_list}) is supported. "
|
||||
f"Supported CUDA architectures are: {valid_archs}."
|
||||
)
|
||||
invalid_arch_list = torch_arch_list - valid_archs
|
||||
if invalid_arch_list:
|
||||
warnings.warn(
|
||||
f"Unsupported CUDA architectures ({invalid_arch_list}) are "
|
||||
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
|
||||
f"({env_arch_list}). Supported CUDA architectures are: "
|
||||
f"{valid_archs}.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return arch_list
|
||||
|
||||
|
||||
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
|
||||
compute_capabilities = get_torch_arch_list()
|
||||
|
||||
if _is_cuda() and not compute_capabilities:
|
||||
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
|
||||
# GPUs on the current machine.
|
||||
device_count = torch.cuda.device_count()
|
||||
for i in range(device_count):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
if major < 7:
|
||||
raise RuntimeError(
|
||||
"GPUs with compute capability below 7.0 are not supported."
|
||||
)
|
||||
compute_capabilities.add(f"{major}.{minor}")
|
||||
|
||||
ext_modules = []
|
||||
|
||||
if _is_cuda():
|
||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||
if not compute_capabilities:
|
||||
# If no GPU is specified nor available, add all supported architectures
|
||||
# based on the NVCC CUDA version.
|
||||
compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy()
|
||||
if nvcc_cuda_version < Version("11.1"):
|
||||
compute_capabilities.remove("8.6")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
compute_capabilities.remove("8.9")
|
||||
compute_capabilities.remove("9.0")
|
||||
# Validate the NVCC CUDA version.
|
||||
if nvcc_cuda_version < Version("11.0"):
|
||||
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
||||
if nvcc_cuda_version < Version("11.1") and any(
|
||||
cc.startswith("8.6") for cc in compute_capabilities
|
||||
):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.1 or higher is required for compute capability 8.6."
|
||||
)
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
if any(cc.startswith("8.9") for cc in compute_capabilities):
|
||||
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||
# instead of 8.9.
|
||||
warnings.warn(
|
||||
"CUDA 11.8 or higher is required for compute capability 8.9. "
|
||||
"Targeting compute capability 8.0 instead.",
|
||||
stacklevel=2,
|
||||
)
|
||||
compute_capabilities = set(
|
||||
cc for cc in compute_capabilities if not cc.startswith("8.9")
|
||||
)
|
||||
compute_capabilities.add("8.0+PTX")
|
||||
if any(cc.startswith("9.0") for cc in compute_capabilities):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.8 or higher is required for compute capability 9.0."
|
||||
)
|
||||
|
||||
NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy()
|
||||
|
||||
# Add target compute capabilities to NVCC flags.
|
||||
for capability in compute_capabilities:
|
||||
num = capability[0] + capability[2]
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
|
||||
if capability.endswith("+PTX"):
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
|
||||
if int(capability[0]) >= 8:
|
||||
NVCC_FLAGS_PUNICA += [
|
||||
"-gencode",
|
||||
f"arch=compute_{num},code=sm_{num}",
|
||||
]
|
||||
if capability.endswith("+PTX"):
|
||||
NVCC_FLAGS_PUNICA += [
|
||||
"-gencode",
|
||||
f"arch=compute_{num},code=compute_{num}",
|
||||
]
|
||||
|
||||
# Use NVCC threads to parallelize the build.
|
||||
if nvcc_cuda_version >= Version("11.2"):
|
||||
nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
|
||||
num_threads = min(os.cpu_count(), nvcc_threads)
|
||||
NVCC_FLAGS += ["--threads", str(num_threads)]
|
||||
|
||||
if nvcc_cuda_version >= Version("11.8"):
|
||||
NVCC_FLAGS += ["-DENABLE_FP8_E5M2"]
|
||||
|
||||
# changes for punica kernels
|
||||
NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS
|
||||
REMOVE_NVCC_FLAGS = [
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
]
|
||||
for flag in REMOVE_NVCC_FLAGS:
|
||||
with contextlib.suppress(ValueError):
|
||||
torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag)
|
||||
|
||||
os.makedirs(os.path.join(ROOT_DIR, "realhf", "_C"), exist_ok=True)
|
||||
|
||||
no_ext = os.getenv("REAL_NO_EXT", "0") == "1"
|
||||
|
||||
if not no_ext and _is_cuda():
|
||||
gae_extension = CUDAExtension(
|
||||
name="realhf._C.cugae",
|
||||
sources=[
|
||||
"csrc/cugae/gae.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS
|
||||
+ [
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
"--use_fast_math",
|
||||
],
|
||||
},
|
||||
libraries=["cuda"],
|
||||
)
|
||||
ext_modules.append(gae_extension)
|
||||
|
||||
interval_op_cuda = CUDAExtension(
|
||||
name="realhf._C.interval_op_cuda",
|
||||
sources=[
|
||||
"csrc/interval_op/interval_op.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
libraries=["cuda"],
|
||||
)
|
||||
ext_modules.append(interval_op_cuda)
|
||||
|
||||
if not no_ext:
|
||||
interval_extension = setuptools.Extension(
|
||||
name="realhf._C.interval_op",
|
||||
sources=[
|
||||
"csrc/interval_op/interval_op.cpp",
|
||||
],
|
||||
language="c++",
|
||||
extra_compile_args=[
|
||||
"-O3",
|
||||
"-Wall",
|
||||
"-std=c++17",
|
||||
],
|
||||
include_dirs=[
|
||||
get_pybind11_include_path(),
|
||||
],
|
||||
)
|
||||
ext_modules.append(interval_extension)
|
||||
|
||||
|
||||
setuptools.setup(
|
||||
name="realhf",
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
packages=setuptools.find_packages(),
|
||||
include_package_data=True,
|
||||
package_data={
|
||||
"": [
|
||||
"csrc/**/*.cu",
|
||||
"csrc/**/*.cuh",
|
||||
"csrc/**/*.hpp",
|
||||
"csrc/**/*.cpp",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
|
|
@ -85,7 +85,9 @@ async def test_collect_trajectory_happy_path(agent, mock_env, mock_prompt, mock_
|
|||
sample = result[0]
|
||||
assert sample.ids == [str(123)]
|
||||
assert torch.equal(sample.data["packed_prompts"], torch.tensor([1, 2, 3]))
|
||||
assert torch.equal(sample.data["rewards"], torch.tensor([0.8, 1.2]))
|
||||
# r = [0.5, 0.7]
|
||||
# ((r - 0.5) * 2 - bias) * scaling, bias=0.1, scaling=2.0
|
||||
assert torch.equal(sample.data["rewards"], torch.tensor([-0.2, 0.6]))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -184,7 +184,7 @@ def _test_data_transfer(
|
|||
]
|
||||
storage_tracker.add_data_synced(
|
||||
gpu_id,
|
||||
ids=[i + dp_rank * world_size for i in range(world_size)],
|
||||
ids=[str(i + dp_rank * world_size) for i in range(world_size)],
|
||||
key=key,
|
||||
is_owner=True,
|
||||
)
|
||||
|
@ -199,7 +199,7 @@ def _test_data_transfer(
|
|||
dist.all_reduce(input_ids)
|
||||
|
||||
s = SequenceSample.from_default(
|
||||
ids=[i + dp_rank * world_size for i in range(world_size)],
|
||||
ids=[str(i + dp_rank * world_size) for i in range(world_size)],
|
||||
seqlens=seqlens.numpy().tolist(),
|
||||
data=dict(input_ids=input_ids),
|
||||
)
|
||||
|
@ -216,7 +216,7 @@ def _test_data_transfer(
|
|||
|
||||
dist.barrier()
|
||||
|
||||
all_ids = list(range(world_size * from_topo.get_dim("data")))
|
||||
all_ids = list(map(str, range(world_size * from_topo.get_dim("data"))))
|
||||
np.random.shuffle(all_ids)
|
||||
_all_ids = [all_ids]
|
||||
dist.broadcast_object_list(_all_ids, src=0)
|
||||
|
@ -236,7 +236,7 @@ def _test_data_transfer(
|
|||
)
|
||||
]
|
||||
size_per_dp = len(all_ids) // dp_size
|
||||
dests[gpu_id] = [coord.data * size_per_dp + i for i in range(size_per_dp)]
|
||||
dests[gpu_id] = [str(coord.data * size_per_dp + i) for i in range(size_per_dp)]
|
||||
|
||||
for gpu_id in range(world_size):
|
||||
if gpu_id not in dests:
|
||||
|
@ -253,29 +253,21 @@ def _test_data_transfer(
|
|||
print("success")
|
||||
|
||||
|
||||
parallelism = [(1, 4, 2), (1, 8, 1)]
|
||||
parallelism = [(4, 1, 1), (2, 2, 2), (1, 8, 1), (3, 2, 1), (2, 1, 2), (1, 2, 2)]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.cpu_count() < 32 or testing.get_free_mem_gb() < 50,
|
||||
reason="The parameter reallocation test requires at least 32 CPUs and 50GB memory.",
|
||||
)
|
||||
@pytest.mark.parametrize("from_pp_dp_tp", [(1, 4, 2)])
|
||||
@pytest.mark.parametrize("to_pp_dp_tp", [(1, 8, 1)])
|
||||
@pytest.mark.parametrize("from_pp_dp_tp", parallelism)
|
||||
@pytest.mark.parametrize("to_pp_dp_tp", parallelism)
|
||||
@pytest.mark.distributed
|
||||
def test_data_transfer(
|
||||
tmp_path,
|
||||
from_pp_dp_tp: Tuple,
|
||||
to_pp_dp_tp: Tuple,
|
||||
):
|
||||
expr_name = uuid.uuid4()
|
||||
trial_name = uuid.uuid4()
|
||||
constants.set_force_cpu(True)
|
||||
test_impl = LocalMultiProcessTest(
|
||||
world_size=16,
|
||||
func=_test_data_transfer,
|
||||
expr_name=expr_name,
|
||||
trial_name=trial_name,
|
||||
timeout_secs=300,
|
||||
tmp_path=tmp_path,
|
||||
from_pp_dp_tp=from_pp_dp_tp,
|
||||
|
|
|
@ -526,10 +526,7 @@ def _test_para_realloc(
|
|||
parallelism = [(4, 1, 1), (2, 2, 2), (1, 8, 1), (3, 2, 1), (2, 1, 2), (1, 2, 2)]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.cpu_count() < 32 or testing.get_free_mem_gb() < 50,
|
||||
reason="The parameter reallocation test requires at least 32 CPUs and 50GB memory.",
|
||||
)
|
||||
@pytest.mark.skip("NCCL-based parameter reallocation is not used currently.")
|
||||
@pytest.mark.parametrize("model_family_name", ["gpt2", "llama"])
|
||||
@pytest.mark.parametrize("is_critic", [False, True])
|
||||
@pytest.mark.parametrize("from_pp_dp_tp", parallelism)
|
||||
|
|
|
@ -2,10 +2,7 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
|
@ -44,6 +41,7 @@ def maybe_synchronize_cuda():
|
|||
"n_intervals", list(reversed([1, 100, 500, 1000, 2000, 4000, 10000, 100000]))
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32, torch.float16])
|
||||
@pytest.mark.gpu
|
||||
def test_get(n_intervals: int, dtype: torch.dtype):
|
||||
device = torch.device("cuda")
|
||||
|
||||
|
|
|
@ -100,9 +100,9 @@ def test_denominator_edge_cases(tracker):
|
|||
tracker.denominator(mask=zero_mask)
|
||||
tracker.stat(denominator="mask", value=torch.FloatTensor([1.0, 2.0]))
|
||||
results = tracker.export()
|
||||
assert torch.isnan(torch.tensor(results["value/min"])) # Should be inf
|
||||
assert torch.isnan(torch.tensor(results["value/max"])) # Should be -inf
|
||||
assert results["value/avg"] == 0.0
|
||||
assert "value/min" not in results
|
||||
assert "value/max" not in results
|
||||
assert "value/avg" not in results
|
||||
|
||||
|
||||
def test_key_specific_export(tracker):
|
||||
|
|
|
@ -19,6 +19,7 @@ from realhf.base.name_resolve import (
|
|||
BACKENDS = [
|
||||
("memory", {}),
|
||||
("nfs", {}),
|
||||
("ray", {}),
|
||||
]
|
||||
if os.environ.get("REAL_ETCD_ADDR"):
|
||||
BACKENDS.append(
|
||||
|
@ -61,6 +62,12 @@ def name_resolve(request):
|
|||
repo = Etcd3NameRecordRepository(**kwargs)
|
||||
yield repo
|
||||
repo.reset()
|
||||
elif backend_type == "ray":
|
||||
from realhf.base.name_resolve import RayNameResolveRepository
|
||||
|
||||
repo = RayNameResolveRepository(**kwargs)
|
||||
yield repo
|
||||
repo.reset()
|
||||
|
||||
|
||||
def test_basic_add_get(name_resolve):
|
||||
|
@ -381,18 +388,19 @@ def test_wait_with_concurrent_delete(name_resolve):
|
|||
def add_then_delete():
|
||||
time.sleep(0.1)
|
||||
name_resolve.add("test_wait_key", "test_value")
|
||||
time.sleep(0.1)
|
||||
time.sleep(1.0)
|
||||
name_resolve.delete("test_wait_key")
|
||||
|
||||
thread = threading.Thread(target=add_then_delete, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Wait with a timeout long enough to capture the key
|
||||
value = name_resolve.wait("test_wait_key", timeout=2.0, poll_frequency=0.05)
|
||||
value = name_resolve.wait("test_wait_key", timeout=3.0, poll_frequency=0.05)
|
||||
assert value == "test_value"
|
||||
|
||||
# Wait for the thread to complete
|
||||
thread.join()
|
||||
time.sleep(0.5)
|
||||
|
||||
# Verify the key was deleted
|
||||
with pytest.raises(NameEntryNotFoundError):
|
||||
|
|
|
@ -76,9 +76,10 @@ def test_buffer_recover(
|
|||
trial_name=trial_name,
|
||||
mode="local",
|
||||
# allocation_mode=f"m1d{dp}p1",
|
||||
nodelist="slurmd-01",
|
||||
allocation_mode="manual",
|
||||
inf=MFCConfig(
|
||||
device_mesh="NODE01:0,1,2,3,4,5,6,7",
|
||||
device_mesh="slurmd-01:0,1,2,3,4,5,6,7",
|
||||
parallel=ParallelismConfig(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=2,
|
||||
|
@ -86,7 +87,7 @@ def test_buffer_recover(
|
|||
),
|
||||
),
|
||||
train=MFCConfig(
|
||||
device_mesh="NODE01:8,9,10,11,12,13,14,15",
|
||||
device_mesh="slurmd-01:8,9,10,11,12,13,14,15",
|
||||
parallel=ParallelismConfig(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=2,
|
||||
|
|
|
@ -51,6 +51,7 @@ def math_code_dataset(request, save_path):
|
|||
return dataset
|
||||
|
||||
|
||||
@pytest.mark.skip("symmetric allocation is not used")
|
||||
@pytest.mark.parametrize(
|
||||
"dp,pp,mp",
|
||||
[
|
||||
|
@ -121,9 +122,138 @@ def test_ppo_symm(
|
|||
run_test_exp(exp_cfg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gdp,gpp,gmp",
|
||||
[
|
||||
(2, 1, 1),
|
||||
(1, 1, 2),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"dp,pp,mp",
|
||||
[
|
||||
(2, 1, 1),
|
||||
(1, 2, 1),
|
||||
(1, 1, 2),
|
||||
],
|
||||
)
|
||||
def test_ppo_decoupled(
|
||||
tmp_path_factory,
|
||||
tokenizer,
|
||||
math_code_dataset,
|
||||
save_path,
|
||||
cpu_hf_model,
|
||||
mconfig,
|
||||
dp,
|
||||
pp,
|
||||
mp,
|
||||
gdp,
|
||||
gpp,
|
||||
gmp,
|
||||
):
|
||||
# Setup experiment env. Should be done before any other operations.
|
||||
log_root = tmp_path_factory.mktemp("ppo")
|
||||
cluster.spec.fileroot = str(log_root)
|
||||
constants.set_experiment_trial_names(
|
||||
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
||||
)
|
||||
|
||||
minbs = 32
|
||||
exp_cfg = PPOMATHConfig(
|
||||
experiment_name=testing._DEFAULT_EXPR_NAME,
|
||||
trial_name=testing._DEFAULT_TRIAL_NAME,
|
||||
mode="local",
|
||||
allocation_mode=f"manual",
|
||||
nodelist="slurmd-01",
|
||||
n_nodes=1,
|
||||
n_gpus_per_node=mp * dp * pp + gmp * gdp * gpp,
|
||||
actor=ModelTrainEvalConfig(
|
||||
path=str(save_path),
|
||||
init_from_scratch=True,
|
||||
backend="mock_train",
|
||||
),
|
||||
ref=ModelTrainEvalConfig(
|
||||
path=str(save_path),
|
||||
init_from_scratch=True,
|
||||
),
|
||||
critic=ModelTrainEvalConfig(
|
||||
path=str(save_path),
|
||||
init_from_scratch=True,
|
||||
init_critic_from_actor=True,
|
||||
backend="mock_train",
|
||||
),
|
||||
actor_gen=MFCConfig(
|
||||
device_mesh="slurmd-01:0,1",
|
||||
parallel=ParallelismConfig(
|
||||
tensor_parallel_size=gmp,
|
||||
pipeline_parallel_size=gpp,
|
||||
data_parallel_size=gdp,
|
||||
),
|
||||
),
|
||||
actor_train=MFCConfig(
|
||||
device_mesh="slurmd-01:2,3",
|
||||
parallel=ParallelismConfig(
|
||||
tensor_parallel_size=mp,
|
||||
pipeline_parallel_size=pp,
|
||||
data_parallel_size=dp,
|
||||
),
|
||||
),
|
||||
critic_train=MFCConfig(
|
||||
device_mesh="slurmd-01:2,3",
|
||||
parallel=ParallelismConfig(
|
||||
tensor_parallel_size=mp,
|
||||
pipeline_parallel_size=pp,
|
||||
data_parallel_size=dp,
|
||||
),
|
||||
),
|
||||
critic_inf=MFCConfig(
|
||||
device_mesh="slurmd-01:2,3",
|
||||
parallel=ParallelismConfig(
|
||||
tensor_parallel_size=mp,
|
||||
pipeline_parallel_size=pp,
|
||||
data_parallel_size=dp,
|
||||
),
|
||||
),
|
||||
ref_inf=MFCConfig(
|
||||
device_mesh="slurmd-01:2,3",
|
||||
parallel=ParallelismConfig(
|
||||
tensor_parallel_size=mp,
|
||||
pipeline_parallel_size=pp,
|
||||
data_parallel_size=dp,
|
||||
),
|
||||
),
|
||||
rew_inf=MFCConfig(
|
||||
device_mesh="slurmd-01:2,3",
|
||||
parallel=ParallelismConfig(
|
||||
tensor_parallel_size=mp,
|
||||
pipeline_parallel_size=pp,
|
||||
data_parallel_size=dp,
|
||||
),
|
||||
),
|
||||
dataset=PromptOnlyDatasetConfig(
|
||||
path=str(save_path / "math_code_dataset.jsonl"),
|
||||
max_prompt_len=mconfig.n_positions // 2,
|
||||
train_bs_n_seqs=minbs,
|
||||
fill_to_max_length=False,
|
||||
),
|
||||
ppo=PPOHyperparameters(
|
||||
gen=GenerationHyperparameters(
|
||||
max_new_tokens=4,
|
||||
min_new_tokens=4,
|
||||
greedy=True,
|
||||
use_cuda_graph=False,
|
||||
),
|
||||
),
|
||||
group_size=2,
|
||||
)
|
||||
|
||||
run_test_exp(exp_cfg)
|
||||
|
||||
|
||||
# The global resharding strategy, where all MFCs
|
||||
# occupy the same device mesh but with different
|
||||
# parallelization strategies.
|
||||
@pytest.mark.skip("Global resharding is not used.")
|
||||
@pytest.mark.parametrize("actor_gen", [(1, 2, 1)])
|
||||
@pytest.mark.parametrize("actor_train", [(1, 1, 2)])
|
||||
@pytest.mark.parametrize("critic_inf", [(1, 1, 2)])
|
||||
|
@ -244,6 +374,7 @@ def test_ppo_global_reshard(
|
|||
|
||||
# Actor/critic train and ref_inf/rew_inf are on disjoint
|
||||
# device meshes and executed concurrently.
|
||||
@pytest.mark.skip("Critic is not used.")
|
||||
@pytest.mark.parametrize("actor_gen", [(2, 2, 1)])
|
||||
@pytest.mark.parametrize("critic_inf", [(2, 1, 2)])
|
||||
def test_ppo_param_realloc_sub_device_mesh(
|
||||
|
@ -306,7 +437,7 @@ def test_ppo_param_realloc_sub_device_mesh(
|
|||
),
|
||||
),
|
||||
actor_gen=MFCConfig(
|
||||
device_mesh="NODE01:0,1,2,3",
|
||||
device_mesh="slurmd-01:0,1,2,3",
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=actor_gen[0],
|
||||
tensor_parallel_size=actor_gen[1],
|
||||
|
@ -314,7 +445,7 @@ def test_ppo_param_realloc_sub_device_mesh(
|
|||
),
|
||||
),
|
||||
actor_train=MFCConfig(
|
||||
device_mesh="NODE01:4,5,6,7",
|
||||
device_mesh="slurmd-01:4,5,6,7",
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=4,
|
||||
tensor_parallel_size=1,
|
||||
|
@ -322,7 +453,7 @@ def test_ppo_param_realloc_sub_device_mesh(
|
|||
),
|
||||
),
|
||||
critic_inf=MFCConfig(
|
||||
device_mesh="NODE01:4,5,6,7",
|
||||
device_mesh="slurmd-01:4,5,6,7",
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=critic_inf[0],
|
||||
tensor_parallel_size=critic_inf[1],
|
||||
|
@ -330,7 +461,7 @@ def test_ppo_param_realloc_sub_device_mesh(
|
|||
),
|
||||
),
|
||||
rew_inf=MFCConfig(
|
||||
device_mesh="NODE01:4,5,6,7",
|
||||
device_mesh="slurmd-01:4,5,6,7",
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=4,
|
||||
tensor_parallel_size=1,
|
||||
|
@ -338,7 +469,7 @@ def test_ppo_param_realloc_sub_device_mesh(
|
|||
),
|
||||
),
|
||||
ref_inf=MFCConfig(
|
||||
device_mesh="NODE01:4,5,6,7",
|
||||
device_mesh="slurmd-01:4,5,6,7",
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=1,
|
||||
tensor_parallel_size=2,
|
||||
|
@ -346,7 +477,7 @@ def test_ppo_param_realloc_sub_device_mesh(
|
|||
),
|
||||
),
|
||||
critic_train=MFCConfig(
|
||||
device_mesh="NODE01:4,5,6,7",
|
||||
device_mesh="slurmd-01:4,5,6,7",
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=2,
|
||||
tensor_parallel_size=1,
|
||||
|
@ -389,6 +520,7 @@ def test_ppo_save(
|
|||
allocation_mode="manual",
|
||||
n_nodes=1,
|
||||
n_gpus_per_node=2,
|
||||
nodelist="slurmd-01",
|
||||
actor=ModelTrainEvalConfig(
|
||||
path=str(save_path),
|
||||
init_from_scratch=True,
|
||||
|
@ -436,7 +568,7 @@ def test_ppo_save(
|
|||
)
|
||||
),
|
||||
actor_train=MFCConfig(
|
||||
device_mesh="NODE01:0",
|
||||
device_mesh="slurmd-01:0",
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=1,
|
||||
tensor_parallel_size=1,
|
||||
|
@ -465,7 +597,7 @@ def test_ppo_save(
|
|||
)
|
||||
),
|
||||
critic_train=MFCConfig(
|
||||
device_mesh="NODE01:1",
|
||||
device_mesh="slurmd-01:1",
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=1,
|
||||
tensor_parallel_size=1,
|
||||
|
|
|
@ -28,6 +28,9 @@ def run_model_worker(cfg, mw, barrier, expr_name=None):
|
|||
|
||||
system_api.ALL_EXPERIMENT_CLASSES = {}
|
||||
register_experiment(expr_name or testing._DEFAULT_EXPR_NAME, lambda: cfg)
|
||||
constants.set_experiment_trial_names(
|
||||
mw.worker_info.experiment_name, mw.worker_info.trial_name
|
||||
)
|
||||
|
||||
worker = ModelWorker()
|
||||
logger.info("Configuring model worker...")
|
||||
|
|
|
@ -1,100 +0,0 @@
|
|||
import asyncio
|
||||
import random
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from realhf.api.cli_args import GenerationHyperparameters
|
||||
from realhf.api.core.model_api import (
|
||||
APIGenerateInput,
|
||||
APIGenerateOutput,
|
||||
BundledGenerationOutputs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sglang_client(request):
|
||||
from sglang.test.test_utils import is_in_ci
|
||||
|
||||
if is_in_ci():
|
||||
from patch import launch_server_cmd
|
||||
else:
|
||||
from sglang.utils import launch_server_cmd
|
||||
|
||||
from sglang.utils import terminate_process, wait_for_server
|
||||
|
||||
server_process, port = launch_server_cmd(
|
||||
f"python -m sglang.launch_server --model-path {request.param} --host 0.0.0.0 --skip-tokenizer-init "
|
||||
)
|
||||
|
||||
wait_for_server(f"http://localhost:{port}")
|
||||
from realhf.impl.model.backend.sglang import SGLangAPIClient
|
||||
|
||||
client = SGLangAPIClient(
|
||||
generate_url=f"http://localhost:{port}/generate",
|
||||
update_weights_url=f"http://localhost:{port}/update_weights_from_disk",
|
||||
)
|
||||
yield client
|
||||
terminate_process(server_process)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sglang_client",
|
||||
["/storage/openpsi/models/Qwen__Qwen2.5-7B-Instruct/"],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.parametrize("group_size", [16])
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_generate(sglang_client, group_size):
|
||||
bs = 8
|
||||
# genlen = 16384
|
||||
genlen = 10
|
||||
prompt_len = 100
|
||||
async with sglang_client:
|
||||
tasks = []
|
||||
qids = []
|
||||
for i in range(bs):
|
||||
qid = str(uuid.uuid4())
|
||||
prompt_ids = [random.randint(10, 100) for _ in range(prompt_len)]
|
||||
gconfig = GenerationHyperparameters(
|
||||
n=group_size,
|
||||
max_new_tokens=genlen,
|
||||
)
|
||||
req = APIGenerateInput(
|
||||
qid=qid,
|
||||
prompt_ids=prompt_ids,
|
||||
input_ids=prompt_ids,
|
||||
gconfig=gconfig,
|
||||
return_logprob=True,
|
||||
)
|
||||
tasks.append(sglang_client.async_add_generate_request(req, stream=False))
|
||||
qids.append(qid)
|
||||
|
||||
outputs = {}
|
||||
for r in asyncio.as_completed(tasks):
|
||||
out = await r
|
||||
outputs[out.qid] = out
|
||||
|
||||
results = [outputs[key] for key in qids]
|
||||
assert all([isinstance(r, APIGenerateOutput) for r in results])
|
||||
batch_token_ids = []
|
||||
batch_logprobs = []
|
||||
max_seqlen = -1
|
||||
for x in results:
|
||||
max_seqlen = max(max_seqlen, max(x.output_lens))
|
||||
batch_token_ids += x.output_ids
|
||||
batch_logprobs += x.output_logprobs
|
||||
|
||||
pad_token_id = 0
|
||||
# To be consistent with our internal implementation,
|
||||
# we should pad generated tokens and logprobs
|
||||
batch_token_ids = [
|
||||
t + [pad_token_id] * (max_seqlen - len(t)) for t in batch_token_ids
|
||||
]
|
||||
batch_logprobs = [p + [0.0] * (max_seqlen - len(p)) for p in batch_logprobs]
|
||||
|
||||
tokens = torch.tensor(batch_token_ids, dtype=torch.long, device="cpu")
|
||||
assert tokens.shape == (bs * group_size, genlen)
|
||||
logprobs = torch.tensor(batch_logprobs, dtype=torch.float32, device="cpu")
|
||||
assert logprobs.shape == (bs * group_size, genlen)
|
|
@ -86,7 +86,7 @@ def test_multi_task_reward_interface(save_path, tokenizer_path, math_code_datase
|
|||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
from realhf.impl.model.interface.rw_interface import MultiTaskRewardInterface
|
||||
from realhf.impl.model.interface.math_rw_interface import MultiTaskRewardInterface
|
||||
|
||||
with constants.model_scope(testing.MODEL_NAME):
|
||||
interface = MultiTaskRewardInterface(
|
||||
|
|
|
@ -55,7 +55,7 @@ def check_sequences_consistency(
|
|||
)
|
||||
|
||||
|
||||
def test_fn(
|
||||
def _fn(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
path: str,
|
||||
|
@ -194,12 +194,12 @@ def test_fn(
|
|||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def test_sglang_consistency(tp: int, dp: int, path: str, model_family_name: str):
|
||||
def check_sglang_consistency(tp: int, dp: int, path: str, model_family_name: str):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
world_size = dp * tp
|
||||
procs = [
|
||||
mp.Process(
|
||||
target=test_fn,
|
||||
target=_fn,
|
||||
args=(
|
||||
i,
|
||||
world_size,
|
||||
|
@ -236,7 +236,7 @@ if __name__ == "__main__":
|
|||
# pp=1,
|
||||
# tp=1,
|
||||
# )
|
||||
test_sglang_consistency(
|
||||
check_sglang_consistency(
|
||||
tp=2,
|
||||
dp=2,
|
||||
path=path,
|
||||
|
|
|
@ -59,7 +59,7 @@ def check_sequences_consistency(
|
|||
)
|
||||
|
||||
|
||||
def test_fn(
|
||||
def _fn(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
path: str,
|
||||
|
@ -203,12 +203,12 @@ def test_fn(
|
|||
print("success")
|
||||
|
||||
|
||||
def test_vllm_tp_consistency(tp: int, dp: int, path: str, model_family_name: str):
|
||||
def check_vllm_tp_consistency(tp: int, dp: int, path: str, model_family_name: str):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
world_size = dp * tp
|
||||
procs = [
|
||||
mp.Process(
|
||||
target=test_fn,
|
||||
target=_fn,
|
||||
args=(
|
||||
i,
|
||||
world_size,
|
||||
|
@ -236,7 +236,7 @@ def test_vllm_tp_consistency(tp: int, dp: int, path: str, model_family_name: str
|
|||
if __name__ == "__main__":
|
||||
# for model_family_name in _available_model_classes:
|
||||
# path = MODEL_CLASS_TO_PATH[model_family_name]
|
||||
test_vllm_tp_consistency(
|
||||
check_vllm_tp_consistency(
|
||||
tp=2,
|
||||
dp=2,
|
||||
path="/storage/models/Qwen__Qwen2-1.5B-Instruct/",
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import queue
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from realhf.api.core.config import ModelName
|
||||
|
@ -118,7 +122,8 @@ def mock_servers():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def gserver_manager(mock_servers):
|
||||
def gserver_manager(request, mock_servers):
|
||||
train_batch_size, offpolicyness = request.param
|
||||
testing.clear_name_resolve()
|
||||
constants.set_experiment_trial_names(
|
||||
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
||||
|
@ -135,6 +140,10 @@ def gserver_manager(mock_servers):
|
|||
config = GserverManagerConfig(
|
||||
model_name=ModelName("default", 0),
|
||||
n_servers=N_SERVERS,
|
||||
train_batch_size=train_batch_size,
|
||||
max_head_offpolicyness=offpolicyness,
|
||||
flush_request_timeout=300,
|
||||
max_concurrent_rollouts=128,
|
||||
schedule_policy="round_robin",
|
||||
worker_info=WorkerInformation(
|
||||
experiment_name=testing._DEFAULT_EXPR_NAME,
|
||||
|
@ -151,6 +160,7 @@ def gserver_manager(mock_servers):
|
|||
m.exit()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gserver_manager", [(4, 1000)], indirect=True)
|
||||
@pytest.mark.asyncio
|
||||
async def test_schedule_policy(gserver_manager):
|
||||
# Test round-robin scheduling
|
||||
|
@ -171,14 +181,8 @@ async def test_schedule_policy(gserver_manager):
|
|||
assert idx3 == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_weight_update(gserver_manager):
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from realhf.api.core.model_api import GenReqMeta
|
||||
|
||||
client = TestClient(gserver_manager.app)
|
||||
|
||||
@pytest.mark.parametrize("gserver_manager", [(4, 1000)], indirect=True)
|
||||
def test_weight_update(gserver_manager):
|
||||
# Set up a new parameter version
|
||||
name = names.model_version(
|
||||
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME, "default"
|
||||
|
@ -187,22 +191,14 @@ async def test_weight_update(gserver_manager):
|
|||
global UPDATE_WEIGHTS_CALL_COUNT
|
||||
UPDATE_WEIGHTS_CALL_COUNT.clear()
|
||||
|
||||
req_meta = GenReqMeta(
|
||||
"2",
|
||||
prompt_len=100,
|
||||
group_size=2,
|
||||
new_token_budget=1024,
|
||||
predicted_new_tokens=None,
|
||||
)
|
||||
|
||||
client.post("/schedule_request", json=dataclasses.asdict(req_meta))
|
||||
gserver_manager._poll()
|
||||
assert gserver_manager._last_param_realloc_step == 1
|
||||
assert len(UPDATE_WEIGHTS_CALL_COUNT) == N_SERVERS
|
||||
for v in UPDATE_WEIGHTS_CALL_COUNT.values():
|
||||
assert v == 1
|
||||
|
||||
# weights updated, no more weights update
|
||||
client.post("/schedule_request", json=dataclasses.asdict(req_meta))
|
||||
gserver_manager._poll()
|
||||
assert gserver_manager._last_param_realloc_step == 1
|
||||
assert len(UPDATE_WEIGHTS_CALL_COUNT) == N_SERVERS
|
||||
for v in UPDATE_WEIGHTS_CALL_COUNT.values():
|
||||
|
@ -213,6 +209,7 @@ async def test_weight_update(gserver_manager):
|
|||
name_resolve.delete(name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gserver_manager", [(4, 1000)], indirect=True)
|
||||
def test_server_lifecycle(gserver_manager):
|
||||
# Test that the server starts and stops properly
|
||||
assert gserver_manager.thread is not None
|
||||
|
@ -224,6 +221,7 @@ def test_server_lifecycle(gserver_manager):
|
|||
assert not gserver_manager.thread.is_alive()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gserver_manager", [(4, 1000)], indirect=True)
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_server_endpoints(gserver_manager):
|
||||
# Test the FastAPI endpoints
|
||||
|
@ -253,6 +251,74 @@ async def test_http_server_endpoints(gserver_manager):
|
|||
assert responses == set(gserver_manager.server_urls)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gserver_manager", [(4, 1000)], indirect=True)
|
||||
def test_unique_server_urls(gserver_manager):
|
||||
# Ensure server URLs are unique
|
||||
assert len(set(gserver_manager.server_urls)) == len(gserver_manager.server_urls)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gserver_manager", [(4, 0), (4, 1), (4, 4)], indirect=True)
|
||||
@pytest.mark.parametrize("n_clients", [1, 2, 3])
|
||||
def test_offpolicyness_control(n_clients, gserver_manager):
|
||||
train_batch_size = gserver_manager.config.train_batch_size
|
||||
offpolicyness = gserver_manager.config.max_head_offpolicyness
|
||||
addr = gserver_manager.manager_addr
|
||||
|
||||
res_queue = queue.Queue(n_clients)
|
||||
|
||||
async def _client_thread(res_queue):
|
||||
cnt = 0
|
||||
for _ in range(train_batch_size):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://{addr}/allocate_rollout",
|
||||
json=dict(qid="a"),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
res = await resp.json()
|
||||
cnt += int(res["success"])
|
||||
res_queue.put(cnt)
|
||||
|
||||
def run_client(res_queue):
|
||||
asyncio.run(_client_thread(res_queue))
|
||||
|
||||
total_cnt = 0
|
||||
|
||||
jobs = [
|
||||
threading.Thread(target=run_client, args=(res_queue,)) for _ in range(n_clients)
|
||||
]
|
||||
for job in jobs:
|
||||
job.start()
|
||||
for job in jobs:
|
||||
job.join()
|
||||
total_cnt += res_queue.get()
|
||||
assert total_cnt == min(
|
||||
train_batch_size * n_clients, (1 + offpolicyness) * train_batch_size
|
||||
)
|
||||
|
||||
# Increase the model version by 1
|
||||
version_name = names.model_version(
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"default",
|
||||
)
|
||||
name_resolve.add(version_name, "1")
|
||||
gserver_manager._poll()
|
||||
|
||||
# Run the rollout worker again
|
||||
jobs = [
|
||||
threading.Thread(target=run_client, args=(res_queue,)) for _ in range(n_clients)
|
||||
]
|
||||
for job in jobs:
|
||||
job.start()
|
||||
for job in jobs:
|
||||
job.join()
|
||||
total_cnt += res_queue.get()
|
||||
|
||||
# The rollout worker should produce new samples
|
||||
assert total_cnt == min(
|
||||
train_batch_size * n_clients * 2, (2 + offpolicyness) * train_batch_size
|
||||
)
|
||||
|
||||
# Final clean up
|
||||
name_resolve.delete(version_name)
|
||||
|
|
|
@ -84,6 +84,7 @@ def partial_rollout_manager():
|
|||
reply_queue=reply_queue,
|
||||
new_tokens_per_chunk=new_tokens_per_chunk,
|
||||
tokenizer=mock_tokenizer,
|
||||
timeout=300,
|
||||
)
|
||||
yield manager
|
||||
# Cleanup if needed
|
||||
|
|
|
@ -1,174 +0,0 @@
|
|||
import asyncio
|
||||
import copy
|
||||
from asyncio.queues import QueueEmpty
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from realhf.api.core.config import (
|
||||
AgentAbstraction,
|
||||
DatasetAbstraction,
|
||||
EnvServiceAbstraction,
|
||||
ModelName,
|
||||
)
|
||||
from realhf.api.core.model_api import (
|
||||
BundledGenerationOutputs,
|
||||
GenerationHyperparameters,
|
||||
)
|
||||
from realhf.api.core.system_api import RolloutWorker as RolloutWorkerConfig
|
||||
from realhf.api.core.system_api import WorkerInformation
|
||||
from realhf.base import constants, name_resolve, names, network, testing
|
||||
from realhf.system.push_pull_stream import NameResolvingZmqPusher
|
||||
from realhf.system.rollout_worker import RolloutWorker
|
||||
from tests.fixtures import *
|
||||
|
||||
N_PULLERS = 3
|
||||
|
||||
|
||||
class MockPartialRolloutManager:
|
||||
def __init__(self, request_queue, reply_queue, **kwargs):
|
||||
self.request_queue = request_queue
|
||||
self.reply_queue = reply_queue
|
||||
self.internal_queue = []
|
||||
|
||||
def get_num_gen_requests(self):
|
||||
return len(self.internal_queue)
|
||||
|
||||
async def run_step(self):
|
||||
async def poll_fresh_requests():
|
||||
for _ in range(8):
|
||||
try:
|
||||
qid, prompt_token_ids, gconfig = self.request_queue.get_nowait()
|
||||
assert isinstance(qid, str)
|
||||
assert isinstance(prompt_token_ids, list)
|
||||
assert all(isinstance(x, int) for x in prompt_token_ids)
|
||||
assert isinstance(gconfig, GenerationHyperparameters)
|
||||
self.internal_queue.append(qid)
|
||||
except QueueEmpty:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def poll_old_requests():
|
||||
for _ in range(8):
|
||||
if random.random() < 0.5 and len(self.internal_queue) > 0:
|
||||
# responses may not return in order
|
||||
idx = random.randint(0, len(self.internal_queue) - 1)
|
||||
qid = self.internal_queue.pop(idx)
|
||||
out = BundledGenerationOutputs(
|
||||
qid=qid,
|
||||
prompt_ids=[1],
|
||||
output_ids=[[2], [3]],
|
||||
seqs=[[1, 2], [1, 3]],
|
||||
logprobs=[[0.0, 0.1], [0.0, 2.0]],
|
||||
no_eos=[True, True],
|
||||
version_start=[0, 1],
|
||||
version_end=[1, 2],
|
||||
)
|
||||
await self.reply_queue.put(out)
|
||||
else:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await asyncio.gather(poll_fresh_requests(), poll_old_requests())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rollout_workers(request):
|
||||
testing.clear_name_resolve()
|
||||
constants.set_experiment_trial_names(
|
||||
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
||||
)
|
||||
|
||||
# add name resolve to make zmq pusher happy
|
||||
puller_ports = network.find_multiple_free_ports(N_PULLERS)
|
||||
for puller_index in range(N_PULLERS):
|
||||
name = names.stream_pullers(
|
||||
testing._DEFAULT_EXPR_NAME,
|
||||
testing._DEFAULT_TRIAL_NAME,
|
||||
)
|
||||
name_resolve.add_subentry(name, str(puller_index))
|
||||
name = names.push_pull_stream(
|
||||
testing._DEFAULT_EXPR_NAME,
|
||||
testing._DEFAULT_TRIAL_NAME,
|
||||
f"puller{puller_index}",
|
||||
)
|
||||
name_resolve.add(name, f"localhost:{puller_ports[puller_index]}")
|
||||
|
||||
with (
|
||||
patch.object(NameResolvingZmqPusher, "push", return_value=None) as mock_push,
|
||||
patch(
|
||||
"realhf.system.rollout_worker.PartialRolloutManager",
|
||||
new=MockPartialRolloutManager,
|
||||
),
|
||||
):
|
||||
|
||||
ms = [RolloutWorker() for _ in range(request.param)]
|
||||
yield ms
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rollout_workers", [1, 2, 3], indirect=True)
|
||||
@pytest.mark.parametrize("offpolicyness", [0, 1, 4])
|
||||
@pytest.mark.asyncio
|
||||
async def test_offpolicyness_control(
|
||||
rollout_workers, save_path, dataset, offpolicyness
|
||||
):
|
||||
train_batch_size = 8
|
||||
config = RolloutWorkerConfig(
|
||||
base_seed=0,
|
||||
model_name=ModelName("default", 0),
|
||||
max_head_offpolicyness=offpolicyness,
|
||||
train_batch_size=train_batch_size,
|
||||
tokenizer_path="/storage/openpsi/models/Qwen__Qwen2.5-1.5B/",
|
||||
new_tokens_per_chunk=1024,
|
||||
max_concurrent_rollouts=128,
|
||||
env=EnvServiceAbstraction("null"),
|
||||
agent=AgentAbstraction("null", dict(episode_length=5, traj_size=5)),
|
||||
datasets=[
|
||||
DatasetAbstraction(
|
||||
"prompt", args=dict(dataset_path=str(save_path / "dataset.jsonl"))
|
||||
)
|
||||
],
|
||||
worker_info=WorkerInformation(
|
||||
experiment_name=testing._DEFAULT_EXPR_NAME,
|
||||
trial_name=testing._DEFAULT_TRIAL_NAME,
|
||||
worker_type="rollout_worker",
|
||||
worker_count=N_PULLERS * 2,
|
||||
worker_index=0,
|
||||
),
|
||||
)
|
||||
for i, m in enumerate(rollout_workers):
|
||||
config = copy.deepcopy(config)
|
||||
config.worker_info.worker_index = i
|
||||
m._configure(config)
|
||||
|
||||
for i in range(10 * (offpolicyness + 1)):
|
||||
for m in rollout_workers:
|
||||
await m._poll_async()
|
||||
|
||||
# Ensure that data is not overly produced
|
||||
for m in rollout_workers:
|
||||
assert m.agent.ACT_GET_CNT > 0
|
||||
assert (
|
||||
(offpolicyness + 1) * train_batch_size >= m.push_stream.push.call_count > 0
|
||||
)
|
||||
|
||||
# Increase the model version by 1
|
||||
version_name = names.model_version(
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
config.model_name.role,
|
||||
)
|
||||
name_resolve.add(version_name, "1")
|
||||
|
||||
# Run the rollout worker again
|
||||
for i in range(10 * (offpolicyness + 1)):
|
||||
for m in rollout_workers:
|
||||
await m._poll_async()
|
||||
|
||||
# The rollout worker should produce new samples
|
||||
for m in rollout_workers:
|
||||
assert (offpolicyness + 2) * train_batch_size >= m.push_stream.push.call_count
|
||||
assert (offpolicyness + 1) * train_batch_size < m.push_stream.push.call_count
|
||||
|
||||
# Final clean up
|
||||
name_resolve.delete(version_name)
|
||||
for m in rollout_workers:
|
||||
await m._exit_async_tasks()
|
|
@ -0,0 +1,65 @@
|
|||
import dataclasses
|
||||
import datetime
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import hydra
|
||||
import yaml
|
||||
from omegaconf import MISSING, OmegaConf
|
||||
|
||||
from realhf.api.quickstart.entrypoint import kind_reminder
|
||||
from realhf.base.constants import init_constants
|
||||
from realhf.experiments.async_exp.async_ppo_math_exp import AsyncPPOMATHConfig
|
||||
from training.utils import run_experiment
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="configs/async-ppo")
|
||||
def main_ppo_math(args):
|
||||
# NOTE: we import logging here to avoid hydra logging overwrite
|
||||
import realhf.base.logging as logging
|
||||
|
||||
logger = logging.getLogger("quickstart", "colored")
|
||||
|
||||
# Overwrite the python dataclass configuration with yaml
|
||||
default_args = OmegaConf.structured(AsyncPPOMATHConfig)
|
||||
args = OmegaConf.merge(default_args, args)
|
||||
args: AsyncPPOMATHConfig = OmegaConf.to_object(args)
|
||||
|
||||
# Set experiment trial name.
|
||||
exp_name = args.experiment_name
|
||||
if args.trial_name == MISSING:
|
||||
args.trial_name = trial_name = (
|
||||
f"run{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
|
||||
)
|
||||
else:
|
||||
trial_name = args.trial_name
|
||||
|
||||
if args.mode != "ray":
|
||||
raise RuntimeError("This script only supports the `ray` mode.")
|
||||
|
||||
init_constants(args)
|
||||
|
||||
from realhf.base.constants import LOG_ROOT
|
||||
|
||||
# Save overwritten configuration to yaml
|
||||
config_save_path = os.path.join(
|
||||
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
|
||||
)
|
||||
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
|
||||
with open(config_save_path, "w") as f:
|
||||
config_dict: Dict = dataclasses.asdict(args)
|
||||
yaml.dump(
|
||||
config_dict,
|
||||
f,
|
||||
default_flow_style=False,
|
||||
sort_keys=False,
|
||||
)
|
||||
|
||||
kind_reminder("async-ppo-math", logger, args)
|
||||
|
||||
run_experiment(args, exp_name, trial_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command: python3 training/main_async_ppo.py --config-name async-ppo-1.7b-gpu8
|
||||
main_ppo_math()
|
|
@ -0,0 +1,65 @@
|
|||
import dataclasses
|
||||
import datetime
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import hydra
|
||||
import yaml
|
||||
from omegaconf import MISSING, OmegaConf
|
||||
|
||||
from realhf.api.quickstart.entrypoint import kind_reminder
|
||||
from realhf.base.constants import init_constants
|
||||
from realhf.experiments.common.ppo_math_exp import PPOMATHConfig
|
||||
from training.utils import run_experiment
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="configs/ppo")
|
||||
def main(args):
|
||||
# NOTE: we import logging here to avoid hydra logging overwrite
|
||||
import realhf.base.logging as logging
|
||||
|
||||
logger = logging.getLogger("quickstart", "colored")
|
||||
|
||||
# Overwrite the python dataclass configuration with yaml
|
||||
default_args = OmegaConf.structured(PPOMATHConfig)
|
||||
args = OmegaConf.merge(default_args, args)
|
||||
args: PPOMATHConfig = OmegaConf.to_object(args)
|
||||
|
||||
# Set experiment trial name.
|
||||
exp_name = args.experiment_name
|
||||
if args.trial_name == MISSING:
|
||||
args.trial_name = trial_name = (
|
||||
f"run{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
|
||||
)
|
||||
else:
|
||||
trial_name = args.trial_name
|
||||
|
||||
if args.mode != "ray":
|
||||
raise RuntimeError("This script only supports the `ray` mode.")
|
||||
|
||||
init_constants(args)
|
||||
|
||||
from realhf.base.constants import LOG_ROOT
|
||||
|
||||
# Save overwritten configuration to yaml
|
||||
config_save_path = os.path.join(
|
||||
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
|
||||
)
|
||||
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
|
||||
with open(config_save_path, "w") as f:
|
||||
config_dict: Dict = dataclasses.asdict(args)
|
||||
yaml.dump(
|
||||
config_dict,
|
||||
f,
|
||||
default_flow_style=False,
|
||||
sort_keys=False,
|
||||
)
|
||||
|
||||
kind_reminder("ppo-math", logger, args)
|
||||
|
||||
run_experiment(args, exp_name, trial_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command: python3 training/main_ppo.py --config-name ppo-1.5b-gpu32
|
||||
main()
|
|
@ -0,0 +1,65 @@
|
|||
import dataclasses
|
||||
import datetime
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import hydra
|
||||
import yaml
|
||||
from omegaconf import MISSING, OmegaConf
|
||||
|
||||
from realhf.api.quickstart.entrypoint import kind_reminder
|
||||
from realhf.base.constants import init_constants
|
||||
from realhf.experiments.common.sft_exp import SFTConfig
|
||||
from training.utils import run_experiment
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="configs/sft")
|
||||
def main(args):
|
||||
# NOTE: we import logging here to avoid hydra logging overwrite
|
||||
import realhf.base.logging as logging
|
||||
|
||||
logger = logging.getLogger("quickstart", "colored")
|
||||
|
||||
# Overwrite the python dataclass configuration with yaml
|
||||
default_args = OmegaConf.structured(SFTConfig)
|
||||
args = OmegaConf.merge(default_args, args)
|
||||
args: SFTConfig = OmegaConf.to_object(args)
|
||||
|
||||
# Set experiment trial name.
|
||||
exp_name = args.experiment_name
|
||||
if args.trial_name == MISSING:
|
||||
args.trial_name = trial_name = (
|
||||
f"run{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
|
||||
)
|
||||
else:
|
||||
trial_name = args.trial_name
|
||||
|
||||
if args.mode != "ray":
|
||||
raise RuntimeError("This script only supports the `ray` mode.")
|
||||
|
||||
init_constants(args)
|
||||
|
||||
from realhf.base.constants import LOG_ROOT
|
||||
|
||||
# Save overwritten configuration to yaml
|
||||
config_save_path = os.path.join(
|
||||
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
|
||||
)
|
||||
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
|
||||
with open(config_save_path, "w") as f:
|
||||
config_dict: Dict = dataclasses.asdict(args)
|
||||
yaml.dump(
|
||||
config_dict,
|
||||
f,
|
||||
default_flow_style=False,
|
||||
sort_keys=False,
|
||||
)
|
||||
|
||||
kind_reminder("sft", logger, args)
|
||||
|
||||
run_experiment(args, exp_name, trial_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command: python3 training/main_sft.py --config-name sft-7b-gpu8
|
||||
main()
|
|
@ -0,0 +1,282 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
from typing import Any, List
|
||||
|
||||
import psutil
|
||||
import ray
|
||||
|
||||
from realhf.api.core.system_api import Experiment, ExperimentScheduling, TasksGroup
|
||||
from realhf.base import constants, logging, name_resolve, names
|
||||
from realhf.system import WORKER_TYPES, load_worker
|
||||
from realhf.system.worker_base import AsyncWorker, Worker, WorkerServerStatus
|
||||
|
||||
|
||||
# Copied from SGLang
|
||||
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
|
||||
"""Kill the process and all its child processes."""
|
||||
# Remove sigchld handler to avoid spammy logs.
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
|
||||
|
||||
if parent_pid is None:
|
||||
parent_pid = os.getpid()
|
||||
include_parent = False
|
||||
|
||||
try:
|
||||
itself = psutil.Process(parent_pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return
|
||||
|
||||
children = itself.children(recursive=True)
|
||||
for child in children:
|
||||
if child.pid == skip_pid:
|
||||
continue
|
||||
try:
|
||||
child.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
if include_parent:
|
||||
try:
|
||||
if parent_pid == os.getpid():
|
||||
itself.kill()
|
||||
sys.exit(0)
|
||||
|
||||
itself.kill()
|
||||
|
||||
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
|
||||
# so we send an additional signal to kill them.
|
||||
itself.send_signal(signal.SIGQUIT)
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
|
||||
@ray.remote
|
||||
class RayWorker:
|
||||
"""A thin wraper over realhf.system.worker_base.Worker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_type: str,
|
||||
worker_cls,
|
||||
kv_store_name,
|
||||
):
|
||||
# Register all datasets and models
|
||||
import realhf.impl.dataset # isort: skip
|
||||
import realhf.impl.model # isort: skip
|
||||
|
||||
os.environ["REAL_MODE"] = "RAY"
|
||||
|
||||
name_resolve.reconfigure("ray", actor_name=kv_store_name)
|
||||
self.worker: Worker | AsyncWorker = worker_cls()
|
||||
self.worker_type = worker_type
|
||||
|
||||
def __repr__(self):
|
||||
return "".join([c.capitalize() for c in self.worker_type.split("_")])
|
||||
|
||||
def configure(self, cfg: Any, expr_config: Any):
|
||||
constants.init_constants(expr_config)
|
||||
|
||||
worker_info = cfg.worker_info
|
||||
idx = worker_info.worker_index
|
||||
constants.set_experiment_trial_names(
|
||||
worker_info.experiment_name, worker_info.trial_name
|
||||
)
|
||||
self.worker.wandb_config = expr_config.wandb
|
||||
self.worker.tensorboard_config = expr_config.tensorboard
|
||||
self.logger = logging.getLogger(f"{self.worker_type} {idx}", "benchmark")
|
||||
self.logger.info(f"Configuring {self.worker_type}...")
|
||||
self.worker._configure(cfg)
|
||||
self.logger.info(f"Configuring {self.worker_type}... Done.")
|
||||
|
||||
def run_sync(self):
|
||||
self.logger.info(f"Running {self.worker_type} lazy initialization...")
|
||||
self.worker._poll()
|
||||
self.logger.info(f"Running {self.worker_type} lazy initialization... Done.")
|
||||
while self.worker.status != WorkerServerStatus.PAUSED:
|
||||
self.worker._poll()
|
||||
|
||||
async def run_async(self):
|
||||
self.logger.info(f"Running {self.worker_type} lazy initialization...")
|
||||
await self.worker._poll_async()
|
||||
self.logger.info(f"Running {self.worker_type} lazy initialization... Done.")
|
||||
while self.worker.status != WorkerServerStatus.PAUSED:
|
||||
await self.worker._poll_async()
|
||||
|
||||
|
||||
def _run_experiment(exp_cfg, expr_name, trial_name):
|
||||
# Register all datasets and models
|
||||
import realhf.impl.dataset # isort: skip
|
||||
import realhf.impl.model # isort: skip
|
||||
from realhf.api.core.system_api import ALL_EXPERIMENT_CLASSES
|
||||
from realhf.system.master_worker import MasterWorker
|
||||
|
||||
constants.set_experiment_trial_names(expr_name, trial_name)
|
||||
|
||||
logger = logging.getLogger(f"RayMasterWorker", "benchmark")
|
||||
|
||||
# Initialize ray in the Ray cluster
|
||||
env_vars = constants.get_env_vars(
|
||||
WADNB_MODE=exp_cfg.wandb.mode,
|
||||
REAL_MODE=os.environ.get("REAL_MODE", ""),
|
||||
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
|
||||
REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""),
|
||||
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
|
||||
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
||||
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
|
||||
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),
|
||||
REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"),
|
||||
REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", "localhost:2379"),
|
||||
)
|
||||
runtime_env = {
|
||||
"env_vars": env_vars,
|
||||
"working_dir": os.getcwd(),
|
||||
}
|
||||
logger.info(f"Ray workers runtime env: {runtime_env}")
|
||||
ray_log_path = exp_cfg.ray_temp_path
|
||||
os.makedirs(ray_log_path, exist_ok=True)
|
||||
ray.init(runtime_env=runtime_env, _temp_dir=ray_log_path)
|
||||
logger.info(f"Ray log root: {ray_log_path}")
|
||||
logger.info("Ray initialized! Ready to run workers.")
|
||||
|
||||
ray_kv_store_name = f"{expr_name}/{trial_name}/ray_kv_store"
|
||||
name_resolve.reconfigure("ray", actor_name=ray_kv_store_name)
|
||||
|
||||
name_resolve.clear_subtree(
|
||||
names.trial_root(experiment_name=expr_name, trial_name=trial_name)
|
||||
)
|
||||
|
||||
# Convert CLI args into worker configurations
|
||||
exp_setup = exp_cfg.initial_setup()
|
||||
exp_setup.set_worker_information(expr_name, trial_name)
|
||||
exp_setup.lazy_init()
|
||||
|
||||
# Initialize all workers
|
||||
all_workers = {}
|
||||
|
||||
scheduling: ExperimentScheduling = exp_cfg.scheduling_setup()
|
||||
|
||||
for worker_type in WORKER_TYPES:
|
||||
sch = getattr(scheduling, worker_type)
|
||||
if sch is None:
|
||||
continue
|
||||
|
||||
available_resources = ray.available_resources()
|
||||
cpu = sch.scheduling.cpu * sch.count
|
||||
gpu = sch.scheduling.gpu * sch.count
|
||||
mem = sch.scheduling.mem * sch.count / 1024 # in GB
|
||||
acpu = available_resources.get("CPU", 0)
|
||||
agpu = available_resources.get("GPU", 0)
|
||||
amem = available_resources.get("memory", 0) / 1024**3
|
||||
if acpu < cpu or agpu < gpu or amem < mem:
|
||||
logger.critical(
|
||||
f"Ray does not have enough resources to launch workers. "
|
||||
f"Required: {cpu} CPU, {gpu} GPU, {mem:.2f} GB memory. "
|
||||
f"Available: {acpu} CPU, {agpu} GPU, {amem:.2f} GB memory. "
|
||||
f"Please launch more Ray nodes otherwise the experiment will get stuck."
|
||||
)
|
||||
|
||||
# Use a customized packed scheduling method
|
||||
# that sequentially allocates nodes.
|
||||
available_nodes = [
|
||||
k
|
||||
for k in available_resources
|
||||
if re.match(r"node:(\b(?:\d{1,3}\.){3}\d{1,3}\b)", k)
|
||||
]
|
||||
total_gpus = available_resources["GPU"]
|
||||
n_gpus_per_node = int(total_gpus // len(available_nodes))
|
||||
|
||||
count = sch.count
|
||||
all_schedules: List[TasksGroup] = []
|
||||
for _ in range(sch.count):
|
||||
s_ = copy.deepcopy(sch)
|
||||
s_.count = 1
|
||||
all_schedules.append(s_)
|
||||
|
||||
workers = []
|
||||
|
||||
for node_idx, i in enumerate(range(0, count, n_gpus_per_node)):
|
||||
_schedules = all_schedules[i : i + n_gpus_per_node]
|
||||
for _idx, sch in enumerate(_schedules):
|
||||
# Schedule jobs one-by-one to maintain the order on remote nodes.
|
||||
worker = RayWorker.options(
|
||||
name=f"{worker_type}/{_idx + i}",
|
||||
num_cpus=sch.scheduling.cpu,
|
||||
num_gpus=sch.scheduling.gpu,
|
||||
memory=sch.scheduling.mem * 1024**2,
|
||||
).remote(
|
||||
worker_type=worker_type,
|
||||
worker_cls=load_worker(worker_type),
|
||||
kv_store_name=ray_kv_store_name,
|
||||
)
|
||||
workers.append(worker)
|
||||
|
||||
all_workers[worker_type] = workers
|
||||
|
||||
try:
|
||||
# Configure workers
|
||||
configure_jobs = []
|
||||
for worker_type in all_workers:
|
||||
worker_configs = getattr(exp_setup, worker_type)
|
||||
workers = all_workers[worker_type]
|
||||
assert len(workers) == len(worker_configs), (
|
||||
len(workers),
|
||||
len(worker_configs),
|
||||
)
|
||||
jobs = [
|
||||
w.configure.remote(c, exp_cfg) for w, c in zip(workers, worker_configs)
|
||||
]
|
||||
configure_jobs += jobs
|
||||
|
||||
ray.get(configure_jobs)
|
||||
|
||||
# Run workers
|
||||
run_jobs = []
|
||||
for worker_type in all_workers:
|
||||
workers = all_workers[worker_type]
|
||||
if worker_type in ["master_worker", "rollout_worker", "gserver_manager"]:
|
||||
# Only the rollout worker is asynchronous
|
||||
jobs = [w.run_async.remote() for w in workers]
|
||||
else:
|
||||
jobs = [w.run_sync.remote() for w in workers]
|
||||
run_jobs += jobs
|
||||
|
||||
ray.get(run_jobs)
|
||||
finally:
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
class DualOutput:
|
||||
def __init__(self, file, terminal):
|
||||
self.file = file
|
||||
self.terminal = terminal
|
||||
|
||||
def write(self, message):
|
||||
self.terminal.write(message)
|
||||
self.file.write(message)
|
||||
|
||||
def flush(self):
|
||||
self.terminal.flush()
|
||||
self.file.flush()
|
||||
|
||||
def fileno(self):
|
||||
# Return the terminal's fileno to maintain expected behavior
|
||||
return self.terminal.fileno()
|
||||
|
||||
|
||||
def run_experiment(exp_cfg, expr_name, trial_name):
|
||||
log_path = os.path.join(constants.LOG_ROOT, expr_name, trial_name, "main.log")
|
||||
with open(log_path, "a") as f:
|
||||
# Create dual output handler
|
||||
dual_out = DualOutput(f, sys.stdout)
|
||||
dual_err = DualOutput(f, sys.stderr)
|
||||
|
||||
# Redirect stdout and stderr
|
||||
with redirect_stdout(dual_out), redirect_stderr(dual_err):
|
||||
_run_experiment(exp_cfg, expr_name, trial_name)
|