mirror of https://github.com/inclusionAI/AReaL
[Doc] Add doc for reproducing released models (#73)
* update benchmark script * . * add benchmark docs * add v0.3.0 configs * . * PullRequest: 178 multi turn math agent training Merge branch gjx/multi-turn-math of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/178?tab=diff Reviewed-by: 博惟 <bowei.fw@antgroup.com> * multi turn math agent training * training data logging and clean math multi-turn exp * fix * . * fix * add docs and config * format * revert multi-turn agent * add config --------- Co-authored-by: 步偶 <sam.gjx@antgroup.com>
This commit is contained in:
parent
ad6e5bd3fa
commit
2d4d937d10
|
@ -31,6 +31,7 @@ parts:
|
|||
- caption: References
|
||||
chapters:
|
||||
- file: references/benchmark
|
||||
- file: references/reproduce
|
||||
- caption: Contributing
|
||||
chapters:
|
||||
- file: contrib
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
# Benchmark Guide
|
||||
|
||||
### [Latest Benchmark Results and Instructions](https://github.com/inclusionAI/AReaL/tree/main/benchmark)
|
||||
## [Latest Benchmark Results and Instructions](https://github.com/inclusionAI/AReaL/tree/main/benchmark)
|
|
@ -0,0 +1,31 @@
|
|||
# Reproduction Guide
|
||||
|
||||
The experiment configurations used to train our released models can be found in `examples/configs`. The user can overwrite `training/config/async_ppo.yaml` with our provided config (e.g., by copy-pasting) and then run `python3 training/main_async_ppo.py` as illustrated in the [quickstart section](../tutorial/quickstart.md).
|
||||
|
||||
**Available Configurations:**
|
||||
|
||||
+ `examples/configs/v0.2-qwen2-math`: Configuration for reproducing the boba math models based on R1-Distilled-Qwen models
|
||||
+ Dataset: [AReaL-boba-106k](https://huggingface.co/datasets/inclusionAI/AReaL-boba-Data/blob/main/AReaL-boba-106k.jsonl)
|
||||
|
||||
+ `examples/configs/v0.3-qwen3-code`: Configuration for reproducing the boba² coding model based on Qwen3
|
||||
+ Dataset: [DeepCoder Dataset](https://huggingface.co/datasets/agentica-org/DeepCoder-Preview-Dataset) (duplicated problems and problems with incorrect answers have been filtered out)
|
||||
|
||||
## Adjusting Computation Resources
|
||||
|
||||
We recommend using the provided number of GPUs and corresponding parallelization strategy for optimal reproducibility. When resources are limited, try decreasing the `n_nodes` and reducing the data parallelism degree in `allocation_mode`.
|
||||
|
||||
**Example Resource Adjustment:**
|
||||
|
||||
When reproducing the Qwen3 8B coding model, the original configuration is:
|
||||
+ `n_nodes=24`
|
||||
+ `allocation_mode=sglang.d80m2p1+d4m2p4`
|
||||
|
||||
If only 6 nodes are available, you can adjust the configuration to:
|
||||
+ `n_nodes=6`
|
||||
+ `allocation_mode=sglang.d20m2p1+d1m2p4`
|
||||
|
||||
If you encounter out-of-memory (OOM) errors, please refer to the [troubleshooting section](../tutorial/troubleshooting.md) or raise an issue on GitHub.
|
||||
|
||||
```{note}
|
||||
We acknowledge that our provided device allocation and parallelism strategies could be further optimized. However, since this guide focuses on reproducibility, we do not recommend changing them arbitrarily due to data loading randomness and potential precision issues. We hope our provided configurations can serve as a starting point for further development and optimization.
|
||||
```
|
|
@ -0,0 +1,178 @@
|
|||
experiment_name: qwen3-14b-code
|
||||
trial_name: my-trial
|
||||
seed: 1
|
||||
mode: ray
|
||||
metric_discovery_port: 0
|
||||
wandb:
|
||||
mode: disbled
|
||||
entity: null
|
||||
project: null
|
||||
name: null
|
||||
job_type: null
|
||||
group: null
|
||||
notes: null
|
||||
tags: null
|
||||
config: null
|
||||
tensorboard:
|
||||
path: null
|
||||
recover_mode: auto
|
||||
recover_retries: 10
|
||||
recover_after: 10
|
||||
exp_ctrl:
|
||||
total_train_epochs: 10
|
||||
save_freq_epochs: null
|
||||
save_freq_steps: 20
|
||||
save_freq_secs: null
|
||||
ckpt_freq_epochs: null
|
||||
ckpt_freq_steps: null
|
||||
ckpt_freq_secs: 600
|
||||
eval_freq_epochs: null
|
||||
eval_freq_steps: null
|
||||
eval_freq_secs: null
|
||||
benchmark_steps: null
|
||||
benchmark_n_seqs: null
|
||||
torch_cache_mysophobia: true
|
||||
cache_clear_freq: 1
|
||||
max_head_offpolicyness: 16
|
||||
n_rollout_workers: null
|
||||
flush_request_timeout: 600
|
||||
cpus_per_generation_server: 4
|
||||
mem_per_generation_server: 61440
|
||||
cpus_per_gserver_manager: 4
|
||||
mem_per_gserver_manager: 10240
|
||||
cpus_per_rollout_worker: 4
|
||||
mem_per_rollout_worker: 20480
|
||||
allocation_mode: sglang.d80m4p1+d4p8m2
|
||||
n_nodes: 48
|
||||
n_gpus_per_node: 8
|
||||
ray_temp_path: /tmp/ray
|
||||
cluster:
|
||||
fileroot: /home/admin/.cache/realhf
|
||||
n_nodes: 48
|
||||
n_gpus_per_node: 8
|
||||
actor:
|
||||
type:
|
||||
_class: qwen3
|
||||
path: /storage/testing/models/Qwen__Qwen3-14B/
|
||||
init_from_scratch: false
|
||||
gradient_checkpointing: true
|
||||
bf16: false
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2.0e-05
|
||||
weight_decay: 0.05
|
||||
beta1: 0.9
|
||||
beta2: 0.95
|
||||
eps: 1.0e-05
|
||||
min_lr_ratio: 0.0
|
||||
lr_scheduler_type: constant
|
||||
warmup_steps_proportion: 0.001
|
||||
initial_loss_scale: 4294967296.0
|
||||
min_loss_scale: 1.0
|
||||
loss_scale_window: 5.0
|
||||
hysteresis: 2
|
||||
gradient_clipping: 1.0
|
||||
megatron:
|
||||
ddp:
|
||||
grad_reduce_in_fp32: true
|
||||
overlap_grad_reduce: true
|
||||
use_distributed_optimizer: true
|
||||
sglang:
|
||||
disable_cuda_graph: false
|
||||
disable_radix_cache: false
|
||||
disable_cuda_graph_padding: false
|
||||
enable_nccl_nvls: false
|
||||
disable_outlines_disk_cache: false
|
||||
disable_custom_all_reduce: false
|
||||
disable_overlap_schedule: false
|
||||
enable_mixed_chunk: false
|
||||
enable_torch_compile: false
|
||||
torch_compile_max_bs: 32
|
||||
cuda_graph_max_bs: null
|
||||
cuda_graph_bs: null
|
||||
torchao_config: ''
|
||||
enable_nan_detection: false
|
||||
enable_p2p_check: false
|
||||
triton_attention_reduce_in_fp32: false
|
||||
triton_attention_num_kv_splits: 16
|
||||
num_continuous_decode_steps: 1
|
||||
enable_memory_saver: false
|
||||
allow_auto_truncate: false
|
||||
attention_backend: flashinfer
|
||||
sampling_backend: null
|
||||
context_length: 30720
|
||||
mem_fraction_static: 0.7
|
||||
max_running_requests: null
|
||||
chunked_prefill_size: -1
|
||||
max_prefill_tokens: 32768
|
||||
schedule_policy: lpm
|
||||
schedule_conservativeness: 1.0
|
||||
cpu_offload_gb: 0
|
||||
dtype: float16
|
||||
kv_cache_dtype: auto
|
||||
log_level: warning
|
||||
log_level_http: warning
|
||||
log_requests: false
|
||||
log_requests_level: 0
|
||||
show_time_cost: false
|
||||
enable_metrics: true
|
||||
decode_log_interval: 1
|
||||
ref:
|
||||
type:
|
||||
_class: qwen3
|
||||
path: /storage/testing/models/Qwen__Qwen3-14B/
|
||||
init_from_scratch: false
|
||||
bf16: false
|
||||
actor_train:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
ref_inf:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
actor_inf:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
shuffle_dataset: true
|
||||
dataset:
|
||||
path: /path/to/deepcoder_0415_v3_verify_new_correct_dedup_lcb.jsonl
|
||||
max_prompt_len: 2048
|
||||
train_bs_n_seqs: 128
|
||||
group_size: 16
|
||||
mask_too_long: false
|
||||
group_adv_norm: false
|
||||
rw_type: sparse
|
||||
success_rate_ub: 0.95
|
||||
success_rate_lb: 0.05
|
||||
ppo:
|
||||
gen:
|
||||
n: 1
|
||||
max_new_tokens: 27648
|
||||
min_new_tokens: 0
|
||||
greedy: false
|
||||
top_p: 1.0
|
||||
top_k: 100000000
|
||||
temperature: 1.0
|
||||
ppo_n_minibatches: 1
|
||||
eps_clip: 0.2
|
||||
c_clip: null
|
||||
value_eps_clip: 0.2
|
||||
early_stop_imp_ratio: 5.0
|
||||
actor_sample_reuse: 1
|
||||
critic_sample_reuse: 1
|
||||
max_reward_clip: 20.0
|
||||
reward_output_scaling: 5.0
|
||||
reward_output_bias: 0.0
|
||||
fuse_rew_ref: true
|
||||
discount: 1.0
|
||||
gae_lambda: 1.0
|
||||
adv_norm: true
|
||||
kl_ctl: 0.0
|
||||
use_adaptive_kl_ctl: false
|
||||
disable_value: true
|
||||
recompute_logprob: true
|
||||
use_decoupled_loss: true
|
||||
behav_imp_weight_cap: 5.0
|
||||
cpus_per_master_worker: 4
|
||||
mem_per_master_worker: 20000
|
||||
cpus_per_model_worker: 4
|
||||
mem_per_model_worker: 90000
|
|
@ -0,0 +1,178 @@
|
|||
experiment_name: qwen3-8b-code
|
||||
trial_name: my-trial
|
||||
seed: 1
|
||||
mode: ray
|
||||
metric_discovery_port: 0
|
||||
wandb:
|
||||
mode: disbled
|
||||
entity: null
|
||||
project: null
|
||||
name: null
|
||||
job_type: null
|
||||
group: null
|
||||
notes: null
|
||||
tags: null
|
||||
config: null
|
||||
tensorboard:
|
||||
path: null
|
||||
recover_mode: auto
|
||||
recover_retries: 10
|
||||
recover_after: 10
|
||||
exp_ctrl:
|
||||
total_train_epochs: 10
|
||||
save_freq_epochs: null
|
||||
save_freq_steps: 20
|
||||
save_freq_secs: null
|
||||
ckpt_freq_epochs: null
|
||||
ckpt_freq_steps: null
|
||||
ckpt_freq_secs: 600
|
||||
eval_freq_epochs: null
|
||||
eval_freq_steps: null
|
||||
eval_freq_secs: null
|
||||
benchmark_steps: null
|
||||
benchmark_n_seqs: null
|
||||
torch_cache_mysophobia: true
|
||||
cache_clear_freq: 1
|
||||
max_head_offpolicyness: 16
|
||||
n_rollout_workers: null
|
||||
flush_request_timeout: 300
|
||||
cpus_per_generation_server: 4
|
||||
mem_per_generation_server: 61440
|
||||
cpus_per_gserver_manager: 4
|
||||
mem_per_gserver_manager: 10240
|
||||
cpus_per_rollout_worker: 4
|
||||
mem_per_rollout_worker: 20480
|
||||
allocation_mode: sglang.d80m2p1+d4m2p4
|
||||
n_nodes: 24
|
||||
n_gpus_per_node: 8
|
||||
ray_temp_path: /tmp/ray
|
||||
cluster:
|
||||
fileroot: /home/admin/.cache/realhf
|
||||
n_nodes: 32
|
||||
n_gpus_per_node: 8
|
||||
actor:
|
||||
type:
|
||||
_class: qwen3
|
||||
path: /storage/testing/models/Qwen__Qwen3-8B/
|
||||
init_from_scratch: false
|
||||
gradient_checkpointing: true
|
||||
bf16: false
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2.0e-05
|
||||
weight_decay: 0.05
|
||||
beta1: 0.9
|
||||
beta2: 0.95
|
||||
eps: 1.0e-05
|
||||
min_lr_ratio: 0.0
|
||||
lr_scheduler_type: constant
|
||||
warmup_steps_proportion: 0.001
|
||||
initial_loss_scale: 4294967296.0
|
||||
min_loss_scale: 1.0
|
||||
loss_scale_window: 5.0
|
||||
hysteresis: 2
|
||||
gradient_clipping: 1.0
|
||||
megatron:
|
||||
ddp:
|
||||
grad_reduce_in_fp32: true
|
||||
overlap_grad_reduce: true
|
||||
use_distributed_optimizer: true
|
||||
sglang:
|
||||
disable_cuda_graph: false
|
||||
disable_radix_cache: false
|
||||
disable_cuda_graph_padding: false
|
||||
enable_nccl_nvls: false
|
||||
disable_outlines_disk_cache: false
|
||||
disable_custom_all_reduce: false
|
||||
disable_overlap_schedule: false
|
||||
enable_mixed_chunk: false
|
||||
enable_torch_compile: false
|
||||
torch_compile_max_bs: 32
|
||||
cuda_graph_max_bs: null
|
||||
cuda_graph_bs: null
|
||||
torchao_config: ''
|
||||
enable_nan_detection: false
|
||||
enable_p2p_check: false
|
||||
triton_attention_reduce_in_fp32: false
|
||||
triton_attention_num_kv_splits: 16
|
||||
num_continuous_decode_steps: 1
|
||||
enable_memory_saver: false
|
||||
allow_auto_truncate: false
|
||||
attention_backend: flashinfer
|
||||
sampling_backend: null
|
||||
context_length: 30720
|
||||
mem_fraction_static: 0.7
|
||||
max_running_requests: null
|
||||
chunked_prefill_size: -1
|
||||
max_prefill_tokens: 32768
|
||||
schedule_policy: lpm
|
||||
schedule_conservativeness: 1.0
|
||||
cpu_offload_gb: 0
|
||||
dtype: float16
|
||||
kv_cache_dtype: auto
|
||||
log_level: warning
|
||||
log_level_http: warning
|
||||
log_requests: false
|
||||
log_requests_level: 0
|
||||
show_time_cost: false
|
||||
enable_metrics: true
|
||||
decode_log_interval: 1
|
||||
ref:
|
||||
type:
|
||||
_class: qwen3
|
||||
path: /storage/testing/models/Qwen__Qwen3-8B/
|
||||
init_from_scratch: false
|
||||
bf16: false
|
||||
actor_train:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
ref_inf:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
actor_inf:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
shuffle_dataset: true
|
||||
dataset:
|
||||
path: /path/to/deepcoder_0415_v3_verify_new_correct_dedup_lcb.jsonl
|
||||
max_prompt_len: 2048
|
||||
train_bs_n_seqs: 128
|
||||
group_size: 16
|
||||
mask_too_long: false
|
||||
group_adv_norm: false
|
||||
rw_type: sparse
|
||||
success_rate_ub: 0.95
|
||||
success_rate_lb: 0.05
|
||||
ppo:
|
||||
gen:
|
||||
n: 1
|
||||
max_new_tokens: 27648
|
||||
min_new_tokens: 0
|
||||
greedy: false
|
||||
top_p: 1.0
|
||||
top_k: 100000000
|
||||
temperature: 1.0
|
||||
ppo_n_minibatches: 1
|
||||
eps_clip: 0.2
|
||||
c_clip: null
|
||||
value_eps_clip: 0.2
|
||||
early_stop_imp_ratio: 5.0
|
||||
actor_sample_reuse: 1
|
||||
critic_sample_reuse: 1
|
||||
max_reward_clip: 20.0
|
||||
reward_output_scaling: 5.0
|
||||
reward_output_bias: 0.0
|
||||
fuse_rew_ref: true
|
||||
discount: 1.0
|
||||
gae_lambda: 1.0
|
||||
adv_norm: true
|
||||
kl_ctl: 0.0
|
||||
use_adaptive_kl_ctl: false
|
||||
disable_value: true
|
||||
recompute_logprob: true
|
||||
use_decoupled_loss: true
|
||||
behav_imp_weight_cap: 5.0
|
||||
cpus_per_master_worker: 4
|
||||
mem_per_master_worker: 20000
|
||||
cpus_per_model_worker: 4
|
||||
mem_per_model_worker: 90000
|
|
@ -37,7 +37,7 @@ def check_code_metadata_entries(data):
|
|||
data["problem_id"] = data["query_id"]
|
||||
assert isinstance(data["prompt"], str)
|
||||
case_size = sys.getsizeof(data["input_output"])
|
||||
if not os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""):
|
||||
if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""):
|
||||
assert (
|
||||
case_size < 500 * 1024
|
||||
), f"'input_output' exceeds 500KB ({case_size} bytes). Use remote testcase instead."
|
||||
|
|
|
@ -4,6 +4,7 @@ import sys
|
|||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import ray
|
||||
import requests
|
||||
|
||||
from realhf.api.cli_args import SGLangConfig
|
||||
|
@ -125,11 +126,17 @@ class GenerationServer(Worker):
|
|||
)
|
||||
|
||||
# Cancel the effect of CUDA device isolation
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
if ray.is_initialized():
|
||||
self.base_gpu_id = 0
|
||||
elif "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
self.base_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"])
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
map(str, range(gpu_utils.gpu_count()))
|
||||
)
|
||||
else:
|
||||
servers_per_node = cluster_spec.n_gpus_per_node // self.config.tp_size
|
||||
idx_on_this_node = self.worker_index % servers_per_node
|
||||
self.base_gpu_id = idx_on_this_node * self.config.tp_size
|
||||
|
||||
self.server_process = None
|
||||
self.server_addr = None
|
||||
|
|
|
@ -37,7 +37,6 @@ torch_cache_mysophobia: true
|
|||
cache_clear_freq: 1
|
||||
|
||||
# Asynchronous RL options
|
||||
new_tokens_per_chunk: 10000000000
|
||||
max_head_offpolicyness: 4
|
||||
n_rollout_workers: null
|
||||
max_concurrent_rollouts: null
|
||||
|
|
Loading…
Reference in New Issue