mirror of https://github.com/inclusionAI/AReaL
[Doc & Fix] Simplify the environment setup procedure (#62)
* PullRequest: 176 [FIX] clear sensitive info Merge branch fw/fix-sensitive-info of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/176 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * test env setup * fix * allow cached model * . * revise docs * change docs * format docs * update readme
This commit is contained in:
parent
d87c898d36
commit
4fab3ac769
51
README.md
51
README.md
|
@ -3,7 +3,7 @@
|
|||
</h1>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://inclusionai.github.io/AReaL/"><b>Documentation</b></a> |
|
||||
| <a href="https://inclusionai.github.io/AReaL/"><b>Documentation</b></a> | <a href="https://deepwiki.com/inclusionAI/AReaL"><b>Ask DeepWiki</b></a> |
|
||||
</p>
|
||||
|
||||
<img align="right" alt="ReaL" src="/assets/logo.png" width="20%">
|
||||
|
@ -17,7 +17,6 @@ AReaL (Ant Reasoning RL) is a fully open-sourced, scalable, and efficient reinfo
|
|||
+ 🔪 **Cutting-Edge Performances:** AReaL can produce models with cutting-edge reasoning capabilities. We are actively working on other domains, such as coding and agent, as well.
|
||||
|
||||
## News
|
||||
**[2025/04/27]** 🔥 We've built a [documentation website](https://deepwiki.com/inclusionAI/AReaL) using the amazing [DeepWiki](https://deepwiki.com/) tool. Check the link to know and ask about AReaL!
|
||||
|
||||
**[2025/03/31]** **(v0.2, Boba)** Our milestone release Boba! Please call it A-ReaL-Boba! This release includes much accelerated training with SGLang support and SOTA 7B and 32B models on math reasoning.
|
||||
|
||||
|
@ -72,23 +71,51 @@ Building upon **R1-Distill-Qwen-32B**, we replicate **QwQ-32B's** inference perf
|
|||
|
||||
## Getting Started
|
||||
### Quick Start
|
||||
```bash
|
||||
# Train the distilled 7B model
|
||||
python3 -m realhf.apps.quickstart ppo-math \
|
||||
--config examples/configs/7B-distill/ppo-7B-distill-gpus-128.yaml
|
||||
|
||||
# Evaluate the 7B model
|
||||
python evaluation/eval_and_aggregate.py \
|
||||
Train Qwen3 1.7B locally:
|
||||
|
||||
```bash
|
||||
bash examples/env/scripts/setup-pip-deps.sh
|
||||
python3 training/main_async_ppo.py \
|
||||
n_nodes=1 n_gpus_per_node=8 \
|
||||
allocation_mode=sglang.d4p1m1+d2p2m1 \
|
||||
cluster.fileroot=/storage/testing/experiments \
|
||||
actor.type._class=qwen3 \
|
||||
actor.path=Qwen/Qwen3-1.7B \
|
||||
ref.type._class=qwen3 \
|
||||
ref.path=Qwen/Qwen3-1.7B \
|
||||
dataset.path=/path/to/dataset/boba_106k_0319.jsonl \
|
||||
dataset.train_bs_n_seqs=32 \
|
||||
group_size=8 \
|
||||
ppo.gen.max_new_tokens=4096 \
|
||||
ppo.ppo_n_minibatches=4 \
|
||||
actor_train.mb_spec.max_tokens_per_mb=32768 \
|
||||
actor_inf.mb_spec.max_tokens_per_mb=32768 \
|
||||
max_concurrent_rollouts=16 \
|
||||
max_head_offpolicyness=4
|
||||
```
|
||||
|
||||
Evaluation
|
||||
|
||||
```bash
|
||||
bash examples/env/scripts/setup-eval-pip-deps.sh
|
||||
cd evaluation
|
||||
# Evaluate the model
|
||||
python eval_and_aggregate.py \
|
||||
--model_path ${MODEL_PATH} \
|
||||
--output_path ${OUTPUT_PATH} \
|
||||
--data_names aime24,aime25 \
|
||||
--prompt_type AReaL-boba \
|
||||
--output_path outputs --temperature 1.0
|
||||
--max_gen_tokens 32768 \
|
||||
--data_names codeforces,lcb_v5 \
|
||||
--prompt_type qwen3-think-pure \
|
||||
--temperature 1.0
|
||||
```
|
||||
|
||||
### Resources
|
||||
+ [Tutorial](/examples/README.md)
|
||||
+ [Tutorial (中文)](/examples/README_zh.md)
|
||||
|
||||
+ [Installation](https://inclusionai.github.io/AReaL/tutorial/installation.html)
|
||||
+ [Quickstart](https://inclusionai.github.io/AReaL/tutorial/quickstart.html)
|
||||
+ [Contributing](https://inclusionai.github.io/AReaL/contrib.html)
|
||||
|
||||
## Future Plan
|
||||
AReaL is under active development. We will have major releases in a weekly manner. We also highly appreciate efforts from the community as well. Here we highlight our future research and development plan.
|
||||
|
|
|
@ -7,7 +7,7 @@ parts:
|
|||
- caption: Tutorial
|
||||
chapters:
|
||||
- file: tutorial/installation
|
||||
- file: tutorial/training
|
||||
- file: tutorial/quickstart
|
||||
- file: tutorial/eval
|
||||
- file: tutorial/agent
|
||||
- file: tutorial/troubleshooting
|
||||
|
|
|
@ -1,68 +1,60 @@
|
|||
# Evaluation
|
||||
|
||||
The evaluation code is located in the `evaluation` folder of the repository. Following the previous tutorial, trained checkpoints will be saved under `/storage/ray/experiments/checkpoints/root/`.
|
||||
The evaluation code is located in the `evaluation` folder of the repository. Following the previous tutorial, trained checkpoints will be saved under `${fileroot}/checkpoints/${USER}/${experiment_name}/${trial_name}/`.
|
||||
|
||||
## Setup Evaluation Environment
|
||||
|
||||
Start a new container to execute the evaluation script. **Note**: Evaluation requires updates to certain Python libraries, so avoid using the training container for this task.
|
||||
> **Note**: Evaluation requires updates to certain Python libraries, so avoid using the training container or virtual environment for this task.
|
||||
|
||||
From the repository directory, create a new conda environment:
|
||||
|
||||
```bash
|
||||
docker run -d --name areal-eval --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.3.0 /bin/bash -c "tail -f /dev/null"
|
||||
docker exec -it areal-eval bash
|
||||
conda create -n areal-eval python=3.12
|
||||
conda activate areal-eval
|
||||
```
|
||||
|
||||
## Install Dependencies
|
||||
|
||||
Execute the following commands inside the Docker container:
|
||||
Install dependencies:
|
||||
|
||||
```bash
|
||||
cd /storage/codes/AReaL/evaluation
|
||||
cd latex2sympy
|
||||
pip install -e .
|
||||
cd ..
|
||||
pip install -r requirements.txt
|
||||
pip install vllm==0.8.5 --no-build-isolation
|
||||
pip install transformers==4.51.1
|
||||
pip install prettytable timeout_decorator
|
||||
bash examples/env/scripts/setup-eval-pip-deps.sh
|
||||
```
|
||||
|
||||
## Run Evaluation
|
||||
|
||||
Specify an output_path to save the test results (optional — if not specified, the results will be saved in `model_path`):
|
||||
|
||||
```bash
|
||||
mkdir -p /storage/ray/eval_output/
|
||||
```
|
||||
|
||||
### Math Eval
|
||||
Specify an `output_path` to save the test results. If not specified, the results will be saved in `model_path`.
|
||||
|
||||
### Math Evaluation
|
||||
|
||||
```bash
|
||||
cd evaluation
|
||||
nohup python eval_and_aggregate.py \
|
||||
--model_path /storage/ray/experiments/checkpoints/root/my-exp/my-trial/epoch1epochstep20globalstep20/ \
|
||||
--output_path /storage/ray/eval_output/ \
|
||||
--max_gen_tokens 32768
|
||||
--model_path /path/to/checkpoint \
|
||||
--output_path /path/to/outputs \
|
||||
--max_gen_tokens 32768 \
|
||||
--data_names math_500,aime24,amc23 \
|
||||
--prompt_type qwen3-think \
|
||||
--task math &> /storage/ray/eval_output/eval_and_aggregate_parallel.log &
|
||||
--task math &> eval_and_aggregate_parallel.log &
|
||||
```
|
||||
|
||||
### Code Eval
|
||||
**Obtaining Data:**
|
||||
- Consider the size of code datasets (Because some of test cases are relatively large), we upload all our code datasets to Huggingface: [todo:upload the code dataset to Huggingface]().
|
||||
- Once you have downloaded the code dataset, place it under **`./evaluation/data/`**.
|
||||
### Code Evaluation
|
||||
|
||||
**Running Eval:**
|
||||
**Obtaining Data:**
|
||||
- Due to the size of code datasets (some test cases are relatively large), we have uploaded all our code datasets to [Hugging Face](https://huggingface.co/inclusionAI).
|
||||
- Once you have downloaded the code dataset, place it under `./evaluation/data/`.
|
||||
|
||||
**Running Evaluation:**
|
||||
```bash
|
||||
cd evaluation
|
||||
nohup python eval_and_aggregate.py \
|
||||
--model_path /storage/ray/experiments/checkpoints/root/my-exp/my-trial/epoch1epochstep20globalstep20/ \
|
||||
--output_path /storage/ray/eval_output/ \
|
||||
--max_gen_tokens 32768 \
|
||||
--data_names codeforces,lcb_v5 \
|
||||
--prompt_type qwen3-think-pure \
|
||||
--num_sample_nodes 8 \
|
||||
--samples_per_node 1 \
|
||||
--n_sampling $((num_sample_nodes * samples_per_node)) \
|
||||
--task code &> /storage/ray/eval_output/eval_and_aggregate_parallel.log &
|
||||
--model_path /path/to/checkpoint \
|
||||
--output_path /path/to/outputs \
|
||||
--max_gen_tokens 32768 \
|
||||
--data_names codeforces,lcb_v5 \
|
||||
--prompt_type qwen3-think-pure \
|
||||
--num_sample_nodes 8 \
|
||||
--samples_per_node 1 \
|
||||
--n_sampling $((num_sample_nodes * samples_per_node)) \
|
||||
--task code &> eval_and_aggregate_parallel.log &
|
||||
```
|
||||
|
||||
### Command Line Parameters
|
||||
|
@ -70,14 +62,16 @@ nohup python eval_and_aggregate.py \
|
|||
- **`--model_path`**: Path to the saved model parameters
|
||||
- **`--output_path`**: Path to store generated answers and log files during evaluation
|
||||
- **`--data_names`**: Dataset(s) to evaluate. Multiple datasets can be separated by commas. Available options:
|
||||
- math: `math_500`, `aime24`, `aime25`, `amc23`
|
||||
- code: `lcb_v5`, `lcb_v5_2410_2502`, `codeforces`, `code_contest_all`
|
||||
- Math: `math_500`, `aime24`, `aime25`, `amc23`
|
||||
- Code: `lcb_v5`, `lcb_v5_2410_2502`, `codeforces`, `code_contest_all`
|
||||
- **`--max_gen_tokens`**: Maximum length of generated answers (default: 32768)
|
||||
- **`--prompt_type`**: Specify the prompt template, for our latest model, we use `qwen3-think` for math dataset, and `qwen3-think-pure` for code dataset.
|
||||
- **`--num_sample_nodes`**: Number of multiple sampling seeds to ensure saampling diversity.
|
||||
- **`--samples_per_node`**: Number of samples to generate per seed for each problem.
|
||||
- **`--prompt_type`**: Specify the prompt template. For our latest model, we use `qwen3-think` for math datasets and `qwen3-think-pure` for code datasets.
|
||||
- **`--num_sample_nodes`**: Number of multiple sampling seeds to ensure sampling diversity.
|
||||
- **`--samples_per_node`**: Number of samples to generate per seed for each problem.
|
||||
|
||||
## Evaluation Results
|
||||
## Logs and Evaluation Results
|
||||
|
||||
Check `${output_path}/math_eval_${max_gen_tokens}/logs` to review the log of each worker.
|
||||
|
||||
The evaluation script will output a results table in the terminal:
|
||||
|
||||
|
@ -97,7 +91,7 @@ The evaluation script will output a results table in the terminal:
|
|||
- **`greedy_acc`**: Average accuracy under greedy sampling
|
||||
- **`sample_pass@{k}`**: Probability of generating a correct answer within `k` attempts under random sampling
|
||||
|
||||
For Codeforces dataset, we use the Elo ranking algorithm to evaluate model performance, referring to [CodeElo](https://github.com/QwenLM/CodeElo) and [rllm](https://github.com/agentica-project/rllm):
|
||||
For the Codeforces dataset, we use the Elo ranking algorithm to evaluate model performance, referring to [CodeElo](https://github.com/QwenLM/CodeElo) and [rllm](https://github.com/agentica-project/rllm):
|
||||
|
||||
```
|
||||
+------------+----------------+-----------+
|
||||
|
@ -107,19 +101,17 @@ For Codeforces dataset, we use the Elo ranking algorithm to evaluate model perfo
|
|||
+------------+----------------+-----------+
|
||||
```
|
||||
|
||||
- **`CF Rating`**: The overall elo rank score of model across 57 Codeforces contests.
|
||||
- **`Percentile`**: The Elo ranking percentile of model among all Codeforces users.
|
||||
**Note**: As the penalty mechanism may cause fluctuations in Elo rankings, we suggest performing multiple evaluations and regard the average score as the final result.
|
||||
- **`CF Rating`**: The overall Elo rank score of the model across 57 Codeforces contests.
|
||||
- **`Percentile`**: The Elo ranking percentile of the model among all Codeforces users.
|
||||
|
||||
> **Note**: As the penalty mechanism may cause fluctuations in Elo rankings, we suggest performing multiple evaluations and taking the average score as the final result.
|
||||
|
||||
## Configuration Details
|
||||
|
||||
### Sampling Parameters
|
||||
|
||||
- The evaluation script defaults to averaging 32 samples with temperature 1.0. For the code dataset, we set it to 8 samples.
|
||||
- We observed that the `enforce_eager` parameter in vLLM significantly impacts evaluation performance
|
||||
- When `enforce_eager=True`, we can reproduce the model performance reported in previous work
|
||||
- Without this setting, evaluation results may fall below reported performance
|
||||
- Therefore, we enforce `enforce_eager=True` during evaluation
|
||||
- We observed that the `enforce_eager` parameter in vLLM significantly impacts evaluation performance. When `enforce_eager=True`, we can reproduce the model performance reported in previous work. Without this setting, evaluation results may fall below reported performance. Therefore, we enforce `enforce_eager=True` during evaluation.
|
||||
|
||||
### Runtime Expectations
|
||||
|
||||
|
|
|
@ -30,49 +30,60 @@ The following hardware configuration has been extensively tested:
|
|||
|
||||
## Runtime Environment
|
||||
|
||||
We recommend using Docker with our provided image. The Dockerfile is available in the top-level directory of the AReaL repository.
|
||||
**For multi-node training**: Ensure a shared storage path is mounted on every node (and mounted to the container if you are using Docker). This path will be used to save checkpoints and logs.
|
||||
|
||||
Pull the Docker image:
|
||||
### Option 1: Docker (Recommended)
|
||||
|
||||
We recommend using Docker with our provided image. The Dockerfile is available in the top-level directory of the AReaL repository.
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/inclusionai/areal-runtime:v0.3.0
|
||||
docker run -it --name areal-node1 \
|
||||
--privileged --gpus all --network host \
|
||||
--shm-size 700g -v /path/to/mount:/path/to/mount \
|
||||
ghcr.io/inclusionai/areal-runtime:v0.3.0 \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
This image includes all training requirements for AReaL.
|
||||
### Option 2: Custom Environment Installation
|
||||
|
||||
**For multi-node training**: Ensure shared storage is mounted to the `/storage` directory on every node. All downloads and resources will be stored in this directory, and the AReaL container will mount this directory to `/storage` within the container.
|
||||
1. Install [Miniconda](https://www.anaconda.com/docs/getting-started/miniconda/install) or [Anaconda](https://www.anaconda.com/docs/getting-started/anaconda/install).
|
||||
|
||||
## Code Setup
|
||||
|
||||
Clone the AReaL project code to `/storage/codes`:
|
||||
2. Create a conda virtual environment:
|
||||
|
||||
```bash
|
||||
conda create -n areal python=3.12
|
||||
conda activate areal
|
||||
```
|
||||
|
||||
3. Install pip dependencies:
|
||||
|
||||
```bash
|
||||
mkdir -p /storage/codes
|
||||
cd /storage/codes/
|
||||
git clone https://github.com/inclusionAI/AReaL
|
||||
pip install -r AReaL/requirements.txt
|
||||
cd AReaL
|
||||
bash examples/env/scripts/setup-pip-deps.sh
|
||||
```
|
||||
|
||||
## Dataset
|
||||
## (Optional) Launch Ray Cluster for Distributed Training
|
||||
|
||||
Download the provided training dataset and place it in `/storage/datasets/`:
|
||||
On the first node, start the Ray Head:
|
||||
|
||||
```bash
|
||||
mkdir -p /storage/datasets/
|
||||
cd /storage/datasets/
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/boba_106k_0319.jsonl?download=true
|
||||
ray start --head
|
||||
```
|
||||
|
||||
## Model
|
||||
|
||||
We train using open-source models available on Hugging Face Hub. Here's an example using Qwen3 (ensure Git LFS is installed):
|
||||
On all other nodes, start the Ray Worker:
|
||||
|
||||
```bash
|
||||
mkdir -p /storage/models
|
||||
cd /storage/models
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Qwen/Qwen3-1.7B
|
||||
cd Qwen3-1.7B
|
||||
git lfs pull
|
||||
# Replace with the actual IP address of the first node
|
||||
RAY_HEAD_IP=xxx.xxx.xxx.xxx
|
||||
ray start --address $RAY_HEAD_IP
|
||||
```
|
||||
|
||||
**Alternative**: You can also use the Hugging Face CLI to download models after installing the `huggingface_hub` package. Refer to the [official documentation](https://huggingface.co/docs/huggingface_hub/guides/cli) for details.
|
||||
You should see the Ray resource status displayed when running `ray status`.
|
||||
|
||||
Properly set the `n_nodes` argument in AReaL's training command, then AReaL's training script will automatically detect the resources and allocate workers to the cluster.
|
||||
|
||||
## Next Steps
|
||||
|
||||
Check the [quickstart section](quickstart.md) to launch your first AReaL job.
|
|
@ -0,0 +1,107 @@
|
|||
# Quickstart
|
||||
|
||||
This guide walks through a simple example of training an LLM to solve math problems.
|
||||
|
||||
## Dataset
|
||||
|
||||
Use `huggingface-cli` to download our open-source dataset:
|
||||
|
||||
```bash
|
||||
huggingface-cli download --repo-type=dataset inclusionAI/AReaL-RL-Data
|
||||
```
|
||||
|
||||
> **Note**: The above command will display the path of the downloaded dataset. You'll need to pass this path to the training command.
|
||||
|
||||
## Model
|
||||
|
||||
We train using open-source models available on Hugging Face Hub. You can either download the model in advance or use the model identifier when running the experiment.
|
||||
|
||||
```bash
|
||||
# If you want to download it in advance
|
||||
huggingface-cli download Qwen/Qwen3-1.7B
|
||||
```
|
||||
|
||||
Refer to the [official documentation](https://huggingface.co/docs/huggingface_hub/guides/cli) for more information on using `huggingface-cli`.
|
||||
|
||||
## Training
|
||||
|
||||
From the repository directory, run:
|
||||
|
||||
```bash
|
||||
# examples/run_async_ppo.sh
|
||||
python3 training/main_sync_ppo.py \
|
||||
n_nodes=1 n_gpus_per_node=8 \
|
||||
allocation_mode=sglang.d4p1m1+d2p2m1 \
|
||||
cluster.fileroot=/path/to/save/logs/checkpoints/ \
|
||||
actor.type._class=qwen3 \
|
||||
actor.path=/path/to/models/Qwen__Qwen3-1.7B \
|
||||
ref.type._class=qwen3 \
|
||||
ref.path=/path/to/models/Qwen__Qwen3-1.7B \
|
||||
dataset.path=/path/to/dataset/boba_106k_0319.jsonl \
|
||||
dataset.train_bs_n_seqs=32 \
|
||||
group_size=8 \
|
||||
ppo.gen.max_new_tokens=4096 \
|
||||
ppo.ppo_n_minibatches=4 \
|
||||
actor_train.mb_spec.max_tokens_per_mb=32768
|
||||
```
|
||||
|
||||
## Command Line Options
|
||||
|
||||
To view all available options:
|
||||
|
||||
```bash
|
||||
python3 training/main_sync_ppo.py --help
|
||||
```
|
||||
|
||||
### Configuration Parameters
|
||||
|
||||
- **`experiment_name`**: The name of your project.
|
||||
- **`trial_name`**: The name of this trial in your project.
|
||||
- **`{actor|ref}.path`**: The path to the model files.
|
||||
- **`dataset.path`**: The path to the dataset JSONL file.
|
||||
- **`cluster.fileroot`**: The root path for saving training outputs (logs and checkpoints).
|
||||
- **`n_nodes`**: The number of nodes in the cluster.
|
||||
- **`n_gpus_per_node`**: The number of GPUs per node.
|
||||
- **`allocation_mode`**: The GPU allocation strategy and 3D parallelism configuration for the experiment. Format:
|
||||
- `sglang.d${DP1}m${TP1}p${PP1}+d${DP2}m${TP2}p${PP2}`: Configures parallel strategies for SGLang generation and training respectively. Generation and training use separate GPU sets, and the total GPU count must equal: DP1×TP1×PP1 + DP2×TP2×PP2 = #GPUs.
|
||||
|
||||
### Training Control
|
||||
|
||||
- **`exp_ctrl.total_train_epochs`**: Number of training epochs (complete dataset iterations).
|
||||
- **`exp_ctrl.save_freq_{epochs|steps|secs}`**: Frequency for saving model parameters to persistent storage. Set to null to disable saving.
|
||||
- **`exp_ctrl.ckpt_freq_{epochs|steps|secs}`**: Frequency for saving temporary parameters for restart capability.
|
||||
- **`dataset.train_bs_n_seqs`**: Training batch size (number of prompts sampled per training iteration).
|
||||
- **`group_size`**: Number of responses sampled per prompt.
|
||||
|
||||
### Memory and Performance
|
||||
|
||||
- **`{actor_train|ref_inf|actor_inf}.mb_spec.max_tokens_per_mb`**: Maximum tokens per mini-batch for forward/backward passes during reference model inference and actor model training. Reduce this value to avoid OOM errors.
|
||||
- **`ppo.ppo_n_minibatches`**: Number of mini-batches for dividing data during each PPO update.
|
||||
|
||||
### PPO Configuration
|
||||
|
||||
- **`ppo.recompute_logprob`**: Whether to compute proximal log probabilities for training. Defaults to True for asynchronous experiments and False for synchronous baselines.
|
||||
- **`ppo.use_decoupled_loss`**: Use decoupled loss to stabilize asynchronous training. Defaults to True.
|
||||
- **`ppo.gen.max_new_tokens`**: Maximum tokens to generate per prompt.
|
||||
|
||||
## Monitoring the Training Process
|
||||
|
||||
We recommend using Weights & Biases (wandb) for monitoring. Run `wandb login` or set the `WANDB_API_KEY` environment variable. Set `wandb.mode=True` in your configuration to upload training statistics.
|
||||
|
||||
You can also use TensorBoard by setting the `tensorboard.path` parameter.
|
||||
|
||||
The main log will be saved to `${fileroot}/logs/${USER}/${experiment_name}/${trial_name}/main.log` and contains the statistics uploaded to wandb.
|
||||
|
||||
### Key Training Statistics
|
||||
|
||||
- **`Epoch 1/5`**: Indicates total epochs required and current epoch being trained.
|
||||
- **`step 6/19`**: Shows the current epoch has 19 steps, with the 6th step just completed.
|
||||
- **`global step 6`**: Step count across all epochs.
|
||||
- **`ppo_actor/task_reward/avg`**: Average reward value of all sampled responses in this step. Should steadily increase during training and eventually stabilize.
|
||||
- **`ppo_actor/importance_weight/avg`**: Average importance sampling ratio across all tokens in the PPO loss. Typically close to 1.0.
|
||||
- **`ppo_actor/actor_clip_ratio/avg`**: Ratio of clipped tokens in PPO loss to total tokens. Usually less than 0.1.
|
||||
- **`ppo_actor/actor_loss/avg`**: PPO loss value. **Does not show clear trends during training** and should not be used as a performance indicator.
|
||||
|
||||
## Next Steps
|
||||
|
||||
[Evaluate your model](eval.md) or check the [troubleshooting section](troubleshooting.md) if you encounter any issues.
|
|
@ -43,9 +43,10 @@ The key to resolving this issue is identifying the phase where the error occurs:
|
|||
#### During SGLang Generation
|
||||
- Decrease the `actor.sglang.mem_fraction_static` parameter
|
||||
- Increase the tensor parallelism degree
|
||||
- Decrease the `max_concurrent_rollouts` parameter for asynchronous RL
|
||||
|
||||
#### During `actor_inf` or `actor_train`
|
||||
- **Adjust microbatch size**: Set parameters like `actor_train.mb_spec.max_tokens_per_mb=20480`. This parameter limits tokens per forward/backward pass and can be set as low as the maximum sequence length (including prompt)
|
||||
- **Adjust microbatch size**: Decrease the parameter `{actor_train|actor_inf}.mb_spec.max_tokens_per_mb=20480`. This parameter limits tokens per forward/backward pass and can be set as low as the maximum sequence length (including prompt)
|
||||
- **Modify parallelism strategy**: Adjust `allocation_mode` by:
|
||||
- Reducing data parallelism
|
||||
- Increasing tensor or pipeline parallelism
|
||||
|
@ -53,4 +54,4 @@ The key to resolving this issue is identifying the phase where the error occurs:
|
|||
|
||||
### CUDA Error: Out of Memory
|
||||
|
||||
This issue may occur during data transfer. Try increasing `mem_per_xx_worker` in the CLI arguments.
|
||||
This issue may occur during data transfer. Try increasing `mem_per_model_worker` in the CLI arguments.
|
|
@ -173,7 +173,7 @@ def evaluate(samples):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# data_list = load_jsonl("functioncall/test/test_success_dataset.jsonl")
|
||||
# data_list = load_jsonl("/storage/openpsi/data/functioncall/test/test_success_dataset.jsonl")
|
||||
data_path = "/storage/openpsi/data/code/deepcoder/deepcoder_0415_v3_verify_new_correct.jsonl"
|
||||
|
||||
id2info = defaultdict(dict)
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
# common
|
||||
vllm
|
||||
tqdm
|
||||
datasets
|
||||
torch
|
||||
transformers==4.47.0
|
||||
transformers==4.51.1
|
||||
hf_transfer
|
||||
huggingface_hub
|
||||
python_dateutil
|
||||
flash_attn
|
||||
|
||||
# math_eval
|
||||
sympy==1.12
|
||||
|
@ -14,4 +13,5 @@ word2number
|
|||
Pebble
|
||||
prettytable
|
||||
timeout-decorator
|
||||
timeout_decorator
|
||||
wandb
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
#/bin/bash
|
||||
# basic dependencies
|
||||
pip install -U pip
|
||||
pip uninstall deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
|
||||
pip install nvidia-ml-py
|
||||
pip install -e evaluation/latex2sympy
|
||||
pip install vllm==0.8.5 --no-build-isolation
|
||||
pip install flash_attn --no-build-isolation
|
||||
pip install -r evaluation/requirements.txt
|
|
@ -0,0 +1,28 @@
|
|||
#!/bin/bash
|
||||
# basic dependencies
|
||||
pip install -U pip
|
||||
pip uninstall deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
|
||||
pip install "sglang[all]==0.4.6.post4"
|
||||
pip install megatron-core==0.11.0 nvidia-ml-py
|
||||
pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose
|
||||
pip install flash-attn --no-build-isolation
|
||||
|
||||
# the sympy virtual env for reward computation
|
||||
pip install virtualenv
|
||||
rm -rf ./sympy
|
||||
python3 -m venv ./sympy
|
||||
# equivalent to install `./evaluation/latex2sympy` in the sympy virtual env
|
||||
./sympy/bin/pip install git+https://github.com/QwenLM/Qwen2.5-Math.git#subdirectory=evaluation/latex2sympy
|
||||
./sympy/bin/pip install regex numpy tqdm datasets python_dateutil sympy==1.12 antlr4-python3-runtime==4.11.1 word2number Pebble timeout-decorator prettytable
|
||||
|
||||
# Install an editable sglang
|
||||
rm -rf ./sglang
|
||||
git clone -b v0.4.6.post4 https://github.com/sgl-project/sglang
|
||||
AREAL_PATH=$PWD
|
||||
cd sglang
|
||||
git apply ../patch/sglang/v0.4.6.post4.patch
|
||||
pip install -e "python[all]" --no-deps
|
||||
cd $AREAL_PATH
|
||||
|
||||
# Install AReaL
|
||||
pip install -e .
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
python3 training/main_async_ppo.py \
|
||||
n_nodes=1 n_gpus_per_node=8 \
|
||||
allocation_mode=sglang.d4p1m1+d2p2m1 \
|
||||
cluster.fileroot=/storage/testing/experiments \
|
||||
actor.type._class=qwen3 \
|
||||
actor.path=/storage/testing/models/Qwen__Qwen3-1.7B \
|
||||
ref.type._class=qwen3 \
|
||||
ref.path=/storage/testing/models/Qwen__Qwen3-1.7B \
|
||||
dataset.path=/storage/testing/dataset/boba_106k_0319.jsonl \
|
||||
dataset.train_bs_n_seqs=32 \
|
||||
group_size=8 \
|
||||
ppo.gen.max_new_tokens=4096 \
|
||||
ppo.ppo_n_minibatches=4 \
|
||||
actor_train.mb_spec.max_tokens_per_mb=32768 \
|
||||
actor_inf.mb_spec.max_tokens_per_mb=32768 \
|
||||
max_concurrent_rollouts=16 \
|
||||
max_head_offpolicyness=4
|
|
@ -0,0 +1,13 @@
|
|||
#!/bin/bash
|
||||
python3 training/main_sft.py \
|
||||
n_nodes=1 n_gpus_per_node=8 \
|
||||
allocation_mode=d4p2m1 \
|
||||
cluster.fileroot=/storage/testing/experiments \
|
||||
model.type._class=qwen3 \
|
||||
exp_ctrl.eval_freq_epochs=1 \
|
||||
model.path=/storage/testing/models/Qwen__Qwen3-1.7B \
|
||||
dataset.train_path=/storage/testing/dataset/areal-sft-stage2-200.jsonl \
|
||||
dataset.valid_path=/storage/testing/dataset/areal-sft-stage2-200.jsonl \
|
||||
dataset.train_bs_n_seqs=64 \
|
||||
dataset.valid_bs_n_seqs=64 \
|
||||
allocation.mb_spec.max_tokens_per_mb=32768
|
|
@ -0,0 +1,15 @@
|
|||
#!/bin/bash
|
||||
python3 training/main_sync_ppo.py \
|
||||
n_nodes=1 n_gpus_per_node=8 \
|
||||
allocation_mode=sglang.d4p1m1+d2p2m1 \
|
||||
cluster.fileroot=/storage/testing/experiments \
|
||||
actor.type._class=qwen3 \
|
||||
actor.path=/storage/testing/models/Qwen__Qwen3-1.7B \
|
||||
ref.type._class=qwen3 \
|
||||
ref.path=/storage/testing/models/Qwen__Qwen3-1.7B \
|
||||
dataset.path=/storage/testing/dataset/boba_106k_0319.jsonl \
|
||||
dataset.train_bs_n_seqs=32 \
|
||||
group_size=8 \
|
||||
ppo.gen.max_new_tokens=4096 \
|
||||
ppo.ppo_n_minibatches=4 \
|
||||
actor_train.mb_spec.max_tokens_per_mb=32768
|
|
@ -121,7 +121,9 @@ def code_verify(id2info, generateds, query_ids, debug=False):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_list = load_jsonl("functioncall/test/test_success_dataset.jsonl")
|
||||
data_list = load_jsonl(
|
||||
"/storage/openpsi/data/functioncall/test/test_success_dataset.jsonl"
|
||||
)
|
||||
id2info = defaultdict(dict)
|
||||
for item in data_list:
|
||||
id2info[item["query_id"]] = item
|
||||
|
|
|
@ -32,7 +32,11 @@ def construct_testcases(
|
|||
result.append({"input": input_, "expectedOutput": output_})
|
||||
continue
|
||||
|
||||
oss_basepath = "http://antsys-hcsfaas-images-dev.cn-heyuan-alipay-office.oss-alipay.aliyuncs.com/"
|
||||
oss_basepath = os.getenv("REAL_OSS_TESTCASE_PATH", "")
|
||||
if not oss_basepath:
|
||||
raise FileNotFoundError(
|
||||
"REAL_OSS_TESTCASE_PATH not set. Cannot use FAAS code reward."
|
||||
)
|
||||
input_url = (
|
||||
input_ if input_.startswith("http") else os.path.join(oss_basepath, input_)
|
||||
)
|
||||
|
@ -153,7 +157,9 @@ def code_verify(
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_list = load_jsonl("functioncall/test/test_success_dataset.jsonl")
|
||||
data_list = load_jsonl(
|
||||
"/storage/openpsi/data/functioncall/test/test_success_dataset.jsonl"
|
||||
)
|
||||
id2info = defaultdict(dict)
|
||||
for item in data_list:
|
||||
id2info[item["query_id"]] = item
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
{"task": "code", "query_id": "0001", "prompt": "", "solutions": ["from typing import *\n\nclass Solution:\n def solveNQueens(self, n: int) -> List[List[str]]:\n def generateBoard():\n board = list()\n for i in range(n):\n row[queens[i]] = \"Q\"\n board.append(\"\".join(row))\n row[queens[i]] = \".\"\n return board\n\n def solve(row: int, columns: int, diagonals1: int, diagonals2: int):\n if row == n:\n board = generateBoard()\n solutions.append(board)\n else:\n availablePositions = ((1 << n) - 1) & (~(columns | diagonals1 | diagonals2))\n while availablePositions:\n position = availablePositions & (-availablePositions)\n availablePositions = availablePositions & (availablePositions - 1)\n column = bin(position - 1).count(\"1\")\n queens[row] = column\n solve(row + 1, columns | position, (diagonals1 | position) << 1, (diagonals2 | position) >> 1)\n\n solutions = list()\n queens = [-1] * n\n row = [\".\"] * n\n solve(0, 0, 0, 0)\n return solutions\n# Test case 1: Smallest case, n = 1\n# There is only one queen, so the only solution is a board with a single 'Q'.\nsolution = Solution()\nassert solution.solveNQueens(1) == [['Q']]\n"], "input_output": "{\"inputs\": [], \"outputs\": [], \"fn_name\": \"\", \"remote\": false}", "language": "PYTHON"}
|
||||
{"task": "code", "query_id": "0002", "prompt": "", "solutions": ["package main\n\nimport (\n\t\"bufio\"\n\t\"fmt\"\n\t\"os\"\n\t\"strconv\"\n\t\"strings\"\n)\n\nfunc main() {\n\tnums, target := arrayToOneStyleInput()\n\tresult := twoSum(nums, target)\n\tfmt.Println(result)\n}\n\nfunc twoSum(nums []int, target int) []int {\n\tprevNums := map[int]int{}\n\tfor i, num := range nums {\n\t\ttargetNum := target - num\n\t\ttargetNumIndex, ok := prevNums[targetNum]\n\t\tif ok {\n\t\t\treturn []int{targetNumIndex, i}\n\t\t} else {\n\t\t\tprevNums[num] = i\n\t\t}\n\t}\n\treturn []int{}\n}\n\nfunc arrayToOneStyleInput() ([]int, int) {\n\t// 读取数组\n\tscanner := bufio.NewScanner(os.Stdin)\n\tscanner.Scan()\n\tarrStr := strings.Trim(scanner.Text(), \"[]\")\n\tarr := stringToIntSlice(strings.ReplaceAll(arrStr, \",\", \" \"))\n\n\t// 读取目标值\n\tscanner.Scan()\n\ttarget, _ := strconv.Atoi(scanner.Text())\n\n\treturn arr, target\n}\n\nfunc stringToIntSlice(s string) []int {\n\tparts := strings.Split(s, \" \")\n\tres := make([]int, len(parts))\n\tfor i, p := range parts {\n\t\tres[i], _ = strconv.Atoi(p)\n\t}\n\treturn res\n}\n"], "input_output": "{\"inputs\": [\"https://artifacts.antgroup-inc.cn/artifact/repositories/artifacts-pre-test-common/runtime/golang/1.0.1/input.txt\"], \"outputs\": [\"https://artifacts.antgroup-inc.cn/artifact/repositories/artifacts-pre-test-common/runtime/golang/1.0.1/output.txt\"], \"fn_name\": \"\", \"remote\": true }", "language": "GO"}
|
||||
{"task": "code", "query_id": "0003", "prompt": "", "solutions": ["public class TestMain {\n public static void main(String[] args) {\n assert \"test\".equals(\"test\");\n }\n}"], "input_output": "{\"inputs\": [], \"outputs\": [], \"fn_name\": \"\", \"remote\": false}", "language": "JAVA"}
|
||||
{"task": "code", "query_id": "0004", "prompt": "", "solutions": ["#include <iostream>\n#include <string>\nint main() {\n std::string name = \"Alice\";\n std::cout << \"Hello, \" << name << \"! Welcome to the world of C++!\\n\";\n return 0;\n}\n"], "input_output": "{\"inputs\": [], \"outputs\": [], \"fn_name\": \"\", \"remote\": false}", "language": "CPP"}
|
||||
{"task": "code", "query_id": "0005", "prompt": "", "solutions": ["import time\n\ndef square(num):\n return num ** 2\n\nresult=square(int(input()))\nprint(result)"], "input_output": "{\"inputs\": [\"2\", \"2\", \"2\"], \"outputs\": [\"4\\n\", \"4\\n\", \"4\\n\"], \"fn_name\": \"\", \"remote\": false}", "language": "PYTHON"}
|
||||
{"task": "code", "query_id": "0006", "prompt": "", "solutions": ["#include <iostream>\n\nint square(int number) {\n return number * number;\n}\n\nint main() {\n int num;\n std::cin >> num;\n\n int result = square(num);\n std::cout << result << std::endl;\n\n return 0;\n}"], "input_output": "{\"inputs\": [\"http://antsys-hcsfaas-images-dev.cn-heyuan-alipay-office.oss-alipay.aliyuncs.com/functioncall/content_2.txt\"], \"outputs\": [\"http://antsys-hcsfaas-images-dev.cn-heyuan-alipay-office.oss-alipay.aliyuncs.com/functioncall/content_4.txt\"], \"fn_name\": \"\", \"remote\": true }", "language": "CPP"}
|
|
@ -1,501 +1,3 @@
|
|||
diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py
|
||||
index 60091b9a..7a1c856b 100644
|
||||
--- a/python/sglang/srt/layers/logits_processor.py
|
||||
+++ b/python/sglang/srt/layers/logits_processor.py
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
+import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -44,6 +45,15 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
)
|
||||
from sglang.srt.utils import dump_to_file
|
||||
|
||||
+# When compute the input and output tokens logprobs, if the rows of the
|
||||
+# logprobs are too large, the peak memory usage will be too high. For example,
|
||||
+# if the logprobs are [10000, 150000], the peak memory usage will be greater
|
||||
+# than 10000 * 150000 * 4 / 1024 / 1024 = 5722.05 MB. (4 is the size of float)
|
||||
+# So we split the logprobs into multiple chunks.
|
||||
+LOGITS_PROCESSER_CHUNK_SIZE = int(
|
||||
+ os.environ.get("SGLANG_LOGITS_PROCESSER_CHUNK_SIZE", "2048")
|
||||
+)
|
||||
+
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -286,15 +296,17 @@ class LogitsProcessor(nn.Module):
|
||||
input_logprob_indices = None
|
||||
else:
|
||||
# Input logprobs are required.
|
||||
- # Find 3 different indices.
|
||||
+ # Find 4 different indices.
|
||||
# 1. pruned_states: hidden states that we want logprobs from.
|
||||
# 2. sample_indices: Indices that have sampled tokens.
|
||||
# 3. input_logprob_indices: Indices that have input logprob tokens.
|
||||
+ # 4. sequence_index_mapping: map pruned_states indices to top_logprobs_nums and token_ids_logprobs indices.
|
||||
sample_index_pt = -1
|
||||
sample_indices = []
|
||||
input_logprob_indices_pt = 0
|
||||
input_logprob_indices = []
|
||||
pt, pruned_states = 0, []
|
||||
+ idx, sequence_index_mapping = 0, []
|
||||
for extend_logprob_start_len, extend_len in zip(
|
||||
logits_metadata.extend_logprob_start_lens_cpu,
|
||||
logits_metadata.extend_seq_lens_cpu,
|
||||
@@ -310,7 +322,11 @@ class LogitsProcessor(nn.Module):
|
||||
# by a caller.
|
||||
assert extend_len > start_len
|
||||
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
||||
+ # sequence_index_mapping, repeat this for loop index
|
||||
+ sequence_index_mapping.extend([idx] * (extend_len - start_len))
|
||||
+ idx += 1
|
||||
pt += extend_len
|
||||
+
|
||||
sample_index_pt += extend_len - start_len
|
||||
sample_indices.append(sample_index_pt)
|
||||
input_logprob_indices.extend(
|
||||
@@ -321,6 +337,7 @@ class LogitsProcessor(nn.Module):
|
||||
)
|
||||
input_logprob_indices_pt += extend_len - start_len
|
||||
|
||||
+ sequence_index_mapping.append(idx - 1)
|
||||
pruned_states = torch.cat(pruned_states)
|
||||
sample_indices = torch.tensor(
|
||||
sample_indices, device=pruned_states.device, dtype=torch.int64
|
||||
@@ -329,12 +346,6 @@ class LogitsProcessor(nn.Module):
|
||||
input_logprob_indices, device=pruned_states.device, dtype=torch.int64
|
||||
)
|
||||
|
||||
- # Compute logits for both input and sampled tokens.
|
||||
- logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
||||
- sampled_logits = (
|
||||
- logits[sample_indices] if sample_indices is not None else logits
|
||||
- )
|
||||
-
|
||||
if self.debug_tensor_dump_output_folder:
|
||||
assert (
|
||||
not self.do_tensor_parallel_all_gather
|
||||
@@ -370,67 +381,176 @@ class LogitsProcessor(nn.Module):
|
||||
else:
|
||||
assert False, "Should never reach"
|
||||
|
||||
+ del hidden_states
|
||||
+
|
||||
if not logits_metadata.extend_return_logprob:
|
||||
+ # Compute logits for both input and sampled tokens.
|
||||
+ logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
||||
+ sampled_logits = (
|
||||
+ logits[sample_indices] if sample_indices is not None else logits
|
||||
+ )
|
||||
+
|
||||
# Decode mode or extend mode without return_logprob.
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=sampled_logits,
|
||||
hidden_states=hidden_states_to_store,
|
||||
)
|
||||
else:
|
||||
- input_logprobs = logits[input_logprob_indices]
|
||||
- del hidden_states, logits
|
||||
+ # Compute logprobs requires lot of memory, so we split pruned_states
|
||||
+ # into chunks of rows to compute input_logprobs separately, then
|
||||
+ # concatenate the results.
|
||||
+ return self._compute_output_by_chunk(
|
||||
+ pruned_states,
|
||||
+ sample_indices,
|
||||
+ hidden_states_to_store,
|
||||
+ input_logprob_indices,
|
||||
+ sequence_index_mapping,
|
||||
+ lm_head,
|
||||
+ logits_metadata,
|
||||
+ )
|
||||
|
||||
- # Normalize the logprob w/o temperature, top-p
|
||||
- pruned_lens = torch.tensor(
|
||||
- logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||
- device=input_logprobs.device,
|
||||
+ def _compute_output_by_chunk(
|
||||
+ self,
|
||||
+ pruned_states: torch.Tensor,
|
||||
+ sample_indices: torch.Tensor,
|
||||
+ hidden_states_to_store: Optional[torch.Tensor],
|
||||
+ input_logprob_indices: torch.Tensor,
|
||||
+ index_mapping: list[int],
|
||||
+ lm_head: VocabParallelEmbedding,
|
||||
+ logits_metadata: LogitsMetadata,
|
||||
+ ) -> LogitsProcessorOutput:
|
||||
+ """
|
||||
+ compute logprobs for the output token from the hidden states.
|
||||
+ To avoid using too much memory, we split pruned_states into chunks of
|
||||
+ rows to compute input_logprobs separately, then concatenate the results.
|
||||
+
|
||||
+ Returns:
|
||||
+ LogitsProcessorOutput: logits processor output class
|
||||
+ """
|
||||
+
|
||||
+ # Normalize the logprob w/o temperature, top-p
|
||||
+ pruned_lens = torch.tensor(
|
||||
+ logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||
+ device=pruned_states.device,
|
||||
+ )
|
||||
+ if logits_metadata.temp_scaled_logprobs:
|
||||
+ logits_metadata.temperature = torch.repeat_interleave(
|
||||
+ logits_metadata.temperature.view(-1),
|
||||
+ pruned_lens,
|
||||
+ ).view(-1, 1)
|
||||
+ if logits_metadata.top_p_normalized_logprobs:
|
||||
+ logits_metadata.top_p = torch.repeat_interleave(
|
||||
+ logits_metadata.top_p,
|
||||
+ pruned_lens,
|
||||
)
|
||||
- if logits_metadata.temp_scaled_logprobs:
|
||||
- logits_metadata.temperature = torch.repeat_interleave(
|
||||
- logits_metadata.temperature.view(-1),
|
||||
- pruned_lens,
|
||||
- ).view(-1, 1)
|
||||
- if logits_metadata.top_p_normalized_logprobs:
|
||||
- logits_metadata.top_p = torch.repeat_interleave(
|
||||
- logits_metadata.top_p,
|
||||
- pruned_lens,
|
||||
- )
|
||||
- input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
||||
- input_logprobs, logits_metadata
|
||||
+
|
||||
+ # The peak memory usage is proportional to the chunk size.
|
||||
+ chunk_size = LOGITS_PROCESSER_CHUNK_SIZE
|
||||
+ num_chunks = (pruned_states.shape[0] + chunk_size - 1) // chunk_size
|
||||
+
|
||||
+ input_token_logprobs = []
|
||||
+ if logits_metadata.extend_return_top_logprob:
|
||||
+ input_top_logprobs_val = []
|
||||
+ input_top_logprobs_idx = []
|
||||
+ else:
|
||||
+ input_top_logprobs_val = None
|
||||
+ input_top_logprobs_idx = None
|
||||
+
|
||||
+ if logits_metadata.extend_token_ids_logprob:
|
||||
+ input_token_ids_logprobs_val = []
|
||||
+ input_token_ids_logprobs_idx = []
|
||||
+ else:
|
||||
+ input_token_ids_logprobs_val = None
|
||||
+ input_token_ids_logprobs_idx = None
|
||||
+
|
||||
+ # It a single sequence is split into multiple chunks, we need to keep track
|
||||
+ # of the pruned length of the sequences in the previous chunks.
|
||||
+ split_len_topk = 0
|
||||
+ split_len_token_ids = 0
|
||||
+
|
||||
+ for i in range(num_chunks):
|
||||
+ start_idx = i * chunk_size
|
||||
+ end_idx = min((i + 1) * chunk_size, pruned_states.shape[0])
|
||||
+
|
||||
+ # Get indices for this chunk
|
||||
+ chunk_mask = (input_logprob_indices >= start_idx) & (
|
||||
+ input_logprob_indices < end_idx
|
||||
+ )
|
||||
+
|
||||
+ global_indices = input_logprob_indices[chunk_mask]
|
||||
+ chunk_indices = global_indices - start_idx
|
||||
+ chunk_states = pruned_states[start_idx:end_idx]
|
||||
+ chunk_logits = self._get_logits(chunk_states, lm_head, logits_metadata)
|
||||
+
|
||||
+ if chunk_indices.numel() == 0:
|
||||
+ continue
|
||||
+
|
||||
+ # Compute the logprobs of the chunk
|
||||
+ chunk_input_logprobs = chunk_logits[chunk_indices]
|
||||
+ chunk_input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
||||
+ chunk_input_logprobs, global_indices, logits_metadata
|
||||
)
|
||||
|
||||
+ # For each chunk, we need to get the slice of the sequence_index_mapping
|
||||
+ chunk_slice = slice(index_mapping[start_idx], index_mapping[end_idx] + 1)
|
||||
+
|
||||
# Get the logprob of top-k tokens
|
||||
if logits_metadata.extend_return_top_logprob:
|
||||
- (
|
||||
+ split_len_topk = self.get_top_logprobs(
|
||||
+ chunk_input_logprobs,
|
||||
+ logits_metadata,
|
||||
+ chunk_slice,
|
||||
input_top_logprobs_val,
|
||||
input_top_logprobs_idx,
|
||||
- ) = self.get_top_logprobs(input_logprobs, logits_metadata)
|
||||
- else:
|
||||
- input_top_logprobs_val = input_top_logprobs_idx = None
|
||||
+ split_len_topk,
|
||||
+ )
|
||||
|
||||
# Get the logprob of given token id
|
||||
if logits_metadata.extend_token_ids_logprob:
|
||||
- (
|
||||
+ split_len_token_ids = self.get_token_ids_logprobs(
|
||||
+ chunk_input_logprobs,
|
||||
+ logits_metadata,
|
||||
+ chunk_slice,
|
||||
input_token_ids_logprobs_val,
|
||||
input_token_ids_logprobs_idx,
|
||||
- ) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
|
||||
- else:
|
||||
- input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
|
||||
-
|
||||
- input_token_logprobs = input_logprobs[
|
||||
- torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
|
||||
- logits_metadata.extend_input_logprob_token_ids_gpu,
|
||||
- ]
|
||||
+ split_len_token_ids,
|
||||
+ )
|
||||
|
||||
- return LogitsProcessorOutput(
|
||||
- next_token_logits=sampled_logits,
|
||||
- input_token_logprobs=input_token_logprobs,
|
||||
- input_top_logprobs_val=input_top_logprobs_val,
|
||||
- input_top_logprobs_idx=input_top_logprobs_idx,
|
||||
- hidden_states=hidden_states_to_store,
|
||||
- input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
||||
- input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
||||
+ # Handle sampled logits for the chunk if needed
|
||||
+ chunk_sample_mask = (sample_indices >= start_idx) & (
|
||||
+ sample_indices < end_idx
|
||||
)
|
||||
+ if i == 0: # Initialize sampled_logits on first chunk
|
||||
+ sampled_logits = torch.empty(
|
||||
+ (sample_indices.shape[0], chunk_logits.shape[1]),
|
||||
+ dtype=chunk_logits.dtype,
|
||||
+ device=chunk_logits.device,
|
||||
+ )
|
||||
+ if chunk_sample_mask.any():
|
||||
+ chunk_sample_indices = sample_indices[chunk_sample_mask] - start_idx
|
||||
+ sampled_logits[chunk_sample_mask] = chunk_logits[chunk_sample_indices]
|
||||
+
|
||||
+ # Get the logprob of the requested token ids
|
||||
+ chunk_input_token_logprobs = chunk_input_logprobs[
|
||||
+ torch.arange(
|
||||
+ chunk_input_logprobs.shape[0], device=chunk_input_logprobs.device
|
||||
+ ),
|
||||
+ logits_metadata.extend_input_logprob_token_ids_gpu[start_idx:end_idx],
|
||||
+ ]
|
||||
+ input_token_logprobs.append(chunk_input_token_logprobs)
|
||||
+
|
||||
+ # Concatenate the results
|
||||
+ input_token_logprobs = torch.cat(input_token_logprobs, dim=0)
|
||||
+
|
||||
+ return LogitsProcessorOutput(
|
||||
+ hidden_states=hidden_states_to_store,
|
||||
+ next_token_logits=sampled_logits,
|
||||
+ input_token_logprobs=input_token_logprobs,
|
||||
+ input_top_logprobs_val=input_top_logprobs_val,
|
||||
+ input_top_logprobs_idx=input_top_logprobs_idx,
|
||||
+ input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
||||
+ input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
||||
+ )
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
@@ -498,60 +618,142 @@ class LogitsProcessor(nn.Module):
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
- def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
||||
+ def get_top_logprobs(
|
||||
+ logprobs: torch.Tensor,
|
||||
+ logits_metadata: LogitsMetadata,
|
||||
+ chunk_slice: slice,
|
||||
+ input_top_logprobs_val: List,
|
||||
+ input_top_logprobs_idx: List,
|
||||
+ split_pruned_len: int,
|
||||
+ ):
|
||||
+ """Get top-k logprobs for each sequence in the chunk.
|
||||
+
|
||||
+ Args:
|
||||
+ logprobs: Log probabilities tensor of shape [seq_len, vocab_size]
|
||||
+ logits_metadata: Metadata containing top-k and pruned length info
|
||||
+ chunk_slice: Slice of sequences to process
|
||||
+ input_top_logprobs_val: List to store top-k logprob values
|
||||
+ input_top_logprobs_idx: List to store top-k token indices
|
||||
+ split_pruned_len: Length of pruned tokens from previous chunk
|
||||
+
|
||||
+ Returns:
|
||||
+ int: Number of remaining tokens to process in next chunk
|
||||
+ """
|
||||
+
|
||||
max_k = max(logits_metadata.top_logprobs_nums)
|
||||
- ret = all_logprobs.topk(max_k, dim=1)
|
||||
+ ret = logprobs.topk(max_k, dim=1)
|
||||
values = ret.values.tolist()
|
||||
indices = ret.indices.tolist()
|
||||
|
||||
- input_top_logprobs_val, input_top_logprobs_idx = [], []
|
||||
-
|
||||
pt = 0
|
||||
- for k, pruned_len in zip(
|
||||
- logits_metadata.top_logprobs_nums,
|
||||
- logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||
- ):
|
||||
+ next_split_pruned_len = 0
|
||||
+ top_k_nums = logits_metadata.top_logprobs_nums[chunk_slice]
|
||||
+ pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[chunk_slice]
|
||||
+
|
||||
+ for n, (k, pruned_len) in enumerate(zip(top_k_nums, pruned_lens)):
|
||||
+ # Adjust pruned length for first sequence
|
||||
+ if n == 0:
|
||||
+ pruned_len -= split_pruned_len
|
||||
+ else:
|
||||
+ split_pruned_len = 0
|
||||
+
|
||||
if pruned_len <= 0:
|
||||
- input_top_logprobs_val.append([])
|
||||
- input_top_logprobs_idx.append([])
|
||||
+ if n == 0:
|
||||
+ input_top_logprobs_val.append([])
|
||||
+ input_top_logprobs_idx.append([])
|
||||
continue
|
||||
|
||||
- input_top_logprobs_val.append(
|
||||
- [values[pt + j][:k] for j in range(pruned_len)]
|
||||
- )
|
||||
- input_top_logprobs_idx.append(
|
||||
- [indices[pt + j][:k] for j in range(pruned_len)]
|
||||
- )
|
||||
- pt += pruned_len
|
||||
+ val = []
|
||||
+ idx = []
|
||||
+ for j in range(pruned_len):
|
||||
+ # Handle remaining tokens in next chunk if any
|
||||
+ if pt + j >= len(values):
|
||||
+ next_split_pruned_len = split_pruned_len + j
|
||||
+ break
|
||||
+ val.append(values[pt + j][:k])
|
||||
+ idx.append(indices[pt + j][:k])
|
||||
+
|
||||
+ if split_pruned_len <= 0 and len(val) > 0:
|
||||
+ input_top_logprobs_val.append(val)
|
||||
+ input_top_logprobs_idx.append(idx)
|
||||
+ else:
|
||||
+ input_top_logprobs_val[-1].extend(val)
|
||||
+ input_top_logprobs_idx[-1].extend(idx)
|
||||
|
||||
- return input_top_logprobs_val, input_top_logprobs_idx
|
||||
+ pt += pruned_len
|
||||
+ return next_split_pruned_len
|
||||
|
||||
@staticmethod
|
||||
def get_token_ids_logprobs(
|
||||
- all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
|
||||
+ logprobs: torch.Tensor,
|
||||
+ logits_metadata: LogitsMetadata,
|
||||
+ chunk_slice: slice,
|
||||
+ input_token_ids_logprobs_val: List,
|
||||
+ input_token_ids_logprobs_idx: List,
|
||||
+ split_pruned_len: int = 0,
|
||||
):
|
||||
- input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
||||
+ """Get token_ids logprobs for each sequence in the chunk.
|
||||
+
|
||||
+ Args:
|
||||
+ logprobs: Log probabilities tensor of shape [seq_len, vocab_size]
|
||||
+ logits_metadata: Metadata containing token IDs and pruned length info
|
||||
+ chunk_slice: Slice of sequences to process
|
||||
+ input_token_ids_logprobs_val: List to store token logprob values
|
||||
+ input_token_ids_logprobs_idx: List to store token indices
|
||||
+ split_pruned_len: Length of pruned tokens from previous chunk
|
||||
+
|
||||
+ Returns:
|
||||
+ int: Number of remaining tokens to process in next chunk
|
||||
+ """
|
||||
pt = 0
|
||||
- for token_ids, pruned_len in zip(
|
||||
- logits_metadata.token_ids_logprobs,
|
||||
- logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||
+ next_split_pruned_len = 0
|
||||
+ token_ids_logprobs_chunk = logits_metadata.token_ids_logprobs[chunk_slice]
|
||||
+ pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[chunk_slice]
|
||||
+
|
||||
+ for n, (token_ids, pruned_len) in enumerate(
|
||||
+ zip(
|
||||
+ token_ids_logprobs_chunk,
|
||||
+ pruned_lens,
|
||||
+ )
|
||||
):
|
||||
+ # Adjust pruned length for first sequence
|
||||
+ if n == 0:
|
||||
+ pruned_len -= split_pruned_len
|
||||
+ else:
|
||||
+ split_pruned_len = 0
|
||||
+
|
||||
if pruned_len <= 0:
|
||||
- input_token_ids_logprobs_val.append([])
|
||||
- input_token_ids_logprobs_idx.append([])
|
||||
+ if n == 0:
|
||||
+ input_token_ids_logprobs_val.append([])
|
||||
+ input_token_ids_logprobs_idx.append([])
|
||||
continue
|
||||
|
||||
- input_token_ids_logprobs_val.append(
|
||||
- [all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
|
||||
- )
|
||||
- input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
||||
- pt += pruned_len
|
||||
+ val = []
|
||||
+ idx = []
|
||||
+ for j in range(pruned_len):
|
||||
+ # Handle remaining tokens in next chunk if any
|
||||
+ if pt + j >= logprobs.shape[0]:
|
||||
+ next_split_pruned_len = split_pruned_len + j
|
||||
+ break
|
||||
+ if token_ids is not None:
|
||||
+ val.append(logprobs[pt + j, token_ids].tolist())
|
||||
+ idx.append(token_ids)
|
||||
+
|
||||
+ if split_pruned_len <= 0 and len(val) > 0:
|
||||
+ input_token_ids_logprobs_val.append(val)
|
||||
+ input_token_ids_logprobs_idx.append(idx)
|
||||
+ elif len(val) > 0:
|
||||
+ input_token_ids_logprobs_val[-1].extend(val)
|
||||
+ input_token_ids_logprobs_idx[-1].extend(idx)
|
||||
|
||||
- return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
|
||||
+ pt += pruned_len
|
||||
+ return next_split_pruned_len
|
||||
|
||||
@staticmethod
|
||||
def compute_temp_top_p_normalized_logprobs(
|
||||
- last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
||||
+ last_logits: torch.Tensor,
|
||||
+ indices: torch.Tensor,
|
||||
+ logits_metadata: LogitsMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
compute logprobs for the output token from the given logits.
|
||||
@@ -561,19 +763,20 @@ class LogitsProcessor(nn.Module):
|
||||
"""
|
||||
# Scale logits if temperature scaling is enabled
|
||||
if logits_metadata.temp_scaled_logprobs:
|
||||
- last_logits = last_logits / logits_metadata.temperature
|
||||
+ last_logits = last_logits / logits_metadata.temperature[indices]
|
||||
+
|
||||
+ top_p = None
|
||||
+ if logits_metadata.top_p is not None:
|
||||
+ top_p = logits_metadata.top_p[indices]
|
||||
|
||||
# Normalize logprobs if top_p normalization is enabled
|
||||
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
|
||||
- if (
|
||||
- logits_metadata.top_p_normalized_logprobs
|
||||
- and (logits_metadata.top_p != 1.0).any()
|
||||
- ):
|
||||
+ if logits_metadata.top_p_normalized_logprobs and (top_p != 1.0).any():
|
||||
from sglang.srt.layers.sampler import top_p_normalize_probs_torch
|
||||
|
||||
probs = torch.softmax(last_logits, dim=-1)
|
||||
del last_logits
|
||||
- probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
|
||||
+ probs = top_p_normalize_probs_torch(probs, top_p)
|
||||
return torch.log(probs)
|
||||
else:
|
||||
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
|
||||
index 5390668c..db370d19 100644
|
||||
--- a/python/sglang/srt/managers/io_struct.py
|
||||
|
|
100
pyproject.toml
100
pyproject.toml
|
@ -25,6 +25,106 @@ classifiers = [
|
|||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
# Core ML/AI libraries
|
||||
"torch>2.0.0",
|
||||
"huggingface_hub",
|
||||
"datasets",
|
||||
"accelerate",
|
||||
"transformers==4.51.1",
|
||||
|
||||
# Scientific computing
|
||||
"numpy<2.0.0",
|
||||
"scipy",
|
||||
"pandas",
|
||||
"matplotlib",
|
||||
"seaborn",
|
||||
"h5py",
|
||||
|
||||
# Utilities and data processing
|
||||
"nltk",
|
||||
"sentencepiece",
|
||||
"einops",
|
||||
"tqdm",
|
||||
"rich",
|
||||
"orjson>=3.10.16",
|
||||
"pydantic",
|
||||
"PyYAML",
|
||||
"omegaconf",
|
||||
"hydra-core",
|
||||
"packaging",
|
||||
"tabulate",
|
||||
|
||||
# Monitoring and logging
|
||||
"wandb",
|
||||
"tensorboardx",
|
||||
"colorama",
|
||||
"colorlog",
|
||||
"psutil",
|
||||
"pynvml",
|
||||
|
||||
# Performance and compression
|
||||
"ninja",
|
||||
"numba",
|
||||
"blosc",
|
||||
"pybind11>=2.10.0",
|
||||
|
||||
# Networking and async
|
||||
"networkx==3.3",
|
||||
"aiofiles",
|
||||
"aiohttp>=3.11.10",
|
||||
"httpx>=0.28.1",
|
||||
"pyzmq",
|
||||
"paramiko",
|
||||
"etcd3",
|
||||
"protobuf<3.21",
|
||||
|
||||
# Distributed computing
|
||||
"ray",
|
||||
"redis",
|
||||
|
||||
# Web frameworks
|
||||
"fastapi>=0.115.12",
|
||||
"uvicorn>=0.34.2",
|
||||
"uvloop>=0.21.0",
|
||||
"flask",
|
||||
|
||||
# Build and packaging tools
|
||||
"build>=1.2.1",
|
||||
"wheel>=0.43.0",
|
||||
"setuptools>=62.3.0,<75.9",
|
||||
"cookiecutter>2.1.1",
|
||||
|
||||
# System utilities
|
||||
"distro-info>=1.0",
|
||||
"python-debian>=0.1.49",
|
||||
"func_timeout",
|
||||
|
||||
# Development tools (consider moving to optional dependencies)
|
||||
"pytest",
|
||||
"ipython",
|
||||
"jupyter-book",
|
||||
"sphinx",
|
||||
"sphinx-nefertiti",
|
||||
"black==25.1.0",
|
||||
"isort==5.13.2",
|
||||
"clang-format==19.1.7",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest",
|
||||
"black==25.1.0",
|
||||
"isort==5.13.2",
|
||||
"clang-format==19.1.7",
|
||||
]
|
||||
|
||||
docs = [
|
||||
"sphinx",
|
||||
"sphinx-nefertiti",
|
||||
"jupyter-book",
|
||||
]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = {attr = "realhf.__version__"}
|
||||
|
||||
|
|
|
@ -153,11 +153,8 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
|
|||
# set env vars
|
||||
BASE_ENVIRONS = constants.get_env_vars(
|
||||
REAL_MODE=args.mode.upper(),
|
||||
CLUSTER_SPEC_PATH=cluster_spec_path,
|
||||
REAL_RECOVER_RUN="1" if is_recover_run else "0",
|
||||
REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0",
|
||||
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
||||
REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", ""),
|
||||
)
|
||||
for k, v in BASE_ENVIRONS.items():
|
||||
os.environ[k] = v
|
||||
|
|
|
@ -571,6 +571,15 @@ def get_repo_path() -> pathlib.Path:
|
|||
|
||||
|
||||
def get_env_vars(**kwargs):
|
||||
kwargs.update(
|
||||
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
|
||||
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
|
||||
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),
|
||||
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
||||
REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"),
|
||||
REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", "localhost:2379"),
|
||||
REAL_OSS_TESTCASE_PATH=os.getenv("REAL_OSS_TESTCASE_PATH", ""),
|
||||
)
|
||||
return {
|
||||
**kwargs,
|
||||
"REAL_PACKAGE_PATH": str(get_repo_path()),
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from huggingface_hub import snapshot_download, try_to_load_from_cache
|
||||
|
||||
from realhf.api.cli_args import ModelTrainEvalConfig, SGLangConfig, vLLMConfig
|
||||
from realhf.api.quickstart.device_mesh import RPCAllocation
|
||||
|
@ -56,13 +58,89 @@ def check_valid_optimizer(model: ModelTrainEvalConfig):
|
|||
)
|
||||
|
||||
|
||||
def check_valid_model_and_path(role: str, model: ModelTrainEvalConfig):
|
||||
if not os.path.exists(model.path):
|
||||
raise FileNotFoundError(
|
||||
f"The model path `{model.path}` for `{role}` does not exist locally. "
|
||||
"You must download the HuggingFace checkpoint before loading it."
|
||||
def check_valid_model_and_path(role: str, model: ModelTrainEvalConfig, fileroot):
|
||||
"""
|
||||
Check if model path exists locally, download from HuggingFace Hub if not.
|
||||
|
||||
Args:
|
||||
role: The role identifier for the model
|
||||
model: ModelTrainEvalConfig object containing model configuration
|
||||
|
||||
Returns:
|
||||
str: The local path to the model (either existing or newly downloaded)
|
||||
|
||||
Raises:
|
||||
Exception: If download fails or other errors occur
|
||||
"""
|
||||
if os.path.exists(model.path):
|
||||
return
|
||||
|
||||
logger.info(f"Model path `{model.path}` for `{role}` does not exist locally.")
|
||||
|
||||
# Extract model name from path or use the path as model identifier
|
||||
# Adjust this logic based on how your ModelTrainEvalConfig stores the model identifier
|
||||
model_name = model.path
|
||||
|
||||
# First, check if model exists in HuggingFace cache
|
||||
logger.info(f"Checking HuggingFace cache for model: {model_name}")
|
||||
cached_path = _check_huggingface_cache(model_name)
|
||||
if cached_path:
|
||||
logger.info(f"Found model in HuggingFace cache: {cached_path}")
|
||||
model.path = cached_path
|
||||
return
|
||||
|
||||
# If not in cache, download to /models/ directory
|
||||
logger.info(f"Model not found in cache. Downloading from HuggingFace Hub...")
|
||||
target_path = os.path.join(fileroot, "models", model_name.replace("/", "--"))
|
||||
if not os.path.exists(target_path):
|
||||
snapshot_download(
|
||||
repo_id=model_name,
|
||||
local_dir=target_path, # Replace '/' to avoid path issues
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
logger.info(f"Model downloaded successfully to: {target_path}")
|
||||
# Update the model object's path to point to the downloaded location
|
||||
model.path = target_path
|
||||
|
||||
|
||||
def _check_huggingface_cache(model_name: str) -> Optional[str]:
|
||||
"""
|
||||
Check if a model exists in the HuggingFace cache.
|
||||
|
||||
Args:
|
||||
model_name: The HuggingFace model identifier (e.g., 'bert-base-uncased')
|
||||
|
||||
Returns:
|
||||
Optional[str]: Path to cached model if found, None otherwise
|
||||
"""
|
||||
# Try to find the model files in cache
|
||||
# We'll check for common files that should exist in a model repo
|
||||
common_files = [
|
||||
"config.json",
|
||||
"pytorch_model.bin",
|
||||
"model.safetensors",
|
||||
"tf_model.h5",
|
||||
]
|
||||
|
||||
cached_path = None
|
||||
for filename in common_files:
|
||||
file_path = try_to_load_from_cache(
|
||||
repo_id=model_name, filename=filename, repo_type="model"
|
||||
)
|
||||
if file_path is not None:
|
||||
# Get the directory containing the cached file
|
||||
cached_path = os.path.dirname(file_path)
|
||||
break
|
||||
|
||||
# Verify the cached directory exists and contains model files
|
||||
if cached_path and os.path.exists(cached_path):
|
||||
# Double-check that it's a valid model directory
|
||||
if any(os.path.exists(os.path.join(cached_path, f)) for f in common_files):
|
||||
return cached_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def check_valid_parallel_batch_size(rpc_alloc: RPCAllocation):
|
||||
try:
|
||||
|
|
|
@ -4,6 +4,7 @@ import json
|
|||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from typing import *
|
||||
|
||||
|
@ -54,6 +55,10 @@ def parse_line(id2info, prompt_str, generated, query_id):
|
|||
f.write(json.dumps({"answer": generated, "solution": cur_solution}) + "\n")
|
||||
|
||||
venv_python = "/sympy/bin/python3"
|
||||
if not os.path.exists(venv_python):
|
||||
venv_python = "sympy/bin/python3"
|
||||
if not os.path.exists(venv_python):
|
||||
venv_python = sys.executable
|
||||
# logger.info(f"math verify working dir: `{os.getcwd()}`")
|
||||
pro = subprocess.Popen(
|
||||
" ".join(
|
||||
|
@ -67,6 +72,7 @@ def parse_line(id2info, prompt_str, generated, query_id):
|
|||
shell=True,
|
||||
preexec_fn=os.setsid,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=sys.stdout,
|
||||
)
|
||||
pro.wait()
|
||||
try:
|
||||
|
@ -131,6 +137,10 @@ def parse_lines_in_parallel(
|
|||
all_query_indices.append(query_indices)
|
||||
|
||||
venv_python = "/sympy/bin/python3"
|
||||
if not os.path.exists(venv_python):
|
||||
venv_python = "sympy/bin/python3"
|
||||
if not os.path.exists(venv_python):
|
||||
venv_python = sys.executable
|
||||
# logger.info(f"math verify working dir: `{os.getcwd()}`")
|
||||
procs = []
|
||||
for tmp_id in tmp_ids:
|
||||
|
@ -148,6 +158,7 @@ def parse_lines_in_parallel(
|
|||
shell=True,
|
||||
preexec_fn=os.setsid,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=sys.stdout,
|
||||
)
|
||||
procs.append(pro)
|
||||
for pro in procs:
|
||||
|
@ -183,7 +194,7 @@ def parse_lines_in_parallel(
|
|||
|
||||
if __name__ == "__main__":
|
||||
sample = {
|
||||
"answers": ["-\\frac{2}{3}"],
|
||||
"answers": ["\\boxed{-\\frac{2}{3}}"],
|
||||
"solutions": [
|
||||
"1. **Apply the operation $\\otimes$ to the innermost parentheses first:**\n \\[\n (1 \\otimes 2) \\otimes 3 = \\left(\\frac{1^2}{2}\\right) \\otimes 3 = \\frac{1}{2} \\otimes 3\n \\]\n \\[\n 1 \\otimes (2 \\otimes 3) = 1 \\otimes \\left(\\frac{2^2}{3}\\right) = 1 \\otimes \\frac{4}{3}\n \\]\n\n2. **Calculate each part using the definition of $\\otimes$:**\n \\[\n \\frac{1}{2} \\otimes 3 = \\frac{\\left(\\frac{1}{2}\\right)^2}{3} = \\frac{\\frac{1}{4}}{3} = \\frac{1}{12}\n \\]\n \\[\n 1 \\otimes \\frac{4}{3} = \\frac{1^2}{\\frac{4}{3}} = \\frac{1}{\\frac{4}{3}} = \\frac{3}{4}\n \\]\n\n3. **Subtract the two results:**\n \\[\n \\left(\\frac{1}{12}\\right) - \\left(\\frac{3}{4}\\right) = \\frac{1}{12} - \\frac{9}{12} = -\\frac{8}{12} = -\\frac{2}{3}\n \\]\n\n4. **Conclude with the final answer:**\n \\[\n \\boxed{A}\n \\]",
|
||||
"\\boxed{-\\frac{2}{3}}",
|
||||
|
|
|
@ -56,7 +56,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine):
|
|||
unique_params = params_tensor[1]
|
||||
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"CONFIG: default_train_mbs={self.pipe_runner.default_train_mbs} "
|
||||
f"default_inf_mbs={self.pipe_runner.default_inf_mbs} "
|
||||
f"num_layers(this stage)={self.module.num_layers} "
|
||||
|
@ -65,7 +65,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine):
|
|||
f"tp_size={constants.tensor_parallel_world_size()} "
|
||||
)
|
||||
if constants.data_parallel_rank() == 0:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"rank={constants.parallelism_rank()} "
|
||||
f"stage={constants.pipe_parallel_rank()} "
|
||||
f"layers={self.module.num_layers} "
|
||||
|
@ -105,7 +105,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine):
|
|||
)
|
||||
mb_inputs, fwd_indices, bwd_indices = input_.split(mb_spec)
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
|
||||
f"#tokens: {input_.data['packed_input_ids'].shape[0]}, "
|
||||
f"pp_size={constants.pipe_parallel_world_size()}, "
|
||||
|
@ -163,7 +163,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine):
|
|||
# mini-batches, so we split mini-batches in the outer loop.
|
||||
mb_inputs, *_ = input_.split(mb_spec)
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
|
||||
f"#tokens: {input_.data['packed_input_ids'].shape[0]}, "
|
||||
f"pp_size={constants.pipe_parallel_world_size()}, "
|
||||
|
|
|
@ -809,7 +809,7 @@ class PipelineRunner:
|
|||
)
|
||||
mb_inputs, fwd_indices, bwd_indices = input_.split(mb_spec)
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
|
||||
f"#tokens: {input_.data['packed_input_ids'].shape[0]}, "
|
||||
f"pp_size={constants.pipe_parallel_world_size()}, "
|
||||
|
@ -886,7 +886,7 @@ class PipelineRunner:
|
|||
mb_spec = MicroBatchSpec(n_mbs=self.default_inf_mbs)
|
||||
mb_inputs, *_ = input_.split(mb_spec)
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
|
||||
f"#tokens: {input_.data['packed_input_ids'].shape[0]}, "
|
||||
f"pp_size={constants.pipe_parallel_world_size()}, "
|
||||
|
@ -1013,7 +1013,7 @@ class PipelineRunner:
|
|||
dist.all_reduce(total_loss_weight, group=constants.data_parallel_group())
|
||||
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
|
||||
f"#tokens: {input_.data['packed_input_ids'].shape[0]}, "
|
||||
f"pp_size={constants.pipe_parallel_world_size()}, "
|
||||
|
|
|
@ -63,7 +63,7 @@ class _TensorDictSequenceBuffer:
|
|||
tuple(sorted(self.__storage[idx].sample.keys)),
|
||||
)
|
||||
if ck not in BUFFER_KEY_WARN_CACHE:
|
||||
logger.warning(
|
||||
logger.debug(
|
||||
f"Unexpected keys in the sample. Expected keys from all MFCDef: {self.__keys}. "
|
||||
f"Keys in the current sample: {self.__storage[idx].sample.keys}"
|
||||
)
|
||||
|
@ -300,7 +300,7 @@ class AsyncIOSequenceBuffer:
|
|||
)
|
||||
|
||||
can_do_rpcs = {rpc.name: self._can_do_rpc(rpc) for rpc in self.rpcs}
|
||||
logger.info(f"After putting batch, can do RPCs? {can_do_rpcs}.")
|
||||
logger.debug(f"After putting batch, can do RPCs? {can_do_rpcs}.")
|
||||
|
||||
self._lock.notify(len(self._rpc_names))
|
||||
return indices
|
||||
|
@ -348,7 +348,7 @@ class AsyncIOSequenceBuffer:
|
|||
async def get_batch_for_rpc(
|
||||
self, rpc: dfg.MFCDef
|
||||
) -> Tuple[List[int], SequenceSample]:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"MFC {rpc.name} is waiting for its input keys: {rpc.input_keys}..."
|
||||
)
|
||||
rpc_idx = self._rpc_names.index(rpc.name)
|
||||
|
@ -359,7 +359,7 @@ class AsyncIOSequenceBuffer:
|
|||
while not self._can_do_rpc(rpc):
|
||||
await self._lock.wait()
|
||||
|
||||
logger.info(f"Input keys ({rpc.input_keys}) for MFC {rpc.name} are ready!")
|
||||
logger.debug(f"Input keys ({rpc.input_keys}) for MFC {rpc.name} are ready!")
|
||||
self._assert_valid_indicator()
|
||||
|
||||
ready_indices = np.nonzero(
|
||||
|
|
|
@ -671,14 +671,8 @@ class RayController:
|
|||
|
||||
env_vars = constants.get_env_vars(
|
||||
REAL_MODE=os.environ.get("REAL_MODE", ""),
|
||||
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
|
||||
REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""),
|
||||
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
|
||||
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
||||
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
|
||||
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),
|
||||
REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"),
|
||||
REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", ""),
|
||||
)
|
||||
runtime_env = {
|
||||
"env_vars": env_vars,
|
||||
|
|
|
@ -86,7 +86,7 @@ class FunctionExecutor:
|
|||
for level, w in enumerate(self.topo_widths):
|
||||
for _ in range(w):
|
||||
await self.ctrl.topo_level_count.get()
|
||||
logger.info(f"DFG level {level}. Flushing {w} function calls.")
|
||||
logger.debug(f"DFG level {level}. Flushing {w} function calls.")
|
||||
self.stream.request(
|
||||
handlers=list(range(self.n_model_workers)), handle_type="flush"
|
||||
)
|
||||
|
@ -203,7 +203,7 @@ class FunctionExecutor:
|
|||
await asyncio.sleep(1)
|
||||
|
||||
async def execute_step(self):
|
||||
logger.info("Waiting for the finish of the execution graph.")
|
||||
logger.debug("Waiting for the finish of the execution graph.")
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
tasks = [
|
||||
|
|
|
@ -183,7 +183,7 @@ class GserverManager(AsyncWorker):
|
|||
success = res["success"]
|
||||
if success:
|
||||
if "num_paused_requests" in res:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"{res['num_paused_requests']} requests are interrupted "
|
||||
f"during updateing weights for server {server_index}: {server_url}"
|
||||
)
|
||||
|
@ -419,7 +419,7 @@ class GserverManager(AsyncWorker):
|
|||
if has_capacity and not is_staled:
|
||||
self.rollout_stat.submitted += 1
|
||||
self.rollout_stat.running += 1
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Allocate rollout for qid {req.qid}. "
|
||||
f"Submitted: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
|
@ -458,7 +458,7 @@ class GserverManager(AsyncWorker):
|
|||
self.rollout_stat.running -= 1
|
||||
if resp_meta.accepted:
|
||||
self.rollout_stat.accepted += 1
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Finish rollout for qid {resp_meta.qid}. "
|
||||
f"Running: {self.rollout_stat.running}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
|
|
|
@ -62,7 +62,7 @@ class MasterWorker(worker_base.AsyncWorker):
|
|||
self.__topo_widths = []
|
||||
for generation in nx.topological_generations(self.__model_rpcs[0]._G):
|
||||
self.__topo_widths.append(len(generation))
|
||||
logger.info("Topological widths: " + str(self.__topo_widths))
|
||||
logger.debug("Topological widths: " + str(self.__topo_widths))
|
||||
|
||||
self.__rpc_srcs = list(filter(lambda rpc: rpc.is_src, self.__model_rpcs))
|
||||
self.__rpc_dsts = list(filter(lambda rpc: rpc.is_dst, self.__model_rpcs))
|
||||
|
@ -169,7 +169,7 @@ class MasterWorker(worker_base.AsyncWorker):
|
|||
def initialize_models(self):
|
||||
# Initialize model backends.
|
||||
model_names = list(self.__model_topos.keys())
|
||||
self.logger.info(f"Initialize model backends with order: {model_names}.")
|
||||
self.logger.debug(f"Initialize model backends with order: {model_names}.")
|
||||
train_rpcs = list(
|
||||
filter(
|
||||
lambda rpc: rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
|
||||
|
@ -318,7 +318,7 @@ class MasterWorker(worker_base.AsyncWorker):
|
|||
self.__summary_writer = SummaryWriter(log_dir=self.tensorboard_config.path)
|
||||
|
||||
# Create coroutines for model RPCs.
|
||||
logger.info(f"Creating asyncio coroutines...")
|
||||
logger.debug(f"Creating asyncio coroutines...")
|
||||
self.func_executor = FunctionExecutor(
|
||||
rpcs=self.__model_rpcs,
|
||||
msid2mwid=self.config.msid2mwid,
|
||||
|
@ -334,7 +334,7 @@ class MasterWorker(worker_base.AsyncWorker):
|
|||
self.func_executor.data_loading_dp_idx = (
|
||||
self.__recover_info.data_loading_dp_idx
|
||||
)
|
||||
logger.info(f"Coroutines created. The master worker is ready to run.")
|
||||
logger.debug(f"Coroutines created. The master worker is ready to run.")
|
||||
|
||||
self.__initialized = True
|
||||
self._train_start_time = time.perf_counter()
|
||||
|
|
|
@ -292,7 +292,7 @@ class ModelFunctionCall:
|
|||
samples, forward_indices, _ = sample.split(
|
||||
data_api.MicroBatchSpec(n_mbs=self.dp_size)
|
||||
)
|
||||
blogger.info(
|
||||
blogger.debug(
|
||||
f"DP split (DP size {self.dp_size}) for RPC {self.rpc.name}: "
|
||||
f"#seqs: {[s.bs for s in samples]}, "
|
||||
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in samples]}"
|
||||
|
@ -353,7 +353,7 @@ class ModelFunctionCall:
|
|||
keys=rpc.input_keys,
|
||||
pattern=pattern,
|
||||
)
|
||||
blogger.info(f"Data tranfer plan for `{rpc.name}`: {data_transfer_plan}.")
|
||||
blogger.debug(f"Data tranfer plan for `{rpc.name}`: {data_transfer_plan}.")
|
||||
|
||||
# Update storage tracker for transferred data.
|
||||
if pattern == "bcast":
|
||||
|
@ -450,7 +450,7 @@ class ModelFunctionCall:
|
|||
elif isinstance(res, list):
|
||||
for j, r in enumerate(res):
|
||||
logger.info(
|
||||
f"RPC name {rpc.name} returns ({j}/{len(res)})\n{data_api.tabulate_stats(r)}"
|
||||
f"RPC name {rpc.name} returns ({j + 1}/{len(res)})\n{data_api.tabulate_stats(r)}"
|
||||
)
|
||||
offset = len(res) * ctrl.step_info.global_step
|
||||
logging.log_wandb_tensorboard(
|
||||
|
|
|
@ -510,7 +510,7 @@ class ModelWorker(worker_base.Worker):
|
|||
elif hook == "save":
|
||||
self.__save_model(hook_data)
|
||||
elif hook == "evaluate":
|
||||
logger.info(f"hook_data: {hook_data}")
|
||||
logger.debug(f"hook_data: {hook_data}")
|
||||
with constants.model_scope(hook_data["model_name"]):
|
||||
ret = self._interface.evaluate(self._model, self._eval_dataloader)
|
||||
if ret:
|
||||
|
@ -557,7 +557,7 @@ class ModelWorker(worker_base.Worker):
|
|||
assert not handled and res is None
|
||||
with constants.model_scope(request.handler.model_name):
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Model `{request.handler.model_name}` handling "
|
||||
f"{len(request.pre_hooks)} pre-hook for request `{request.handle_name}`. "
|
||||
f"The current hook is `{request.pre_hooks[0]}`. "
|
||||
|
@ -714,7 +714,7 @@ class ModelWorker(worker_base.Worker):
|
|||
blogger.debug(
|
||||
f"Model worker {self.__worker_index} cleared cache in {et-st:.4f}s. "
|
||||
)
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"Get clear_data_cache, dump cuda tmark. "
|
||||
f"Remaining data in local storage: {self.data_manager.storage_size()}. "
|
||||
)
|
||||
|
@ -742,7 +742,7 @@ class ModelWorker(worker_base.Worker):
|
|||
model_name = request.handler.model_name
|
||||
with constants.model_scope(model_name):
|
||||
if constants.parallelism_rank() == 0:
|
||||
blogger.info(
|
||||
blogger.debug(
|
||||
f"Model #{request.handler.model_name}# "
|
||||
f"starts handling request *{request.handle_name}*."
|
||||
)
|
||||
|
@ -789,7 +789,7 @@ class ModelWorker(worker_base.Worker):
|
|||
and self._is_dp_head
|
||||
and self._dp_rank == 0
|
||||
):
|
||||
blogger.info(
|
||||
blogger.debug(
|
||||
f"Model #{request.handler.model_name}# handle "
|
||||
f"request *{request.handle_name}*"
|
||||
f" in ${time.perf_counter() - tik:.4f}$s"
|
||||
|
|
|
@ -209,7 +209,9 @@ class RolloutWorker(AsyncWorker):
|
|||
resp.raise_for_status()
|
||||
res = await resp.json()
|
||||
if not res["success"]:
|
||||
logger.info(f"Cannot allocate new rollout because: {res['reason']}")
|
||||
logger.debug(
|
||||
f"Cannot allocate new rollout because: {res['reason']}"
|
||||
)
|
||||
return res["success"]
|
||||
|
||||
async def _poll_async(self):
|
||||
|
@ -272,7 +274,7 @@ class RolloutWorker(AsyncWorker):
|
|||
|
||||
self.rollout_stat.submitted += 1
|
||||
self.rollout_stat.running += 1
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Submit a new rollout for qid {qid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
|
@ -323,7 +325,7 @@ class RolloutWorker(AsyncWorker):
|
|||
) as resp:
|
||||
resp.raise_for_status()
|
||||
assert (await resp.json())["success"]
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Finish rollout for qid {qid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
# Basic experiment info
|
||||
experiment_name: async-ppo
|
||||
trial_name: my-trial
|
||||
seed: 1
|
||||
mode: ray
|
||||
metric_discovery_port: 17997
|
||||
wandb:
|
||||
mode: disabled
|
||||
entity: null
|
||||
project: null
|
||||
name: null
|
||||
job_type: null
|
||||
group: null
|
||||
notes: null
|
||||
tags: null
|
||||
config: null
|
||||
tensorboard:
|
||||
path: null
|
||||
recover_mode: auto
|
||||
recover_retries: 10
|
||||
recover_after: 10
|
||||
|
||||
exp_ctrl:
|
||||
total_train_epochs: 5
|
||||
save_freq_epochs: 1
|
||||
save_freq_steps: null
|
||||
save_freq_secs: null
|
||||
ckpt_freq_epochs: null
|
||||
ckpt_freq_steps: null
|
||||
ckpt_freq_secs: 600
|
||||
eval_freq_epochs: null
|
||||
eval_freq_steps: null
|
||||
eval_freq_secs: null
|
||||
benchmark_steps: null
|
||||
benchmark_n_seqs: null
|
||||
torch_cache_mysophobia: true
|
||||
cache_clear_freq: 1
|
||||
|
||||
# Asynchronous RL options
|
||||
new_tokens_per_chunk: 10000000000
|
||||
max_head_offpolicyness: 4
|
||||
n_rollout_workers: null
|
||||
max_concurrent_rollouts: null
|
||||
flush_request_timeout: 300
|
||||
|
||||
# Asynchronous worker resources
|
||||
cpus_per_generation_server: 4
|
||||
mem_per_generation_server: 61440
|
||||
cpus_per_gserver_manager: 4
|
||||
mem_per_gserver_manager: 10240
|
||||
cpus_per_rollout_worker: 4
|
||||
mem_per_rollout_worker: 20480
|
||||
|
||||
# Allocation and parallelism
|
||||
allocation_mode: sglang.d4p1m1+d2p2m1
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
|
||||
# Cluster configuration
|
||||
ray_temp_path: /tmp/ray
|
||||
cluster:
|
||||
fileroot: /storage/ray/experiments
|
||||
n_nodes: 32
|
||||
n_gpus_per_node: 8
|
||||
|
||||
# Model
|
||||
actor:
|
||||
type:
|
||||
_class: qwen3
|
||||
path: /storage/openpsi/models/Qwen3-1.7B/
|
||||
init_from_scratch: false
|
||||
gradient_checkpointing: true
|
||||
bf16: false
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2.0e-05
|
||||
weight_decay: 0.05
|
||||
beta1: 0.9
|
||||
beta2: 0.95
|
||||
eps: 1.0e-05
|
||||
min_lr_ratio: 0.0
|
||||
lr_scheduler_type: constant
|
||||
warmup_steps_proportion: 0.001
|
||||
initial_loss_scale: 4294967296.0
|
||||
min_loss_scale: 1.0
|
||||
loss_scale_window: 5.0
|
||||
hysteresis: 2
|
||||
gradient_clipping: 1.0
|
||||
megatron:
|
||||
ddp:
|
||||
grad_reduce_in_fp32: true
|
||||
overlap_grad_reduce: true
|
||||
use_distributed_optimizer: true
|
||||
sglang:
|
||||
disable_cuda_graph: false
|
||||
disable_radix_cache: false
|
||||
disable_cuda_graph_padding: false
|
||||
enable_nccl_nvls: false
|
||||
disable_outlines_disk_cache: false
|
||||
disable_custom_all_reduce: false
|
||||
disable_overlap_schedule: false
|
||||
enable_mixed_chunk: false
|
||||
enable_torch_compile: false
|
||||
torch_compile_max_bs: 32
|
||||
cuda_graph_max_bs: null
|
||||
cuda_graph_bs: null
|
||||
torchao_config: ''
|
||||
enable_nan_detection: false
|
||||
enable_p2p_check: false
|
||||
triton_attention_reduce_in_fp32: false
|
||||
triton_attention_num_kv_splits: 8
|
||||
num_continuous_decode_steps: 1
|
||||
enable_memory_saver: false
|
||||
allow_auto_truncate: false
|
||||
attention_backend: flashinfer
|
||||
sampling_backend: null
|
||||
context_length: 32768
|
||||
mem_fraction_static: 0.8
|
||||
max_running_requests: null
|
||||
chunked_prefill_size: -1
|
||||
max_prefill_tokens: 32768
|
||||
schedule_policy: lpm
|
||||
schedule_conservativeness: 1.0
|
||||
cpu_offload_gb: 0
|
||||
dtype: float16
|
||||
kv_cache_dtype: auto
|
||||
log_level: warning
|
||||
log_level_http: warning
|
||||
log_requests: false
|
||||
log_requests_level: 0
|
||||
show_time_cost: false
|
||||
enable_metrics: true
|
||||
decode_log_interval: 1
|
||||
ref:
|
||||
type:
|
||||
_class: qwen3
|
||||
path: /storage/openpsi/models/Qwen3-1.7B/
|
||||
init_from_scratch: false
|
||||
bf16: false
|
||||
actor_train:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
ref_inf:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
actor_inf:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
|
||||
# Dataset
|
||||
shuffle_dataset: true
|
||||
dataset:
|
||||
path: /storage/datasets/boba_106k_0319.jsonl
|
||||
max_prompt_len: 1024
|
||||
train_bs_n_seqs: 512
|
||||
|
||||
# Algorithm
|
||||
group_size: 16
|
||||
mask_too_long: false
|
||||
group_adv_norm: false
|
||||
rw_type: sparse
|
||||
success_rate_ub: 1.0
|
||||
success_rate_lb: 0.0
|
||||
ppo:
|
||||
gen:
|
||||
n: 1
|
||||
max_new_tokens: 27648
|
||||
min_new_tokens: 0
|
||||
greedy: false
|
||||
top_p: 1.0
|
||||
top_k: 1000000
|
||||
temperature: 1.0
|
||||
ppo_n_minibatches: 4
|
||||
eps_clip: 0.2
|
||||
c_clip: null
|
||||
value_eps_clip: 0.2
|
||||
early_stop_imp_ratio: 5.0
|
||||
actor_sample_reuse: 1
|
||||
critic_sample_reuse: 1
|
||||
max_reward_clip: 20.0
|
||||
reward_output_scaling: 5.0
|
||||
reward_output_bias: 0.0
|
||||
fuse_rew_ref: true
|
||||
discount: 1.0
|
||||
gae_lambda: 1.0
|
||||
adv_norm: true
|
||||
kl_ctl: 0.0
|
||||
use_adaptive_kl_ctl: false
|
||||
disable_value: true
|
||||
recompute_logprob: true
|
||||
use_decoupled_loss: true
|
||||
behav_imp_weight_cap: null
|
||||
|
||||
# worker resources
|
||||
cpus_per_master_worker: 4
|
||||
mem_per_master_worker: 20000
|
||||
cpus_per_model_worker: 4
|
||||
mem_per_model_worker: 90000
|
|
@ -1,65 +0,0 @@
|
|||
max_head_offpolicyness: 4
|
||||
experiment_name: async-ppo-1.7b-gpu32
|
||||
trial_name: my-trial
|
||||
mode: ray
|
||||
cluster:
|
||||
fileroot: /storage/ray/experiments
|
||||
wandb:
|
||||
mode: disabled
|
||||
recover_mode: auto
|
||||
recover_retries: 10
|
||||
allocation_mode: sglang.d24p1m1+d4p2m1
|
||||
n_nodes: 4
|
||||
n_gpus_per_node: 8
|
||||
cache_clear_freq: 1
|
||||
exp_ctrl:
|
||||
total_train_epochs: 5
|
||||
save_freq_epochs: 1
|
||||
ckpt_freq_secs: 600
|
||||
torch_cache_mysophobia: true
|
||||
dataset:
|
||||
path: /storage/datasets/boba_106k_0319.jsonl
|
||||
max_prompt_len: 1024
|
||||
train_bs_n_seqs: 512
|
||||
group_size: 16
|
||||
group_adv_norm: false
|
||||
actor:
|
||||
type:
|
||||
_class: qwen3
|
||||
path: /storage/openpsi/models/Qwen3-1.7B/
|
||||
optimizer:
|
||||
lr: 2e-05
|
||||
lr_scheduler_type: constant
|
||||
eps: 1e-5
|
||||
warmup_steps_proportion: 0.001
|
||||
hysteresis: 2
|
||||
sglang:
|
||||
mem_fraction_static: 0.8
|
||||
actor_train:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
actor_gen:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
actor_inf:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
ppo:
|
||||
gen:
|
||||
max_new_tokens: 27648
|
||||
min_new_tokens: 0
|
||||
top_p: 1.0
|
||||
top_k: 1000000
|
||||
temperature: 1.0
|
||||
ppo_n_minibatches: 4
|
||||
kl_ctl: 0.0
|
||||
discount: 1.0
|
||||
value_eps_clip: 0.2
|
||||
disable_value: true
|
||||
reward_output_scaling: 5
|
||||
reward_output_bias: 0.0
|
||||
adv_norm: true
|
||||
value_norm: true
|
||||
recompute_logprob: true
|
||||
use_decoupled_loss: true
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
max_head_offpolicyness: 4
|
||||
experiment_name: async-ppo-1.7b-gpu8
|
||||
trial_name: my-trial
|
||||
mode: ray
|
||||
cluster:
|
||||
fileroot: /storage/ray/experiments
|
||||
wandb:
|
||||
mode: disabled
|
||||
recover_mode: auto
|
||||
recover_retries: 10
|
||||
allocation_mode: sglang.d4p1m1+d2p2m1
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
cache_clear_freq: 1
|
||||
exp_ctrl:
|
||||
total_train_epochs: 5
|
||||
save_freq_epochs: 1
|
||||
ckpt_freq_secs: 600
|
||||
torch_cache_mysophobia: true
|
||||
dataset:
|
||||
path: /storage/datasets/boba_106k_0319.jsonl
|
||||
max_prompt_len: 1024
|
||||
train_bs_n_seqs: 512
|
||||
group_size: 16
|
||||
group_adv_norm: false
|
||||
actor:
|
||||
type:
|
||||
_class: qwen3
|
||||
path: /storage/openpsi/models/Qwen3-1.7B/
|
||||
optimizer:
|
||||
lr: 2e-05
|
||||
lr_scheduler_type: constant
|
||||
eps: 1e-5
|
||||
warmup_steps_proportion: 0.001
|
||||
hysteresis: 2
|
||||
sglang:
|
||||
mem_fraction_static: 0.8
|
||||
actor_train:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
actor_gen:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
actor_inf:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
ppo:
|
||||
gen:
|
||||
max_new_tokens: 27648
|
||||
min_new_tokens: 0
|
||||
top_p: 1.0
|
||||
top_k: 1000000
|
||||
temperature: 1.0
|
||||
ppo_n_minibatches: 4
|
||||
kl_ctl: 0.0
|
||||
discount: 1.0
|
||||
value_eps_clip: 0.2
|
||||
disable_value: true
|
||||
reward_output_scaling: 5
|
||||
reward_output_bias: 0.0
|
||||
adv_norm: true
|
||||
value_norm: true
|
||||
recompute_logprob: true
|
||||
use_decoupled_loss: true
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
experiment_name: ppo-1.7b-gpu32
|
||||
trial_name: my-trial
|
||||
mode: ray
|
||||
cluster:
|
||||
fileroot: /storage/ray/experiments
|
||||
wandb:
|
||||
mode: disabled
|
||||
recover_mode: auto
|
||||
recover_retries: 10
|
||||
allocation_mode: sglang.d16p1m1+d8p2m1
|
||||
n_nodes: 4
|
||||
n_gpus_per_node: 8
|
||||
cache_clear_freq: 1
|
||||
exp_ctrl:
|
||||
total_train_epochs: 5
|
||||
save_freq_epochs: 1
|
||||
ckpt_freq_secs: 600
|
||||
torch_cache_mysophobia: true
|
||||
dataset:
|
||||
path: /storage/datasets/boba_106k_0319.jsonl
|
||||
max_prompt_len: 1024
|
||||
train_bs_n_seqs: 512
|
||||
group_size: 16
|
||||
group_adv_norm: false
|
||||
actor:
|
||||
type:
|
||||
_class: qwen3
|
||||
path: /storage/openpsi/models/Qwen3-1.7B/
|
||||
optimizer:
|
||||
lr: 2e-05
|
||||
lr_scheduler_type: constant
|
||||
eps: 1e-5
|
||||
warmup_steps_proportion: 0.001
|
||||
hysteresis: 2
|
||||
sglang:
|
||||
mem_fraction_static: 0.8
|
||||
actor_train:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
actor_gen:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
actor_inf:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
ppo:
|
||||
gen:
|
||||
max_new_tokens: 27648
|
||||
min_new_tokens: 0
|
||||
top_p: 1.0
|
||||
top_k: 1000000
|
||||
temperature: 1.0
|
||||
ppo_n_minibatches: 4
|
||||
kl_ctl: 0.0
|
||||
discount: 1.0
|
||||
value_eps_clip: 0.2
|
||||
disable_value: true
|
||||
reward_output_scaling: 5
|
||||
reward_output_bias: 0.0
|
||||
adv_norm: true
|
||||
value_norm: true
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
# Basic experiment info
|
||||
experiment_name: sft
|
||||
trial_name: my-trial
|
||||
seed: 1
|
||||
mode: ray
|
||||
metric_discovery_port: 17997
|
||||
wandb:
|
||||
mode: disabled
|
||||
entity: null
|
||||
project: null
|
||||
name: null
|
||||
job_type: null
|
||||
group: null
|
||||
notes: null
|
||||
tags: null
|
||||
config: null
|
||||
tensorboard:
|
||||
path: null
|
||||
recover_mode: auto
|
||||
recover_retries: 10
|
||||
recover_after: 10
|
||||
|
||||
exp_ctrl:
|
||||
total_train_epochs: 5
|
||||
save_freq_epochs: 1
|
||||
save_freq_steps: null
|
||||
save_freq_secs: null
|
||||
ckpt_freq_epochs: null
|
||||
ckpt_freq_steps: null
|
||||
ckpt_freq_secs: 600
|
||||
eval_freq_epochs: null
|
||||
eval_freq_steps: null
|
||||
eval_freq_secs: null
|
||||
benchmark_steps: null
|
||||
benchmark_n_seqs: null
|
||||
torch_cache_mysophobia: true
|
||||
cache_clear_freq: 1
|
||||
|
||||
# Allocation and parallelism
|
||||
allocation_mode: d4p2m1
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
|
||||
# Cluster configuration
|
||||
ray_temp_path: /tmp/ray
|
||||
cluster:
|
||||
fileroot: /storage/ray/experiments
|
||||
n_nodes: 32
|
||||
n_gpus_per_node: 8
|
||||
|
||||
# Model
|
||||
model:
|
||||
type:
|
||||
_class: qwen2
|
||||
path: /storage/models/DeepSeek-R1-Distill-Qwen-7B
|
||||
init_from_scratch: false
|
||||
gradient_checkpointing: true
|
||||
bf16: false
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 1.0e-05
|
||||
weight_decay: 0.1
|
||||
beta1: 0.9
|
||||
beta2: 0.95
|
||||
eps: 1.0e-05
|
||||
min_lr_ratio: 0.0
|
||||
lr_scheduler_type: constant
|
||||
warmup_steps_proportion: 0.03
|
||||
offload: false
|
||||
initial_loss_scale: 262144.0
|
||||
min_loss_scale: 1.0
|
||||
loss_scale_window: 10.0
|
||||
hysteresis: 2
|
||||
gradient_clipping: 1.0
|
||||
megatron:
|
||||
ddp:
|
||||
grad_reduce_in_fp32: true
|
||||
overlap_grad_reduce: true
|
||||
use_distributed_optimizer: true
|
||||
allocation:
|
||||
mb_spec:
|
||||
n_mbs: 1
|
||||
max_tokens_per_mb: 32768
|
||||
|
||||
# Dataset
|
||||
dataset:
|
||||
train_path: /storage/datasets/boba-sft_200_0319.jsonl
|
||||
valid_path: /storage/datasets/boba-sft_200_0319.jsonl
|
||||
max_seqlen: 32768
|
||||
train_bs_n_seqs: 16
|
||||
valid_bs_n_seqs: 16
|
||||
|
||||
# worker resources
|
||||
cpus_per_master_worker: 4
|
||||
mem_per_master_worker: 20000
|
||||
cpus_per_model_worker: 4
|
||||
mem_per_model_worker: 90000
|
|
@ -1,40 +0,0 @@
|
|||
experiment_name: sft-7b-gpu8
|
||||
trial_name: my-trial
|
||||
mode: ray
|
||||
wandb:
|
||||
mode: disabled
|
||||
recover_mode: auto
|
||||
recover_retries: 10
|
||||
allocation_mode: d2p4m1
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
exp_ctrl:
|
||||
total_train_epochs: 200
|
||||
save_freq_epochs: 1
|
||||
ckpt_freq_secs: 600
|
||||
torch_cache_mysophobia: true
|
||||
dataset:
|
||||
train_path: /storage/datasets/boba-sft_200_0319.jsonl
|
||||
valid_path: /storage/datasets/boba-sft_200_0319.jsonl
|
||||
max_seqlen: 32768
|
||||
train_bs_n_seqs: 16
|
||||
valid_bs_n_seqs: 16
|
||||
model:
|
||||
type:
|
||||
_class: qwen2
|
||||
path: /storage/models/DeepSeek-R1-Distill-Qwen-7B
|
||||
optimizer:
|
||||
type: adam
|
||||
lr_scheduler_type: constant
|
||||
lr: 1e-5
|
||||
warmup_steps_proportion: 0.03
|
||||
initial_loss_scale: 262144.0
|
||||
loss_scale_window: 10
|
||||
hysteresis: 2
|
||||
weight_decay: 0.1
|
||||
eps: 1e-5
|
||||
bf16: true
|
||||
allocation:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 32768
|
||||
|
|
@ -0,0 +1,183 @@
|
|||
# Basic experiment info
|
||||
experiment_name: sync-ppo
|
||||
trial_name: my-trial
|
||||
seed: 1
|
||||
mode: ray
|
||||
metric_discovery_port: 17997
|
||||
wandb:
|
||||
mode: disabled
|
||||
entity: null
|
||||
project: null
|
||||
name: null
|
||||
job_type: null
|
||||
group: null
|
||||
notes: null
|
||||
tags: null
|
||||
config: null
|
||||
tensorboard:
|
||||
path: null
|
||||
recover_mode: auto
|
||||
recover_retries: 10
|
||||
recover_after: 10
|
||||
|
||||
exp_ctrl:
|
||||
total_train_epochs: 5
|
||||
save_freq_epochs: 1
|
||||
save_freq_steps: null
|
||||
save_freq_secs: null
|
||||
ckpt_freq_epochs: null
|
||||
ckpt_freq_steps: null
|
||||
ckpt_freq_secs: 600
|
||||
eval_freq_epochs: null
|
||||
eval_freq_steps: null
|
||||
eval_freq_secs: null
|
||||
benchmark_steps: null
|
||||
benchmark_n_seqs: null
|
||||
torch_cache_mysophobia: true
|
||||
cache_clear_freq: 1
|
||||
|
||||
# Allocation and parallelism
|
||||
allocation_mode: sglang.d4p1m1+d2p2m1
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
|
||||
# Cluster configuration
|
||||
ray_temp_path: /tmp/ray
|
||||
cluster:
|
||||
fileroot: /storage/ray/experiments
|
||||
n_nodes: 32
|
||||
n_gpus_per_node: 8
|
||||
|
||||
# Model
|
||||
actor:
|
||||
type:
|
||||
_class: qwen3
|
||||
path: /storage/openpsi/models/Qwen3-1.7B/
|
||||
init_from_scratch: false
|
||||
gradient_checkpointing: true
|
||||
bf16: false
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2.0e-05
|
||||
weight_decay: 0.05
|
||||
beta1: 0.9
|
||||
beta2: 0.95
|
||||
eps: 1.0e-05
|
||||
min_lr_ratio: 0.0
|
||||
lr_scheduler_type: constant
|
||||
warmup_steps_proportion: 0.001
|
||||
initial_loss_scale: 4294967296.0
|
||||
min_loss_scale: 1.0
|
||||
loss_scale_window: 5.0
|
||||
hysteresis: 2
|
||||
gradient_clipping: 1.0
|
||||
megatron:
|
||||
ddp:
|
||||
grad_reduce_in_fp32: true
|
||||
overlap_grad_reduce: true
|
||||
use_distributed_optimizer: true
|
||||
sglang:
|
||||
disable_cuda_graph: false
|
||||
disable_radix_cache: false
|
||||
disable_cuda_graph_padding: false
|
||||
enable_nccl_nvls: false
|
||||
disable_outlines_disk_cache: false
|
||||
disable_custom_all_reduce: false
|
||||
disable_overlap_schedule: false
|
||||
enable_mixed_chunk: false
|
||||
enable_torch_compile: false
|
||||
torch_compile_max_bs: 32
|
||||
cuda_graph_max_bs: null
|
||||
cuda_graph_bs: null
|
||||
torchao_config: ''
|
||||
enable_nan_detection: false
|
||||
enable_p2p_check: false
|
||||
triton_attention_reduce_in_fp32: false
|
||||
triton_attention_num_kv_splits: 8
|
||||
num_continuous_decode_steps: 1
|
||||
enable_memory_saver: false
|
||||
allow_auto_truncate: false
|
||||
attention_backend: flashinfer
|
||||
sampling_backend: null
|
||||
context_length: 32768
|
||||
mem_fraction_static: 0.8
|
||||
max_running_requests: null
|
||||
chunked_prefill_size: -1
|
||||
max_prefill_tokens: 32768
|
||||
schedule_policy: lpm
|
||||
schedule_conservativeness: 1.0
|
||||
cpu_offload_gb: 0
|
||||
dtype: float16
|
||||
kv_cache_dtype: auto
|
||||
log_level: warning
|
||||
log_level_http: warning
|
||||
log_requests: false
|
||||
log_requests_level: 0
|
||||
show_time_cost: false
|
||||
enable_metrics: true
|
||||
decode_log_interval: 1
|
||||
ref:
|
||||
type:
|
||||
_class: qwen3
|
||||
path: /storage/openpsi/models/Qwen3-1.7B/
|
||||
init_from_scratch: false
|
||||
bf16: false
|
||||
actor_train:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
ref_inf:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
actor_inf:
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 30720
|
||||
|
||||
# Dataset
|
||||
shuffle_dataset: true
|
||||
dataset:
|
||||
path: /storage/datasets/boba_106k_0319.jsonl
|
||||
max_prompt_len: 1024
|
||||
train_bs_n_seqs: 512
|
||||
|
||||
# Algorithm
|
||||
group_size: 16
|
||||
mask_too_long: false
|
||||
group_adv_norm: false
|
||||
rw_type: sparse
|
||||
success_rate_ub: 1.0
|
||||
success_rate_lb: 0.0
|
||||
ppo:
|
||||
gen:
|
||||
n: 1
|
||||
max_new_tokens: 27648
|
||||
min_new_tokens: 0
|
||||
greedy: false
|
||||
top_p: 1.0
|
||||
top_k: 1000000
|
||||
temperature: 1.0
|
||||
ppo_n_minibatches: 4
|
||||
eps_clip: 0.2
|
||||
c_clip: null
|
||||
value_eps_clip: 0.2
|
||||
early_stop_imp_ratio: 5.0
|
||||
actor_sample_reuse: 1
|
||||
critic_sample_reuse: 1
|
||||
max_reward_clip: 20.0
|
||||
reward_output_scaling: 5.0
|
||||
reward_output_bias: 0.0
|
||||
fuse_rew_ref: true
|
||||
discount: 1.0
|
||||
gae_lambda: 1.0
|
||||
adv_norm: true
|
||||
kl_ctl: 0.0
|
||||
use_adaptive_kl_ctl: false
|
||||
disable_value: true
|
||||
recompute_logprob: false
|
||||
use_decoupled_loss: false
|
||||
behav_imp_weight_cap: null
|
||||
|
||||
# worker resources
|
||||
cpus_per_master_worker: 4
|
||||
mem_per_master_worker: 20000
|
||||
cpus_per_model_worker: 4
|
||||
mem_per_model_worker: 90000
|
|
@ -61,5 +61,15 @@ def main_ppo_math(args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command: python3 training/main_async_ppo.py --config-name async-ppo-1.7b-gpu8
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--help", action="store_true")
|
||||
args = parser.parse_args()
|
||||
if args.help:
|
||||
from realhf.api.cli_args import print_config_help
|
||||
|
||||
print_config_help(AsyncPPOMATHConfig())
|
||||
exit(0)
|
||||
|
||||
main_ppo_math()
|
||||
|
|
|
@ -13,7 +13,7 @@ from realhf.experiments.common.sft_exp import SFTConfig
|
|||
from training.utils import run_experiment
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="configs/sft")
|
||||
@hydra.main(version_base=None, config_path="configs", config_name="sft")
|
||||
def main(args):
|
||||
# NOTE: we import logging here to avoid hydra logging overwrite
|
||||
import realhf.base.logging as logging
|
||||
|
@ -61,5 +61,14 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command: python3 training/main_sft.py --config-name sft-7b-gpu8
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--help", action="store_true")
|
||||
args = parser.parse_args()
|
||||
if args.help:
|
||||
from realhf.api.cli_args import print_config_help
|
||||
|
||||
print_config_help(SFTConfig())
|
||||
exit(0)
|
||||
main()
|
||||
|
|
|
@ -13,7 +13,7 @@ from realhf.experiments.common.ppo_math_exp import PPOMATHConfig
|
|||
from training.utils import run_experiment
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="configs/ppo")
|
||||
@hydra.main(version_base=None, config_path="configs", config_name="sync-ppo")
|
||||
def main(args):
|
||||
# NOTE: we import logging here to avoid hydra logging overwrite
|
||||
import realhf.base.logging as logging
|
||||
|
@ -61,5 +61,14 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command: python3 training/main_ppo.py --config-name ppo-1.5b-gpu32
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--help", action="store_true")
|
||||
args = parser.parse_args()
|
||||
if args.help:
|
||||
from realhf.api.cli_args import print_config_help
|
||||
|
||||
print_config_help(PPOMATHConfig())
|
||||
exit(0)
|
||||
main()
|
|
@ -124,15 +124,9 @@ def _run_experiment(exp_cfg, expr_name, trial_name):
|
|||
# Initialize ray in the Ray cluster
|
||||
env_vars = constants.get_env_vars(
|
||||
WADNB_MODE=exp_cfg.wandb.mode,
|
||||
REAL_MODE=os.environ.get("REAL_MODE", ""),
|
||||
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
|
||||
REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""),
|
||||
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
|
||||
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
||||
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
|
||||
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),
|
||||
REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"),
|
||||
REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", "localhost:2379"),
|
||||
REAL_MODE="ray",
|
||||
REAL_RECOVER_RUN="0",
|
||||
REAL_SAVE_RECOVER_STATES="1",
|
||||
)
|
||||
runtime_env = {
|
||||
"env_vars": env_vars,
|
||||
|
|
Loading…
Reference in New Issue