update dataset for fused-refrw

This commit is contained in:
meijun.mei 2025-03-27 16:23:10 +08:00
parent fc79f21622
commit abfe8bd30f
14 changed files with 400 additions and 43 deletions

View File

@ -54,22 +54,21 @@ Please refer to the [tutorial](/examples/README.md) under the `examples` directo
# Download the dataset
DATA_PATH=/storage/datasets/
cd $DATA_PATH
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_r1_distilled.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_prompts_for_r1_distilled.jsonl?download=true
# Training in a Ray cluster with 16 nodes
# stage 1
MODEL_PATH=${path_to_DeepSeek-R1-Distill-Qwen-1.5B}
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_r1_distilled.jsonl $DATA_PATH/id2info.json 8192
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_prompts_for_r1_distilled.jsonl 8192
# stage 2
MODEL_PATH=${model_path_from_stage_1}
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_r1_distilled.jsonl $DATA_PATH/id2info.json 16384
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_prompts_for_r1_distilled.jsonl 16384
# stage 3
MODEL_PATH=${model_path_from_stage_2}
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_r1_distilled.jsonl $DATA_PATH/id2info.json 24000
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_prompts_for_r1_distilled.jsonl 24000
```
@ -134,11 +133,10 @@ Table 2. Evaluation on AIME 2024, AMC 2023, and MATH-500. The results are report
# Training in a Ray cluster with 16 nodes
DATA_PATH=/storage/datasets/
cd $DATA_PATH
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_zero.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_orz_zero.jsonl?download=true
MODEL_PATH=${path_to_Qwen2.5-7B}
bash ./examples/train_7B_zero_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_zero.jsonl $DATA_PATH/id2info.json 24000
bash ./examples/train_7B_zero_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_orz_zero.jsonl 24000
```

View File

@ -114,9 +114,8 @@ We provide a dataset for training. Download the dataset and place it in `/storag
```bash
mkdir -p /storage/datasets/
cd /storage/datasets/
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_r1_distilled.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_zero.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_prompts_for_r1_distilled.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_orz_zero.jsonl?download=true
```
## Model

View File

@ -120,18 +120,8 @@ cd /storage/ray/
```bash
mkdir -p /storage/datasets/
cd /storage/datasets/
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_r1_distilled.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_zero.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
```
如果无法访问 `huggingface.co`,也可以从 ModelScope 下载:
```bash
mkdir -p /storage/datasets/
cd /storage/datasets/
wget https://www.modelscope.cn/datasets/inclusionAI/AReaL-RL-Data/resolve/master/data/prompts_for_r1_distilled.jsonl
wget https://www.modelscope.cn/datasets/inclusionAI/AReaL-RL-Data/resolve/master/data/prompts_for_zero.jsonl
wget https://www.modelscope.cn/datasets/inclusionAI/AReaL-RL-Data/resolve/master/data/id2info.json
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_prompts_for_r1_distilled.jsonl?download=true
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_orz_zero.jsonl?download=true
```
## 模型
@ -146,15 +136,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-R1-D
你也可以在安装 PyPI 和 huggingface_hub 后利用 huggingface CLI 进行下载,具体请参考[官方文档](https://huggingface.co/docs/huggingface_hub/guides/cli)
如果无法访问 `huggingface.co`,也可以从 ModelScope 下载(请确保已经安装了 Git LFS
```
mkdir -p /storage/models
cd /storage/models
git clone https://www.modelscope.cn/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B.git
git clone https://www.modelscope.cn/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B.git
```
## 启动 Ray 集群

View File

@ -0,0 +1,259 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
"""
Dataset Toolkit - Process and validate code/math datasets with flexible input support
"""
import json
import argparse
import logging
import random
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
# Configure console logging
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
logger = logging.getLogger(__name__)
def load_jsonl(file_path: str) -> List[Dict]:
"""Load JSONL file with validation"""
try:
with open(file_path, "r", encoding="utf-8") as f:
return [json.loads(line) for line in f]
except FileNotFoundError:
print(f"ERROR: JSONL file not found: {file_path}")
raise
except json.JSONDecodeError as e:
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
raise
def process_code_data(file_path: str) -> List[Dict]:
"""Process code dataset from JSONL file"""
if not file_path:
return []
raw_data = load_jsonl(file_path)
processed = []
for item in raw_data:
try:
# Field extraction and transformation
input_output = json.loads(item["input_output"])
processed.append(
{
"task": "code",
"query_id": str(item["id"]),
"prompt": item["question"],
"input_output": {
"inputs": input_output.get("inputs", []),
"outputs": input_output.get("outputs", []),
"fn_name": item.get("metadata", {}).get("fn_name", ""),
},
}
)
except KeyError as e:
logger.warning(f"Skipping code item: Missing key {e}")
return processed
def process_math_data(file_path: str) -> List[Dict]:
"""Process math dataset from JSON/JSONL file"""
if not file_path:
return []
raw_data = load_jsonl(file_path)
processed = []
for item in raw_data:
try:
processed.append(
{
"task": "math",
"query_id": str(item["query_id"]),
"prompt": item["prompt"],
"solutions": item["solutions"],
}
)
except KeyError as e:
logger.warning(f"Skipping math item: Missing key {e}")
return processed
def validate_raw_code(item: Dict) -> Tuple[bool, List[str]]:
"""Validate raw code item structure"""
errors = []
required = {
"task": str,
"query_id": str,
"prompt": str,
"input_output": (str, dict),
}
code_input_output_required = {
"inputs": list,
"outputs": list,
#'fn_name': str
}
for field, typ in required.items():
if field not in item:
errors.append(f"Missing required field: {field}")
elif not isinstance(item[field], typ):
errors.append(f"Invalid type for {field}: expected {typ.__name__}")
input_output = item["input_output"]
if isinstance(item["input_output"], str):
input_output = json.loads(item["input_output"])
for io_field, io_typ in code_input_output_required.items():
if io_field not in input_output:
errors.append(f"Missing required field: {io_field} in input_output")
elif not isinstance(input_output[io_field], io_typ):
errors.append(
f"Invalid type for {io_field}: expected {io_typ.__name__} in input_output"
)
return (not errors, errors)
def validate_raw_math(item: Dict) -> Tuple[bool, List[str]]:
"""Validate raw math item structure"""
errors = []
required = {"query_id": str, "prompt": str, "solutions": list}
for field, typ in required.items():
if field not in item:
errors.append(f"Missing required field: {field}")
elif not isinstance(item[field], typ):
type_names = (
[t.__name__ for t in typ] if isinstance(typ, tuple) else typ.__name__
)
errors.append(f"Invalid type for {field}: expected {type_names}")
return (not errors, errors)
def validate_raw_item(item: Dict) -> Tuple[bool, List[str]]:
"""Validate raw code item structure"""
if item.get("task", "") == "math":
return validate_raw_math(item)
else:
return validate_raw_code(item)
def check_item_valid(raw_data: list):
if not raw_data:
return defaultdict(int)
logger.info(f"Validating code-math dataset...")
total_stats = defaultdict(int)
try:
for index, item in enumerate(raw_data):
valid, errors = validate_raw_item(item)
if valid:
total_stats["valid"] += 1
else:
total_stats["invalid"] += 1
logger.debug(f"item {index}: {errors}")
except Exception as e:
logger.error(f"validation failed: {str(e)}")
return defaultdict(int)
return total_stats
def filter_shuffle(shuffle: bool, processed_data):
# Final validation and write output
valid_output = []
invalid_count = 0
for item in processed_data:
valid, errors = validate_raw_item(item)
if valid:
valid_output.append(item)
else:
invalid_count += 1
logger.error(
f'{item["task"]} item {item.get("query_id")} is invalid: {errors}'
)
if shuffle:
random.shuffle(valid_output)
return valid_output, invalid_count
def save_file(output_path: str, processed_data: list):
try:
with open(output_path, "w") as f:
for item in processed_data:
f.write(json.dumps(item) + "\n")
except IOError as e:
logger.error(f"Failed to write output: {str(e)}")
return
def main():
parser = argparse.ArgumentParser(
description="Dataset Toolkit: Process and validate STEM datasets",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--code", help="Path to code dataset (JSONL)")
parser.add_argument("--math", help="Path to math dataset (JSONL)")
parser.add_argument("--output", help="Output file path (JSONL)")
parser.add_argument(
"--mode",
choices=["check", "process"],
default="process",
help="Operation mode: check raw data or process datasets",
)
parser.add_argument(
"--shuffle", action="store_true", help="Shuffle output data in process mode"
)
args = parser.parse_args()
if args.mode == "check":
# Validate raw data structure
raw_data = (
load_jsonl(args.code)
if args.code
else [] + load_jsonl(args.math)
if args.codmathe
else []
)
total_stats = check_item_valid(raw_data)
# Print validation summary
logger.info(
f"\nValidation Summary: {total_stats['valid']} valid, {total_stats['invalid']} invalid"
)
elif args.mode == "process":
# Process and merge datasets
if not args.output:
logger.error("Output file required in process mode")
return
processed_data = []
stats = defaultdict(int)
if args.code:
code_data = process_code_data(args.code)
logger.info(f"Loaded {len(code_data)} code items")
processed_data.extend(code_data)
stats["code"] = len(code_data)
if args.math:
math_data = process_math_data(args.math)
logger.info(f"Loaded {len(math_data)} math items")
processed_data.extend(math_data)
stats["math"] = len(math_data)
processed_data, invalid_count = filter_shuffle(args.shuffle, processed_data)
stats["invalid"] = invalid_count
save_file(args.output, processed_data)
logger.info("\nProcessing Complete:")
logger.info(f"Total items: {len(processed_data)}")
logger.info(f"Code items: {stats['code']}")
logger.info(f"Math items: {stats['math']}")
logger.info(f"Invalid items: {stats['invalid']}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,119 @@
#!/usr/bin/env python3
import argparse
import json
import random
import sys
from pathlib import Path
from typing import List, Dict, Optional
def load_jsonl(file_path: Path) -> List[Dict]:
"""Load JSONL file with validation"""
try:
with file_path.open("r", encoding="utf-8") as f:
return [json.loads(line) for line in f]
except FileNotFoundError:
print(f"ERROR: JSONL file not found: {file_path}")
raise
except json.JSONDecodeError as e:
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
raise
def load_id2info(file_path: Path) -> Dict:
"""Load ID mapping file with structure validation"""
try:
with file_path.open("r", encoding="utf-8") as f:
data = json.load(f)
if not all(isinstance(v, dict) and "solutions" in v for v in data.values()):
raise ValueError("Invalid id2info structure")
return data
except FileNotFoundError:
print(f"ERROR: ID mapping file not found: {file_path}")
raise
except json.JSONDecodeError as e:
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
raise
def process_data(prompts_data: List[Dict], id2info: Dict, output_path: Path) -> None:
"""Process and save transformed data"""
processed = []
missing_ids = 0
for item in prompts_data:
query_id = item.get("query_id")
if not query_id or query_id not in id2info:
missing_ids += 1
continue
processed.append(
{
"prompt": item.get("prompt", ""),
"task": "math",
"query_id": query_id,
"solutions": id2info[query_id].get("solutions", []),
}
)
if missing_ids:
print(f"WARNING: Found {missing_ids} items with missing/invalid query_id")
if not processed:
print("ERROR: No valid data to process")
sys.exit(1)
random.shuffle(processed)
try:
with output_path.open("w", encoding="utf-8") as f:
for item in processed:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
print(f"SUCCESS: Wrote {len(processed)} items to {output_path}")
except IOError as e:
print(f"ERROR: Failed to write output: {str(e)}")
raise
def main():
"""
Command line entry point
Example usage:
python math_process.py \
--prompts_path ./input/prompts.jsonl \
--id2info_path ./input/id2info.json \
--output_path ./output/processed.jsonl
"""
parser = argparse.ArgumentParser(
description="Math dataset processing tool",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--prompts_path",
type=Path,
required=True,
help="Path to prompts.jsonl input file",
)
parser.add_argument(
"--id2info_path",
type=Path,
required=True,
help="Path to id2info.json input file",
)
parser.add_argument(
"--output_path",
type=Path,
default="output.jsonl",
help="Path for processed output file",
)
args = parser.parse_args()
print("Starting data processing...")
try:
prompts_data = load_jsonl(args.prompts_path)
id2info = load_id2info(args.id2info_path)
process_data(prompts_data, id2info, args.output_path)
except Exception as e:
print(f"FATAL: Processing failed - {str(e)}")
sys.exit(1)
print("Operation completed successfully")
if __name__ == "__main__":
main()

View File

@ -6,7 +6,7 @@ TRAIN_BATCH_SIZE="1024"
GROUP_SIZE="8"
NODES="16"
ALLOCATION_MODE="vllm.d64p1m1+d32p2m1"
MAX_NEW_TOKENS=$4
MAX_NEW_TOKENS=$3
MAX_NUM_SEQS=128
PPO_MBS=4
KL_CTL=0.001

View File

@ -5,7 +5,7 @@ TRAIN_BATCH_SIZE="512"
GROUP_SIZE="64"
NODES="16"
ALLOCATION_MODE="vllm.d16p1m4+d32p2m1"
MAX_NEW_TOKENS=$4
MAX_NEW_TOKENS=$3
MAX_NUM_SEQS=128
PPO_MBS=4
KL_CTL=0.0

View File

@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
EXP_NAME=ppo-zero-distill-1.5B-n1
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
DATASET_NAME="prompts_for_r1_distilled.jsonl"
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
NODES=1
ALLOCATION_MODE="actor_gen:d4p1m2,*:d4p2m1"

View File

@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
EXP_NAME=ppo-zero-distill-1.5B-n16
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
DATASET_NAME="prompts_for_r1_distilled.jsonl"
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
NODES=16
ALLOCATION_MODE="vllm.d64p1m1+d32p2m1"

View File

@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
EXP_NAME=ppo-zero-distill-1.5B-n4
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
DATASET_NAME="prompts_for_r1_distilled.jsonl"
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
NODES=4
ALLOCATION_MODE="vllm.d16p1m1+d8p2m1"

View File

@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
EXP_NAME=ppo-zero-distill-7B-n16
MODEL_NAME="DeepSeek-R1-Distill-Qwen-7B"
DATASET_NAME="prompts_for_r1_distilled.jsonl"
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
NODES=16
ALLOCATION_MODE="vllm.d16p1m4+d32p2m1"

View File

@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
EXP_NAME=ppo-zero-7B-zero-n16
MODEL_NAME="Qwen2.5-7B"
DATASET_NAME="prompts_for_zero.jsonl"
DATASET_NAME="full_orz_zero.jsonl"
NODES=16
ALLOCATION_MODE="vllm.d64p1m1+d32p2m1"

View File

@ -723,9 +723,6 @@ def load_shuffle_split_dataset(
if dataset_path.endswith(".jsonl"):
with open(dataset_path, "r") as f:
data = [json.loads(ff) for ff in f]
elif dataset_path.endswith(".json"):
with open(dataset_path, "r") as f:
data = json.load(f)
else:
raise NotImplementedError(f"Unknown dataset extension: {dataset_path}")
else:

View File

@ -49,12 +49,16 @@ def load_metadata(path):
assert str(path).endswith(".jsonl"), path
with open(path, "r") as f:
data = [json.loads(l) for l in f.readlines()]
id2info = {}
omit_cnt = defaultdict(int)
task_cnt = defaultdict(int)
for d in data:
assert d["query_id"] not in d, (d["task"], d["query_id"])
try:
if "task" not in d:
d["task"] = "math"
logger.warning(f'Key "task" not found in the dataset. Use math as default task type.')
if d["task"] == "math":
d = check_math_metadata_entries(d)
elif d["task"] == "code":
@ -202,7 +206,7 @@ else:
),
),
max_length=512,
dataset_path="/storage/openpsi/users/bowei.fw/data/math.jsonl",
dataset_path='/storage/datasets/full_prompts_for_r1_distilled.jsonl'
)
dataloader = torch.utils.data.DataLoader(