mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
parent
d1554585a4
commit
9dcdb7a684
|
@ -2,14 +2,16 @@ import json
|
|||
import random
|
||||
|
||||
data = []
|
||||
with open("/storage/openpsi/data/code/apps/codeparrot-apps-test.jsonl", "r") as f:
|
||||
with open("/storage/openpsi/data/code/apps/test.jsonl", "r") as f:
|
||||
code_data = [json.loads(l) for l in f.readlines()]
|
||||
|
||||
original_keys = list(code_data[0].keys())
|
||||
print(original_keys)
|
||||
for d in code_data:
|
||||
print(d["starter_code"], type(d["starter_code"]))
|
||||
# print(d["starter_code"], type(d["starter_code"]))
|
||||
# print(json.loads(d["solutions"])[0])
|
||||
inout = json.loads(d["input_output"])
|
||||
print(dict(inputs=inout["inputs"][:2], outputs=inout["outputs"][:2]))
|
||||
exit(0)
|
||||
d["query_id"] = d["id"]
|
||||
d["prompt"] = d["question"]
|
||||
|
|
|
@ -18,12 +18,12 @@ def capture_stdout(code):
|
|||
fake_stdout = StringIO()
|
||||
|
||||
try:
|
||||
sys.stdout = fake_stdout # 重定向输出
|
||||
exec(code, {"__builtins__": __builtins__}) # 在隔离环境中执行
|
||||
sys.stdout = fake_stdout
|
||||
exec(code, {"__builtins__": __builtins__})
|
||||
except Exception as e:
|
||||
return f"error: {str(e)}, traceback: {traceback.format_exc()}"
|
||||
finally:
|
||||
sys.stdout = original_stdout # 恢复原stdout
|
||||
sys.stdout = original_stdout
|
||||
return fake_stdout.getvalue()
|
||||
|
||||
|
||||
|
@ -33,12 +33,17 @@ def _temp_run(problem, generation, debug, result):
|
|||
try:
|
||||
if debug:
|
||||
logger.debug(f"Running test for problem: {problem}")
|
||||
result.append(run_test(sample=problem, test=generation, debug=debug))
|
||||
r = run_test(sample=problem, test=generation, debug=debug)
|
||||
result.append(r)
|
||||
if debug:
|
||||
logger.debug(f"Test completed with result: {result}")
|
||||
except Exception as e:
|
||||
if debug:
|
||||
logger.error(f"Error in _temp_run: {e}, problem:{problem}")
|
||||
|
||||
logger.warning(
|
||||
f"Error in _temp_run: {e}\n"
|
||||
f"traceback: {''.join(traceback.format_exception(*sys.exc_info()))}\n"
|
||||
f"problem:{problem}"
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(
|
||||
|
@ -51,6 +56,7 @@ def check_correctness(problem, generation, timeout, debug=False):
|
|||
The global timeout is to catch some extreme/rare cases not handled by the timeouts
|
||||
inside `run_test`"""
|
||||
if debug:
|
||||
# FIXME: error variable "problem" is not defined
|
||||
result = capture_stdout(
|
||||
"from functioncall.code.function.testing_util import run_test\n"
|
||||
+ "run_test(sample=problem, test=generation, debug=debug)"
|
||||
|
@ -73,7 +79,7 @@ def check_correctness(problem, generation, timeout, debug=False):
|
|||
# Remark: ideally we would consider that all tests failed but we can't access number of tests here easily
|
||||
# so we use 21=the average number of tests for a smaple in the test split instead
|
||||
avg_number_tests = 21
|
||||
result = [[-1, ""] for _ in range(avg_number_tests)]
|
||||
result = [[-1 for _ in range(avg_number_tests)], {}]
|
||||
if debug:
|
||||
logger.debug(f"Global timeout occurred, returning default result.")
|
||||
if debug:
|
||||
|
@ -86,7 +92,7 @@ def check_correctness(problem, generation, timeout, debug=False):
|
|||
return result[0]
|
||||
|
||||
|
||||
def code_verify(id2info, generateds, query_ids, debug=True):
|
||||
def code_verify(id2info, generateds, query_ids, debug=False):
|
||||
assert len(generateds) == len(query_ids)
|
||||
problems = [id2info[qid] for qid in query_ids]
|
||||
|
||||
|
@ -117,19 +123,25 @@ def code_verify(id2info, generateds, query_ids, debug=True):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
path = "/storage/openpsi/data/code/apps/codeparrot-apps-test.jsonl"
|
||||
path = "/storage/openpsi/data/code/apps/test.jsonl"
|
||||
data = []
|
||||
with open(path, "r") as f:
|
||||
code_data = [json.loads(l) for l in f.readlines()]
|
||||
|
||||
problem = code_data[0]
|
||||
problem["problem_id"] = problem["id"]
|
||||
id2info = {problem["problem_id"]: problem}
|
||||
id2info = {}
|
||||
solutions = []
|
||||
query_ids = []
|
||||
for i in range(10):
|
||||
problem = code_data[i]
|
||||
problem["problem_id"] = problem["id"]
|
||||
id2info[problem["problem_id"]] = problem
|
||||
solutions.append(json.loads(problem["solutions"])[0])
|
||||
query_ids.append(problem["id"])
|
||||
|
||||
result = code_verify(
|
||||
id2info,
|
||||
[json.loads(problem["solutions"])[0]],
|
||||
[problem["problem_id"]],
|
||||
solutions,
|
||||
query_ids,
|
||||
debug=False,
|
||||
)
|
||||
print(result)
|
||||
|
|
|
@ -11,7 +11,11 @@ def process_results(answer, solution):
|
|||
|
||||
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
|
||||
retval = 0
|
||||
elif extracted_solution is None or extracted_solution.strip() in ["None", "none", ""]:
|
||||
elif extracted_solution is None or extracted_solution.strip() in [
|
||||
"None",
|
||||
"none",
|
||||
"",
|
||||
]:
|
||||
retval = 0
|
||||
elif math_equal(extracted_answer, extracted_solution, timeout=True):
|
||||
retval = 1
|
||||
|
|
|
@ -18,7 +18,11 @@ def process_results(answer, solution):
|
|||
# raise
|
||||
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
|
||||
retval = 0
|
||||
elif extracted_solution is None or extracted_solution.strip() in ["None", "none", ""]:
|
||||
elif extracted_solution is None or extracted_solution.strip() in [
|
||||
"None",
|
||||
"none",
|
||||
"",
|
||||
]:
|
||||
retval = 0
|
||||
elif math_equal(extracted_answer, extracted_solution, timeout=False):
|
||||
# elif call_with_timeout(math_equal, extracted_answer, extracted_solution):
|
||||
|
|
|
@ -17,10 +17,12 @@ id2info = {}
|
|||
|
||||
|
||||
def check_code_metadata_entries(data):
|
||||
# TODO: check test multi task reward
|
||||
pass
|
||||
|
||||
|
||||
def check_math_metadata_entries(data):
|
||||
# TODO: check test multi task reward
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ from realhf.api.core.data_api import (
|
|||
)
|
||||
from realhf.base import constants
|
||||
from realhf.base.datapack import flat2d
|
||||
from realhf.impl.dataset.math_code_dataset import id2info
|
||||
from realhf.impl.dataset.math_code_dataset import load_metadata
|
||||
from realhf.impl.dataset.math_parser import parse_lines_in_parallel as math_verify_local
|
||||
|
||||
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
|
||||
|
@ -121,6 +121,9 @@ def check_with_elementtree(text):
|
|||
return False, f"Error: XML格式错误, {str(e)}"
|
||||
|
||||
|
||||
id2info = {}
|
||||
|
||||
|
||||
def dispatch_reward_calculation(task, answers, query_id_strs) -> List:
|
||||
global id2info
|
||||
assert len(answers) == len(query_id_strs)
|
||||
|
@ -176,6 +179,7 @@ def retokenize_and_verify(
|
|||
|
||||
@dataclasses.dataclass
|
||||
class MultiTaskRewardInterface(model_api.ModelInterface):
|
||||
dataset_path: str = ""
|
||||
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
||||
output_scaling: float = 1.0
|
||||
output_bias: float = 0.0
|
||||
|
@ -185,6 +189,8 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
check_verifier_status: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
global id2info
|
||||
id2info = load_metadata(self.dataset_path)
|
||||
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(f"output_scaling: {self.output_scaling}")
|
||||
|
@ -197,7 +203,6 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
dispatched_indices = {}
|
||||
for task_idx, task_name in enumerate(RL_TASKS):
|
||||
indices = (data.data["task_ids"] == task_idx).numpy().nonzero()[0].tolist()
|
||||
print("============", task_idx, task_name, indices)
|
||||
if len(indices) > 0:
|
||||
dispatched[task_name] = SequenceSample.gather([xs[i] for i in indices])
|
||||
dispatched_indices[task_name] = indices
|
||||
|
@ -209,7 +214,6 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
) -> SequenceSample:
|
||||
xs = [None for _ in range(bs)]
|
||||
for task_name, indices in dispatched_indices.items():
|
||||
print(task_name, indices)
|
||||
xxs = results[task_name].unpack()
|
||||
assert len(indices) == len(xxs), (len(indices), len(xxs))
|
||||
for i, xx in zip(indices, xxs):
|
||||
|
@ -361,8 +365,8 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
+ "\n"
|
||||
)
|
||||
|
||||
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
|
||||
logger.info(f"reward: {sum(scores) / len(scores)}")
|
||||
logger.info(f"[{task_type}] number of samples: {len(scores)}, {scores.shape}")
|
||||
logger.info(f"[{task_type}] avg reward: {sum(scores) / len(scores)}")
|
||||
|
||||
def inference(
|
||||
self,
|
||||
|
@ -371,8 +375,6 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample | None:
|
||||
task_data, dispatch_indices = self._dispatch_tasks(data)
|
||||
for d in task_data.values():
|
||||
print(d.bs)
|
||||
|
||||
assert self.rw_type == "sparse"
|
||||
|
||||
|
@ -423,7 +425,6 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
model: model_api.Model,
|
||||
data: SequenceSample,
|
||||
) -> SequenceSample:
|
||||
from realhf.base.testing import TESTING_MODEL_VOCAB_SIZE
|
||||
|
||||
prompt_lens = flat2d(data.seqlens["packed_prompts"])
|
||||
task_ids = data.data["task_ids"].cpu().numpy().tolist()
|
||||
|
@ -435,11 +436,11 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
offset += plen
|
||||
if task_id == RL_TASKS.index("math"):
|
||||
answer_str = (
|
||||
"something unimportant but the answer is \\boxed{-\\frac{2}{3}}"
|
||||
"something unimportant but the answer is \\boxed{-\\frac{2}{3}}."
|
||||
)
|
||||
elif task_id == RL_TASKS.index("code"):
|
||||
answer_str = (
|
||||
"```python\ninput()\nprint(1)\nprint(1)\nprint(1)\nprint(1)\n```"
|
||||
"```python\ninput()\nimport time\ntime.sleep(1e-3)\nprint(1)\n```"
|
||||
)
|
||||
else:
|
||||
answer_str = "something unimportant"
|
||||
|
|
|
@ -35,10 +35,10 @@ def math_code_dataset(request, save_path):
|
|||
prompt=generate_random_sentence(prompt_len),
|
||||
problem_id=str(uuid.uuid4()),
|
||||
input_output=json.dumps(
|
||||
{"inputs": [1] * 8, "outputs": ["1\n1\n1\n1\n"] * 8}
|
||||
{"inputs": ["1\n"] * 8, "outputs": ["1\n"] * 8}
|
||||
),
|
||||
solutions=json.dumps(
|
||||
["```python\ninput()\nprint(1)\nprint(1)\nprint(1)\nprint(1)\n```"]
|
||||
["```python\ninput()\nimport time\ntime.sleep(1e-3)\nprint(1)\n```"]
|
||||
* 3
|
||||
),
|
||||
difficulty=random.random() * 10,
|
||||
|
@ -48,8 +48,8 @@ def math_code_dataset(request, save_path):
|
|||
task="math",
|
||||
query_id=str(uuid.uuid4()),
|
||||
prompt=generate_random_sentence(prompt_len),
|
||||
answers=["-\\frac{2}{3}"],
|
||||
solutions=["-\\frac{2}{3}"],
|
||||
answers=["\\boxed{-\\frac{2}{3}}"],
|
||||
solutions=["\\boxed{-\\frac{2}{3}}"],
|
||||
)
|
||||
dataset.append(d)
|
||||
with open(str(save_path / "math_code_dataset.jsonl"), "w") as f:
|
||||
|
@ -90,6 +90,7 @@ def test_multi_task_reward_interface(save_path, tokenizer_path, math_code_datase
|
|||
|
||||
with constants.model_scope(testing.MODEL_NAME):
|
||||
interface = MultiTaskRewardInterface(
|
||||
dataset_path=str(save_path / "math_code_dataset.jsonl"),
|
||||
tokenizer_path=tokenizer_path,
|
||||
group_size=1,
|
||||
check_verifier_status=False,
|
||||
|
@ -108,4 +109,5 @@ def test_multi_task_reward_interface(save_path, tokenizer_path, math_code_datase
|
|||
d = interface.mock("inference", model, d)
|
||||
rewards = interface.inference(model, d, mb_spec=MicroBatchSpec())
|
||||
d.update_(rewards)
|
||||
print("success")
|
||||
assert rewards.data["rewards"].all(), rewards.data["rewards"]
|
||||
dist.destroy_process_group()
|
||||
|
|
Loading…
Reference in New Issue