mirror of https://github.com/inclusionAI/AReaL
PullRequest: 49 Enable Parallel Execution for Code, Math, and Reference Tasks
Merge branch async-ref-rew of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/49 Signed-off-by: 博惟 <bowei.fw@antgroup.com>
This commit is contained in:
commit
f8afa97484
14
README.md
14
README.md
|
@ -54,22 +54,21 @@ Please refer to the [tutorial](/examples/README.md) under the `examples` directo
|
|||
# Download the dataset
|
||||
DATA_PATH=/storage/datasets/
|
||||
cd $DATA_PATH
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_r1_distilled.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_prompts_for_r1_distilled.jsonl?download=true
|
||||
|
||||
# Training in a Ray cluster with 16 nodes
|
||||
|
||||
# stage 1
|
||||
MODEL_PATH=${path_to_DeepSeek-R1-Distill-Qwen-1.5B}
|
||||
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_r1_distilled.jsonl $DATA_PATH/id2info.json 8192
|
||||
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_prompts_for_r1_distilled.jsonl 8192
|
||||
|
||||
# stage 2
|
||||
MODEL_PATH=${model_path_from_stage_1}
|
||||
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_r1_distilled.jsonl $DATA_PATH/id2info.json 16384
|
||||
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_prompts_for_r1_distilled.jsonl 16384
|
||||
|
||||
# stage 3
|
||||
MODEL_PATH=${model_path_from_stage_2}
|
||||
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_r1_distilled.jsonl $DATA_PATH/id2info.json 24000
|
||||
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_prompts_for_r1_distilled.jsonl 24000
|
||||
|
||||
```
|
||||
|
||||
|
@ -134,11 +133,10 @@ Table 2. Evaluation on AIME 2024, AMC 2023, and MATH-500. The results are report
|
|||
# Training in a Ray cluster with 16 nodes
|
||||
DATA_PATH=/storage/datasets/
|
||||
cd $DATA_PATH
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_zero.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_orz_zero.jsonl?download=true
|
||||
|
||||
MODEL_PATH=${path_to_Qwen2.5-7B}
|
||||
bash ./examples/train_7B_zero_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_zero.jsonl $DATA_PATH/id2info.json 24000
|
||||
bash ./examples/train_7B_zero_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_orz_zero.jsonl 24000
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -114,9 +114,8 @@ We provide a dataset for training. Download the dataset and place it in `/storag
|
|||
```bash
|
||||
mkdir -p /storage/datasets/
|
||||
cd /storage/datasets/
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_r1_distilled.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_zero.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_prompts_for_r1_distilled.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_orz_zero.jsonl?download=true
|
||||
```
|
||||
|
||||
## Model
|
||||
|
@ -291,7 +290,6 @@ The descriptions of the important parameters are as follows:
|
|||
+ `MODE`: It is always `ray`, and do not change it to other values when referring to this tutorial for training.
|
||||
+ `BASE_MODEL_PATH`: The path of the model.
|
||||
+ `DATA_PATH`: The path of the dataset jsonl file
|
||||
+ `REAL_MATH_METADATA_PATH`: Set it to the path of the json file of the math metadata, refer to troubleshooting.
|
||||
+ `CLUSTER_SPEC_PATH`: Set it to the path of cluster_config.json
|
||||
|
||||
+ `n_nodes`: The number of nodes
|
||||
|
@ -513,10 +511,5 @@ This error typically occurs during the vLLM generation phase and is another symp
|
|||
- Reduce the training batch size or the number of answers generated per prompt. Note that this may lower sample efficiency and extend training time.
|
||||
- [Switch vLLM's attention backend to xformers](https://github.com/vllm-project/vllm/issues/5376).
|
||||
|
||||
## Others
|
||||
|
||||
### How to train with other datasets
|
||||
|
||||
The dataset must be in `jsonl` format, with each entry containing two keys: `prompt` (a math problem) and `query_id` (a unique identifier for the problem). After preparing the dataset, update the `REAL_MATH_METADATA_PATH` with the new dataset's metadata. Metadata is a JSON file that records the solutions for each problem. The training code uses this metadata to determine if the model solved a problem correctly.
|
||||
|
||||
|
||||
|
|
|
@ -120,18 +120,8 @@ cd /storage/ray/
|
|||
```bash
|
||||
mkdir -p /storage/datasets/
|
||||
cd /storage/datasets/
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_r1_distilled.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_zero.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
|
||||
```
|
||||
|
||||
如果无法访问 `huggingface.co`,也可以从 ModelScope 下载:
|
||||
```bash
|
||||
mkdir -p /storage/datasets/
|
||||
cd /storage/datasets/
|
||||
wget https://www.modelscope.cn/datasets/inclusionAI/AReaL-RL-Data/resolve/master/data/prompts_for_r1_distilled.jsonl
|
||||
wget https://www.modelscope.cn/datasets/inclusionAI/AReaL-RL-Data/resolve/master/data/prompts_for_zero.jsonl
|
||||
wget https://www.modelscope.cn/datasets/inclusionAI/AReaL-RL-Data/resolve/master/data/id2info.json
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_prompts_for_r1_distilled.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_orz_zero.jsonl?download=true
|
||||
```
|
||||
|
||||
## 模型
|
||||
|
@ -146,15 +136,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-R1-D
|
|||
|
||||
你也可以在安装 PyPI 和 huggingface_hub 后利用 huggingface CLI 进行下载,具体请参考[官方文档](https://huggingface.co/docs/huggingface_hub/guides/cli)
|
||||
|
||||
如果无法访问 `huggingface.co`,也可以从 ModelScope 下载(请确保已经安装了 Git LFS):
|
||||
|
||||
```
|
||||
mkdir -p /storage/models
|
||||
cd /storage/models
|
||||
git clone https://www.modelscope.cn/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B.git
|
||||
git clone https://www.modelscope.cn/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B.git
|
||||
```
|
||||
|
||||
|
||||
## 启动 Ray 集群
|
||||
|
||||
|
@ -315,7 +296,6 @@ python3 -m realhf.apps.quickstart ppo-math --show-args
|
|||
+ MODE:总是为 ray,参考本教程进行训练时不要改成其他值。
|
||||
+ BASE_MODEL_PATH:模型的路径
|
||||
+ DATA_PATH:数据集 jsonl 文件的路径
|
||||
+ REAL_MATH_METADATA_PATH:设置成数学 metadata 的 json 文件路径,参考troubleshooting。
|
||||
+ CLUSTER_SPEC_PATH:设置成 cluster_config.json 的路径
|
||||
|
||||
+ n_nodes:节点数量
|
||||
|
@ -530,8 +510,3 @@ ALL_PARAMS=(
|
|||
+ 减小训练batch size或者每个prompt生成的答案数量,但减小后会降低样本效率、延长训练时间
|
||||
+ [将vLLM的attention backend换成xformers](https://github.com/vllm-project/vllm/issues/5376)
|
||||
|
||||
## 其他
|
||||
|
||||
### 如何用其他数据集进行训练
|
||||
|
||||
数据集需要是是 jsonl 格式的文件,其中每一条数据需要包含两个 key,分别是 prompt,即一道数学问题,和query_id,即这道数学问题的唯一标识符。在准备好数据集后,还需要根据数据集中的题目更新REAL_MATH_METADATA_PATH的内容。metadata 是一个 json 文件,记录了每道题目的答案、来源和解法。训练代码需要根据 metadata 来判断模型是否做对了一道题。
|
||||
|
|
|
@ -0,0 +1,240 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
"""
|
||||
Dataset Toolkit - Process and validate code/math datasets with flexible input support
|
||||
"""
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
# Configure console logging
|
||||
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_jsonl(file_path: str) -> List[Dict]:
|
||||
"""Load JSONL file with validation"""
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return [json.loads(line) for line in f]
|
||||
except FileNotFoundError:
|
||||
print(f"ERROR: JSONL file not found: {file_path}")
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def process_code_data(file_path: str) -> List[Dict]:
|
||||
"""Process code dataset from JSONL file"""
|
||||
if not file_path:
|
||||
return []
|
||||
|
||||
raw_data = load_jsonl(file_path)
|
||||
processed = []
|
||||
|
||||
for item in raw_data:
|
||||
# Field extraction and transformation
|
||||
input_output = json.loads(item["input_output"])
|
||||
processed.append(
|
||||
{
|
||||
"task": "code",
|
||||
"query_id": str(item["id"]),
|
||||
"prompt": item["question"],
|
||||
"input_output": json.dumps(
|
||||
{
|
||||
"inputs": input_output.get("inputs", []),
|
||||
"outputs": input_output.get("outputs", []),
|
||||
"fn_name": item.get("metadata", {}).get("fn_name", ""),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
def process_math_data(file_path: str) -> List[Dict]:
|
||||
"""Process math dataset from JSON/JSONL file"""
|
||||
if not file_path:
|
||||
return []
|
||||
|
||||
raw_data = load_jsonl(file_path)
|
||||
processed = []
|
||||
|
||||
for item in raw_data:
|
||||
processed.append(
|
||||
{
|
||||
"task": "math",
|
||||
"query_id": str(item["query_id"]),
|
||||
"prompt": item["prompt"],
|
||||
"solutions": item["solutions"],
|
||||
}
|
||||
)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
def validate_raw_code(item: Dict) -> Tuple[bool, List[str]]:
|
||||
"""Validate raw code item structure"""
|
||||
errors = []
|
||||
required = {
|
||||
"task": str,
|
||||
"query_id": str,
|
||||
"prompt": str,
|
||||
"input_output": str,
|
||||
}
|
||||
|
||||
code_input_output_required = {
|
||||
"inputs": list,
|
||||
"outputs": list,
|
||||
#'fn_name': str
|
||||
}
|
||||
|
||||
for field, typ in required.items():
|
||||
if field not in item:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
elif not isinstance(item[field], typ):
|
||||
errors.append(f"Invalid type for {field}: expected {typ.__name__}")
|
||||
|
||||
input_output = json.loads(item["input_output"])
|
||||
|
||||
for io_field, io_typ in code_input_output_required.items():
|
||||
if io_field not in input_output:
|
||||
errors.append(f"Missing required field: {io_field} in input_output")
|
||||
elif not isinstance(input_output[io_field], io_typ):
|
||||
errors.append(
|
||||
f"Invalid type for {io_field}: expected {io_typ.__name__} in input_output"
|
||||
)
|
||||
|
||||
return (not errors, errors)
|
||||
|
||||
|
||||
def validate_raw_math(item: Dict) -> Tuple[bool, List[str]]:
|
||||
"""Validate raw math item structure"""
|
||||
errors = []
|
||||
required = {"query_id": str, "prompt": str, "solutions": list}
|
||||
|
||||
for field, typ in required.items():
|
||||
if field not in item:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
elif not isinstance(item[field], typ):
|
||||
type_names = (
|
||||
[t.__name__ for t in typ] if isinstance(typ, tuple) else typ.__name__
|
||||
)
|
||||
errors.append(f"Invalid type for {field}: expected {type_names}")
|
||||
|
||||
return (not errors, errors)
|
||||
|
||||
|
||||
def validate_raw_item(item: Dict) -> Tuple[bool, List[str]]:
|
||||
"""Validate raw code item structure"""
|
||||
if item.get("task", "") == "math":
|
||||
return validate_raw_math(item)
|
||||
else:
|
||||
return validate_raw_code(item)
|
||||
|
||||
|
||||
def check_item_valid(raw_data: list):
|
||||
if not raw_data:
|
||||
return defaultdict(int)
|
||||
logger.info(f"Validating code-math dataset...")
|
||||
total_stats = defaultdict(int)
|
||||
for index, item in enumerate(raw_data):
|
||||
valid, errors = validate_raw_item(item)
|
||||
if valid:
|
||||
total_stats["valid"] += 1
|
||||
else:
|
||||
total_stats["invalid"] += 1
|
||||
logger.warning(f"item {index}: {errors}")
|
||||
return total_stats
|
||||
|
||||
|
||||
def filter_shuffle(shuffle: bool, processed_data):
|
||||
# Final validation and write output
|
||||
valid_output = []
|
||||
invalid_count = 0
|
||||
for item in processed_data:
|
||||
valid, errors = validate_raw_item(item)
|
||||
if valid:
|
||||
valid_output.append(item)
|
||||
else:
|
||||
invalid_count += 1
|
||||
logger.error(
|
||||
f'{item["task"]} item {item.get("query_id")} is invalid: {errors}'
|
||||
)
|
||||
if shuffle:
|
||||
random.shuffle(valid_output)
|
||||
return valid_output, invalid_count
|
||||
|
||||
|
||||
def save_file(output_path: str, processed_data: list):
|
||||
with open(output_path, "w") as f:
|
||||
for item in processed_data:
|
||||
f.write(json.dumps(item) + "\n")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Dataset Toolkit: Process and validate STEM datasets",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--code", help="Path to code dataset (JSONL)")
|
||||
parser.add_argument("--math", help="Path to math dataset (JSONL)")
|
||||
parser.add_argument("--output", help="Output file path (JSONL)")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["check", "process"],
|
||||
default="check",
|
||||
help="Operation mode: check raw data or process datasets",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shuffle", action="store_true", help="Shuffle output data in process mode"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.mode == "check":
|
||||
# Validate raw data structure
|
||||
code_data = load_jsonl(args.code) if args.code else []
|
||||
math_data = load_jsonl(args.math) if args.math else []
|
||||
total_stats = check_item_valid(code_data + math_data)
|
||||
|
||||
# Print validation summary
|
||||
logger.info(
|
||||
f"\nValidation Summary: {total_stats['valid']} valid, {total_stats['invalid']} invalid"
|
||||
)
|
||||
elif args.mode == "process":
|
||||
# Process and merge datasets
|
||||
if not args.output:
|
||||
logger.error("Output file required in process mode")
|
||||
return
|
||||
processed_data = []
|
||||
stats = defaultdict(int)
|
||||
if args.code:
|
||||
code_data = process_code_data(args.code)
|
||||
logger.info(f"Loaded {len(code_data)} code items")
|
||||
processed_data.extend(code_data)
|
||||
stats["code"] = len(code_data)
|
||||
if args.math:
|
||||
math_data = process_math_data(args.math)
|
||||
logger.info(f"Loaded {len(math_data)} math items")
|
||||
processed_data.extend(math_data)
|
||||
stats["math"] = len(math_data)
|
||||
|
||||
processed_data, invalid_count = filter_shuffle(args.shuffle, processed_data)
|
||||
stats["invalid"] = invalid_count
|
||||
save_file(args.output, processed_data)
|
||||
logger.info("\nProcessing Complete:")
|
||||
logger.info(f"Total items: {len(processed_data)}")
|
||||
logger.info(f"Code items: {stats['code']}")
|
||||
logger.info(f"Math items: {stats['math']}")
|
||||
logger.info(f"Invalid items: {stats['invalid']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,119 @@
|
|||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
|
||||
def load_jsonl(file_path: Path) -> List[Dict]:
|
||||
"""Load JSONL file with validation"""
|
||||
try:
|
||||
with file_path.open("r", encoding="utf-8") as f:
|
||||
return [json.loads(line) for line in f]
|
||||
except FileNotFoundError:
|
||||
print(f"ERROR: JSONL file not found: {file_path}")
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def load_id2info(file_path: Path) -> Dict:
|
||||
"""Load ID mapping file with structure validation"""
|
||||
try:
|
||||
with file_path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not all(isinstance(v, dict) and "solutions" in v for v in data.values()):
|
||||
raise ValueError("Invalid id2info structure")
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
print(f"ERROR: ID mapping file not found: {file_path}")
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def process_data(prompts_data: List[Dict], id2info: Dict, output_path: Path) -> None:
|
||||
"""Process and save transformed data"""
|
||||
processed = []
|
||||
missing_ids = 0
|
||||
for item in prompts_data:
|
||||
query_id = item.get("query_id")
|
||||
if not query_id or query_id not in id2info:
|
||||
missing_ids += 1
|
||||
continue
|
||||
processed.append(
|
||||
{
|
||||
"prompt": item.get("prompt", ""),
|
||||
"task": "math",
|
||||
"query_id": query_id,
|
||||
"solutions": id2info[query_id].get("solutions", []),
|
||||
}
|
||||
)
|
||||
if missing_ids:
|
||||
print(f"WARNING: Found {missing_ids} items with missing/invalid query_id")
|
||||
if not processed:
|
||||
print("ERROR: No valid data to process")
|
||||
sys.exit(1)
|
||||
random.shuffle(processed)
|
||||
|
||||
try:
|
||||
with output_path.open("w", encoding="utf-8") as f:
|
||||
for item in processed:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||
print(f"SUCCESS: Wrote {len(processed)} items to {output_path}")
|
||||
except IOError as e:
|
||||
print(f"ERROR: Failed to write output: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Command line entry point
|
||||
|
||||
Example usage:
|
||||
python math_process.py \
|
||||
--prompts_path ./input/prompts.jsonl \
|
||||
--id2info_path ./input/id2info.json \
|
||||
--output_path ./output/processed.jsonl
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Math dataset processing tool",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompts_path",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to prompts.jsonl input file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--id2info_path",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to id2info.json input file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=Path,
|
||||
default="output.jsonl",
|
||||
help="Path for processed output file",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
print("Starting data processing...")
|
||||
try:
|
||||
prompts_data = load_jsonl(args.prompts_path)
|
||||
id2info = load_id2info(args.id2info_path)
|
||||
process_data(prompts_data, id2info, args.output_path)
|
||||
except Exception as e:
|
||||
print(f"FATAL: Processing failed - {str(e)}")
|
||||
sys.exit(1)
|
||||
print("Operation completed successfully")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -6,7 +6,7 @@ TRAIN_BATCH_SIZE="1024"
|
|||
GROUP_SIZE="8"
|
||||
NODES="16"
|
||||
ALLOCATION_MODE="vllm.d64p1m1+d32p2m1"
|
||||
MAX_NEW_TOKENS=$4
|
||||
MAX_NEW_TOKENS=$3
|
||||
MAX_NUM_SEQS=128
|
||||
PPO_MBS=4
|
||||
KL_CTL=0.001
|
||||
|
@ -18,7 +18,6 @@ BASE_MODEL_PATH="$1"
|
|||
|
||||
# original data
|
||||
DATA_PATH="$2"
|
||||
REAL_MATH_METADATA_PATH="$3"
|
||||
|
||||
# Option 1: The experiment runs locally with subprocesses.
|
||||
# MODE=local
|
||||
|
@ -53,7 +52,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
|
|||
# It's the user's responsibility to tune them appropriately.
|
||||
unset CLUSTER_SPEC_PATH
|
||||
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
||||
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
|
||||
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
||||
python3 -m realhf.apps.quickstart ppo-math \
|
||||
mode=$MODE \
|
||||
|
|
|
@ -5,7 +5,7 @@ TRAIN_BATCH_SIZE="512"
|
|||
GROUP_SIZE="64"
|
||||
NODES="16"
|
||||
ALLOCATION_MODE="vllm.d16p1m4+d32p2m1"
|
||||
MAX_NEW_TOKENS=$4
|
||||
MAX_NEW_TOKENS=$3
|
||||
MAX_NUM_SEQS=128
|
||||
PPO_MBS=4
|
||||
KL_CTL=0.0
|
||||
|
@ -17,7 +17,6 @@ BASE_MODEL_PATH="$1"
|
|||
|
||||
# original data
|
||||
DATA_PATH="$2"
|
||||
REAL_MATH_METADATA_PATH="$3"
|
||||
|
||||
# Option 1: The experiment runs locally with subprocesses.
|
||||
# MODE=local
|
||||
|
@ -52,7 +51,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
|
|||
# It's the user's responsibility to tune them appropriately.
|
||||
unset CLUSTER_SPEC_PATH
|
||||
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
||||
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
|
||||
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
||||
python3 -m realhf.apps.quickstart ppo-math \
|
||||
mode=$MODE \
|
||||
|
|
|
@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
|||
|
||||
EXP_NAME=ppo-zero-distill-1.5B-n1
|
||||
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
DATASET_NAME="prompts_for_r1_distilled.jsonl"
|
||||
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
|
||||
NODES=1
|
||||
ALLOCATION_MODE="actor_gen:d4p1m2,*:d4p2m1"
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
|||
|
||||
EXP_NAME=ppo-zero-distill-1.5B-n16
|
||||
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
DATASET_NAME="prompts_for_r1_distilled.jsonl"
|
||||
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
|
||||
NODES=16
|
||||
ALLOCATION_MODE="vllm.d64p1m1+d32p2m1"
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
|||
|
||||
EXP_NAME=ppo-zero-distill-1.5B-n4
|
||||
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
DATASET_NAME="prompts_for_r1_distilled.jsonl"
|
||||
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
|
||||
NODES=4
|
||||
ALLOCATION_MODE="vllm.d16p1m1+d8p2m1"
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
|||
|
||||
EXP_NAME=ppo-zero-distill-7B-n16
|
||||
MODEL_NAME="DeepSeek-R1-Distill-Qwen-7B"
|
||||
DATASET_NAME="prompts_for_r1_distilled.jsonl"
|
||||
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
|
||||
NODES=16
|
||||
ALLOCATION_MODE="vllm.d16p1m4+d32p2m1"
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
|||
|
||||
EXP_NAME=ppo-zero-7B-zero-n16
|
||||
MODEL_NAME="Qwen2.5-7B"
|
||||
DATASET_NAME="prompts_for_zero.jsonl"
|
||||
DATASET_NAME="full_orz_zero.jsonl"
|
||||
NODES=16
|
||||
ALLOCATION_MODE="vllm.d64p1m1+d32p2m1"
|
||||
|
||||
|
|
|
@ -1,52 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||
|
||||
EXP_NAME=ppo-zero-distill-1.5B-n16-jun1
|
||||
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
DATASET_NAME="codeparrot-apps-test.jsonl"
|
||||
NODES=16
|
||||
ALLOCATION_MODE="vllm.d64p1m1+d32p2m1"
|
||||
|
||||
LOG_DIR="/storage/ray/train_batch_logs/${EXP_NAME}/$(date +'%Y%m%d-%H%M%S')"
|
||||
mkdir -p ${LOG_DIR}
|
||||
echo "Log Dir: ${LOG_DIR}"
|
||||
|
||||
MAX_WORKERS=$(expr 16 / ${NODES})
|
||||
|
||||
FIFO_NAME=$(mktemp -u)
|
||||
mkfifo "$FIFO_NAME"
|
||||
exec 3<>"$FIFO_NAME"
|
||||
rm -f "$FIFO_NAME"
|
||||
|
||||
for ((i=0; i<MAX_WORKERS; i++)); do
|
||||
echo >&3
|
||||
done
|
||||
|
||||
|
||||
ALL_PARAMS=(
|
||||
"${EXP_NAME} ${MODEL_NAME} ${DATASET_NAME} 1024 8 ${NODES} ${ALLOCATION_MODE} 16384 128 1 0.001"
|
||||
#"${EXP_NAME} ${MODEL_NAME} ${DATASET_NAME} 1024 8 ${NODES} ${ALLOCATION_MODE} 16384 128 1 0.001"
|
||||
#"${EXP_NAME} ${MODEL_NAME} ${DATASET_NAME} 1024 8 ${NODES} ${ALLOCATION_MODE} 16384 128 1 0.001"
|
||||
)
|
||||
|
||||
echo "Task Count: ${#ALL_PARAMS[@]}"
|
||||
|
||||
for ((i=0; i<${#ALL_PARAMS[@]}; i++)); do
|
||||
read -u3
|
||||
|
||||
{
|
||||
echo "$(date +"%Y-%m-%d %H:%M.%S") Task $i started: ${ALL_PARAMS[$i]}"
|
||||
bash -c "bash ${SCRIPT_DIR}/train_code_small_on_ray.sh ${ALL_PARAMS[$i]} &> ${LOG_DIR}/${i}.log"
|
||||
echo "$(date +"%Y-%m-%d %H:%M.%S") Task $i completed with exit code: $?, ${ALL_PARAMS[$i]}"
|
||||
sleep 120
|
||||
echo >&3
|
||||
} &
|
||||
|
||||
sleep 120
|
||||
done
|
||||
|
||||
wait
|
||||
|
||||
exec 3>&-
|
||||
echo "All tasks completed"
|
|
@ -1,123 +0,0 @@
|
|||
#!/bin/sh
|
||||
MODEL_FAMILY=qwen2
|
||||
|
||||
EXP_NAME="$1"
|
||||
MODEL_NAME="$2"
|
||||
DATASET_NAME="$3"
|
||||
TRAIN_BATCH_SIZE="$4"
|
||||
GROUP_SIZE="$5"
|
||||
NODES="$6"
|
||||
ALLOCATION_MODE="$7"
|
||||
MAX_NEW_TOKENS=$8
|
||||
MAX_NUM_SEQS=$9
|
||||
PPO_MBS=${10}
|
||||
KL_CTL=${11}
|
||||
|
||||
MAX_TOKEN_PER_MB=$(expr 2048 + ${MAX_NEW_TOKENS} + 1024)
|
||||
MAX_SEQ_LEN_TO_CAPTURE=$(expr 2048 + ${MAX_NEW_TOKENS})
|
||||
|
||||
BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
|
||||
|
||||
# original data
|
||||
DATA_PATH="/storage/datasets/${DATASET_NAME}"
|
||||
REAL_CODE_METADATA_PATH="/storage/datasets/codeparrot-apps-test.jsonl"
|
||||
|
||||
# Option 1: The experiment runs locally with subprocesses.
|
||||
# MODE=local
|
||||
# Option 2: The experiment runs in a Ray cluster
|
||||
# MODE=ray
|
||||
# Option 3: The experiment runs in a SLURM + pyxis cluster
|
||||
# Using the slurm mode requires a cluster spec file
|
||||
# and setting CLUSTER_SPEC_PATH to the path of it.
|
||||
MODE=ray
|
||||
|
||||
# `experiment_name` and `trial_name` can be arbitrary.
|
||||
# Logs and saved checkpoints will be indexed by them.
|
||||
#EXP_NAME=ppo-zero--${MODEL_NAME}--${DATASET_NAME}
|
||||
#EXP_NAME=ppo-zero-distill-1.5B-default
|
||||
TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
|
||||
|
||||
# We use the "heuristic" allocation mode here to automatically determine the parallelism strategy
|
||||
# for each model function call, i.e., actor generation, critic inference, actor train, etc.
|
||||
# The number of GPUs is `n_nodes` * `n_gpus_per_node` (not set explictly here, defaults to 8).
|
||||
# ReaL will make full use of these available GPUs to design allocations.
|
||||
# This does not ensure the optimal throughput, but it is a good starting point.
|
||||
|
||||
# The `heuristic` allocation mode is not ensured to run with every model configurations.
|
||||
# For example, if the vocabulary size is an odd number, the model parallelism may not work.
|
||||
# In these cases, you can use the `ppo_manual.sh` to specify the parallelism strategy manually.
|
||||
|
||||
# The `ppo` subcommand specifies that this is a PPO experiment.
|
||||
# The `save_freq_steps` is set to `null` to disable saving checkpoints.
|
||||
# Enable it if you want to save checkpoints.
|
||||
# The `ppo` option is used to control the generation and PPO algorithm hyperparameters.
|
||||
# Note that the performance of PPO is sensitive to the the pre-trained model and hyperparameters.
|
||||
# It's the user's responsibility to tune them appropriately.
|
||||
unset CLUSTER_SPEC_PATH
|
||||
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
||||
REAL_CODE_METADATA_PATH=${REAL_CODE_METADATA_PATH} \
|
||||
FUNCTIONCALL_SERVICE_DOMAIN="" \
|
||||
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
||||
python3 -m realhf.apps.quickstart ppo-code \
|
||||
mode=$MODE \
|
||||
experiment_name=$EXP_NAME \
|
||||
trial_name=$TRIAL_NAME \
|
||||
wandb.mode=disabled \
|
||||
exp_ctrl.total_train_epochs=1 \
|
||||
exp_ctrl.save_freq_epochs=1 \
|
||||
exp_ctrl.ckpt_freq_secs=600 \
|
||||
group_size=${GROUP_SIZE} \
|
||||
group_adv_norm=False \
|
||||
use_dense_reward=False \
|
||||
reward_delta=True \
|
||||
rw_type=sparse \
|
||||
check_xml_format=False \
|
||||
actor.type._class=$MODEL_FAMILY \
|
||||
actor.path=$BASE_MODEL_PATH \
|
||||
actor.vllm.hybrid_train=False \
|
||||
actor.vllm.enforce_eager=False \
|
||||
actor.vllm.max_seq_len_to_capture=${MAX_SEQ_LEN_TO_CAPTURE} \
|
||||
actor.vllm.max_num_seqs=${MAX_NUM_SEQS} \
|
||||
actor.vllm.gpu_memory_utilization=1 \
|
||||
actor.vllm.swap_space=64 \
|
||||
critic.type._class=$MODEL_FAMILY \
|
||||
critic.type.is_critic=True \
|
||||
critic.init_critic_from_actor=True \
|
||||
critic.path=$BASE_MODEL_PATH\
|
||||
ref.type._class=$MODEL_FAMILY \
|
||||
ref.path=$BASE_MODEL_PATH \
|
||||
rew.type._class=$MODEL_FAMILY \
|
||||
rew.type.is_critic=True \
|
||||
rew.init_critic_from_actor=True \
|
||||
rew.path=$BASE_MODEL_PATH \
|
||||
dataset.path=$DATA_PATH \
|
||||
dataset.max_prompt_len=2048 \
|
||||
dataset.train_bs_n_seqs=${TRAIN_BATCH_SIZE} \
|
||||
ppo.gen.max_new_tokens=${MAX_NEW_TOKENS} \
|
||||
ppo.gen.min_new_tokens=0 \
|
||||
ppo.disable_value=True \
|
||||
ppo.gen.top_p=1 ppo.gen.top_k=1000000 \
|
||||
ppo.ppo_n_minibatches=${PPO_MBS} \
|
||||
ppo.gen.temperature=0.6 \
|
||||
ppo.kl_ctl=${KL_CTL} \
|
||||
ppo.value_eps_clip=0.2 \
|
||||
ppo.reward_output_scaling=5 \
|
||||
ppo.reward_output_bias=0.0 \
|
||||
ppo.adv_norm=True ppo.value_norm=True \
|
||||
mask_too_long=False \
|
||||
ppo.discount=1.0 \
|
||||
actor.optimizer.lr=1e-6 \
|
||||
critic.optimizer.lr=5e-6 \
|
||||
actor.optimizer.lr_scheduler_type=constant \
|
||||
actor_gen.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
ref_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
rew_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
critic_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
actor_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
critic_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
cache_clear_freq=1 \
|
||||
n_nodes=${NODES} \
|
||||
allocation_mode="'${ALLOCATION_MODE}'" n_gpus_per_node=8 \
|
||||
recover_mode=auto \
|
||||
recover_retries=10 \
|
||||
torch_cache_mysophobia=True
|
|
@ -20,7 +20,6 @@ BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
|
|||
|
||||
# original data
|
||||
DATA_PATH="/storage/datasets/${DATASET_NAME}"
|
||||
REAL_MATH_METADATA_PATH="/storage/datasets/id2info.json"
|
||||
|
||||
# Option 1: The experiment runs locally with subprocesses.
|
||||
# MODE=local
|
||||
|
@ -55,7 +54,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
|
|||
# It's the user's responsibility to tune them appropriately.
|
||||
unset CLUSTER_SPEC_PATH
|
||||
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
||||
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
|
||||
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
||||
python3 -m realhf.apps.quickstart ppo-math \
|
||||
mode=$MODE \
|
||||
|
@ -67,8 +65,6 @@ python3 -m realhf.apps.quickstart ppo-math \
|
|||
exp_ctrl.ckpt_freq_secs=600 \
|
||||
group_size=${GROUP_SIZE} \
|
||||
group_adv_norm=False \
|
||||
use_dense_reward=False \
|
||||
reward_delta=True \
|
||||
rw_type=sparse \
|
||||
check_xml_format=False \
|
||||
actor.type._class=$MODEL_FAMILY \
|
||||
|
@ -85,10 +81,6 @@ python3 -m realhf.apps.quickstart ppo-math \
|
|||
critic.path=$BASE_MODEL_PATH\
|
||||
ref.type._class=$MODEL_FAMILY \
|
||||
ref.path=$BASE_MODEL_PATH \
|
||||
rew.type._class=$MODEL_FAMILY \
|
||||
rew.type.is_critic=True \
|
||||
rew.init_critic_from_actor=True \
|
||||
rew.path=$BASE_MODEL_PATH \
|
||||
dataset.path=$DATA_PATH \
|
||||
dataset.max_prompt_len=2048 \
|
||||
dataset.train_bs_n_seqs=${TRAIN_BATCH_SIZE} \
|
||||
|
@ -103,6 +95,7 @@ python3 -m realhf.apps.quickstart ppo-math \
|
|||
ppo.reward_output_scaling=5 \
|
||||
ppo.reward_output_bias=0.0 \
|
||||
ppo.adv_norm=True ppo.value_norm=True \
|
||||
ppo.fuse_rew_ref=False \
|
||||
mask_too_long=False \
|
||||
ppo.discount=1.0 \
|
||||
actor.optimizer.lr=1e-6 \
|
||||
|
@ -110,7 +103,7 @@ python3 -m realhf.apps.quickstart ppo-math \
|
|||
actor.optimizer.lr_scheduler_type=constant \
|
||||
actor_gen.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
ref_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
rew_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
actor_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
critic_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
actor_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
critic_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
|
|
|
@ -20,7 +20,6 @@ BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
|
|||
|
||||
# original data
|
||||
DATA_PATH="/storage/datasets/${DATASET_NAME}"
|
||||
REAL_MATH_METADATA_PATH="/storage/datasets/id2info.json"
|
||||
|
||||
# Option 1: The experiment runs locally with subprocesses.
|
||||
# MODE=local
|
||||
|
@ -55,7 +54,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
|
|||
# It's the user's responsibility to tune them appropriately.
|
||||
unset CLUSTER_SPEC_PATH
|
||||
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
||||
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
|
||||
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
||||
python3 -m realhf.apps.quickstart ppo-math \
|
||||
mode=$MODE \
|
||||
|
@ -67,8 +65,6 @@ python3 -m realhf.apps.quickstart ppo-math \
|
|||
exp_ctrl.ckpt_freq_secs=600 \
|
||||
group_size=${GROUP_SIZE} \
|
||||
group_adv_norm=False \
|
||||
use_dense_reward=False \
|
||||
reward_delta=True \
|
||||
rw_type=sparse \
|
||||
check_xml_format=False \
|
||||
actor.type._class=$MODEL_FAMILY \
|
||||
|
@ -85,10 +81,6 @@ python3 -m realhf.apps.quickstart ppo-math \
|
|||
critic.path=$BASE_MODEL_PATH\
|
||||
ref.type._class=$MODEL_FAMILY \
|
||||
ref.path=$BASE_MODEL_PATH \
|
||||
rew.type._class=$MODEL_FAMILY \
|
||||
rew.type.is_critic=True \
|
||||
rew.init_critic_from_actor=True \
|
||||
rew.path=$BASE_MODEL_PATH \
|
||||
dataset.path=$DATA_PATH \
|
||||
dataset.max_prompt_len=2048 \
|
||||
dataset.train_bs_n_seqs=${TRAIN_BATCH_SIZE} \
|
||||
|
@ -103,6 +95,7 @@ python3 -m realhf.apps.quickstart ppo-math \
|
|||
ppo.reward_output_scaling=5 \
|
||||
ppo.reward_output_bias=0.0 \
|
||||
ppo.adv_norm=True ppo.value_norm=True \
|
||||
ppo.fuse_rew_ref=False \
|
||||
mask_too_long=False \
|
||||
ppo.discount=1.0 \
|
||||
actor.optimizer.lr=1e-6 \
|
||||
|
@ -110,7 +103,7 @@ python3 -m realhf.apps.quickstart ppo-math \
|
|||
actor.optimizer.lr_scheduler_type=constant \
|
||||
actor_gen.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
ref_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
rew_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
actor_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
critic_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
actor_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
critic_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
|
|
|
@ -20,7 +20,6 @@ BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
|
|||
|
||||
# original data
|
||||
DATA_PATH="/storage/datasets/${DATASET_NAME}"
|
||||
REAL_MATH_METADATA_PATH="/storage/datasets/id2info.json"
|
||||
|
||||
# Option 1: The experiment runs locally with subprocesses.
|
||||
# MODE=local
|
||||
|
@ -55,7 +54,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
|
|||
# It's the user's responsibility to tune them appropriately.
|
||||
unset CLUSTER_SPEC_PATH
|
||||
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
||||
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
|
||||
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
||||
python3 -m realhf.apps.quickstart ppo-math \
|
||||
mode=$MODE \
|
||||
|
@ -67,8 +65,6 @@ python3 -m realhf.apps.quickstart ppo-math \
|
|||
exp_ctrl.ckpt_freq_secs=600 \
|
||||
group_size=${GROUP_SIZE} \
|
||||
group_adv_norm=False \
|
||||
use_dense_reward=False \
|
||||
reward_delta=True \
|
||||
rw_type=sparse \
|
||||
check_xml_format=True \
|
||||
actor.type._class=$MODEL_FAMILY \
|
||||
|
@ -85,10 +81,6 @@ python3 -m realhf.apps.quickstart ppo-math \
|
|||
critic.path=$BASE_MODEL_PATH\
|
||||
ref.type._class=$MODEL_FAMILY \
|
||||
ref.path=$BASE_MODEL_PATH \
|
||||
rew.type._class=$MODEL_FAMILY \
|
||||
rew.type.is_critic=True \
|
||||
rew.init_critic_from_actor=True \
|
||||
rew.path=$BASE_MODEL_PATH \
|
||||
dataset.path=$DATA_PATH \
|
||||
dataset.max_prompt_len=2048 \
|
||||
dataset.train_bs_n_seqs=${TRAIN_BATCH_SIZE} \
|
||||
|
@ -103,6 +95,7 @@ python3 -m realhf.apps.quickstart ppo-math \
|
|||
ppo.reward_output_scaling=0.5 \
|
||||
ppo.reward_output_bias=-1.0 \
|
||||
ppo.adv_norm=True ppo.value_norm=True \
|
||||
ppo.fuse_rew_ref=False \
|
||||
mask_too_long=False \
|
||||
ppo.discount=1.0 \
|
||||
actor.optimizer.lr=1e-6 \
|
||||
|
@ -110,7 +103,7 @@ python3 -m realhf.apps.quickstart ppo-math \
|
|||
actor.optimizer.lr_scheduler_type=constant \
|
||||
actor_gen.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
ref_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
rew_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
actor_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
critic_inf.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
actor_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
critic_train.mb_spec.max_tokens_per_mb=${MAX_TOKEN_PER_MB} \
|
||||
|
|
|
@ -5,38 +5,25 @@ import sys
|
|||
import time
|
||||
import traceback
|
||||
from io import StringIO
|
||||
|
||||
from function.testing_util import run_test
|
||||
from typing import Dict, List
|
||||
|
||||
from functioncall.base import logging
|
||||
from functioncall.code.function.testing_util import run_test
|
||||
|
||||
logger = logging.getLogger("Functioncall")
|
||||
|
||||
|
||||
def try_load_json(data):
|
||||
"""
|
||||
Attempts to load the given data as JSON. If successful, returns the parsed JSON.
|
||||
Otherwise, returns None and logs an error.
|
||||
"""
|
||||
try:
|
||||
loaded_data = json.loads(data)
|
||||
return loaded_data
|
||||
except json.JSONDecodeError as e:
|
||||
logger.info(f"Failed to load JSON: {e}")
|
||||
return data
|
||||
|
||||
|
||||
def capture_stdout(code):
|
||||
original_stdout = sys.stdout
|
||||
fake_stdout = StringIO()
|
||||
|
||||
try:
|
||||
sys.stdout = fake_stdout # 重定向输出
|
||||
exec(code, {"__builtins__": __builtins__}) # 在隔离环境中执行
|
||||
sys.stdout = fake_stdout
|
||||
exec(code, {"__builtins__": __builtins__})
|
||||
except Exception as e:
|
||||
return f"error: {str(e)}, traceback: {traceback.format_exc()}"
|
||||
finally:
|
||||
sys.stdout = original_stdout # 恢复原stdout
|
||||
sys.stdout = original_stdout
|
||||
return fake_stdout.getvalue()
|
||||
|
||||
|
||||
|
@ -46,12 +33,17 @@ def _temp_run(problem, generation, debug, result):
|
|||
try:
|
||||
if debug:
|
||||
logger.debug(f"Running test for problem: {problem}")
|
||||
result.append(run_test(sample=problem, test=generation, debug=debug))
|
||||
r = run_test(sample=problem, test=generation, debug=debug)
|
||||
result.append(r)
|
||||
if debug:
|
||||
logger.debug(f"Test completed with result: {result}")
|
||||
except Exception as e:
|
||||
if debug:
|
||||
logger.error(f"Error in _temp_run: {e}, problem:{problem}", exe_info=True)
|
||||
|
||||
logger.warning(
|
||||
f"Error in _temp_run: {e}\n"
|
||||
f"traceback: {''.join(traceback.format_exception(*sys.exc_info()))}\n"
|
||||
f"problem:{problem}"
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(
|
||||
|
@ -64,8 +56,9 @@ def check_correctness(problem, generation, timeout, debug=False):
|
|||
The global timeout is to catch some extreme/rare cases not handled by the timeouts
|
||||
inside `run_test`"""
|
||||
if debug:
|
||||
# FIXME: error variable "problem" is not defined
|
||||
result = capture_stdout(
|
||||
"from function.testing_util import run_test\n"
|
||||
"from functioncall.code.function.testing_util import run_test\n"
|
||||
+ "run_test(sample=problem, test=generation, debug=debug)"
|
||||
)
|
||||
return result[0], result[1]
|
||||
|
@ -86,7 +79,7 @@ def check_correctness(problem, generation, timeout, debug=False):
|
|||
# Remark: ideally we would consider that all tests failed but we can't access number of tests here easily
|
||||
# so we use 21=the average number of tests for a smaple in the test split instead
|
||||
avg_number_tests = 21
|
||||
result = [[-1] * avg_number_tests]
|
||||
result = [[-1 for _ in range(avg_number_tests)], {}]
|
||||
if debug:
|
||||
logger.debug(f"Global timeout occurred, returning default result.")
|
||||
if debug:
|
||||
|
@ -99,96 +92,56 @@ def check_correctness(problem, generation, timeout, debug=False):
|
|||
return result[0]
|
||||
|
||||
|
||||
def load_problems(path):
|
||||
problem_map = {}
|
||||
for idx, line in enumerate(open(path, "rb")):
|
||||
if line is not None:
|
||||
line = line.strip().decode("utf-8")
|
||||
row = json.loads(line)
|
||||
|
||||
query_id = str(row["id"])
|
||||
problem_map[query_id] = {
|
||||
"problem_id": query_id,
|
||||
# "question": row["question"],
|
||||
"solutions": row["solutions"],
|
||||
"input_output": row["input_output"],
|
||||
# "difficulty": level,
|
||||
# "url": row["url"],
|
||||
# "starter_code": row["starter_code"],
|
||||
}
|
||||
|
||||
return problem_map
|
||||
|
||||
|
||||
global_problems = load_problems(
|
||||
os.getenv(
|
||||
"REAL_CODE_METADATA_PATH",
|
||||
"/storage/datasets/codeparrot-apps-test.jsonl",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def code_verify(generateds, query_ids, debug=False):
|
||||
assert len(generateds) == len(query_ids), (
|
||||
len(generateds),
|
||||
len(query_ids),
|
||||
)
|
||||
def code_verify(id2info, generateds, query_ids, debug=False):
|
||||
assert len(generateds) == len(query_ids)
|
||||
problems = [id2info[qid] for qid in query_ids]
|
||||
|
||||
result = []
|
||||
global global_problems
|
||||
|
||||
for idx, query_id in enumerate(query_ids):
|
||||
if query_id not in global_problems:
|
||||
continue
|
||||
|
||||
problem = global_problems[query_id]
|
||||
test_code = generateds[idx]
|
||||
for query_id, generated, problem in zip(query_ids, generateds, problems):
|
||||
logger.debug(f"run_batch_code, query_id: {query_id}")
|
||||
|
||||
try:
|
||||
curr_res, metadata = check_correctness(
|
||||
problem=problem, generation=test_code, timeout=6000, debug=debug
|
||||
problem=problem, generation=generated, timeout=6000, debug=debug
|
||||
)
|
||||
|
||||
if any(x != True for x in curr_res):
|
||||
logger.debug(
|
||||
f'id:{problem["problem_id"]}, Results were not all True: {metadata}'
|
||||
)
|
||||
result.append(f"{query_id} failed")
|
||||
logger.debug(f"id:{query_id}, Results were not all True: {metadata}")
|
||||
result.append(0)
|
||||
else:
|
||||
# print(f"id:{problem["problem_id"]}, result : {curr_res}")
|
||||
result.append(f"{query_id} success")
|
||||
result.append(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"test framework exception = {repr(e)}{e}\n", exe_info=True)
|
||||
result.append(f"{query_id} failed")
|
||||
break
|
||||
finally:
|
||||
# print(f"id:{problem["problem_id"]}, result : {curr_res}")
|
||||
# assert isinstance(curr_res, list)
|
||||
pass
|
||||
exc_info = sys.exc_info()
|
||||
logger.error(
|
||||
f"test framework exception = {repr(e)}{e}\n{traceback.format_exception(*exc_info)}"
|
||||
)
|
||||
result.append(0)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
path = "/storage/openpsi/data/code/apps/test.jsonl"
|
||||
data = []
|
||||
with open(path, "r") as f:
|
||||
code_data = [json.loads(l) for l in f.readlines()]
|
||||
|
||||
def create_test_params(index_list):
|
||||
global global_problems
|
||||
codes, query_ids = [], []
|
||||
id2info = {}
|
||||
solutions = []
|
||||
query_ids = []
|
||||
for i in range(10):
|
||||
problem = code_data[i]
|
||||
problem["problem_id"] = problem["id"]
|
||||
id2info[problem["problem_id"]] = problem
|
||||
solutions.append(json.loads(problem["solutions"])[0])
|
||||
query_ids.append(problem["id"])
|
||||
|
||||
for index in index_list:
|
||||
if str(index) not in global_problems:
|
||||
continue
|
||||
|
||||
problem = global_problems[str(index)]
|
||||
if "solutions" not in problem or not problem["solutions"]:
|
||||
continue
|
||||
|
||||
codes.append(try_load_json(problem["solutions"])[0])
|
||||
query_ids.append(str(index))
|
||||
return codes, query_ids
|
||||
|
||||
codes, query_ids = create_test_params(list(range(805, 806)))
|
||||
result = code_verify(codes, query_ids, True)
|
||||
result = code_verify(
|
||||
id2info,
|
||||
solutions,
|
||||
query_ids,
|
||||
debug=False,
|
||||
)
|
||||
print(result)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
from functioncall.base import logging
|
||||
|
@ -8,29 +9,14 @@ from functioncall.base.call import batch_function_call
|
|||
logger = logging.getLogger("Functioncall")
|
||||
|
||||
|
||||
def try_load_json(data):
|
||||
"""
|
||||
Attempts to load the given data as JSON. If successful, returns the parsed JSON.
|
||||
Otherwise, returns None and logs an error.
|
||||
"""
|
||||
try:
|
||||
loaded_data = json.loads(data)
|
||||
return loaded_data
|
||||
except json.JSONDecodeError as e:
|
||||
# print(f"Failed to load JSON: {e}")
|
||||
return data
|
||||
|
||||
|
||||
def load_problems_with_testcase_batch(path, debug=False, test_case_batch_size=None):
|
||||
def load_problems_with_testcase_batch(
|
||||
id2info, query_ids, debug=False, test_case_batch_size=None
|
||||
):
|
||||
problem_map = defaultdict(list)
|
||||
for idx, line in enumerate(open(path, "rb")):
|
||||
if line is None:
|
||||
continue
|
||||
|
||||
for idx, query_id in enumerate(query_ids):
|
||||
problem = id2info[query_id]
|
||||
# parse one problem
|
||||
row = json.loads(line.strip().decode("utf-8"))
|
||||
query_id = str(row.get("id", row.get("query_id")))
|
||||
input_output = json.loads(row["input_output"]) if "input_output" in row else {}
|
||||
input_output = json.loads(problem["input_output"])
|
||||
inputs = input_output.get("inputs", [])
|
||||
outputs = input_output.get("outputs", [])
|
||||
assert len(inputs) == len(
|
||||
|
@ -57,17 +43,14 @@ def load_problems_with_testcase_batch(path, debug=False, test_case_batch_size=No
|
|||
"batche_index": batch_idx,
|
||||
}
|
||||
if debug:
|
||||
sub_problem["solutions"] = row.get("solutions", [])
|
||||
sub_problem["solutions"] = problem.get("solutions", [])
|
||||
problem_map[query_id].append(sub_problem)
|
||||
|
||||
return problem_map
|
||||
|
||||
|
||||
global_problems = None
|
||||
|
||||
|
||||
def code_verify(
|
||||
generateds, query_ids, debug=False, timeout=1000, timeout_for_testcase=6
|
||||
id2info, generateds, query_ids, debug=False, timeout=1000, timeout_for_testcase=6
|
||||
):
|
||||
assert len(generateds) == len(query_ids), (
|
||||
len(generateds),
|
||||
|
@ -75,24 +58,13 @@ def code_verify(
|
|||
)
|
||||
payload_list = []
|
||||
|
||||
global global_problems
|
||||
if global_problems is None:
|
||||
global_problems = load_problems_with_testcase_batch(
|
||||
os.getenv(
|
||||
"REAL_CODE_METADATA_PATH",
|
||||
"/storage/datasets/codeparrot-apps-test.jsonl",
|
||||
),
|
||||
debug=True,
|
||||
test_case_batch_size=20,
|
||||
)
|
||||
global_problems = load_problems_with_testcase_batch(
|
||||
id2info,
|
||||
query_ids,
|
||||
debug=True,
|
||||
test_case_batch_size=20,
|
||||
)
|
||||
for idx, query_id in enumerate(query_ids):
|
||||
if query_id not in global_problems:
|
||||
payload_list.append(None)
|
||||
logger.warning(
|
||||
f"Invalid query id : {query_id}, type: {type(query_id)}, should be in problem dataset!"
|
||||
)
|
||||
continue
|
||||
|
||||
problems = global_problems[query_id]
|
||||
for problem in problems:
|
||||
payload_list.append(
|
||||
|
@ -129,34 +101,28 @@ def code_verify(
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
path = "/storage/openpsi/data/code/apps/codeparrot-apps-test.jsonl"
|
||||
data = []
|
||||
with open(path, "r") as f:
|
||||
code_data = [json.loads(l) for l in f.readlines()]
|
||||
|
||||
id2info = {}
|
||||
|
||||
def create_test_params(count=10):
|
||||
global global_problems
|
||||
if global_problems is None:
|
||||
global_problems = load_problems_with_testcase_batch(
|
||||
os.getenv(
|
||||
"REAL_CODE_METADATA_PATH",
|
||||
"/storage/datasets/codeparrot-apps-test.jsonl",
|
||||
),
|
||||
debug=True,
|
||||
test_case_batch_size=20,
|
||||
)
|
||||
codes, query_ids = [], []
|
||||
idx = 0
|
||||
for query_id, problems in global_problems.items():
|
||||
if idx >= count:
|
||||
break
|
||||
|
||||
problem = problems[0]
|
||||
if "solutions" not in problem or not problem["solutions"]:
|
||||
global id2info
|
||||
query_ids = []
|
||||
generateds = []
|
||||
cnt = 0
|
||||
while cnt < count:
|
||||
d = random.choice(code_data)
|
||||
if not d["solutions"]:
|
||||
continue
|
||||
id2info[d["id"]] = d
|
||||
query_ids.append(d["id"])
|
||||
generateds.append(d["solutions"][0])
|
||||
cnt += 1
|
||||
return generateds, query_ids
|
||||
|
||||
codes.append(try_load_json(problem["solutions"])[0])
|
||||
query_ids.append(query_id)
|
||||
idx += 1
|
||||
|
||||
return codes, query_ids
|
||||
|
||||
codes, query_ids = create_test_params(100)
|
||||
result = code_verify(codes, query_ids, True)
|
||||
generateds, query_ids = create_test_params(100)
|
||||
result = code_verify(id2info, generateds, query_ids, True)
|
||||
print(result)
|
||||
|
|
|
@ -7,7 +7,7 @@ from grader import math_equal
|
|||
|
||||
def process_results(answer, solution):
|
||||
extracted_answer = extract_answer(answer, "math", use_last_number=False)
|
||||
extracted_solution = solution
|
||||
extracted_solution = extract_answer(solution, "math", use_last_number=True)
|
||||
|
||||
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
|
||||
retval = 0
|
||||
|
|
|
@ -9,28 +9,9 @@ from functioncall.base.call import batch_function_call
|
|||
logger = logging.getLogger("Functioncall")
|
||||
|
||||
|
||||
def loadJson(dataDir):
|
||||
with open(dataDir, "r") as f:
|
||||
if dataDir.endswith(".jsonl"):
|
||||
samples = [json.loads(line) for line in f.readlines()]
|
||||
else:
|
||||
samples = json.load(f)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
id2info = None
|
||||
|
||||
|
||||
def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=1000) -> List:
|
||||
global id2info
|
||||
if id2info is None:
|
||||
id2info = loadJson(
|
||||
os.getenv(
|
||||
"REAL_MATH_MEATADATA_PATH",
|
||||
"/storage/datasets/id2info.json",
|
||||
)
|
||||
)
|
||||
def math_verify(
|
||||
id2info, generateds: List, query_ids: List, batch_size=10, timeout=1000
|
||||
) -> List:
|
||||
assert len(generateds) == len(query_ids), (
|
||||
len(generateds),
|
||||
len(query_ids),
|
||||
|
@ -43,7 +24,7 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=1000)
|
|||
for idx, (query_id, generated) in enumerate(zip(query_ids, generateds)):
|
||||
base_query_id = query_id.split("@idx:")[0]
|
||||
info = id2info[base_query_id]
|
||||
for cur_solution in info["answers"]:
|
||||
for cur_solution in info["solutions"]:
|
||||
parameters.append((generated, cur_solution, idx))
|
||||
query_indices.append(idx)
|
||||
|
||||
|
@ -95,27 +76,19 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=1000)
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# sample = {
|
||||
# "prompt": "",
|
||||
# "query_id": "fe11b471-1aa9-4867-958f-a0a811c85f92",
|
||||
# "answer": "\\boxed{-\\frac{1}{30}}",
|
||||
# }
|
||||
|
||||
if id2info is None:
|
||||
id2info = loadJson(
|
||||
os.getenv(
|
||||
"REAL_MATH_MEATADATA_PATH",
|
||||
"/storage/datasets/id2info.json",
|
||||
)
|
||||
)
|
||||
|
||||
answers = []
|
||||
query_ids = []
|
||||
|
||||
for id, value in id2info.items():
|
||||
answers.append(value["solutions"][0])
|
||||
query_ids.append(id)
|
||||
sample = {
|
||||
"answers": ["-\\frac{2}{3}"],
|
||||
"solutions": [
|
||||
"1. **Apply the operation $\\otimes$ to the innermost parentheses first:**\n \\[\n (1 \\otimes 2) \\otimes 3 = \\left(\\frac{1^2}{2}\\right) \\otimes 3 = \\frac{1}{2} \\otimes 3\n \\]\n \\[\n 1 \\otimes (2 \\otimes 3) = 1 \\otimes \\left(\\frac{2^2}{3}\\right) = 1 \\otimes \\frac{4}{3}\n \\]\n\n2. **Calculate each part using the definition of $\\otimes$:**\n \\[\n \\frac{1}{2} \\otimes 3 = \\frac{\\left(\\frac{1}{2}\\right)^2}{3} = \\frac{\\frac{1}{4}}{3} = \\frac{1}{12}\n \\]\n \\[\n 1 \\otimes \\frac{4}{3} = \\frac{1^2}{\\frac{4}{3}} = \\frac{1}{\\frac{4}{3}} = \\frac{3}{4}\n \\]\n\n3. **Subtract the two results:**\n \\[\n \\left(\\frac{1}{12}\\right) - \\left(\\frac{3}{4}\\right) = \\frac{1}{12} - \\frac{9}{12} = -\\frac{8}{12} = -\\frac{2}{3}\n \\]\n\n4. **Conclude with the final answer:**\n \\[\n \\boxed{A}\n \\]",
|
||||
"\\boxed{-\\frac{2}{3}}",
|
||||
],
|
||||
}
|
||||
id2info = {"fe11b471-1aa9-4867-958f-a0a811c85f92": sample}
|
||||
|
||||
start_time = time.time()
|
||||
result = math_verify(answers[:200], query_ids[:200])
|
||||
result = math_verify(
|
||||
id2info,
|
||||
sample["answers"] * 100,
|
||||
["fe11b471-1aa9-4867-958f-a0a811c85f92" for _ in range(100)],
|
||||
)
|
||||
print(result)
|
||||
|
|
|
@ -43,6 +43,8 @@ from realhf.base.cluster import spec as cluster_spec
|
|||
|
||||
logger = logging.getLogger("api.data")
|
||||
|
||||
RL_TASKS = ["math", "code", "rlhf"]
|
||||
|
||||
|
||||
def load_hf_tokenizer(
|
||||
model_name_or_path: str,
|
||||
|
@ -529,6 +531,7 @@ class SequenceSample:
|
|||
"rewards",
|
||||
"greedy_rewards",
|
||||
"base_scores",
|
||||
"task_ids",
|
||||
]:
|
||||
return [[1] for _ in seqlens]
|
||||
elif key in [
|
||||
|
@ -720,9 +723,6 @@ def load_shuffle_split_dataset(
|
|||
if dataset_path.endswith(".jsonl"):
|
||||
with open(dataset_path, "r") as f:
|
||||
data = [json.loads(ff) for ff in f]
|
||||
elif dataset_path.endswith(".json"):
|
||||
with open(dataset_path, "r") as f:
|
||||
data = json.load(f)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown dataset extension: {dataset_path}")
|
||||
else:
|
||||
|
|
|
@ -160,8 +160,6 @@ def main_start(args, recover_count: int = 0):
|
|||
CLUSTER_SPEC_PATH=cluster_spec_path,
|
||||
REAL_RECOVER_RUN="1" if is_recover_run else "0",
|
||||
REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0",
|
||||
REAL_MATH_METADATA_PATH=os.getenv("REAL_MATH_METADATA_PATH", ""),
|
||||
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
|
||||
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
||||
REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", "localhost:2379"),
|
||||
)
|
||||
|
|
|
@ -513,6 +513,10 @@ def tp_and_pp_group():
|
|||
return grid().get_model_parallel_group()
|
||||
|
||||
|
||||
def tp_and_pp_cpu_group():
|
||||
return grid().ds_model_proc_group_gloo
|
||||
|
||||
|
||||
def tp_and_pp_rank():
|
||||
"""Used as the rank in the world group of vLLM."""
|
||||
return grid().get_model_parallel_rank()
|
||||
|
|
|
@ -206,9 +206,9 @@ class LocalMultiProcessTest:
|
|||
|
||||
|
||||
def init_global_constants(
|
||||
num_dp,
|
||||
num_mp,
|
||||
num_pp,
|
||||
num_dp=1,
|
||||
num_mp=1,
|
||||
num_pp=1,
|
||||
topo=None,
|
||||
model_name=None,
|
||||
msid2mwid=None,
|
||||
|
@ -217,7 +217,12 @@ def init_global_constants(
|
|||
gradient_accumulation_fusion=False,
|
||||
max_prompt_len=None,
|
||||
is_train: bool = True,
|
||||
expr_name=None,
|
||||
trial_name=None,
|
||||
):
|
||||
expr_name = expr_name if expr_name is not None else _DEFAULT_EXPR_NAME
|
||||
trial_name = trial_name if trial_name is not None else _DEFAULT_TRIAL_NAME
|
||||
constants.set_experiment_trial_names(expr_name, trial_name)
|
||||
model_name = model_name if model_name is not None else MODEL_NAME
|
||||
|
||||
if topo is None:
|
||||
|
|
|
@ -21,7 +21,7 @@ def check_is_realhf_native_impl(_cls):
|
|||
def check_is_realhf_native_model_interface(name):
|
||||
# NOTE: we should not import iterfaces here,
|
||||
# such that we can avoid CUDA initialization.
|
||||
return name in ["ppo_actor", "ppo_critic", "sft"]
|
||||
return name in ["ppo_actor", "ppo_critic", "sft", "rw-math-code", "fused-threading"]
|
||||
|
||||
|
||||
def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]):
|
||||
|
|
|
@ -1,562 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
import copy
|
||||
import dataclasses
|
||||
import math
|
||||
import os
|
||||
import pprint
|
||||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.core.config import (
|
||||
DatasetAbstraction,
|
||||
ModelInterfaceAbstraction,
|
||||
ModelInterfaceType,
|
||||
)
|
||||
from realhf.api.core.dfg import MFCDef, ParamReallocHook
|
||||
from realhf.api.core.model_api import GenerationHyperparameters
|
||||
from realhf.api.core.system_api import ExperimentConfig
|
||||
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
|
||||
from realhf.api.quickstart.device_mesh import DeviceMesh, MFCConfig, RPCAllocation
|
||||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
|
||||
from realhf.experiments.common.common import CommonExperimentConfig
|
||||
from realhf.experiments.common.utils import resolve_replica_ids, resolve_rpc_hooks
|
||||
|
||||
logger = logging.getLogger("PPO Code exp", "colored")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PPOHyperparameters:
|
||||
"""Configuration of PPO hyperparameters.
|
||||
|
||||
:param gen: Generation hyperparameters.
|
||||
:type gen: GenerationHyperparameters
|
||||
:param ppo_n_minibatches: Number of minibatches in each PPO update.
|
||||
:type ppo_n_minibatches: int
|
||||
:param kl_ctl: Coefficient of KL divergence rewards.
|
||||
:type kl_ctl: float
|
||||
:param discount: Discount factor.
|
||||
:type discount: float
|
||||
:param gae_lambda: Lambda factor in GAE.
|
||||
:type gae_lambda: float
|
||||
:param eps_clip: PPO actor probability ratio clipping factor.
|
||||
:type eps_clip: float
|
||||
:param value_eps_clip: PPO value clipping factor.
|
||||
:type value_eps_clip: float
|
||||
:param max_reward_clip: Maximum reward value.
|
||||
:type max_reward_clip: float
|
||||
:param reward_output_scaling: Scaling factor of the reward model output.
|
||||
:type reward_output_scaling: float
|
||||
:param reward_output_bias: Bias of the reward model output.
|
||||
The number outputed by the reward model will be
|
||||
CLIP((x - bias) * scaling, -max_reward_clip, max_reward_clip).
|
||||
:type reward_output_bias: float
|
||||
:param early_stop_imp_ratio: PPO update will be early stopped if importance ratio
|
||||
exceeds this maximum value.
|
||||
:type early_stop_imp_ratio: float
|
||||
:param use_adaptive_kl_ctl: Whether to use adaptive KL divergence coefficient.
|
||||
:type use_adaptive_kl_ctl: bool
|
||||
:param adv_norm: Whether to use advantage normalization.
|
||||
:type adv_norm: bool
|
||||
:param value_norm: Whether to denormalize valued and normalize return predictions.
|
||||
:type value_norm: bool
|
||||
:param value_norm_type: Type of value normalization.
|
||||
Either exponential moving average ("exp") or moving average ("ma").
|
||||
:type value_norm_type: str
|
||||
:param value_norm_beta: Exponential decay factor
|
||||
in exponential moving average.
|
||||
:type value_norm_beta: float
|
||||
:param value_norm_eps: Epsilon factor in the
|
||||
denominator of exponential moving average.
|
||||
:type value_norm_eps: float
|
||||
:param disable_value: A shortcut option to disable the critic model.
|
||||
:type disable_value: bool
|
||||
"""
|
||||
|
||||
gen: GenerationHyperparameters = dataclasses.field(
|
||||
default_factory=GenerationHyperparameters
|
||||
)
|
||||
ppo_n_minibatches: int = 4
|
||||
kl_ctl: float = 0.1
|
||||
discount: float = 1.0
|
||||
gae_lambda: float = 1.0
|
||||
eps_clip: float = 0.2
|
||||
value_eps_clip: float = 0.2
|
||||
disable_value: bool = False
|
||||
max_reward_clip: float = 20.0
|
||||
reward_output_scaling: float = 1.0
|
||||
reward_output_bias: float = 0.0
|
||||
early_stop_imp_ratio: float = 5.0
|
||||
use_adaptive_kl_ctl: bool = False
|
||||
adv_norm: bool = True
|
||||
value_norm: bool = True
|
||||
value_norm_type: str = dataclasses.field(
|
||||
metadata={"choices": ["exp", "ma"]}, default="exp"
|
||||
)
|
||||
value_norm_beta: float = 0.99995
|
||||
value_norm_eps: float = 1e-5
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PPOCODEConfig(CommonExperimentConfig):
|
||||
"""PPO experiment configuration.
|
||||
|
||||
It is a subclass of :class:`CommonExperimentConfig`,
|
||||
so all CLI options in the base class are available.
|
||||
|
||||
We don't implement runtime evaluation for PPO.
|
||||
|
||||
We identify that the RLHF process is composed of four
|
||||
distinct models with independent parameters and six
|
||||
*model function calls* upon these models.
|
||||
|
||||
The four models are\:
|
||||
|
||||
- Actor\: The primary LLM that generates text.
|
||||
- Critic\: The value function that estimates the value of a state.
|
||||
- Ref\: The reference LLM that provides KL regularization.
|
||||
- Rew\: The reward model that provides reward signals.
|
||||
|
||||
The four model function calls and their dependencies are\:
|
||||
|
||||
- Rollout\: Generate text from the actor model.
|
||||
- InfReward\: Infer rewards from the reward model given generated text.
|
||||
- InfRef\: Infer log probabilities from the reference model given generated text.
|
||||
- InfValues\: Infer values from the critic model given generated text.
|
||||
- TrainActor\: Train the actor model given generated text, rewards, values, and reference log probabilities.
|
||||
- TrainCritic\: Train the critic model given generated text, rewards, values, and reference log probabilities.
|
||||
|
||||
This class resolves these dependencies under the hood.
|
||||
What the users should specify are the runtime configurations
|
||||
of models and allocations of *each model function call*.
|
||||
|
||||
:param actor: Runtime configuration of the primary LLM.
|
||||
:type actor: ModelTrainEvalConfig
|
||||
:param critic: Runtime configuration of the critic model of PPO.
|
||||
:type critic: ModelTrainEvalConfig
|
||||
:param ref: Runtime configuration of the reference LLM.
|
||||
:type ref: ModelTrainEvalConfig
|
||||
:param rew: Runtime configuration of the reward LLM.
|
||||
:type rew: ModelTrainEvalConfig
|
||||
:param actor_train: :class:`MFCConfig` for TrainActor.
|
||||
:type actor_train: MFCConfig
|
||||
:param critic_train: :class:`MFCConfig` for TrainCritic.
|
||||
:type critic_train: MFCConfig
|
||||
:param actor_gen: :class:`MFCConfig` for Rollout.
|
||||
:type actor_gen: MFCConfig
|
||||
:param critic_inf: :class:`MFCConfig` for InfValues.
|
||||
:type critic_inf: MFCConfig
|
||||
:param rew_inf: :class:`MFCConfig` for InfReward.
|
||||
:type rew_inf: MFCConfig
|
||||
:param ref_inf: :class:`MFCConfig` for InfRef.
|
||||
:type ref_inf: MFCConfig
|
||||
:param dataset: Dataset configuration.
|
||||
:type dataset: PromptOnlyDatasetConfig
|
||||
:param ppo: Configuration for the PPO algorithm.
|
||||
:type ppo: PPOHyperparameters
|
||||
:param group_size: The number of answers remained for each prompt.
|
||||
:type group_size: int
|
||||
:param generation_size: The number of answers sampled for each prompt.
|
||||
Among them, only `group_size` samples are remained according to
|
||||
the reward score, aka best-of-n sampling. If None, use `group_size`.
|
||||
:type generation_size: Optional[int]
|
||||
:param mask_no_eos_with_zero: Whether to mask out the reward if an
|
||||
answer is truncated due to exceeding the length limit.
|
||||
:type mask_no_eos_with_zero: bool
|
||||
:param mask_too_long: Whether to mask out the PPO loss if an
|
||||
answer is truncated due to exceeding the length limit.
|
||||
:type mask_too_long: bool
|
||||
:param check_verifier_status: If True, raise an error
|
||||
when the reward is all-zero. This usually indicates a bug
|
||||
of the verifier.
|
||||
:type check_verifier_status: bool
|
||||
:param group_adv_norm: Whther to use grouped advantage
|
||||
normaliztion in GRPO.
|
||||
:type group_adv_norm: bool
|
||||
"""
|
||||
|
||||
actor: ModelTrainEvalConfig = dataclasses.field(
|
||||
default_factory=ModelTrainEvalConfig
|
||||
)
|
||||
critic: ModelTrainEvalConfig = dataclasses.field(
|
||||
default_factory=ModelTrainEvalConfig
|
||||
)
|
||||
ref: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
|
||||
rew: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
|
||||
|
||||
# for manual allocation only
|
||||
actor_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
critic_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
actor_gen: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
critic_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
rew_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
|
||||
dataset: PromptOnlyDatasetConfig = dataclasses.field(
|
||||
default_factory=PromptOnlyDatasetConfig
|
||||
)
|
||||
|
||||
ppo: PPOHyperparameters = dataclasses.field(default_factory=PPOHyperparameters)
|
||||
|
||||
group_size: int = 1
|
||||
generation_size: Optional[int] = None
|
||||
mask_no_eos_with_zero: bool = False
|
||||
rm_output_scaling: Optional[float] = None
|
||||
ref_ema_eta: Optional[float] = None
|
||||
group_adv_norm: bool = False
|
||||
mask_too_long: bool = False
|
||||
rw_type: Optional[str] = "sparse"
|
||||
task: str = "code"
|
||||
check_xml_format: bool = False
|
||||
use_dense_reward: bool = False
|
||||
reward_delta: bool = True
|
||||
|
||||
check_verifier_status: bool = False
|
||||
|
||||
dataset_filter_threshold: float = 100.0
|
||||
dataset_max_filter_percentage: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
self.ppo_kwargs = dict(
|
||||
n_minibatches=self.ppo.ppo_n_minibatches,
|
||||
kl_ctl=self.ppo.kl_ctl,
|
||||
discount=self.ppo.discount,
|
||||
gae_lambda=self.ppo.gae_lambda,
|
||||
eps_clip=self.ppo.eps_clip,
|
||||
value_eps_clip=self.ppo.value_eps_clip,
|
||||
max_reward_clip=self.ppo.max_reward_clip,
|
||||
adaptive_kl_ctl=self.ppo.use_adaptive_kl_ctl,
|
||||
value_norm=self.ppo.value_norm,
|
||||
value_norm_type=self.ppo.value_norm_type,
|
||||
value_norm_beta=self.ppo.value_norm_beta,
|
||||
value_norm_eps=self.ppo.value_norm_eps,
|
||||
disable_value=self.ppo.disable_value,
|
||||
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
|
||||
)
|
||||
|
||||
if self.rm_output_scaling is None:
|
||||
self.rm_output_scaling = self.ppo.reward_output_scaling
|
||||
|
||||
@property
|
||||
def models(self) -> Dict[str, ModelTrainEvalConfig]:
|
||||
# role to config
|
||||
if self.ppo.disable_value:
|
||||
return {
|
||||
"actor": self.actor,
|
||||
# "critic": self.critic,
|
||||
"ref": self.ref,
|
||||
"reward": self.rew,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"actor": self.actor,
|
||||
"critic": self.critic,
|
||||
"ref": self.ref,
|
||||
"reward": self.rew,
|
||||
}
|
||||
|
||||
@property
|
||||
def rpcs(self):
|
||||
if (
|
||||
self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens
|
||||
> self.actor.vllm.max_seq_len_to_capture
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"vllm max seq len to capture {self.actor.vllm.max_seq_len_to_capture} is "
|
||||
f"smaller than the prompt length + generation length {self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
|
||||
)
|
||||
if not os.path.exists(os.getenv("REAL_CODE_METADATA_PATH")):
|
||||
raise RuntimeError(
|
||||
"Dataset json path REAL_CODE_METADATA_PATH does not exist."
|
||||
)
|
||||
|
||||
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
|
||||
if not (domain.startswith("http://") and ":" in domain):
|
||||
raise RuntimeError(
|
||||
"function call address FUNCTIONCALL_SERVICE_DOMAIN is invalid."
|
||||
)
|
||||
|
||||
# interfaces
|
||||
actor_interface = ModelInterfaceAbstraction(
|
||||
"ppo_actor",
|
||||
args={
|
||||
**copy.deepcopy(self.ppo_kwargs),
|
||||
# NOTE: to_container converts the object to a dict
|
||||
# It is used for unifying the profiling API, which requires to
|
||||
# pass external interface configurations in the launch command.
|
||||
# Customized dataclass objects will not work in that case.
|
||||
"generation_config": (
|
||||
OmegaConf.to_container(self.ppo.gen, resolve=True)
|
||||
if isinstance(self.ppo.gen, (OmegaConf, DictConfig))
|
||||
else dataclasses.asdict(self.ppo.gen)
|
||||
),
|
||||
"early_stop_imp_ratio": self.ppo.early_stop_imp_ratio,
|
||||
"adv_norm": self.ppo.adv_norm,
|
||||
"group_size": self.group_size,
|
||||
"generation_size": self.generation_size,
|
||||
"group_adv_norm": self.group_adv_norm,
|
||||
"mask_too_long": self.mask_too_long,
|
||||
"use_dense_reward": self.use_dense_reward,
|
||||
"reward_delta": self.reward_delta,
|
||||
},
|
||||
)
|
||||
ref_interface = copy.deepcopy(actor_interface)
|
||||
ref_interface.args["enable_save"] = False
|
||||
|
||||
critic_interface = ModelInterfaceAbstraction(
|
||||
"ppo_critic",
|
||||
args={
|
||||
**copy.deepcopy(self.ppo_kwargs),
|
||||
"group_size": self.group_size,
|
||||
"mask_too_long": self.mask_too_long,
|
||||
"use_dense_reward": self.use_dense_reward,
|
||||
"reward_delta": self.reward_delta,
|
||||
},
|
||||
)
|
||||
critic_interface.args.pop("eps_clip")
|
||||
rw_interface = ModelInterfaceAbstraction(
|
||||
"rw_code",
|
||||
args=dict(
|
||||
rw_type=self.rw_type,
|
||||
task=self.task,
|
||||
check_xml_format=self.check_xml_format,
|
||||
tokenizer_path=self.actor.path,
|
||||
enable_save=False,
|
||||
output_scaling=self.ppo.reward_output_scaling,
|
||||
rm_output_scaling=self.rm_output_scaling,
|
||||
output_bias=self.ppo.reward_output_bias,
|
||||
group_size=self.group_size,
|
||||
check_verifier_status=self.check_verifier_status,
|
||||
max_sync_length=self.ppo.gen.max_new_tokens
|
||||
+ self.dataset.max_prompt_len
|
||||
+ 128,
|
||||
),
|
||||
)
|
||||
rollout = MFCDef(
|
||||
name="actor_gen",
|
||||
model_name="actor",
|
||||
mb_spec=self.actor_gen.mb_spec,
|
||||
interface_type=ModelInterfaceType.GENERATE,
|
||||
model_type=self.actor.type,
|
||||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=["packed_prompts"],
|
||||
output_keys=[
|
||||
"seq_no_eos_mask",
|
||||
"packed_input_ids",
|
||||
"packed_logprobs",
|
||||
"prompt_mask",
|
||||
],
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
inf_reward = MFCDef(
|
||||
name="rew_inf",
|
||||
model_name="reward",
|
||||
mb_spec=self.rew_inf.mb_spec,
|
||||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
interface_impl=rw_interface,
|
||||
model_type=self.rew.type,
|
||||
model_path=self.rew.path,
|
||||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=["packed_input_ids", "packed_prompts"],
|
||||
output_keys=["rewards", "dense_rewards"],
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
inf_ref_inputs = ["packed_input_ids"]
|
||||
inf_ref_logits = MFCDef(
|
||||
name="ref_inf",
|
||||
model_name="ref",
|
||||
mb_spec=self.ref_inf.mb_spec,
|
||||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
model_type=self.ref.type,
|
||||
model_path=self.ref.path,
|
||||
interface_impl=ref_interface,
|
||||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=inf_ref_inputs,
|
||||
output_keys=["packed_ref_logprobs"],
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
inf_values = MFCDef(
|
||||
name="critic_inf",
|
||||
model_name="critic",
|
||||
mb_spec=self.critic_inf.mb_spec,
|
||||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
interface_impl=critic_interface,
|
||||
model_type=self.critic.type,
|
||||
model_path=self.critic.path,
|
||||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=["packed_input_ids", "seq_no_eos_mask"],
|
||||
output_keys=["values"],
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
train_actor_inputs = [
|
||||
"packed_input_ids",
|
||||
"packed_logprobs",
|
||||
"packed_ref_logprobs",
|
||||
"rewards",
|
||||
"dense_rewards",
|
||||
"values",
|
||||
"prompt_mask",
|
||||
"seq_no_eos_mask",
|
||||
]
|
||||
if self.ppo.disable_value:
|
||||
train_actor_inputs.remove("values")
|
||||
train_actor = MFCDef(
|
||||
name="actor_train",
|
||||
model_name="actor",
|
||||
mb_spec=self.actor_train.mb_spec,
|
||||
interface_type=ModelInterfaceType.TRAIN_STEP,
|
||||
model_type=self.actor.type,
|
||||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=train_actor_inputs,
|
||||
log_return_value=True,
|
||||
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
train_critic = MFCDef(
|
||||
name="critic_train",
|
||||
model_name="critic",
|
||||
mb_spec=self.critic_train.mb_spec,
|
||||
interface_type=ModelInterfaceType.TRAIN_STEP,
|
||||
interface_impl=critic_interface,
|
||||
model_type=self.critic.type,
|
||||
model_path=self.critic.path,
|
||||
input_keys=[
|
||||
"packed_input_ids",
|
||||
"packed_logprobs",
|
||||
"packed_ref_logprobs",
|
||||
"rewards",
|
||||
"dense_rewards",
|
||||
"values",
|
||||
"prompt_mask",
|
||||
"seq_no_eos_mask",
|
||||
],
|
||||
log_return_value=True,
|
||||
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
if self.ppo.disable_value:
|
||||
return {
|
||||
"actor_gen": rollout,
|
||||
"actor_train": train_actor,
|
||||
# "critic_inf": inf_values,
|
||||
# "critic_train": train_critic,
|
||||
"ref_inf": inf_ref_logits,
|
||||
"rew_inf": inf_reward,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"actor_gen": rollout,
|
||||
"actor_train": train_actor,
|
||||
"critic_inf": inf_values,
|
||||
"critic_train": train_critic,
|
||||
"ref_inf": inf_ref_logits,
|
||||
"rew_inf": inf_reward,
|
||||
}
|
||||
|
||||
@property
|
||||
def allocations(self):
|
||||
if self.ppo.disable_value:
|
||||
return {
|
||||
"actor_gen": self.actor_gen,
|
||||
"actor_train": self.actor_train,
|
||||
# "critic_inf": self.critic_inf,
|
||||
# "critic_train": self.critic_train,
|
||||
"ref_inf": self.ref_inf,
|
||||
"rew_inf": self.rew_inf,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"actor_gen": self.actor_gen,
|
||||
"actor_train": self.actor_train,
|
||||
"critic_inf": self.critic_inf,
|
||||
"critic_train": self.critic_train,
|
||||
"ref_inf": self.ref_inf,
|
||||
"rew_inf": self.rew_inf,
|
||||
}
|
||||
|
||||
@property
|
||||
def datasets(self):
|
||||
return [
|
||||
DatasetAbstraction(
|
||||
"code_prompt",
|
||||
args=dict(
|
||||
dataset_path=self.dataset.path,
|
||||
max_length=self.dataset.max_prompt_len,
|
||||
fill_to_max_length=self.dataset.fill_to_max_length,
|
||||
filter_threshold=self.dataset_filter_threshold,
|
||||
max_filter_percentage=self.dataset_max_filter_percentage,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
@property
|
||||
def tokenizer_name_or_path(self) -> str:
|
||||
return self.actor.path
|
||||
|
||||
@property
|
||||
def search_kwargs(self):
|
||||
return {
|
||||
"num_gen_tokens": self.ppo.gen.max_new_tokens,
|
||||
"n_ppo_minibatches": self.ppo.ppo_n_minibatches,
|
||||
"seq_len": self.dataset.max_prompt_len,
|
||||
}
|
||||
|
||||
@property
|
||||
def max_prompt_len(self):
|
||||
return self.dataset.max_prompt_len
|
||||
|
||||
def initial_setup(self) -> ExperimentConfig:
|
||||
rpc_allocs = self._get_rpc_allocations()
|
||||
|
||||
resolve_replica_ids(rpc_allocs, self.models)
|
||||
resolve_rpc_hooks(
|
||||
rpc_allocs, self.models
|
||||
) # inplace modify MFCDefs in rpc allocations
|
||||
|
||||
pprint.pprint(rpc_allocs)
|
||||
|
||||
######### update ref model using ema, ref_ema_eta = 0 means fixed ref model #########
|
||||
def _find_rpc(name):
|
||||
return next(alloc.rpc for alloc in rpc_allocs if alloc.rpc.name == name)
|
||||
|
||||
# Remove the offload hook of ref_inf, because
|
||||
# we need to receive parameters from peer GPUs and update it immediately.
|
||||
if self.ref_ema_eta is not None:
|
||||
|
||||
ref_inf = _find_rpc("ref_inf")
|
||||
ref_inf._post_hooks = []
|
||||
|
||||
# Add an unidirectional parameter reallocation hook.
|
||||
actor_train = _find_rpc("actor_train")
|
||||
actor_train.add_post_hook(
|
||||
ParamReallocHook(
|
||||
target=ref_inf.model_name,
|
||||
eta=self.ref_ema_eta,
|
||||
)
|
||||
)
|
||||
######### The main difference from normal PPO #########
|
||||
|
||||
model_worker = self._get_model_worker_configs(rpc_allocs)
|
||||
|
||||
return ExperimentConfig(
|
||||
exp_ctrl=self.exp_ctrl,
|
||||
wandb=self.wandb,
|
||||
model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs],
|
||||
model_worker=model_worker,
|
||||
)
|
||||
|
||||
|
||||
register_quickstart_exp("ppo-code", PPOCODEConfig)
|
|
@ -7,8 +7,6 @@ import os
|
|||
import pprint
|
||||
from typing import *
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.core.config import (
|
||||
DatasetAbstraction,
|
||||
|
@ -89,8 +87,6 @@ class PPOHyperparameters:
|
|||
gae_lambda: float = 1.0
|
||||
eps_clip: float = 0.2
|
||||
value_eps_clip: float = 0.2
|
||||
disable_value: bool = False
|
||||
recompute_logprob: bool = False
|
||||
max_reward_clip: float = 20.0
|
||||
reward_output_scaling: float = 1.0
|
||||
reward_output_bias: float = 0.0
|
||||
|
@ -104,6 +100,10 @@ class PPOHyperparameters:
|
|||
value_norm_beta: float = 0.99995
|
||||
value_norm_eps: float = 1e-5
|
||||
|
||||
disable_value: bool = False
|
||||
recompute_logprob: bool = False
|
||||
fuse_rew_ref: bool = True
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PPOMATHConfig(CommonExperimentConfig):
|
||||
|
@ -190,7 +190,6 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
default_factory=ModelTrainEvalConfig
|
||||
)
|
||||
ref: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
|
||||
rew: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
|
||||
|
||||
# for manual allocation only
|
||||
actor_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||
|
@ -210,15 +209,11 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
group_size: int = 1
|
||||
generation_size: Optional[int] = None
|
||||
mask_no_eos_with_zero: bool = False
|
||||
rm_output_scaling: Optional[float] = None
|
||||
ref_ema_eta: Optional[float] = None
|
||||
group_adv_norm: bool = False
|
||||
mask_too_long: bool = False
|
||||
rw_type: Optional[str] = "sparse"
|
||||
task: str = "math"
|
||||
check_xml_format: bool = False
|
||||
use_dense_reward: bool = False
|
||||
reward_delta: bool = True
|
||||
|
||||
check_verifier_status: bool = False
|
||||
|
||||
|
@ -244,26 +239,21 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
|
||||
)
|
||||
|
||||
if self.rm_output_scaling is None:
|
||||
self.rm_output_scaling = self.ppo.reward_output_scaling
|
||||
|
||||
@property
|
||||
def models(self) -> Dict[str, ModelTrainEvalConfig]:
|
||||
# role to config
|
||||
reward = copy.deepcopy(self.actor)
|
||||
models = {
|
||||
"actor": self.actor,
|
||||
"critic": self.critic,
|
||||
"ref": self.ref,
|
||||
"reward": reward,
|
||||
}
|
||||
if self.ppo.disable_value:
|
||||
return {
|
||||
"actor": self.actor,
|
||||
# "critic": self.critic,
|
||||
"ref": self.ref,
|
||||
"reward": self.rew,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"actor": self.actor,
|
||||
"critic": self.critic,
|
||||
"ref": self.ref,
|
||||
"reward": self.rew,
|
||||
}
|
||||
models.pop("critic")
|
||||
if self.ppo.fuse_rew_ref:
|
||||
models.pop("reward")
|
||||
return models
|
||||
|
||||
@property
|
||||
def rpcs(self):
|
||||
|
@ -278,10 +268,6 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
f"smaller than the prompt length + generation length "
|
||||
f"{self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
|
||||
)
|
||||
if not os.path.exists(os.getenv("REAL_MATH_METADATA_PATH")):
|
||||
raise RuntimeError(
|
||||
"Dataset json path REAL_MATH_METADATA_PATH does not exist."
|
||||
)
|
||||
|
||||
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
|
||||
if domain and (not (domain.startswith("http://") and ":" in domain)):
|
||||
|
@ -305,12 +291,8 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
"generation_size": self.generation_size,
|
||||
"group_adv_norm": self.group_adv_norm,
|
||||
"mask_too_long": self.mask_too_long,
|
||||
"use_dense_reward": self.use_dense_reward,
|
||||
"reward_delta": self.reward_delta,
|
||||
},
|
||||
)
|
||||
ref_interface = copy.deepcopy(actor_interface)
|
||||
ref_interface.args["enable_save"] = False
|
||||
|
||||
critic_interface = ModelInterfaceAbstraction(
|
||||
"ppo_critic",
|
||||
|
@ -318,29 +300,31 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
**copy.deepcopy(self.ppo_kwargs),
|
||||
"group_size": self.group_size,
|
||||
"mask_too_long": self.mask_too_long,
|
||||
"use_dense_reward": self.use_dense_reward,
|
||||
"reward_delta": self.reward_delta,
|
||||
},
|
||||
)
|
||||
critic_interface.args.pop("eps_clip")
|
||||
rw_interface = ModelInterfaceAbstraction(
|
||||
"rw_math",
|
||||
"rw-math-code",
|
||||
args=dict(
|
||||
rw_type=self.rw_type,
|
||||
task=self.task,
|
||||
check_xml_format=self.check_xml_format,
|
||||
dataset_path=self.dataset.path,
|
||||
tokenizer_path=self.actor.path,
|
||||
enable_save=False,
|
||||
output_scaling=self.ppo.reward_output_scaling,
|
||||
rm_output_scaling=self.rm_output_scaling,
|
||||
output_bias=self.ppo.reward_output_bias,
|
||||
rw_type=self.rw_type,
|
||||
check_xml_format=self.check_xml_format,
|
||||
group_size=self.group_size,
|
||||
check_verifier_status=self.check_verifier_status,
|
||||
max_sync_length=self.ppo.gen.max_new_tokens
|
||||
+ self.dataset.max_prompt_len
|
||||
+ 128,
|
||||
),
|
||||
)
|
||||
|
||||
ref_interface = copy.deepcopy(actor_interface)
|
||||
ref_interface.args["enable_save"] = False
|
||||
if self.ppo.fuse_rew_ref:
|
||||
ref_interface = ModelInterfaceAbstraction(
|
||||
"fused-threading",
|
||||
args=dict(interfaces=dict(rew=rw_interface, ref=ref_interface)),
|
||||
)
|
||||
|
||||
rollout_output_keys = [
|
||||
"seq_no_eos_mask",
|
||||
"packed_input_ids",
|
||||
|
@ -357,7 +341,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
model_type=self.actor.type,
|
||||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=["packed_prompts"],
|
||||
input_keys=["packed_prompts", "task_ids"],
|
||||
output_keys=rollout_output_keys,
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
@ -379,18 +363,21 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
inf_reward = MFCDef(
|
||||
name="rew_inf",
|
||||
model_name="reward",
|
||||
mb_spec=self.rew_inf.mb_spec,
|
||||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
interface_impl=rw_interface,
|
||||
model_type=self.rew.type,
|
||||
model_path=self.rew.path,
|
||||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=["packed_input_ids", "packed_prompts"],
|
||||
output_keys=["rewards", "dense_rewards"],
|
||||
input_keys=["packed_input_ids", "packed_prompts", "task_ids"],
|
||||
output_keys=["rewards"],
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
# add rew param into ref MFC
|
||||
inf_ref_inputs = ["packed_input_ids"]
|
||||
inf_ref_outputs = ["packed_ref_logprobs"]
|
||||
if self.ppo.fuse_rew_ref:
|
||||
inf_ref_inputs += ["packed_prompts", "task_ids"]
|
||||
inf_ref_outputs += ["rewards"]
|
||||
|
||||
inf_ref_logits = MFCDef(
|
||||
name="ref_inf",
|
||||
model_name="ref",
|
||||
|
@ -401,7 +388,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
interface_impl=ref_interface,
|
||||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=inf_ref_inputs,
|
||||
output_keys=["packed_ref_logprobs"],
|
||||
output_keys=inf_ref_outputs,
|
||||
output_key_remap=dict(logprobs="packed_ref_logprobs"),
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
@ -425,7 +412,6 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
"packed_logprobs",
|
||||
"packed_ref_logprobs",
|
||||
"rewards",
|
||||
"dense_rewards",
|
||||
"values",
|
||||
"prompt_mask",
|
||||
"seq_no_eos_mask",
|
||||
|
@ -459,7 +445,6 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
"packed_logprobs",
|
||||
"packed_ref_logprobs",
|
||||
"rewards",
|
||||
"dense_rewards",
|
||||
"values",
|
||||
"prompt_mask",
|
||||
"seq_no_eos_mask",
|
||||
|
@ -483,6 +468,8 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
rpcs.pop("critic_train")
|
||||
if not self.ppo.recompute_logprob:
|
||||
rpcs.pop("actor_inf")
|
||||
if self.ppo.fuse_rew_ref:
|
||||
rpcs.pop("rew_inf")
|
||||
return rpcs
|
||||
|
||||
@property
|
||||
|
@ -501,17 +488,18 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
allocs.pop("critic_train")
|
||||
if not self.ppo.recompute_logprob:
|
||||
allocs.pop("actor_inf")
|
||||
if self.ppo.fuse_rew_ref:
|
||||
allocs.pop("rew_inf")
|
||||
return allocs
|
||||
|
||||
@property
|
||||
def datasets(self):
|
||||
return [
|
||||
DatasetAbstraction(
|
||||
"math_prompt",
|
||||
"math_code_prompt",
|
||||
args=dict(
|
||||
dataset_path=self.dataset.path,
|
||||
max_length=self.dataset.max_prompt_len,
|
||||
fill_to_max_length=self.dataset.fill_to_max_length,
|
||||
filter_threshold=self.dataset_filter_threshold,
|
||||
max_filter_percentage=self.dataset_max_filter_percentage,
|
||||
),
|
||||
|
|
|
@ -0,0 +1,223 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import json
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, Hashable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch.utils.data
|
||||
|
||||
from realhf.api.core import data_api
|
||||
from realhf.base import logging
|
||||
|
||||
logger = logging.getLogger("Math Code Dataset")
|
||||
|
||||
|
||||
def check_math_metadata_entries(data):
|
||||
assert data["task"] == "math"
|
||||
assert "query_id" in data
|
||||
data["query_id"] = str(data["query_id"])
|
||||
assert isinstance(data["prompt"], str)
|
||||
assert isinstance(data["solutions"], list)
|
||||
for sol in data["solutions"]:
|
||||
assert isinstance(sol, str)
|
||||
return data
|
||||
|
||||
|
||||
def check_code_metadata_entries(data):
|
||||
assert data["task"] == "code"
|
||||
assert "query_id" in data
|
||||
data["query_id"] = str(data["query_id"])
|
||||
if "problem_id" not in data:
|
||||
data["problem_id"] = data["query_id"]
|
||||
assert isinstance(data["prompt"], str)
|
||||
input_output = json.loads(data["input_output"])
|
||||
assert len(input_output["inputs"]) == len(input_output["outputs"])
|
||||
for inp, out in zip(input_output["inputs"], input_output["outputs"]):
|
||||
assert isinstance(inp, str) and isinstance(out, str), (
|
||||
inp,
|
||||
out,
|
||||
input_output.get("fn_name"),
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def load_metadata(path):
|
||||
assert str(path).endswith(".jsonl"), path
|
||||
with open(path, "r") as f:
|
||||
data = [json.loads(l) for l in f.readlines()]
|
||||
|
||||
id2info = {}
|
||||
omit_cnt = defaultdict(int)
|
||||
task_cnt = defaultdict(int)
|
||||
for d in data:
|
||||
assert d["query_id"] not in d, (d["task"], d["query_id"])
|
||||
try:
|
||||
if "task" not in d:
|
||||
d["task"] = "math"
|
||||
logger.warning(f'Key "task" not found in the dataset. Use math as default task type.')
|
||||
if d["task"] == "math":
|
||||
d = check_math_metadata_entries(d)
|
||||
elif d["task"] == "code":
|
||||
d = check_code_metadata_entries(d)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Data validation failed: query_id {d['query_id']}. "
|
||||
f"Error: {traceback.format_exc()}. Omit it in the dataset."
|
||||
)
|
||||
omit_cnt[d["task"]] += 1
|
||||
continue
|
||||
id2info[d["query_id"]] = d
|
||||
task_cnt[d["task"]] += 1
|
||||
logger.warning(f"Number of ignored data: {dict(**omit_cnt)}")
|
||||
return id2info, task_cnt
|
||||
|
||||
|
||||
class MATHCodePromptDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
util: data_api.DatasetUtility,
|
||||
max_length: Optional[int] = None,
|
||||
dataset_path: Optional[str] = None,
|
||||
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
|
||||
filter_threshold: float = 1e4,
|
||||
max_filter_percentage: float = 0.0,
|
||||
):
|
||||
"""Required keys: prompt, query_id, task=math/code, solutions.
|
||||
|
||||
For code dataset, they additionally require an "input_output" key.
|
||||
"""
|
||||
self._util = util
|
||||
self.max_length = max_length
|
||||
|
||||
id2info, task_cnt = load_metadata(dataset_path)
|
||||
|
||||
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
|
||||
|
||||
prompts_str = [x["prompt"] for x in data]
|
||||
self.ids = [x["query_id"] for x in data]
|
||||
self.tasks_ids = [data_api.RL_TASKS.index(x["task"]) for x in data]
|
||||
if "scores" in data[0]:
|
||||
self.base_scores = [np.mean(x["scores"]) for x in data]
|
||||
util.tokenizer.padding_side = "left"
|
||||
prompt_encodings = util.tokenizer(
|
||||
prompts_str,
|
||||
truncation=True,
|
||||
# max_length=max_length,
|
||||
padding=False,
|
||||
return_length=True,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
|
||||
logger.info(f"{len(data)} samples, checking lengths (max_length={max_length})")
|
||||
indices = [
|
||||
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
|
||||
]
|
||||
logger.info(
|
||||
f"{len(indices)} samples remain, among them {task_cnt['math']} are math data and {task_cnt['code']} are code data"
|
||||
)
|
||||
|
||||
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
|
||||
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
|
||||
self.ids = [
|
||||
str(self.ids[idx]) + f"@idx:{idx}-{util.dp_rank}" for idx in indices
|
||||
]
|
||||
if "scores" in data[0]:
|
||||
self.base_scores = [self.base_scores[idx] for idx in indices]
|
||||
|
||||
assert all(len(x) == l for x, l in zip(self.prompts, self.prompt_lengths))
|
||||
|
||||
logger.info(f"Number of prompts in the dataset: {len(self.prompts)}")
|
||||
|
||||
self.active_indices = list(range(len(self.prompts)))
|
||||
self.filter_threshold = filter_threshold
|
||||
self.max_filter_percentage = max_filter_percentage
|
||||
|
||||
@property
|
||||
def util(self):
|
||||
return self._util
|
||||
|
||||
def __len__(self):
|
||||
return len(self.active_indices)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# print(self.base_scores)
|
||||
idx = self.active_indices[idx]
|
||||
data = dict(
|
||||
task_ids=torch.tensor([self.tasks_ids[idx]], dtype=torch.long),
|
||||
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
|
||||
)
|
||||
if hasattr(self, "base_scores"):
|
||||
data["base_scores"] = torch.tensor(
|
||||
[self.base_scores[idx]], dtype=torch.float32
|
||||
)
|
||||
return data_api.SequenceSample.from_default(
|
||||
ids=[self.ids[idx]],
|
||||
seqlens=[self.prompt_lengths[idx]],
|
||||
data=data,
|
||||
)
|
||||
|
||||
def filter(self, eval_scores: Dict[Hashable, float]):
|
||||
# Get all data indices that have a higher score than the threshold.
|
||||
idx2scores_to_remove = {}
|
||||
for pop_idx, idx in enumerate(self.active_indices):
|
||||
data_id = self.ids[idx]
|
||||
if data_id not in eval_scores:
|
||||
continue
|
||||
if eval_scores[data_id] > self.filter_threshold:
|
||||
idx2scores_to_remove[pop_idx] = eval_scores[data_id]
|
||||
|
||||
# Control the number of samples to be removed according to max_filter_percentage.
|
||||
n = int(len(self.active_indices) * self.max_filter_percentage)
|
||||
indices_to_remove = sorted(
|
||||
idx2scores_to_remove.keys(),
|
||||
key=lambda x: idx2scores_to_remove[x],
|
||||
reverse=True,
|
||||
)[:n]
|
||||
|
||||
for pop_idx in sorted(indices_to_remove, reverse=True):
|
||||
self.active_indices.pop(pop_idx)
|
||||
logger.info(
|
||||
f"Math prompt dataset DP rank {self.util.dp_rank} filtered"
|
||||
f" {len(indices_to_remove)} samples, {len(self.active_indices)} samples remain. "
|
||||
f"Original dataset size: {len(self.prompts)}. "
|
||||
f"Filter threshold: {self.filter_threshold}. "
|
||||
f"Max filter percentage: {self.max_filter_percentage}. "
|
||||
f"Current number of eval scores: {len(eval_scores)}."
|
||||
)
|
||||
|
||||
|
||||
if not __name__ == "__main__":
|
||||
data_api.register_dataset("math_code_prompt", MATHCodePromptDataset)
|
||||
else:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
dataset = MATHCodePromptDataset(
|
||||
data_api.DatasetUtility(
|
||||
seed=0,
|
||||
dp_rank=0,
|
||||
world_size=1,
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
"/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
|
||||
),
|
||||
),
|
||||
max_length=512,
|
||||
dataset_path='/storage/datasets/full_prompts_for_r1_distilled.jsonl'
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
collate_fn=data_api.SequenceSample.gather,
|
||||
# NOTE: This is *NOT* the actual batch size for training.
|
||||
# It is just a proper size to load data to workers.
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
print(f"size: {len(dataset)}")
|
||||
for d in dataloader:
|
||||
# print(d.ids)
|
||||
pass
|
|
@ -4,13 +4,10 @@ import json
|
|||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import *
|
||||
|
||||
from realhf.base import logging
|
||||
from realhf.base.constants import parallelism_rank
|
||||
|
||||
logger = logging.getLogger("math parser")
|
||||
|
||||
|
@ -48,20 +45,7 @@ def loadJson(dataDir):
|
|||
return samples
|
||||
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
id2info = None
|
||||
|
||||
|
||||
def parse_line(prompt_str, generated, query_id):
|
||||
global id2info
|
||||
if id2info is None:
|
||||
try:
|
||||
id2info = loadJson(os.environ["REAL_MATH_METADATA_PATH"])
|
||||
except KeyError as e:
|
||||
raise KeyError("The json file REAL_MATH_METADATA_PATH is not set") from e
|
||||
def parse_line(id2info, prompt_str, generated, query_id):
|
||||
info = id2info[query_id.split("@idx:")[0]]
|
||||
|
||||
tmp_id = str(uuid.uuid4())
|
||||
|
@ -112,17 +96,12 @@ def parse_line(prompt_str, generated, query_id):
|
|||
|
||||
|
||||
def parse_lines_in_parallel(
|
||||
id2info,
|
||||
generateds: List,
|
||||
query_ids: List,
|
||||
max_workers=22,
|
||||
check_xml_format=False,
|
||||
) -> List:
|
||||
global id2info
|
||||
if id2info is None:
|
||||
try:
|
||||
id2info = loadJson(os.environ["REAL_MATH_METADATA_PATH"])
|
||||
except KeyError as e:
|
||||
raise KeyError("The json file REAL_MATH_METADATA_PATH is not set") from e
|
||||
assert len(generateds) == len(query_ids), (
|
||||
len(generateds),
|
||||
len(query_ids),
|
||||
|
@ -204,17 +183,19 @@ def parse_lines_in_parallel(
|
|||
|
||||
if __name__ == "__main__":
|
||||
sample = {
|
||||
"prompt": "",
|
||||
"query_id": "35ecd821a9e7e31da9ef0663a25347ce",
|
||||
# "answer_in_box": ["\\boxed{\\frac{1}{2}}", "<think></think><answer>\\boxed{\\frac{1}{2}}</answer>"]
|
||||
"answer": "<think>\n1. The problem requires us to determine the number of sequences of 144 hand movements such that every position appears exactly once and the hands return to the initial position at the end.\n2. We know that each movement involves one hand moving clockwise to the next number while the other hand stays in place.\n3. Considering the 12-hour clock, we can represent each positioning of the hands as a combination of the positions of both hands. Since both hands can be in any of the 12 positions, there are 12 x 12 = 144 different positionings.\n4. Given that at each position only one hand moves, every single movement is unique, leading to a total of 144 unique movements.\n5. These 144 movements must form a Hamiltonian Cycle, where each edge represents a valid movement between two positions.\n6. The problem thus reduces to finding a Hamiltonian cycle in a directed graph. Since the appearance of each movement is unique, it also determines the direction of the movement.\n7. We consider the edges that are rotations of each other as equivalent. Taking the rotational symmetry into account, we have 144/12 = 12 equivalence classes.\n8. The problem now is to determine the number of ways to arrange these 12 classes of rotations in a circle, which is 11 factorial.\n9. We must find the value of 11! and then compute the result modulo 1000.\n</think>\n<answer>\n320\n</answer>",
|
||||
"answers": ["-\\frac{2}{3}"],
|
||||
"solutions": [
|
||||
"1. **Apply the operation $\\otimes$ to the innermost parentheses first:**\n \\[\n (1 \\otimes 2) \\otimes 3 = \\left(\\frac{1^2}{2}\\right) \\otimes 3 = \\frac{1}{2} \\otimes 3\n \\]\n \\[\n 1 \\otimes (2 \\otimes 3) = 1 \\otimes \\left(\\frac{2^2}{3}\\right) = 1 \\otimes \\frac{4}{3}\n \\]\n\n2. **Calculate each part using the definition of $\\otimes$:**\n \\[\n \\frac{1}{2} \\otimes 3 = \\frac{\\left(\\frac{1}{2}\\right)^2}{3} = \\frac{\\frac{1}{4}}{3} = \\frac{1}{12}\n \\]\n \\[\n 1 \\otimes \\frac{4}{3} = \\frac{1^2}{\\frac{4}{3}} = \\frac{1}{\\frac{4}{3}} = \\frac{3}{4}\n \\]\n\n3. **Subtract the two results:**\n \\[\n \\left(\\frac{1}{12}\\right) - \\left(\\frac{3}{4}\\right) = \\frac{1}{12} - \\frac{9}{12} = -\\frac{8}{12} = -\\frac{2}{3}\n \\]\n\n4. **Conclude with the final answer:**\n \\[\n \\boxed{A}\n \\]",
|
||||
"\\boxed{-\\frac{2}{3}}",
|
||||
],
|
||||
}
|
||||
id2info = {"fe11b471-1aa9-4867-958f-a0a811c85f92": sample}
|
||||
|
||||
print(
|
||||
parse_lines_in_parallel(
|
||||
# [sample["answer_in_box"][0] for _ in range(50)] + [sample["answer_in_box"][1] for _ in range(50)],
|
||||
[sample["answer"]] * 100,
|
||||
[sample["query_id"] for _ in range(100)],
|
||||
id2info,
|
||||
sample["answers"] * 100,
|
||||
["fe11b471-1aa9-4867-958f-a0a811c85f92" for _ in range(100)],
|
||||
max_workers=8,
|
||||
check_xml_format=True,
|
||||
)
|
|
@ -2,7 +2,6 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import uuid
|
||||
from typing import Callable, Dict, Hashable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -85,312 +84,4 @@ class PromptDataset(torch.utils.data.Dataset):
|
|||
)
|
||||
|
||||
|
||||
class MATHPromptDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
util: data_api.DatasetUtility,
|
||||
max_length: Optional[int] = None,
|
||||
dataset_path: Optional[str] = None,
|
||||
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
|
||||
fill_to_max_length: bool = False,
|
||||
filter_threshold: float = 1e4,
|
||||
max_filter_percentage: float = 0.0,
|
||||
):
|
||||
"""A dataset with prompts. Usually used for PPO.
|
||||
|
||||
Args:
|
||||
util (api.data.DatasetUtility): .
|
||||
max_length (Optional[int], optional): The maximum length of each sequence in the batch.
|
||||
dataset_path (Optional[str], optional): Path to the dataset json/jsonl file.
|
||||
The json/jsonl file should be a list of dictionary. Each element in the list should have
|
||||
a key "prompt". Defaults to None.
|
||||
dataset_builder (Optional[Callable[[], List[Dict]]], optional): Alternative to dataset_path.
|
||||
A callable that returns a list of dictionary. Defaults to None.
|
||||
"""
|
||||
self._util = util
|
||||
self.max_length = max_length
|
||||
|
||||
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
|
||||
|
||||
prompts_str = [x["prompt"] for x in data]
|
||||
self.ids = [x["query_id"] for x in data]
|
||||
if "scores" in data[0]:
|
||||
self.base_scores = [np.mean(x["scores"]) for x in data]
|
||||
util.tokenizer.padding_side = "left"
|
||||
prompt_encodings = util.tokenizer(
|
||||
prompts_str,
|
||||
truncation=True,
|
||||
# max_length=max_length,
|
||||
padding=False,
|
||||
return_length=True,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
|
||||
if fill_to_max_length:
|
||||
for i in range(len(prompt_encodings["length"])):
|
||||
x = prompt_encodings["input_ids"][i]
|
||||
if max_length > len(x):
|
||||
# Fill with the final non-pad token to the left.
|
||||
prompt_encodings["input_ids"][i] = [x[-1]] * (
|
||||
max_length - len(x)
|
||||
) + x
|
||||
prompt_encodings["length"][i] = max_length
|
||||
|
||||
logger.info(f"{len(data)} samples, checking lengths (max_length={max_length})")
|
||||
indices = [
|
||||
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
|
||||
]
|
||||
logger.info(f"{len(indices)} samples remain")
|
||||
|
||||
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
|
||||
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
|
||||
self.ids = [self.ids[idx] + f"@idx:{idx}-{util.dp_rank}" for idx in indices]
|
||||
if "scores" in data[0]:
|
||||
self.base_scores = [self.base_scores[idx] for idx in indices]
|
||||
|
||||
assert all(len(x) == l for x, l in zip(self.prompts, self.prompt_lengths))
|
||||
|
||||
logger.info(f"Number of prompts in the dataset: {len(self.prompts)}")
|
||||
|
||||
self.active_indices = list(range(len(self.prompts)))
|
||||
self.filter_threshold = filter_threshold
|
||||
self.max_filter_percentage = max_filter_percentage
|
||||
|
||||
@property
|
||||
def util(self):
|
||||
return self._util
|
||||
|
||||
def __len__(self):
|
||||
return len(self.active_indices)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# print(self.base_scores)
|
||||
idx = self.active_indices[idx]
|
||||
if hasattr(self, "base_scores"):
|
||||
return data_api.SequenceSample.from_default(
|
||||
ids=[self.ids[idx]],
|
||||
seqlens=[self.prompt_lengths[idx]],
|
||||
data=dict(
|
||||
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
|
||||
base_scores=torch.tensor(
|
||||
[self.base_scores[idx]], dtype=torch.float32
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
return data_api.SequenceSample.from_default(
|
||||
ids=[self.ids[idx]],
|
||||
seqlens=[self.prompt_lengths[idx]],
|
||||
data=dict(
|
||||
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long)
|
||||
),
|
||||
)
|
||||
|
||||
def filter(self, eval_scores: Dict[Hashable, float]):
|
||||
# Get all data indices that have a higher score than the threshold.
|
||||
idx2scores_to_remove = {}
|
||||
for pop_idx, idx in enumerate(self.active_indices):
|
||||
data_id = self.ids[idx]
|
||||
if data_id not in eval_scores:
|
||||
continue
|
||||
if eval_scores[data_id] > self.filter_threshold:
|
||||
idx2scores_to_remove[pop_idx] = eval_scores[data_id]
|
||||
|
||||
# Control the number of samples to be removed according to max_filter_percentage.
|
||||
n = int(len(self.active_indices) * self.max_filter_percentage)
|
||||
indices_to_remove = sorted(
|
||||
idx2scores_to_remove.keys(),
|
||||
key=lambda x: idx2scores_to_remove[x],
|
||||
reverse=True,
|
||||
)[:n]
|
||||
|
||||
for pop_idx in sorted(indices_to_remove, reverse=True):
|
||||
self.active_indices.pop(pop_idx)
|
||||
logger.info(
|
||||
f"Math prompt dataset DP rank {self.util.dp_rank} filtered"
|
||||
f" {len(indices_to_remove)} samples, {len(self.active_indices)} samples remain. "
|
||||
f"Original dataset size: {len(self.prompts)}. "
|
||||
f"Filter threshold: {self.filter_threshold}. "
|
||||
f"Max filter percentage: {self.max_filter_percentage}. "
|
||||
f"Current number of eval scores: {len(eval_scores)}."
|
||||
)
|
||||
|
||||
|
||||
class CODEPromptDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
util: data_api.DatasetUtility,
|
||||
max_length: Optional[int] = None,
|
||||
dataset_path: Optional[str] = None,
|
||||
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
|
||||
fill_to_max_length: bool = False,
|
||||
filter_threshold: float = 1e4,
|
||||
max_filter_percentage: float = 0.0,
|
||||
):
|
||||
"""A dataset with prompts. Usually used for PPO.
|
||||
|
||||
Args:
|
||||
util (api.data.DatasetUtility): .
|
||||
max_length (Optional[int], optional): The maximum length of each sequence in the batch.
|
||||
dataset_path (Optional[str], optional): Path to the dataset json/jsonl file.
|
||||
The json/jsonl file should be a list of dictionary. Each element in the list should have
|
||||
a key "prompt". Defaults to None.
|
||||
dataset_builder (Optional[Callable[[], List[Dict]]], optional): Alternative to dataset_path.
|
||||
A callable that returns a list of dictionary. Defaults to None.
|
||||
"""
|
||||
self._util = util
|
||||
self.max_length = max_length
|
||||
|
||||
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
|
||||
|
||||
prompts_str = [x["question"] for x in data]
|
||||
self.ids = [x["id"] for x in data]
|
||||
if "scores" in data[0]:
|
||||
self.base_scores = [np.mean(x["scores"]) for x in data]
|
||||
util.tokenizer.padding_side = "left"
|
||||
prompt_encodings = util.tokenizer(
|
||||
prompts_str,
|
||||
truncation=True,
|
||||
# max_length=max_length,
|
||||
padding=False,
|
||||
return_length=True,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
|
||||
if fill_to_max_length:
|
||||
for i in range(len(prompt_encodings["length"])):
|
||||
x = prompt_encodings["input_ids"][i]
|
||||
if max_length > len(x):
|
||||
# Fill with the final non-pad token to the left.
|
||||
prompt_encodings["input_ids"][i] = [x[-1]] * (
|
||||
max_length - len(x)
|
||||
) + x
|
||||
prompt_encodings["length"][i] = max_length
|
||||
|
||||
logger.info(f"{len(data)} samples, checking lengths (max_length={max_length})")
|
||||
indices = [
|
||||
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
|
||||
]
|
||||
logger.info(f"{len(indices)} samples remain")
|
||||
|
||||
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
|
||||
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
|
||||
self.ids = [
|
||||
str(self.ids[idx]) + f"@idx:{idx}-{util.dp_rank}" for idx in indices
|
||||
]
|
||||
if "scores" in data[0]:
|
||||
self.base_scores = [self.base_scores[idx] for idx in indices]
|
||||
|
||||
assert all(len(x) == l for x, l in zip(self.prompts, self.prompt_lengths))
|
||||
|
||||
logger.info(f"Number of prompts in the dataset: {len(self.prompts)}")
|
||||
|
||||
self.active_indices = list(range(len(self.prompts)))
|
||||
self.filter_threshold = filter_threshold
|
||||
self.max_filter_percentage = max_filter_percentage
|
||||
|
||||
@property
|
||||
def util(self):
|
||||
return self._util
|
||||
|
||||
def __len__(self):
|
||||
return len(self.active_indices)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# print(self.base_scores)
|
||||
print(
|
||||
f"CODEPromptDataset idx{idx}, use idx: {self.active_indices[idx]}, size: {len(self.ids)}"
|
||||
)
|
||||
idx = self.active_indices[idx]
|
||||
|
||||
if hasattr(self, "base_scores"):
|
||||
return data_api.SequenceSample.from_default(
|
||||
ids=[self.ids[idx]],
|
||||
seqlens=[self.prompt_lengths[idx]],
|
||||
data=dict(
|
||||
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
|
||||
base_scores=torch.tensor(
|
||||
[self.base_scores[idx]], dtype=torch.float32
|
||||
),
|
||||
),
|
||||
metadata=dict(random_id=[uuid.uuid4()]),
|
||||
)
|
||||
else:
|
||||
return data_api.SequenceSample.from_default(
|
||||
ids=[self.ids[idx]],
|
||||
seqlens=[self.prompt_lengths[idx]],
|
||||
data=dict(
|
||||
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long)
|
||||
),
|
||||
metadata=dict(random_id=[uuid.uuid4()]),
|
||||
)
|
||||
|
||||
def filter(self, eval_scores: Dict[Hashable, float]):
|
||||
# Get all data indices that have a higher score than the threshold.
|
||||
idx2scores_to_remove = {}
|
||||
for pop_idx, idx in enumerate(self.active_indices):
|
||||
data_id = self.ids[idx]
|
||||
if data_id not in eval_scores:
|
||||
continue
|
||||
if eval_scores[data_id] > self.filter_threshold:
|
||||
idx2scores_to_remove[pop_idx] = eval_scores[data_id]
|
||||
|
||||
# Control the number of samples to be removed according to max_filter_percentage.
|
||||
n = int(len(self.active_indices) * self.max_filter_percentage)
|
||||
indices_to_remove = sorted(
|
||||
idx2scores_to_remove.keys(),
|
||||
key=lambda x: idx2scores_to_remove[x],
|
||||
reverse=True,
|
||||
)[:n]
|
||||
|
||||
for pop_idx in sorted(indices_to_remove, reverse=True):
|
||||
self.active_indices.pop(pop_idx)
|
||||
logger.info(
|
||||
f"Code prompt dataset DP rank {self.util.dp_rank} filtered"
|
||||
f" {len(indices_to_remove)} samples, {len(self.active_indices)} samples remain. "
|
||||
f"Original dataset size: {len(self.prompts)}. "
|
||||
f"Filter threshold: {self.filter_threshold}. "
|
||||
f"Max filter percentage: {self.max_filter_percentage}. "
|
||||
f"Current number of eval scores: {len(eval_scores)}."
|
||||
)
|
||||
|
||||
|
||||
if not __name__ == "__main__":
|
||||
data_api.register_dataset("prompt", PromptDataset)
|
||||
data_api.register_dataset("math_prompt", MATHPromptDataset)
|
||||
data_api.register_dataset("code_prompt", CODEPromptDataset)
|
||||
else:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
dataset = MATHPromptDataset(
|
||||
data_api.DatasetUtility(
|
||||
seed=0,
|
||||
dp_rank=0,
|
||||
world_size=1,
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
"/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
|
||||
),
|
||||
),
|
||||
max_length=512,
|
||||
dataset_path="/storage/openpsi/data/math/Qwen_RL_training_xss/0101/train_rl@0101_with_qwqsft-7b_score.jsonl",
|
||||
)
|
||||
|
||||
dataset = CODEPromptDataset(
|
||||
data_api.DatasetUtility(
|
||||
seed=0,
|
||||
dp_rank=0,
|
||||
world_size=1,
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
"/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
|
||||
),
|
||||
),
|
||||
max_length=512,
|
||||
dataset_path="/storage/datasets/codeparrot-apps-test.jsonl",
|
||||
)
|
||||
data_generator = enumerate(dataset)
|
||||
dataset_batch_counter, cur_sample = next(data_generator)
|
||||
print(
|
||||
f"size: {len(dataset)}, index: {dataset_batch_counter}, cur_sample: {cur_sample}"
|
||||
)
|
||||
data_api.register_dataset("prompt", PromptDataset)
|
||||
|
|
|
@ -1,429 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
import collections
|
||||
import dataclasses
|
||||
import html
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from ast import parse
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import colorama
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import realhf.api.core.model_api as model_api
|
||||
import realhf.base.logging as logging
|
||||
from functioncall.code.verify import code_verify
|
||||
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer
|
||||
from realhf.base import constants
|
||||
|
||||
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
|
||||
|
||||
|
||||
class CodeVerifierException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def extract_python_code(text, min_length=20, strict_syntax=True):
|
||||
code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```"
|
||||
code_blocks = re.findall(code_pattern, text, re.DOTALL)
|
||||
valid_blocks = []
|
||||
for block in code_blocks:
|
||||
clean_block = block.strip()
|
||||
if len(clean_block) < min_length:
|
||||
continue
|
||||
|
||||
# verify code syntax
|
||||
if strict_syntax:
|
||||
try:
|
||||
parse(clean_block, mode="exec")
|
||||
except (SyntaxError, IndentationError):
|
||||
continue
|
||||
|
||||
valid_blocks.append(clean_block)
|
||||
|
||||
if not valid_blocks:
|
||||
return None
|
||||
# return the last code block
|
||||
return valid_blocks[-1]
|
||||
|
||||
|
||||
def check_with_elementtree(text):
|
||||
def escape_between_tags(text, tags=["think", "answer"]):
|
||||
"""转义标签之间的内容,但保留标签本身."""
|
||||
# 构建标签模式
|
||||
tag_pattern = "|".join(tags)
|
||||
parts = []
|
||||
current_pos = 0
|
||||
|
||||
# 匹配开始和结束标签
|
||||
pattern = f"</?({tag_pattern})[^>]*>"
|
||||
|
||||
for match in re.finditer(pattern, text):
|
||||
# 添加标签之前的内容(需要转义)
|
||||
if current_pos < match.start():
|
||||
parts.append(html.escape(text[current_pos : match.start()]))
|
||||
|
||||
# 添加标签本身(不转义)
|
||||
parts.append(match.group())
|
||||
current_pos = match.end()
|
||||
|
||||
# 添加最后剩余的内容
|
||||
if current_pos < len(text):
|
||||
parts.append(html.escape(text[current_pos:]))
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
text = escape_between_tags(text)
|
||||
if not text.strip().startswith("<think>"):
|
||||
text = "<think>" + text
|
||||
try:
|
||||
xml_text = f"<root>{text}</root>"
|
||||
x = ET.fromstring(xml_text)
|
||||
if x.text is not None and x.text.strip() != "":
|
||||
return False, f"Error: extra content before <think>. {x.text}"
|
||||
if len(x) != 2:
|
||||
return False, f"Error: there are {len(x)} tags."
|
||||
if x[0].tag is None or x[0].tag != "think":
|
||||
return False, f"Error: <think> tag is missing. {x[0].tag}"
|
||||
if x[0].tail is not None and x[0].tail.strip() != "":
|
||||
return (
|
||||
False,
|
||||
f"Error: extra content between <think> and <answer>. {x[0].tail}",
|
||||
)
|
||||
if x[1].tag is None or x[1].tag != "answer":
|
||||
return False, f"Error: <answer> tag is missing. {x[1].tag}"
|
||||
if x[1].tail is not None and x[1].tail.strip() != "":
|
||||
return False, f"Error: extra content after <answer>, {x[1].tail}"
|
||||
|
||||
return True, x[1].text if x[1].text is not None else ""
|
||||
except ET.ParseError as e:
|
||||
return False, f"Error: XML格式错误, {str(e)}"
|
||||
|
||||
|
||||
def retokenize(
|
||||
task,
|
||||
tokenizer,
|
||||
packed_input_ids,
|
||||
input_cu_seqlens,
|
||||
prompts,
|
||||
prompt_cu_seqlens,
|
||||
query_ids,
|
||||
check_xml_format=False,
|
||||
do_eval=False,
|
||||
):
|
||||
input_ids = [
|
||||
packed_input_ids[start:end]
|
||||
for start, end in zip(input_cu_seqlens[:-1], input_cu_seqlens[1:])
|
||||
]
|
||||
prompt_ids = [
|
||||
prompts[start:end]
|
||||
for start, end in zip(prompt_cu_seqlens[:-1], prompt_cu_seqlens[1:])
|
||||
]
|
||||
seq_strs = tokenizer.batch_decode(
|
||||
input_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
||||
)
|
||||
prompt_strs = tokenizer.batch_decode(
|
||||
prompt_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
||||
)
|
||||
# query_id_strs = query_ids
|
||||
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
|
||||
|
||||
format_rewards = []
|
||||
|
||||
queryid_to_results = collections.defaultdict(list)
|
||||
# 8 processes on each node, with 10 subprocesses each
|
||||
if do_eval == True:
|
||||
_answers = [
|
||||
seq_str.split(prompt_str)[1]
|
||||
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
|
||||
]
|
||||
|
||||
codes = [extract_python_code(_answer) for _answer in _answers]
|
||||
logger.info(
|
||||
f"code_rw_interface, size: {len(query_id_strs)}, valid code size: {len(codes)}, query_id_0: {query_id_strs[0]}"
|
||||
)
|
||||
format_rewards = code_verify(codes, query_id_strs)
|
||||
|
||||
if check_xml_format:
|
||||
with ThreadPoolExecutor(max_workers=22) as executor:
|
||||
futures = [
|
||||
executor.submit(check_with_elementtree, answer_str)
|
||||
for answer_str in _answers
|
||||
]
|
||||
# xml_rewards = []
|
||||
for idx, future in enumerate(futures):
|
||||
xml_reward, _ = future.result()
|
||||
# xml_rewards.append(xml_reward)
|
||||
if xml_reward == 1 and format_rewards[idx] == 0:
|
||||
format_rewards[idx] = -0.8
|
||||
elif xml_reward == 0 and format_rewards[idx] == 0:
|
||||
format_rewards[idx] = -1
|
||||
|
||||
for query_id_str, format_reward in zip(query_id_strs, format_rewards):
|
||||
if query_id_str not in queryid_to_results:
|
||||
queryid_to_results[query_id_str] = []
|
||||
queryid_to_results[query_id_str].append(format_reward)
|
||||
else:
|
||||
for query_id_str in query_id_strs:
|
||||
if query_id_str not in queryid_to_results:
|
||||
queryid_to_results[query_id_str] = []
|
||||
queryid_to_results[query_id_str].append(0)
|
||||
format_rewards.append(0)
|
||||
|
||||
return format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PackedCodeRewardInterface(model_api.ModelInterface):
|
||||
|
||||
enable_save: bool = False
|
||||
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
||||
output_scaling: float = 1.0
|
||||
rm_output_scaling: float = 1.0
|
||||
rm_output_bias: float = 0.0
|
||||
output_bias: float = 0.0
|
||||
loss_fun = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
max_sync_length: int = 2048
|
||||
rw_type: str = "sparse"
|
||||
task: str = "code" # math or countdown or code
|
||||
check_xml_format: bool = False
|
||||
post_process: str = "sigmoid"
|
||||
group_size: int = 1
|
||||
check_verifier_status: bool = False
|
||||
|
||||
_call_count: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
|
||||
logger.info(f"rm_output_scaling: {self.rm_output_scaling}")
|
||||
logger.info(f"rm_output_bias: {self.rm_output_bias}")
|
||||
logger.info(f"output_scaling: {self.output_scaling}")
|
||||
logger.info(f"output_bias: {self.output_bias}")
|
||||
logger.info(f"max_sync_length: {self.max_sync_length}")
|
||||
logger.info(f"rw_type: {self.rw_type}")
|
||||
logger.info(f"post_process: {self.post_process}")
|
||||
|
||||
while True:
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated",
|
||||
f"v{self._call_count}r{dist.get_rank()}.txt",
|
||||
)
|
||||
if os.path.exists(gen_file_path):
|
||||
self._call_count += 1
|
||||
else:
|
||||
break
|
||||
logger.info(f"call_count: {self._call_count}")
|
||||
|
||||
def inference(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
data_: SequenceSample,
|
||||
mb_spec,
|
||||
) -> SequenceSample:
|
||||
|
||||
packed_input_ids: torch.Tensor = data_.data["packed_input_ids"].squeeze()
|
||||
|
||||
input_seqlens = torch.tensor(data_.seqlens["packed_input_ids"]).view(-1)
|
||||
input_cu_seqlens = torch.nn.functional.pad(
|
||||
input_seqlens.cumsum(0), (1, 0)
|
||||
).int()
|
||||
|
||||
packed_prompts = data_.data["packed_prompts"]
|
||||
prompts = []
|
||||
prompt_seqlens = []
|
||||
offset = 0
|
||||
for x in data_.seqlens["packed_prompts"]:
|
||||
prompts += [packed_prompts[offset : offset + x[0]]] * self.group_size
|
||||
offset += x[0]
|
||||
prompt_seqlens.extend(x * self.group_size)
|
||||
|
||||
assert offset == sum(x[0] for x in data_.seqlens["packed_prompts"])
|
||||
# non_packed_prompts = copy.deepcopy(prompts)
|
||||
prompts = torch.cat(prompts)
|
||||
prompt_seqlens = torch.tensor(prompt_seqlens).view(-1)
|
||||
prompt_cu_seqlens = torch.nn.functional.pad(
|
||||
prompt_seqlens.cumsum(0), (1, 0)
|
||||
).int()
|
||||
|
||||
query_ids = [data_id for data_id in data_.ids for _ in range(self.group_size)]
|
||||
|
||||
format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results = (
|
||||
retokenize(
|
||||
self.task,
|
||||
self.tokenizer,
|
||||
packed_input_ids,
|
||||
input_cu_seqlens,
|
||||
prompts,
|
||||
prompt_cu_seqlens,
|
||||
query_ids,
|
||||
check_xml_format=self.check_xml_format,
|
||||
do_eval=constants.is_last_pipe_stage(),
|
||||
)
|
||||
)
|
||||
|
||||
assert self.rw_type == "sparse"
|
||||
dense_scores = torch.zeros_like(packed_input_ids).float()
|
||||
scores = torch.FloatTensor(format_rewards).to(packed_input_ids.device)
|
||||
scores[scores == 0] = -1
|
||||
|
||||
if len(scores) == 0:
|
||||
return None
|
||||
|
||||
assert dense_scores.shape == packed_input_ids.shape
|
||||
scores = (
|
||||
scores.to(packed_input_ids.device) - self.output_bias
|
||||
) * self.output_scaling
|
||||
|
||||
logger.info(f"Code reward logging info @v{model.version.global_step}")
|
||||
logger.info(
|
||||
f"before: Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
||||
)
|
||||
logger.info(
|
||||
f"number of samples: {len(scores)}, {scores.shape}, group_size: {self.group_size}"
|
||||
)
|
||||
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated",
|
||||
f"v{self._call_count}r{dist.get_rank()}.txt",
|
||||
)
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
info = "\n".join(
|
||||
[
|
||||
f"idx: {idx} / {len(scores)}",
|
||||
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
|
||||
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
|
||||
]
|
||||
)
|
||||
_f.write(info + "\n")
|
||||
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated_jsonl",
|
||||
f"v{self._call_count}r{dist.get_rank()}.jsonl",
|
||||
)
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
_f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt": prompt_str,
|
||||
"generated": seq_str.split(prompt_str)[1],
|
||||
"reward": score.item(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
||||
)
|
||||
|
||||
pass_at_k = np.mean(
|
||||
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
|
||||
)
|
||||
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
|
||||
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
|
||||
|
||||
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
|
||||
|
||||
logger.info(f"reward: {sum(scores) / len(scores)}")
|
||||
|
||||
train_pass_monitor_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"training_monitor",
|
||||
f"v{self._call_count}r{dist.get_rank()}.jsonl",
|
||||
)
|
||||
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
|
||||
logger.info(
|
||||
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
|
||||
)
|
||||
with open(train_pass_monitor_file_path, "w") as monitor_file:
|
||||
for key, value in queryid_to_results.items():
|
||||
pass1 = sum(value) / len(value)
|
||||
pass8 = int(sum(value) > 0)
|
||||
monitor_file.write(
|
||||
json.dumps(
|
||||
{"query_id": key, "pass1": pass1, "pass8": pass8},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
self._call_count += 1
|
||||
|
||||
if scores.dtype != torch.float32:
|
||||
scores = scores.to(torch.float32)
|
||||
if dense_scores.dtype != torch.float32:
|
||||
dense_scores = dense_scores.to(torch.float32)
|
||||
|
||||
res = SequenceSample(
|
||||
keys=["rewards", "dense_rewards"],
|
||||
trailing_shapes=dict(rewards=(), dense_rewards=()),
|
||||
dtypes=dict(rewards=torch.float32, dense_rewards=torch.float32),
|
||||
ids=data_.ids,
|
||||
seqlens=dict(
|
||||
rewards=[
|
||||
torch.tensor([1 for _ in range(len(x))], dtype=torch.int32)
|
||||
for x in data_.seqlens["packed_input_ids"]
|
||||
],
|
||||
dense_rewards=data_.seqlens["packed_input_ids"],
|
||||
),
|
||||
data=dict(rewards=scores, dense_rewards=dense_scores),
|
||||
)
|
||||
|
||||
# record rewards for each piece of data
|
||||
avg_scores = []
|
||||
offset = 0
|
||||
for i in range(data_.bs):
|
||||
score_lis = scores[
|
||||
offset : offset + len(data_.seqlens["packed_input_ids"][i])
|
||||
]
|
||||
avg_scores.append(score_lis.mean().item())
|
||||
offset += len(data_.seqlens["packed_input_ids"][i])
|
||||
assert offset == sum(len(x) for x in data_.seqlens["packed_input_ids"])
|
||||
|
||||
res.metadata["scores"] = avg_scores
|
||||
|
||||
if self.check_verifier_status:
|
||||
avg_score = torch.tensor(
|
||||
np.mean(avg_scores), device=constants.current_device()
|
||||
)
|
||||
dist.all_reduce(
|
||||
avg_score, op=dist.ReduceOp.SUM, group=constants.parallelism_group()
|
||||
)
|
||||
avg_score /= constants.parallelism_group_size()
|
||||
avg_score = avg_score.item()
|
||||
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
|
||||
|
||||
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
|
||||
raise CodeVerifierException(
|
||||
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
model_api.register_interface("rw_code", PackedCodeRewardInterface)
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import dataclasses
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
import realhf.api.core.model_api as model_api
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.core.config import ModelInterfaceAbstraction
|
||||
from realhf.api.core.data_api import RL_TASKS, MicroBatchSpec, SequenceSample
|
||||
from realhf.base.datapack import flat2d
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FusedThreadingForwardInterface(model_api.ModelInterface):
|
||||
|
||||
def __init__(self, interfaces: Dict[str, ModelInterfaceAbstraction]):
|
||||
self.interfaces = {
|
||||
key: model_api.make_interface(interface)
|
||||
for key, interface in interfaces.items()
|
||||
}
|
||||
|
||||
def run_interface(
|
||||
self,
|
||||
interface_name: str,
|
||||
model,
|
||||
data,
|
||||
mb_spec,
|
||||
) -> SequenceSample | None:
|
||||
tik = time.perf_counter()
|
||||
res = self.interfaces[interface_name].inference(model, data, mb_spec)
|
||||
t = time.perf_counter() - tik
|
||||
logger.info(f"Interface {interface_name} cost {t} s")
|
||||
return res
|
||||
|
||||
def inference(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample | None:
|
||||
with ThreadPoolExecutor(max_workers=len(self.interfaces)) as executor:
|
||||
tasks = []
|
||||
for interface_name in self.interfaces:
|
||||
task = executor.submit(
|
||||
self.run_interface, interface_name, model, data, mb_spec
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
final_result = None
|
||||
for task in as_completed(tasks):
|
||||
res = task.result()
|
||||
if res is None:
|
||||
continue
|
||||
if final_result is None:
|
||||
final_result = res
|
||||
else:
|
||||
final_result.update_(res)
|
||||
|
||||
return final_result
|
||||
|
||||
|
||||
model_api.register_interface("fused-threading", FusedThreadingForwardInterface)
|
|
@ -1,46 +1,73 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import ast
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import html
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import xml.etree.ElementTree as ET
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import colorama
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import tqdm
|
||||
import transformers
|
||||
|
||||
import realhf.api.core.model_api as model_api
|
||||
import realhf.base.logging as logging
|
||||
from functioncall.code.local_verify import code_verify as local_code_verify
|
||||
from functioncall.code.verify import code_verify
|
||||
from functioncall.math.verify import math_verify
|
||||
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer
|
||||
from realhf.api.core.data_api import (
|
||||
RL_TASKS,
|
||||
MicroBatchSpec,
|
||||
SequenceSample,
|
||||
load_hf_tokenizer,
|
||||
)
|
||||
from realhf.base import constants
|
||||
from realhf.base.constants import data_parallel_group, data_parallel_world_size
|
||||
from realhf.base.datapack import flat2d
|
||||
from realhf.impl.model.interface.math_parser import parse_lines_in_parallel
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
from realhf.impl.dataset.math_code_dataset import load_metadata
|
||||
from realhf.impl.dataset.math_parser import parse_lines_in_parallel as math_verify_local
|
||||
|
||||
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
|
||||
|
||||
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
|
||||
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
|
||||
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else math_verify_local
|
||||
code_verify_call = code_verify if ENABLE_FUNCTION_CALL else local_code_verify
|
||||
|
||||
|
||||
class MathVerifierException(Exception):
|
||||
class VerifierException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def extract_python_code(text, min_length=20, strict_syntax=True):
|
||||
code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```"
|
||||
code_blocks = re.findall(code_pattern, text, re.DOTALL)
|
||||
valid_blocks = []
|
||||
for block in code_blocks:
|
||||
clean_block = block.strip()
|
||||
if len(clean_block) < min_length:
|
||||
continue
|
||||
|
||||
# verify code syntax
|
||||
if strict_syntax:
|
||||
try:
|
||||
ast.parse(clean_block, mode="exec")
|
||||
except (SyntaxError, IndentationError):
|
||||
continue
|
||||
|
||||
valid_blocks.append(clean_block)
|
||||
|
||||
if not valid_blocks:
|
||||
logger.warning(f"failed to extract python code from {text}")
|
||||
return None
|
||||
# return the last code block
|
||||
return valid_blocks[-1]
|
||||
|
||||
|
||||
def check_with_elementtree(text):
|
||||
def escape_between_tags(text, tags=["think", "answer"]):
|
||||
"""转义标签之间的内容,但保留标签本身."""
|
||||
|
@ -94,27 +121,37 @@ def check_with_elementtree(text):
|
|||
return False, f"Error: XML格式错误, {str(e)}"
|
||||
|
||||
|
||||
def retokenize(
|
||||
id2info = {}
|
||||
|
||||
|
||||
def dispatch_reward_calculation(task, answers, query_id_strs) -> List:
|
||||
global id2info
|
||||
assert len(answers) == len(query_id_strs)
|
||||
format_rewards = []
|
||||
if task == "math":
|
||||
format_rewards = math_verify_call(id2info, answers, query_id_strs)
|
||||
elif task == "code":
|
||||
codes = [extract_python_code(_answer) for _answer in answers]
|
||||
format_rewards = code_verify_call(id2info, codes, query_id_strs)
|
||||
assert len(format_rewards) == len(answers), (
|
||||
task,
|
||||
len(format_rewards),
|
||||
len(answers),
|
||||
answers,
|
||||
)
|
||||
return format_rewards
|
||||
|
||||
|
||||
def retokenize_and_verify(
|
||||
task,
|
||||
tokenizer,
|
||||
packed_input_ids,
|
||||
input_cu_seqlens,
|
||||
prompts,
|
||||
prompt_cu_seqlens,
|
||||
query_ids,
|
||||
prompt_ids: List[List[int]],
|
||||
seq_ids: List[List[int]],
|
||||
query_ids: List[str],
|
||||
check_xml_format=False,
|
||||
do_eval=False,
|
||||
):
|
||||
input_ids = [
|
||||
packed_input_ids[start:end]
|
||||
for start, end in zip(input_cu_seqlens[:-1], input_cu_seqlens[1:])
|
||||
]
|
||||
prompt_ids = [
|
||||
prompts[start:end]
|
||||
for start, end in zip(prompt_cu_seqlens[:-1], prompt_cu_seqlens[1:])
|
||||
]
|
||||
seq_strs = tokenizer.batch_decode(
|
||||
input_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
||||
seq_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
||||
)
|
||||
prompt_strs = tokenizer.batch_decode(
|
||||
prompt_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
||||
|
@ -122,261 +159,185 @@ def retokenize(
|
|||
# query_id_strs = query_ids
|
||||
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
|
||||
|
||||
format_rewards = []
|
||||
queryid_to_results = collections.defaultdict(list)
|
||||
# 8 processes on each node, with 10 subprocesses each
|
||||
if do_eval == True:
|
||||
_answers = [
|
||||
seq_str.split(prompt_str)[1]
|
||||
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
|
||||
]
|
||||
answers = [
|
||||
seq_str.split(prompt_str)[1]
|
||||
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
|
||||
]
|
||||
|
||||
format_rewards = math_verify_call(_answers, query_id_strs)
|
||||
if check_xml_format:
|
||||
with ThreadPoolExecutor(max_workers=22) as executor:
|
||||
futures = [
|
||||
executor.submit(check_with_elementtree, answer_str)
|
||||
for answer_str in _answers
|
||||
]
|
||||
# xml_rewards = []
|
||||
for idx, future in enumerate(futures):
|
||||
xml_reward, _ = future.result()
|
||||
# xml_rewards.append(xml_reward)
|
||||
if xml_reward == 1 and format_rewards[idx] == 0:
|
||||
format_rewards[idx] = -0.8
|
||||
elif xml_reward == 0 and format_rewards[idx] == 0:
|
||||
format_rewards[idx] = -1
|
||||
format_rewards = dispatch_reward_calculation(task, answers, query_id_strs)
|
||||
|
||||
for query_id_str, format_reward in zip(query_id_strs, format_rewards):
|
||||
if query_id_str not in queryid_to_results:
|
||||
queryid_to_results[query_id_str] = []
|
||||
queryid_to_results[query_id_str].append(format_reward)
|
||||
else:
|
||||
for query_id_str in query_id_strs:
|
||||
if query_id_str not in queryid_to_results:
|
||||
queryid_to_results[query_id_str] = []
|
||||
queryid_to_results[query_id_str].append(0)
|
||||
format_rewards.append(0)
|
||||
if check_xml_format:
|
||||
for idx, answer in enumerate(answers):
|
||||
xml_reward, _ = check_with_elementtree(answer)
|
||||
if xml_reward == 1 and format_rewards[idx] == 0:
|
||||
format_rewards[idx] = -0.8
|
||||
elif xml_reward == 0 and format_rewards[idx] == 0:
|
||||
format_rewards[idx] = -1
|
||||
|
||||
return format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results
|
||||
return format_rewards, prompt_strs, seq_strs
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PackedMathRewardInterface(model_api.ModelInterface):
|
||||
|
||||
enable_save: bool = False
|
||||
class MultiTaskRewardInterface(model_api.ModelInterface):
|
||||
dataset_path: str = ""
|
||||
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
||||
output_scaling: float = 1.0
|
||||
rm_output_scaling: float = 1.0
|
||||
rm_output_bias: float = 0.0
|
||||
output_bias: float = 0.0
|
||||
loss_fun = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
max_sync_length: int = 2048
|
||||
rw_type: str = "sparse"
|
||||
task: str = "math" # math or countdown
|
||||
check_xml_format: bool = False
|
||||
post_process: str = "sigmoid"
|
||||
group_size: int = 1
|
||||
check_verifier_status: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
global id2info
|
||||
id2info, _ = load_metadata(self.dataset_path)
|
||||
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
|
||||
logger.info(f"rm_output_scaling: {self.rm_output_scaling}")
|
||||
logger.info(f"rm_output_bias: {self.rm_output_bias}")
|
||||
logger.info(f"output_scaling: {self.output_scaling}")
|
||||
logger.info(f"output_bias: {self.output_bias}")
|
||||
logger.info(f"max_sync_length: {self.max_sync_length}")
|
||||
logger.info(f"rw_type: {self.rw_type}")
|
||||
logger.info(f"post_process: {self.post_process}")
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(f"output_scaling: {self.output_scaling}")
|
||||
logger.info(f"output_bias: {self.output_bias}")
|
||||
logger.info(f"rw_type: {self.rw_type}")
|
||||
|
||||
def inference(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
data_: SequenceSample,
|
||||
mb_spec,
|
||||
) -> SequenceSample:
|
||||
|
||||
packed_input_ids: torch.Tensor = data_.data["packed_input_ids"].squeeze()
|
||||
|
||||
input_seqlens = torch.tensor(data_.seqlens["packed_input_ids"]).view(-1)
|
||||
input_cu_seqlens = torch.nn.functional.pad(
|
||||
input_seqlens.cumsum(0), (1, 0)
|
||||
).int()
|
||||
|
||||
packed_prompts = data_.data["packed_prompts"]
|
||||
prompts = []
|
||||
prompt_seqlens = []
|
||||
offset = 0
|
||||
for x in data_.seqlens["packed_prompts"]:
|
||||
prompts += [packed_prompts[offset : offset + x[0]]] * self.group_size
|
||||
offset += x[0]
|
||||
prompt_seqlens.extend(x * self.group_size)
|
||||
|
||||
assert offset == sum(x[0] for x in data_.seqlens["packed_prompts"])
|
||||
# non_packed_prompts = copy.deepcopy(prompts)
|
||||
prompts = torch.cat(prompts)
|
||||
prompt_seqlens = torch.tensor(prompt_seqlens).view(-1)
|
||||
prompt_cu_seqlens = torch.nn.functional.pad(
|
||||
prompt_seqlens.cumsum(0), (1, 0)
|
||||
).int()
|
||||
|
||||
query_ids = [data_id for data_id in data_.ids for _ in range(self.group_size)]
|
||||
|
||||
format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results = (
|
||||
retokenize(
|
||||
self.task,
|
||||
self.tokenizer,
|
||||
packed_input_ids,
|
||||
input_cu_seqlens,
|
||||
prompts,
|
||||
prompt_cu_seqlens,
|
||||
query_ids,
|
||||
check_xml_format=self.check_xml_format,
|
||||
do_eval=constants.is_last_pipe_stage(),
|
||||
def _dispatch_tasks(self, data: SequenceSample) -> Tuple[Dict, Dict]:
|
||||
xs = data.unpack()
|
||||
dispatched = {}
|
||||
dispatched_indices = {}
|
||||
for task_idx, task_name in enumerate(RL_TASKS):
|
||||
indices = (
|
||||
(data.data["task_ids"] == task_idx).cpu().numpy().nonzero()[0].tolist()
|
||||
)
|
||||
if len(indices) > 0:
|
||||
dispatched[task_name] = SequenceSample.gather([xs[i] for i in indices])
|
||||
dispatched_indices[task_name] = indices
|
||||
|
||||
return dispatched, dispatched_indices
|
||||
|
||||
def _gather_tasks(
|
||||
self, results: Dict, dispatched_indices: Dict, bs: int
|
||||
) -> SequenceSample:
|
||||
xs = [None for _ in range(bs)]
|
||||
for task_name, indices in dispatched_indices.items():
|
||||
xxs = results[task_name].unpack()
|
||||
assert len(indices) == len(xxs), (len(indices), len(xxs))
|
||||
for i, xx in zip(indices, xxs):
|
||||
xs[i] = xx
|
||||
assert all(xs)
|
||||
return SequenceSample.gather(xs)
|
||||
|
||||
def _dispatch_tp_and_pp(self, data: SequenceSample):
|
||||
tp_pp_size = constants.tp_and_pp_world_size()
|
||||
if tp_pp_size == 1:
|
||||
return data, None
|
||||
splitted, _, backward_indices = data.split(
|
||||
mb_spec=MicroBatchSpec(n_mbs=tp_pp_size)
|
||||
)
|
||||
tp_pp_rank = constants.tp_and_pp_rank()
|
||||
print("dispatched batch size", [s.bs for s in splitted], flush=True)
|
||||
return splitted[tp_pp_rank], backward_indices
|
||||
|
||||
def _gather_tp_and_pp(self, input_, data: SequenceSample, backward_indices):
|
||||
tp_pp_size = constants.tp_and_pp_world_size()
|
||||
if tp_pp_size == 1:
|
||||
return data
|
||||
local_rank = constants.grid().topo.get_rank(
|
||||
data=constants.data_parallel_rank(),
|
||||
model=0,
|
||||
pipe=constants.pipe_parallel_world_size() - 1,
|
||||
)
|
||||
dst = constants.to_global_pg_rank(local_rank)
|
||||
gather_list = None
|
||||
if dist.get_rank() == dst:
|
||||
gather_list = [None for _ in range(tp_pp_size)]
|
||||
x = data.data["rewards"].cpu().numpy().tolist()
|
||||
print(x, flush=True)
|
||||
dist.gather_object(
|
||||
x, gather_list, dst=dst, group=constants.tp_and_pp_cpu_group()
|
||||
)
|
||||
if dist.get_rank() != dst:
|
||||
return None
|
||||
gathered = np.array(gather_list).reshape(-1, self.group_size)
|
||||
assert len(gathered) == len(backward_indices)
|
||||
rewards = (
|
||||
np.concatenate([gathered[i] for i in backward_indices]).flatten().tolist()
|
||||
)
|
||||
return SequenceSample(
|
||||
keys=["rewards"],
|
||||
trailing_shapes=dict(rewards=()),
|
||||
dtypes=dict(rewards=torch.float32),
|
||||
ids=input_.ids,
|
||||
seqlens=dict(
|
||||
rewards=[[1 for _ in range(self.group_size)] for _ in range(input_.bs)],
|
||||
),
|
||||
data=dict(rewards=torch.tensor(rewards, dtype=torch.float32)),
|
||||
)
|
||||
|
||||
assert self.rw_type == "sparse"
|
||||
dense_scores = torch.zeros_like(packed_input_ids).float()
|
||||
def calculate_task_reward(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
task_type: str,
|
||||
):
|
||||
# mb_spec is disrespected here
|
||||
packed_input_ids: torch.Tensor = data.data["packed_input_ids"]
|
||||
input_seqlens = flat2d(data.seqlens["packed_input_ids"])
|
||||
seq_ids = []
|
||||
offset = 0
|
||||
for slen in input_seqlens:
|
||||
seq_ids.append(
|
||||
packed_input_ids[offset : offset + slen].cpu().numpy().tolist()
|
||||
)
|
||||
offset += slen
|
||||
assert offset == packed_input_ids.shape[0], (offset, packed_input_ids.shape)
|
||||
prompt_input_ids = data.data["packed_prompts"]
|
||||
prompt_len = flat2d(data.seqlens["packed_prompts"])
|
||||
prompt_ids = []
|
||||
offset = 0
|
||||
for slen in prompt_len:
|
||||
p = prompt_input_ids[offset : offset + slen].cpu().numpy().tolist()
|
||||
prompt_ids += [p] * self.group_size
|
||||
offset += slen
|
||||
format_rewards, prompt_strs, seq_strs = retokenize_and_verify(
|
||||
task_type,
|
||||
self.tokenizer,
|
||||
prompt_ids=prompt_ids,
|
||||
seq_ids=seq_ids,
|
||||
query_ids=[
|
||||
str(data_id) for data_id in data.ids for _ in range(self.group_size)
|
||||
],
|
||||
check_xml_format=self.check_xml_format,
|
||||
)
|
||||
scores = torch.FloatTensor(format_rewards).to(packed_input_ids.device)
|
||||
scores[scores == 0] = -1
|
||||
|
||||
if len(scores) == 0:
|
||||
return None
|
||||
|
||||
assert dense_scores.shape == packed_input_ids.shape
|
||||
scores = (
|
||||
scores.to(packed_input_ids.device) - self.output_bias
|
||||
) * self.output_scaling
|
||||
|
||||
logger.info(f"Math reward logging info @v{model.version.global_step}")
|
||||
logger.info(
|
||||
f"before: Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
||||
)
|
||||
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
|
||||
if constants.is_last_pipe_stage():
|
||||
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated",
|
||||
f"v{model.version.global_step}r{dist.get_rank()}.txt",
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
logger.info(
|
||||
f"Generated samples and rewards will be dumped to: {gen_file_path}"
|
||||
)
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
info = "\n".join(
|
||||
[
|
||||
f"idx: {idx} / {len(scores)}",
|
||||
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
|
||||
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
|
||||
]
|
||||
)
|
||||
_f.write(info + "\n")
|
||||
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated_jsonl",
|
||||
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
|
||||
)
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
logger.info(
|
||||
f"Generated samples and rewards will be dumped to: {gen_file_path}"
|
||||
)
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
_f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt": prompt_str,
|
||||
"generated": seq_str.split(prompt_str)[1],
|
||||
"reward": score.item(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
||||
)
|
||||
|
||||
pass_at_k = np.mean(
|
||||
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
|
||||
)
|
||||
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
|
||||
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
|
||||
|
||||
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
|
||||
|
||||
logger.info(f"reward: {sum(scores) / len(scores)}")
|
||||
|
||||
train_pass_monitor_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"training_monitor",
|
||||
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
|
||||
)
|
||||
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
|
||||
logger.info(
|
||||
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
|
||||
)
|
||||
with open(train_pass_monitor_file_path, "w") as monitor_file:
|
||||
for key, value in queryid_to_results.items():
|
||||
pass1 = sum(value) / len(value)
|
||||
pass8 = int(sum(value) > 0)
|
||||
monitor_file.write(
|
||||
json.dumps(
|
||||
{"query_id": key, "pass1": pass1, "pass8": pass8},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
model.inc_version()
|
||||
|
||||
if scores.dtype != torch.float32:
|
||||
scores = scores.to(torch.float32)
|
||||
if dense_scores.dtype != torch.float32:
|
||||
dense_scores = dense_scores.to(torch.float32)
|
||||
self.log_rewards_to_file(task_type, model, prompt_strs, seq_strs, scores)
|
||||
|
||||
res = SequenceSample(
|
||||
keys=["rewards", "dense_rewards"],
|
||||
trailing_shapes=dict(rewards=(), dense_rewards=()),
|
||||
dtypes=dict(rewards=torch.float32, dense_rewards=torch.float32),
|
||||
ids=data_.ids,
|
||||
keys=["rewards"],
|
||||
trailing_shapes=dict(rewards=()),
|
||||
dtypes=dict(rewards=torch.float32),
|
||||
ids=data.ids,
|
||||
seqlens=dict(
|
||||
rewards=[
|
||||
torch.tensor([1 for _ in range(len(x))], dtype=torch.int32)
|
||||
for x in data_.seqlens["packed_input_ids"]
|
||||
[1 for _ in range(len(x))] for x in data.seqlens["packed_input_ids"]
|
||||
],
|
||||
dense_rewards=data_.seqlens["packed_input_ids"],
|
||||
),
|
||||
data=dict(rewards=scores, dense_rewards=dense_scores),
|
||||
data=dict(rewards=scores),
|
||||
)
|
||||
|
||||
# record rewards for each piece of data
|
||||
avg_scores = []
|
||||
offset = 0
|
||||
for i in range(data_.bs):
|
||||
for i in range(data.bs):
|
||||
score_lis = scores[
|
||||
offset : offset + len(data_.seqlens["packed_input_ids"][i])
|
||||
offset : offset + len(data.seqlens["packed_input_ids"][i])
|
||||
]
|
||||
avg_scores.append(score_lis.mean().item())
|
||||
offset += len(data_.seqlens["packed_input_ids"][i])
|
||||
assert offset == sum(len(x) for x in data_.seqlens["packed_input_ids"])
|
||||
offset += len(data.seqlens["packed_input_ids"][i])
|
||||
assert offset == sum(len(x) for x in data.seqlens["packed_input_ids"])
|
||||
|
||||
res.metadata["scores"] = avg_scores
|
||||
|
||||
|
@ -385,20 +346,165 @@ class PackedMathRewardInterface(model_api.ModelInterface):
|
|||
np.mean(avg_scores), device=constants.current_device()
|
||||
)
|
||||
dist.all_reduce(
|
||||
avg_score, op=dist.ReduceOp.SUM, group=constants.parallelism_group()
|
||||
avg_score, op=dist.ReduceOp.SUM, group=constants.data_parallel_group()
|
||||
)
|
||||
avg_score /= constants.parallelism_group_size()
|
||||
avg_score /= constants.data_parallel_group()
|
||||
avg_score = avg_score.item()
|
||||
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
|
||||
minimal_score = (-1 - self.output_bias) * self.output_scaling
|
||||
|
||||
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
|
||||
raise MathVerifierException(
|
||||
raise VerifierException(
|
||||
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
|
||||
)
|
||||
|
||||
if not constants.is_last_pipe_stage():
|
||||
return None
|
||||
return res
|
||||
|
||||
def log_rewards_to_file(
|
||||
self, task_type: str, model: model_api.Model, prompt_strs, seq_strs, scores
|
||||
):
|
||||
tik = time.perf_counter()
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated",
|
||||
task_type,
|
||||
f"v{model.version.global_step}r{dist.get_rank()}.txt",
|
||||
)
|
||||
|
||||
model_api.register_interface("rw_math", PackedMathRewardInterface)
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
info = "\n".join(
|
||||
[
|
||||
f"idx: {idx} / {len(scores)}",
|
||||
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
|
||||
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
|
||||
]
|
||||
)
|
||||
_f.write(info + "\n")
|
||||
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated_jsonl",
|
||||
task_type,
|
||||
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
|
||||
)
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
_f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt": prompt_str,
|
||||
"generated": seq_str.split(prompt_str)[1],
|
||||
"reward": score.item(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
logger.info(f"[{task_type}] number of samples: {len(scores)}, {scores.shape}")
|
||||
logger.info(f"[{task_type}] avg reward: {sum(scores) / len(scores)}")
|
||||
logger.info(f"[{task_type}] log to file time: {time.perf_counter()- tik:.2f}s")
|
||||
|
||||
def inference(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample | None:
|
||||
input_ = data
|
||||
data, backward_indices = self._dispatch_tp_and_pp(data)
|
||||
task_data, dispatch_indices = self._dispatch_tasks(data)
|
||||
|
||||
assert self.rw_type == "sparse"
|
||||
|
||||
def _task_func(func, task_type: str):
|
||||
def _wrapped_func(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
raise asyncio.CancelledError(
|
||||
f"[{task_type}] task failed: {e}"
|
||||
) from e
|
||||
finally:
|
||||
duration = time.perf_counter() - start_time
|
||||
logger.info(f"[{task_type}] time cost: {duration:.4f}s")
|
||||
return task_type, result
|
||||
|
||||
return _wrapped_func
|
||||
|
||||
async def _run_tasks():
|
||||
tasks = []
|
||||
for task_type, d in task_data.items():
|
||||
task_func = _task_func(self.calculate_task_reward, task_type)
|
||||
task_args = (model, d, mb_spec, task_type)
|
||||
task = asyncio.create_task(asyncio.to_thread(task_func, *task_args))
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
task_results = {}
|
||||
for res in results:
|
||||
task_type, result = res
|
||||
task_results[task_type] = result
|
||||
|
||||
return task_results
|
||||
|
||||
task_results = asyncio.run(_run_tasks())
|
||||
final_result = self._gather_tasks(task_results, dispatch_indices, data.bs)
|
||||
final_result = self._gather_tp_and_pp(input_, final_result, backward_indices)
|
||||
|
||||
model.inc_version()
|
||||
|
||||
return final_result
|
||||
|
||||
def _mock_inference(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
data: SequenceSample,
|
||||
) -> SequenceSample:
|
||||
|
||||
prompt_lens = flat2d(data.seqlens["packed_prompts"])
|
||||
task_ids = data.data["task_ids"].cpu().numpy().tolist()
|
||||
seqlens = []
|
||||
offset = 0
|
||||
seq = []
|
||||
for plen, task_id in zip(prompt_lens, task_ids):
|
||||
seq += [data.data["packed_prompts"][offset : offset + plen]]
|
||||
offset += plen
|
||||
if task_id == RL_TASKS.index("math"):
|
||||
answer_str = (
|
||||
"something unimportant but the answer is \\boxed{-\\frac{2}{3}}."
|
||||
)
|
||||
elif task_id == RL_TASKS.index("code"):
|
||||
answer_str = (
|
||||
"```python\ninput()\nimport time\ntime.sleep(1e-3)\nprint(1)\n```"
|
||||
)
|
||||
else:
|
||||
answer_str = "something unimportant"
|
||||
encoding = model.tokenizer(
|
||||
[answer_str], add_special_tokens=True, return_attention_mask=False
|
||||
)
|
||||
|
||||
ans = torch.tensor(encoding["input_ids"], dtype=torch.long).flatten()
|
||||
seq += [ans]
|
||||
seqlens.append(plen + len(ans))
|
||||
|
||||
x = SequenceSample.from_default(
|
||||
seqlens=seqlens,
|
||||
ids=data.ids,
|
||||
data=dict(packed_input_ids=torch.cat(seq)),
|
||||
)
|
||||
data.update_(x)
|
||||
return data
|
||||
|
||||
|
||||
model_api.register_interface("rw-math-code", MultiTaskRewardInterface)
|
||||
|
|
|
@ -18,7 +18,7 @@ import realhf.base.logging as logging
|
|||
import realhf.impl.model.utils.ppo_functional as ppo_functional
|
||||
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
|
||||
from realhf.base.datapack import flat2d
|
||||
from realhf.impl.model.interface.math_parser import parse_lines_in_parallel
|
||||
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
from realhf.impl.model.nn.real_llm_generate import concat_prompt_to_generation_output
|
||||
from realhf.impl.model.utils.functional import (
|
||||
|
|
|
@ -618,8 +618,6 @@ class RayController:
|
|||
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
|
||||
REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""),
|
||||
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
|
||||
REAL_MATH_METADATA_PATH=os.environ.get("REAL_MATH_METADATA_PATH", ""),
|
||||
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
|
||||
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
||||
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
|
||||
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),
|
||||
|
|
|
@ -920,7 +920,7 @@ class ModelWorker(worker_base.Worker):
|
|||
"dataset_eval_scores.json",
|
||||
)
|
||||
eval_scores = {}
|
||||
if isinstance(res, data_api.SequenceSample):
|
||||
if isinstance(res, data_api.SequenceSample) and constants.is_dp_head():
|
||||
if rpc.output_key_remap:
|
||||
res.remap_keys_(rpc.output_key_remap)
|
||||
res = res.select(rpc.output_keys)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import *
|
||||
|
||||
import pytest
|
||||
|
@ -27,23 +28,25 @@ def model_class(request):
|
|||
|
||||
|
||||
@pytest.fixture(params=[testing.TESTING_DATASET_SIZE])
|
||||
def math_dataset(request, save_path):
|
||||
with open(os.getenv("REAL_MATH_METADATA_PATH"), "r") as f:
|
||||
query_ids = list(json.load(f).keys())
|
||||
def math_code_dataset(request, save_path):
|
||||
size = request.param
|
||||
max_prompt_len = 8
|
||||
max_resp_len = 8
|
||||
dataset = []
|
||||
for i in range(size):
|
||||
prompt_len = random.randint(1, max_prompt_len)
|
||||
n_pairs = random.randint(1, 5)
|
||||
d = dict(
|
||||
query_id=query_ids[i],
|
||||
query_id=str(uuid.uuid4()),
|
||||
prompt=generate_random_sentence(prompt_len),
|
||||
task=random.choice(["math", "code"]),
|
||||
)
|
||||
if d["task"] == "math":
|
||||
d["solutions"] = [generate_random_sentence(max_resp_len)]
|
||||
elif d["task"] == "code":
|
||||
d["input_output"] = json.dumps(dict(inputs=["the\n"], outputs=["the\n"]))
|
||||
dataset.append(d)
|
||||
with open(str(save_path / "math_dataset.json"), "w") as f:
|
||||
json.dump(dataset, f)
|
||||
with open(str(save_path / "math_code_dataset.jsonl"), "a") as f:
|
||||
f.write(json.dumps(d) + "\n")
|
||||
return dataset
|
||||
|
||||
|
||||
|
@ -59,7 +62,7 @@ def math_dataset(request, save_path):
|
|||
def test_ppo_symm(
|
||||
tmp_path_factory,
|
||||
tokenizer,
|
||||
math_dataset,
|
||||
math_code_dataset,
|
||||
save_path,
|
||||
cpu_hf_model,
|
||||
mconfig,
|
||||
|
@ -97,13 +100,8 @@ def test_ppo_symm(
|
|||
init_critic_from_actor=True,
|
||||
backend="mock_train",
|
||||
),
|
||||
rew=ModelTrainEvalConfig(
|
||||
path=str(save_path),
|
||||
init_critic_from_actor=True,
|
||||
init_from_scratch=True,
|
||||
),
|
||||
dataset=PromptOnlyDatasetConfig(
|
||||
path=str(save_path / "math_dataset.json"),
|
||||
path=str(save_path / "math_code_dataset.jsonl"),
|
||||
max_prompt_len=mconfig.n_positions // 2,
|
||||
train_bs_n_seqs=minbs,
|
||||
fill_to_max_length=False,
|
||||
|
@ -116,6 +114,7 @@ def test_ppo_symm(
|
|||
use_cuda_graph=False,
|
||||
),
|
||||
),
|
||||
group_size=2,
|
||||
)
|
||||
|
||||
run_test_exp(exp_cfg)
|
||||
|
@ -133,7 +132,7 @@ def test_ppo_symm(
|
|||
def test_ppo_global_reshard(
|
||||
tmp_path_factory,
|
||||
tokenizer,
|
||||
math_dataset,
|
||||
math_code_dataset,
|
||||
save_path,
|
||||
cpu_hf_model,
|
||||
mconfig,
|
||||
|
@ -180,7 +179,7 @@ def test_ppo_global_reshard(
|
|||
init_from_scratch=True,
|
||||
),
|
||||
dataset=PromptOnlyDatasetConfig(
|
||||
path=str(save_path / "math_dataset.json"),
|
||||
path=str(save_path / "math_code_dataset.jsonl"),
|
||||
max_prompt_len=mconfig.n_positions // 2,
|
||||
train_bs_n_seqs=minbs,
|
||||
fill_to_max_length=False,
|
||||
|
@ -249,7 +248,7 @@ def test_ppo_global_reshard(
|
|||
def test_ppo_param_realloc_sub_device_mesh(
|
||||
tmp_path_factory,
|
||||
tokenizer,
|
||||
math_dataset,
|
||||
math_code_dataset,
|
||||
save_path,
|
||||
cpu_hf_model,
|
||||
mconfig,
|
||||
|
@ -292,7 +291,7 @@ def test_ppo_param_realloc_sub_device_mesh(
|
|||
init_from_scratch=True,
|
||||
),
|
||||
dataset=PromptOnlyDatasetConfig(
|
||||
path=str(save_path / "math_dataset.json"),
|
||||
path=str(save_path / "math_code_dataset.jsonl"),
|
||||
max_prompt_len=mconfig.n_positions // 2,
|
||||
train_bs_n_seqs=minbs,
|
||||
fill_to_max_length=False,
|
||||
|
@ -410,7 +409,7 @@ def test_ppo_save(
|
|||
init_from_scratch=True,
|
||||
),
|
||||
dataset=PromptOnlyDatasetConfig(
|
||||
path=str(save_path / "math_dataset.json"),
|
||||
path=str(save_path / "math_code_dataset.jsonl"),
|
||||
max_prompt_len=mconfig.n_positions // 2,
|
||||
train_bs_n_seqs=bs,
|
||||
fill_to_max_length=False,
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright 2025 Ant Group Inc. All Rights Reserved.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import *
|
||||
|
||||
import pytest
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from realhf.api.core.data_api import (
|
||||
DatasetUtility,
|
||||
MicroBatchSpec,
|
||||
SequenceSample,
|
||||
load_hf_tokenizer,
|
||||
)
|
||||
from realhf.api.core.model_api import FinetuneSpec, Model
|
||||
from realhf.base import constants, network, testing
|
||||
from tests.fixtures import *
|
||||
|
||||
|
||||
@pytest.fixture(params=[testing.TESTING_DATASET_SIZE])
|
||||
def math_code_dataset(request, save_path):
|
||||
size = request.param
|
||||
max_prompt_len = 8
|
||||
dataset = []
|
||||
for i in range(size):
|
||||
prompt_len = random.randint(1, max_prompt_len)
|
||||
n_pairs = random.randint(1, 5)
|
||||
if random.random() < 0.5:
|
||||
d = dict(
|
||||
task="code",
|
||||
query_id=str(uuid.uuid4()),
|
||||
prompt=generate_random_sentence(prompt_len),
|
||||
problem_id=str(uuid.uuid4()),
|
||||
input_output=json.dumps(
|
||||
{"inputs": ["1\n"] * 8, "outputs": ["1\n"] * 8}
|
||||
),
|
||||
solutions=json.dumps(
|
||||
["```python\ninput()\nimport time\ntime.sleep(1e-3)\nprint(1)\n```"]
|
||||
* 3
|
||||
),
|
||||
difficulty=random.random() * 10,
|
||||
)
|
||||
else:
|
||||
d = dict(
|
||||
task="math",
|
||||
query_id=str(uuid.uuid4()),
|
||||
prompt=generate_random_sentence(prompt_len),
|
||||
answers=["\\boxed{-\\frac{2}{3}}"],
|
||||
solutions=["\\boxed{-\\frac{2}{3}}"],
|
||||
)
|
||||
dataset.append(d)
|
||||
with open(str(save_path / "math_code_dataset.jsonl"), "w") as f:
|
||||
f.write("\n".join([json.dumps(d) for d in dataset]))
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokenizer_path", ["/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"]
|
||||
)
|
||||
def test_multi_task_reward_interface(save_path, tokenizer_path, math_code_dataset):
|
||||
from realhf.impl.dataset.math_code_dataset import MATHCodePromptDataset
|
||||
|
||||
dist.init_process_group(
|
||||
rank=0, world_size=1, init_method=f"tcp://localhost:{network.find_free_port()}"
|
||||
)
|
||||
testing.init_global_constants()
|
||||
|
||||
dataset = MATHCodePromptDataset(
|
||||
DatasetUtility(
|
||||
seed=0,
|
||||
dp_rank=0,
|
||||
world_size=1,
|
||||
tokenizer=load_hf_tokenizer(tokenizer_path),
|
||||
),
|
||||
max_length=512,
|
||||
dataset_path=str(save_path / "math_code_dataset.jsonl"),
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
collate_fn=SequenceSample.gather,
|
||||
# NOTE: This is *NOT* the actual batch size for training.
|
||||
# It is just a proper size to load data to workers.
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
from realhf.impl.model.interface.rw_interface import MultiTaskRewardInterface
|
||||
|
||||
with constants.model_scope(testing.MODEL_NAME):
|
||||
interface = MultiTaskRewardInterface(
|
||||
dataset_path=str(save_path / "math_code_dataset.jsonl"),
|
||||
tokenizer_path=tokenizer_path,
|
||||
group_size=1,
|
||||
check_verifier_status=False,
|
||||
)
|
||||
model = Model(
|
||||
name="test",
|
||||
module=None,
|
||||
tokenizer=load_hf_tokenizer(tokenizer_path),
|
||||
device=torch.device("cpu"),
|
||||
ft_spec=FinetuneSpec(
|
||||
total_train_epochs=1, dataset_size=100, train_batch_size=3
|
||||
),
|
||||
)
|
||||
|
||||
for d in dataloader:
|
||||
d = interface.mock("inference", model, d)
|
||||
rewards = interface.inference(model, d, mb_spec=MicroBatchSpec())
|
||||
d.update_(rewards)
|
||||
assert rewards.data["rewards"].all(), rewards.data["rewards"]
|
||||
dist.destroy_process_group()
|
Loading…
Reference in New Issue