[Doc] Add doc for reproducing benchmark results (#65)

* update benchmark script

* .

* add benchmark docs
This commit is contained in:
Wei Fu 2025-06-01 20:50:40 +08:00 committed by GitHub
parent ce4d7354bf
commit 11a34dfb51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 245 additions and 2 deletions

View File

@ -0,0 +1,38 @@
## Benchmark Results
We measure *effective training throughput*, defined as the number of tokens used in each RL training step divided by the time of that step. Note that since AReaL is asynchronous, some generated tokens may not be consumed by the trainer. We do not count these tokens in the throughput calculation.
We compare against the latest release of verl (v0.3.0.post1) as of May 7, 2025.
![Throughput Comparison](scaling_trend_vs_verl.png)
## How to Reproduce
Run `build_cmd.py` to generate the CLI command to run AReaL:
```bash
python3 benchmark/verl_v0_3_1_312a8cb/build_cmd.py --model-size 1 --ctx 32768 --n-nodes 4
```
The above command generates the command to run AReaL with `DeepSeek-R1-Distill-Qwen-1.5B` using 32k context length (31k generation length) on 4 nodes (32 GPUs). You can choose `model_size` from [1, 7, 32] and `n_nodes` from [4, 8, 16, 32, 64].
The throughput value can be parsed from the log (you should manually divide the number of tokens by the time):
```bash
# obtain the time
grep -ai "execution time" /path/to/main.log
# obtain the number of processed tokens
grep -ai "n_tokens" /path/to/main.log
```
Alternatively, these metrics can be found in wandb and tensorboard logs, named "time_perf/e2e" and "ppo_actor/n_tokens" respectively.
## Settings
The key settings include:
+ **Batch size**: 512 prompts with 16 answers each, following DAPO and VAPO.
+ **Maximum Generation Length**: 31k (32k context length minus 1k prompt length).
+ **Minimum Generation Length**: 0. This should not be set to the maximum length because the generation length varies for each prompt.
+ **Model**: Distilled models from DeepSeek-R1, which have long-CoT reasoning abilities. These *cannot* be simply replaced with models like Qwen2-7B, Qwen2.5-Math-7B, or randomly initialized weights.
+ **Dataset**: Our [math dataset](https://huggingface.co/datasets/inclusionAI/AReaL-boba-Data) or other similar open-source datasets. The prompts should trigger the model to generate a long-CoT trajectory.

View File

@ -0,0 +1,41 @@
wandb:
mode: disabled
mode: ray
max_head_offpolicyness: 4
experiment_name: async-ppo-throughput
max_concurrent_rollouts: 512
exp_ctrl:
total_train_epochs: 10
benchmark_n_seqs: 5120
actor:
type:
_class: qwen2
sglang:
mem_fraction_static: 0.8
context_length: 32768
dataset:
path: /storage/testing/dataset/boba_106k_0319.jsonl
max_prompt_len: 1024
train_bs_n_seqs: 512
group_size: 16
ppo:
gen:
min_new_tokens: 0
temperature: 1.0
recompute_logprob: true
use_decoupled_loss: true
disable_value: true
ppo_n_minibatches: 4
kl_ctl: 0.0
value_eps_clip: 0.2
actor_train:
mb_spec:
max_tokens_per_mb: 32768
actor_inf:
mb_spec:
max_tokens_per_mb: 32768
n_gpus_per_node: 8
cache_clear_freq: 1
recover_mode: disabled
recover_retries: 10
torch_cache_mysophobia: true

View File

@ -0,0 +1,156 @@
from pathlib import Path
paths = {
1: "/storage/testing/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B",
7: "/storage/testing/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-7B",
32: "/storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-32B",
}
from typing import Any, List, Union
import yaml
def yaml_to_cli_args(yaml_file_path: str) -> List[str]:
"""
Convert a YAML file to structured CLI configuration arguments.
Args:
yaml_file_path (str): Path to the YAML file
Returns:
List[str]: List of CLI arguments in the format "key=value" or "nested.key=value"
Example:
>>> yaml_to_cli_args("config.yaml")
['wandb.mode=disabled', 'mode=ray', 'max_head_offpolicyness=4', ...]
"""
with open(yaml_file_path, "r") as file:
data = yaml.safe_load(file)
return _flatten_dict_to_cli_args(data)
def yaml_dict_to_cli_args(yaml_dict: dict) -> List[str]:
"""
Convert a YAML dictionary to structured CLI configuration arguments.
Args:
yaml_dict (dict): Dictionary loaded from YAML
Returns:
List[str]: List of CLI arguments in the format "key=value" or "nested.key=value"
"""
return _flatten_dict_to_cli_args(yaml_dict)
def _flatten_dict_to_cli_args(data: dict, parent_key: str = "") -> List[str]:
"""
Recursively flatten a nested dictionary into CLI arguments.
Args:
data (dict): Dictionary to flatten
parent_key (str): Current parent key path
Returns:
List[str]: Flattened CLI arguments
"""
cli_args = []
for key, value in data.items():
# Build the full key path
full_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(value, dict):
# Recursively process nested dictionaries
cli_args.extend(_flatten_dict_to_cli_args(value, full_key))
else:
# Convert value to string representation
cli_args.append(f"{full_key}={_format_value(value)}")
return cli_args
def _format_value(value: Any) -> str:
"""
Format a value for CLI argument representation.
Args:
value: The value to format
Returns:
str: Formatted string representation of the value
"""
if isinstance(value, bool):
return str(value).lower()
elif isinstance(value, (int, float)):
return str(value)
elif isinstance(value, str):
return value
elif isinstance(value, list):
# Convert list to comma-separated string
return ",".join(str(item) for item in value)
elif value is None:
return "null"
else:
return str(value)
def get_allocation_mode(n_nodes: int, model_size: int):
assert n_nodes >= 4, n_nodes
sglang_tp = 1 if model_size < 32 else 8
n_sglang_gpus = n_nodes * 8 // 4 * 3
sglang_alloc = f"sglang.d{n_sglang_gpus // sglang_tp}p1m{sglang_tp}"
n_train_gpus = n_nodes * 8 - n_sglang_gpus
if model_size == 1:
pp = 2
tp = 1
elif model_size == 7:
pp = 4
tp = 1
elif model_size == 32:
pp = 8
tp = 2
else:
raise NotImplementedError(model_size)
train_alloc = f"d{n_train_gpus // tp // pp}m{tp}p{pp}"
return sglang_alloc + "+" + train_alloc
def get_trial_name(model_size: int, ctx: int, n_nodes: int):
trial_name = f"m{model_size}-ctx{ctx}-n{n_nodes}"
return trial_name
def build_cmd(model_size: int, ctx: int, n_nodes: int):
trial_name = get_trial_name(model_size, ctx, n_nodes)
allocation_mode = get_allocation_mode(model_size=model_size, n_nodes=n_nodes)
config_path = Path(__file__).parent / "areal_config.yaml"
cli_args = yaml_to_cli_args(str(config_path))
cmd = (
["python3", "training/main_async_ppo.py"]
+ cli_args
+ [
f"actor.path={paths[model_size]}",
f"ppo.gen.max_new_tokens={ctx-1024}",
f"trial_name={trial_name}",
f"allocation_mode={allocation_mode}",
f"n_nodes={n_nodes}",
]
)
if model_size > 1:
# overwrite to avoid OOM
cmd += ["actor.sglang.mem_fraction_static=0.7"]
return " ".join(cmd)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-x", "--model-size", type=int)
parser.add_argument("-s", "--ctx", type=int)
parser.add_argument("-n", "--n-nodes", type=int)
args = parser.parse_args()
print(build_cmd(args.model_size, args.ctx, args.n_nodes))

Binary file not shown.

After

Width:  |  Height:  |  Size: 271 KiB

View File

@ -20,6 +20,9 @@ parts:
- file: developer/model_worker
- file: developer/algo_interface
- file: developer/allocation_parallel
- caption: References
chapters:
- file: references/benchmark
- caption: Contributing
chapters:
- file: contrib

View File

@ -0,0 +1,3 @@
# Benchmark Guide
### [Latest Benchmark Results and Instructions](https://github.com/inclusionAI/AReaL/tree/main/benchmark)

View File

@ -49,11 +49,13 @@ def topk(scores, gen_lengths, k) -> list:
return [idx for idx, _ in sorted_indices]
@torch.compile
@torch.no_grad()
def calc_entropy(logits, cu_seqlens):
leave_one_indices = build_leave_one_indices(logits, cu_seqlens)
probs = torch.nn.functional.softmax(logits, dim=-1)
probs = torch.nn.functional.softmax(logits.detach(), dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-7), dim=-1)[leave_one_indices]
return entropy.detach()
return entropy
def _ppo_actor_loss_from_model_outputs(