mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
parent
9dcdb7a684
commit
25c45c7e83
|
@ -1,32 +0,0 @@
|
||||||
import json
|
|
||||||
import random
|
|
||||||
|
|
||||||
data = []
|
|
||||||
with open("/storage/openpsi/data/code/apps/test.jsonl", "r") as f:
|
|
||||||
code_data = [json.loads(l) for l in f.readlines()]
|
|
||||||
|
|
||||||
original_keys = list(code_data[0].keys())
|
|
||||||
print(original_keys)
|
|
||||||
for d in code_data:
|
|
||||||
# print(d["starter_code"], type(d["starter_code"]))
|
|
||||||
# print(json.loads(d["solutions"])[0])
|
|
||||||
inout = json.loads(d["input_output"])
|
|
||||||
print(dict(inputs=inout["inputs"][:2], outputs=inout["outputs"][:2]))
|
|
||||||
exit(0)
|
|
||||||
d["query_id"] = d["id"]
|
|
||||||
d["prompt"] = d["question"]
|
|
||||||
d["task"] = "code"
|
|
||||||
for k in original_keys:
|
|
||||||
d.pop(k)
|
|
||||||
data.append(d)
|
|
||||||
|
|
||||||
with open("/storage/openpsi/users/gjx/data/DeepScaleR/prompts.jsonl", "r") as f:
|
|
||||||
math_data = [json.loads(l) for l in f.readlines()]
|
|
||||||
|
|
||||||
for d in math_data:
|
|
||||||
data.append(dict(prompt=d["prompt"], task="math", query_id=d["query_id"]))
|
|
||||||
|
|
||||||
random.shuffle(data)
|
|
||||||
with open("/storage/openpsi/users/bowei.fw/data/code_math.jsonl", "w") as f:
|
|
||||||
for d in data:
|
|
||||||
f.write(json.dumps(d, ensure_ascii=False) + "\n")
|
|
|
@ -7,7 +7,7 @@ from grader import math_equal
|
||||||
|
|
||||||
def process_results(answer, solution):
|
def process_results(answer, solution):
|
||||||
extracted_answer = extract_answer(answer, "math", use_last_number=False)
|
extracted_answer = extract_answer(answer, "math", use_last_number=False)
|
||||||
extracted_solution = solution
|
extracted_solution = extract_answer(solution, "math", use_last_number=True)
|
||||||
|
|
||||||
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
|
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
|
||||||
retval = 0
|
retval = 0
|
||||||
|
|
|
@ -24,7 +24,7 @@ def math_verify(
|
||||||
for idx, (query_id, generated) in enumerate(zip(query_ids, generateds)):
|
for idx, (query_id, generated) in enumerate(zip(query_ids, generateds)):
|
||||||
base_query_id = query_id.split("@idx:")[0]
|
base_query_id = query_id.split("@idx:")[0]
|
||||||
info = id2info[base_query_id]
|
info = id2info[base_query_id]
|
||||||
for cur_solution in info["answers"]:
|
for cur_solution in info["solutions"]:
|
||||||
parameters.append((generated, cur_solution, idx))
|
parameters.append((generated, cur_solution, idx))
|
||||||
query_indices.append(idx)
|
query_indices.append(idx)
|
||||||
|
|
||||||
|
|
|
@ -302,6 +302,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
rw_interface = ModelInterfaceAbstraction(
|
rw_interface = ModelInterfaceAbstraction(
|
||||||
"reward",
|
"reward",
|
||||||
args=dict(
|
args=dict(
|
||||||
|
dataset_path=self.dataset.path,
|
||||||
tokenizer_path=self.actor.path,
|
tokenizer_path=self.actor.path,
|
||||||
output_scaling=self.ppo.reward_output_scaling,
|
output_scaling=self.ppo.reward_output_scaling,
|
||||||
output_bias=self.ppo.reward_output_bias,
|
output_bias=self.ppo.reward_output_bias,
|
||||||
|
@ -336,7 +337,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
model_type=self.actor.type,
|
model_type=self.actor.type,
|
||||||
model_path=self.actor.path,
|
model_path=self.actor.path,
|
||||||
interface_impl=actor_interface,
|
interface_impl=actor_interface,
|
||||||
input_keys=["packed_prompts"],
|
input_keys=["packed_prompts", "task_ids"],
|
||||||
output_keys=rollout_output_keys,
|
output_keys=rollout_output_keys,
|
||||||
n_seqs=self.dataset.train_bs_n_seqs,
|
n_seqs=self.dataset.train_bs_n_seqs,
|
||||||
)
|
)
|
||||||
|
@ -361,20 +362,20 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
interface_type=ModelInterfaceType.INFERENCE,
|
interface_type=ModelInterfaceType.INFERENCE,
|
||||||
interface_impl=rw_interface,
|
interface_impl=rw_interface,
|
||||||
min_n_seqs_per_pass=1 / self.group_size,
|
min_n_seqs_per_pass=1 / self.group_size,
|
||||||
input_keys=["packed_input_ids", "packed_prompts"],
|
input_keys=["packed_input_ids", "packed_prompts", "task_ids"],
|
||||||
output_keys=["rewards"],
|
output_keys=["rewards"],
|
||||||
n_seqs=self.dataset.train_bs_n_seqs,
|
n_seqs=self.dataset.train_bs_n_seqs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# add rew param into ref MFC
|
# add rew param into ref MFC
|
||||||
inf_ref_inputs = ["packed_input_ids"]
|
inf_ref_inputs = ["packed_input_ids"]
|
||||||
inf_ref_outputs = ["logprobs"]
|
inf_ref_outputs = ["packed_ref_logprobs"]
|
||||||
if self.ppo.fuse_rew_ref:
|
if self.ppo.fuse_rew_ref:
|
||||||
inf_ref_inputs += ["packed_prompts"]
|
inf_ref_inputs += ["packed_prompts", "task_ids"]
|
||||||
inf_ref_outputs += ["rewards"]
|
inf_ref_outputs += ["rewards"]
|
||||||
|
|
||||||
inf_ref_logits = MFCDef(
|
inf_ref_logits = MFCDef(
|
||||||
name="ref_rw",
|
name="ref_inf",
|
||||||
model_name="ref",
|
model_name="ref",
|
||||||
mb_spec=self.ref_inf.mb_spec,
|
mb_spec=self.ref_inf.mb_spec,
|
||||||
interface_type=ModelInterfaceType.INFERENCE,
|
interface_type=ModelInterfaceType.INFERENCE,
|
||||||
|
@ -495,7 +496,6 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
args=dict(
|
args=dict(
|
||||||
dataset_path=self.dataset.path,
|
dataset_path=self.dataset.path,
|
||||||
max_length=self.dataset.max_prompt_len,
|
max_length=self.dataset.max_prompt_len,
|
||||||
fill_to_max_length=self.dataset.fill_to_max_length,
|
|
||||||
filter_threshold=self.dataset_filter_threshold,
|
filter_threshold=self.dataset_filter_threshold,
|
||||||
max_filter_percentage=self.dataset_max_filter_percentage,
|
max_filter_percentage=self.dataset_max_filter_percentage,
|
||||||
),
|
),
|
||||||
|
|
|
@ -3,6 +3,8 @@
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import traceback
|
||||||
|
from collections import defaultdict
|
||||||
from typing import Callable, Dict, Hashable, List, Optional
|
from typing import Callable, Dict, Hashable, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -13,17 +15,34 @@ from realhf.base import logging
|
||||||
|
|
||||||
logger = logging.getLogger("Math Code Dataset")
|
logger = logging.getLogger("Math Code Dataset")
|
||||||
|
|
||||||
id2info = {}
|
|
||||||
|
def check_math_metadata_entries(data):
|
||||||
|
assert data["task"] == "math"
|
||||||
|
assert "query_id" in data
|
||||||
|
data["query_id"] = str(data["query_id"])
|
||||||
|
assert isinstance(data["prompt"], str)
|
||||||
|
assert isinstance(data["solutions"], list)
|
||||||
|
for sol in data["solutions"]:
|
||||||
|
assert isinstance(sol, str)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def check_code_metadata_entries(data):
|
def check_code_metadata_entries(data):
|
||||||
# TODO: check test multi task reward
|
assert data["task"] == "code"
|
||||||
pass
|
assert "query_id" in data
|
||||||
|
data["query_id"] = str(data["query_id"])
|
||||||
|
if "problem_id" not in data:
|
||||||
def check_math_metadata_entries(data):
|
data["problem_id"] = data["query_id"]
|
||||||
# TODO: check test multi task reward
|
assert isinstance(data["prompt"], str)
|
||||||
pass
|
input_output = json.loads(data["input_output"])
|
||||||
|
assert len(input_output["inputs"]) == len(input_output["outputs"])
|
||||||
|
for inp, out in zip(input_output["inputs"], input_output["outputs"]):
|
||||||
|
assert isinstance(inp, str) and isinstance(out, str), (
|
||||||
|
inp,
|
||||||
|
out,
|
||||||
|
input_output.get("fn_name"),
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def load_metadata(path):
|
def load_metadata(path):
|
||||||
|
@ -31,14 +50,26 @@ def load_metadata(path):
|
||||||
with open(path, "r") as f:
|
with open(path, "r") as f:
|
||||||
data = [json.loads(l) for l in f.readlines()]
|
data = [json.loads(l) for l in f.readlines()]
|
||||||
id2info = {}
|
id2info = {}
|
||||||
|
omit_cnt = defaultdict(int)
|
||||||
|
task_cnt = defaultdict(int)
|
||||||
for d in data:
|
for d in data:
|
||||||
assert d["query_id"] not in d, (d["task"], d["query_id"])
|
assert d["query_id"] not in d, (d["task"], d["query_id"])
|
||||||
if d["task"] == "math":
|
try:
|
||||||
check_math_metadata_entries(d)
|
if d["task"] == "math":
|
||||||
elif d["task"] == "code":
|
d = check_math_metadata_entries(d)
|
||||||
check_code_metadata_entries(d)
|
elif d["task"] == "code":
|
||||||
|
d = check_code_metadata_entries(d)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Data validation failed: query_id {d['query_id']}. "
|
||||||
|
f"Error: {traceback.format_exc()}. Omit it in the dataset."
|
||||||
|
)
|
||||||
|
omit_cnt[d["task"]] += 1
|
||||||
|
continue
|
||||||
id2info[d["query_id"]] = d
|
id2info[d["query_id"]] = d
|
||||||
return id2info
|
task_cnt[d["task"]] += 1
|
||||||
|
logger.warning(f"Number of ignored data: {dict(**omit_cnt)}")
|
||||||
|
return id2info, dict(task_cnt)
|
||||||
|
|
||||||
|
|
||||||
class MATHCodePromptDataset(torch.utils.data.Dataset):
|
class MATHCodePromptDataset(torch.utils.data.Dataset):
|
||||||
|
@ -59,8 +90,7 @@ class MATHCodePromptDataset(torch.utils.data.Dataset):
|
||||||
self._util = util
|
self._util = util
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
|
||||||
global id2info
|
id2info, task_cnt = load_metadata(dataset_path)
|
||||||
id2info = load_metadata(dataset_path)
|
|
||||||
|
|
||||||
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
|
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
|
||||||
|
|
||||||
|
@ -83,7 +113,9 @@ class MATHCodePromptDataset(torch.utils.data.Dataset):
|
||||||
indices = [
|
indices = [
|
||||||
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
|
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
|
||||||
]
|
]
|
||||||
logger.info(f"{len(indices)} samples remain")
|
logger.info(
|
||||||
|
f"{len(indices)} samples remain, among them {task_cnt['math']} are math data and {task_cnt['code']} are code data"
|
||||||
|
)
|
||||||
|
|
||||||
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
|
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
|
||||||
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
|
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
|
||||||
|
@ -183,4 +215,5 @@ else:
|
||||||
)
|
)
|
||||||
print(f"size: {len(dataset)}")
|
print(f"size: {len(dataset)}")
|
||||||
for d in dataloader:
|
for d in dataloader:
|
||||||
print(d.ids)
|
# print(d.ids)
|
||||||
|
pass
|
||||||
|
|
|
@ -190,7 +190,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
global id2info
|
global id2info
|
||||||
id2info = load_metadata(self.dataset_path)
|
id2info, _ = load_metadata(self.dataset_path)
|
||||||
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
|
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
|
||||||
if constants.parallelism_rank() == 0:
|
if constants.parallelism_rank() == 0:
|
||||||
logger.info(f"output_scaling: {self.output_scaling}")
|
logger.info(f"output_scaling: {self.output_scaling}")
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import uuid
|
||||||
from typing import *
|
from typing import *
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -34,14 +35,18 @@ def math_code_dataset(request, save_path):
|
||||||
dataset = []
|
dataset = []
|
||||||
for i in range(size):
|
for i in range(size):
|
||||||
prompt_len = random.randint(1, max_prompt_len)
|
prompt_len = random.randint(1, max_prompt_len)
|
||||||
n_pairs = random.randint(1, 5)
|
|
||||||
d = dict(
|
d = dict(
|
||||||
query_id=query_ids[i],
|
query_id=str(uuid.uuid4()),
|
||||||
prompt=generate_random_sentence(prompt_len),
|
prompt=generate_random_sentence(prompt_len),
|
||||||
|
task=random.choice(["math", "code"]),
|
||||||
)
|
)
|
||||||
|
if d["task"] == "math":
|
||||||
|
d["solutions"] = [generate_random_sentence(max_resp_len)]
|
||||||
|
elif d["task"] == "code":
|
||||||
|
d["input_output"] = json.dumps(dict(inputs=["the\n"], outputs=["the\n"]))
|
||||||
dataset.append(d)
|
dataset.append(d)
|
||||||
with open(str(save_path / "math_code_dataset.json"), "w") as f:
|
with open(str(save_path / "math_code_dataset.jsonl"), "a") as f:
|
||||||
json.dump(dataset, f)
|
f.write(json.dumps(d) + "\n")
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,7 +62,7 @@ def math_code_dataset(request, save_path):
|
||||||
def test_ppo_symm(
|
def test_ppo_symm(
|
||||||
tmp_path_factory,
|
tmp_path_factory,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
math_dataset,
|
math_code_dataset,
|
||||||
save_path,
|
save_path,
|
||||||
cpu_hf_model,
|
cpu_hf_model,
|
||||||
mconfig,
|
mconfig,
|
||||||
|
@ -96,7 +101,7 @@ def test_ppo_symm(
|
||||||
backend="mock_train",
|
backend="mock_train",
|
||||||
),
|
),
|
||||||
dataset=PromptOnlyDatasetConfig(
|
dataset=PromptOnlyDatasetConfig(
|
||||||
path=str(save_path / "math_dataset.json"),
|
path=str(save_path / "math_code_dataset.jsonl"),
|
||||||
max_prompt_len=mconfig.n_positions // 2,
|
max_prompt_len=mconfig.n_positions // 2,
|
||||||
train_bs_n_seqs=minbs,
|
train_bs_n_seqs=minbs,
|
||||||
fill_to_max_length=False,
|
fill_to_max_length=False,
|
||||||
|
@ -127,7 +132,7 @@ def test_ppo_symm(
|
||||||
def test_ppo_global_reshard(
|
def test_ppo_global_reshard(
|
||||||
tmp_path_factory,
|
tmp_path_factory,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
math_dataset,
|
math_code_dataset,
|
||||||
save_path,
|
save_path,
|
||||||
cpu_hf_model,
|
cpu_hf_model,
|
||||||
mconfig,
|
mconfig,
|
||||||
|
@ -174,7 +179,7 @@ def test_ppo_global_reshard(
|
||||||
init_from_scratch=True,
|
init_from_scratch=True,
|
||||||
),
|
),
|
||||||
dataset=PromptOnlyDatasetConfig(
|
dataset=PromptOnlyDatasetConfig(
|
||||||
path=str(save_path / "math_dataset.json"),
|
path=str(save_path / "math_code_dataset.jsonl"),
|
||||||
max_prompt_len=mconfig.n_positions // 2,
|
max_prompt_len=mconfig.n_positions // 2,
|
||||||
train_bs_n_seqs=minbs,
|
train_bs_n_seqs=minbs,
|
||||||
fill_to_max_length=False,
|
fill_to_max_length=False,
|
||||||
|
@ -244,7 +249,7 @@ def test_ppo_global_reshard(
|
||||||
def test_ppo_param_realloc_sub_device_mesh(
|
def test_ppo_param_realloc_sub_device_mesh(
|
||||||
tmp_path_factory,
|
tmp_path_factory,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
math_dataset,
|
math_code_dataset,
|
||||||
save_path,
|
save_path,
|
||||||
cpu_hf_model,
|
cpu_hf_model,
|
||||||
mconfig,
|
mconfig,
|
||||||
|
@ -287,7 +292,7 @@ def test_ppo_param_realloc_sub_device_mesh(
|
||||||
init_from_scratch=True,
|
init_from_scratch=True,
|
||||||
),
|
),
|
||||||
dataset=PromptOnlyDatasetConfig(
|
dataset=PromptOnlyDatasetConfig(
|
||||||
path=str(save_path / "math_dataset.json"),
|
path=str(save_path / "math_code_dataset.jsonl"),
|
||||||
max_prompt_len=mconfig.n_positions // 2,
|
max_prompt_len=mconfig.n_positions // 2,
|
||||||
train_bs_n_seqs=minbs,
|
train_bs_n_seqs=minbs,
|
||||||
fill_to_max_length=False,
|
fill_to_max_length=False,
|
||||||
|
@ -406,7 +411,7 @@ def test_ppo_save(
|
||||||
init_from_scratch=True,
|
init_from_scratch=True,
|
||||||
),
|
),
|
||||||
dataset=PromptOnlyDatasetConfig(
|
dataset=PromptOnlyDatasetConfig(
|
||||||
path=str(save_path / "math_dataset.json"),
|
path=str(save_path / "math_code_dataset.jsonl"),
|
||||||
max_prompt_len=mconfig.n_positions // 2,
|
max_prompt_len=mconfig.n_positions // 2,
|
||||||
train_bs_n_seqs=bs,
|
train_bs_n_seqs=bs,
|
||||||
fill_to_max_length=False,
|
fill_to_max_length=False,
|
||||||
|
|
Loading…
Reference in New Issue