From 25c45c7e836340f6719bca7b040ba11bb1aee109 Mon Sep 17 00:00:00 2001 From: "bowei.fw" Date: Sat, 22 Mar 2025 12:22:34 +0800 Subject: [PATCH] . --- create_multitask_data.py | 32 ---------- functioncall/math/function/handler.py | 2 +- functioncall/math/verify.py | 2 +- realhf/experiments/common/ppo_math_exp.py | 12 ++-- realhf/impl/dataset/math_code_dataset.py | 67 +++++++++++++++------ realhf/impl/model/interface/rw_interface.py | 2 +- tests/experiments/test_math_ppo.py | 27 +++++---- 7 files changed, 75 insertions(+), 69 deletions(-) delete mode 100644 create_multitask_data.py diff --git a/create_multitask_data.py b/create_multitask_data.py deleted file mode 100644 index 831b9d9..0000000 --- a/create_multitask_data.py +++ /dev/null @@ -1,32 +0,0 @@ -import json -import random - -data = [] -with open("/storage/openpsi/data/code/apps/test.jsonl", "r") as f: - code_data = [json.loads(l) for l in f.readlines()] - -original_keys = list(code_data[0].keys()) -print(original_keys) -for d in code_data: - # print(d["starter_code"], type(d["starter_code"])) - # print(json.loads(d["solutions"])[0]) - inout = json.loads(d["input_output"]) - print(dict(inputs=inout["inputs"][:2], outputs=inout["outputs"][:2])) - exit(0) - d["query_id"] = d["id"] - d["prompt"] = d["question"] - d["task"] = "code" - for k in original_keys: - d.pop(k) - data.append(d) - -with open("/storage/openpsi/users/gjx/data/DeepScaleR/prompts.jsonl", "r") as f: - math_data = [json.loads(l) for l in f.readlines()] - -for d in math_data: - data.append(dict(prompt=d["prompt"], task="math", query_id=d["query_id"])) - -random.shuffle(data) -with open("/storage/openpsi/users/bowei.fw/data/code_math.jsonl", "w") as f: - for d in data: - f.write(json.dumps(d, ensure_ascii=False) + "\n") diff --git a/functioncall/math/function/handler.py b/functioncall/math/function/handler.py index 2d2e45e..263e0b7 100644 --- a/functioncall/math/function/handler.py +++ b/functioncall/math/function/handler.py @@ -7,7 +7,7 @@ from grader import math_equal def process_results(answer, solution): extracted_answer = extract_answer(answer, "math", use_last_number=False) - extracted_solution = solution + extracted_solution = extract_answer(solution, "math", use_last_number=True) if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]: retval = 0 diff --git a/functioncall/math/verify.py b/functioncall/math/verify.py index 1218432..2eff2c9 100644 --- a/functioncall/math/verify.py +++ b/functioncall/math/verify.py @@ -24,7 +24,7 @@ def math_verify( for idx, (query_id, generated) in enumerate(zip(query_ids, generateds)): base_query_id = query_id.split("@idx:")[0] info = id2info[base_query_id] - for cur_solution in info["answers"]: + for cur_solution in info["solutions"]: parameters.append((generated, cur_solution, idx)) query_indices.append(idx) diff --git a/realhf/experiments/common/ppo_math_exp.py b/realhf/experiments/common/ppo_math_exp.py index 748b8bb..171c0a2 100644 --- a/realhf/experiments/common/ppo_math_exp.py +++ b/realhf/experiments/common/ppo_math_exp.py @@ -302,6 +302,7 @@ class PPOMATHConfig(CommonExperimentConfig): rw_interface = ModelInterfaceAbstraction( "reward", args=dict( + dataset_path=self.dataset.path, tokenizer_path=self.actor.path, output_scaling=self.ppo.reward_output_scaling, output_bias=self.ppo.reward_output_bias, @@ -336,7 +337,7 @@ class PPOMATHConfig(CommonExperimentConfig): model_type=self.actor.type, model_path=self.actor.path, interface_impl=actor_interface, - input_keys=["packed_prompts"], + input_keys=["packed_prompts", "task_ids"], output_keys=rollout_output_keys, n_seqs=self.dataset.train_bs_n_seqs, ) @@ -361,20 +362,20 @@ class PPOMATHConfig(CommonExperimentConfig): interface_type=ModelInterfaceType.INFERENCE, interface_impl=rw_interface, min_n_seqs_per_pass=1 / self.group_size, - input_keys=["packed_input_ids", "packed_prompts"], + input_keys=["packed_input_ids", "packed_prompts", "task_ids"], output_keys=["rewards"], n_seqs=self.dataset.train_bs_n_seqs, ) # add rew param into ref MFC inf_ref_inputs = ["packed_input_ids"] - inf_ref_outputs = ["logprobs"] + inf_ref_outputs = ["packed_ref_logprobs"] if self.ppo.fuse_rew_ref: - inf_ref_inputs += ["packed_prompts"] + inf_ref_inputs += ["packed_prompts", "task_ids"] inf_ref_outputs += ["rewards"] inf_ref_logits = MFCDef( - name="ref_rw", + name="ref_inf", model_name="ref", mb_spec=self.ref_inf.mb_spec, interface_type=ModelInterfaceType.INFERENCE, @@ -495,7 +496,6 @@ class PPOMATHConfig(CommonExperimentConfig): args=dict( dataset_path=self.dataset.path, max_length=self.dataset.max_prompt_len, - fill_to_max_length=self.dataset.fill_to_max_length, filter_threshold=self.dataset_filter_threshold, max_filter_percentage=self.dataset_max_filter_percentage, ), diff --git a/realhf/impl/dataset/math_code_dataset.py b/realhf/impl/dataset/math_code_dataset.py index 95dc06d..e829ce6 100644 --- a/realhf/impl/dataset/math_code_dataset.py +++ b/realhf/impl/dataset/math_code_dataset.py @@ -3,6 +3,8 @@ # Licensed under the Apache License, Version 2.0 (the "License"). import json +import traceback +from collections import defaultdict from typing import Callable, Dict, Hashable, List, Optional import numpy as np @@ -13,17 +15,34 @@ from realhf.base import logging logger = logging.getLogger("Math Code Dataset") -id2info = {} + +def check_math_metadata_entries(data): + assert data["task"] == "math" + assert "query_id" in data + data["query_id"] = str(data["query_id"]) + assert isinstance(data["prompt"], str) + assert isinstance(data["solutions"], list) + for sol in data["solutions"]: + assert isinstance(sol, str) + return data def check_code_metadata_entries(data): - # TODO: check test multi task reward - pass - - -def check_math_metadata_entries(data): - # TODO: check test multi task reward - pass + assert data["task"] == "code" + assert "query_id" in data + data["query_id"] = str(data["query_id"]) + if "problem_id" not in data: + data["problem_id"] = data["query_id"] + assert isinstance(data["prompt"], str) + input_output = json.loads(data["input_output"]) + assert len(input_output["inputs"]) == len(input_output["outputs"]) + for inp, out in zip(input_output["inputs"], input_output["outputs"]): + assert isinstance(inp, str) and isinstance(out, str), ( + inp, + out, + input_output.get("fn_name"), + ) + return data def load_metadata(path): @@ -31,14 +50,26 @@ def load_metadata(path): with open(path, "r") as f: data = [json.loads(l) for l in f.readlines()] id2info = {} + omit_cnt = defaultdict(int) + task_cnt = defaultdict(int) for d in data: assert d["query_id"] not in d, (d["task"], d["query_id"]) - if d["task"] == "math": - check_math_metadata_entries(d) - elif d["task"] == "code": - check_code_metadata_entries(d) + try: + if d["task"] == "math": + d = check_math_metadata_entries(d) + elif d["task"] == "code": + d = check_code_metadata_entries(d) + except Exception as e: + logger.warning( + f"Data validation failed: query_id {d['query_id']}. " + f"Error: {traceback.format_exc()}. Omit it in the dataset." + ) + omit_cnt[d["task"]] += 1 + continue id2info[d["query_id"]] = d - return id2info + task_cnt[d["task"]] += 1 + logger.warning(f"Number of ignored data: {dict(**omit_cnt)}") + return id2info, dict(task_cnt) class MATHCodePromptDataset(torch.utils.data.Dataset): @@ -59,8 +90,7 @@ class MATHCodePromptDataset(torch.utils.data.Dataset): self._util = util self.max_length = max_length - global id2info - id2info = load_metadata(dataset_path) + id2info, task_cnt = load_metadata(dataset_path) data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder) @@ -83,7 +113,9 @@ class MATHCodePromptDataset(torch.utils.data.Dataset): indices = [ i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length ] - logger.info(f"{len(indices)} samples remain") + logger.info( + f"{len(indices)} samples remain, among them {task_cnt['math']} are math data and {task_cnt['code']} are code data" + ) self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices] self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices] @@ -183,4 +215,5 @@ else: ) print(f"size: {len(dataset)}") for d in dataloader: - print(d.ids) + # print(d.ids) + pass diff --git a/realhf/impl/model/interface/rw_interface.py b/realhf/impl/model/interface/rw_interface.py index c9140e5..9acca33 100644 --- a/realhf/impl/model/interface/rw_interface.py +++ b/realhf/impl/model/interface/rw_interface.py @@ -190,7 +190,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface): def __post_init__(self): global id2info - id2info = load_metadata(self.dataset_path) + id2info, _ = load_metadata(self.dataset_path) self.tokenizer = load_hf_tokenizer(self.tokenizer_path) if constants.parallelism_rank() == 0: logger.info(f"output_scaling: {self.output_scaling}") diff --git a/tests/experiments/test_math_ppo.py b/tests/experiments/test_math_ppo.py index b556449..7bb89fa 100644 --- a/tests/experiments/test_math_ppo.py +++ b/tests/experiments/test_math_ppo.py @@ -2,6 +2,7 @@ import os import shutil +import uuid from typing import * import pytest @@ -34,14 +35,18 @@ def math_code_dataset(request, save_path): dataset = [] for i in range(size): prompt_len = random.randint(1, max_prompt_len) - n_pairs = random.randint(1, 5) d = dict( - query_id=query_ids[i], + query_id=str(uuid.uuid4()), prompt=generate_random_sentence(prompt_len), + task=random.choice(["math", "code"]), ) + if d["task"] == "math": + d["solutions"] = [generate_random_sentence(max_resp_len)] + elif d["task"] == "code": + d["input_output"] = json.dumps(dict(inputs=["the\n"], outputs=["the\n"])) dataset.append(d) - with open(str(save_path / "math_code_dataset.json"), "w") as f: - json.dump(dataset, f) + with open(str(save_path / "math_code_dataset.jsonl"), "a") as f: + f.write(json.dumps(d) + "\n") return dataset @@ -57,7 +62,7 @@ def math_code_dataset(request, save_path): def test_ppo_symm( tmp_path_factory, tokenizer, - math_dataset, + math_code_dataset, save_path, cpu_hf_model, mconfig, @@ -96,7 +101,7 @@ def test_ppo_symm( backend="mock_train", ), dataset=PromptOnlyDatasetConfig( - path=str(save_path / "math_dataset.json"), + path=str(save_path / "math_code_dataset.jsonl"), max_prompt_len=mconfig.n_positions // 2, train_bs_n_seqs=minbs, fill_to_max_length=False, @@ -127,7 +132,7 @@ def test_ppo_symm( def test_ppo_global_reshard( tmp_path_factory, tokenizer, - math_dataset, + math_code_dataset, save_path, cpu_hf_model, mconfig, @@ -174,7 +179,7 @@ def test_ppo_global_reshard( init_from_scratch=True, ), dataset=PromptOnlyDatasetConfig( - path=str(save_path / "math_dataset.json"), + path=str(save_path / "math_code_dataset.jsonl"), max_prompt_len=mconfig.n_positions // 2, train_bs_n_seqs=minbs, fill_to_max_length=False, @@ -244,7 +249,7 @@ def test_ppo_global_reshard( def test_ppo_param_realloc_sub_device_mesh( tmp_path_factory, tokenizer, - math_dataset, + math_code_dataset, save_path, cpu_hf_model, mconfig, @@ -287,7 +292,7 @@ def test_ppo_param_realloc_sub_device_mesh( init_from_scratch=True, ), dataset=PromptOnlyDatasetConfig( - path=str(save_path / "math_dataset.json"), + path=str(save_path / "math_code_dataset.jsonl"), max_prompt_len=mconfig.n_positions // 2, train_bs_n_seqs=minbs, fill_to_max_length=False, @@ -406,7 +411,7 @@ def test_ppo_save( init_from_scratch=True, ), dataset=PromptOnlyDatasetConfig( - path=str(save_path / "math_dataset.json"), + path=str(save_path / "math_code_dataset.jsonl"), max_prompt_len=mconfig.n_positions // 2, train_bs_n_seqs=bs, fill_to_max_length=False,