This commit is contained in:
bowei.fw 2025-03-22 12:22:34 +08:00
parent 9dcdb7a684
commit 25c45c7e83
7 changed files with 75 additions and 69 deletions

View File

@ -1,32 +0,0 @@
import json
import random
data = []
with open("/storage/openpsi/data/code/apps/test.jsonl", "r") as f:
code_data = [json.loads(l) for l in f.readlines()]
original_keys = list(code_data[0].keys())
print(original_keys)
for d in code_data:
# print(d["starter_code"], type(d["starter_code"]))
# print(json.loads(d["solutions"])[0])
inout = json.loads(d["input_output"])
print(dict(inputs=inout["inputs"][:2], outputs=inout["outputs"][:2]))
exit(0)
d["query_id"] = d["id"]
d["prompt"] = d["question"]
d["task"] = "code"
for k in original_keys:
d.pop(k)
data.append(d)
with open("/storage/openpsi/users/gjx/data/DeepScaleR/prompts.jsonl", "r") as f:
math_data = [json.loads(l) for l in f.readlines()]
for d in math_data:
data.append(dict(prompt=d["prompt"], task="math", query_id=d["query_id"]))
random.shuffle(data)
with open("/storage/openpsi/users/bowei.fw/data/code_math.jsonl", "w") as f:
for d in data:
f.write(json.dumps(d, ensure_ascii=False) + "\n")

View File

@ -7,7 +7,7 @@ from grader import math_equal
def process_results(answer, solution):
extracted_answer = extract_answer(answer, "math", use_last_number=False)
extracted_solution = solution
extracted_solution = extract_answer(solution, "math", use_last_number=True)
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
retval = 0

View File

@ -24,7 +24,7 @@ def math_verify(
for idx, (query_id, generated) in enumerate(zip(query_ids, generateds)):
base_query_id = query_id.split("@idx:")[0]
info = id2info[base_query_id]
for cur_solution in info["answers"]:
for cur_solution in info["solutions"]:
parameters.append((generated, cur_solution, idx))
query_indices.append(idx)

View File

@ -302,6 +302,7 @@ class PPOMATHConfig(CommonExperimentConfig):
rw_interface = ModelInterfaceAbstraction(
"reward",
args=dict(
dataset_path=self.dataset.path,
tokenizer_path=self.actor.path,
output_scaling=self.ppo.reward_output_scaling,
output_bias=self.ppo.reward_output_bias,
@ -336,7 +337,7 @@ class PPOMATHConfig(CommonExperimentConfig):
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=["packed_prompts"],
input_keys=["packed_prompts", "task_ids"],
output_keys=rollout_output_keys,
n_seqs=self.dataset.train_bs_n_seqs,
)
@ -361,20 +362,20 @@ class PPOMATHConfig(CommonExperimentConfig):
interface_type=ModelInterfaceType.INFERENCE,
interface_impl=rw_interface,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=["packed_input_ids", "packed_prompts"],
input_keys=["packed_input_ids", "packed_prompts", "task_ids"],
output_keys=["rewards"],
n_seqs=self.dataset.train_bs_n_seqs,
)
# add rew param into ref MFC
inf_ref_inputs = ["packed_input_ids"]
inf_ref_outputs = ["logprobs"]
inf_ref_outputs = ["packed_ref_logprobs"]
if self.ppo.fuse_rew_ref:
inf_ref_inputs += ["packed_prompts"]
inf_ref_inputs += ["packed_prompts", "task_ids"]
inf_ref_outputs += ["rewards"]
inf_ref_logits = MFCDef(
name="ref_rw",
name="ref_inf",
model_name="ref",
mb_spec=self.ref_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
@ -495,7 +496,6 @@ class PPOMATHConfig(CommonExperimentConfig):
args=dict(
dataset_path=self.dataset.path,
max_length=self.dataset.max_prompt_len,
fill_to_max_length=self.dataset.fill_to_max_length,
filter_threshold=self.dataset_filter_threshold,
max_filter_percentage=self.dataset_max_filter_percentage,
),

View File

@ -3,6 +3,8 @@
# Licensed under the Apache License, Version 2.0 (the "License").
import json
import traceback
from collections import defaultdict
from typing import Callable, Dict, Hashable, List, Optional
import numpy as np
@ -13,17 +15,34 @@ from realhf.base import logging
logger = logging.getLogger("Math Code Dataset")
id2info = {}
def check_math_metadata_entries(data):
assert data["task"] == "math"
assert "query_id" in data
data["query_id"] = str(data["query_id"])
assert isinstance(data["prompt"], str)
assert isinstance(data["solutions"], list)
for sol in data["solutions"]:
assert isinstance(sol, str)
return data
def check_code_metadata_entries(data):
# TODO: check test multi task reward
pass
def check_math_metadata_entries(data):
# TODO: check test multi task reward
pass
assert data["task"] == "code"
assert "query_id" in data
data["query_id"] = str(data["query_id"])
if "problem_id" not in data:
data["problem_id"] = data["query_id"]
assert isinstance(data["prompt"], str)
input_output = json.loads(data["input_output"])
assert len(input_output["inputs"]) == len(input_output["outputs"])
for inp, out in zip(input_output["inputs"], input_output["outputs"]):
assert isinstance(inp, str) and isinstance(out, str), (
inp,
out,
input_output.get("fn_name"),
)
return data
def load_metadata(path):
@ -31,14 +50,26 @@ def load_metadata(path):
with open(path, "r") as f:
data = [json.loads(l) for l in f.readlines()]
id2info = {}
omit_cnt = defaultdict(int)
task_cnt = defaultdict(int)
for d in data:
assert d["query_id"] not in d, (d["task"], d["query_id"])
if d["task"] == "math":
check_math_metadata_entries(d)
elif d["task"] == "code":
check_code_metadata_entries(d)
try:
if d["task"] == "math":
d = check_math_metadata_entries(d)
elif d["task"] == "code":
d = check_code_metadata_entries(d)
except Exception as e:
logger.warning(
f"Data validation failed: query_id {d['query_id']}. "
f"Error: {traceback.format_exc()}. Omit it in the dataset."
)
omit_cnt[d["task"]] += 1
continue
id2info[d["query_id"]] = d
return id2info
task_cnt[d["task"]] += 1
logger.warning(f"Number of ignored data: {dict(**omit_cnt)}")
return id2info, dict(task_cnt)
class MATHCodePromptDataset(torch.utils.data.Dataset):
@ -59,8 +90,7 @@ class MATHCodePromptDataset(torch.utils.data.Dataset):
self._util = util
self.max_length = max_length
global id2info
id2info = load_metadata(dataset_path)
id2info, task_cnt = load_metadata(dataset_path)
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
@ -83,7 +113,9 @@ class MATHCodePromptDataset(torch.utils.data.Dataset):
indices = [
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
]
logger.info(f"{len(indices)} samples remain")
logger.info(
f"{len(indices)} samples remain, among them {task_cnt['math']} are math data and {task_cnt['code']} are code data"
)
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
@ -183,4 +215,5 @@ else:
)
print(f"size: {len(dataset)}")
for d in dataloader:
print(d.ids)
# print(d.ids)
pass

View File

@ -190,7 +190,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
def __post_init__(self):
global id2info
id2info = load_metadata(self.dataset_path)
id2info, _ = load_metadata(self.dataset_path)
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
if constants.parallelism_rank() == 0:
logger.info(f"output_scaling: {self.output_scaling}")

View File

@ -2,6 +2,7 @@
import os
import shutil
import uuid
from typing import *
import pytest
@ -34,14 +35,18 @@ def math_code_dataset(request, save_path):
dataset = []
for i in range(size):
prompt_len = random.randint(1, max_prompt_len)
n_pairs = random.randint(1, 5)
d = dict(
query_id=query_ids[i],
query_id=str(uuid.uuid4()),
prompt=generate_random_sentence(prompt_len),
task=random.choice(["math", "code"]),
)
if d["task"] == "math":
d["solutions"] = [generate_random_sentence(max_resp_len)]
elif d["task"] == "code":
d["input_output"] = json.dumps(dict(inputs=["the\n"], outputs=["the\n"]))
dataset.append(d)
with open(str(save_path / "math_code_dataset.json"), "w") as f:
json.dump(dataset, f)
with open(str(save_path / "math_code_dataset.jsonl"), "a") as f:
f.write(json.dumps(d) + "\n")
return dataset
@ -57,7 +62,7 @@ def math_code_dataset(request, save_path):
def test_ppo_symm(
tmp_path_factory,
tokenizer,
math_dataset,
math_code_dataset,
save_path,
cpu_hf_model,
mconfig,
@ -96,7 +101,7 @@ def test_ppo_symm(
backend="mock_train",
),
dataset=PromptOnlyDatasetConfig(
path=str(save_path / "math_dataset.json"),
path=str(save_path / "math_code_dataset.jsonl"),
max_prompt_len=mconfig.n_positions // 2,
train_bs_n_seqs=minbs,
fill_to_max_length=False,
@ -127,7 +132,7 @@ def test_ppo_symm(
def test_ppo_global_reshard(
tmp_path_factory,
tokenizer,
math_dataset,
math_code_dataset,
save_path,
cpu_hf_model,
mconfig,
@ -174,7 +179,7 @@ def test_ppo_global_reshard(
init_from_scratch=True,
),
dataset=PromptOnlyDatasetConfig(
path=str(save_path / "math_dataset.json"),
path=str(save_path / "math_code_dataset.jsonl"),
max_prompt_len=mconfig.n_positions // 2,
train_bs_n_seqs=minbs,
fill_to_max_length=False,
@ -244,7 +249,7 @@ def test_ppo_global_reshard(
def test_ppo_param_realloc_sub_device_mesh(
tmp_path_factory,
tokenizer,
math_dataset,
math_code_dataset,
save_path,
cpu_hf_model,
mconfig,
@ -287,7 +292,7 @@ def test_ppo_param_realloc_sub_device_mesh(
init_from_scratch=True,
),
dataset=PromptOnlyDatasetConfig(
path=str(save_path / "math_dataset.json"),
path=str(save_path / "math_code_dataset.jsonl"),
max_prompt_len=mconfig.n_positions // 2,
train_bs_n_seqs=minbs,
fill_to_max_length=False,
@ -406,7 +411,7 @@ def test_ppo_save(
init_from_scratch=True,
),
dataset=PromptOnlyDatasetConfig(
path=str(save_path / "math_dataset.json"),
path=str(save_path / "math_code_dataset.jsonl"),
max_prompt_len=mconfig.n_positions // 2,
train_bs_n_seqs=bs,
fill_to_max_length=False,