mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
parent
9dcdb7a684
commit
25c45c7e83
|
@ -1,32 +0,0 @@
|
|||
import json
|
||||
import random
|
||||
|
||||
data = []
|
||||
with open("/storage/openpsi/data/code/apps/test.jsonl", "r") as f:
|
||||
code_data = [json.loads(l) for l in f.readlines()]
|
||||
|
||||
original_keys = list(code_data[0].keys())
|
||||
print(original_keys)
|
||||
for d in code_data:
|
||||
# print(d["starter_code"], type(d["starter_code"]))
|
||||
# print(json.loads(d["solutions"])[0])
|
||||
inout = json.loads(d["input_output"])
|
||||
print(dict(inputs=inout["inputs"][:2], outputs=inout["outputs"][:2]))
|
||||
exit(0)
|
||||
d["query_id"] = d["id"]
|
||||
d["prompt"] = d["question"]
|
||||
d["task"] = "code"
|
||||
for k in original_keys:
|
||||
d.pop(k)
|
||||
data.append(d)
|
||||
|
||||
with open("/storage/openpsi/users/gjx/data/DeepScaleR/prompts.jsonl", "r") as f:
|
||||
math_data = [json.loads(l) for l in f.readlines()]
|
||||
|
||||
for d in math_data:
|
||||
data.append(dict(prompt=d["prompt"], task="math", query_id=d["query_id"]))
|
||||
|
||||
random.shuffle(data)
|
||||
with open("/storage/openpsi/users/bowei.fw/data/code_math.jsonl", "w") as f:
|
||||
for d in data:
|
||||
f.write(json.dumps(d, ensure_ascii=False) + "\n")
|
|
@ -7,7 +7,7 @@ from grader import math_equal
|
|||
|
||||
def process_results(answer, solution):
|
||||
extracted_answer = extract_answer(answer, "math", use_last_number=False)
|
||||
extracted_solution = solution
|
||||
extracted_solution = extract_answer(solution, "math", use_last_number=True)
|
||||
|
||||
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
|
||||
retval = 0
|
||||
|
|
|
@ -24,7 +24,7 @@ def math_verify(
|
|||
for idx, (query_id, generated) in enumerate(zip(query_ids, generateds)):
|
||||
base_query_id = query_id.split("@idx:")[0]
|
||||
info = id2info[base_query_id]
|
||||
for cur_solution in info["answers"]:
|
||||
for cur_solution in info["solutions"]:
|
||||
parameters.append((generated, cur_solution, idx))
|
||||
query_indices.append(idx)
|
||||
|
||||
|
|
|
@ -302,6 +302,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
rw_interface = ModelInterfaceAbstraction(
|
||||
"reward",
|
||||
args=dict(
|
||||
dataset_path=self.dataset.path,
|
||||
tokenizer_path=self.actor.path,
|
||||
output_scaling=self.ppo.reward_output_scaling,
|
||||
output_bias=self.ppo.reward_output_bias,
|
||||
|
@ -336,7 +337,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
model_type=self.actor.type,
|
||||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=["packed_prompts"],
|
||||
input_keys=["packed_prompts", "task_ids"],
|
||||
output_keys=rollout_output_keys,
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
@ -361,20 +362,20 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
interface_impl=rw_interface,
|
||||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=["packed_input_ids", "packed_prompts"],
|
||||
input_keys=["packed_input_ids", "packed_prompts", "task_ids"],
|
||||
output_keys=["rewards"],
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
# add rew param into ref MFC
|
||||
inf_ref_inputs = ["packed_input_ids"]
|
||||
inf_ref_outputs = ["logprobs"]
|
||||
inf_ref_outputs = ["packed_ref_logprobs"]
|
||||
if self.ppo.fuse_rew_ref:
|
||||
inf_ref_inputs += ["packed_prompts"]
|
||||
inf_ref_inputs += ["packed_prompts", "task_ids"]
|
||||
inf_ref_outputs += ["rewards"]
|
||||
|
||||
inf_ref_logits = MFCDef(
|
||||
name="ref_rw",
|
||||
name="ref_inf",
|
||||
model_name="ref",
|
||||
mb_spec=self.ref_inf.mb_spec,
|
||||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
|
@ -495,7 +496,6 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
args=dict(
|
||||
dataset_path=self.dataset.path,
|
||||
max_length=self.dataset.max_prompt_len,
|
||||
fill_to_max_length=self.dataset.fill_to_max_length,
|
||||
filter_threshold=self.dataset_filter_threshold,
|
||||
max_filter_percentage=self.dataset_max_filter_percentage,
|
||||
),
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import json
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, Hashable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -13,17 +15,34 @@ from realhf.base import logging
|
|||
|
||||
logger = logging.getLogger("Math Code Dataset")
|
||||
|
||||
id2info = {}
|
||||
|
||||
def check_math_metadata_entries(data):
|
||||
assert data["task"] == "math"
|
||||
assert "query_id" in data
|
||||
data["query_id"] = str(data["query_id"])
|
||||
assert isinstance(data["prompt"], str)
|
||||
assert isinstance(data["solutions"], list)
|
||||
for sol in data["solutions"]:
|
||||
assert isinstance(sol, str)
|
||||
return data
|
||||
|
||||
|
||||
def check_code_metadata_entries(data):
|
||||
# TODO: check test multi task reward
|
||||
pass
|
||||
|
||||
|
||||
def check_math_metadata_entries(data):
|
||||
# TODO: check test multi task reward
|
||||
pass
|
||||
assert data["task"] == "code"
|
||||
assert "query_id" in data
|
||||
data["query_id"] = str(data["query_id"])
|
||||
if "problem_id" not in data:
|
||||
data["problem_id"] = data["query_id"]
|
||||
assert isinstance(data["prompt"], str)
|
||||
input_output = json.loads(data["input_output"])
|
||||
assert len(input_output["inputs"]) == len(input_output["outputs"])
|
||||
for inp, out in zip(input_output["inputs"], input_output["outputs"]):
|
||||
assert isinstance(inp, str) and isinstance(out, str), (
|
||||
inp,
|
||||
out,
|
||||
input_output.get("fn_name"),
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def load_metadata(path):
|
||||
|
@ -31,14 +50,26 @@ def load_metadata(path):
|
|||
with open(path, "r") as f:
|
||||
data = [json.loads(l) for l in f.readlines()]
|
||||
id2info = {}
|
||||
omit_cnt = defaultdict(int)
|
||||
task_cnt = defaultdict(int)
|
||||
for d in data:
|
||||
assert d["query_id"] not in d, (d["task"], d["query_id"])
|
||||
try:
|
||||
if d["task"] == "math":
|
||||
check_math_metadata_entries(d)
|
||||
d = check_math_metadata_entries(d)
|
||||
elif d["task"] == "code":
|
||||
check_code_metadata_entries(d)
|
||||
d = check_code_metadata_entries(d)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Data validation failed: query_id {d['query_id']}. "
|
||||
f"Error: {traceback.format_exc()}. Omit it in the dataset."
|
||||
)
|
||||
omit_cnt[d["task"]] += 1
|
||||
continue
|
||||
id2info[d["query_id"]] = d
|
||||
return id2info
|
||||
task_cnt[d["task"]] += 1
|
||||
logger.warning(f"Number of ignored data: {dict(**omit_cnt)}")
|
||||
return id2info, dict(task_cnt)
|
||||
|
||||
|
||||
class MATHCodePromptDataset(torch.utils.data.Dataset):
|
||||
|
@ -59,8 +90,7 @@ class MATHCodePromptDataset(torch.utils.data.Dataset):
|
|||
self._util = util
|
||||
self.max_length = max_length
|
||||
|
||||
global id2info
|
||||
id2info = load_metadata(dataset_path)
|
||||
id2info, task_cnt = load_metadata(dataset_path)
|
||||
|
||||
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
|
||||
|
||||
|
@ -83,7 +113,9 @@ class MATHCodePromptDataset(torch.utils.data.Dataset):
|
|||
indices = [
|
||||
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
|
||||
]
|
||||
logger.info(f"{len(indices)} samples remain")
|
||||
logger.info(
|
||||
f"{len(indices)} samples remain, among them {task_cnt['math']} are math data and {task_cnt['code']} are code data"
|
||||
)
|
||||
|
||||
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
|
||||
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
|
||||
|
@ -183,4 +215,5 @@ else:
|
|||
)
|
||||
print(f"size: {len(dataset)}")
|
||||
for d in dataloader:
|
||||
print(d.ids)
|
||||
# print(d.ids)
|
||||
pass
|
||||
|
|
|
@ -190,7 +190,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
|
||||
def __post_init__(self):
|
||||
global id2info
|
||||
id2info = load_metadata(self.dataset_path)
|
||||
id2info, _ = load_metadata(self.dataset_path)
|
||||
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(f"output_scaling: {self.output_scaling}")
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import *
|
||||
|
||||
import pytest
|
||||
|
@ -34,14 +35,18 @@ def math_code_dataset(request, save_path):
|
|||
dataset = []
|
||||
for i in range(size):
|
||||
prompt_len = random.randint(1, max_prompt_len)
|
||||
n_pairs = random.randint(1, 5)
|
||||
d = dict(
|
||||
query_id=query_ids[i],
|
||||
query_id=str(uuid.uuid4()),
|
||||
prompt=generate_random_sentence(prompt_len),
|
||||
task=random.choice(["math", "code"]),
|
||||
)
|
||||
if d["task"] == "math":
|
||||
d["solutions"] = [generate_random_sentence(max_resp_len)]
|
||||
elif d["task"] == "code":
|
||||
d["input_output"] = json.dumps(dict(inputs=["the\n"], outputs=["the\n"]))
|
||||
dataset.append(d)
|
||||
with open(str(save_path / "math_code_dataset.json"), "w") as f:
|
||||
json.dump(dataset, f)
|
||||
with open(str(save_path / "math_code_dataset.jsonl"), "a") as f:
|
||||
f.write(json.dumps(d) + "\n")
|
||||
return dataset
|
||||
|
||||
|
||||
|
@ -57,7 +62,7 @@ def math_code_dataset(request, save_path):
|
|||
def test_ppo_symm(
|
||||
tmp_path_factory,
|
||||
tokenizer,
|
||||
math_dataset,
|
||||
math_code_dataset,
|
||||
save_path,
|
||||
cpu_hf_model,
|
||||
mconfig,
|
||||
|
@ -96,7 +101,7 @@ def test_ppo_symm(
|
|||
backend="mock_train",
|
||||
),
|
||||
dataset=PromptOnlyDatasetConfig(
|
||||
path=str(save_path / "math_dataset.json"),
|
||||
path=str(save_path / "math_code_dataset.jsonl"),
|
||||
max_prompt_len=mconfig.n_positions // 2,
|
||||
train_bs_n_seqs=minbs,
|
||||
fill_to_max_length=False,
|
||||
|
@ -127,7 +132,7 @@ def test_ppo_symm(
|
|||
def test_ppo_global_reshard(
|
||||
tmp_path_factory,
|
||||
tokenizer,
|
||||
math_dataset,
|
||||
math_code_dataset,
|
||||
save_path,
|
||||
cpu_hf_model,
|
||||
mconfig,
|
||||
|
@ -174,7 +179,7 @@ def test_ppo_global_reshard(
|
|||
init_from_scratch=True,
|
||||
),
|
||||
dataset=PromptOnlyDatasetConfig(
|
||||
path=str(save_path / "math_dataset.json"),
|
||||
path=str(save_path / "math_code_dataset.jsonl"),
|
||||
max_prompt_len=mconfig.n_positions // 2,
|
||||
train_bs_n_seqs=minbs,
|
||||
fill_to_max_length=False,
|
||||
|
@ -244,7 +249,7 @@ def test_ppo_global_reshard(
|
|||
def test_ppo_param_realloc_sub_device_mesh(
|
||||
tmp_path_factory,
|
||||
tokenizer,
|
||||
math_dataset,
|
||||
math_code_dataset,
|
||||
save_path,
|
||||
cpu_hf_model,
|
||||
mconfig,
|
||||
|
@ -287,7 +292,7 @@ def test_ppo_param_realloc_sub_device_mesh(
|
|||
init_from_scratch=True,
|
||||
),
|
||||
dataset=PromptOnlyDatasetConfig(
|
||||
path=str(save_path / "math_dataset.json"),
|
||||
path=str(save_path / "math_code_dataset.jsonl"),
|
||||
max_prompt_len=mconfig.n_positions // 2,
|
||||
train_bs_n_seqs=minbs,
|
||||
fill_to_max_length=False,
|
||||
|
@ -406,7 +411,7 @@ def test_ppo_save(
|
|||
init_from_scratch=True,
|
||||
),
|
||||
dataset=PromptOnlyDatasetConfig(
|
||||
path=str(save_path / "math_dataset.json"),
|
||||
path=str(save_path / "math_code_dataset.jsonl"),
|
||||
max_prompt_len=mconfig.n_positions // 2,
|
||||
train_bs_n_seqs=bs,
|
||||
fill_to_max_length=False,
|
||||
|
|
Loading…
Reference in New Issue