This commit is contained in:
bowei.fw 2025-03-22 12:22:34 +08:00
parent 9dcdb7a684
commit 25c45c7e83
7 changed files with 75 additions and 69 deletions

View File

@ -1,32 +0,0 @@
import json
import random
data = []
with open("/storage/openpsi/data/code/apps/test.jsonl", "r") as f:
code_data = [json.loads(l) for l in f.readlines()]
original_keys = list(code_data[0].keys())
print(original_keys)
for d in code_data:
# print(d["starter_code"], type(d["starter_code"]))
# print(json.loads(d["solutions"])[0])
inout = json.loads(d["input_output"])
print(dict(inputs=inout["inputs"][:2], outputs=inout["outputs"][:2]))
exit(0)
d["query_id"] = d["id"]
d["prompt"] = d["question"]
d["task"] = "code"
for k in original_keys:
d.pop(k)
data.append(d)
with open("/storage/openpsi/users/gjx/data/DeepScaleR/prompts.jsonl", "r") as f:
math_data = [json.loads(l) for l in f.readlines()]
for d in math_data:
data.append(dict(prompt=d["prompt"], task="math", query_id=d["query_id"]))
random.shuffle(data)
with open("/storage/openpsi/users/bowei.fw/data/code_math.jsonl", "w") as f:
for d in data:
f.write(json.dumps(d, ensure_ascii=False) + "\n")

View File

@ -7,7 +7,7 @@ from grader import math_equal
def process_results(answer, solution): def process_results(answer, solution):
extracted_answer = extract_answer(answer, "math", use_last_number=False) extracted_answer = extract_answer(answer, "math", use_last_number=False)
extracted_solution = solution extracted_solution = extract_answer(solution, "math", use_last_number=True)
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]: if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
retval = 0 retval = 0

View File

@ -24,7 +24,7 @@ def math_verify(
for idx, (query_id, generated) in enumerate(zip(query_ids, generateds)): for idx, (query_id, generated) in enumerate(zip(query_ids, generateds)):
base_query_id = query_id.split("@idx:")[0] base_query_id = query_id.split("@idx:")[0]
info = id2info[base_query_id] info = id2info[base_query_id]
for cur_solution in info["answers"]: for cur_solution in info["solutions"]:
parameters.append((generated, cur_solution, idx)) parameters.append((generated, cur_solution, idx))
query_indices.append(idx) query_indices.append(idx)

View File

@ -302,6 +302,7 @@ class PPOMATHConfig(CommonExperimentConfig):
rw_interface = ModelInterfaceAbstraction( rw_interface = ModelInterfaceAbstraction(
"reward", "reward",
args=dict( args=dict(
dataset_path=self.dataset.path,
tokenizer_path=self.actor.path, tokenizer_path=self.actor.path,
output_scaling=self.ppo.reward_output_scaling, output_scaling=self.ppo.reward_output_scaling,
output_bias=self.ppo.reward_output_bias, output_bias=self.ppo.reward_output_bias,
@ -336,7 +337,7 @@ class PPOMATHConfig(CommonExperimentConfig):
model_type=self.actor.type, model_type=self.actor.type,
model_path=self.actor.path, model_path=self.actor.path,
interface_impl=actor_interface, interface_impl=actor_interface,
input_keys=["packed_prompts"], input_keys=["packed_prompts", "task_ids"],
output_keys=rollout_output_keys, output_keys=rollout_output_keys,
n_seqs=self.dataset.train_bs_n_seqs, n_seqs=self.dataset.train_bs_n_seqs,
) )
@ -361,20 +362,20 @@ class PPOMATHConfig(CommonExperimentConfig):
interface_type=ModelInterfaceType.INFERENCE, interface_type=ModelInterfaceType.INFERENCE,
interface_impl=rw_interface, interface_impl=rw_interface,
min_n_seqs_per_pass=1 / self.group_size, min_n_seqs_per_pass=1 / self.group_size,
input_keys=["packed_input_ids", "packed_prompts"], input_keys=["packed_input_ids", "packed_prompts", "task_ids"],
output_keys=["rewards"], output_keys=["rewards"],
n_seqs=self.dataset.train_bs_n_seqs, n_seqs=self.dataset.train_bs_n_seqs,
) )
# add rew param into ref MFC # add rew param into ref MFC
inf_ref_inputs = ["packed_input_ids"] inf_ref_inputs = ["packed_input_ids"]
inf_ref_outputs = ["logprobs"] inf_ref_outputs = ["packed_ref_logprobs"]
if self.ppo.fuse_rew_ref: if self.ppo.fuse_rew_ref:
inf_ref_inputs += ["packed_prompts"] inf_ref_inputs += ["packed_prompts", "task_ids"]
inf_ref_outputs += ["rewards"] inf_ref_outputs += ["rewards"]
inf_ref_logits = MFCDef( inf_ref_logits = MFCDef(
name="ref_rw", name="ref_inf",
model_name="ref", model_name="ref",
mb_spec=self.ref_inf.mb_spec, mb_spec=self.ref_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE, interface_type=ModelInterfaceType.INFERENCE,
@ -495,7 +496,6 @@ class PPOMATHConfig(CommonExperimentConfig):
args=dict( args=dict(
dataset_path=self.dataset.path, dataset_path=self.dataset.path,
max_length=self.dataset.max_prompt_len, max_length=self.dataset.max_prompt_len,
fill_to_max_length=self.dataset.fill_to_max_length,
filter_threshold=self.dataset_filter_threshold, filter_threshold=self.dataset_filter_threshold,
max_filter_percentage=self.dataset_max_filter_percentage, max_filter_percentage=self.dataset_max_filter_percentage,
), ),

View File

@ -3,6 +3,8 @@
# Licensed under the Apache License, Version 2.0 (the "License"). # Licensed under the Apache License, Version 2.0 (the "License").
import json import json
import traceback
from collections import defaultdict
from typing import Callable, Dict, Hashable, List, Optional from typing import Callable, Dict, Hashable, List, Optional
import numpy as np import numpy as np
@ -13,17 +15,34 @@ from realhf.base import logging
logger = logging.getLogger("Math Code Dataset") logger = logging.getLogger("Math Code Dataset")
id2info = {}
def check_math_metadata_entries(data):
assert data["task"] == "math"
assert "query_id" in data
data["query_id"] = str(data["query_id"])
assert isinstance(data["prompt"], str)
assert isinstance(data["solutions"], list)
for sol in data["solutions"]:
assert isinstance(sol, str)
return data
def check_code_metadata_entries(data): def check_code_metadata_entries(data):
# TODO: check test multi task reward assert data["task"] == "code"
pass assert "query_id" in data
data["query_id"] = str(data["query_id"])
if "problem_id" not in data:
def check_math_metadata_entries(data): data["problem_id"] = data["query_id"]
# TODO: check test multi task reward assert isinstance(data["prompt"], str)
pass input_output = json.loads(data["input_output"])
assert len(input_output["inputs"]) == len(input_output["outputs"])
for inp, out in zip(input_output["inputs"], input_output["outputs"]):
assert isinstance(inp, str) and isinstance(out, str), (
inp,
out,
input_output.get("fn_name"),
)
return data
def load_metadata(path): def load_metadata(path):
@ -31,14 +50,26 @@ def load_metadata(path):
with open(path, "r") as f: with open(path, "r") as f:
data = [json.loads(l) for l in f.readlines()] data = [json.loads(l) for l in f.readlines()]
id2info = {} id2info = {}
omit_cnt = defaultdict(int)
task_cnt = defaultdict(int)
for d in data: for d in data:
assert d["query_id"] not in d, (d["task"], d["query_id"]) assert d["query_id"] not in d, (d["task"], d["query_id"])
if d["task"] == "math": try:
check_math_metadata_entries(d) if d["task"] == "math":
elif d["task"] == "code": d = check_math_metadata_entries(d)
check_code_metadata_entries(d) elif d["task"] == "code":
d = check_code_metadata_entries(d)
except Exception as e:
logger.warning(
f"Data validation failed: query_id {d['query_id']}. "
f"Error: {traceback.format_exc()}. Omit it in the dataset."
)
omit_cnt[d["task"]] += 1
continue
id2info[d["query_id"]] = d id2info[d["query_id"]] = d
return id2info task_cnt[d["task"]] += 1
logger.warning(f"Number of ignored data: {dict(**omit_cnt)}")
return id2info, dict(task_cnt)
class MATHCodePromptDataset(torch.utils.data.Dataset): class MATHCodePromptDataset(torch.utils.data.Dataset):
@ -59,8 +90,7 @@ class MATHCodePromptDataset(torch.utils.data.Dataset):
self._util = util self._util = util
self.max_length = max_length self.max_length = max_length
global id2info id2info, task_cnt = load_metadata(dataset_path)
id2info = load_metadata(dataset_path)
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder) data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
@ -83,7 +113,9 @@ class MATHCodePromptDataset(torch.utils.data.Dataset):
indices = [ indices = [
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
] ]
logger.info(f"{len(indices)} samples remain") logger.info(
f"{len(indices)} samples remain, among them {task_cnt['math']} are math data and {task_cnt['code']} are code data"
)
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices] self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices] self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
@ -183,4 +215,5 @@ else:
) )
print(f"size: {len(dataset)}") print(f"size: {len(dataset)}")
for d in dataloader: for d in dataloader:
print(d.ids) # print(d.ids)
pass

View File

@ -190,7 +190,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
def __post_init__(self): def __post_init__(self):
global id2info global id2info
id2info = load_metadata(self.dataset_path) id2info, _ = load_metadata(self.dataset_path)
self.tokenizer = load_hf_tokenizer(self.tokenizer_path) self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
if constants.parallelism_rank() == 0: if constants.parallelism_rank() == 0:
logger.info(f"output_scaling: {self.output_scaling}") logger.info(f"output_scaling: {self.output_scaling}")

View File

@ -2,6 +2,7 @@
import os import os
import shutil import shutil
import uuid
from typing import * from typing import *
import pytest import pytest
@ -34,14 +35,18 @@ def math_code_dataset(request, save_path):
dataset = [] dataset = []
for i in range(size): for i in range(size):
prompt_len = random.randint(1, max_prompt_len) prompt_len = random.randint(1, max_prompt_len)
n_pairs = random.randint(1, 5)
d = dict( d = dict(
query_id=query_ids[i], query_id=str(uuid.uuid4()),
prompt=generate_random_sentence(prompt_len), prompt=generate_random_sentence(prompt_len),
task=random.choice(["math", "code"]),
) )
if d["task"] == "math":
d["solutions"] = [generate_random_sentence(max_resp_len)]
elif d["task"] == "code":
d["input_output"] = json.dumps(dict(inputs=["the\n"], outputs=["the\n"]))
dataset.append(d) dataset.append(d)
with open(str(save_path / "math_code_dataset.json"), "w") as f: with open(str(save_path / "math_code_dataset.jsonl"), "a") as f:
json.dump(dataset, f) f.write(json.dumps(d) + "\n")
return dataset return dataset
@ -57,7 +62,7 @@ def math_code_dataset(request, save_path):
def test_ppo_symm( def test_ppo_symm(
tmp_path_factory, tmp_path_factory,
tokenizer, tokenizer,
math_dataset, math_code_dataset,
save_path, save_path,
cpu_hf_model, cpu_hf_model,
mconfig, mconfig,
@ -96,7 +101,7 @@ def test_ppo_symm(
backend="mock_train", backend="mock_train",
), ),
dataset=PromptOnlyDatasetConfig( dataset=PromptOnlyDatasetConfig(
path=str(save_path / "math_dataset.json"), path=str(save_path / "math_code_dataset.jsonl"),
max_prompt_len=mconfig.n_positions // 2, max_prompt_len=mconfig.n_positions // 2,
train_bs_n_seqs=minbs, train_bs_n_seqs=minbs,
fill_to_max_length=False, fill_to_max_length=False,
@ -127,7 +132,7 @@ def test_ppo_symm(
def test_ppo_global_reshard( def test_ppo_global_reshard(
tmp_path_factory, tmp_path_factory,
tokenizer, tokenizer,
math_dataset, math_code_dataset,
save_path, save_path,
cpu_hf_model, cpu_hf_model,
mconfig, mconfig,
@ -174,7 +179,7 @@ def test_ppo_global_reshard(
init_from_scratch=True, init_from_scratch=True,
), ),
dataset=PromptOnlyDatasetConfig( dataset=PromptOnlyDatasetConfig(
path=str(save_path / "math_dataset.json"), path=str(save_path / "math_code_dataset.jsonl"),
max_prompt_len=mconfig.n_positions // 2, max_prompt_len=mconfig.n_positions // 2,
train_bs_n_seqs=minbs, train_bs_n_seqs=minbs,
fill_to_max_length=False, fill_to_max_length=False,
@ -244,7 +249,7 @@ def test_ppo_global_reshard(
def test_ppo_param_realloc_sub_device_mesh( def test_ppo_param_realloc_sub_device_mesh(
tmp_path_factory, tmp_path_factory,
tokenizer, tokenizer,
math_dataset, math_code_dataset,
save_path, save_path,
cpu_hf_model, cpu_hf_model,
mconfig, mconfig,
@ -287,7 +292,7 @@ def test_ppo_param_realloc_sub_device_mesh(
init_from_scratch=True, init_from_scratch=True,
), ),
dataset=PromptOnlyDatasetConfig( dataset=PromptOnlyDatasetConfig(
path=str(save_path / "math_dataset.json"), path=str(save_path / "math_code_dataset.jsonl"),
max_prompt_len=mconfig.n_positions // 2, max_prompt_len=mconfig.n_positions // 2,
train_bs_n_seqs=minbs, train_bs_n_seqs=minbs,
fill_to_max_length=False, fill_to_max_length=False,
@ -406,7 +411,7 @@ def test_ppo_save(
init_from_scratch=True, init_from_scratch=True,
), ),
dataset=PromptOnlyDatasetConfig( dataset=PromptOnlyDatasetConfig(
path=str(save_path / "math_dataset.json"), path=str(save_path / "math_code_dataset.jsonl"),
max_prompt_len=mconfig.n_positions // 2, max_prompt_len=mconfig.n_positions // 2,
train_bs_n_seqs=bs, train_bs_n_seqs=bs,
fill_to_max_length=False, fill_to_max_length=False,