mirror of https://github.com/inclusionAI/AReaL
249 lines
7.9 KiB
Python
249 lines
7.9 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
"""
|
|
Dataset Toolkit - Process and validate code/math datasets with flexible input support
|
|
"""
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import random
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
# Configure console logging
|
|
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def load_jsonl(file_path: str) -> List[Dict]:
|
|
"""Load JSONL file with validation"""
|
|
try:
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
return [json.loads(line) for line in f]
|
|
except FileNotFoundError:
|
|
print(f"ERROR: JSONL file not found: {file_path}")
|
|
raise
|
|
except json.JSONDecodeError as e:
|
|
print(f"ERROR: JSON parsing failed in {file_path}: {str(e)}")
|
|
raise
|
|
|
|
|
|
def process_code_data(file_path: str) -> List[Dict]:
|
|
"""Process code dataset from JSONL file"""
|
|
if not file_path:
|
|
return []
|
|
|
|
raw_data = load_jsonl(file_path)
|
|
processed = []
|
|
|
|
for item in raw_data:
|
|
# Field extraction and transformation
|
|
input_output = json.loads(item["input_output"])
|
|
processed.append(
|
|
{
|
|
"task": "code",
|
|
"query_id": str(item["id"]),
|
|
"prompt": item["question"],
|
|
"solutions": item.get("solutions", []),
|
|
"input_output": json.dumps(
|
|
{
|
|
"inputs": input_output.get("inputs", []),
|
|
"outputs": input_output.get("outputs", []),
|
|
"fn_name": item.get("metadata", {}).get("fn_name", ""),
|
|
"remote": False,
|
|
}
|
|
),
|
|
"language": item.get("language", "PYTHON"),
|
|
}
|
|
)
|
|
|
|
case_size = sys.getsizeof(processed[-1]["input_output"])
|
|
assert (
|
|
case_size < 500 * 1024
|
|
), f"'input_output' exceeds 500KB ({case_size} bytes). Use remote testcase instead."
|
|
|
|
return processed
|
|
|
|
|
|
def process_math_data(file_path: str) -> List[Dict]:
|
|
"""Process math dataset from JSON/JSONL file"""
|
|
if not file_path:
|
|
return []
|
|
|
|
raw_data = load_jsonl(file_path)
|
|
processed = []
|
|
|
|
for item in raw_data:
|
|
processed.append(
|
|
{
|
|
"task": "math",
|
|
"query_id": str(item["query_id"]),
|
|
"prompt": item["prompt"],
|
|
"solutions": item.get("solutions", []),
|
|
}
|
|
)
|
|
|
|
return processed
|
|
|
|
|
|
def validate_raw_code(item: Dict) -> Tuple[bool, List[str]]:
|
|
"""Validate raw code item structure"""
|
|
errors = []
|
|
required = {
|
|
"task": str,
|
|
"query_id": str,
|
|
"prompt": str,
|
|
"input_output": str,
|
|
}
|
|
|
|
code_input_output_required = {
|
|
"inputs": list,
|
|
"outputs": list,
|
|
#'fn_name': str
|
|
}
|
|
|
|
for field, typ in required.items():
|
|
if field not in item:
|
|
errors.append(f"Missing required field: {field}")
|
|
elif not isinstance(item[field], typ):
|
|
errors.append(f"Invalid type for {field}: expected {typ.__name__}")
|
|
|
|
input_output = json.loads(item["input_output"])
|
|
|
|
for io_field, io_typ in code_input_output_required.items():
|
|
if io_field not in input_output:
|
|
errors.append(f"Missing required field: {io_field} in input_output")
|
|
elif not isinstance(input_output[io_field], io_typ):
|
|
errors.append(
|
|
f"Invalid type for {io_field}: expected {io_typ.__name__} in input_output"
|
|
)
|
|
|
|
return (not errors, errors)
|
|
|
|
|
|
def validate_raw_math(item: Dict) -> Tuple[bool, List[str]]:
|
|
"""Validate raw math item structure"""
|
|
errors = []
|
|
required = {"query_id": str, "prompt": str, "solutions": list}
|
|
|
|
for field, typ in required.items():
|
|
if field not in item:
|
|
errors.append(f"Missing required field: {field}")
|
|
elif not isinstance(item[field], typ):
|
|
type_names = (
|
|
[t.__name__ for t in typ] if isinstance(typ, tuple) else typ.__name__
|
|
)
|
|
errors.append(f"Invalid type for {field}: expected {type_names}")
|
|
|
|
return (not errors, errors)
|
|
|
|
|
|
def validate_raw_item(item: Dict) -> Tuple[bool, List[str]]:
|
|
"""Validate raw code item structure"""
|
|
if item.get("task", "") == "math":
|
|
return validate_raw_math(item)
|
|
else:
|
|
return validate_raw_code(item)
|
|
|
|
|
|
def check_item_valid(raw_data: list):
|
|
if not raw_data:
|
|
return defaultdict(int)
|
|
logger.info(f"Validating code-math dataset...")
|
|
total_stats = defaultdict(int)
|
|
for index, item in enumerate(raw_data):
|
|
valid, errors = validate_raw_item(item)
|
|
if valid:
|
|
total_stats["valid"] += 1
|
|
else:
|
|
total_stats["invalid"] += 1
|
|
logger.warning(f"item {index}: {errors}")
|
|
return total_stats
|
|
|
|
|
|
def filter_shuffle(shuffle: bool, processed_data):
|
|
# Final validation and write output
|
|
valid_output = []
|
|
invalid_count = 0
|
|
for item in processed_data:
|
|
valid, errors = validate_raw_item(item)
|
|
if valid:
|
|
valid_output.append(item)
|
|
else:
|
|
invalid_count += 1
|
|
logger.error(
|
|
f'{item["task"]} item {item.get("query_id")} is invalid: {errors}'
|
|
)
|
|
if shuffle:
|
|
random.shuffle(valid_output)
|
|
return valid_output, invalid_count
|
|
|
|
|
|
def save_file(output_path: str, processed_data: list):
|
|
with open(output_path, "w") as f:
|
|
for item in processed_data:
|
|
f.write(json.dumps(item) + "\n")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Dataset Toolkit: Process and validate STEM datasets",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
parser.add_argument("--code", help="Path to code dataset (JSONL)")
|
|
parser.add_argument("--math", help="Path to math dataset (JSONL)")
|
|
parser.add_argument("--output", help="Output file path (JSONL)")
|
|
parser.add_argument(
|
|
"--mode",
|
|
choices=["check", "process"],
|
|
default="check",
|
|
help="Operation mode: check raw data or process datasets",
|
|
)
|
|
parser.add_argument(
|
|
"--shuffle", action="store_true", help="Shuffle output data in process mode"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
if args.mode == "check":
|
|
# Validate raw data structure
|
|
code_data = load_jsonl(args.code) if args.code else []
|
|
math_data = load_jsonl(args.math) if args.math else []
|
|
total_stats = check_item_valid(code_data + math_data)
|
|
|
|
# Print validation summary
|
|
logger.info(
|
|
f"\nValidation Summary: {total_stats['valid']} valid, {total_stats['invalid']} invalid"
|
|
)
|
|
elif args.mode == "process":
|
|
# Process and merge datasets
|
|
if not args.output:
|
|
logger.error("Output file required in process mode")
|
|
return
|
|
processed_data = []
|
|
stats = defaultdict(int)
|
|
if args.code:
|
|
code_data = process_code_data(args.code)
|
|
logger.info(f"Loaded {len(code_data)} code items")
|
|
processed_data.extend(code_data)
|
|
stats["code"] = len(code_data)
|
|
if args.math:
|
|
math_data = process_math_data(args.math)
|
|
logger.info(f"Loaded {len(math_data)} math items")
|
|
processed_data.extend(math_data)
|
|
stats["math"] = len(math_data)
|
|
|
|
processed_data, invalid_count = filter_shuffle(args.shuffle, processed_data)
|
|
stats["invalid"] = invalid_count
|
|
save_file(args.output, processed_data)
|
|
logger.info("\nProcessing Complete:")
|
|
logger.info(f"Total items: {len(processed_data)}")
|
|
logger.info(f"Code items: {stats['code']}")
|
|
logger.info(f"Math items: {stats['math']}")
|
|
logger.info(f"Invalid items: {stats['invalid']}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|