From 4fab3ac7694ad393139bb3a67edd52225377da74 Mon Sep 17 00:00:00 2001 From: Wei Fu <36355462+garrett4wade@users.noreply.github.com> Date: Sun, 1 Jun 2025 14:57:21 +0800 Subject: [PATCH] [Doc & Fix] Simplify the environment setup procedure (#62) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * PullRequest: 176 [FIX] clear sensitive info Merge branch fw/fix-sensitive-info of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/176 Reviewed-by: 晓雷 * . * . * . * . * . * test env setup * fix * allow cached model * . * revise docs * change docs * format docs * update readme --- README.md | 51 +- docs/_toc.yml | 2 +- docs/tutorial/eval.md | 100 ++-- docs/tutorial/installation.md | 59 ++- docs/tutorial/quickstart.md | 107 ++++ docs/tutorial/troubleshooting.md | 5 +- evaluation/code_verifier/local_verify.py | 2 +- evaluation/requirements.txt | 8 +- examples/env/scripts/setup-eval-pip-deps.sh | 9 + examples/env/scripts/setup-pip-deps.sh | 28 + examples/run_async_ppo.sh | 18 + examples/run_sft.sh | 13 + examples/run_sync_ppo.sh | 15 + functioncall/code/local_verify.py | 4 +- functioncall/code/verify.py | 10 +- functioncall/test/test_success_dataset.jsonl | 6 - patch/sglang/v0.4.6.post4.patch | 498 ------------------ pyproject.toml | 100 ++++ realhf/apps/main.py | 3 - realhf/base/constants.py | 9 + realhf/experiments/common/check.py | 90 +++- realhf/impl/dataset/math_parser.py | 13 +- realhf/impl/model/backend/inference.py | 8 +- realhf/impl/model/backend/pipe_runner.py | 6 +- realhf/system/buffer.py | 8 +- realhf/system/controller.py | 6 - realhf/system/function_executor.py | 4 +- realhf/system/gserver_manager.py | 6 +- realhf/system/master_worker.py | 8 +- realhf/system/model_function_call.py | 6 +- realhf/system/model_worker.py | 10 +- realhf/system/rollout_worker.py | 8 +- training/configs/async-ppo.yaml | 198 +++++++ .../async-ppo/async-ppo-1.7b-gpu32.yaml | 65 --- .../async-ppo/async-ppo-1.7b-gpu8.yaml | 65 --- training/configs/ppo/ppo-1.5b-gpu32.yaml | 62 --- training/configs/sft.yaml | 97 ++++ training/configs/sft/sft-7b-gpu8.yaml | 40 -- training/configs/sync-ppo.yaml | 183 +++++++ training/main_async_ppo.py | 12 +- training/main_sft.py | 13 +- training/{main_ppo.py => main_sync_ppo.py} | 13 +- training/utils.py | 12 +- 43 files changed, 1082 insertions(+), 898 deletions(-) create mode 100644 docs/tutorial/quickstart.md create mode 100644 examples/env/scripts/setup-eval-pip-deps.sh create mode 100644 examples/env/scripts/setup-pip-deps.sh create mode 100644 examples/run_async_ppo.sh create mode 100644 examples/run_sft.sh create mode 100644 examples/run_sync_ppo.sh delete mode 100644 functioncall/test/test_success_dataset.jsonl create mode 100644 training/configs/async-ppo.yaml delete mode 100644 training/configs/async-ppo/async-ppo-1.7b-gpu32.yaml delete mode 100644 training/configs/async-ppo/async-ppo-1.7b-gpu8.yaml delete mode 100644 training/configs/ppo/ppo-1.5b-gpu32.yaml create mode 100644 training/configs/sft.yaml delete mode 100644 training/configs/sft/sft-7b-gpu8.yaml create mode 100644 training/configs/sync-ppo.yaml rename training/{main_ppo.py => main_sync_ppo.py} (82%) diff --git a/README.md b/README.md index 10b1315..eed448c 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@

-| Documentation | +| Documentation | Ask DeepWiki |

ReaL @@ -17,7 +17,6 @@ AReaL (Ant Reasoning RL) is a fully open-sourced, scalable, and efficient reinfo + 🔪 **Cutting-Edge Performances:** AReaL can produce models with cutting-edge reasoning capabilities. We are actively working on other domains, such as coding and agent, as well. ## News -**[2025/04/27]** 🔥 We've built a [documentation website](https://deepwiki.com/inclusionAI/AReaL) using the amazing [DeepWiki](https://deepwiki.com/) tool. Check the link to know and ask about AReaL! **[2025/03/31]** **(v0.2, Boba)** Our milestone release Boba! Please call it A-ReaL-Boba! This release includes much accelerated training with SGLang support and SOTA 7B and 32B models on math reasoning. @@ -72,23 +71,51 @@ Building upon **R1-Distill-Qwen-32B**, we replicate **QwQ-32B's** inference perf ## Getting Started ### Quick Start -```bash -# Train the distilled 7B model -python3 -m realhf.apps.quickstart ppo-math \ - --config examples/configs/7B-distill/ppo-7B-distill-gpus-128.yaml -# Evaluate the 7B model -python evaluation/eval_and_aggregate.py \ +Train Qwen3 1.7B locally: + +```bash +bash examples/env/scripts/setup-pip-deps.sh +python3 training/main_async_ppo.py \ + n_nodes=1 n_gpus_per_node=8 \ + allocation_mode=sglang.d4p1m1+d2p2m1 \ + cluster.fileroot=/storage/testing/experiments \ + actor.type._class=qwen3 \ + actor.path=Qwen/Qwen3-1.7B \ + ref.type._class=qwen3 \ + ref.path=Qwen/Qwen3-1.7B \ + dataset.path=/path/to/dataset/boba_106k_0319.jsonl \ + dataset.train_bs_n_seqs=32 \ + group_size=8 \ + ppo.gen.max_new_tokens=4096 \ + ppo.ppo_n_minibatches=4 \ + actor_train.mb_spec.max_tokens_per_mb=32768 \ + actor_inf.mb_spec.max_tokens_per_mb=32768 \ + max_concurrent_rollouts=16 \ + max_head_offpolicyness=4 +``` + +Evaluation + +```bash +bash examples/env/scripts/setup-eval-pip-deps.sh +cd evaluation +# Evaluate the model +python eval_and_aggregate.py \ --model_path ${MODEL_PATH} \ --output_path ${OUTPUT_PATH} \ --data_names aime24,aime25 \ - --prompt_type AReaL-boba \ - --output_path outputs --temperature 1.0 + --max_gen_tokens 32768 \ + --data_names codeforces,lcb_v5 \ + --prompt_type qwen3-think-pure \ + --temperature 1.0 ``` ### Resources -+ [Tutorial](/examples/README.md) -+ [Tutorial (中文)](/examples/README_zh.md) + ++ [Installation](https://inclusionai.github.io/AReaL/tutorial/installation.html) ++ [Quickstart](https://inclusionai.github.io/AReaL/tutorial/quickstart.html) ++ [Contributing](https://inclusionai.github.io/AReaL/contrib.html) ## Future Plan AReaL is under active development. We will have major releases in a weekly manner. We also highly appreciate efforts from the community as well. Here we highlight our future research and development plan. diff --git a/docs/_toc.yml b/docs/_toc.yml index e487d4c..71af37b 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -7,7 +7,7 @@ parts: - caption: Tutorial chapters: - file: tutorial/installation - - file: tutorial/training + - file: tutorial/quickstart - file: tutorial/eval - file: tutorial/agent - file: tutorial/troubleshooting diff --git a/docs/tutorial/eval.md b/docs/tutorial/eval.md index 5a34879..96c4aee 100644 --- a/docs/tutorial/eval.md +++ b/docs/tutorial/eval.md @@ -1,68 +1,60 @@ # Evaluation -The evaluation code is located in the `evaluation` folder of the repository. Following the previous tutorial, trained checkpoints will be saved under `/storage/ray/experiments/checkpoints/root/`. +The evaluation code is located in the `evaluation` folder of the repository. Following the previous tutorial, trained checkpoints will be saved under `${fileroot}/checkpoints/${USER}/${experiment_name}/${trial_name}/`. ## Setup Evaluation Environment -Start a new container to execute the evaluation script. **Note**: Evaluation requires updates to certain Python libraries, so avoid using the training container for this task. +> **Note**: Evaluation requires updates to certain Python libraries, so avoid using the training container or virtual environment for this task. + +From the repository directory, create a new conda environment: ```bash -docker run -d --name areal-eval --privileged --gpus all --network host --shm-size 700g -v /storage:/storage ghcr.io/inclusionai/areal-runtime:v0.3.0 /bin/bash -c "tail -f /dev/null" -docker exec -it areal-eval bash +conda create -n areal-eval python=3.12 +conda activate areal-eval ``` -## Install Dependencies - -Execute the following commands inside the Docker container: +Install dependencies: ```bash -cd /storage/codes/AReaL/evaluation -cd latex2sympy -pip install -e . -cd .. -pip install -r requirements.txt -pip install vllm==0.8.5 --no-build-isolation -pip install transformers==4.51.1 -pip install prettytable timeout_decorator +bash examples/env/scripts/setup-eval-pip-deps.sh ``` ## Run Evaluation -Specify an output_path to save the test results (optional — if not specified, the results will be saved in `model_path`): - -```bash -mkdir -p /storage/ray/eval_output/ -``` - -### Math Eval +Specify an `output_path` to save the test results. If not specified, the results will be saved in `model_path`. + +### Math Evaluation ```bash +cd evaluation nohup python eval_and_aggregate.py \ - --model_path /storage/ray/experiments/checkpoints/root/my-exp/my-trial/epoch1epochstep20globalstep20/ \ - --output_path /storage/ray/eval_output/ \ - --max_gen_tokens 32768 + --model_path /path/to/checkpoint \ + --output_path /path/to/outputs \ + --max_gen_tokens 32768 \ --data_names math_500,aime24,amc23 \ --prompt_type qwen3-think \ - --task math &> /storage/ray/eval_output/eval_and_aggregate_parallel.log & + --task math &> eval_and_aggregate_parallel.log & ``` -### Code Eval -**Obtaining Data:** -- Consider the size of code datasets (Because some of test cases are relatively large), we upload all our code datasets to Huggingface: [todo:upload the code dataset to Huggingface](). -- Once you have downloaded the code dataset, place it under **`./evaluation/data/`**. +### Code Evaluation -**Running Eval:** +**Obtaining Data:** +- Due to the size of code datasets (some test cases are relatively large), we have uploaded all our code datasets to [Hugging Face](https://huggingface.co/inclusionAI). +- Once you have downloaded the code dataset, place it under `./evaluation/data/`. + +**Running Evaluation:** ```bash +cd evaluation nohup python eval_and_aggregate.py \ - --model_path /storage/ray/experiments/checkpoints/root/my-exp/my-trial/epoch1epochstep20globalstep20/ \ - --output_path /storage/ray/eval_output/ \ - --max_gen_tokens 32768 \ - --data_names codeforces,lcb_v5 \ - --prompt_type qwen3-think-pure \ - --num_sample_nodes 8 \ - --samples_per_node 1 \ - --n_sampling $((num_sample_nodes * samples_per_node)) \ - --task code &> /storage/ray/eval_output/eval_and_aggregate_parallel.log & + --model_path /path/to/checkpoint \ + --output_path /path/to/outputs \ + --max_gen_tokens 32768 \ + --data_names codeforces,lcb_v5 \ + --prompt_type qwen3-think-pure \ + --num_sample_nodes 8 \ + --samples_per_node 1 \ + --n_sampling $((num_sample_nodes * samples_per_node)) \ + --task code &> eval_and_aggregate_parallel.log & ``` ### Command Line Parameters @@ -70,14 +62,16 @@ nohup python eval_and_aggregate.py \ - **`--model_path`**: Path to the saved model parameters - **`--output_path`**: Path to store generated answers and log files during evaluation - **`--data_names`**: Dataset(s) to evaluate. Multiple datasets can be separated by commas. Available options: - - math: `math_500`, `aime24`, `aime25`, `amc23` - - code: `lcb_v5`, `lcb_v5_2410_2502`, `codeforces`, `code_contest_all` + - Math: `math_500`, `aime24`, `aime25`, `amc23` + - Code: `lcb_v5`, `lcb_v5_2410_2502`, `codeforces`, `code_contest_all` - **`--max_gen_tokens`**: Maximum length of generated answers (default: 32768) -- **`--prompt_type`**: Specify the prompt template, for our latest model, we use `qwen3-think` for math dataset, and `qwen3-think-pure` for code dataset. -- **`--num_sample_nodes`**: Number of multiple sampling seeds to ensure saampling diversity. -- **`--samples_per_node`**: Number of samples to generate per seed for each problem. +- **`--prompt_type`**: Specify the prompt template. For our latest model, we use `qwen3-think` for math datasets and `qwen3-think-pure` for code datasets. +- **`--num_sample_nodes`**: Number of multiple sampling seeds to ensure sampling diversity. +- **`--samples_per_node`**: Number of samples to generate per seed for each problem. -## Evaluation Results +## Logs and Evaluation Results + +Check `${output_path}/math_eval_${max_gen_tokens}/logs` to review the log of each worker. The evaluation script will output a results table in the terminal: @@ -97,7 +91,7 @@ The evaluation script will output a results table in the terminal: - **`greedy_acc`**: Average accuracy under greedy sampling - **`sample_pass@{k}`**: Probability of generating a correct answer within `k` attempts under random sampling -For Codeforces dataset, we use the Elo ranking algorithm to evaluate model performance, referring to [CodeElo](https://github.com/QwenLM/CodeElo) and [rllm](https://github.com/agentica-project/rllm): +For the Codeforces dataset, we use the Elo ranking algorithm to evaluate model performance, referring to [CodeElo](https://github.com/QwenLM/CodeElo) and [rllm](https://github.com/agentica-project/rllm): ``` +------------+----------------+-----------+ @@ -107,19 +101,17 @@ For Codeforces dataset, we use the Elo ranking algorithm to evaluate model perfo +------------+----------------+-----------+ ``` -- **`CF Rating`**: The overall elo rank score of model across 57 Codeforces contests. -- **`Percentile`**: The Elo ranking percentile of model among all Codeforces users. -**Note**: As the penalty mechanism may cause fluctuations in Elo rankings, we suggest performing multiple evaluations and regard the average score as the final result. +- **`CF Rating`**: The overall Elo rank score of the model across 57 Codeforces contests. +- **`Percentile`**: The Elo ranking percentile of the model among all Codeforces users. + +> **Note**: As the penalty mechanism may cause fluctuations in Elo rankings, we suggest performing multiple evaluations and taking the average score as the final result. ## Configuration Details ### Sampling Parameters - The evaluation script defaults to averaging 32 samples with temperature 1.0. For the code dataset, we set it to 8 samples. -- We observed that the `enforce_eager` parameter in vLLM significantly impacts evaluation performance -- When `enforce_eager=True`, we can reproduce the model performance reported in previous work -- Without this setting, evaluation results may fall below reported performance -- Therefore, we enforce `enforce_eager=True` during evaluation +- We observed that the `enforce_eager` parameter in vLLM significantly impacts evaluation performance. When `enforce_eager=True`, we can reproduce the model performance reported in previous work. Without this setting, evaluation results may fall below reported performance. Therefore, we enforce `enforce_eager=True` during evaluation. ### Runtime Expectations diff --git a/docs/tutorial/installation.md b/docs/tutorial/installation.md index 863419c..fb71bbe 100644 --- a/docs/tutorial/installation.md +++ b/docs/tutorial/installation.md @@ -30,49 +30,60 @@ The following hardware configuration has been extensively tested: ## Runtime Environment -We recommend using Docker with our provided image. The Dockerfile is available in the top-level directory of the AReaL repository. +**For multi-node training**: Ensure a shared storage path is mounted on every node (and mounted to the container if you are using Docker). This path will be used to save checkpoints and logs. -Pull the Docker image: +### Option 1: Docker (Recommended) + +We recommend using Docker with our provided image. The Dockerfile is available in the top-level directory of the AReaL repository. ```bash docker pull ghcr.io/inclusionai/areal-runtime:v0.3.0 +docker run -it --name areal-node1 \ + --privileged --gpus all --network host \ + --shm-size 700g -v /path/to/mount:/path/to/mount \ + ghcr.io/inclusionai/areal-runtime:v0.3.0 \ + /bin/bash ``` -This image includes all training requirements for AReaL. +### Option 2: Custom Environment Installation -**For multi-node training**: Ensure shared storage is mounted to the `/storage` directory on every node. All downloads and resources will be stored in this directory, and the AReaL container will mount this directory to `/storage` within the container. +1. Install [Miniconda](https://www.anaconda.com/docs/getting-started/miniconda/install) or [Anaconda](https://www.anaconda.com/docs/getting-started/anaconda/install). -## Code Setup - -Clone the AReaL project code to `/storage/codes`: +2. Create a conda virtual environment: + +```bash +conda create -n areal python=3.12 +conda activate areal +``` + +3. Install pip dependencies: ```bash -mkdir -p /storage/codes -cd /storage/codes/ git clone https://github.com/inclusionAI/AReaL -pip install -r AReaL/requirements.txt +cd AReaL +bash examples/env/scripts/setup-pip-deps.sh ``` -## Dataset +## (Optional) Launch Ray Cluster for Distributed Training -Download the provided training dataset and place it in `/storage/datasets/`: +On the first node, start the Ray Head: ```bash -mkdir -p /storage/datasets/ -cd /storage/datasets/ -wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/boba_106k_0319.jsonl?download=true +ray start --head ``` -## Model - -We train using open-source models available on Hugging Face Hub. Here's an example using Qwen3 (ensure Git LFS is installed): +On all other nodes, start the Ray Worker: ```bash -mkdir -p /storage/models -cd /storage/models -GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Qwen/Qwen3-1.7B -cd Qwen3-1.7B -git lfs pull +# Replace with the actual IP address of the first node +RAY_HEAD_IP=xxx.xxx.xxx.xxx +ray start --address $RAY_HEAD_IP ``` -**Alternative**: You can also use the Hugging Face CLI to download models after installing the `huggingface_hub` package. Refer to the [official documentation](https://huggingface.co/docs/huggingface_hub/guides/cli) for details. \ No newline at end of file +You should see the Ray resource status displayed when running `ray status`. + +Properly set the `n_nodes` argument in AReaL's training command, then AReaL's training script will automatically detect the resources and allocate workers to the cluster. + +## Next Steps + +Check the [quickstart section](quickstart.md) to launch your first AReaL job. \ No newline at end of file diff --git a/docs/tutorial/quickstart.md b/docs/tutorial/quickstart.md new file mode 100644 index 0000000..11c70ad --- /dev/null +++ b/docs/tutorial/quickstart.md @@ -0,0 +1,107 @@ +# Quickstart + +This guide walks through a simple example of training an LLM to solve math problems. + +## Dataset + +Use `huggingface-cli` to download our open-source dataset: + +```bash +huggingface-cli download --repo-type=dataset inclusionAI/AReaL-RL-Data +``` + +> **Note**: The above command will display the path of the downloaded dataset. You'll need to pass this path to the training command. + +## Model + +We train using open-source models available on Hugging Face Hub. You can either download the model in advance or use the model identifier when running the experiment. + +```bash +# If you want to download it in advance +huggingface-cli download Qwen/Qwen3-1.7B +``` + +Refer to the [official documentation](https://huggingface.co/docs/huggingface_hub/guides/cli) for more information on using `huggingface-cli`. + +## Training + +From the repository directory, run: + +```bash +# examples/run_async_ppo.sh +python3 training/main_sync_ppo.py \ + n_nodes=1 n_gpus_per_node=8 \ + allocation_mode=sglang.d4p1m1+d2p2m1 \ + cluster.fileroot=/path/to/save/logs/checkpoints/ \ + actor.type._class=qwen3 \ + actor.path=/path/to/models/Qwen__Qwen3-1.7B \ + ref.type._class=qwen3 \ + ref.path=/path/to/models/Qwen__Qwen3-1.7B \ + dataset.path=/path/to/dataset/boba_106k_0319.jsonl \ + dataset.train_bs_n_seqs=32 \ + group_size=8 \ + ppo.gen.max_new_tokens=4096 \ + ppo.ppo_n_minibatches=4 \ + actor_train.mb_spec.max_tokens_per_mb=32768 +``` + +## Command Line Options + +To view all available options: + +```bash +python3 training/main_sync_ppo.py --help +``` + +### Configuration Parameters + +- **`experiment_name`**: The name of your project. +- **`trial_name`**: The name of this trial in your project. +- **`{actor|ref}.path`**: The path to the model files. +- **`dataset.path`**: The path to the dataset JSONL file. +- **`cluster.fileroot`**: The root path for saving training outputs (logs and checkpoints). +- **`n_nodes`**: The number of nodes in the cluster. +- **`n_gpus_per_node`**: The number of GPUs per node. +- **`allocation_mode`**: The GPU allocation strategy and 3D parallelism configuration for the experiment. Format: + - `sglang.d${DP1}m${TP1}p${PP1}+d${DP2}m${TP2}p${PP2}`: Configures parallel strategies for SGLang generation and training respectively. Generation and training use separate GPU sets, and the total GPU count must equal: DP1×TP1×PP1 + DP2×TP2×PP2 = #GPUs. + +### Training Control + +- **`exp_ctrl.total_train_epochs`**: Number of training epochs (complete dataset iterations). +- **`exp_ctrl.save_freq_{epochs|steps|secs}`**: Frequency for saving model parameters to persistent storage. Set to null to disable saving. +- **`exp_ctrl.ckpt_freq_{epochs|steps|secs}`**: Frequency for saving temporary parameters for restart capability. +- **`dataset.train_bs_n_seqs`**: Training batch size (number of prompts sampled per training iteration). +- **`group_size`**: Number of responses sampled per prompt. + +### Memory and Performance + +- **`{actor_train|ref_inf|actor_inf}.mb_spec.max_tokens_per_mb`**: Maximum tokens per mini-batch for forward/backward passes during reference model inference and actor model training. Reduce this value to avoid OOM errors. +- **`ppo.ppo_n_minibatches`**: Number of mini-batches for dividing data during each PPO update. + +### PPO Configuration + +- **`ppo.recompute_logprob`**: Whether to compute proximal log probabilities for training. Defaults to True for asynchronous experiments and False for synchronous baselines. +- **`ppo.use_decoupled_loss`**: Use decoupled loss to stabilize asynchronous training. Defaults to True. +- **`ppo.gen.max_new_tokens`**: Maximum tokens to generate per prompt. + +## Monitoring the Training Process + +We recommend using Weights & Biases (wandb) for monitoring. Run `wandb login` or set the `WANDB_API_KEY` environment variable. Set `wandb.mode=True` in your configuration to upload training statistics. + +You can also use TensorBoard by setting the `tensorboard.path` parameter. + +The main log will be saved to `${fileroot}/logs/${USER}/${experiment_name}/${trial_name}/main.log` and contains the statistics uploaded to wandb. + +### Key Training Statistics + +- **`Epoch 1/5`**: Indicates total epochs required and current epoch being trained. +- **`step 6/19`**: Shows the current epoch has 19 steps, with the 6th step just completed. +- **`global step 6`**: Step count across all epochs. +- **`ppo_actor/task_reward/avg`**: Average reward value of all sampled responses in this step. Should steadily increase during training and eventually stabilize. +- **`ppo_actor/importance_weight/avg`**: Average importance sampling ratio across all tokens in the PPO loss. Typically close to 1.0. +- **`ppo_actor/actor_clip_ratio/avg`**: Ratio of clipped tokens in PPO loss to total tokens. Usually less than 0.1. +- **`ppo_actor/actor_loss/avg`**: PPO loss value. **Does not show clear trends during training** and should not be used as a performance indicator. + +## Next Steps + +[Evaluate your model](eval.md) or check the [troubleshooting section](troubleshooting.md) if you encounter any issues. \ No newline at end of file diff --git a/docs/tutorial/troubleshooting.md b/docs/tutorial/troubleshooting.md index 5061a5a..2ebde1b 100644 --- a/docs/tutorial/troubleshooting.md +++ b/docs/tutorial/troubleshooting.md @@ -43,9 +43,10 @@ The key to resolving this issue is identifying the phase where the error occurs: #### During SGLang Generation - Decrease the `actor.sglang.mem_fraction_static` parameter - Increase the tensor parallelism degree +- Decrease the `max_concurrent_rollouts` parameter for asynchronous RL #### During `actor_inf` or `actor_train` -- **Adjust microbatch size**: Set parameters like `actor_train.mb_spec.max_tokens_per_mb=20480`. This parameter limits tokens per forward/backward pass and can be set as low as the maximum sequence length (including prompt) +- **Adjust microbatch size**: Decrease the parameter `{actor_train|actor_inf}.mb_spec.max_tokens_per_mb=20480`. This parameter limits tokens per forward/backward pass and can be set as low as the maximum sequence length (including prompt) - **Modify parallelism strategy**: Adjust `allocation_mode` by: - Reducing data parallelism - Increasing tensor or pipeline parallelism @@ -53,4 +54,4 @@ The key to resolving this issue is identifying the phase where the error occurs: ### CUDA Error: Out of Memory -This issue may occur during data transfer. Try increasing `mem_per_xx_worker` in the CLI arguments. \ No newline at end of file +This issue may occur during data transfer. Try increasing `mem_per_model_worker` in the CLI arguments. \ No newline at end of file diff --git a/evaluation/code_verifier/local_verify.py b/evaluation/code_verifier/local_verify.py index a402dd5..8c0a2b6 100644 --- a/evaluation/code_verifier/local_verify.py +++ b/evaluation/code_verifier/local_verify.py @@ -173,7 +173,7 @@ def evaluate(samples): if __name__ == "__main__": - # data_list = load_jsonl("functioncall/test/test_success_dataset.jsonl") + # data_list = load_jsonl("/storage/openpsi/data/functioncall/test/test_success_dataset.jsonl") data_path = "/storage/openpsi/data/code/deepcoder/deepcoder_0415_v3_verify_new_correct.jsonl" id2info = defaultdict(dict) diff --git a/evaluation/requirements.txt b/evaluation/requirements.txt index 49a45d5..3b211d4 100644 --- a/evaluation/requirements.txt +++ b/evaluation/requirements.txt @@ -1,11 +1,10 @@ # common -vllm tqdm datasets -torch -transformers==4.47.0 +transformers==4.51.1 +hf_transfer +huggingface_hub python_dateutil -flash_attn # math_eval sympy==1.12 @@ -14,4 +13,5 @@ word2number Pebble prettytable timeout-decorator +timeout_decorator wandb diff --git a/examples/env/scripts/setup-eval-pip-deps.sh b/examples/env/scripts/setup-eval-pip-deps.sh new file mode 100644 index 0000000..fbdf4ee --- /dev/null +++ b/examples/env/scripts/setup-eval-pip-deps.sh @@ -0,0 +1,9 @@ +#/bin/bash +# basic dependencies +pip install -U pip +pip uninstall deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y +pip install nvidia-ml-py +pip install -e evaluation/latex2sympy +pip install vllm==0.8.5 --no-build-isolation +pip install flash_attn --no-build-isolation +pip install -r evaluation/requirements.txt \ No newline at end of file diff --git a/examples/env/scripts/setup-pip-deps.sh b/examples/env/scripts/setup-pip-deps.sh new file mode 100644 index 0000000..6acc768 --- /dev/null +++ b/examples/env/scripts/setup-pip-deps.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# basic dependencies +pip install -U pip +pip uninstall deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y +pip install "sglang[all]==0.4.6.post4" +pip install megatron-core==0.11.0 nvidia-ml-py +pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose +pip install flash-attn --no-build-isolation + +# the sympy virtual env for reward computation +pip install virtualenv +rm -rf ./sympy +python3 -m venv ./sympy +# equivalent to install `./evaluation/latex2sympy` in the sympy virtual env +./sympy/bin/pip install git+https://github.com/QwenLM/Qwen2.5-Math.git#subdirectory=evaluation/latex2sympy +./sympy/bin/pip install regex numpy tqdm datasets python_dateutil sympy==1.12 antlr4-python3-runtime==4.11.1 word2number Pebble timeout-decorator prettytable + +# Install an editable sglang +rm -rf ./sglang +git clone -b v0.4.6.post4 https://github.com/sgl-project/sglang +AREAL_PATH=$PWD +cd sglang +git apply ../patch/sglang/v0.4.6.post4.patch +pip install -e "python[all]" --no-deps +cd $AREAL_PATH + +# Install AReaL +pip install -e . diff --git a/examples/run_async_ppo.sh b/examples/run_async_ppo.sh new file mode 100644 index 0000000..e6c7d43 --- /dev/null +++ b/examples/run_async_ppo.sh @@ -0,0 +1,18 @@ +#!/bin/bash +python3 training/main_async_ppo.py \ + n_nodes=1 n_gpus_per_node=8 \ + allocation_mode=sglang.d4p1m1+d2p2m1 \ + cluster.fileroot=/storage/testing/experiments \ + actor.type._class=qwen3 \ + actor.path=/storage/testing/models/Qwen__Qwen3-1.7B \ + ref.type._class=qwen3 \ + ref.path=/storage/testing/models/Qwen__Qwen3-1.7B \ + dataset.path=/storage/testing/dataset/boba_106k_0319.jsonl \ + dataset.train_bs_n_seqs=32 \ + group_size=8 \ + ppo.gen.max_new_tokens=4096 \ + ppo.ppo_n_minibatches=4 \ + actor_train.mb_spec.max_tokens_per_mb=32768 \ + actor_inf.mb_spec.max_tokens_per_mb=32768 \ + max_concurrent_rollouts=16 \ + max_head_offpolicyness=4 \ No newline at end of file diff --git a/examples/run_sft.sh b/examples/run_sft.sh new file mode 100644 index 0000000..0a36a52 --- /dev/null +++ b/examples/run_sft.sh @@ -0,0 +1,13 @@ +#!/bin/bash +python3 training/main_sft.py \ + n_nodes=1 n_gpus_per_node=8 \ + allocation_mode=d4p2m1 \ + cluster.fileroot=/storage/testing/experiments \ + model.type._class=qwen3 \ + exp_ctrl.eval_freq_epochs=1 \ + model.path=/storage/testing/models/Qwen__Qwen3-1.7B \ + dataset.train_path=/storage/testing/dataset/areal-sft-stage2-200.jsonl \ + dataset.valid_path=/storage/testing/dataset/areal-sft-stage2-200.jsonl \ + dataset.train_bs_n_seqs=64 \ + dataset.valid_bs_n_seqs=64 \ + allocation.mb_spec.max_tokens_per_mb=32768 \ No newline at end of file diff --git a/examples/run_sync_ppo.sh b/examples/run_sync_ppo.sh new file mode 100644 index 0000000..d36964c --- /dev/null +++ b/examples/run_sync_ppo.sh @@ -0,0 +1,15 @@ +#!/bin/bash +python3 training/main_sync_ppo.py \ + n_nodes=1 n_gpus_per_node=8 \ + allocation_mode=sglang.d4p1m1+d2p2m1 \ + cluster.fileroot=/storage/testing/experiments \ + actor.type._class=qwen3 \ + actor.path=/storage/testing/models/Qwen__Qwen3-1.7B \ + ref.type._class=qwen3 \ + ref.path=/storage/testing/models/Qwen__Qwen3-1.7B \ + dataset.path=/storage/testing/dataset/boba_106k_0319.jsonl \ + dataset.train_bs_n_seqs=32 \ + group_size=8 \ + ppo.gen.max_new_tokens=4096 \ + ppo.ppo_n_minibatches=4 \ + actor_train.mb_spec.max_tokens_per_mb=32768 \ No newline at end of file diff --git a/functioncall/code/local_verify.py b/functioncall/code/local_verify.py index be32608..d8c9ff6 100644 --- a/functioncall/code/local_verify.py +++ b/functioncall/code/local_verify.py @@ -121,7 +121,9 @@ def code_verify(id2info, generateds, query_ids, debug=False): if __name__ == "__main__": - data_list = load_jsonl("functioncall/test/test_success_dataset.jsonl") + data_list = load_jsonl( + "/storage/openpsi/data/functioncall/test/test_success_dataset.jsonl" + ) id2info = defaultdict(dict) for item in data_list: id2info[item["query_id"]] = item diff --git a/functioncall/code/verify.py b/functioncall/code/verify.py index cbfad34..5425d29 100644 --- a/functioncall/code/verify.py +++ b/functioncall/code/verify.py @@ -32,7 +32,11 @@ def construct_testcases( result.append({"input": input_, "expectedOutput": output_}) continue - oss_basepath = "http://antsys-hcsfaas-images-dev.cn-heyuan-alipay-office.oss-alipay.aliyuncs.com/" + oss_basepath = os.getenv("REAL_OSS_TESTCASE_PATH", "") + if not oss_basepath: + raise FileNotFoundError( + "REAL_OSS_TESTCASE_PATH not set. Cannot use FAAS code reward." + ) input_url = ( input_ if input_.startswith("http") else os.path.join(oss_basepath, input_) ) @@ -153,7 +157,9 @@ def code_verify( if __name__ == "__main__": - data_list = load_jsonl("functioncall/test/test_success_dataset.jsonl") + data_list = load_jsonl( + "/storage/openpsi/data/functioncall/test/test_success_dataset.jsonl" + ) id2info = defaultdict(dict) for item in data_list: id2info[item["query_id"]] = item diff --git a/functioncall/test/test_success_dataset.jsonl b/functioncall/test/test_success_dataset.jsonl deleted file mode 100644 index c8f5a35..0000000 --- a/functioncall/test/test_success_dataset.jsonl +++ /dev/null @@ -1,6 +0,0 @@ -{"task": "code", "query_id": "0001", "prompt": "", "solutions": ["from typing import *\n\nclass Solution:\n def solveNQueens(self, n: int) -> List[List[str]]:\n def generateBoard():\n board = list()\n for i in range(n):\n row[queens[i]] = \"Q\"\n board.append(\"\".join(row))\n row[queens[i]] = \".\"\n return board\n\n def solve(row: int, columns: int, diagonals1: int, diagonals2: int):\n if row == n:\n board = generateBoard()\n solutions.append(board)\n else:\n availablePositions = ((1 << n) - 1) & (~(columns | diagonals1 | diagonals2))\n while availablePositions:\n position = availablePositions & (-availablePositions)\n availablePositions = availablePositions & (availablePositions - 1)\n column = bin(position - 1).count(\"1\")\n queens[row] = column\n solve(row + 1, columns | position, (diagonals1 | position) << 1, (diagonals2 | position) >> 1)\n\n solutions = list()\n queens = [-1] * n\n row = [\".\"] * n\n solve(0, 0, 0, 0)\n return solutions\n# Test case 1: Smallest case, n = 1\n# There is only one queen, so the only solution is a board with a single 'Q'.\nsolution = Solution()\nassert solution.solveNQueens(1) == [['Q']]\n"], "input_output": "{\"inputs\": [], \"outputs\": [], \"fn_name\": \"\", \"remote\": false}", "language": "PYTHON"} -{"task": "code", "query_id": "0002", "prompt": "", "solutions": ["package main\n\nimport (\n\t\"bufio\"\n\t\"fmt\"\n\t\"os\"\n\t\"strconv\"\n\t\"strings\"\n)\n\nfunc main() {\n\tnums, target := arrayToOneStyleInput()\n\tresult := twoSum(nums, target)\n\tfmt.Println(result)\n}\n\nfunc twoSum(nums []int, target int) []int {\n\tprevNums := map[int]int{}\n\tfor i, num := range nums {\n\t\ttargetNum := target - num\n\t\ttargetNumIndex, ok := prevNums[targetNum]\n\t\tif ok {\n\t\t\treturn []int{targetNumIndex, i}\n\t\t} else {\n\t\t\tprevNums[num] = i\n\t\t}\n\t}\n\treturn []int{}\n}\n\nfunc arrayToOneStyleInput() ([]int, int) {\n\t// 读取数组\n\tscanner := bufio.NewScanner(os.Stdin)\n\tscanner.Scan()\n\tarrStr := strings.Trim(scanner.Text(), \"[]\")\n\tarr := stringToIntSlice(strings.ReplaceAll(arrStr, \",\", \" \"))\n\n\t// 读取目标值\n\tscanner.Scan()\n\ttarget, _ := strconv.Atoi(scanner.Text())\n\n\treturn arr, target\n}\n\nfunc stringToIntSlice(s string) []int {\n\tparts := strings.Split(s, \" \")\n\tres := make([]int, len(parts))\n\tfor i, p := range parts {\n\t\tres[i], _ = strconv.Atoi(p)\n\t}\n\treturn res\n}\n"], "input_output": "{\"inputs\": [\"https://artifacts.antgroup-inc.cn/artifact/repositories/artifacts-pre-test-common/runtime/golang/1.0.1/input.txt\"], \"outputs\": [\"https://artifacts.antgroup-inc.cn/artifact/repositories/artifacts-pre-test-common/runtime/golang/1.0.1/output.txt\"], \"fn_name\": \"\", \"remote\": true }", "language": "GO"} -{"task": "code", "query_id": "0003", "prompt": "", "solutions": ["public class TestMain {\n public static void main(String[] args) {\n assert \"test\".equals(\"test\");\n }\n}"], "input_output": "{\"inputs\": [], \"outputs\": [], \"fn_name\": \"\", \"remote\": false}", "language": "JAVA"} -{"task": "code", "query_id": "0004", "prompt": "", "solutions": ["#include \n#include \nint main() {\n std::string name = \"Alice\";\n std::cout << \"Hello, \" << name << \"! Welcome to the world of C++!\\n\";\n return 0;\n}\n"], "input_output": "{\"inputs\": [], \"outputs\": [], \"fn_name\": \"\", \"remote\": false}", "language": "CPP"} -{"task": "code", "query_id": "0005", "prompt": "", "solutions": ["import time\n\ndef square(num):\n return num ** 2\n\nresult=square(int(input()))\nprint(result)"], "input_output": "{\"inputs\": [\"2\", \"2\", \"2\"], \"outputs\": [\"4\\n\", \"4\\n\", \"4\\n\"], \"fn_name\": \"\", \"remote\": false}", "language": "PYTHON"} -{"task": "code", "query_id": "0006", "prompt": "", "solutions": ["#include \n\nint square(int number) {\n return number * number;\n}\n\nint main() {\n int num;\n std::cin >> num;\n\n int result = square(num);\n std::cout << result << std::endl;\n\n return 0;\n}"], "input_output": "{\"inputs\": [\"http://antsys-hcsfaas-images-dev.cn-heyuan-alipay-office.oss-alipay.aliyuncs.com/functioncall/content_2.txt\"], \"outputs\": [\"http://antsys-hcsfaas-images-dev.cn-heyuan-alipay-office.oss-alipay.aliyuncs.com/functioncall/content_4.txt\"], \"fn_name\": \"\", \"remote\": true }", "language": "CPP"} \ No newline at end of file diff --git a/patch/sglang/v0.4.6.post4.patch b/patch/sglang/v0.4.6.post4.patch index 892f335..b7dbd09 100644 --- a/patch/sglang/v0.4.6.post4.patch +++ b/patch/sglang/v0.4.6.post4.patch @@ -1,501 +1,3 @@ -diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py -index 60091b9a..7a1c856b 100644 ---- a/python/sglang/srt/layers/logits_processor.py -+++ b/python/sglang/srt/layers/logits_processor.py -@@ -15,6 +15,7 @@ - - import dataclasses - import logging -+import os - from typing import List, Optional, Union - - import torch -@@ -44,6 +45,15 @@ from sglang.srt.model_executor.forward_batch_info import ( - ) - from sglang.srt.utils import dump_to_file - -+# When compute the input and output tokens logprobs, if the rows of the -+# logprobs are too large, the peak memory usage will be too high. For example, -+# if the logprobs are [10000, 150000], the peak memory usage will be greater -+# than 10000 * 150000 * 4 / 1024 / 1024 = 5722.05 MB. (4 is the size of float) -+# So we split the logprobs into multiple chunks. -+LOGITS_PROCESSER_CHUNK_SIZE = int( -+ os.environ.get("SGLANG_LOGITS_PROCESSER_CHUNK_SIZE", "2048") -+) -+ - logger = logging.getLogger(__name__) - - -@@ -286,15 +296,17 @@ class LogitsProcessor(nn.Module): - input_logprob_indices = None - else: - # Input logprobs are required. -- # Find 3 different indices. -+ # Find 4 different indices. - # 1. pruned_states: hidden states that we want logprobs from. - # 2. sample_indices: Indices that have sampled tokens. - # 3. input_logprob_indices: Indices that have input logprob tokens. -+ # 4. sequence_index_mapping: map pruned_states indices to top_logprobs_nums and token_ids_logprobs indices. - sample_index_pt = -1 - sample_indices = [] - input_logprob_indices_pt = 0 - input_logprob_indices = [] - pt, pruned_states = 0, [] -+ idx, sequence_index_mapping = 0, [] - for extend_logprob_start_len, extend_len in zip( - logits_metadata.extend_logprob_start_lens_cpu, - logits_metadata.extend_seq_lens_cpu, -@@ -310,7 +322,11 @@ class LogitsProcessor(nn.Module): - # by a caller. - assert extend_len > start_len - pruned_states.append(hidden_states[pt + start_len : pt + extend_len]) -+ # sequence_index_mapping, repeat this for loop index -+ sequence_index_mapping.extend([idx] * (extend_len - start_len)) -+ idx += 1 - pt += extend_len -+ - sample_index_pt += extend_len - start_len - sample_indices.append(sample_index_pt) - input_logprob_indices.extend( -@@ -321,6 +337,7 @@ class LogitsProcessor(nn.Module): - ) - input_logprob_indices_pt += extend_len - start_len - -+ sequence_index_mapping.append(idx - 1) - pruned_states = torch.cat(pruned_states) - sample_indices = torch.tensor( - sample_indices, device=pruned_states.device, dtype=torch.int64 -@@ -329,12 +346,6 @@ class LogitsProcessor(nn.Module): - input_logprob_indices, device=pruned_states.device, dtype=torch.int64 - ) - -- # Compute logits for both input and sampled tokens. -- logits = self._get_logits(pruned_states, lm_head, logits_metadata) -- sampled_logits = ( -- logits[sample_indices] if sample_indices is not None else logits -- ) -- - if self.debug_tensor_dump_output_folder: - assert ( - not self.do_tensor_parallel_all_gather -@@ -370,67 +381,176 @@ class LogitsProcessor(nn.Module): - else: - assert False, "Should never reach" - -+ del hidden_states -+ - if not logits_metadata.extend_return_logprob: -+ # Compute logits for both input and sampled tokens. -+ logits = self._get_logits(pruned_states, lm_head, logits_metadata) -+ sampled_logits = ( -+ logits[sample_indices] if sample_indices is not None else logits -+ ) -+ - # Decode mode or extend mode without return_logprob. - return LogitsProcessorOutput( - next_token_logits=sampled_logits, - hidden_states=hidden_states_to_store, - ) - else: -- input_logprobs = logits[input_logprob_indices] -- del hidden_states, logits -+ # Compute logprobs requires lot of memory, so we split pruned_states -+ # into chunks of rows to compute input_logprobs separately, then -+ # concatenate the results. -+ return self._compute_output_by_chunk( -+ pruned_states, -+ sample_indices, -+ hidden_states_to_store, -+ input_logprob_indices, -+ sequence_index_mapping, -+ lm_head, -+ logits_metadata, -+ ) - -- # Normalize the logprob w/o temperature, top-p -- pruned_lens = torch.tensor( -- logits_metadata.extend_logprob_pruned_lens_cpu, -- device=input_logprobs.device, -+ def _compute_output_by_chunk( -+ self, -+ pruned_states: torch.Tensor, -+ sample_indices: torch.Tensor, -+ hidden_states_to_store: Optional[torch.Tensor], -+ input_logprob_indices: torch.Tensor, -+ index_mapping: list[int], -+ lm_head: VocabParallelEmbedding, -+ logits_metadata: LogitsMetadata, -+ ) -> LogitsProcessorOutput: -+ """ -+ compute logprobs for the output token from the hidden states. -+ To avoid using too much memory, we split pruned_states into chunks of -+ rows to compute input_logprobs separately, then concatenate the results. -+ -+ Returns: -+ LogitsProcessorOutput: logits processor output class -+ """ -+ -+ # Normalize the logprob w/o temperature, top-p -+ pruned_lens = torch.tensor( -+ logits_metadata.extend_logprob_pruned_lens_cpu, -+ device=pruned_states.device, -+ ) -+ if logits_metadata.temp_scaled_logprobs: -+ logits_metadata.temperature = torch.repeat_interleave( -+ logits_metadata.temperature.view(-1), -+ pruned_lens, -+ ).view(-1, 1) -+ if logits_metadata.top_p_normalized_logprobs: -+ logits_metadata.top_p = torch.repeat_interleave( -+ logits_metadata.top_p, -+ pruned_lens, - ) -- if logits_metadata.temp_scaled_logprobs: -- logits_metadata.temperature = torch.repeat_interleave( -- logits_metadata.temperature.view(-1), -- pruned_lens, -- ).view(-1, 1) -- if logits_metadata.top_p_normalized_logprobs: -- logits_metadata.top_p = torch.repeat_interleave( -- logits_metadata.top_p, -- pruned_lens, -- ) -- input_logprobs = self.compute_temp_top_p_normalized_logprobs( -- input_logprobs, logits_metadata -+ -+ # The peak memory usage is proportional to the chunk size. -+ chunk_size = LOGITS_PROCESSER_CHUNK_SIZE -+ num_chunks = (pruned_states.shape[0] + chunk_size - 1) // chunk_size -+ -+ input_token_logprobs = [] -+ if logits_metadata.extend_return_top_logprob: -+ input_top_logprobs_val = [] -+ input_top_logprobs_idx = [] -+ else: -+ input_top_logprobs_val = None -+ input_top_logprobs_idx = None -+ -+ if logits_metadata.extend_token_ids_logprob: -+ input_token_ids_logprobs_val = [] -+ input_token_ids_logprobs_idx = [] -+ else: -+ input_token_ids_logprobs_val = None -+ input_token_ids_logprobs_idx = None -+ -+ # It a single sequence is split into multiple chunks, we need to keep track -+ # of the pruned length of the sequences in the previous chunks. -+ split_len_topk = 0 -+ split_len_token_ids = 0 -+ -+ for i in range(num_chunks): -+ start_idx = i * chunk_size -+ end_idx = min((i + 1) * chunk_size, pruned_states.shape[0]) -+ -+ # Get indices for this chunk -+ chunk_mask = (input_logprob_indices >= start_idx) & ( -+ input_logprob_indices < end_idx -+ ) -+ -+ global_indices = input_logprob_indices[chunk_mask] -+ chunk_indices = global_indices - start_idx -+ chunk_states = pruned_states[start_idx:end_idx] -+ chunk_logits = self._get_logits(chunk_states, lm_head, logits_metadata) -+ -+ if chunk_indices.numel() == 0: -+ continue -+ -+ # Compute the logprobs of the chunk -+ chunk_input_logprobs = chunk_logits[chunk_indices] -+ chunk_input_logprobs = self.compute_temp_top_p_normalized_logprobs( -+ chunk_input_logprobs, global_indices, logits_metadata - ) - -+ # For each chunk, we need to get the slice of the sequence_index_mapping -+ chunk_slice = slice(index_mapping[start_idx], index_mapping[end_idx] + 1) -+ - # Get the logprob of top-k tokens - if logits_metadata.extend_return_top_logprob: -- ( -+ split_len_topk = self.get_top_logprobs( -+ chunk_input_logprobs, -+ logits_metadata, -+ chunk_slice, - input_top_logprobs_val, - input_top_logprobs_idx, -- ) = self.get_top_logprobs(input_logprobs, logits_metadata) -- else: -- input_top_logprobs_val = input_top_logprobs_idx = None -+ split_len_topk, -+ ) - - # Get the logprob of given token id - if logits_metadata.extend_token_ids_logprob: -- ( -+ split_len_token_ids = self.get_token_ids_logprobs( -+ chunk_input_logprobs, -+ logits_metadata, -+ chunk_slice, - input_token_ids_logprobs_val, - input_token_ids_logprobs_idx, -- ) = self.get_token_ids_logprobs(input_logprobs, logits_metadata) -- else: -- input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None -- -- input_token_logprobs = input_logprobs[ -- torch.arange(input_logprobs.shape[0], device=input_logprobs.device), -- logits_metadata.extend_input_logprob_token_ids_gpu, -- ] -+ split_len_token_ids, -+ ) - -- return LogitsProcessorOutput( -- next_token_logits=sampled_logits, -- input_token_logprobs=input_token_logprobs, -- input_top_logprobs_val=input_top_logprobs_val, -- input_top_logprobs_idx=input_top_logprobs_idx, -- hidden_states=hidden_states_to_store, -- input_token_ids_logprobs_val=input_token_ids_logprobs_val, -- input_token_ids_logprobs_idx=input_token_ids_logprobs_idx, -+ # Handle sampled logits for the chunk if needed -+ chunk_sample_mask = (sample_indices >= start_idx) & ( -+ sample_indices < end_idx - ) -+ if i == 0: # Initialize sampled_logits on first chunk -+ sampled_logits = torch.empty( -+ (sample_indices.shape[0], chunk_logits.shape[1]), -+ dtype=chunk_logits.dtype, -+ device=chunk_logits.device, -+ ) -+ if chunk_sample_mask.any(): -+ chunk_sample_indices = sample_indices[chunk_sample_mask] - start_idx -+ sampled_logits[chunk_sample_mask] = chunk_logits[chunk_sample_indices] -+ -+ # Get the logprob of the requested token ids -+ chunk_input_token_logprobs = chunk_input_logprobs[ -+ torch.arange( -+ chunk_input_logprobs.shape[0], device=chunk_input_logprobs.device -+ ), -+ logits_metadata.extend_input_logprob_token_ids_gpu[start_idx:end_idx], -+ ] -+ input_token_logprobs.append(chunk_input_token_logprobs) -+ -+ # Concatenate the results -+ input_token_logprobs = torch.cat(input_token_logprobs, dim=0) -+ -+ return LogitsProcessorOutput( -+ hidden_states=hidden_states_to_store, -+ next_token_logits=sampled_logits, -+ input_token_logprobs=input_token_logprobs, -+ input_top_logprobs_val=input_top_logprobs_val, -+ input_top_logprobs_idx=input_top_logprobs_idx, -+ input_token_ids_logprobs_val=input_token_ids_logprobs_val, -+ input_token_ids_logprobs_idx=input_token_ids_logprobs_idx, -+ ) - - def _get_logits( - self, -@@ -498,60 +618,142 @@ class LogitsProcessor(nn.Module): - return logits - - @staticmethod -- def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata): -+ def get_top_logprobs( -+ logprobs: torch.Tensor, -+ logits_metadata: LogitsMetadata, -+ chunk_slice: slice, -+ input_top_logprobs_val: List, -+ input_top_logprobs_idx: List, -+ split_pruned_len: int, -+ ): -+ """Get top-k logprobs for each sequence in the chunk. -+ -+ Args: -+ logprobs: Log probabilities tensor of shape [seq_len, vocab_size] -+ logits_metadata: Metadata containing top-k and pruned length info -+ chunk_slice: Slice of sequences to process -+ input_top_logprobs_val: List to store top-k logprob values -+ input_top_logprobs_idx: List to store top-k token indices -+ split_pruned_len: Length of pruned tokens from previous chunk -+ -+ Returns: -+ int: Number of remaining tokens to process in next chunk -+ """ -+ - max_k = max(logits_metadata.top_logprobs_nums) -- ret = all_logprobs.topk(max_k, dim=1) -+ ret = logprobs.topk(max_k, dim=1) - values = ret.values.tolist() - indices = ret.indices.tolist() - -- input_top_logprobs_val, input_top_logprobs_idx = [], [] -- - pt = 0 -- for k, pruned_len in zip( -- logits_metadata.top_logprobs_nums, -- logits_metadata.extend_logprob_pruned_lens_cpu, -- ): -+ next_split_pruned_len = 0 -+ top_k_nums = logits_metadata.top_logprobs_nums[chunk_slice] -+ pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[chunk_slice] -+ -+ for n, (k, pruned_len) in enumerate(zip(top_k_nums, pruned_lens)): -+ # Adjust pruned length for first sequence -+ if n == 0: -+ pruned_len -= split_pruned_len -+ else: -+ split_pruned_len = 0 -+ - if pruned_len <= 0: -- input_top_logprobs_val.append([]) -- input_top_logprobs_idx.append([]) -+ if n == 0: -+ input_top_logprobs_val.append([]) -+ input_top_logprobs_idx.append([]) - continue - -- input_top_logprobs_val.append( -- [values[pt + j][:k] for j in range(pruned_len)] -- ) -- input_top_logprobs_idx.append( -- [indices[pt + j][:k] for j in range(pruned_len)] -- ) -- pt += pruned_len -+ val = [] -+ idx = [] -+ for j in range(pruned_len): -+ # Handle remaining tokens in next chunk if any -+ if pt + j >= len(values): -+ next_split_pruned_len = split_pruned_len + j -+ break -+ val.append(values[pt + j][:k]) -+ idx.append(indices[pt + j][:k]) -+ -+ if split_pruned_len <= 0 and len(val) > 0: -+ input_top_logprobs_val.append(val) -+ input_top_logprobs_idx.append(idx) -+ else: -+ input_top_logprobs_val[-1].extend(val) -+ input_top_logprobs_idx[-1].extend(idx) - -- return input_top_logprobs_val, input_top_logprobs_idx -+ pt += pruned_len -+ return next_split_pruned_len - - @staticmethod - def get_token_ids_logprobs( -- all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata -+ logprobs: torch.Tensor, -+ logits_metadata: LogitsMetadata, -+ chunk_slice: slice, -+ input_token_ids_logprobs_val: List, -+ input_token_ids_logprobs_idx: List, -+ split_pruned_len: int = 0, - ): -- input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], [] -+ """Get token_ids logprobs for each sequence in the chunk. -+ -+ Args: -+ logprobs: Log probabilities tensor of shape [seq_len, vocab_size] -+ logits_metadata: Metadata containing token IDs and pruned length info -+ chunk_slice: Slice of sequences to process -+ input_token_ids_logprobs_val: List to store token logprob values -+ input_token_ids_logprobs_idx: List to store token indices -+ split_pruned_len: Length of pruned tokens from previous chunk -+ -+ Returns: -+ int: Number of remaining tokens to process in next chunk -+ """ - pt = 0 -- for token_ids, pruned_len in zip( -- logits_metadata.token_ids_logprobs, -- logits_metadata.extend_logprob_pruned_lens_cpu, -+ next_split_pruned_len = 0 -+ token_ids_logprobs_chunk = logits_metadata.token_ids_logprobs[chunk_slice] -+ pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[chunk_slice] -+ -+ for n, (token_ids, pruned_len) in enumerate( -+ zip( -+ token_ids_logprobs_chunk, -+ pruned_lens, -+ ) - ): -+ # Adjust pruned length for first sequence -+ if n == 0: -+ pruned_len -= split_pruned_len -+ else: -+ split_pruned_len = 0 -+ - if pruned_len <= 0: -- input_token_ids_logprobs_val.append([]) -- input_token_ids_logprobs_idx.append([]) -+ if n == 0: -+ input_token_ids_logprobs_val.append([]) -+ input_token_ids_logprobs_idx.append([]) - continue - -- input_token_ids_logprobs_val.append( -- [all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)] -- ) -- input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)]) -- pt += pruned_len -+ val = [] -+ idx = [] -+ for j in range(pruned_len): -+ # Handle remaining tokens in next chunk if any -+ if pt + j >= logprobs.shape[0]: -+ next_split_pruned_len = split_pruned_len + j -+ break -+ if token_ids is not None: -+ val.append(logprobs[pt + j, token_ids].tolist()) -+ idx.append(token_ids) -+ -+ if split_pruned_len <= 0 and len(val) > 0: -+ input_token_ids_logprobs_val.append(val) -+ input_token_ids_logprobs_idx.append(idx) -+ elif len(val) > 0: -+ input_token_ids_logprobs_val[-1].extend(val) -+ input_token_ids_logprobs_idx[-1].extend(idx) - -- return input_token_ids_logprobs_val, input_token_ids_logprobs_idx -+ pt += pruned_len -+ return next_split_pruned_len - - @staticmethod - def compute_temp_top_p_normalized_logprobs( -- last_logits: torch.Tensor, logits_metadata: LogitsMetadata -+ last_logits: torch.Tensor, -+ indices: torch.Tensor, -+ logits_metadata: LogitsMetadata, - ) -> torch.Tensor: - """ - compute logprobs for the output token from the given logits. -@@ -561,19 +763,20 @@ class LogitsProcessor(nn.Module): - """ - # Scale logits if temperature scaling is enabled - if logits_metadata.temp_scaled_logprobs: -- last_logits = last_logits / logits_metadata.temperature -+ last_logits = last_logits / logits_metadata.temperature[indices] -+ -+ top_p = None -+ if logits_metadata.top_p is not None: -+ top_p = logits_metadata.top_p[indices] - - # Normalize logprobs if top_p normalization is enabled - # NOTE: only normalize logprobs when top_p is set and not equal to 1.0 -- if ( -- logits_metadata.top_p_normalized_logprobs -- and (logits_metadata.top_p != 1.0).any() -- ): -+ if logits_metadata.top_p_normalized_logprobs and (top_p != 1.0).any(): - from sglang.srt.layers.sampler import top_p_normalize_probs_torch - - probs = torch.softmax(last_logits, dim=-1) - del last_logits -- probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p) -+ probs = top_p_normalize_probs_torch(probs, top_p) - return torch.log(probs) - else: - return torch.nn.functional.log_softmax(last_logits, dim=-1) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 5390668c..db370d19 100644 --- a/python/sglang/srt/managers/io_struct.py diff --git a/pyproject.toml b/pyproject.toml index b12ec4a..459e69b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,106 @@ classifiers = [ "Programming Language :: Python :: 3.10", ] +dependencies = [ + # Core ML/AI libraries + "torch>2.0.0", + "huggingface_hub", + "datasets", + "accelerate", + "transformers==4.51.1", + + # Scientific computing + "numpy<2.0.0", + "scipy", + "pandas", + "matplotlib", + "seaborn", + "h5py", + + # Utilities and data processing + "nltk", + "sentencepiece", + "einops", + "tqdm", + "rich", + "orjson>=3.10.16", + "pydantic", + "PyYAML", + "omegaconf", + "hydra-core", + "packaging", + "tabulate", + + # Monitoring and logging + "wandb", + "tensorboardx", + "colorama", + "colorlog", + "psutil", + "pynvml", + + # Performance and compression + "ninja", + "numba", + "blosc", + "pybind11>=2.10.0", + + # Networking and async + "networkx==3.3", + "aiofiles", + "aiohttp>=3.11.10", + "httpx>=0.28.1", + "pyzmq", + "paramiko", + "etcd3", + "protobuf<3.21", + + # Distributed computing + "ray", + "redis", + + # Web frameworks + "fastapi>=0.115.12", + "uvicorn>=0.34.2", + "uvloop>=0.21.0", + "flask", + + # Build and packaging tools + "build>=1.2.1", + "wheel>=0.43.0", + "setuptools>=62.3.0,<75.9", + "cookiecutter>2.1.1", + + # System utilities + "distro-info>=1.0", + "python-debian>=0.1.49", + "func_timeout", + + # Development tools (consider moving to optional dependencies) + "pytest", + "ipython", + "jupyter-book", + "sphinx", + "sphinx-nefertiti", + "black==25.1.0", + "isort==5.13.2", + "clang-format==19.1.7", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "black==25.1.0", + "isort==5.13.2", + "clang-format==19.1.7", +] + +docs = [ + "sphinx", + "sphinx-nefertiti", + "jupyter-book", +] + [tool.setuptools.dynamic] version = {attr = "realhf.__version__"} diff --git a/realhf/apps/main.py b/realhf/apps/main.py index 55c07f5..2a5a23f 100644 --- a/realhf/apps/main.py +++ b/realhf/apps/main.py @@ -153,11 +153,8 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0): # set env vars BASE_ENVIRONS = constants.get_env_vars( REAL_MODE=args.mode.upper(), - CLUSTER_SPEC_PATH=cluster_spec_path, REAL_RECOVER_RUN="1" if is_recover_run else "0", REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0", - FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""), - REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", ""), ) for k, v in BASE_ENVIRONS.items(): os.environ[k] = v diff --git a/realhf/base/constants.py b/realhf/base/constants.py index e96ca3c..6b0e946 100644 --- a/realhf/base/constants.py +++ b/realhf/base/constants.py @@ -571,6 +571,15 @@ def get_repo_path() -> pathlib.Path: def get_env_vars(**kwargs): + kwargs.update( + CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""), + REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"), + REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"), + FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""), + REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"), + REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", "localhost:2379"), + REAL_OSS_TESTCASE_PATH=os.getenv("REAL_OSS_TESTCASE_PATH", ""), + ) return { **kwargs, "REAL_PACKAGE_PATH": str(get_repo_path()), diff --git a/realhf/experiments/common/check.py b/realhf/experiments/common/check.py index 23f167e..b1b46bc 100644 --- a/realhf/experiments/common/check.py +++ b/realhf/experiments/common/check.py @@ -2,7 +2,9 @@ # Copyright 2024 Wei Fu & Zhiyu Mei # Licensed under the Apache License, Version 2.0 (the "License"). import os -from typing import List +from typing import List, Optional + +from huggingface_hub import snapshot_download, try_to_load_from_cache from realhf.api.cli_args import ModelTrainEvalConfig, SGLangConfig, vLLMConfig from realhf.api.quickstart.device_mesh import RPCAllocation @@ -56,13 +58,89 @@ def check_valid_optimizer(model: ModelTrainEvalConfig): ) -def check_valid_model_and_path(role: str, model: ModelTrainEvalConfig): - if not os.path.exists(model.path): - raise FileNotFoundError( - f"The model path `{model.path}` for `{role}` does not exist locally. " - "You must download the HuggingFace checkpoint before loading it." +def check_valid_model_and_path(role: str, model: ModelTrainEvalConfig, fileroot): + """ + Check if model path exists locally, download from HuggingFace Hub if not. + + Args: + role: The role identifier for the model + model: ModelTrainEvalConfig object containing model configuration + + Returns: + str: The local path to the model (either existing or newly downloaded) + + Raises: + Exception: If download fails or other errors occur + """ + if os.path.exists(model.path): + return + + logger.info(f"Model path `{model.path}` for `{role}` does not exist locally.") + + # Extract model name from path or use the path as model identifier + # Adjust this logic based on how your ModelTrainEvalConfig stores the model identifier + model_name = model.path + + # First, check if model exists in HuggingFace cache + logger.info(f"Checking HuggingFace cache for model: {model_name}") + cached_path = _check_huggingface_cache(model_name) + if cached_path: + logger.info(f"Found model in HuggingFace cache: {cached_path}") + model.path = cached_path + return + + # If not in cache, download to /models/ directory + logger.info(f"Model not found in cache. Downloading from HuggingFace Hub...") + target_path = os.path.join(fileroot, "models", model_name.replace("/", "--")) + if not os.path.exists(target_path): + snapshot_download( + repo_id=model_name, + local_dir=target_path, # Replace '/' to avoid path issues + local_dir_use_symlinks=False, ) + logger.info(f"Model downloaded successfully to: {target_path}") + # Update the model object's path to point to the downloaded location + model.path = target_path + + +def _check_huggingface_cache(model_name: str) -> Optional[str]: + """ + Check if a model exists in the HuggingFace cache. + + Args: + model_name: The HuggingFace model identifier (e.g., 'bert-base-uncased') + + Returns: + Optional[str]: Path to cached model if found, None otherwise + """ + # Try to find the model files in cache + # We'll check for common files that should exist in a model repo + common_files = [ + "config.json", + "pytorch_model.bin", + "model.safetensors", + "tf_model.h5", + ] + + cached_path = None + for filename in common_files: + file_path = try_to_load_from_cache( + repo_id=model_name, filename=filename, repo_type="model" + ) + if file_path is not None: + # Get the directory containing the cached file + cached_path = os.path.dirname(file_path) + break + + # Verify the cached directory exists and contains model files + if cached_path and os.path.exists(cached_path): + # Double-check that it's a valid model directory + if any(os.path.exists(os.path.join(cached_path, f)) for f in common_files): + return cached_path + + return None + def check_valid_parallel_batch_size(rpc_alloc: RPCAllocation): try: diff --git a/realhf/impl/dataset/math_parser.py b/realhf/impl/dataset/math_parser.py index 58e4d43..49945a5 100644 --- a/realhf/impl/dataset/math_parser.py +++ b/realhf/impl/dataset/math_parser.py @@ -4,6 +4,7 @@ import json import os import signal import subprocess +import sys import uuid from typing import * @@ -54,6 +55,10 @@ def parse_line(id2info, prompt_str, generated, query_id): f.write(json.dumps({"answer": generated, "solution": cur_solution}) + "\n") venv_python = "/sympy/bin/python3" + if not os.path.exists(venv_python): + venv_python = "sympy/bin/python3" + if not os.path.exists(venv_python): + venv_python = sys.executable # logger.info(f"math verify working dir: `{os.getcwd()}`") pro = subprocess.Popen( " ".join( @@ -67,6 +72,7 @@ def parse_line(id2info, prompt_str, generated, query_id): shell=True, preexec_fn=os.setsid, stdout=subprocess.DEVNULL, + stderr=sys.stdout, ) pro.wait() try: @@ -131,6 +137,10 @@ def parse_lines_in_parallel( all_query_indices.append(query_indices) venv_python = "/sympy/bin/python3" + if not os.path.exists(venv_python): + venv_python = "sympy/bin/python3" + if not os.path.exists(venv_python): + venv_python = sys.executable # logger.info(f"math verify working dir: `{os.getcwd()}`") procs = [] for tmp_id in tmp_ids: @@ -148,6 +158,7 @@ def parse_lines_in_parallel( shell=True, preexec_fn=os.setsid, stdout=subprocess.DEVNULL, + stderr=sys.stdout, ) procs.append(pro) for pro in procs: @@ -183,7 +194,7 @@ def parse_lines_in_parallel( if __name__ == "__main__": sample = { - "answers": ["-\\frac{2}{3}"], + "answers": ["\\boxed{-\\frac{2}{3}}"], "solutions": [ "1. **Apply the operation $\\otimes$ to the innermost parentheses first:**\n \\[\n (1 \\otimes 2) \\otimes 3 = \\left(\\frac{1^2}{2}\\right) \\otimes 3 = \\frac{1}{2} \\otimes 3\n \\]\n \\[\n 1 \\otimes (2 \\otimes 3) = 1 \\otimes \\left(\\frac{2^2}{3}\\right) = 1 \\otimes \\frac{4}{3}\n \\]\n\n2. **Calculate each part using the definition of $\\otimes$:**\n \\[\n \\frac{1}{2} \\otimes 3 = \\frac{\\left(\\frac{1}{2}\\right)^2}{3} = \\frac{\\frac{1}{4}}{3} = \\frac{1}{12}\n \\]\n \\[\n 1 \\otimes \\frac{4}{3} = \\frac{1^2}{\\frac{4}{3}} = \\frac{1}{\\frac{4}{3}} = \\frac{3}{4}\n \\]\n\n3. **Subtract the two results:**\n \\[\n \\left(\\frac{1}{12}\\right) - \\left(\\frac{3}{4}\\right) = \\frac{1}{12} - \\frac{9}{12} = -\\frac{8}{12} = -\\frac{2}{3}\n \\]\n\n4. **Conclude with the final answer:**\n \\[\n \\boxed{A}\n \\]", "\\boxed{-\\frac{2}{3}}", diff --git a/realhf/impl/model/backend/inference.py b/realhf/impl/model/backend/inference.py index 3a68d0a..1cc16f4 100644 --- a/realhf/impl/model/backend/inference.py +++ b/realhf/impl/model/backend/inference.py @@ -56,7 +56,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine): unique_params = params_tensor[1] if constants.parallelism_rank() == 0: - logger.info( + logger.debug( f"CONFIG: default_train_mbs={self.pipe_runner.default_train_mbs} " f"default_inf_mbs={self.pipe_runner.default_inf_mbs} " f"num_layers(this stage)={self.module.num_layers} " @@ -65,7 +65,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine): f"tp_size={constants.tensor_parallel_world_size()} " ) if constants.data_parallel_rank() == 0: - logger.info( + logger.debug( f"rank={constants.parallelism_rank()} " f"stage={constants.pipe_parallel_rank()} " f"layers={self.module.num_layers} " @@ -105,7 +105,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine): ) mb_inputs, fwd_indices, bwd_indices = input_.split(mb_spec) if constants.parallelism_rank() == 0: - logger.info( + logger.debug( f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, " f"#tokens: {input_.data['packed_input_ids'].shape[0]}, " f"pp_size={constants.pipe_parallel_world_size()}, " @@ -163,7 +163,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine): # mini-batches, so we split mini-batches in the outer loop. mb_inputs, *_ = input_.split(mb_spec) if constants.parallelism_rank() == 0: - logger.info( + logger.debug( f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, " f"#tokens: {input_.data['packed_input_ids'].shape[0]}, " f"pp_size={constants.pipe_parallel_world_size()}, " diff --git a/realhf/impl/model/backend/pipe_runner.py b/realhf/impl/model/backend/pipe_runner.py index dcc30b0..67ddeeb 100644 --- a/realhf/impl/model/backend/pipe_runner.py +++ b/realhf/impl/model/backend/pipe_runner.py @@ -809,7 +809,7 @@ class PipelineRunner: ) mb_inputs, fwd_indices, bwd_indices = input_.split(mb_spec) if constants.parallelism_rank() == 0: - logger.info( + logger.debug( f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, " f"#tokens: {input_.data['packed_input_ids'].shape[0]}, " f"pp_size={constants.pipe_parallel_world_size()}, " @@ -886,7 +886,7 @@ class PipelineRunner: mb_spec = MicroBatchSpec(n_mbs=self.default_inf_mbs) mb_inputs, *_ = input_.split(mb_spec) if constants.parallelism_rank() == 0: - logger.info( + logger.debug( f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, " f"#tokens: {input_.data['packed_input_ids'].shape[0]}, " f"pp_size={constants.pipe_parallel_world_size()}, " @@ -1013,7 +1013,7 @@ class PipelineRunner: dist.all_reduce(total_loss_weight, group=constants.data_parallel_group()) if constants.parallelism_rank() == 0: - logger.info( + logger.debug( f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, " f"#tokens: {input_.data['packed_input_ids'].shape[0]}, " f"pp_size={constants.pipe_parallel_world_size()}, " diff --git a/realhf/system/buffer.py b/realhf/system/buffer.py index 16d695a..a3cc43b 100644 --- a/realhf/system/buffer.py +++ b/realhf/system/buffer.py @@ -63,7 +63,7 @@ class _TensorDictSequenceBuffer: tuple(sorted(self.__storage[idx].sample.keys)), ) if ck not in BUFFER_KEY_WARN_CACHE: - logger.warning( + logger.debug( f"Unexpected keys in the sample. Expected keys from all MFCDef: {self.__keys}. " f"Keys in the current sample: {self.__storage[idx].sample.keys}" ) @@ -300,7 +300,7 @@ class AsyncIOSequenceBuffer: ) can_do_rpcs = {rpc.name: self._can_do_rpc(rpc) for rpc in self.rpcs} - logger.info(f"After putting batch, can do RPCs? {can_do_rpcs}.") + logger.debug(f"After putting batch, can do RPCs? {can_do_rpcs}.") self._lock.notify(len(self._rpc_names)) return indices @@ -348,7 +348,7 @@ class AsyncIOSequenceBuffer: async def get_batch_for_rpc( self, rpc: dfg.MFCDef ) -> Tuple[List[int], SequenceSample]: - logger.info( + logger.debug( f"MFC {rpc.name} is waiting for its input keys: {rpc.input_keys}..." ) rpc_idx = self._rpc_names.index(rpc.name) @@ -359,7 +359,7 @@ class AsyncIOSequenceBuffer: while not self._can_do_rpc(rpc): await self._lock.wait() - logger.info(f"Input keys ({rpc.input_keys}) for MFC {rpc.name} are ready!") + logger.debug(f"Input keys ({rpc.input_keys}) for MFC {rpc.name} are ready!") self._assert_valid_indicator() ready_indices = np.nonzero( diff --git a/realhf/system/controller.py b/realhf/system/controller.py index 7757fea..678fa9e 100644 --- a/realhf/system/controller.py +++ b/realhf/system/controller.py @@ -671,14 +671,8 @@ class RayController: env_vars = constants.get_env_vars( REAL_MODE=os.environ.get("REAL_MODE", ""), - CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""), REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""), REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""), - FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""), - REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"), - REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"), - REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"), - REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", ""), ) runtime_env = { "env_vars": env_vars, diff --git a/realhf/system/function_executor.py b/realhf/system/function_executor.py index f62d3b2..8add93c 100644 --- a/realhf/system/function_executor.py +++ b/realhf/system/function_executor.py @@ -86,7 +86,7 @@ class FunctionExecutor: for level, w in enumerate(self.topo_widths): for _ in range(w): await self.ctrl.topo_level_count.get() - logger.info(f"DFG level {level}. Flushing {w} function calls.") + logger.debug(f"DFG level {level}. Flushing {w} function calls.") self.stream.request( handlers=list(range(self.n_model_workers)), handle_type="flush" ) @@ -203,7 +203,7 @@ class FunctionExecutor: await asyncio.sleep(1) async def execute_step(self): - logger.info("Waiting for the finish of the execution graph.") + logger.debug("Waiting for the finish of the execution graph.") loop = asyncio.get_event_loop() tasks = [ diff --git a/realhf/system/gserver_manager.py b/realhf/system/gserver_manager.py index fcffc47..cc6811d 100644 --- a/realhf/system/gserver_manager.py +++ b/realhf/system/gserver_manager.py @@ -183,7 +183,7 @@ class GserverManager(AsyncWorker): success = res["success"] if success: if "num_paused_requests" in res: - logger.info( + logger.debug( f"{res['num_paused_requests']} requests are interrupted " f"during updateing weights for server {server_index}: {server_url}" ) @@ -419,7 +419,7 @@ class GserverManager(AsyncWorker): if has_capacity and not is_staled: self.rollout_stat.submitted += 1 self.rollout_stat.running += 1 - logger.info( + logger.debug( f"Allocate rollout for qid {req.qid}. " f"Submitted: {self.rollout_stat.submitted}, " f"running: {self.rollout_stat.running}, " @@ -458,7 +458,7 @@ class GserverManager(AsyncWorker): self.rollout_stat.running -= 1 if resp_meta.accepted: self.rollout_stat.accepted += 1 - logger.info( + logger.debug( f"Finish rollout for qid {resp_meta.qid}. " f"Running: {self.rollout_stat.running}, " f"running: {self.rollout_stat.running}, " diff --git a/realhf/system/master_worker.py b/realhf/system/master_worker.py index ac419ee..9a80c7c 100644 --- a/realhf/system/master_worker.py +++ b/realhf/system/master_worker.py @@ -62,7 +62,7 @@ class MasterWorker(worker_base.AsyncWorker): self.__topo_widths = [] for generation in nx.topological_generations(self.__model_rpcs[0]._G): self.__topo_widths.append(len(generation)) - logger.info("Topological widths: " + str(self.__topo_widths)) + logger.debug("Topological widths: " + str(self.__topo_widths)) self.__rpc_srcs = list(filter(lambda rpc: rpc.is_src, self.__model_rpcs)) self.__rpc_dsts = list(filter(lambda rpc: rpc.is_dst, self.__model_rpcs)) @@ -169,7 +169,7 @@ class MasterWorker(worker_base.AsyncWorker): def initialize_models(self): # Initialize model backends. model_names = list(self.__model_topos.keys()) - self.logger.info(f"Initialize model backends with order: {model_names}.") + self.logger.debug(f"Initialize model backends with order: {model_names}.") train_rpcs = list( filter( lambda rpc: rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP, @@ -318,7 +318,7 @@ class MasterWorker(worker_base.AsyncWorker): self.__summary_writer = SummaryWriter(log_dir=self.tensorboard_config.path) # Create coroutines for model RPCs. - logger.info(f"Creating asyncio coroutines...") + logger.debug(f"Creating asyncio coroutines...") self.func_executor = FunctionExecutor( rpcs=self.__model_rpcs, msid2mwid=self.config.msid2mwid, @@ -334,7 +334,7 @@ class MasterWorker(worker_base.AsyncWorker): self.func_executor.data_loading_dp_idx = ( self.__recover_info.data_loading_dp_idx ) - logger.info(f"Coroutines created. The master worker is ready to run.") + logger.debug(f"Coroutines created. The master worker is ready to run.") self.__initialized = True self._train_start_time = time.perf_counter() diff --git a/realhf/system/model_function_call.py b/realhf/system/model_function_call.py index 8a5c193..659dc6e 100644 --- a/realhf/system/model_function_call.py +++ b/realhf/system/model_function_call.py @@ -292,7 +292,7 @@ class ModelFunctionCall: samples, forward_indices, _ = sample.split( data_api.MicroBatchSpec(n_mbs=self.dp_size) ) - blogger.info( + blogger.debug( f"DP split (DP size {self.dp_size}) for RPC {self.rpc.name}: " f"#seqs: {[s.bs for s in samples]}, " f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in samples]}" @@ -353,7 +353,7 @@ class ModelFunctionCall: keys=rpc.input_keys, pattern=pattern, ) - blogger.info(f"Data tranfer plan for `{rpc.name}`: {data_transfer_plan}.") + blogger.debug(f"Data tranfer plan for `{rpc.name}`: {data_transfer_plan}.") # Update storage tracker for transferred data. if pattern == "bcast": @@ -450,7 +450,7 @@ class ModelFunctionCall: elif isinstance(res, list): for j, r in enumerate(res): logger.info( - f"RPC name {rpc.name} returns ({j}/{len(res)})\n{data_api.tabulate_stats(r)}" + f"RPC name {rpc.name} returns ({j + 1}/{len(res)})\n{data_api.tabulate_stats(r)}" ) offset = len(res) * ctrl.step_info.global_step logging.log_wandb_tensorboard( diff --git a/realhf/system/model_worker.py b/realhf/system/model_worker.py index f9253e1..bd2e99b 100644 --- a/realhf/system/model_worker.py +++ b/realhf/system/model_worker.py @@ -510,7 +510,7 @@ class ModelWorker(worker_base.Worker): elif hook == "save": self.__save_model(hook_data) elif hook == "evaluate": - logger.info(f"hook_data: {hook_data}") + logger.debug(f"hook_data: {hook_data}") with constants.model_scope(hook_data["model_name"]): ret = self._interface.evaluate(self._model, self._eval_dataloader) if ret: @@ -557,7 +557,7 @@ class ModelWorker(worker_base.Worker): assert not handled and res is None with constants.model_scope(request.handler.model_name): if constants.parallelism_rank() == 0: - logger.info( + logger.debug( f"Model `{request.handler.model_name}` handling " f"{len(request.pre_hooks)} pre-hook for request `{request.handle_name}`. " f"The current hook is `{request.pre_hooks[0]}`. " @@ -714,7 +714,7 @@ class ModelWorker(worker_base.Worker): blogger.debug( f"Model worker {self.__worker_index} cleared cache in {et-st:.4f}s. " ) - logger.info( + logger.debug( "Get clear_data_cache, dump cuda tmark. " f"Remaining data in local storage: {self.data_manager.storage_size()}. " ) @@ -742,7 +742,7 @@ class ModelWorker(worker_base.Worker): model_name = request.handler.model_name with constants.model_scope(model_name): if constants.parallelism_rank() == 0: - blogger.info( + blogger.debug( f"Model #{request.handler.model_name}# " f"starts handling request *{request.handle_name}*." ) @@ -789,7 +789,7 @@ class ModelWorker(worker_base.Worker): and self._is_dp_head and self._dp_rank == 0 ): - blogger.info( + blogger.debug( f"Model #{request.handler.model_name}# handle " f"request *{request.handle_name}*" f" in ${time.perf_counter() - tik:.4f}$s" diff --git a/realhf/system/rollout_worker.py b/realhf/system/rollout_worker.py index 5b575ce..045bf06 100644 --- a/realhf/system/rollout_worker.py +++ b/realhf/system/rollout_worker.py @@ -209,7 +209,9 @@ class RolloutWorker(AsyncWorker): resp.raise_for_status() res = await resp.json() if not res["success"]: - logger.info(f"Cannot allocate new rollout because: {res['reason']}") + logger.debug( + f"Cannot allocate new rollout because: {res['reason']}" + ) return res["success"] async def _poll_async(self): @@ -272,7 +274,7 @@ class RolloutWorker(AsyncWorker): self.rollout_stat.submitted += 1 self.rollout_stat.running += 1 - logger.info( + logger.debug( f"Submit a new rollout for qid {qid}. " f"Submit: {self.rollout_stat.submitted}, " f"running: {self.rollout_stat.running}, " @@ -323,7 +325,7 @@ class RolloutWorker(AsyncWorker): ) as resp: resp.raise_for_status() assert (await resp.json())["success"] - logger.info( + logger.debug( f"Finish rollout for qid {qid}. " f"Submit: {self.rollout_stat.submitted}, " f"running: {self.rollout_stat.running}, " diff --git a/training/configs/async-ppo.yaml b/training/configs/async-ppo.yaml new file mode 100644 index 0000000..5e098ec --- /dev/null +++ b/training/configs/async-ppo.yaml @@ -0,0 +1,198 @@ +# Basic experiment info +experiment_name: async-ppo +trial_name: my-trial +seed: 1 +mode: ray +metric_discovery_port: 17997 +wandb: + mode: disabled + entity: null + project: null + name: null + job_type: null + group: null + notes: null + tags: null + config: null +tensorboard: + path: null +recover_mode: auto +recover_retries: 10 +recover_after: 10 + +exp_ctrl: + total_train_epochs: 5 + save_freq_epochs: 1 + save_freq_steps: null + save_freq_secs: null + ckpt_freq_epochs: null + ckpt_freq_steps: null + ckpt_freq_secs: 600 + eval_freq_epochs: null + eval_freq_steps: null + eval_freq_secs: null + benchmark_steps: null + benchmark_n_seqs: null +torch_cache_mysophobia: true +cache_clear_freq: 1 + +# Asynchronous RL options +new_tokens_per_chunk: 10000000000 +max_head_offpolicyness: 4 +n_rollout_workers: null +max_concurrent_rollouts: null +flush_request_timeout: 300 + +# Asynchronous worker resources +cpus_per_generation_server: 4 +mem_per_generation_server: 61440 +cpus_per_gserver_manager: 4 +mem_per_gserver_manager: 10240 +cpus_per_rollout_worker: 4 +mem_per_rollout_worker: 20480 + +# Allocation and parallelism +allocation_mode: sglang.d4p1m1+d2p2m1 +n_nodes: 1 +n_gpus_per_node: 8 + +# Cluster configuration +ray_temp_path: /tmp/ray +cluster: + fileroot: /storage/ray/experiments + n_nodes: 32 + n_gpus_per_node: 8 + +# Model +actor: + type: + _class: qwen3 + path: /storage/openpsi/models/Qwen3-1.7B/ + init_from_scratch: false + gradient_checkpointing: true + bf16: false + optimizer: + type: adam + lr: 2.0e-05 + weight_decay: 0.05 + beta1: 0.9 + beta2: 0.95 + eps: 1.0e-05 + min_lr_ratio: 0.0 + lr_scheduler_type: constant + warmup_steps_proportion: 0.001 + initial_loss_scale: 4294967296.0 + min_loss_scale: 1.0 + loss_scale_window: 5.0 + hysteresis: 2 + gradient_clipping: 1.0 + megatron: + ddp: + grad_reduce_in_fp32: true + overlap_grad_reduce: true + use_distributed_optimizer: true + sglang: + disable_cuda_graph: false + disable_radix_cache: false + disable_cuda_graph_padding: false + enable_nccl_nvls: false + disable_outlines_disk_cache: false + disable_custom_all_reduce: false + disable_overlap_schedule: false + enable_mixed_chunk: false + enable_torch_compile: false + torch_compile_max_bs: 32 + cuda_graph_max_bs: null + cuda_graph_bs: null + torchao_config: '' + enable_nan_detection: false + enable_p2p_check: false + triton_attention_reduce_in_fp32: false + triton_attention_num_kv_splits: 8 + num_continuous_decode_steps: 1 + enable_memory_saver: false + allow_auto_truncate: false + attention_backend: flashinfer + sampling_backend: null + context_length: 32768 + mem_fraction_static: 0.8 + max_running_requests: null + chunked_prefill_size: -1 + max_prefill_tokens: 32768 + schedule_policy: lpm + schedule_conservativeness: 1.0 + cpu_offload_gb: 0 + dtype: float16 + kv_cache_dtype: auto + log_level: warning + log_level_http: warning + log_requests: false + log_requests_level: 0 + show_time_cost: false + enable_metrics: true + decode_log_interval: 1 +ref: + type: + _class: qwen3 + path: /storage/openpsi/models/Qwen3-1.7B/ + init_from_scratch: false + bf16: false +actor_train: + mb_spec: + max_tokens_per_mb: 30720 +ref_inf: + mb_spec: + max_tokens_per_mb: 30720 +actor_inf: + mb_spec: + max_tokens_per_mb: 30720 + +# Dataset +shuffle_dataset: true +dataset: + path: /storage/datasets/boba_106k_0319.jsonl + max_prompt_len: 1024 + train_bs_n_seqs: 512 + +# Algorithm +group_size: 16 +mask_too_long: false +group_adv_norm: false +rw_type: sparse +success_rate_ub: 1.0 +success_rate_lb: 0.0 +ppo: + gen: + n: 1 + max_new_tokens: 27648 + min_new_tokens: 0 + greedy: false + top_p: 1.0 + top_k: 1000000 + temperature: 1.0 + ppo_n_minibatches: 4 + eps_clip: 0.2 + c_clip: null + value_eps_clip: 0.2 + early_stop_imp_ratio: 5.0 + actor_sample_reuse: 1 + critic_sample_reuse: 1 + max_reward_clip: 20.0 + reward_output_scaling: 5.0 + reward_output_bias: 0.0 + fuse_rew_ref: true + discount: 1.0 + gae_lambda: 1.0 + adv_norm: true + kl_ctl: 0.0 + use_adaptive_kl_ctl: false + disable_value: true + recompute_logprob: true + use_decoupled_loss: true + behav_imp_weight_cap: null + +# worker resources +cpus_per_master_worker: 4 +mem_per_master_worker: 20000 +cpus_per_model_worker: 4 +mem_per_model_worker: 90000 diff --git a/training/configs/async-ppo/async-ppo-1.7b-gpu32.yaml b/training/configs/async-ppo/async-ppo-1.7b-gpu32.yaml deleted file mode 100644 index 72eb287..0000000 --- a/training/configs/async-ppo/async-ppo-1.7b-gpu32.yaml +++ /dev/null @@ -1,65 +0,0 @@ -max_head_offpolicyness: 4 -experiment_name: async-ppo-1.7b-gpu32 -trial_name: my-trial -mode: ray -cluster: - fileroot: /storage/ray/experiments -wandb: - mode: disabled -recover_mode: auto -recover_retries: 10 -allocation_mode: sglang.d24p1m1+d4p2m1 -n_nodes: 4 -n_gpus_per_node: 8 -cache_clear_freq: 1 -exp_ctrl: - total_train_epochs: 5 - save_freq_epochs: 1 - ckpt_freq_secs: 600 -torch_cache_mysophobia: true -dataset: - path: /storage/datasets/boba_106k_0319.jsonl - max_prompt_len: 1024 - train_bs_n_seqs: 512 -group_size: 16 -group_adv_norm: false -actor: - type: - _class: qwen3 - path: /storage/openpsi/models/Qwen3-1.7B/ - optimizer: - lr: 2e-05 - lr_scheduler_type: constant - eps: 1e-5 - warmup_steps_proportion: 0.001 - hysteresis: 2 - sglang: - mem_fraction_static: 0.8 -actor_train: - mb_spec: - max_tokens_per_mb: 30720 -actor_gen: - mb_spec: - max_tokens_per_mb: 30720 -actor_inf: - mb_spec: - max_tokens_per_mb: 30720 -ppo: - gen: - max_new_tokens: 27648 - min_new_tokens: 0 - top_p: 1.0 - top_k: 1000000 - temperature: 1.0 - ppo_n_minibatches: 4 - kl_ctl: 0.0 - discount: 1.0 - value_eps_clip: 0.2 - disable_value: true - reward_output_scaling: 5 - reward_output_bias: 0.0 - adv_norm: true - value_norm: true - recompute_logprob: true - use_decoupled_loss: true - diff --git a/training/configs/async-ppo/async-ppo-1.7b-gpu8.yaml b/training/configs/async-ppo/async-ppo-1.7b-gpu8.yaml deleted file mode 100644 index 418b586..0000000 --- a/training/configs/async-ppo/async-ppo-1.7b-gpu8.yaml +++ /dev/null @@ -1,65 +0,0 @@ -max_head_offpolicyness: 4 -experiment_name: async-ppo-1.7b-gpu8 -trial_name: my-trial -mode: ray -cluster: - fileroot: /storage/ray/experiments -wandb: - mode: disabled -recover_mode: auto -recover_retries: 10 -allocation_mode: sglang.d4p1m1+d2p2m1 -n_nodes: 1 -n_gpus_per_node: 8 -cache_clear_freq: 1 -exp_ctrl: - total_train_epochs: 5 - save_freq_epochs: 1 - ckpt_freq_secs: 600 -torch_cache_mysophobia: true -dataset: - path: /storage/datasets/boba_106k_0319.jsonl - max_prompt_len: 1024 - train_bs_n_seqs: 512 -group_size: 16 -group_adv_norm: false -actor: - type: - _class: qwen3 - path: /storage/openpsi/models/Qwen3-1.7B/ - optimizer: - lr: 2e-05 - lr_scheduler_type: constant - eps: 1e-5 - warmup_steps_proportion: 0.001 - hysteresis: 2 - sglang: - mem_fraction_static: 0.8 -actor_train: - mb_spec: - max_tokens_per_mb: 30720 -actor_gen: - mb_spec: - max_tokens_per_mb: 30720 -actor_inf: - mb_spec: - max_tokens_per_mb: 30720 -ppo: - gen: - max_new_tokens: 27648 - min_new_tokens: 0 - top_p: 1.0 - top_k: 1000000 - temperature: 1.0 - ppo_n_minibatches: 4 - kl_ctl: 0.0 - discount: 1.0 - value_eps_clip: 0.2 - disable_value: true - reward_output_scaling: 5 - reward_output_bias: 0.0 - adv_norm: true - value_norm: true - recompute_logprob: true - use_decoupled_loss: true - diff --git a/training/configs/ppo/ppo-1.5b-gpu32.yaml b/training/configs/ppo/ppo-1.5b-gpu32.yaml deleted file mode 100644 index 777a407..0000000 --- a/training/configs/ppo/ppo-1.5b-gpu32.yaml +++ /dev/null @@ -1,62 +0,0 @@ -experiment_name: ppo-1.7b-gpu32 -trial_name: my-trial -mode: ray -cluster: - fileroot: /storage/ray/experiments -wandb: - mode: disabled -recover_mode: auto -recover_retries: 10 -allocation_mode: sglang.d16p1m1+d8p2m1 -n_nodes: 4 -n_gpus_per_node: 8 -cache_clear_freq: 1 -exp_ctrl: - total_train_epochs: 5 - save_freq_epochs: 1 - ckpt_freq_secs: 600 -torch_cache_mysophobia: true -dataset: - path: /storage/datasets/boba_106k_0319.jsonl - max_prompt_len: 1024 - train_bs_n_seqs: 512 -group_size: 16 -group_adv_norm: false -actor: - type: - _class: qwen3 - path: /storage/openpsi/models/Qwen3-1.7B/ - optimizer: - lr: 2e-05 - lr_scheduler_type: constant - eps: 1e-5 - warmup_steps_proportion: 0.001 - hysteresis: 2 - sglang: - mem_fraction_static: 0.8 -actor_train: - mb_spec: - max_tokens_per_mb: 30720 -actor_gen: - mb_spec: - max_tokens_per_mb: 30720 -actor_inf: - mb_spec: - max_tokens_per_mb: 30720 -ppo: - gen: - max_new_tokens: 27648 - min_new_tokens: 0 - top_p: 1.0 - top_k: 1000000 - temperature: 1.0 - ppo_n_minibatches: 4 - kl_ctl: 0.0 - discount: 1.0 - value_eps_clip: 0.2 - disable_value: true - reward_output_scaling: 5 - reward_output_bias: 0.0 - adv_norm: true - value_norm: true - diff --git a/training/configs/sft.yaml b/training/configs/sft.yaml new file mode 100644 index 0000000..822369b --- /dev/null +++ b/training/configs/sft.yaml @@ -0,0 +1,97 @@ +# Basic experiment info +experiment_name: sft +trial_name: my-trial +seed: 1 +mode: ray +metric_discovery_port: 17997 +wandb: + mode: disabled + entity: null + project: null + name: null + job_type: null + group: null + notes: null + tags: null + config: null +tensorboard: + path: null +recover_mode: auto +recover_retries: 10 +recover_after: 10 + +exp_ctrl: + total_train_epochs: 5 + save_freq_epochs: 1 + save_freq_steps: null + save_freq_secs: null + ckpt_freq_epochs: null + ckpt_freq_steps: null + ckpt_freq_secs: 600 + eval_freq_epochs: null + eval_freq_steps: null + eval_freq_secs: null + benchmark_steps: null + benchmark_n_seqs: null +torch_cache_mysophobia: true +cache_clear_freq: 1 + +# Allocation and parallelism +allocation_mode: d4p2m1 +n_nodes: 1 +n_gpus_per_node: 8 + +# Cluster configuration +ray_temp_path: /tmp/ray +cluster: + fileroot: /storage/ray/experiments + n_nodes: 32 + n_gpus_per_node: 8 + +# Model +model: + type: + _class: qwen2 + path: /storage/models/DeepSeek-R1-Distill-Qwen-7B + init_from_scratch: false + gradient_checkpointing: true + bf16: false + optimizer: + type: adam + lr: 1.0e-05 + weight_decay: 0.1 + beta1: 0.9 + beta2: 0.95 + eps: 1.0e-05 + min_lr_ratio: 0.0 + lr_scheduler_type: constant + warmup_steps_proportion: 0.03 + offload: false + initial_loss_scale: 262144.0 + min_loss_scale: 1.0 + loss_scale_window: 10.0 + hysteresis: 2 + gradient_clipping: 1.0 + megatron: + ddp: + grad_reduce_in_fp32: true + overlap_grad_reduce: true + use_distributed_optimizer: true +allocation: + mb_spec: + n_mbs: 1 + max_tokens_per_mb: 32768 + +# Dataset +dataset: + train_path: /storage/datasets/boba-sft_200_0319.jsonl + valid_path: /storage/datasets/boba-sft_200_0319.jsonl + max_seqlen: 32768 + train_bs_n_seqs: 16 + valid_bs_n_seqs: 16 + +# worker resources +cpus_per_master_worker: 4 +mem_per_master_worker: 20000 +cpus_per_model_worker: 4 +mem_per_model_worker: 90000 \ No newline at end of file diff --git a/training/configs/sft/sft-7b-gpu8.yaml b/training/configs/sft/sft-7b-gpu8.yaml deleted file mode 100644 index f35824e..0000000 --- a/training/configs/sft/sft-7b-gpu8.yaml +++ /dev/null @@ -1,40 +0,0 @@ -experiment_name: sft-7b-gpu8 -trial_name: my-trial -mode: ray -wandb: - mode: disabled -recover_mode: auto -recover_retries: 10 -allocation_mode: d2p4m1 -n_nodes: 1 -n_gpus_per_node: 8 -exp_ctrl: - total_train_epochs: 200 - save_freq_epochs: 1 - ckpt_freq_secs: 600 -torch_cache_mysophobia: true -dataset: - train_path: /storage/datasets/boba-sft_200_0319.jsonl - valid_path: /storage/datasets/boba-sft_200_0319.jsonl - max_seqlen: 32768 - train_bs_n_seqs: 16 - valid_bs_n_seqs: 16 -model: - type: - _class: qwen2 - path: /storage/models/DeepSeek-R1-Distill-Qwen-7B - optimizer: - type: adam - lr_scheduler_type: constant - lr: 1e-5 - warmup_steps_proportion: 0.03 - initial_loss_scale: 262144.0 - loss_scale_window: 10 - hysteresis: 2 - weight_decay: 0.1 - eps: 1e-5 - bf16: true -allocation: - mb_spec: - max_tokens_per_mb: 32768 - diff --git a/training/configs/sync-ppo.yaml b/training/configs/sync-ppo.yaml new file mode 100644 index 0000000..88ae35f --- /dev/null +++ b/training/configs/sync-ppo.yaml @@ -0,0 +1,183 @@ +# Basic experiment info +experiment_name: sync-ppo +trial_name: my-trial +seed: 1 +mode: ray +metric_discovery_port: 17997 +wandb: + mode: disabled + entity: null + project: null + name: null + job_type: null + group: null + notes: null + tags: null + config: null +tensorboard: + path: null +recover_mode: auto +recover_retries: 10 +recover_after: 10 + +exp_ctrl: + total_train_epochs: 5 + save_freq_epochs: 1 + save_freq_steps: null + save_freq_secs: null + ckpt_freq_epochs: null + ckpt_freq_steps: null + ckpt_freq_secs: 600 + eval_freq_epochs: null + eval_freq_steps: null + eval_freq_secs: null + benchmark_steps: null + benchmark_n_seqs: null +torch_cache_mysophobia: true +cache_clear_freq: 1 + +# Allocation and parallelism +allocation_mode: sglang.d4p1m1+d2p2m1 +n_nodes: 1 +n_gpus_per_node: 8 + +# Cluster configuration +ray_temp_path: /tmp/ray +cluster: + fileroot: /storage/ray/experiments + n_nodes: 32 + n_gpus_per_node: 8 + +# Model +actor: + type: + _class: qwen3 + path: /storage/openpsi/models/Qwen3-1.7B/ + init_from_scratch: false + gradient_checkpointing: true + bf16: false + optimizer: + type: adam + lr: 2.0e-05 + weight_decay: 0.05 + beta1: 0.9 + beta2: 0.95 + eps: 1.0e-05 + min_lr_ratio: 0.0 + lr_scheduler_type: constant + warmup_steps_proportion: 0.001 + initial_loss_scale: 4294967296.0 + min_loss_scale: 1.0 + loss_scale_window: 5.0 + hysteresis: 2 + gradient_clipping: 1.0 + megatron: + ddp: + grad_reduce_in_fp32: true + overlap_grad_reduce: true + use_distributed_optimizer: true + sglang: + disable_cuda_graph: false + disable_radix_cache: false + disable_cuda_graph_padding: false + enable_nccl_nvls: false + disable_outlines_disk_cache: false + disable_custom_all_reduce: false + disable_overlap_schedule: false + enable_mixed_chunk: false + enable_torch_compile: false + torch_compile_max_bs: 32 + cuda_graph_max_bs: null + cuda_graph_bs: null + torchao_config: '' + enable_nan_detection: false + enable_p2p_check: false + triton_attention_reduce_in_fp32: false + triton_attention_num_kv_splits: 8 + num_continuous_decode_steps: 1 + enable_memory_saver: false + allow_auto_truncate: false + attention_backend: flashinfer + sampling_backend: null + context_length: 32768 + mem_fraction_static: 0.8 + max_running_requests: null + chunked_prefill_size: -1 + max_prefill_tokens: 32768 + schedule_policy: lpm + schedule_conservativeness: 1.0 + cpu_offload_gb: 0 + dtype: float16 + kv_cache_dtype: auto + log_level: warning + log_level_http: warning + log_requests: false + log_requests_level: 0 + show_time_cost: false + enable_metrics: true + decode_log_interval: 1 +ref: + type: + _class: qwen3 + path: /storage/openpsi/models/Qwen3-1.7B/ + init_from_scratch: false + bf16: false +actor_train: + mb_spec: + max_tokens_per_mb: 30720 +ref_inf: + mb_spec: + max_tokens_per_mb: 30720 +actor_inf: + mb_spec: + max_tokens_per_mb: 30720 + +# Dataset +shuffle_dataset: true +dataset: + path: /storage/datasets/boba_106k_0319.jsonl + max_prompt_len: 1024 + train_bs_n_seqs: 512 + +# Algorithm +group_size: 16 +mask_too_long: false +group_adv_norm: false +rw_type: sparse +success_rate_ub: 1.0 +success_rate_lb: 0.0 +ppo: + gen: + n: 1 + max_new_tokens: 27648 + min_new_tokens: 0 + greedy: false + top_p: 1.0 + top_k: 1000000 + temperature: 1.0 + ppo_n_minibatches: 4 + eps_clip: 0.2 + c_clip: null + value_eps_clip: 0.2 + early_stop_imp_ratio: 5.0 + actor_sample_reuse: 1 + critic_sample_reuse: 1 + max_reward_clip: 20.0 + reward_output_scaling: 5.0 + reward_output_bias: 0.0 + fuse_rew_ref: true + discount: 1.0 + gae_lambda: 1.0 + adv_norm: true + kl_ctl: 0.0 + use_adaptive_kl_ctl: false + disable_value: true + recompute_logprob: false + use_decoupled_loss: false + behav_imp_weight_cap: null + +# worker resources +cpus_per_master_worker: 4 +mem_per_master_worker: 20000 +cpus_per_model_worker: 4 +mem_per_model_worker: 90000 \ No newline at end of file diff --git a/training/main_async_ppo.py b/training/main_async_ppo.py index 3254f45..aa89065 100644 --- a/training/main_async_ppo.py +++ b/training/main_async_ppo.py @@ -61,5 +61,15 @@ def main_ppo_math(args): if __name__ == "__main__": - # Command: python3 training/main_async_ppo.py --config-name async-ppo-1.7b-gpu8 + import argparse + + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--help", action="store_true") + args = parser.parse_args() + if args.help: + from realhf.api.cli_args import print_config_help + + print_config_help(AsyncPPOMATHConfig()) + exit(0) + main_ppo_math() diff --git a/training/main_sft.py b/training/main_sft.py index 968c371..00240b5 100644 --- a/training/main_sft.py +++ b/training/main_sft.py @@ -13,7 +13,7 @@ from realhf.experiments.common.sft_exp import SFTConfig from training.utils import run_experiment -@hydra.main(version_base=None, config_path="configs/sft") +@hydra.main(version_base=None, config_path="configs", config_name="sft") def main(args): # NOTE: we import logging here to avoid hydra logging overwrite import realhf.base.logging as logging @@ -61,5 +61,14 @@ def main(args): if __name__ == "__main__": - # Command: python3 training/main_sft.py --config-name sft-7b-gpu8 + import argparse + + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--help", action="store_true") + args = parser.parse_args() + if args.help: + from realhf.api.cli_args import print_config_help + + print_config_help(SFTConfig()) + exit(0) main() diff --git a/training/main_ppo.py b/training/main_sync_ppo.py similarity index 82% rename from training/main_ppo.py rename to training/main_sync_ppo.py index b2f1b6e..358f29b 100644 --- a/training/main_ppo.py +++ b/training/main_sync_ppo.py @@ -13,7 +13,7 @@ from realhf.experiments.common.ppo_math_exp import PPOMATHConfig from training.utils import run_experiment -@hydra.main(version_base=None, config_path="configs/ppo") +@hydra.main(version_base=None, config_path="configs", config_name="sync-ppo") def main(args): # NOTE: we import logging here to avoid hydra logging overwrite import realhf.base.logging as logging @@ -61,5 +61,14 @@ def main(args): if __name__ == "__main__": - # Command: python3 training/main_ppo.py --config-name ppo-1.5b-gpu32 + import argparse + + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--help", action="store_true") + args = parser.parse_args() + if args.help: + from realhf.api.cli_args import print_config_help + + print_config_help(PPOMATHConfig()) + exit(0) main() diff --git a/training/utils.py b/training/utils.py index 9632509..a282714 100644 --- a/training/utils.py +++ b/training/utils.py @@ -124,15 +124,9 @@ def _run_experiment(exp_cfg, expr_name, trial_name): # Initialize ray in the Ray cluster env_vars = constants.get_env_vars( WADNB_MODE=exp_cfg.wandb.mode, - REAL_MODE=os.environ.get("REAL_MODE", ""), - CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""), - REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""), - REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""), - FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""), - REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"), - REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"), - REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"), - REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", "localhost:2379"), + REAL_MODE="ray", + REAL_RECOVER_RUN="0", + REAL_SAVE_RECOVER_STATES="1", ) runtime_env = { "env_vars": env_vars,