mirror of https://github.com/inclusionAI/AReaL
update dataset for fused-refrw
This commit is contained in:
parent
fc79f21622
commit
abfe8bd30f
14
README.md
14
README.md
|
@ -54,22 +54,21 @@ Please refer to the [tutorial](/examples/README.md) under the `examples` directo
|
|||
# Download the dataset
|
||||
DATA_PATH=/storage/datasets/
|
||||
cd $DATA_PATH
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_r1_distilled.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_prompts_for_r1_distilled.jsonl?download=true
|
||||
|
||||
# Training in a Ray cluster with 16 nodes
|
||||
|
||||
# stage 1
|
||||
MODEL_PATH=${path_to_DeepSeek-R1-Distill-Qwen-1.5B}
|
||||
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_r1_distilled.jsonl $DATA_PATH/id2info.json 8192
|
||||
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_prompts_for_r1_distilled.jsonl 8192
|
||||
|
||||
# stage 2
|
||||
MODEL_PATH=${model_path_from_stage_1}
|
||||
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_r1_distilled.jsonl $DATA_PATH/id2info.json 16384
|
||||
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_prompts_for_r1_distilled.jsonl 16384
|
||||
|
||||
# stage 3
|
||||
MODEL_PATH=${model_path_from_stage_2}
|
||||
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_r1_distilled.jsonl $DATA_PATH/id2info.json 24000
|
||||
bash ./examples/train_1.5B_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_prompts_for_r1_distilled.jsonl 24000
|
||||
|
||||
```
|
||||
|
||||
|
@ -134,11 +133,10 @@ Table 2. Evaluation on AIME 2024, AMC 2023, and MATH-500. The results are report
|
|||
# Training in a Ray cluster with 16 nodes
|
||||
DATA_PATH=/storage/datasets/
|
||||
cd $DATA_PATH
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_zero.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_orz_zero.jsonl?download=true
|
||||
|
||||
MODEL_PATH=${path_to_Qwen2.5-7B}
|
||||
bash ./examples/train_7B_zero_n16_on_ray.sh $MODEL_PATH $DATA_PATH/prompts_for_zero.jsonl $DATA_PATH/id2info.json 24000
|
||||
bash ./examples/train_7B_zero_n16_on_ray.sh $MODEL_PATH $DATA_PATH/full_orz_zero.jsonl 24000
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -114,9 +114,8 @@ We provide a dataset for training. Download the dataset and place it in `/storag
|
|||
```bash
|
||||
mkdir -p /storage/datasets/
|
||||
cd /storage/datasets/
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_r1_distilled.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_zero.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_prompts_for_r1_distilled.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_orz_zero.jsonl?download=true
|
||||
```
|
||||
|
||||
## Model
|
||||
|
|
|
@ -120,18 +120,8 @@ cd /storage/ray/
|
|||
```bash
|
||||
mkdir -p /storage/datasets/
|
||||
cd /storage/datasets/
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_r1_distilled.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/prompts_for_zero.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/id2info.json?download=true
|
||||
```
|
||||
|
||||
如果无法访问 `huggingface.co`,也可以从 ModelScope 下载:
|
||||
```bash
|
||||
mkdir -p /storage/datasets/
|
||||
cd /storage/datasets/
|
||||
wget https://www.modelscope.cn/datasets/inclusionAI/AReaL-RL-Data/resolve/master/data/prompts_for_r1_distilled.jsonl
|
||||
wget https://www.modelscope.cn/datasets/inclusionAI/AReaL-RL-Data/resolve/master/data/prompts_for_zero.jsonl
|
||||
wget https://www.modelscope.cn/datasets/inclusionAI/AReaL-RL-Data/resolve/master/data/id2info.json
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_prompts_for_r1_distilled.jsonl?download=true
|
||||
wget https://huggingface.co/datasets/inclusionAI/AReaL-RL-Data/resolve/main/data/full_orz_zero.jsonl?download=true
|
||||
```
|
||||
|
||||
## 模型
|
||||
|
@ -146,15 +136,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-R1-D
|
|||
|
||||
你也可以在安装 PyPI 和 huggingface_hub 后利用 huggingface CLI 进行下载,具体请参考[官方文档](https://huggingface.co/docs/huggingface_hub/guides/cli)
|
||||
|
||||
如果无法访问 `huggingface.co`,也可以从 ModelScope 下载(请确保已经安装了 Git LFS):
|
||||
|
||||
```
|
||||
mkdir -p /storage/models
|
||||
cd /storage/models
|
||||
git clone https://www.modelscope.cn/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B.git
|
||||
git clone https://www.modelscope.cn/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B.git
|
||||
```
|
||||
|
||||
|
||||
## 启动 Ray 集群
|
||||
|
||||
|
|
|
@ -0,0 +1,259 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
"""
|
||||
Dataset Toolkit - Process and validate code/math datasets with flexible input support
|
||||
"""
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
# Configure console logging
|
||||
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_jsonl(file_path: str) -> List[Dict]:
|
||||
"""Load JSONL file with validation"""
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return [json.loads(line) for line in f]
|
||||
except FileNotFoundError:
|
||||
print(f"ERROR: JSONL file not found: {file_path}")
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def process_code_data(file_path: str) -> List[Dict]:
|
||||
"""Process code dataset from JSONL file"""
|
||||
if not file_path:
|
||||
return []
|
||||
|
||||
raw_data = load_jsonl(file_path)
|
||||
processed = []
|
||||
|
||||
for item in raw_data:
|
||||
try:
|
||||
# Field extraction and transformation
|
||||
input_output = json.loads(item["input_output"])
|
||||
processed.append(
|
||||
{
|
||||
"task": "code",
|
||||
"query_id": str(item["id"]),
|
||||
"prompt": item["question"],
|
||||
"input_output": {
|
||||
"inputs": input_output.get("inputs", []),
|
||||
"outputs": input_output.get("outputs", []),
|
||||
"fn_name": item.get("metadata", {}).get("fn_name", ""),
|
||||
},
|
||||
}
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.warning(f"Skipping code item: Missing key {e}")
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
def process_math_data(file_path: str) -> List[Dict]:
|
||||
"""Process math dataset from JSON/JSONL file"""
|
||||
if not file_path:
|
||||
return []
|
||||
|
||||
raw_data = load_jsonl(file_path)
|
||||
processed = []
|
||||
|
||||
for item in raw_data:
|
||||
try:
|
||||
processed.append(
|
||||
{
|
||||
"task": "math",
|
||||
"query_id": str(item["query_id"]),
|
||||
"prompt": item["prompt"],
|
||||
"solutions": item["solutions"],
|
||||
}
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.warning(f"Skipping math item: Missing key {e}")
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
def validate_raw_code(item: Dict) -> Tuple[bool, List[str]]:
|
||||
"""Validate raw code item structure"""
|
||||
errors = []
|
||||
required = {
|
||||
"task": str,
|
||||
"query_id": str,
|
||||
"prompt": str,
|
||||
"input_output": (str, dict),
|
||||
}
|
||||
|
||||
code_input_output_required = {
|
||||
"inputs": list,
|
||||
"outputs": list,
|
||||
#'fn_name': str
|
||||
}
|
||||
|
||||
for field, typ in required.items():
|
||||
if field not in item:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
elif not isinstance(item[field], typ):
|
||||
errors.append(f"Invalid type for {field}: expected {typ.__name__}")
|
||||
|
||||
input_output = item["input_output"]
|
||||
if isinstance(item["input_output"], str):
|
||||
input_output = json.loads(item["input_output"])
|
||||
|
||||
for io_field, io_typ in code_input_output_required.items():
|
||||
if io_field not in input_output:
|
||||
errors.append(f"Missing required field: {io_field} in input_output")
|
||||
elif not isinstance(input_output[io_field], io_typ):
|
||||
errors.append(
|
||||
f"Invalid type for {io_field}: expected {io_typ.__name__} in input_output"
|
||||
)
|
||||
|
||||
return (not errors, errors)
|
||||
|
||||
|
||||
def validate_raw_math(item: Dict) -> Tuple[bool, List[str]]:
|
||||
"""Validate raw math item structure"""
|
||||
errors = []
|
||||
required = {"query_id": str, "prompt": str, "solutions": list}
|
||||
|
||||
for field, typ in required.items():
|
||||
if field not in item:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
elif not isinstance(item[field], typ):
|
||||
type_names = (
|
||||
[t.__name__ for t in typ] if isinstance(typ, tuple) else typ.__name__
|
||||
)
|
||||
errors.append(f"Invalid type for {field}: expected {type_names}")
|
||||
|
||||
return (not errors, errors)
|
||||
|
||||
|
||||
def validate_raw_item(item: Dict) -> Tuple[bool, List[str]]:
|
||||
"""Validate raw code item structure"""
|
||||
if item.get("task", "") == "math":
|
||||
return validate_raw_math(item)
|
||||
else:
|
||||
return validate_raw_code(item)
|
||||
|
||||
|
||||
def check_item_valid(raw_data: list):
|
||||
if not raw_data:
|
||||
return defaultdict(int)
|
||||
logger.info(f"Validating code-math dataset...")
|
||||
total_stats = defaultdict(int)
|
||||
try:
|
||||
for index, item in enumerate(raw_data):
|
||||
valid, errors = validate_raw_item(item)
|
||||
if valid:
|
||||
total_stats["valid"] += 1
|
||||
else:
|
||||
total_stats["invalid"] += 1
|
||||
logger.debug(f"item {index}: {errors}")
|
||||
except Exception as e:
|
||||
logger.error(f"validation failed: {str(e)}")
|
||||
return defaultdict(int)
|
||||
return total_stats
|
||||
|
||||
|
||||
def filter_shuffle(shuffle: bool, processed_data):
|
||||
# Final validation and write output
|
||||
valid_output = []
|
||||
invalid_count = 0
|
||||
for item in processed_data:
|
||||
valid, errors = validate_raw_item(item)
|
||||
if valid:
|
||||
valid_output.append(item)
|
||||
else:
|
||||
invalid_count += 1
|
||||
logger.error(
|
||||
f'{item["task"]} item {item.get("query_id")} is invalid: {errors}'
|
||||
)
|
||||
if shuffle:
|
||||
random.shuffle(valid_output)
|
||||
return valid_output, invalid_count
|
||||
|
||||
|
||||
def save_file(output_path: str, processed_data: list):
|
||||
try:
|
||||
with open(output_path, "w") as f:
|
||||
for item in processed_data:
|
||||
f.write(json.dumps(item) + "\n")
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to write output: {str(e)}")
|
||||
return
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Dataset Toolkit: Process and validate STEM datasets",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--code", help="Path to code dataset (JSONL)")
|
||||
parser.add_argument("--math", help="Path to math dataset (JSONL)")
|
||||
parser.add_argument("--output", help="Output file path (JSONL)")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["check", "process"],
|
||||
default="process",
|
||||
help="Operation mode: check raw data or process datasets",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shuffle", action="store_true", help="Shuffle output data in process mode"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.mode == "check":
|
||||
# Validate raw data structure
|
||||
raw_data = (
|
||||
load_jsonl(args.code)
|
||||
if args.code
|
||||
else [] + load_jsonl(args.math)
|
||||
if args.codmathe
|
||||
else []
|
||||
)
|
||||
total_stats = check_item_valid(raw_data)
|
||||
|
||||
# Print validation summary
|
||||
logger.info(
|
||||
f"\nValidation Summary: {total_stats['valid']} valid, {total_stats['invalid']} invalid"
|
||||
)
|
||||
elif args.mode == "process":
|
||||
# Process and merge datasets
|
||||
if not args.output:
|
||||
logger.error("Output file required in process mode")
|
||||
return
|
||||
processed_data = []
|
||||
stats = defaultdict(int)
|
||||
if args.code:
|
||||
code_data = process_code_data(args.code)
|
||||
logger.info(f"Loaded {len(code_data)} code items")
|
||||
processed_data.extend(code_data)
|
||||
stats["code"] = len(code_data)
|
||||
if args.math:
|
||||
math_data = process_math_data(args.math)
|
||||
logger.info(f"Loaded {len(math_data)} math items")
|
||||
processed_data.extend(math_data)
|
||||
stats["math"] = len(math_data)
|
||||
|
||||
processed_data, invalid_count = filter_shuffle(args.shuffle, processed_data)
|
||||
stats["invalid"] = invalid_count
|
||||
save_file(args.output, processed_data)
|
||||
logger.info("\nProcessing Complete:")
|
||||
logger.info(f"Total items: {len(processed_data)}")
|
||||
logger.info(f"Code items: {stats['code']}")
|
||||
logger.info(f"Math items: {stats['math']}")
|
||||
logger.info(f"Invalid items: {stats['invalid']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,119 @@
|
|||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
|
||||
def load_jsonl(file_path: Path) -> List[Dict]:
|
||||
"""Load JSONL file with validation"""
|
||||
try:
|
||||
with file_path.open("r", encoding="utf-8") as f:
|
||||
return [json.loads(line) for line in f]
|
||||
except FileNotFoundError:
|
||||
print(f"ERROR: JSONL file not found: {file_path}")
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def load_id2info(file_path: Path) -> Dict:
|
||||
"""Load ID mapping file with structure validation"""
|
||||
try:
|
||||
with file_path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not all(isinstance(v, dict) and "solutions" in v for v in data.values()):
|
||||
raise ValueError("Invalid id2info structure")
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
print(f"ERROR: ID mapping file not found: {file_path}")
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def process_data(prompts_data: List[Dict], id2info: Dict, output_path: Path) -> None:
|
||||
"""Process and save transformed data"""
|
||||
processed = []
|
||||
missing_ids = 0
|
||||
for item in prompts_data:
|
||||
query_id = item.get("query_id")
|
||||
if not query_id or query_id not in id2info:
|
||||
missing_ids += 1
|
||||
continue
|
||||
processed.append(
|
||||
{
|
||||
"prompt": item.get("prompt", ""),
|
||||
"task": "math",
|
||||
"query_id": query_id,
|
||||
"solutions": id2info[query_id].get("solutions", []),
|
||||
}
|
||||
)
|
||||
if missing_ids:
|
||||
print(f"WARNING: Found {missing_ids} items with missing/invalid query_id")
|
||||
if not processed:
|
||||
print("ERROR: No valid data to process")
|
||||
sys.exit(1)
|
||||
random.shuffle(processed)
|
||||
|
||||
try:
|
||||
with output_path.open("w", encoding="utf-8") as f:
|
||||
for item in processed:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||
print(f"SUCCESS: Wrote {len(processed)} items to {output_path}")
|
||||
except IOError as e:
|
||||
print(f"ERROR: Failed to write output: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Command line entry point
|
||||
|
||||
Example usage:
|
||||
python math_process.py \
|
||||
--prompts_path ./input/prompts.jsonl \
|
||||
--id2info_path ./input/id2info.json \
|
||||
--output_path ./output/processed.jsonl
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Math dataset processing tool",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompts_path",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to prompts.jsonl input file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--id2info_path",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to id2info.json input file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=Path,
|
||||
default="output.jsonl",
|
||||
help="Path for processed output file",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
print("Starting data processing...")
|
||||
try:
|
||||
prompts_data = load_jsonl(args.prompts_path)
|
||||
id2info = load_id2info(args.id2info_path)
|
||||
process_data(prompts_data, id2info, args.output_path)
|
||||
except Exception as e:
|
||||
print(f"FATAL: Processing failed - {str(e)}")
|
||||
sys.exit(1)
|
||||
print("Operation completed successfully")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -6,7 +6,7 @@ TRAIN_BATCH_SIZE="1024"
|
|||
GROUP_SIZE="8"
|
||||
NODES="16"
|
||||
ALLOCATION_MODE="vllm.d64p1m1+d32p2m1"
|
||||
MAX_NEW_TOKENS=$4
|
||||
MAX_NEW_TOKENS=$3
|
||||
MAX_NUM_SEQS=128
|
||||
PPO_MBS=4
|
||||
KL_CTL=0.001
|
||||
|
|
|
@ -5,7 +5,7 @@ TRAIN_BATCH_SIZE="512"
|
|||
GROUP_SIZE="64"
|
||||
NODES="16"
|
||||
ALLOCATION_MODE="vllm.d16p1m4+d32p2m1"
|
||||
MAX_NEW_TOKENS=$4
|
||||
MAX_NEW_TOKENS=$3
|
||||
MAX_NUM_SEQS=128
|
||||
PPO_MBS=4
|
||||
KL_CTL=0.0
|
||||
|
|
|
@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
|||
|
||||
EXP_NAME=ppo-zero-distill-1.5B-n1
|
||||
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
DATASET_NAME="prompts_for_r1_distilled.jsonl"
|
||||
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
|
||||
NODES=1
|
||||
ALLOCATION_MODE="actor_gen:d4p1m2,*:d4p2m1"
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
|||
|
||||
EXP_NAME=ppo-zero-distill-1.5B-n16
|
||||
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
DATASET_NAME="prompts_for_r1_distilled.jsonl"
|
||||
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
|
||||
NODES=16
|
||||
ALLOCATION_MODE="vllm.d64p1m1+d32p2m1"
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
|||
|
||||
EXP_NAME=ppo-zero-distill-1.5B-n4
|
||||
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
DATASET_NAME="prompts_for_r1_distilled.jsonl"
|
||||
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
|
||||
NODES=4
|
||||
ALLOCATION_MODE="vllm.d16p1m1+d8p2m1"
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
|||
|
||||
EXP_NAME=ppo-zero-distill-7B-n16
|
||||
MODEL_NAME="DeepSeek-R1-Distill-Qwen-7B"
|
||||
DATASET_NAME="prompts_for_r1_distilled.jsonl"
|
||||
DATASET_NAME="full_prompts_for_r1_distilled.jsonl"
|
||||
NODES=16
|
||||
ALLOCATION_MODE="vllm.d16p1m4+d32p2m1"
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
|||
|
||||
EXP_NAME=ppo-zero-7B-zero-n16
|
||||
MODEL_NAME="Qwen2.5-7B"
|
||||
DATASET_NAME="prompts_for_zero.jsonl"
|
||||
DATASET_NAME="full_orz_zero.jsonl"
|
||||
NODES=16
|
||||
ALLOCATION_MODE="vllm.d64p1m1+d32p2m1"
|
||||
|
||||
|
|
|
@ -723,9 +723,6 @@ def load_shuffle_split_dataset(
|
|||
if dataset_path.endswith(".jsonl"):
|
||||
with open(dataset_path, "r") as f:
|
||||
data = [json.loads(ff) for ff in f]
|
||||
elif dataset_path.endswith(".json"):
|
||||
with open(dataset_path, "r") as f:
|
||||
data = json.load(f)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown dataset extension: {dataset_path}")
|
||||
else:
|
||||
|
|
|
@ -49,12 +49,16 @@ def load_metadata(path):
|
|||
assert str(path).endswith(".jsonl"), path
|
||||
with open(path, "r") as f:
|
||||
data = [json.loads(l) for l in f.readlines()]
|
||||
|
||||
id2info = {}
|
||||
omit_cnt = defaultdict(int)
|
||||
task_cnt = defaultdict(int)
|
||||
for d in data:
|
||||
assert d["query_id"] not in d, (d["task"], d["query_id"])
|
||||
try:
|
||||
if "task" not in d:
|
||||
d["task"] = "math"
|
||||
logger.warning(f'Key "task" not found in the dataset. Use math as default task type.')
|
||||
if d["task"] == "math":
|
||||
d = check_math_metadata_entries(d)
|
||||
elif d["task"] == "code":
|
||||
|
@ -202,7 +206,7 @@ else:
|
|||
),
|
||||
),
|
||||
max_length=512,
|
||||
dataset_path="/storage/openpsi/users/bowei.fw/data/math.jsonl",
|
||||
dataset_path='/storage/datasets/full_prompts_for_r1_distilled.jsonl'
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
|
|
Loading…
Reference in New Issue