PullRequest: 49 Enable Parallel Execution for Code, Math, and Reference Tasks

Merge branch async-ref-rew of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/49

Signed-off-by: 博惟 <bowei.fw@antgroup.com>
This commit is contained in:
君末 2025-03-28 09:40:54 +08:00
commit f8afa97484
39 changed files with 1340 additions and 2137 deletions

View File

@ -54,22 +54,21 @@ Please refer to the [tutorial](/examples/README.md) under the `examples` directo
# Download the dataset
DATA_PATH=/storage/datasets/
cd $DATA_PATH
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_r1_distilled.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_prompts_for_r1_distilled.jsonl?download=true
# Training in a Ray cluster with 16 nodes
# stage 1
MODEL_PATH=${path_to_DeepSeek-R1-Distill-Qwen-1.5B}
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_r1_distilled.jsonl $DATA_PATH/id2info.json 8192
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_prompts_for_r1_distilled.jsonl 8192
# stage 2
MODEL_PATH=${model_path_from_stage_1}
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_r1_distilled.jsonl $DATA_PATH/id2info.json 16384
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_prompts_for_r1_distilled.jsonl 16384
# stage 3
MODEL_PATH=${model_path_from_stage_2}
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_r1_distilled.jsonl $DATA_PATH/id2info.json 24000
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_prompts_for_r1_distilled.jsonl 24000
```
@ -134,11 +133,10 @@ Table 2. Evaluation on AIME 2024, AMC 2023, and MATH-500. The results are report
# Training in a Ray cluster with 16 nodes
DATA_PATH=/storage/datasets/
cd $DATA_PATH
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_zero.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_orz_zero.jsonl?download=true
MODEL_PATH=${path_to_Qwen2.5-7B}
bash ./examples/train_7B_zero_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_zero.jsonl $DATA_PATH/id2info.json 24000
bash ./examples/train_7B_zero_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_orz_zero.jsonl 24000
```

View File

@ -114,9 +114,8 @@ We provide a dataset for training. Download the dataset and place it in `/storag
```bash
mkdir -p /storage/datasets/
cd /storage/datasets/
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_r1_distilled.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_zero.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_prompts_for_r1_distilled.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_orz_zero.jsonl?download=true
```
## Model
@ -291,7 +290,6 @@ The descriptions of the important parameters are as follows:
+ `MODE`: It is always `ray`, and do not change it to other values when referring to this tutorial for training.
+ `BASE_MODEL_PATH`: The path of the model.
+ `DATA_PATH`: The path of the dataset jsonl file
+ `REAL_MATH_METADATA_PATH`: Set it to the path of the json file of the math metadata, refer to troubleshooting.
+ `CLUSTER_SPEC_PATH`: Set it to the path of cluster_config.json
+ `n_nodes`: The number of nodes
@ -513,10 +511,5 @@ This error typically occurs during the vLLM generation phase and is another symp
- Reduce the training batch size or the number of answers generated per prompt. Note that this may lower sample efficiency and extend training time.
- [Switch vLLM's attention backend to xformers](https://github.com/vllm-project/vllm/issues/5376).
## Others
### How to train with other datasets
The dataset must be in `jsonl` format, with each entry containing two keys: `prompt` (a math problem) and `query_id` (a unique identifier for the problem). After preparing the dataset, update the `REAL_MATH_METADATA_PATH` with the new dataset's metadata. Metadata is a JSON file that records the solutions for each problem. The training code uses this metadata to determine if the model solved a problem correctly.

View File

@ -120,18 +120,8 @@ cd /storage/ray/
```bash
mkdir -p /storage/datasets/
cd /storage/datasets/
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_r1_distilled.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_zero.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
```
如果无法访问 `huggingface.co`,也可以从 ModelScope 下载:
```bash
mkdir -p /storage/datasets/
cd /storage/datasets/
wget https://www.modelscope.cn/datasets/inclusionAI/AReaL-RL-Data/resolve/master/data/prompts_for_r1_distilled.jsonl
wget https://www.modelscope.cn/datasets/inclusionAI/AReaL-RL-Data/resolve/master/data/prompts_for_zero.jsonl
wget https://www.modelscope.cn/datasets/inclusionAI/AReaL-RL-Data/resolve/master/data/id2info.json
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_prompts_for_r1_distilled.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_orz_zero.jsonl?download=true
```
## 模型
@ -146,15 +136,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-R1-D
你也可以在安装 PyPI 和 huggingface_hub 后利用 huggingface CLI 进行下载,具体请参考[官方文档](https://huggingface.co/docs/huggingface_hub/guides/cli)
如果无法访问 `huggingface.co`,也可以从 ModelScope 下载(请确保已经安装了 Git LFS
```
mkdir -p /storage/models
cd /storage/models
git clone https://www.modelscope.cn/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B.git
git clone https://www.modelscope.cn/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B.git
```
## 启动 Ray 集群
@ -315,7 +296,6 @@ python3 -m realhf.apps.quickstart ppo-math --show-args
+ MODE总是为 ray参考本教程进行训练时不要改成其他值。
+ BASE_MODEL_PATH模型的路径
+ DATA_PATH数据集 jsonl 文件的路径
+ REAL_MATH_METADATA_PATH设置成数学 metadata 的 json 文件路径参考troubleshooting。
+ CLUSTER_SPEC_PATH设置成 cluster_config.json 的路径
+ n_nodes节点数量
@ -530,8 +510,3 @@ ALL_PARAMS=(
+ 减小训练batch size或者每个prompt生成的答案数量但减小后会降低样本效率、延长训练时间
+ [将vLLM的attention backend换成xformers](https://github.com/vllm-project/vllm/issues/5376)
## 其他
### 如何用其他数据集进行训练
数据集需要是是 jsonl 格式的文件,其中每一条数据需要包含两个 key分别是 prompt即一道数学问题和query_id即这道数学问题的唯一标识符。在准备好数据集后还需要根据数据集中的题目更新REAL_MATH_METADATA_PATH的内容。metadata 是一个 json 文件,记录了每道题目的答案、来源和解法。训练代码需要根据 metadata 来判断模型是否做对了一道题。

View File

@ -0,0 +1,240 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
"""
Dataset Toolkit - Process and validate code/math datasets with flexible input support
"""
import json
import argparse
import logging
import random
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
# Configure console logging
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
logger = logging.getLogger(__name__)
def load_jsonl(file_path: str) -> List[Dict]:
"""Load JSONL file with validation"""
try:
with open(file_path, "r", encoding="utf-8") as f:
return [json.loads(line) for line in f]
except FileNotFoundError:
print(f"ERROR: JSONL file not found: {file_path}")
raise
except json.JSONDecodeError as e:
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
raise
def process_code_data(file_path: str) -> List[Dict]:
"""Process code dataset from JSONL file"""
if not file_path:
return []
raw_data = load_jsonl(file_path)
processed = []
for item in raw_data:
# Field extraction and transformation
input_output = json.loads(item["input_output"])
processed.append(
{
"task": "code",
"query_id": str(item["id"]),
"prompt": item["question"],
"input_output": json.dumps(
{
"inputs": input_output.get("inputs", []),
"outputs": input_output.get("outputs", []),
"fn_name": item.get("metadata", {}).get("fn_name", ""),
}
),
}
)
return processed
def process_math_data(file_path: str) -> List[Dict]:
"""Process math dataset from JSON/JSONL file"""
if not file_path:
return []
raw_data = load_jsonl(file_path)
processed = []
for item in raw_data:
processed.append(
{
"task": "math",
"query_id": str(item["query_id"]),
"prompt": item["prompt"],
"solutions": item["solutions"],
}
)
return processed
def validate_raw_code(item: Dict) -> Tuple[bool, List[str]]:
"""Validate raw code item structure"""
errors = []
required = {
"task": str,
"query_id": str,
"prompt": str,
"input_output": str,
}
code_input_output_required = {
"inputs": list,
"outputs": list,
#'fn_name': str
}
for field, typ in required.items():
if field not in item:
errors.append(f"Missing required field: {field}")
elif not isinstance(item[field], typ):
errors.append(f"Invalid type for {field}: expected {typ.__name__}")
input_output = json.loads(item["input_output"])
for io_field, io_typ in code_input_output_required.items():
if io_field not in input_output:
errors.append(f"Missing required field: {io_field} in input_output")
elif not isinstance(input_output[io_field], io_typ):
errors.append(
f"Invalid type for {io_field}: expected {io_typ.__name__} in input_output"
)
return (not errors, errors)
def validate_raw_math(item: Dict) -> Tuple[bool, List[str]]:
"""Validate raw math item structure"""
errors = []
required = {"query_id": str, "prompt": str, "solutions": list}
for field, typ in required.items():
if field not in item:
errors.append(f"Missing required field: {field}")
elif not isinstance(item[field], typ):
type_names = (
[t.__name__ for t in typ] if isinstance(typ, tuple) else typ.__name__
)
errors.append(f"Invalid type for {field}: expected {type_names}")
return (not errors, errors)
def validate_raw_item(item: Dict) -> Tuple[bool, List[str]]:
"""Validate raw code item structure"""
if item.get("task", "") == "math":
return validate_raw_math(item)
else:
return validate_raw_code(item)
def check_item_valid(raw_data: list):
if not raw_data:
return defaultdict(int)
logger.info(f"Validating code-math dataset...")
total_stats = defaultdict(int)
for index, item in enumerate(raw_data):
valid, errors = validate_raw_item(item)
if valid:
total_stats["valid"] += 1
else:
total_stats["invalid"] += 1
logger.warning(f"item {index}: {errors}")
return total_stats
def filter_shuffle(shuffle: bool, processed_data):
# Final validation and write output
valid_output = []
invalid_count = 0
for item in processed_data:
valid, errors = validate_raw_item(item)
if valid:
valid_output.append(item)
else:
invalid_count += 1
logger.error(
f'{item["task"]} item {item.get("query_id")} is invalid: {errors}'
)
if shuffle:
random.shuffle(valid_output)
return valid_output, invalid_count
def save_file(output_path: str, processed_data: list):
with open(output_path, "w") as f:
for item in processed_data:
f.write(json.dumps(item) + "\n")
def main():
parser = argparse.ArgumentParser(
description="Dataset Toolkit: Process and validate STEM datasets",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--code", help="Path to code dataset (JSONL)")
parser.add_argument("--math", help="Path to math dataset (JSONL)")
parser.add_argument("--output", help="Output file path (JSONL)")
parser.add_argument(
"--mode",
choices=["check", "process"],
default="check",
help="Operation mode: check raw data or process datasets",
)
parser.add_argument(
"--shuffle", action="store_true", help="Shuffle output data in process mode"
)
args = parser.parse_args()
if args.mode == "check":
# Validate raw data structure
code_data = load_jsonl(args.code) if args.code else []
math_data = load_jsonl(args.math) if args.math else []
total_stats = check_item_valid(code_data + math_data)
# Print validation summary
logger.info(
f"\nValidation Summary: {total_stats['valid']} valid, {total_stats['invalid']} invalid"
)
elif args.mode == "process":
# Process and merge datasets
if not args.output:
logger.error("Output file required in process mode")
return
processed_data = []
stats = defaultdict(int)
if args.code:
code_data = process_code_data(args.code)
logger.info(f"Loaded {len(code_data)} code items")
processed_data.extend(code_data)
stats["code"] = len(code_data)
if args.math:
math_data = process_math_data(args.math)
logger.info(f"Loaded {len(math_data)} math items")
processed_data.extend(math_data)
stats["math"] = len(math_data)
processed_data, invalid_count = filter_shuffle(args.shuffle, processed_data)
stats["invalid"] = invalid_count
save_file(args.output, processed_data)
logger.info("\nProcessing Complete:")
logger.info(f"Total items: {len(processed_data)}")
logger.info(f"Code items: {stats['code']}")
logger.info(f"Math items: {stats['math']}")
logger.info(f"Invalid items: {stats['invalid']}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,119 @@
#!/usr/bin/env python3
import argparse
import json
import random
import sys
from pathlib import Path
from typing import List, Dict, Optional
def load_jsonl(file_path: Path) -> List[Dict]:
"""Load JSONL file with validation"""
try:
with file_path.open("r", encoding="utf-8") as f:
return [json.loads(line) for line in f]
except FileNotFoundError:
print(f"ERROR: JSONL file not found: {file_path}")
raise
except json.JSONDecodeError as e:
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
raise
def load_id2info(file_path: Path) -> Dict:
"""Load ID mapping file with structure validation"""
try:
with file_path.open("r", encoding="utf-8") as f:
data = json.load(f)
if not all(isinstance(v, dict) and "solutions" in v for v in data.values()):
raise ValueError("Invalid id2info structure")
return data
except FileNotFoundError:
print(f"ERROR: ID mapping file not found: {file_path}")
raise
except json.JSONDecodeError as e:
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
raise
def process_data(prompts_data: List[Dict], id2info: Dict, output_path: Path) -> None:
"""Process and save transformed data"""
processed = []
missing_ids = 0
for item in prompts_data:
query_id = item.get("query_id")
if not query_id or query_id not in id2info:
missing_ids += 1
continue
processed.append(
{
"prompt": item.get("prompt", ""),
"task": "math",
"query_id": query_id,
"solutions": id2info[query_id].get("solutions", []),
}
)
if missing_ids:
print(f"WARNING: Found {missing_ids} items with missing/invalid query_id")
if not processed:
print("ERROR: No valid data to process")
sys.exit(1)
random.shuffle(processed)
try:
with output_path.open("w", encoding="utf-8") as f:
for item in processed:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
print(f"SUCCESS: Wrote {len(processed)} items to {output_path}")
except IOError as e:
print(f"ERROR: Failed to write output: {str(e)}")
raise
def main():
"""
Command line entry point
Example usage:
python math_process.py \
--prompts_path ./input/prompts.jsonl \
--id2info_path ./input/id2info.json \
--output_path ./output/processed.jsonl
"""
parser = argparse.ArgumentParser(
description="Math dataset processing tool",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--prompts_path",
type=Path,
required=True,
help="Path to prompts.jsonl input file",
)
parser.add_argument(
"--id2info_path",
type=Path,
required=True,
help="Path to id2info.json input file",
)
parser.add_argument(
"--output_path",
type=Path,
default="output.jsonl",
help="Path for processed output file",
)
args = parser.parse_args()
print("Starting data processing...")
try:
prompts_data = load_jsonl(args.prompts_path)
id2info = load_id2info(args.id2info_path)
process_data(prompts_data, id2info, args.output_path)
except Exception as e:
print(f"FATAL: Processing failed - {str(e)}")
sys.exit(1)
print("Operation completed successfully")
if __name__ == "__main__":
main()

View File

@ -6,7 +6,7 @@ TRAIN_BATCH_SIZE="1024"
GROUP_SIZE="8"
NODES="16"
ALLOCATION_MODE="vllm.d64p1m1+d32p2m1"
MAX_NEW_TOKENS=$4
MAX_NEW_TOKENS=$3
MAX_NUM_SEQS=128
PPO_MBS=4
KL_CTL=0.001
@ -18,7 +18,6 @@ BASE_MODEL_PATH="$1"
# original data
DATA_PATH="$2"
REAL_MATH_METADATA_PATH="$3"
# Option 1: The experiment runs locally with subprocesses.
# MODE=local
@ -53,7 +52,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
# It's the user's responsibility to tune them appropriately.
unset CLUSTER_SPEC_PATH
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
python3 -m realhf.apps.quickstart ppo-math \
mode=$MODE \

View File

@ -5,7 +5,7 @@ TRAIN_BATCH_SIZE="512"
GROUP_SIZE="64"
NODES="16"
ALLOCATION_MODE="vllm.d16p1m4+d32p2m1"
MAX_NEW_TOKENS=$4
MAX_NEW_TOKENS=$3
MAX_NUM_SEQS=128
PPO_MBS=4
KL_CTL=0.0
@ -17,7 +17,6 @@ BASE_MODEL_PATH="$1"
# original data
DATA_PATH="$2"
REAL_MATH_METADATA_PATH="$3"
# Option 1: The experiment runs locally with subprocesses.
# MODE=local
@ -52,7 +51,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
# It's the user's responsibility to tune them appropriately.
unset CLUSTER_SPEC_PATH
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
python3 -m realhf.apps.quickstart ppo-math \
mode=$MODE \

View File

@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
EXP_NAME=ppo-zero-distill-1.5B-n1
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
DATASET_NAME="prompts_for_r1_distilled.jsonl"
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
NODES=1
ALLOCATION_MODE="actor_gen:d4p1m2,*:d4p2m1"

View File

@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
EXP_NAME=ppo-zero-distill-1.5B-n16
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
DATASET_NAME="prompts_for_r1_distilled.jsonl"
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
NODES=16
ALLOCATION_MODE="vllm.d64p1m1+d32p2m1"

View File

@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
EXP_NAME=ppo-zero-distill-1.5B-n4
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
DATASET_NAME="prompts_for_r1_distilled.jsonl"
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
NODES=4
ALLOCATION_MODE="vllm.d16p1m1+d8p2m1"

View File

@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
EXP_NAME=ppo-zero-distill-7B-n16
MODEL_NAME="DeepSeek-R1-Distill-Qwen-7B"
DATASET_NAME="prompts_for_r1_distilled.jsonl"
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
NODES=16
ALLOCATION_MODE="vllm.d16p1m4+d32p2m1"

View File

@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
EXP_NAME=ppo-zero-7B-zero-n16
MODEL_NAME="Qwen2.5-7B"
DATASET_NAME="prompts_for_zero.jsonl"
DATASET_NAME="full_orz_zero.jsonl"
NODES=16
ALLOCATION_MODE="vllm.d64p1m1+d32p2m1"

View File

@ -1,52 +0,0 @@
#!/bin/bash
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
EXP_NAME=ppo-zero-distill-1.5B-n16-jun1
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
DATASET_NAME="codeparrot-apps-test.jsonl"
NODES=16
ALLOCATION_MODE="vllm.d64p1m1+d32p2m1"
LOG_DIR="/storage/ray/train_batch_logs/${EXP_NAME}/$(date +'%Y%m%d-%H%M%S')"
mkdir -p ${LOG_DIR}
echo "Log Dir: ${LOG_DIR}"
MAX_WORKERS=$(expr 16 / ${NODES})
FIFO_NAME=$(mktemp -u)
mkfifo "$FIFO_NAME"
exec 3<>"$FIFO_NAME"
rm -f "$FIFO_NAME"
for ((i=0; i<MAX_WORKERS; i++)); do
echo >&3
done
ALL_PARAMS=(
"${EXP_NAME} ${MODEL_NAME} ${DATASET_NAME} 1024 8 ${NODES} ${ALLOCATION_MODE} 16384 128 1 0.001"
#"${EXP_NAME} ${MODEL_NAME} ${DATASET_NAME} 1024 8 ${NODES} ${ALLOCATION_MODE} 16384 128 1 0.001"
#"${EXP_NAME} ${MODEL_NAME} ${DATASET_NAME} 1024 8 ${NODES} ${ALLOCATION_MODE} 16384 128 1 0.001"
)
echo "Task Count: ${#ALL_PARAMS[@]}"
for ((i=0; i<${#ALL_PARAMS[@]}; i++)); do
read -u3
{
echo "$(date +"%Y-%m-%d %H:%M.%S") Task $i started: ${ALL_PARAMS[$i]}"
bash -c "bash ${SCRIPT_DIR}/train_code_small_on_ray.sh ${ALL_PARAMS[$i]} &> ${LOG_DIR}/${i}.log"
echo "$(date +"%Y-%m-%d %H:%M.%S") Task $i completed with exit code: $?, ${ALL_PARAMS[$i]}"
sleep 120
echo >&3
} &
sleep 120
done
wait
exec 3>&-
echo "All tasks completed"

View File

@ -1,123 +0,0 @@
#!/bin/sh
MODEL_FAMILY=qwen2
EXP_NAME="$1"
MODEL_NAME="$2"
DATASET_NAME="$3"
TRAIN_BATCH_SIZE="$4"
GROUP_SIZE="$5"
NODES="$6"
ALLOCATION_MODE="$7"
MAX_NEW_TOKENS=$8
MAX_NUM_SEQS=$9
PPO_MBS=${10}
KL_CTL=${11}
MAX_TOKEN_PER_MB=$(expr 2048 + ${MAX_NEW_TOKENS} + 1024)
MAX_SEQ_LEN_TO_CAPTURE=$(expr 2048 + ${MAX_NEW_TOKENS})
BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
# original data
DATA_PATH="/storage/datasets/${DATASET_NAME}"
REAL_CODE_METADATA_PATH="/storage/datasets/codeparrot-apps-test.jsonl"
# Option 1: The experiment runs locally with subprocesses.
# MODE=local
# Option 2: The experiment runs in a Ray cluster
# MODE=ray
# Option 3: The experiment runs in a SLURM + pyxis cluster
# Using the slurm mode requires a cluster spec file
# and setting CLUSTER_SPEC_PATH to the path of it.
MODE=ray
# `experiment_name` and `trial_name` can be arbitrary.
# Logs and saved checkpoints will be indexed by them.
#EXP_NAME=ppo-zero--${MODEL_NAME}--${DATASET_NAME}
#EXP_NAME=ppo-zero-distill-1.5B-default
TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
# We use the "heuristic" allocation mode here to automatically determine the parallelism strategy
# for each model function call, i.e., actor generation, critic inference, actor train, etc.
# The number of GPUs is `n_nodes` * `n_gpus_per_node` (not set explictly here, defaults to 8).
# ReaL will make full use of these available GPUs to design allocations.
# This does not ensure the optimal throughput, but it is a good starting point.
# The `heuristic` allocation mode is not ensured to run with every model configurations.
# For example, if the vocabulary size is an odd number, the model parallelism may not work.
# In these cases, you can use the `ppo_manual.sh` to specify the parallelism strategy manually.
# The `ppo` subcommand specifies that this is a PPO experiment.
# The `save_freq_steps` is set to `null` to disable saving checkpoints.
# Enable it if you want to save checkpoints.
# The `ppo` option is used to control the generation and PPO algorithm hyperparameters.
# Note that the performance of PPO is sensitive to the the pre-trained model and hyperparameters.
# It's the user's responsibility to tune them appropriately.
unset CLUSTER_SPEC_PATH
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
REAL_CODE_METADATA_PATH=${REAL_CODE_METADATA_PATH} \
FUNCTIONCALL_SERVICE_DOMAIN="" \
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
python3 -m realhf.apps.quickstart ppo-code \
mode=$MODE \
experiment_name=$EXP_NAME \
trial_name=$TRIAL_NAME \
wandb.mode=disabled \
exp_ctrl.total_train_epochs=1 \
exp_ctrl.save_freq_epochs=1 \
exp_ctrl.ckpt_freq_secs=600 \
group_size=${GROUP_SIZE} \
group_adv_norm=False \
use_dense_reward=False \
reward_delta=True \
rw_type=sparse \
check_xml_format=False \
actor.type._class=$MODEL_FAMILY \
actor.path=$BASE_MODEL_PATH \
actor.vllm.hybrid_train=False \
actor.vllm.enforce_eager=False \
actor.vllm.max_seq_len_to_capture=${MAX_SEQ_LEN_TO_CAPTURE} \
actor.vllm.max_num_seqs=${MAX_NUM_SEQS} \
actor.vllm.gpu_memory_utilization=1 \
actor.vllm.swap_space=64 \
critic.type._class=$MODEL_FAMILY \
critic.type.is_critic=True \
critic.init_critic_from_actor=True \
critic.path=$BASE_MODEL_PATH\
ref.type._class=$MODEL_FAMILY \
ref.path=$BASE_MODEL_PATH \
rew.type._class=$MODEL_FAMILY \
rew.type.is_critic=True \
rew.init_critic_from_actor=True \
rew.path=$BASE_MODEL_PATH \
dataset.path=$DATA_PATH \
dataset.max_prompt_len=2048 \
dataset.train_bs_n_seqs=${TRAIN_BATCH_SIZE} \
ppo.gen.max_new_tokens=${MAX_NEW_TOKENS} \
ppo.gen.min_new_tokens=0 \
ppo.disable_value=True \
ppo.gen.top_p=1 ppo.gen.top_k=1000000 \
ppo.ppo_n_minibatches=${PPO_MBS} \
ppo.gen.temperature=0.6 \
ppo.kl_ctl=${KL_CTL} \
ppo.value_eps_clip=0.2 \
ppo.reward_output_scaling=5 \
ppo.reward_output_bias=0.0 \
ppo.adv_norm=True ppo.value_norm=True \
mask_too_long=False \
ppo.discount=1.0 \
actor.optimizer.lr=1e-6 \
critic.optimizer.lr=5e-6 \
actor.optimizer.lr_scheduler_type=constant \
actor_gen.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
ref_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
rew_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
critic_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
actor_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
critic_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
cache_clear_freq=1 \
n_nodes=${NODES} \
allocation_mode="'${ALLOCATION_MODE}'" n_gpus_per_node=8 \
recover_mode=auto \
recover_retries=10 \
torch_cache_mysophobia=True

View File

@ -20,7 +20,6 @@ BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
# original data
DATA_PATH="/storage/datasets/${DATASET_NAME}"
REAL_MATH_METADATA_PATH="/storage/datasets/id2info.json"
# Option 1: The experiment runs locally with subprocesses.
# MODE=local
@ -55,7 +54,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
# It's the user's responsibility to tune them appropriately.
unset CLUSTER_SPEC_PATH
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
python3 -m realhf.apps.quickstart ppo-math \
mode=$MODE \
@ -67,8 +65,6 @@ python3 -m realhf.apps.quickstart ppo-math \
exp_ctrl.ckpt_freq_secs=600 \
group_size=${GROUP_SIZE} \
group_adv_norm=False \
use_dense_reward=False \
reward_delta=True \
rw_type=sparse \
check_xml_format=False \
actor.type._class=$MODEL_FAMILY \
@ -85,10 +81,6 @@ python3 -m realhf.apps.quickstart ppo-math \
critic.path=$BASE_MODEL_PATH\
ref.type._class=$MODEL_FAMILY \
ref.path=$BASE_MODEL_PATH \
rew.type._class=$MODEL_FAMILY \
rew.type.is_critic=True \
rew.init_critic_from_actor=True \
rew.path=$BASE_MODEL_PATH \
dataset.path=$DATA_PATH \
dataset.max_prompt_len=2048 \
dataset.train_bs_n_seqs=${TRAIN_BATCH_SIZE} \
@ -103,6 +95,7 @@ python3 -m realhf.apps.quickstart ppo-math \
ppo.reward_output_scaling=5 \
ppo.reward_output_bias=0.0 \
ppo.adv_norm=True ppo.value_norm=True \
ppo.fuse_rew_ref=False \
mask_too_long=False \
ppo.discount=1.0 \
actor.optimizer.lr=1e-6 \
@ -110,7 +103,7 @@ python3 -m realhf.apps.quickstart ppo-math \
actor.optimizer.lr_scheduler_type=constant \
actor_gen.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
ref_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
rew_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
actor_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
critic_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
actor_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
critic_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \

View File

@ -20,7 +20,6 @@ BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
# original data
DATA_PATH="/storage/datasets/${DATASET_NAME}"
REAL_MATH_METADATA_PATH="/storage/datasets/id2info.json"
# Option 1: The experiment runs locally with subprocesses.
# MODE=local
@ -55,7 +54,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
# It's the user's responsibility to tune them appropriately.
unset CLUSTER_SPEC_PATH
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
python3 -m realhf.apps.quickstart ppo-math \
mode=$MODE \
@ -67,8 +65,6 @@ python3 -m realhf.apps.quickstart ppo-math \
exp_ctrl.ckpt_freq_secs=600 \
group_size=${GROUP_SIZE} \
group_adv_norm=False \
use_dense_reward=False \
reward_delta=True \
rw_type=sparse \
check_xml_format=False \
actor.type._class=$MODEL_FAMILY \
@ -85,10 +81,6 @@ python3 -m realhf.apps.quickstart ppo-math \
critic.path=$BASE_MODEL_PATH\
ref.type._class=$MODEL_FAMILY \
ref.path=$BASE_MODEL_PATH \
rew.type._class=$MODEL_FAMILY \
rew.type.is_critic=True \
rew.init_critic_from_actor=True \
rew.path=$BASE_MODEL_PATH \
dataset.path=$DATA_PATH \
dataset.max_prompt_len=2048 \
dataset.train_bs_n_seqs=${TRAIN_BATCH_SIZE} \
@ -103,6 +95,7 @@ python3 -m realhf.apps.quickstart ppo-math \
ppo.reward_output_scaling=5 \
ppo.reward_output_bias=0.0 \
ppo.adv_norm=True ppo.value_norm=True \
ppo.fuse_rew_ref=False \
mask_too_long=False \
ppo.discount=1.0 \
actor.optimizer.lr=1e-6 \
@ -110,7 +103,7 @@ python3 -m realhf.apps.quickstart ppo-math \
actor.optimizer.lr_scheduler_type=constant \
actor_gen.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
ref_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
rew_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
actor_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
critic_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
actor_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
critic_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \

View File

@ -20,7 +20,6 @@ BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
# original data
DATA_PATH="/storage/datasets/${DATASET_NAME}"
REAL_MATH_METADATA_PATH="/storage/datasets/id2info.json"
# Option 1: The experiment runs locally with subprocesses.
# MODE=local
@ -55,7 +54,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
# It's the user's responsibility to tune them appropriately.
unset CLUSTER_SPEC_PATH
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
python3 -m realhf.apps.quickstart ppo-math \
mode=$MODE \
@ -67,8 +65,6 @@ python3 -m realhf.apps.quickstart ppo-math \
exp_ctrl.ckpt_freq_secs=600 \
group_size=${GROUP_SIZE} \
group_adv_norm=False \
use_dense_reward=False \
reward_delta=True \
rw_type=sparse \
check_xml_format=True \
actor.type._class=$MODEL_FAMILY \
@ -85,10 +81,6 @@ python3 -m realhf.apps.quickstart ppo-math \
critic.path=$BASE_MODEL_PATH\
ref.type._class=$MODEL_FAMILY \
ref.path=$BASE_MODEL_PATH \
rew.type._class=$MODEL_FAMILY \
rew.type.is_critic=True \
rew.init_critic_from_actor=True \
rew.path=$BASE_MODEL_PATH \
dataset.path=$DATA_PATH \
dataset.max_prompt_len=2048 \
dataset.train_bs_n_seqs=${TRAIN_BATCH_SIZE} \
@ -103,6 +95,7 @@ python3 -m realhf.apps.quickstart ppo-math \
ppo.reward_output_scaling=0.5 \
ppo.reward_output_bias=-1.0 \
ppo.adv_norm=True ppo.value_norm=True \
ppo.fuse_rew_ref=False \
mask_too_long=False \
ppo.discount=1.0 \
actor.optimizer.lr=1e-6 \
@ -110,7 +103,7 @@ python3 -m realhf.apps.quickstart ppo-math \
actor.optimizer.lr_scheduler_type=constant \
actor_gen.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
ref_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
rew_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
actor_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
critic_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
actor_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
critic_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \

View File

@ -5,38 +5,25 @@ import sys
import time
import traceback
from io import StringIO
from function.testing_util import run_test
from typing import Dict, List
from functioncall.base import logging
from functioncall.code.function.testing_util import run_test
logger = logging.getLogger("Functioncall")
def try_load_json(data):
"""
Attempts to load the given data as JSON. If successful, returns the parsed JSON.
Otherwise, returns None and logs an error.
"""
try:
loaded_data = json.loads(data)
return loaded_data
except json.JSONDecodeError as e:
logger.info(f"Failed to load JSON: {e}")
return data
def capture_stdout(code):
original_stdout = sys.stdout
fake_stdout = StringIO()
try:
sys.stdout = fake_stdout # 重定向输出
exec(code, {"__builtins__": __builtins__}) # 在隔离环境中执行
sys.stdout = fake_stdout
exec(code, {"__builtins__": __builtins__})
except Exception as e:
return f"error: {str(e)}, traceback: {traceback.format_exc()}"
finally:
sys.stdout = original_stdout # 恢复原stdout
sys.stdout = original_stdout
return fake_stdout.getvalue()
@ -46,12 +33,17 @@ def _temp_run(problem, generation, debug, result):
try:
if debug:
logger.debug(f"Running test for problem: {problem}")
result.append(run_test(sample=problem, test=generation, debug=debug))
r = run_test(sample=problem, test=generation, debug=debug)
result.append(r)
if debug:
logger.debug(f"Test completed with result: {result}")
except Exception as e:
if debug:
logger.error(f"Error in _temp_run: {e}, problem:{problem}", exe_info=True)
logger.warning(
f"Error in _temp_run: {e}\n"
f"traceback: {''.join(traceback.format_exception(*sys.exc_info()))}\n"
f"problem:{problem}"
)
execution_time = time.time() - start_time
logger.info(
@ -64,8 +56,9 @@ def check_correctness(problem, generation, timeout, debug=False):
The global timeout is to catch some extreme/rare cases not handled by the timeouts
inside `run_test`"""
if debug:
# FIXME: error variable "problem" is not defined
result = capture_stdout(
"from function.testing_util import run_test\n"
"from functioncall.code.function.testing_util import run_test\n"
+ "run_test(sample=problem, test=generation, debug=debug)"
)
return result[0], result[1]
@ -86,7 +79,7 @@ def check_correctness(problem, generation, timeout, debug=False):
# Remark: ideally we would consider that all tests failed but we can't access number of tests here easily
# so we use 21=the average number of tests for a smaple in the test split instead
avg_number_tests = 21
result = [[-1] * avg_number_tests]
result = [[-1 for _ in range(avg_number_tests)], {}]
if debug:
logger.debug(f"Global timeout occurred, returning default result.")
if debug:
@ -99,96 +92,56 @@ def check_correctness(problem, generation, timeout, debug=False):
return result[0]
def load_problems(path):
problem_map = {}
for idx, line in enumerate(open(path, "rb")):
if line is not None:
line = line.strip().decode("utf-8")
row = json.loads(line)
query_id = str(row["id"])
problem_map[query_id] = {
"problem_id": query_id,
# "question": row["question"],
"solutions": row["solutions"],
"input_output": row["input_output"],
# "difficulty": level,
# "url": row["url"],
# "starter_code": row["starter_code"],
}
return problem_map
global_problems = load_problems(
os.getenv(
"REAL_CODE_METADATA_PATH",
"/storage/datasets/codeparrot-apps-test.jsonl",
)
)
def code_verify(generateds, query_ids, debug=False):
assert len(generateds) == len(query_ids), (
len(generateds),
len(query_ids),
)
def code_verify(id2info, generateds, query_ids, debug=False):
assert len(generateds) == len(query_ids)
problems = [id2info[qid] for qid in query_ids]
result = []
global global_problems
for idx, query_id in enumerate(query_ids):
if query_id not in global_problems:
continue
problem = global_problems[query_id]
test_code = generateds[idx]
for query_id, generated, problem in zip(query_ids, generateds, problems):
logger.debug(f"run_batch_code, query_id: {query_id}")
try:
curr_res, metadata = check_correctness(
problem=problem, generation=test_code, timeout=6000, debug=debug
problem=problem, generation=generated, timeout=6000, debug=debug
)
if any(x != True for x in curr_res):
logger.debug(
f'id:{problem["problem_id"]}, Results were not all True: {metadata}'
)
result.append(f"{query_id} failed")
logger.debug(f"id:{query_id}, Results were not all True: {metadata}")
result.append(0)
else:
# print(f"id:{problem["problem_id"]}, result : {curr_res}")
result.append(f"{query_id} success")
result.append(1)
except Exception as e:
logger.error(f"test framework exception = {repr(e)}{e}\n", exe_info=True)
result.append(f"{query_id} failed")
break
finally:
# print(f"id:{problem["problem_id"]}, result : {curr_res}")
# assert isinstance(curr_res, list)
pass
exc_info = sys.exc_info()
logger.error(
f"test framework exception = {repr(e)}{e}\n{traceback.format_exception(*exc_info)}"
)
result.append(0)
return result
if __name__ == "__main__":
path = "/storage/openpsi/data/code/apps/test.jsonl"
data = []
with open(path, "r") as f:
code_data = [json.loads(l) for l in f.readlines()]
def create_test_params(index_list):
global global_problems
codes, query_ids = [], []
id2info = {}
solutions = []
query_ids = []
for i in range(10):
problem = code_data[i]
problem["problem_id"] = problem["id"]
id2info[problem["problem_id"]] = problem
solutions.append(json.loads(problem["solutions"])[0])
query_ids.append(problem["id"])
for index in index_list:
if str(index) not in global_problems:
continue
problem = global_problems[str(index)]
if "solutions" not in problem or not problem["solutions"]:
continue
codes.append(try_load_json(problem["solutions"])[0])
query_ids.append(str(index))
return codes, query_ids
codes, query_ids = create_test_params(list(range(805, 806)))
result = code_verify(codes, query_ids, True)
result = code_verify(
id2info,
solutions,
query_ids,
debug=False,
)
print(result)

View File

@ -1,5 +1,6 @@
import json
import os
import random
from collections import defaultdict
from functioncall.base import logging
@ -8,29 +9,14 @@ from functioncall.base.call import batch_function_call
logger = logging.getLogger("Functioncall")
def try_load_json(data):
"""
Attempts to load the given data as JSON. If successful, returns the parsed JSON.
Otherwise, returns None and logs an error.
"""
try:
loaded_data = json.loads(data)
return loaded_data
except json.JSONDecodeError as e:
# print(f"Failed to load JSON: {e}")
return data
def load_problems_with_testcase_batch(path, debug=False, test_case_batch_size=None):
def load_problems_with_testcase_batch(
id2info, query_ids, debug=False, test_case_batch_size=None
):
problem_map = defaultdict(list)
for idx, line in enumerate(open(path, "rb")):
if line is None:
continue
for idx, query_id in enumerate(query_ids):
problem = id2info[query_id]
# parse one problem
row = json.loads(line.strip().decode("utf-8"))
query_id = str(row.get("id", row.get("query_id")))
input_output = json.loads(row["input_output"]) if "input_output" in row else {}
input_output = json.loads(problem["input_output"])
inputs = input_output.get("inputs", [])
outputs = input_output.get("outputs", [])
assert len(inputs) == len(
@ -57,17 +43,14 @@ def load_problems_with_testcase_batch(path, debug=False, test_case_batch_size=No
"batche_index": batch_idx,
}
if debug:
sub_problem["solutions"] = row.get("solutions", [])
sub_problem["solutions"] = problem.get("solutions", [])
problem_map[query_id].append(sub_problem)
return problem_map
global_problems = None
def code_verify(
generateds, query_ids, debug=False, timeout=1000, timeout_for_testcase=6
id2info, generateds, query_ids, debug=False, timeout=1000, timeout_for_testcase=6
):
assert len(generateds) == len(query_ids), (
len(generateds),
@ -75,24 +58,13 @@ def code_verify(
)
payload_list = []
global global_problems
if global_problems is None:
global_problems = load_problems_with_testcase_batch(
os.getenv(
"REAL_CODE_METADATA_PATH",
"/storage/datasets/codeparrot-apps-test.jsonl",
),
debug=True,
test_case_batch_size=20,
)
global_problems = load_problems_with_testcase_batch(
id2info,
query_ids,
debug=True,
test_case_batch_size=20,
)
for idx, query_id in enumerate(query_ids):
if query_id not in global_problems:
payload_list.append(None)
logger.warning(
f"Invalid query id : {query_id}, type: {type(query_id)}, should be in problem dataset!"
)
continue
problems = global_problems[query_id]
for problem in problems:
payload_list.append(
@ -129,34 +101,28 @@ def code_verify(
if __name__ == "__main__":
path = "/storage/openpsi/data/code/apps/codeparrot-apps-test.jsonl"
data = []
with open(path, "r") as f:
code_data = [json.loads(l) for l in f.readlines()]
id2info = {}
def create_test_params(count=10):
global global_problems
if global_problems is None:
global_problems = load_problems_with_testcase_batch(
os.getenv(
"REAL_CODE_METADATA_PATH",
"/storage/datasets/codeparrot-apps-test.jsonl",
),
debug=True,
test_case_batch_size=20,
)
codes, query_ids = [], []
idx = 0
for query_id, problems in global_problems.items():
if idx >= count:
break
problem = problems[0]
if "solutions" not in problem or not problem["solutions"]:
global id2info
query_ids = []
generateds = []
cnt = 0
while cnt < count:
d = random.choice(code_data)
if not d["solutions"]:
continue
id2info[d["id"]] = d
query_ids.append(d["id"])
generateds.append(d["solutions"][0])
cnt += 1
return generateds, query_ids
codes.append(try_load_json(problem["solutions"])[0])
query_ids.append(query_id)
idx += 1
return codes, query_ids
codes, query_ids = create_test_params(100)
result = code_verify(codes, query_ids, True)
generateds, query_ids = create_test_params(100)
result = code_verify(id2info, generateds, query_ids, True)
print(result)

View File

@ -7,7 +7,7 @@ from grader import math_equal
def process_results(answer, solution):
extracted_answer = extract_answer(answer, "math", use_last_number=False)
extracted_solution = solution
extracted_solution = extract_answer(solution, "math", use_last_number=True)
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
retval = 0

View File

@ -9,28 +9,9 @@ from functioncall.base.call import batch_function_call
logger = logging.getLogger("Functioncall")
def loadJson(dataDir):
with open(dataDir, "r") as f:
if dataDir.endswith(".jsonl"):
samples = [json.loads(line) for line in f.readlines()]
else:
samples = json.load(f)
return samples
id2info = None
def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=1000) -> List:
global id2info
if id2info is None:
id2info = loadJson(
os.getenv(
"REAL_MATH_MEATADATA_PATH",
"/storage/datasets/id2info.json",
)
)
def math_verify(
id2info, generateds: List, query_ids: List, batch_size=10, timeout=1000
) -> List:
assert len(generateds) == len(query_ids), (
len(generateds),
len(query_ids),
@ -43,7 +24,7 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=1000)
for idx, (query_id, generated) in enumerate(zip(query_ids, generateds)):
base_query_id = query_id.split("@idx:")[0]
info = id2info[base_query_id]
for cur_solution in info["answers"]:
for cur_solution in info["solutions"]:
parameters.append((generated, cur_solution, idx))
query_indices.append(idx)
@ -95,27 +76,19 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=1000)
if __name__ == "__main__":
# sample = {
# "prompt": "",
# "query_id": "fe11b471-1aa9-4867-958f-a0a811c85f92",
# "answer": "\\boxed{-\\frac{1}{30}}",
# }
if id2info is None:
id2info = loadJson(
os.getenv(
"REAL_MATH_MEATADATA_PATH",
"/storage/datasets/id2info.json",
)
)
answers = []
query_ids = []
for id, value in id2info.items():
answers.append(value["solutions"][0])
query_ids.append(id)
sample = {
"answers": ["-\\frac{2}{3}"],
"solutions": [
"1. **Apply the operation $\\otimes$ to the innermost parentheses first:**\n \\[\n (1 \\otimes 2) \\otimes 3 = \\left(\\frac{1^2}{2}\\right) \\otimes 3 = \\frac{1}{2} \\otimes 3\n \\]\n \\[\n 1 \\otimes (2 \\otimes 3) = 1 \\otimes \\left(\\frac{2^2}{3}\\right) = 1 \\otimes \\frac{4}{3}\n \\]\n\n2. **Calculate each part using the definition of $\\otimes$:**\n \\[\n \\frac{1}{2} \\otimes 3 = \\frac{\\left(\\frac{1}{2}\\right)^2}{3} = \\frac{\\frac{1}{4}}{3} = \\frac{1}{12}\n \\]\n \\[\n 1 \\otimes \\frac{4}{3} = \\frac{1^2}{\\frac{4}{3}} = \\frac{1}{\\frac{4}{3}} = \\frac{3}{4}\n \\]\n\n3. **Subtract the two results:**\n \\[\n \\left(\\frac{1}{12}\\right) - \\left(\\frac{3}{4}\\right) = \\frac{1}{12} - \\frac{9}{12} = -\\frac{8}{12} = -\\frac{2}{3}\n \\]\n\n4. **Conclude with the final answer:**\n \\[\n \\boxed{A}\n \\]",
"\\boxed{-\\frac{2}{3}}",
],
}
id2info = {"fe11b471-1aa9-4867-958f-a0a811c85f92": sample}
start_time = time.time()
result = math_verify(answers[:200], query_ids[:200])
result = math_verify(
id2info,
sample["answers"] * 100,
["fe11b471-1aa9-4867-958f-a0a811c85f92" for _ in range(100)],
)
print(result)

View File

@ -43,6 +43,8 @@ from realhf.base.cluster import spec as cluster_spec
logger = logging.getLogger("api.data")
RL_TASKS = ["math", "code", "rlhf"]
def load_hf_tokenizer(
model_name_or_path: str,
@ -529,6 +531,7 @@ class SequenceSample:
"rewards",
"greedy_rewards",
"base_scores",
"task_ids",
]:
return [[1] for _ in seqlens]
elif key in [
@ -720,9 +723,6 @@ def load_shuffle_split_dataset(
if dataset_path.endswith(".jsonl"):
with open(dataset_path, "r") as f:
data = [json.loads(ff) for ff in f]
elif dataset_path.endswith(".json"):
with open(dataset_path, "r") as f:
data = json.load(f)
else:
raise NotImplementedError(f"Unknown dataset extension: {dataset_path}")
else:

View File

@ -160,8 +160,6 @@ def main_start(args, recover_count: int = 0):
CLUSTER_SPEC_PATH=cluster_spec_path,
REAL_RECOVER_RUN="1" if is_recover_run else "0",
REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0",
REAL_MATH_METADATA_PATH=os.getenv("REAL_MATH_METADATA_PATH", ""),
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", "localhost:2379"),
)

View File

@ -513,6 +513,10 @@ def tp_and_pp_group():
return grid().get_model_parallel_group()
def tp_and_pp_cpu_group():
return grid().ds_model_proc_group_gloo
def tp_and_pp_rank():
"""Used as the rank in the world group of vLLM."""
return grid().get_model_parallel_rank()

View File

@ -206,9 +206,9 @@ class LocalMultiProcessTest:
def init_global_constants(
num_dp,
num_mp,
num_pp,
num_dp=1,
num_mp=1,
num_pp=1,
topo=None,
model_name=None,
msid2mwid=None,
@ -217,7 +217,12 @@ def init_global_constants(
gradient_accumulation_fusion=False,
max_prompt_len=None,
is_train: bool = True,
expr_name=None,
trial_name=None,
):
expr_name = expr_name if expr_name is not None else _DEFAULT_EXPR_NAME
trial_name = trial_name if trial_name is not None else _DEFAULT_TRIAL_NAME
constants.set_experiment_trial_names(expr_name, trial_name)
model_name = model_name if model_name is not None else MODEL_NAME
if topo is None:

View File

@ -21,7 +21,7 @@ def check_is_realhf_native_impl(_cls):
def check_is_realhf_native_model_interface(name):
# NOTE: we should not import iterfaces here,
# such that we can avoid CUDA initialization.
return name in ["ppo_actor", "ppo_critic", "sft"]
return name in ["ppo_actor", "ppo_critic", "sft", "rw-math-code", "fused-threading"]
def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]):

View File

@ -1,562 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import copy
import dataclasses
import math
import os
import pprint
from typing import *
import numpy as np
from omegaconf import DictConfig, OmegaConf
import realhf.base.logging as logging
from realhf.api.core.config import (
DatasetAbstraction,
ModelInterfaceAbstraction,
ModelInterfaceType,
)
from realhf.api.core.dfg import MFCDef, ParamReallocHook
from realhf.api.core.model_api import GenerationHyperparameters
from realhf.api.core.system_api import ExperimentConfig
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
from realhf.api.quickstart.device_mesh import DeviceMesh, MFCConfig, RPCAllocation
from realhf.api.quickstart.entrypoint import register_quickstart_exp
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
from realhf.experiments.common.common import CommonExperimentConfig
from realhf.experiments.common.utils import resolve_replica_ids, resolve_rpc_hooks
logger = logging.getLogger("PPO Code exp", "colored")
@dataclasses.dataclass
class PPOHyperparameters:
"""Configuration of PPO hyperparameters.
:param gen: Generation hyperparameters.
:type gen: GenerationHyperparameters
:param ppo_n_minibatches: Number of minibatches in each PPO update.
:type ppo_n_minibatches: int
:param kl_ctl: Coefficient of KL divergence rewards.
:type kl_ctl: float
:param discount: Discount factor.
:type discount: float
:param gae_lambda: Lambda factor in GAE.
:type gae_lambda: float
:param eps_clip: PPO actor probability ratio clipping factor.
:type eps_clip: float
:param value_eps_clip: PPO value clipping factor.
:type value_eps_clip: float
:param max_reward_clip: Maximum reward value.
:type max_reward_clip: float
:param reward_output_scaling: Scaling factor of the reward model output.
:type reward_output_scaling: float
:param reward_output_bias: Bias of the reward model output.
The number outputed by the reward model will be
CLIP((x - bias) * scaling, -max_reward_clip, max_reward_clip).
:type reward_output_bias: float
:param early_stop_imp_ratio: PPO update will be early stopped if importance ratio
exceeds this maximum value.
:type early_stop_imp_ratio: float
:param use_adaptive_kl_ctl: Whether to use adaptive KL divergence coefficient.
:type use_adaptive_kl_ctl: bool
:param adv_norm: Whether to use advantage normalization.
:type adv_norm: bool
:param value_norm: Whether to denormalize valued and normalize return predictions.
:type value_norm: bool
:param value_norm_type: Type of value normalization.
Either exponential moving average ("exp") or moving average ("ma").
:type value_norm_type: str
:param value_norm_beta: Exponential decay factor
in exponential moving average.
:type value_norm_beta: float
:param value_norm_eps: Epsilon factor in the
denominator of exponential moving average.
:type value_norm_eps: float
:param disable_value: A shortcut option to disable the critic model.
:type disable_value: bool
"""
gen: GenerationHyperparameters = dataclasses.field(
default_factory=GenerationHyperparameters
)
ppo_n_minibatches: int = 4
kl_ctl: float = 0.1
discount: float = 1.0
gae_lambda: float = 1.0
eps_clip: float = 0.2
value_eps_clip: float = 0.2
disable_value: bool = False
max_reward_clip: float = 20.0
reward_output_scaling: float = 1.0
reward_output_bias: float = 0.0
early_stop_imp_ratio: float = 5.0
use_adaptive_kl_ctl: bool = False
adv_norm: bool = True
value_norm: bool = True
value_norm_type: str = dataclasses.field(
metadata={"choices": ["exp", "ma"]}, default="exp"
)
value_norm_beta: float = 0.99995
value_norm_eps: float = 1e-5
@dataclasses.dataclass
class PPOCODEConfig(CommonExperimentConfig):
"""PPO experiment configuration.
It is a subclass of :class:`CommonExperimentConfig`,
so all CLI options in the base class are available.
We don't implement runtime evaluation for PPO.
We identify that the RLHF process is composed of four
distinct models with independent parameters and six
*model function calls* upon these models.
The four models are\:
- Actor\: The primary LLM that generates text.
- Critic\: The value function that estimates the value of a state.
- Ref\: The reference LLM that provides KL regularization.
- Rew\: The reward model that provides reward signals.
The four model function calls and their dependencies are\:
- Rollout\: Generate text from the actor model.
- InfReward\: Infer rewards from the reward model given generated text.
- InfRef\: Infer log probabilities from the reference model given generated text.
- InfValues\: Infer values from the critic model given generated text.
- TrainActor\: Train the actor model given generated text, rewards, values, and reference log probabilities.
- TrainCritic\: Train the critic model given generated text, rewards, values, and reference log probabilities.
This class resolves these dependencies under the hood.
What the users should specify are the runtime configurations
of models and allocations of *each model function call*.
:param actor: Runtime configuration of the primary LLM.
:type actor: ModelTrainEvalConfig
:param critic: Runtime configuration of the critic model of PPO.
:type critic: ModelTrainEvalConfig
:param ref: Runtime configuration of the reference LLM.
:type ref: ModelTrainEvalConfig
:param rew: Runtime configuration of the reward LLM.
:type rew: ModelTrainEvalConfig
:param actor_train: :class:`MFCConfig` for TrainActor.
:type actor_train: MFCConfig
:param critic_train: :class:`MFCConfig` for TrainCritic.
:type critic_train: MFCConfig
:param actor_gen: :class:`MFCConfig` for Rollout.
:type actor_gen: MFCConfig
:param critic_inf: :class:`MFCConfig` for InfValues.
:type critic_inf: MFCConfig
:param rew_inf: :class:`MFCConfig` for InfReward.
:type rew_inf: MFCConfig
:param ref_inf: :class:`MFCConfig` for InfRef.
:type ref_inf: MFCConfig
:param dataset: Dataset configuration.
:type dataset: PromptOnlyDatasetConfig
:param ppo: Configuration for the PPO algorithm.
:type ppo: PPOHyperparameters
:param group_size: The number of answers remained for each prompt.
:type group_size: int
:param generation_size: The number of answers sampled for each prompt.
Among them, only `group_size` samples are remained according to
the reward score, aka best-of-n sampling. If None, use `group_size`.
:type generation_size: Optional[int]
:param mask_no_eos_with_zero: Whether to mask out the reward if an
answer is truncated due to exceeding the length limit.
:type mask_no_eos_with_zero: bool
:param mask_too_long: Whether to mask out the PPO loss if an
answer is truncated due to exceeding the length limit.
:type mask_too_long: bool
:param check_verifier_status: If True, raise an error
when the reward is all-zero. This usually indicates a bug
of the verifier.
:type check_verifier_status: bool
:param group_adv_norm: Whther to use grouped advantage
normaliztion in GRPO.
:type group_adv_norm: bool
"""
actor: ModelTrainEvalConfig = dataclasses.field(
default_factory=ModelTrainEvalConfig
)
critic: ModelTrainEvalConfig = dataclasses.field(
default_factory=ModelTrainEvalConfig
)
ref: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
rew: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
# for manual allocation only
actor_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
critic_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
actor_gen: MFCConfig = dataclasses.field(default_factory=MFCConfig)
critic_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
rew_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
dataset: PromptOnlyDatasetConfig = dataclasses.field(
default_factory=PromptOnlyDatasetConfig
)
ppo: PPOHyperparameters = dataclasses.field(default_factory=PPOHyperparameters)
group_size: int = 1
generation_size: Optional[int] = None
mask_no_eos_with_zero: bool = False
rm_output_scaling: Optional[float] = None
ref_ema_eta: Optional[float] = None
group_adv_norm: bool = False
mask_too_long: bool = False
rw_type: Optional[str] = "sparse"
task: str = "code"
check_xml_format: bool = False
use_dense_reward: bool = False
reward_delta: bool = True
check_verifier_status: bool = False
dataset_filter_threshold: float = 100.0
dataset_max_filter_percentage: float = 0.0
def __post_init__(self):
self.ppo_kwargs = dict(
n_minibatches=self.ppo.ppo_n_minibatches,
kl_ctl=self.ppo.kl_ctl,
discount=self.ppo.discount,
gae_lambda=self.ppo.gae_lambda,
eps_clip=self.ppo.eps_clip,
value_eps_clip=self.ppo.value_eps_clip,
max_reward_clip=self.ppo.max_reward_clip,
adaptive_kl_ctl=self.ppo.use_adaptive_kl_ctl,
value_norm=self.ppo.value_norm,
value_norm_type=self.ppo.value_norm_type,
value_norm_beta=self.ppo.value_norm_beta,
value_norm_eps=self.ppo.value_norm_eps,
disable_value=self.ppo.disable_value,
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
)
if self.rm_output_scaling is None:
self.rm_output_scaling = self.ppo.reward_output_scaling
@property
def models(self) -> Dict[str, ModelTrainEvalConfig]:
# role to config
if self.ppo.disable_value:
return {
"actor": self.actor,
# "critic": self.critic,
"ref": self.ref,
"reward": self.rew,
}
else:
return {
"actor": self.actor,
"critic": self.critic,
"ref": self.ref,
"reward": self.rew,
}
@property
def rpcs(self):
if (
self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens
> self.actor.vllm.max_seq_len_to_capture
):
raise RuntimeError(
f"vllm max seq len to capture {self.actor.vllm.max_seq_len_to_capture} is "
f"smaller than the prompt length + generation length {self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
)
if not os.path.exists(os.getenv("REAL_CODE_METADATA_PATH")):
raise RuntimeError(
"Dataset json path REAL_CODE_METADATA_PATH does not exist."
)
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
if not (domain.startswith("http://") and ":" in domain):
raise RuntimeError(
"function call address FUNCTIONCALL_SERVICE_DOMAIN is invalid."
)
# interfaces
actor_interface = ModelInterfaceAbstraction(
"ppo_actor",
args={
**copy.deepcopy(self.ppo_kwargs),
# NOTE: to_container converts the object to a dict
# It is used for unifying the profiling API, which requires to
# pass external interface configurations in the launch command.
# Customized dataclass objects will not work in that case.
"generation_config": (
OmegaConf.to_container(self.ppo.gen, resolve=True)
if isinstance(self.ppo.gen, (OmegaConf, DictConfig))
else dataclasses.asdict(self.ppo.gen)
),
"early_stop_imp_ratio": self.ppo.early_stop_imp_ratio,
"adv_norm": self.ppo.adv_norm,
"group_size": self.group_size,
"generation_size": self.generation_size,
"group_adv_norm": self.group_adv_norm,
"mask_too_long": self.mask_too_long,
"use_dense_reward": self.use_dense_reward,
"reward_delta": self.reward_delta,
},
)
ref_interface = copy.deepcopy(actor_interface)
ref_interface.args["enable_save"] = False
critic_interface = ModelInterfaceAbstraction(
"ppo_critic",
args={
**copy.deepcopy(self.ppo_kwargs),
"group_size": self.group_size,
"mask_too_long": self.mask_too_long,
"use_dense_reward": self.use_dense_reward,
"reward_delta": self.reward_delta,
},
)
critic_interface.args.pop("eps_clip")
rw_interface = ModelInterfaceAbstraction(
"rw_code",
args=dict(
rw_type=self.rw_type,
task=self.task,
check_xml_format=self.check_xml_format,
tokenizer_path=self.actor.path,
enable_save=False,
output_scaling=self.ppo.reward_output_scaling,
rm_output_scaling=self.rm_output_scaling,
output_bias=self.ppo.reward_output_bias,
group_size=self.group_size,
check_verifier_status=self.check_verifier_status,
max_sync_length=self.ppo.gen.max_new_tokens
+ self.dataset.max_prompt_len
+ 128,
),
)
rollout = MFCDef(
name="actor_gen",
model_name="actor",
mb_spec=self.actor_gen.mb_spec,
interface_type=ModelInterfaceType.GENERATE,
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=["packed_prompts"],
output_keys=[
"seq_no_eos_mask",
"packed_input_ids",
"packed_logprobs",
"prompt_mask",
],
n_seqs=self.dataset.train_bs_n_seqs,
)
inf_reward = MFCDef(
name="rew_inf",
model_name="reward",
mb_spec=self.rew_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
interface_impl=rw_interface,
model_type=self.rew.type,
model_path=self.rew.path,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=["packed_input_ids", "packed_prompts"],
output_keys=["rewards", "dense_rewards"],
n_seqs=self.dataset.train_bs_n_seqs,
)
inf_ref_inputs = ["packed_input_ids"]
inf_ref_logits = MFCDef(
name="ref_inf",
model_name="ref",
mb_spec=self.ref_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
model_type=self.ref.type,
model_path=self.ref.path,
interface_impl=ref_interface,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=inf_ref_inputs,
output_keys=["packed_ref_logprobs"],
n_seqs=self.dataset.train_bs_n_seqs,
)
inf_values = MFCDef(
name="critic_inf",
model_name="critic",
mb_spec=self.critic_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
interface_impl=critic_interface,
model_type=self.critic.type,
model_path=self.critic.path,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=["packed_input_ids", "seq_no_eos_mask"],
output_keys=["values"],
n_seqs=self.dataset.train_bs_n_seqs,
)
train_actor_inputs = [
"packed_input_ids",
"packed_logprobs",
"packed_ref_logprobs",
"rewards",
"dense_rewards",
"values",
"prompt_mask",
"seq_no_eos_mask",
]
if self.ppo.disable_value:
train_actor_inputs.remove("values")
train_actor = MFCDef(
name="actor_train",
model_name="actor",
mb_spec=self.actor_train.mb_spec,
interface_type=ModelInterfaceType.TRAIN_STEP,
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=train_actor_inputs,
log_return_value=True,
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
n_seqs=self.dataset.train_bs_n_seqs,
)
train_critic = MFCDef(
name="critic_train",
model_name="critic",
mb_spec=self.critic_train.mb_spec,
interface_type=ModelInterfaceType.TRAIN_STEP,
interface_impl=critic_interface,
model_type=self.critic.type,
model_path=self.critic.path,
input_keys=[
"packed_input_ids",
"packed_logprobs",
"packed_ref_logprobs",
"rewards",
"dense_rewards",
"values",
"prompt_mask",
"seq_no_eos_mask",
],
log_return_value=True,
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
n_seqs=self.dataset.train_bs_n_seqs,
)
if self.ppo.disable_value:
return {
"actor_gen": rollout,
"actor_train": train_actor,
# "critic_inf": inf_values,
# "critic_train": train_critic,
"ref_inf": inf_ref_logits,
"rew_inf": inf_reward,
}
else:
return {
"actor_gen": rollout,
"actor_train": train_actor,
"critic_inf": inf_values,
"critic_train": train_critic,
"ref_inf": inf_ref_logits,
"rew_inf": inf_reward,
}
@property
def allocations(self):
if self.ppo.disable_value:
return {
"actor_gen": self.actor_gen,
"actor_train": self.actor_train,
# "critic_inf": self.critic_inf,
# "critic_train": self.critic_train,
"ref_inf": self.ref_inf,
"rew_inf": self.rew_inf,
}
else:
return {
"actor_gen": self.actor_gen,
"actor_train": self.actor_train,
"critic_inf": self.critic_inf,
"critic_train": self.critic_train,
"ref_inf": self.ref_inf,
"rew_inf": self.rew_inf,
}
@property
def datasets(self):
return [
DatasetAbstraction(
"code_prompt",
args=dict(
dataset_path=self.dataset.path,
max_length=self.dataset.max_prompt_len,
fill_to_max_length=self.dataset.fill_to_max_length,
filter_threshold=self.dataset_filter_threshold,
max_filter_percentage=self.dataset_max_filter_percentage,
),
)
]
@property
def tokenizer_name_or_path(self) -> str:
return self.actor.path
@property
def search_kwargs(self):
return {
"num_gen_tokens": self.ppo.gen.max_new_tokens,
"n_ppo_minibatches": self.ppo.ppo_n_minibatches,
"seq_len": self.dataset.max_prompt_len,
}
@property
def max_prompt_len(self):
return self.dataset.max_prompt_len
def initial_setup(self) -> ExperimentConfig:
rpc_allocs = self._get_rpc_allocations()
resolve_replica_ids(rpc_allocs, self.models)
resolve_rpc_hooks(
rpc_allocs, self.models
) # inplace modify MFCDefs in rpc allocations
pprint.pprint(rpc_allocs)
######### update ref model using ema, ref_ema_eta = 0 means fixed ref model #########
def _find_rpc(name):
return next(alloc.rpc for alloc in rpc_allocs if alloc.rpc.name == name)
# Remove the offload hook of ref_inf, because
# we need to receive parameters from peer GPUs and update it immediately.
if self.ref_ema_eta is not None:
ref_inf = _find_rpc("ref_inf")
ref_inf._post_hooks = []
# Add an unidirectional parameter reallocation hook.
actor_train = _find_rpc("actor_train")
actor_train.add_post_hook(
ParamReallocHook(
target=ref_inf.model_name,
eta=self.ref_ema_eta,
)
)
######### The main difference from normal PPO #########
model_worker = self._get_model_worker_configs(rpc_allocs)
return ExperimentConfig(
exp_ctrl=self.exp_ctrl,
wandb=self.wandb,
model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs],
model_worker=model_worker,
)
register_quickstart_exp("ppo-code", PPOCODEConfig)

View File

@ -7,8 +7,6 @@ import os
import pprint
from typing import *
from omegaconf import DictConfig, OmegaConf
import realhf.base.logging as logging
from realhf.api.core.config import (
DatasetAbstraction,
@ -89,8 +87,6 @@ class PPOHyperparameters:
gae_lambda: float = 1.0
eps_clip: float = 0.2
value_eps_clip: float = 0.2
disable_value: bool = False
recompute_logprob: bool = False
max_reward_clip: float = 20.0
reward_output_scaling: float = 1.0
reward_output_bias: float = 0.0
@ -104,6 +100,10 @@ class PPOHyperparameters:
value_norm_beta: float = 0.99995
value_norm_eps: float = 1e-5
disable_value: bool = False
recompute_logprob: bool = False
fuse_rew_ref: bool = True
@dataclasses.dataclass
class PPOMATHConfig(CommonExperimentConfig):
@ -190,7 +190,6 @@ class PPOMATHConfig(CommonExperimentConfig):
default_factory=ModelTrainEvalConfig
)
ref: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
rew: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
# for manual allocation only
actor_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
@ -210,15 +209,11 @@ class PPOMATHConfig(CommonExperimentConfig):
group_size: int = 1
generation_size: Optional[int] = None
mask_no_eos_with_zero: bool = False
rm_output_scaling: Optional[float] = None
ref_ema_eta: Optional[float] = None
group_adv_norm: bool = False
mask_too_long: bool = False
rw_type: Optional[str] = "sparse"
task: str = "math"
check_xml_format: bool = False
use_dense_reward: bool = False
reward_delta: bool = True
check_verifier_status: bool = False
@ -244,26 +239,21 @@ class PPOMATHConfig(CommonExperimentConfig):
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
)
if self.rm_output_scaling is None:
self.rm_output_scaling = self.ppo.reward_output_scaling
@property
def models(self) -> Dict[str, ModelTrainEvalConfig]:
# role to config
reward = copy.deepcopy(self.actor)
models = {
"actor": self.actor,
"critic": self.critic,
"ref": self.ref,
"reward": reward,
}
if self.ppo.disable_value:
return {
"actor": self.actor,
# "critic": self.critic,
"ref": self.ref,
"reward": self.rew,
}
else:
return {
"actor": self.actor,
"critic": self.critic,
"ref": self.ref,
"reward": self.rew,
}
models.pop("critic")
if self.ppo.fuse_rew_ref:
models.pop("reward")
return models
@property
def rpcs(self):
@ -278,10 +268,6 @@ class PPOMATHConfig(CommonExperimentConfig):
f"smaller than the prompt length + generation length "
f"{self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
)
if not os.path.exists(os.getenv("REAL_MATH_METADATA_PATH")):
raise RuntimeError(
"Dataset json path REAL_MATH_METADATA_PATH does not exist."
)
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
if domain and (not (domain.startswith("http://") and ":" in domain)):
@ -305,12 +291,8 @@ class PPOMATHConfig(CommonExperimentConfig):
"generation_size": self.generation_size,
"group_adv_norm": self.group_adv_norm,
"mask_too_long": self.mask_too_long,
"use_dense_reward": self.use_dense_reward,
"reward_delta": self.reward_delta,
},
)
ref_interface = copy.deepcopy(actor_interface)
ref_interface.args["enable_save"] = False
critic_interface = ModelInterfaceAbstraction(
"ppo_critic",
@ -318,29 +300,31 @@ class PPOMATHConfig(CommonExperimentConfig):
**copy.deepcopy(self.ppo_kwargs),
"group_size": self.group_size,
"mask_too_long": self.mask_too_long,
"use_dense_reward": self.use_dense_reward,
"reward_delta": self.reward_delta,
},
)
critic_interface.args.pop("eps_clip")
rw_interface = ModelInterfaceAbstraction(
"rw_math",
"rw-math-code",
args=dict(
rw_type=self.rw_type,
task=self.task,
check_xml_format=self.check_xml_format,
dataset_path=self.dataset.path,
tokenizer_path=self.actor.path,
enable_save=False,
output_scaling=self.ppo.reward_output_scaling,
rm_output_scaling=self.rm_output_scaling,
output_bias=self.ppo.reward_output_bias,
rw_type=self.rw_type,
check_xml_format=self.check_xml_format,
group_size=self.group_size,
check_verifier_status=self.check_verifier_status,
max_sync_length=self.ppo.gen.max_new_tokens
+ self.dataset.max_prompt_len
+ 128,
),
)
ref_interface = copy.deepcopy(actor_interface)
ref_interface.args["enable_save"] = False
if self.ppo.fuse_rew_ref:
ref_interface = ModelInterfaceAbstraction(
"fused-threading",
args=dict(interfaces=dict(rew=rw_interface, ref=ref_interface)),
)
rollout_output_keys = [
"seq_no_eos_mask",
"packed_input_ids",
@ -357,7 +341,7 @@ class PPOMATHConfig(CommonExperimentConfig):
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=["packed_prompts"],
input_keys=["packed_prompts", "task_ids"],
output_keys=rollout_output_keys,
n_seqs=self.dataset.train_bs_n_seqs,
)
@ -379,18 +363,21 @@ class PPOMATHConfig(CommonExperimentConfig):
inf_reward = MFCDef(
name="rew_inf",
model_name="reward",
mb_spec=self.rew_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
interface_impl=rw_interface,
model_type=self.rew.type,
model_path=self.rew.path,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=["packed_input_ids", "packed_prompts"],
output_keys=["rewards", "dense_rewards"],
input_keys=["packed_input_ids", "packed_prompts", "task_ids"],
output_keys=["rewards"],
n_seqs=self.dataset.train_bs_n_seqs,
)
# add rew param into ref MFC
inf_ref_inputs = ["packed_input_ids"]
inf_ref_outputs = ["packed_ref_logprobs"]
if self.ppo.fuse_rew_ref:
inf_ref_inputs += ["packed_prompts", "task_ids"]
inf_ref_outputs += ["rewards"]
inf_ref_logits = MFCDef(
name="ref_inf",
model_name="ref",
@ -401,7 +388,7 @@ class PPOMATHConfig(CommonExperimentConfig):
interface_impl=ref_interface,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=inf_ref_inputs,
output_keys=["packed_ref_logprobs"],
output_keys=inf_ref_outputs,
output_key_remap=dict(logprobs="packed_ref_logprobs"),
n_seqs=self.dataset.train_bs_n_seqs,
)
@ -425,7 +412,6 @@ class PPOMATHConfig(CommonExperimentConfig):
"packed_logprobs",
"packed_ref_logprobs",
"rewards",
"dense_rewards",
"values",
"prompt_mask",
"seq_no_eos_mask",
@ -459,7 +445,6 @@ class PPOMATHConfig(CommonExperimentConfig):
"packed_logprobs",
"packed_ref_logprobs",
"rewards",
"dense_rewards",
"values",
"prompt_mask",
"seq_no_eos_mask",
@ -483,6 +468,8 @@ class PPOMATHConfig(CommonExperimentConfig):
rpcs.pop("critic_train")
if not self.ppo.recompute_logprob:
rpcs.pop("actor_inf")
if self.ppo.fuse_rew_ref:
rpcs.pop("rew_inf")
return rpcs
@property
@ -501,17 +488,18 @@ class PPOMATHConfig(CommonExperimentConfig):
allocs.pop("critic_train")
if not self.ppo.recompute_logprob:
allocs.pop("actor_inf")
if self.ppo.fuse_rew_ref:
allocs.pop("rew_inf")
return allocs
@property
def datasets(self):
return [
DatasetAbstraction(
"math_prompt",
"math_code_prompt",
args=dict(
dataset_path=self.dataset.path,
max_length=self.dataset.max_prompt_len,
fill_to_max_length=self.dataset.fill_to_max_length,
filter_threshold=self.dataset_filter_threshold,
max_filter_percentage=self.dataset_max_filter_percentage,
),

View File

@ -0,0 +1,223 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import json
import traceback
from collections import defaultdict
from typing import Callable, Dict, Hashable, List, Optional
import numpy as np
import torch.utils.data
from realhf.api.core import data_api
from realhf.base import logging
logger = logging.getLogger("Math Code Dataset")
def check_math_metadata_entries(data):
assert data["task"] == "math"
assert "query_id" in data
data["query_id"] = str(data["query_id"])
assert isinstance(data["prompt"], str)
assert isinstance(data["solutions"], list)
for sol in data["solutions"]:
assert isinstance(sol, str)
return data
def check_code_metadata_entries(data):
assert data["task"] == "code"
assert "query_id" in data
data["query_id"] = str(data["query_id"])
if "problem_id" not in data:
data["problem_id"] = data["query_id"]
assert isinstance(data["prompt"], str)
input_output = json.loads(data["input_output"])
assert len(input_output["inputs"]) == len(input_output["outputs"])
for inp, out in zip(input_output["inputs"], input_output["outputs"]):
assert isinstance(inp, str) and isinstance(out, str), (
inp,
out,
input_output.get("fn_name"),
)
return data
def load_metadata(path):
assert str(path).endswith(".jsonl"), path
with open(path, "r") as f:
data = [json.loads(l) for l in f.readlines()]
id2info = {}
omit_cnt = defaultdict(int)
task_cnt = defaultdict(int)
for d in data:
assert d["query_id"] not in d, (d["task"], d["query_id"])
try:
if "task" not in d:
d["task"] = "math"
logger.warning(f'Key "task" not found in the dataset. Use math as default task type.')
if d["task"] == "math":
d = check_math_metadata_entries(d)
elif d["task"] == "code":
d = check_code_metadata_entries(d)
except Exception as e:
logger.warning(
f"Data validation failed: query_id {d['query_id']}. "
f"Error: {traceback.format_exc()}. Omit it in the dataset."
)
omit_cnt[d["task"]] += 1
continue
id2info[d["query_id"]] = d
task_cnt[d["task"]] += 1
logger.warning(f"Number of ignored data: {dict(**omit_cnt)}")
return id2info, task_cnt
class MATHCodePromptDataset(torch.utils.data.Dataset):
def __init__(
self,
util: data_api.DatasetUtility,
max_length: Optional[int] = None,
dataset_path: Optional[str] = None,
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
filter_threshold: float = 1e4,
max_filter_percentage: float = 0.0,
):
"""Required keys: prompt, query_id, task=math/code, solutions.
For code dataset, they additionally require an "input_output" key.
"""
self._util = util
self.max_length = max_length
id2info, task_cnt = load_metadata(dataset_path)
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
prompts_str = [x["prompt"] for x in data]
self.ids = [x["query_id"] for x in data]
self.tasks_ids = [data_api.RL_TASKS.index(x["task"]) for x in data]
if "scores" in data[0]:
self.base_scores = [np.mean(x["scores"]) for x in data]
util.tokenizer.padding_side = "left"
prompt_encodings = util.tokenizer(
prompts_str,
truncation=True,
# max_length=max_length,
padding=False,
return_length=True,
return_attention_mask=False,
)
logger.info(f"{len(data)} samples, checking lengths (max_length={max_length})")
indices = [
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
]
logger.info(
f"{len(indices)} samples remain, among them {task_cnt['math']} are math data and {task_cnt['code']} are code data"
)
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
self.ids = [
str(self.ids[idx]) + f"@idx:{idx}-{util.dp_rank}" for idx in indices
]
if "scores" in data[0]:
self.base_scores = [self.base_scores[idx] for idx in indices]
assert all(len(x) == l for x, l in zip(self.prompts, self.prompt_lengths))
logger.info(f"Number of prompts in the dataset: {len(self.prompts)}")
self.active_indices = list(range(len(self.prompts)))
self.filter_threshold = filter_threshold
self.max_filter_percentage = max_filter_percentage
@property
def util(self):
return self._util
def __len__(self):
return len(self.active_indices)
def __getitem__(self, idx):
# print(self.base_scores)
idx = self.active_indices[idx]
data = dict(
task_ids=torch.tensor([self.tasks_ids[idx]], dtype=torch.long),
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
)
if hasattr(self, "base_scores"):
data["base_scores"] = torch.tensor(
[self.base_scores[idx]], dtype=torch.float32
)
return data_api.SequenceSample.from_default(
ids=[self.ids[idx]],
seqlens=[self.prompt_lengths[idx]],
data=data,
)
def filter(self, eval_scores: Dict[Hashable, float]):
# Get all data indices that have a higher score than the threshold.
idx2scores_to_remove = {}
for pop_idx, idx in enumerate(self.active_indices):
data_id = self.ids[idx]
if data_id not in eval_scores:
continue
if eval_scores[data_id] > self.filter_threshold:
idx2scores_to_remove[pop_idx] = eval_scores[data_id]
# Control the number of samples to be removed according to max_filter_percentage.
n = int(len(self.active_indices) * self.max_filter_percentage)
indices_to_remove = sorted(
idx2scores_to_remove.keys(),
key=lambda x: idx2scores_to_remove[x],
reverse=True,
)[:n]
for pop_idx in sorted(indices_to_remove, reverse=True):
self.active_indices.pop(pop_idx)
logger.info(
f"Math prompt dataset DP rank {self.util.dp_rank} filtered"
f" {len(indices_to_remove)} samples, {len(self.active_indices)} samples remain. "
f"Original dataset size: {len(self.prompts)}. "
f"Filter threshold: {self.filter_threshold}. "
f"Max filter percentage: {self.max_filter_percentage}. "
f"Current number of eval scores: {len(eval_scores)}."
)
if not __name__ == "__main__":
data_api.register_dataset("math_code_prompt", MATHCodePromptDataset)
else:
from transformers import AutoTokenizer
dataset = MATHCodePromptDataset(
data_api.DatasetUtility(
seed=0,
dp_rank=0,
world_size=1,
tokenizer=AutoTokenizer.from_pretrained(
"/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
),
),
max_length=512,
dataset_path='/storage/datasets/full_prompts_for_r1_distilled.jsonl'
)
dataloader = torch.utils.data.DataLoader(
dataset,
collate_fn=data_api.SequenceSample.gather,
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
batch_size=4,
shuffle=True,
)
print(f"size: {len(dataset)}")
for d in dataloader:
# print(d.ids)
pass

View File

@ -4,13 +4,10 @@ import json
import os
import signal
import subprocess
import time
import traceback
import uuid
from typing import *
from realhf.base import logging
from realhf.base.constants import parallelism_rank
logger = logging.getLogger("math parser")
@ -48,20 +45,7 @@ def loadJson(dataDir):
return samples
headers = {
"Content-Type": "application/json",
}
id2info = None
def parse_line(prompt_str, generated, query_id):
global id2info
if id2info is None:
try:
id2info = loadJson(os.environ["REAL_MATH_METADATA_PATH"])
except KeyError as e:
raise KeyError("The json file REAL_MATH_METADATA_PATH is not set") from e
def parse_line(id2info, prompt_str, generated, query_id):
info = id2info[query_id.split("@idx:")[0]]
tmp_id = str(uuid.uuid4())
@ -112,17 +96,12 @@ def parse_line(prompt_str, generated, query_id):
def parse_lines_in_parallel(
id2info,
generateds: List,
query_ids: List,
max_workers=22,
check_xml_format=False,
) -> List:
global id2info
if id2info is None:
try:
id2info = loadJson(os.environ["REAL_MATH_METADATA_PATH"])
except KeyError as e:
raise KeyError("The json file REAL_MATH_METADATA_PATH is not set") from e
assert len(generateds) == len(query_ids), (
len(generateds),
len(query_ids),
@ -204,17 +183,19 @@ def parse_lines_in_parallel(
if __name__ == "__main__":
sample = {
"prompt": "",
"query_id": "35ecd821a9e7e31da9ef0663a25347ce",
# "answer_in_box": ["\\boxed{\\frac{1}{2}}", "<think></think><answer>\\boxed{\\frac{1}{2}}</answer>"]
"answer": "<think>\n1. The problem requires us to determine the number of sequences of 144 hand movements such that every position appears exactly once and the hands return to the initial position at the end.\n2. We know that each movement involves one hand moving clockwise to the next number while the other hand stays in place.\n3. Considering the 12-hour clock, we can represent each positioning of the hands as a combination of the positions of both hands. Since both hands can be in any of the 12 positions, there are 12 x 12 = 144 different positionings.\n4. Given that at each position only one hand moves, every single movement is unique, leading to a total of 144 unique movements.\n5. These 144 movements must form a Hamiltonian Cycle, where each edge represents a valid movement between two positions.\n6. The problem thus reduces to finding a Hamiltonian cycle in a directed graph. Since the appearance of each movement is unique, it also determines the direction of the movement.\n7. We consider the edges that are rotations of each other as equivalent. Taking the rotational symmetry into account, we have 144/12 = 12 equivalence classes.\n8. The problem now is to determine the number of ways to arrange these 12 classes of rotations in a circle, which is 11 factorial.\n9. We must find the value of 11! and then compute the result modulo 1000.\n</think>\n<answer>\n320\n</answer>",
"answers": ["-\\frac{2}{3}"],
"solutions": [
"1. **Apply the operation $\\otimes$ to the innermost parentheses first:**\n \\[\n (1 \\otimes 2) \\otimes 3 = \\left(\\frac{1^2}{2}\\right) \\otimes 3 = \\frac{1}{2} \\otimes 3\n \\]\n \\[\n 1 \\otimes (2 \\otimes 3) = 1 \\otimes \\left(\\frac{2^2}{3}\\right) = 1 \\otimes \\frac{4}{3}\n \\]\n\n2. **Calculate each part using the definition of $\\otimes$:**\n \\[\n \\frac{1}{2} \\otimes 3 = \\frac{\\left(\\frac{1}{2}\\right)^2}{3} = \\frac{\\frac{1}{4}}{3} = \\frac{1}{12}\n \\]\n \\[\n 1 \\otimes \\frac{4}{3} = \\frac{1^2}{\\frac{4}{3}} = \\frac{1}{\\frac{4}{3}} = \\frac{3}{4}\n \\]\n\n3. **Subtract the two results:**\n \\[\n \\left(\\frac{1}{12}\\right) - \\left(\\frac{3}{4}\\right) = \\frac{1}{12} - \\frac{9}{12} = -\\frac{8}{12} = -\\frac{2}{3}\n \\]\n\n4. **Conclude with the final answer:**\n \\[\n \\boxed{A}\n \\]",
"\\boxed{-\\frac{2}{3}}",
],
}
id2info = {"fe11b471-1aa9-4867-958f-a0a811c85f92": sample}
print(
parse_lines_in_parallel(
# [sample["answer_in_box"][0] for _ in range(50)] + [sample["answer_in_box"][1] for _ in range(50)],
[sample["answer"]] * 100,
[sample["query_id"] for _ in range(100)],
id2info,
sample["answers"] * 100,
["fe11b471-1aa9-4867-958f-a0a811c85f92" for _ in range(100)],
max_workers=8,
check_xml_format=True,
)

View File

@ -2,7 +2,6 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import uuid
from typing import Callable, Dict, Hashable, List, Optional
import numpy as np
@ -85,312 +84,4 @@ class PromptDataset(torch.utils.data.Dataset):
)
class MATHPromptDataset(torch.utils.data.Dataset):
def __init__(
self,
util: data_api.DatasetUtility,
max_length: Optional[int] = None,
dataset_path: Optional[str] = None,
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
fill_to_max_length: bool = False,
filter_threshold: float = 1e4,
max_filter_percentage: float = 0.0,
):
"""A dataset with prompts. Usually used for PPO.
Args:
util (api.data.DatasetUtility): .
max_length (Optional[int], optional): The maximum length of each sequence in the batch.
dataset_path (Optional[str], optional): Path to the dataset json/jsonl file.
The json/jsonl file should be a list of dictionary. Each element in the list should have
a key "prompt". Defaults to None.
dataset_builder (Optional[Callable[[], List[Dict]]], optional): Alternative to dataset_path.
A callable that returns a list of dictionary. Defaults to None.
"""
self._util = util
self.max_length = max_length
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
prompts_str = [x["prompt"] for x in data]
self.ids = [x["query_id"] for x in data]
if "scores" in data[0]:
self.base_scores = [np.mean(x["scores"]) for x in data]
util.tokenizer.padding_side = "left"
prompt_encodings = util.tokenizer(
prompts_str,
truncation=True,
# max_length=max_length,
padding=False,
return_length=True,
return_attention_mask=False,
)
if fill_to_max_length:
for i in range(len(prompt_encodings["length"])):
x = prompt_encodings["input_ids"][i]
if max_length > len(x):
# Fill with the final non-pad token to the left.
prompt_encodings["input_ids"][i] = [x[-1]] * (
max_length - len(x)
) + x
prompt_encodings["length"][i] = max_length
logger.info(f"{len(data)} samples, checking lengths (max_length={max_length})")
indices = [
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
]
logger.info(f"{len(indices)} samples remain")
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
self.ids = [self.ids[idx] + f"@idx:{idx}-{util.dp_rank}" for idx in indices]
if "scores" in data[0]:
self.base_scores = [self.base_scores[idx] for idx in indices]
assert all(len(x) == l for x, l in zip(self.prompts, self.prompt_lengths))
logger.info(f"Number of prompts in the dataset: {len(self.prompts)}")
self.active_indices = list(range(len(self.prompts)))
self.filter_threshold = filter_threshold
self.max_filter_percentage = max_filter_percentage
@property
def util(self):
return self._util
def __len__(self):
return len(self.active_indices)
def __getitem__(self, idx):
# print(self.base_scores)
idx = self.active_indices[idx]
if hasattr(self, "base_scores"):
return data_api.SequenceSample.from_default(
ids=[self.ids[idx]],
seqlens=[self.prompt_lengths[idx]],
data=dict(
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
base_scores=torch.tensor(
[self.base_scores[idx]], dtype=torch.float32
),
),
)
else:
return data_api.SequenceSample.from_default(
ids=[self.ids[idx]],
seqlens=[self.prompt_lengths[idx]],
data=dict(
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long)
),
)
def filter(self, eval_scores: Dict[Hashable, float]):
# Get all data indices that have a higher score than the threshold.
idx2scores_to_remove = {}
for pop_idx, idx in enumerate(self.active_indices):
data_id = self.ids[idx]
if data_id not in eval_scores:
continue
if eval_scores[data_id] > self.filter_threshold:
idx2scores_to_remove[pop_idx] = eval_scores[data_id]
# Control the number of samples to be removed according to max_filter_percentage.
n = int(len(self.active_indices) * self.max_filter_percentage)
indices_to_remove = sorted(
idx2scores_to_remove.keys(),
key=lambda x: idx2scores_to_remove[x],
reverse=True,
)[:n]
for pop_idx in sorted(indices_to_remove, reverse=True):
self.active_indices.pop(pop_idx)
logger.info(
f"Math prompt dataset DP rank {self.util.dp_rank} filtered"
f" {len(indices_to_remove)} samples, {len(self.active_indices)} samples remain. "
f"Original dataset size: {len(self.prompts)}. "
f"Filter threshold: {self.filter_threshold}. "
f"Max filter percentage: {self.max_filter_percentage}. "
f"Current number of eval scores: {len(eval_scores)}."
)
class CODEPromptDataset(torch.utils.data.Dataset):
def __init__(
self,
util: data_api.DatasetUtility,
max_length: Optional[int] = None,
dataset_path: Optional[str] = None,
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
fill_to_max_length: bool = False,
filter_threshold: float = 1e4,
max_filter_percentage: float = 0.0,
):
"""A dataset with prompts. Usually used for PPO.
Args:
util (api.data.DatasetUtility): .
max_length (Optional[int], optional): The maximum length of each sequence in the batch.
dataset_path (Optional[str], optional): Path to the dataset json/jsonl file.
The json/jsonl file should be a list of dictionary. Each element in the list should have
a key "prompt". Defaults to None.
dataset_builder (Optional[Callable[[], List[Dict]]], optional): Alternative to dataset_path.
A callable that returns a list of dictionary. Defaults to None.
"""
self._util = util
self.max_length = max_length
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
prompts_str = [x["question"] for x in data]
self.ids = [x["id"] for x in data]
if "scores" in data[0]:
self.base_scores = [np.mean(x["scores"]) for x in data]
util.tokenizer.padding_side = "left"
prompt_encodings = util.tokenizer(
prompts_str,
truncation=True,
# max_length=max_length,
padding=False,
return_length=True,
return_attention_mask=False,
)
if fill_to_max_length:
for i in range(len(prompt_encodings["length"])):
x = prompt_encodings["input_ids"][i]
if max_length > len(x):
# Fill with the final non-pad token to the left.
prompt_encodings["input_ids"][i] = [x[-1]] * (
max_length - len(x)
) + x
prompt_encodings["length"][i] = max_length
logger.info(f"{len(data)} samples, checking lengths (max_length={max_length})")
indices = [
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
]
logger.info(f"{len(indices)} samples remain")
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
self.ids = [
str(self.ids[idx]) + f"@idx:{idx}-{util.dp_rank}" for idx in indices
]
if "scores" in data[0]:
self.base_scores = [self.base_scores[idx] for idx in indices]
assert all(len(x) == l for x, l in zip(self.prompts, self.prompt_lengths))
logger.info(f"Number of prompts in the dataset: {len(self.prompts)}")
self.active_indices = list(range(len(self.prompts)))
self.filter_threshold = filter_threshold
self.max_filter_percentage = max_filter_percentage
@property
def util(self):
return self._util
def __len__(self):
return len(self.active_indices)
def __getitem__(self, idx):
# print(self.base_scores)
print(
f"CODEPromptDataset idx{idx}, use idx: {self.active_indices[idx]}, size: {len(self.ids)}"
)
idx = self.active_indices[idx]
if hasattr(self, "base_scores"):
return data_api.SequenceSample.from_default(
ids=[self.ids[idx]],
seqlens=[self.prompt_lengths[idx]],
data=dict(
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
base_scores=torch.tensor(
[self.base_scores[idx]], dtype=torch.float32
),
),
metadata=dict(random_id=[uuid.uuid4()]),
)
else:
return data_api.SequenceSample.from_default(
ids=[self.ids[idx]],
seqlens=[self.prompt_lengths[idx]],
data=dict(
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long)
),
metadata=dict(random_id=[uuid.uuid4()]),
)
def filter(self, eval_scores: Dict[Hashable, float]):
# Get all data indices that have a higher score than the threshold.
idx2scores_to_remove = {}
for pop_idx, idx in enumerate(self.active_indices):
data_id = self.ids[idx]
if data_id not in eval_scores:
continue
if eval_scores[data_id] > self.filter_threshold:
idx2scores_to_remove[pop_idx] = eval_scores[data_id]
# Control the number of samples to be removed according to max_filter_percentage.
n = int(len(self.active_indices) * self.max_filter_percentage)
indices_to_remove = sorted(
idx2scores_to_remove.keys(),
key=lambda x: idx2scores_to_remove[x],
reverse=True,
)[:n]
for pop_idx in sorted(indices_to_remove, reverse=True):
self.active_indices.pop(pop_idx)
logger.info(
f"Code prompt dataset DP rank {self.util.dp_rank} filtered"
f" {len(indices_to_remove)} samples, {len(self.active_indices)} samples remain. "
f"Original dataset size: {len(self.prompts)}. "
f"Filter threshold: {self.filter_threshold}. "
f"Max filter percentage: {self.max_filter_percentage}. "
f"Current number of eval scores: {len(eval_scores)}."
)
if not __name__ == "__main__":
data_api.register_dataset("prompt", PromptDataset)
data_api.register_dataset("math_prompt", MATHPromptDataset)
data_api.register_dataset("code_prompt", CODEPromptDataset)
else:
from transformers import AutoTokenizer
dataset = MATHPromptDataset(
data_api.DatasetUtility(
seed=0,
dp_rank=0,
world_size=1,
tokenizer=AutoTokenizer.from_pretrained(
"/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
),
),
max_length=512,
dataset_path="/storage/openpsi/data/math/Qwen_RL_training_xss/0101/train_rl@0101_with_qwqsft-7b_score.jsonl",
)
dataset = CODEPromptDataset(
data_api.DatasetUtility(
seed=0,
dp_rank=0,
world_size=1,
tokenizer=AutoTokenizer.from_pretrained(
"/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
),
),
max_length=512,
dataset_path="/storage/datasets/codeparrot-apps-test.jsonl",
)
data_generator = enumerate(dataset)
dataset_batch_counter, cur_sample = next(data_generator)
print(
f"size: {len(dataset)}, index: {dataset_batch_counter}, cur_sample: {cur_sample}"
)
data_api.register_dataset("prompt", PromptDataset)

View File

@ -1,429 +0,0 @@
# Copyright 2025 Ant Group Inc.
import collections
import dataclasses
import html
import json
import os
import re
import xml.etree.ElementTree as ET
from ast import parse
from concurrent.futures import ThreadPoolExecutor, as_completed
import colorama
import numpy as np
import torch
import torch.distributed as dist
import realhf.api.core.model_api as model_api
import realhf.base.logging as logging
from functioncall.code.verify import code_verify
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer
from realhf.base import constants
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
class CodeVerifierException(Exception):
pass
def extract_python_code(text, min_length=20, strict_syntax=True):
code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```"
code_blocks = re.findall(code_pattern, text, re.DOTALL)
valid_blocks = []
for block in code_blocks:
clean_block = block.strip()
if len(clean_block) < min_length:
continue
# verify code syntax
if strict_syntax:
try:
parse(clean_block, mode="exec")
except (SyntaxError, IndentationError):
continue
valid_blocks.append(clean_block)
if not valid_blocks:
return None
# return the last code block
return valid_blocks[-1]
def check_with_elementtree(text):
def escape_between_tags(text, tags=["think", "answer"]):
"""转义标签之间的内容,但保留标签本身."""
# 构建标签模式
tag_pattern = "|".join(tags)
parts = []
current_pos = 0
# 匹配开始和结束标签
pattern = f"</?({tag_pattern})[^>]*>"
for match in re.finditer(pattern, text):
# 添加标签之前的内容(需要转义)
if current_pos < match.start():
parts.append(html.escape(text[current_pos : match.start()]))
# 添加标签本身(不转义)
parts.append(match.group())
current_pos = match.end()
# 添加最后剩余的内容
if current_pos < len(text):
parts.append(html.escape(text[current_pos:]))
return "".join(parts)
text = escape_between_tags(text)
if not text.strip().startswith("<think>"):
text = "<think>" + text
try:
xml_text = f"<root>{text}</root>"
x = ET.fromstring(xml_text)
if x.text is not None and x.text.strip() != "":
return False, f"Error: extra content before <think>. {x.text}"
if len(x) != 2:
return False, f"Error: there are {len(x)} tags."
if x[0].tag is None or x[0].tag != "think":
return False, f"Error: <think> tag is missing. {x[0].tag}"
if x[0].tail is not None and x[0].tail.strip() != "":
return (
False,
f"Error: extra content between <think> and <answer>. {x[0].tail}",
)
if x[1].tag is None or x[1].tag != "answer":
return False, f"Error: <answer> tag is missing. {x[1].tag}"
if x[1].tail is not None and x[1].tail.strip() != "":
return False, f"Error: extra content after <answer>, {x[1].tail}"
return True, x[1].text if x[1].text is not None else ""
except ET.ParseError as e:
return False, f"Error: XML格式错误, {str(e)}"
def retokenize(
task,
tokenizer,
packed_input_ids,
input_cu_seqlens,
prompts,
prompt_cu_seqlens,
query_ids,
check_xml_format=False,
do_eval=False,
):
input_ids = [
packed_input_ids[start:end]
for start, end in zip(input_cu_seqlens[:-1], input_cu_seqlens[1:])
]
prompt_ids = [
prompts[start:end]
for start, end in zip(prompt_cu_seqlens[:-1], prompt_cu_seqlens[1:])
]
seq_strs = tokenizer.batch_decode(
input_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
)
prompt_strs = tokenizer.batch_decode(
prompt_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
)
# query_id_strs = query_ids
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
format_rewards = []
queryid_to_results = collections.defaultdict(list)
# 8 processes on each node, with 10 subprocesses each
if do_eval == True:
_answers = [
seq_str.split(prompt_str)[1]
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
]
codes = [extract_python_code(_answer) for _answer in _answers]
logger.info(
f"code_rw_interface, size: {len(query_id_strs)}, valid code size: {len(codes)}, query_id_0: {query_id_strs[0]}"
)
format_rewards = code_verify(codes, query_id_strs)
if check_xml_format:
with ThreadPoolExecutor(max_workers=22) as executor:
futures = [
executor.submit(check_with_elementtree, answer_str)
for answer_str in _answers
]
# xml_rewards = []
for idx, future in enumerate(futures):
xml_reward, _ = future.result()
# xml_rewards.append(xml_reward)
if xml_reward == 1 and format_rewards[idx] == 0:
format_rewards[idx] = -0.8
elif xml_reward == 0 and format_rewards[idx] == 0:
format_rewards[idx] = -1
for query_id_str, format_reward in zip(query_id_strs, format_rewards):
if query_id_str not in queryid_to_results:
queryid_to_results[query_id_str] = []
queryid_to_results[query_id_str].append(format_reward)
else:
for query_id_str in query_id_strs:
if query_id_str not in queryid_to_results:
queryid_to_results[query_id_str] = []
queryid_to_results[query_id_str].append(0)
format_rewards.append(0)
return format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results
@dataclasses.dataclass
class PackedCodeRewardInterface(model_api.ModelInterface):
enable_save: bool = False
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
output_scaling: float = 1.0
rm_output_scaling: float = 1.0
rm_output_bias: float = 0.0
output_bias: float = 0.0
loss_fun = torch.nn.CrossEntropyLoss(reduction="none")
max_sync_length: int = 2048
rw_type: str = "sparse"
task: str = "code" # math or countdown or code
check_xml_format: bool = False
post_process: str = "sigmoid"
group_size: int = 1
check_verifier_status: bool = False
_call_count: int = 0
def __post_init__(self):
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
logger.info(f"rm_output_scaling: {self.rm_output_scaling}")
logger.info(f"rm_output_bias: {self.rm_output_bias}")
logger.info(f"output_scaling: {self.output_scaling}")
logger.info(f"output_bias: {self.output_bias}")
logger.info(f"max_sync_length: {self.max_sync_length}")
logger.info(f"rw_type: {self.rw_type}")
logger.info(f"post_process: {self.post_process}")
while True:
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
f"v{self._call_count}r{dist.get_rank()}.txt",
)
if os.path.exists(gen_file_path):
self._call_count += 1
else:
break
logger.info(f"call_count: {self._call_count}")
def inference(
self,
model: model_api.Model,
data_: SequenceSample,
mb_spec,
) -> SequenceSample:
packed_input_ids: torch.Tensor = data_.data["packed_input_ids"].squeeze()
input_seqlens = torch.tensor(data_.seqlens["packed_input_ids"]).view(-1)
input_cu_seqlens = torch.nn.functional.pad(
input_seqlens.cumsum(0), (1, 0)
).int()
packed_prompts = data_.data["packed_prompts"]
prompts = []
prompt_seqlens = []
offset = 0
for x in data_.seqlens["packed_prompts"]:
prompts += [packed_prompts[offset : offset + x[0]]] * self.group_size
offset += x[0]
prompt_seqlens.extend(x * self.group_size)
assert offset == sum(x[0] for x in data_.seqlens["packed_prompts"])
# non_packed_prompts = copy.deepcopy(prompts)
prompts = torch.cat(prompts)
prompt_seqlens = torch.tensor(prompt_seqlens).view(-1)
prompt_cu_seqlens = torch.nn.functional.pad(
prompt_seqlens.cumsum(0), (1, 0)
).int()
query_ids = [data_id for data_id in data_.ids for _ in range(self.group_size)]
format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results = (
retokenize(
self.task,
self.tokenizer,
packed_input_ids,
input_cu_seqlens,
prompts,
prompt_cu_seqlens,
query_ids,
check_xml_format=self.check_xml_format,
do_eval=constants.is_last_pipe_stage(),
)
)
assert self.rw_type == "sparse"
dense_scores = torch.zeros_like(packed_input_ids).float()
scores = torch.FloatTensor(format_rewards).to(packed_input_ids.device)
scores[scores == 0] = -1
if len(scores) == 0:
return None
assert dense_scores.shape == packed_input_ids.shape
scores = (
scores.to(packed_input_ids.device) - self.output_bias
) * self.output_scaling
logger.info(f"Code reward logging info @v{model.version.global_step}")
logger.info(
f"before: Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
)
logger.info(
f"number of samples: {len(scores)}, {scores.shape}, group_size: {self.group_size}"
)
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
f"v{self._call_count}r{dist.get_rank()}.txt",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
info = "\n".join(
[
f"idx: {idx} / {len(scores)}",
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
]
)
_f.write(info + "\n")
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated_jsonl",
f"v{self._call_count}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
_f.write(
json.dumps(
{
"prompt": prompt_str,
"generated": seq_str.split(prompt_str)[1],
"reward": score.item(),
},
ensure_ascii=False,
)
+ "\n"
)
logger.info(
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
)
pass_at_k = np.mean(
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
)
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
logger.info(f"reward: {sum(scores) / len(scores)}")
train_pass_monitor_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"training_monitor",
f"v{self._call_count}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
logger.info(
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
)
with open(train_pass_monitor_file_path, "w") as monitor_file:
for key, value in queryid_to_results.items():
pass1 = sum(value) / len(value)
pass8 = int(sum(value) > 0)
monitor_file.write(
json.dumps(
{"query_id": key, "pass1": pass1, "pass8": pass8},
ensure_ascii=False,
)
+ "\n"
)
self._call_count += 1
if scores.dtype != torch.float32:
scores = scores.to(torch.float32)
if dense_scores.dtype != torch.float32:
dense_scores = dense_scores.to(torch.float32)
res = SequenceSample(
keys=["rewards", "dense_rewards"],
trailing_shapes=dict(rewards=(), dense_rewards=()),
dtypes=dict(rewards=torch.float32, dense_rewards=torch.float32),
ids=data_.ids,
seqlens=dict(
rewards=[
torch.tensor([1 for _ in range(len(x))], dtype=torch.int32)
for x in data_.seqlens["packed_input_ids"]
],
dense_rewards=data_.seqlens["packed_input_ids"],
),
data=dict(rewards=scores, dense_rewards=dense_scores),
)
# record rewards for each piece of data
avg_scores = []
offset = 0
for i in range(data_.bs):
score_lis = scores[
offset : offset + len(data_.seqlens["packed_input_ids"][i])
]
avg_scores.append(score_lis.mean().item())
offset += len(data_.seqlens["packed_input_ids"][i])
assert offset == sum(len(x) for x in data_.seqlens["packed_input_ids"])
res.metadata["scores"] = avg_scores
if self.check_verifier_status:
avg_score = torch.tensor(
np.mean(avg_scores), device=constants.current_device()
)
dist.all_reduce(
avg_score, op=dist.ReduceOp.SUM, group=constants.parallelism_group()
)
avg_score /= constants.parallelism_group_size()
avg_score = avg_score.item()
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
raise CodeVerifierException(
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
)
return res
model_api.register_interface("rw_code", PackedCodeRewardInterface)

View File

@ -0,0 +1,71 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict
import torch
import realhf.api.core.model_api as model_api
import realhf.base.logging as logging
from realhf.api.core.config import ModelInterfaceAbstraction
from realhf.api.core.data_api import RL_TASKS, MicroBatchSpec, SequenceSample
from realhf.base.datapack import flat2d
from realhf.impl.model.nn.real_llm_api import ReaLModel
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class FusedThreadingForwardInterface(model_api.ModelInterface):
def __init__(self, interfaces: Dict[str, ModelInterfaceAbstraction]):
self.interfaces = {
key: model_api.make_interface(interface)
for key, interface in interfaces.items()
}
def run_interface(
self,
interface_name: str,
model,
data,
mb_spec,
) -> SequenceSample | None:
tik = time.perf_counter()
res = self.interfaces[interface_name].inference(model, data, mb_spec)
t = time.perf_counter() - tik
logger.info(f"Interface {interface_name} cost {t} s")
return res
def inference(
self,
model: model_api.Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> SequenceSample | None:
with ThreadPoolExecutor(max_workers=len(self.interfaces)) as executor:
tasks = []
for interface_name in self.interfaces:
task = executor.submit(
self.run_interface, interface_name, model, data, mb_spec
)
tasks.append(task)
final_result = None
for task in as_completed(tasks):
res = task.result()
if res is None:
continue
if final_result is None:
final_result = res
else:
final_result.update_(res)
return final_result
model_api.register_interface("fused-threading", FusedThreadingForwardInterface)

View File

@ -1,46 +1,73 @@
# Copyright 2025 Ant Group Inc.
import collections
import copy
import ast
import asyncio
import dataclasses
import html
import itertools
import json
import os
import random
import re
import time
import xml.etree.ElementTree as ET
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional
from typing import Dict, List, Tuple
import colorama
import numpy as np
import requests
import torch
import torch.distributed as dist
import tqdm
import transformers
import realhf.api.core.model_api as model_api
import realhf.base.logging as logging
from functioncall.code.local_verify import code_verify as local_code_verify
from functioncall.code.verify import code_verify
from functioncall.math.verify import math_verify
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer
from realhf.api.core.data_api import (
RL_TASKS,
MicroBatchSpec,
SequenceSample,
load_hf_tokenizer,
)
from realhf.base import constants
from realhf.base.constants import data_parallel_group, data_parallel_world_size
from realhf.base.datapack import flat2d
from realhf.impl.model.interface.math_parser import parse_lines_in_parallel
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.dataset.math_code_dataset import load_metadata
from realhf.impl.dataset.math_parser import parse_lines_in_parallel as math_verify_local
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else math_verify_local
code_verify_call = code_verify if ENABLE_FUNCTION_CALL else local_code_verify
class MathVerifierException(Exception):
class VerifierException(Exception):
pass
def extract_python_code(text, min_length=20, strict_syntax=True):
code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```"
code_blocks = re.findall(code_pattern, text, re.DOTALL)
valid_blocks = []
for block in code_blocks:
clean_block = block.strip()
if len(clean_block) < min_length:
continue
# verify code syntax
if strict_syntax:
try:
ast.parse(clean_block, mode="exec")
except (SyntaxError, IndentationError):
continue
valid_blocks.append(clean_block)
if not valid_blocks:
logger.warning(f"failed to extract python code from {text}")
return None
# return the last code block
return valid_blocks[-1]
def check_with_elementtree(text):
def escape_between_tags(text, tags=["think", "answer"]):
"""转义标签之间的内容,但保留标签本身."""
@ -94,27 +121,37 @@ def check_with_elementtree(text):
return False, f"Error: XML格式错误, {str(e)}"
def retokenize(
id2info = {}
def dispatch_reward_calculation(task, answers, query_id_strs) -> List:
global id2info
assert len(answers) == len(query_id_strs)
format_rewards = []
if task == "math":
format_rewards = math_verify_call(id2info, answers, query_id_strs)
elif task == "code":
codes = [extract_python_code(_answer) for _answer in answers]
format_rewards = code_verify_call(id2info, codes, query_id_strs)
assert len(format_rewards) == len(answers), (
task,
len(format_rewards),
len(answers),
answers,
)
return format_rewards
def retokenize_and_verify(
task,
tokenizer,
packed_input_ids,
input_cu_seqlens,
prompts,
prompt_cu_seqlens,
query_ids,
prompt_ids: List[List[int]],
seq_ids: List[List[int]],
query_ids: List[str],
check_xml_format=False,
do_eval=False,
):
input_ids = [
packed_input_ids[start:end]
for start, end in zip(input_cu_seqlens[:-1], input_cu_seqlens[1:])
]
prompt_ids = [
prompts[start:end]
for start, end in zip(prompt_cu_seqlens[:-1], prompt_cu_seqlens[1:])
]
seq_strs = tokenizer.batch_decode(
input_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
seq_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
)
prompt_strs = tokenizer.batch_decode(
prompt_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
@ -122,261 +159,185 @@ def retokenize(
# query_id_strs = query_ids
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
format_rewards = []
queryid_to_results = collections.defaultdict(list)
# 8 processes on each node, with 10 subprocesses each
if do_eval == True:
_answers = [
seq_str.split(prompt_str)[1]
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
]
answers = [
seq_str.split(prompt_str)[1]
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
]
format_rewards = math_verify_call(_answers, query_id_strs)
if check_xml_format:
with ThreadPoolExecutor(max_workers=22) as executor:
futures = [
executor.submit(check_with_elementtree, answer_str)
for answer_str in _answers
]
# xml_rewards = []
for idx, future in enumerate(futures):
xml_reward, _ = future.result()
# xml_rewards.append(xml_reward)
if xml_reward == 1 and format_rewards[idx] == 0:
format_rewards[idx] = -0.8
elif xml_reward == 0 and format_rewards[idx] == 0:
format_rewards[idx] = -1
format_rewards = dispatch_reward_calculation(task, answers, query_id_strs)
for query_id_str, format_reward in zip(query_id_strs, format_rewards):
if query_id_str not in queryid_to_results:
queryid_to_results[query_id_str] = []
queryid_to_results[query_id_str].append(format_reward)
else:
for query_id_str in query_id_strs:
if query_id_str not in queryid_to_results:
queryid_to_results[query_id_str] = []
queryid_to_results[query_id_str].append(0)
format_rewards.append(0)
if check_xml_format:
for idx, answer in enumerate(answers):
xml_reward, _ = check_with_elementtree(answer)
if xml_reward == 1 and format_rewards[idx] == 0:
format_rewards[idx] = -0.8
elif xml_reward == 0 and format_rewards[idx] == 0:
format_rewards[idx] = -1
return format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results
return format_rewards, prompt_strs, seq_strs
@dataclasses.dataclass
class PackedMathRewardInterface(model_api.ModelInterface):
enable_save: bool = False
class MultiTaskRewardInterface(model_api.ModelInterface):
dataset_path: str = ""
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
output_scaling: float = 1.0
rm_output_scaling: float = 1.0
rm_output_bias: float = 0.0
output_bias: float = 0.0
loss_fun = torch.nn.CrossEntropyLoss(reduction="none")
max_sync_length: int = 2048
rw_type: str = "sparse"
task: str = "math" # math or countdown
check_xml_format: bool = False
post_process: str = "sigmoid"
group_size: int = 1
check_verifier_status: bool = False
def __post_init__(self):
global id2info
id2info, _ = load_metadata(self.dataset_path)
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
logger.info(f"rm_output_scaling: {self.rm_output_scaling}")
logger.info(f"rm_output_bias: {self.rm_output_bias}")
logger.info(f"output_scaling: {self.output_scaling}")
logger.info(f"output_bias: {self.output_bias}")
logger.info(f"max_sync_length: {self.max_sync_length}")
logger.info(f"rw_type: {self.rw_type}")
logger.info(f"post_process: {self.post_process}")
if constants.parallelism_rank() == 0:
logger.info(f"output_scaling: {self.output_scaling}")
logger.info(f"output_bias: {self.output_bias}")
logger.info(f"rw_type: {self.rw_type}")
def inference(
self,
model: model_api.Model,
data_: SequenceSample,
mb_spec,
) -> SequenceSample:
packed_input_ids: torch.Tensor = data_.data["packed_input_ids"].squeeze()
input_seqlens = torch.tensor(data_.seqlens["packed_input_ids"]).view(-1)
input_cu_seqlens = torch.nn.functional.pad(
input_seqlens.cumsum(0), (1, 0)
).int()
packed_prompts = data_.data["packed_prompts"]
prompts = []
prompt_seqlens = []
offset = 0
for x in data_.seqlens["packed_prompts"]:
prompts += [packed_prompts[offset : offset + x[0]]] * self.group_size
offset += x[0]
prompt_seqlens.extend(x * self.group_size)
assert offset == sum(x[0] for x in data_.seqlens["packed_prompts"])
# non_packed_prompts = copy.deepcopy(prompts)
prompts = torch.cat(prompts)
prompt_seqlens = torch.tensor(prompt_seqlens).view(-1)
prompt_cu_seqlens = torch.nn.functional.pad(
prompt_seqlens.cumsum(0), (1, 0)
).int()
query_ids = [data_id for data_id in data_.ids for _ in range(self.group_size)]
format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results = (
retokenize(
self.task,
self.tokenizer,
packed_input_ids,
input_cu_seqlens,
prompts,
prompt_cu_seqlens,
query_ids,
check_xml_format=self.check_xml_format,
do_eval=constants.is_last_pipe_stage(),
def _dispatch_tasks(self, data: SequenceSample) -> Tuple[Dict, Dict]:
xs = data.unpack()
dispatched = {}
dispatched_indices = {}
for task_idx, task_name in enumerate(RL_TASKS):
indices = (
(data.data["task_ids"] == task_idx).cpu().numpy().nonzero()[0].tolist()
)
if len(indices) > 0:
dispatched[task_name] = SequenceSample.gather([xs[i] for i in indices])
dispatched_indices[task_name] = indices
return dispatched, dispatched_indices
def _gather_tasks(
self, results: Dict, dispatched_indices: Dict, bs: int
) -> SequenceSample:
xs = [None for _ in range(bs)]
for task_name, indices in dispatched_indices.items():
xxs = results[task_name].unpack()
assert len(indices) == len(xxs), (len(indices), len(xxs))
for i, xx in zip(indices, xxs):
xs[i] = xx
assert all(xs)
return SequenceSample.gather(xs)
def _dispatch_tp_and_pp(self, data: SequenceSample):
tp_pp_size = constants.tp_and_pp_world_size()
if tp_pp_size == 1:
return data, None
splitted, _, backward_indices = data.split(
mb_spec=MicroBatchSpec(n_mbs=tp_pp_size)
)
tp_pp_rank = constants.tp_and_pp_rank()
print("dispatched batch size", [s.bs for s in splitted], flush=True)
return splitted[tp_pp_rank], backward_indices
def _gather_tp_and_pp(self, input_, data: SequenceSample, backward_indices):
tp_pp_size = constants.tp_and_pp_world_size()
if tp_pp_size == 1:
return data
local_rank = constants.grid().topo.get_rank(
data=constants.data_parallel_rank(),
model=0,
pipe=constants.pipe_parallel_world_size() - 1,
)
dst = constants.to_global_pg_rank(local_rank)
gather_list = None
if dist.get_rank() == dst:
gather_list = [None for _ in range(tp_pp_size)]
x = data.data["rewards"].cpu().numpy().tolist()
print(x, flush=True)
dist.gather_object(
x, gather_list, dst=dst, group=constants.tp_and_pp_cpu_group()
)
if dist.get_rank() != dst:
return None
gathered = np.array(gather_list).reshape(-1, self.group_size)
assert len(gathered) == len(backward_indices)
rewards = (
np.concatenate([gathered[i] for i in backward_indices]).flatten().tolist()
)
return SequenceSample(
keys=["rewards"],
trailing_shapes=dict(rewards=()),
dtypes=dict(rewards=torch.float32),
ids=input_.ids,
seqlens=dict(
rewards=[[1 for _ in range(self.group_size)] for _ in range(input_.bs)],
),
data=dict(rewards=torch.tensor(rewards, dtype=torch.float32)),
)
assert self.rw_type == "sparse"
dense_scores = torch.zeros_like(packed_input_ids).float()
def calculate_task_reward(
self,
model: model_api.Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
task_type: str,
):
# mb_spec is disrespected here
packed_input_ids: torch.Tensor = data.data["packed_input_ids"]
input_seqlens = flat2d(data.seqlens["packed_input_ids"])
seq_ids = []
offset = 0
for slen in input_seqlens:
seq_ids.append(
packed_input_ids[offset : offset + slen].cpu().numpy().tolist()
)
offset += slen
assert offset == packed_input_ids.shape[0], (offset, packed_input_ids.shape)
prompt_input_ids = data.data["packed_prompts"]
prompt_len = flat2d(data.seqlens["packed_prompts"])
prompt_ids = []
offset = 0
for slen in prompt_len:
p = prompt_input_ids[offset : offset + slen].cpu().numpy().tolist()
prompt_ids += [p] * self.group_size
offset += slen
format_rewards, prompt_strs, seq_strs = retokenize_and_verify(
task_type,
self.tokenizer,
prompt_ids=prompt_ids,
seq_ids=seq_ids,
query_ids=[
str(data_id) for data_id in data.ids for _ in range(self.group_size)
],
check_xml_format=self.check_xml_format,
)
scores = torch.FloatTensor(format_rewards).to(packed_input_ids.device)
scores[scores == 0] = -1
if len(scores) == 0:
return None
assert dense_scores.shape == packed_input_ids.shape
scores = (
scores.to(packed_input_ids.device) - self.output_bias
) * self.output_scaling
logger.info(f"Math reward logging info @v{model.version.global_step}")
logger.info(
f"before: Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
)
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
if constants.is_last_pipe_stage():
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
f"v{model.version.global_step}r{dist.get_rank()}.txt",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
logger.info(
f"Generated samples and rewards will be dumped to: {gen_file_path}"
)
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
info = "\n".join(
[
f"idx: {idx} / {len(scores)}",
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
]
)
_f.write(info + "\n")
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated_jsonl",
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
logger.info(
f"Generated samples and rewards will be dumped to: {gen_file_path}"
)
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
_f.write(
json.dumps(
{
"prompt": prompt_str,
"generated": seq_str.split(prompt_str)[1],
"reward": score.item(),
},
ensure_ascii=False,
)
+ "\n"
)
logger.info(
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
)
pass_at_k = np.mean(
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
)
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
logger.info(f"reward: {sum(scores) / len(scores)}")
train_pass_monitor_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"training_monitor",
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
logger.info(
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
)
with open(train_pass_monitor_file_path, "w") as monitor_file:
for key, value in queryid_to_results.items():
pass1 = sum(value) / len(value)
pass8 = int(sum(value) > 0)
monitor_file.write(
json.dumps(
{"query_id": key, "pass1": pass1, "pass8": pass8},
ensure_ascii=False,
)
+ "\n"
)
model.inc_version()
if scores.dtype != torch.float32:
scores = scores.to(torch.float32)
if dense_scores.dtype != torch.float32:
dense_scores = dense_scores.to(torch.float32)
self.log_rewards_to_file(task_type, model, prompt_strs, seq_strs, scores)
res = SequenceSample(
keys=["rewards", "dense_rewards"],
trailing_shapes=dict(rewards=(), dense_rewards=()),
dtypes=dict(rewards=torch.float32, dense_rewards=torch.float32),
ids=data_.ids,
keys=["rewards"],
trailing_shapes=dict(rewards=()),
dtypes=dict(rewards=torch.float32),
ids=data.ids,
seqlens=dict(
rewards=[
torch.tensor([1 for _ in range(len(x))], dtype=torch.int32)
for x in data_.seqlens["packed_input_ids"]
[1 for _ in range(len(x))] for x in data.seqlens["packed_input_ids"]
],
dense_rewards=data_.seqlens["packed_input_ids"],
),
data=dict(rewards=scores, dense_rewards=dense_scores),
data=dict(rewards=scores),
)
# record rewards for each piece of data
avg_scores = []
offset = 0
for i in range(data_.bs):
for i in range(data.bs):
score_lis = scores[
offset : offset + len(data_.seqlens["packed_input_ids"][i])
offset : offset + len(data.seqlens["packed_input_ids"][i])
]
avg_scores.append(score_lis.mean().item())
offset += len(data_.seqlens["packed_input_ids"][i])
assert offset == sum(len(x) for x in data_.seqlens["packed_input_ids"])
offset += len(data.seqlens["packed_input_ids"][i])
assert offset == sum(len(x) for x in data.seqlens["packed_input_ids"])
res.metadata["scores"] = avg_scores
@ -385,20 +346,165 @@ class PackedMathRewardInterface(model_api.ModelInterface):
np.mean(avg_scores), device=constants.current_device()
)
dist.all_reduce(
avg_score, op=dist.ReduceOp.SUM, group=constants.parallelism_group()
avg_score, op=dist.ReduceOp.SUM, group=constants.data_parallel_group()
)
avg_score /= constants.parallelism_group_size()
avg_score /= constants.data_parallel_group()
avg_score = avg_score.item()
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
minimal_score = (-1 - self.output_bias) * self.output_scaling
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
raise MathVerifierException(
raise VerifierException(
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
)
if not constants.is_last_pipe_stage():
return None
return res
def log_rewards_to_file(
self, task_type: str, model: model_api.Model, prompt_strs, seq_strs, scores
):
tik = time.perf_counter()
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
task_type,
f"v{model.version.global_step}r{dist.get_rank()}.txt",
)
model_api.register_interface("rw_math", PackedMathRewardInterface)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
info = "\n".join(
[
f"idx: {idx} / {len(scores)}",
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
]
)
_f.write(info + "\n")
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated_jsonl",
task_type,
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
_f.write(
json.dumps(
{
"prompt": prompt_str,
"generated": seq_str.split(prompt_str)[1],
"reward": score.item(),
},
ensure_ascii=False,
)
+ "\n"
)
logger.info(f"[{task_type}] number of samples: {len(scores)}, {scores.shape}")
logger.info(f"[{task_type}] avg reward: {sum(scores) / len(scores)}")
logger.info(f"[{task_type}] log to file time: {time.perf_counter()- tik:.2f}s")
def inference(
self,
model: model_api.Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> SequenceSample | None:
input_ = data
data, backward_indices = self._dispatch_tp_and_pp(data)
task_data, dispatch_indices = self._dispatch_tasks(data)
assert self.rw_type == "sparse"
def _task_func(func, task_type: str):
def _wrapped_func(*args, **kwargs):
start_time = time.perf_counter()
try:
result = func(*args, **kwargs)
except Exception as e:
raise asyncio.CancelledError(
f"[{task_type}] task failed: {e}"
) from e
finally:
duration = time.perf_counter() - start_time
logger.info(f"[{task_type}] time cost: {duration:.4f}s")
return task_type, result
return _wrapped_func
async def _run_tasks():
tasks = []
for task_type, d in task_data.items():
task_func = _task_func(self.calculate_task_reward, task_type)
task_args = (model, d, mb_spec, task_type)
task = asyncio.create_task(asyncio.to_thread(task_func, *task_args))
tasks.append(task)
results = await asyncio.gather(*tasks)
task_results = {}
for res in results:
task_type, result = res
task_results[task_type] = result
return task_results
task_results = asyncio.run(_run_tasks())
final_result = self._gather_tasks(task_results, dispatch_indices, data.bs)
final_result = self._gather_tp_and_pp(input_, final_result, backward_indices)
model.inc_version()
return final_result
def _mock_inference(
self,
model: model_api.Model,
data: SequenceSample,
) -> SequenceSample:
prompt_lens = flat2d(data.seqlens["packed_prompts"])
task_ids = data.data["task_ids"].cpu().numpy().tolist()
seqlens = []
offset = 0
seq = []
for plen, task_id in zip(prompt_lens, task_ids):
seq += [data.data["packed_prompts"][offset : offset + plen]]
offset += plen
if task_id == RL_TASKS.index("math"):
answer_str = (
"something unimportant but the answer is \\boxed{-\\frac{2}{3}}."
)
elif task_id == RL_TASKS.index("code"):
answer_str = (
"```python\ninput()\nimport time\ntime.sleep(1e-3)\nprint(1)\n```"
)
else:
answer_str = "something unimportant"
encoding = model.tokenizer(
[answer_str], add_special_tokens=True, return_attention_mask=False
)
ans = torch.tensor(encoding["input_ids"], dtype=torch.long).flatten()
seq += [ans]
seqlens.append(plen + len(ans))
x = SequenceSample.from_default(
seqlens=seqlens,
ids=data.ids,
data=dict(packed_input_ids=torch.cat(seq)),
)
data.update_(x)
return data
model_api.register_interface("rw-math-code", MultiTaskRewardInterface)

View File

@ -18,7 +18,7 @@ import realhf.base.logging as logging
import realhf.impl.model.utils.ppo_functional as ppo_functional
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
from realhf.base.datapack import flat2d
from realhf.impl.model.interface.math_parser import parse_lines_in_parallel
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.nn.real_llm_generate import concat_prompt_to_generation_output
from realhf.impl.model.utils.functional import (

View File

@ -618,8 +618,6 @@ class RayController:
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""),
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
REAL_MATH_METADATA_PATH=os.environ.get("REAL_MATH_METADATA_PATH", ""),
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),

View File

@ -920,7 +920,7 @@ class ModelWorker(worker_base.Worker):
"dataset_eval_scores.json",
)
eval_scores = {}
if isinstance(res, data_api.SequenceSample):
if isinstance(res, data_api.SequenceSample) and constants.is_dp_head():
if rpc.output_key_remap:
res.remap_keys_(rpc.output_key_remap)
res = res.select(rpc.output_keys)

View File

@ -2,6 +2,7 @@
import os
import shutil
import uuid
from typing import *
import pytest
@ -27,23 +28,25 @@ def model_class(request):
@pytest.fixture(params=[testing.TESTING_DATASET_SIZE])
def math_dataset(request, save_path):
with open(os.getenv("REAL_MATH_METADATA_PATH"), "r") as f:
query_ids = list(json.load(f).keys())
def math_code_dataset(request, save_path):
size = request.param
max_prompt_len = 8
max_resp_len = 8
dataset = []
for i in range(size):
prompt_len = random.randint(1, max_prompt_len)
n_pairs = random.randint(1, 5)
d = dict(
query_id=query_ids[i],
query_id=str(uuid.uuid4()),
prompt=generate_random_sentence(prompt_len),
task=random.choice(["math", "code"]),
)
if d["task"] == "math":
d["solutions"] = [generate_random_sentence(max_resp_len)]
elif d["task"] == "code":
d["input_output"] = json.dumps(dict(inputs=["the\n"], outputs=["the\n"]))
dataset.append(d)
with open(str(save_path / "math_dataset.json"), "w") as f:
json.dump(dataset, f)
with open(str(save_path / "math_code_dataset.jsonl"), "a") as f:
f.write(json.dumps(d) + "\n")
return dataset
@ -59,7 +62,7 @@ def math_dataset(request, save_path):
def test_ppo_symm(
tmp_path_factory,
tokenizer,
math_dataset,
math_code_dataset,
save_path,
cpu_hf_model,
mconfig,
@ -97,13 +100,8 @@ def test_ppo_symm(
init_critic_from_actor=True,
backend="mock_train",
),
rew=ModelTrainEvalConfig(
path=str(save_path),
init_critic_from_actor=True,
init_from_scratch=True,
),
dataset=PromptOnlyDatasetConfig(
path=str(save_path / "math_dataset.json"),
path=str(save_path / "math_code_dataset.jsonl"),
max_prompt_len=mconfig.n_positions // 2,
train_bs_n_seqs=minbs,
fill_to_max_length=False,
@ -116,6 +114,7 @@ def test_ppo_symm(
use_cuda_graph=False,
),
),
group_size=2,
)
run_test_exp(exp_cfg)
@ -133,7 +132,7 @@ def test_ppo_symm(
def test_ppo_global_reshard(
tmp_path_factory,
tokenizer,
math_dataset,
math_code_dataset,
save_path,
cpu_hf_model,
mconfig,
@ -180,7 +179,7 @@ def test_ppo_global_reshard(
init_from_scratch=True,
),
dataset=PromptOnlyDatasetConfig(
path=str(save_path / "math_dataset.json"),
path=str(save_path / "math_code_dataset.jsonl"),
max_prompt_len=mconfig.n_positions // 2,
train_bs_n_seqs=minbs,
fill_to_max_length=False,
@ -249,7 +248,7 @@ def test_ppo_global_reshard(
def test_ppo_param_realloc_sub_device_mesh(
tmp_path_factory,
tokenizer,
math_dataset,
math_code_dataset,
save_path,
cpu_hf_model,
mconfig,
@ -292,7 +291,7 @@ def test_ppo_param_realloc_sub_device_mesh(
init_from_scratch=True,
),
dataset=PromptOnlyDatasetConfig(
path=str(save_path / "math_dataset.json"),
path=str(save_path / "math_code_dataset.jsonl"),
max_prompt_len=mconfig.n_positions // 2,
train_bs_n_seqs=minbs,
fill_to_max_length=False,
@ -410,7 +409,7 @@ def test_ppo_save(
init_from_scratch=True,
),
dataset=PromptOnlyDatasetConfig(
path=str(save_path / "math_dataset.json"),
path=str(save_path / "math_code_dataset.jsonl"),
max_prompt_len=mconfig.n_positions // 2,
train_bs_n_seqs=bs,
fill_to_max_length=False,

View File

@ -0,0 +1,113 @@
# Copyright 2025 Ant Group Inc. All Rights Reserved.
import os
import shutil
import uuid
from typing import *
import pytest
import torch.distributed as dist
from torch.utils.data.dataloader import DataLoader
from realhf.api.core.data_api import (
DatasetUtility,
MicroBatchSpec,
SequenceSample,
load_hf_tokenizer,
)
from realhf.api.core.model_api import FinetuneSpec, Model
from realhf.base import constants, network, testing
from tests.fixtures import *
@pytest.fixture(params=[testing.TESTING_DATASET_SIZE])
def math_code_dataset(request, save_path):
size = request.param
max_prompt_len = 8
dataset = []
for i in range(size):
prompt_len = random.randint(1, max_prompt_len)
n_pairs = random.randint(1, 5)
if random.random() < 0.5:
d = dict(
task="code",
query_id=str(uuid.uuid4()),
prompt=generate_random_sentence(prompt_len),
problem_id=str(uuid.uuid4()),
input_output=json.dumps(
{"inputs": ["1\n"] * 8, "outputs": ["1\n"] * 8}
),
solutions=json.dumps(
["```python\ninput()\nimport time\ntime.sleep(1e-3)\nprint(1)\n```"]
* 3
),
difficulty=random.random() * 10,
)
else:
d = dict(
task="math",
query_id=str(uuid.uuid4()),
prompt=generate_random_sentence(prompt_len),
answers=["\\boxed{-\\frac{2}{3}}"],
solutions=["\\boxed{-\\frac{2}{3}}"],
)
dataset.append(d)
with open(str(save_path / "math_code_dataset.jsonl"), "w") as f:
f.write("\n".join([json.dumps(d) for d in dataset]))
return dataset
@pytest.mark.parametrize(
"tokenizer_path", ["/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"]
)
def test_multi_task_reward_interface(save_path, tokenizer_path, math_code_dataset):
from realhf.impl.dataset.math_code_dataset import MATHCodePromptDataset
dist.init_process_group(
rank=0, world_size=1, init_method=f"tcp://localhost:{network.find_free_port()}"
)
testing.init_global_constants()
dataset = MATHCodePromptDataset(
DatasetUtility(
seed=0,
dp_rank=0,
world_size=1,
tokenizer=load_hf_tokenizer(tokenizer_path),
),
max_length=512,
dataset_path=str(save_path / "math_code_dataset.jsonl"),
)
dataloader = DataLoader(
dataset,
collate_fn=SequenceSample.gather,
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
batch_size=4,
shuffle=True,
)
from realhf.impl.model.interface.rw_interface import MultiTaskRewardInterface
with constants.model_scope(testing.MODEL_NAME):
interface = MultiTaskRewardInterface(
dataset_path=str(save_path / "math_code_dataset.jsonl"),
tokenizer_path=tokenizer_path,
group_size=1,
check_verifier_status=False,
)
model = Model(
name="test",
module=None,
tokenizer=load_hf_tokenizer(tokenizer_path),
device=torch.device("cpu"),
ft_spec=FinetuneSpec(
total_train_epochs=1, dataset_size=100, train_batch_size=3
),
)
for d in dataloader:
d = interface.mock("inference", model, d)
rewards = interface.inference(model, d, mb_spec=MicroBatchSpec())
d.update_(rewards)
assert rewards.data["rewards"].all(), rewards.data["rewards"]
dist.destroy_process_group()