This commit is contained in:
bowei.fw 2025-03-21 21:22:50 +08:00
parent eb1e8a7592
commit de8243cc78
25 changed files with 547 additions and 1067 deletions

30
create_multitask_data.py Normal file
View File

@ -0,0 +1,30 @@
import json
import random
data = []
with open("/storage/openpsi/data/code/apps/codeparrot-apps-test.jsonl", "r") as f:
code_data = [json.loads(l) for l in f.readlines()]
original_keys = list(code_data[0].keys())
print(original_keys)
for d in code_data:
print(d["starter_code"], type(d["starter_code"]))
# print(json.loads(d["solutions"])[0])
exit(0)
d["query_id"] = d["id"]
d["prompt"] = d["question"]
d["task"] = "code"
for k in original_keys:
d.pop(k)
data.append(d)
with open("/storage/openpsi/users/gjx/data/DeepScaleR/prompts.jsonl", "r") as f:
math_data = [json.loads(l) for l in f.readlines()]
for d in math_data:
data.append(dict(prompt=d["prompt"], task="math", query_id=d["query_id"]))
random.shuffle(data)
with open("/storage/openpsi/users/bowei.fw/data/code_math.jsonl", "w") as f:
for d in data:
f.write(json.dumps(d, ensure_ascii=False) + "\n")

View File

@ -291,7 +291,6 @@ The descriptions of the important parameters are as follows:
+ `MODE`: It is always `ray`, and do not change it to other values when referring to this tutorial for training.
+ `BASE_MODEL_PATH`: The path of the model.
+ `DATA_PATH`: The path of the dataset jsonl file
+ `REAL_MATH_METADATA_PATH`: Set it to the path of the json file of the math metadata, refer to troubleshooting.
+ `CLUSTER_SPEC_PATH`: Set it to the path of cluster_config.json
+ `n_nodes`: The number of nodes
@ -513,10 +512,5 @@ This error typically occurs during the vLLM generation phase and is another symp
- Reduce the training batch size or the number of answers generated per prompt. Note that this may lower sample efficiency and extend training time.
- [Switch vLLM's attention backend to xformers](https://github.com/vllm-project/vllm/issues/5376).
## Others
### How to train with other datasets
The dataset must be in `jsonl` format, with each entry containing two keys: `prompt` (a math problem) and `query_id` (a unique identifier for the problem). After preparing the dataset, update the `REAL_MATH_METADATA_PATH` with the new dataset's metadata. Metadata is a JSON file that records the solutions for each problem. The training code uses this metadata to determine if the model solved a problem correctly.

View File

@ -315,7 +315,6 @@ python3 -m realhf.apps.quickstart ppo-math --show-args
+ MODE总是为 ray参考本教程进行训练时不要改成其他值。
+ BASE_MODEL_PATH模型的路径
+ DATA_PATH数据集 jsonl 文件的路径
+ REAL_MATH_METADATA_PATH设置成数学 metadata 的 json 文件路径参考troubleshooting。
+ CLUSTER_SPEC_PATH设置成 cluster_config.json 的路径
+ n_nodes节点数量
@ -530,8 +529,3 @@ ALL_PARAMS=(
+ 减小训练batch size或者每个prompt生成的答案数量但减小后会降低样本效率、延长训练时间
+ [将vLLM的attention backend换成xformers](https://github.com/vllm-project/vllm/issues/5376)
## 其他
### 如何用其他数据集进行训练
数据集需要是是 jsonl 格式的文件,其中每一条数据需要包含两个 key分别是 prompt即一道数学问题和query_id即这道数学问题的唯一标识符。在准备好数据集后还需要根据数据集中的题目更新REAL_MATH_METADATA_PATH的内容。metadata 是一个 json 文件,记录了每道题目的答案、来源和解法。训练代码需要根据 metadata 来判断模型是否做对了一道题。

View File

@ -18,7 +18,6 @@ BASE_MODEL_PATH="$1"
# original data
DATA_PATH="$2"
REAL_MATH_METADATA_PATH="$3"
# Option 1: The experiment runs locally with subprocesses.
# MODE=local
@ -53,7 +52,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
# It's the user's responsibility to tune them appropriately.
unset CLUSTER_SPEC_PATH
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
python3 -m realhf.apps.quickstart ppo-math \
mode=$MODE \

View File

@ -17,7 +17,6 @@ BASE_MODEL_PATH="$1"
# original data
DATA_PATH="$2"
REAL_MATH_METADATA_PATH="$3"
# Option 1: The experiment runs locally with subprocesses.
# MODE=local
@ -52,7 +51,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
# It's the user's responsibility to tune them appropriately.
unset CLUSTER_SPEC_PATH
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
python3 -m realhf.apps.quickstart ppo-math \
mode=$MODE \

View File

@ -20,7 +20,6 @@ BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
# original data
DATA_PATH="/storage/datasets/${DATASET_NAME}"
REAL_CODE_METADATA_PATH="/storage/datasets/codeparrot-apps-test.jsonl"
# Option 1: The experiment runs locally with subprocesses.
# MODE=local
@ -55,7 +54,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
# It's the user's responsibility to tune them appropriately.
unset CLUSTER_SPEC_PATH
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
REAL_CODE_METADATA_PATH=${REAL_CODE_METADATA_PATH} \
FUNCTIONCALL_SERVICE_DOMAIN="" \
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
python3 -m realhf.apps.quickstart ppo-code \

View File

@ -20,7 +20,6 @@ BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
# original data
DATA_PATH="/storage/datasets/${DATASET_NAME}"
REAL_MATH_METADATA_PATH="/storage/datasets/id2info.json"
# Option 1: The experiment runs locally with subprocesses.
# MODE=local
@ -55,7 +54,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
# It's the user's responsibility to tune them appropriately.
unset CLUSTER_SPEC_PATH
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
python3 -m realhf.apps.quickstart ppo-math \
mode=$MODE \

View File

@ -20,7 +20,6 @@ BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
# original data
DATA_PATH="/storage/datasets/${DATASET_NAME}"
REAL_MATH_METADATA_PATH="/storage/datasets/id2info.json"
# Option 1: The experiment runs locally with subprocesses.
# MODE=local
@ -55,7 +54,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
# It's the user's responsibility to tune them appropriately.
unset CLUSTER_SPEC_PATH
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
python3 -m realhf.apps.quickstart ppo-math \
mode=$MODE \

View File

@ -20,7 +20,6 @@ BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
# original data
DATA_PATH="/storage/datasets/${DATASET_NAME}"
REAL_MATH_METADATA_PATH="/storage/datasets/id2info.json"
# Option 1: The experiment runs locally with subprocesses.
# MODE=local
@ -55,7 +54,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
# It's the user's responsibility to tune them appropriately.
unset CLUSTER_SPEC_PATH
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
python3 -m realhf.apps.quickstart ppo-math \
mode=$MODE \

View File

@ -5,27 +5,14 @@ import sys
import time
import traceback
from io import StringIO
from function.testing_util import run_test
from typing import Dict, List
from functioncall.base import logging
from functioncall.code.function.testing_util import run_test
logger = logging.getLogger("Functioncall")
def try_load_json(data):
"""
Attempts to load the given data as JSON. If successful, returns the parsed JSON.
Otherwise, returns None and logs an error.
"""
try:
loaded_data = json.loads(data)
return loaded_data
except json.JSONDecodeError as e:
logger.info(f"Failed to load JSON: {e}")
return data
def capture_stdout(code):
original_stdout = sys.stdout
fake_stdout = StringIO()
@ -51,7 +38,7 @@ def _temp_run(problem, generation, debug, result):
logger.debug(f"Test completed with result: {result}")
except Exception as e:
if debug:
logger.error(f"Error in _temp_run: {e}, problem:{problem}", exe_info=True)
logger.error(f"Error in _temp_run: {e}, problem:{problem}")
execution_time = time.time() - start_time
logger.info(
@ -65,7 +52,7 @@ def check_correctness(problem, generation, timeout, debug=False):
inside `run_test`"""
if debug:
result = capture_stdout(
"from function.testing_util import run_test\n"
"from functioncall.code.function.testing_util import run_test\n"
+ "run_test(sample=problem, test=generation, debug=debug)"
)
return result[0], result[1]
@ -86,7 +73,7 @@ def check_correctness(problem, generation, timeout, debug=False):
# Remark: ideally we would consider that all tests failed but we can't access number of tests here easily
# so we use 21=the average number of tests for a smaple in the test split instead
avg_number_tests = 21
result = [[-1] * avg_number_tests]
result = [[-1, ""] for _ in range(avg_number_tests)]
if debug:
logger.debug(f"Global timeout occurred, returning default result.")
if debug:
@ -99,96 +86,50 @@ def check_correctness(problem, generation, timeout, debug=False):
return result[0]
def load_problems(path):
problem_map = {}
for idx, line in enumerate(open(path, "rb")):
if line is not None:
line = line.strip().decode("utf-8")
row = json.loads(line)
query_id = str(row["id"])
problem_map[query_id] = {
"problem_id": query_id,
# "question": row["question"],
"solutions": row["solutions"],
"input_output": row["input_output"],
# "difficulty": level,
# "url": row["url"],
# "starter_code": row["starter_code"],
}
return problem_map
global_problems = load_problems(
os.getenv(
"REAL_CODE_METADATA_PATH",
"/storage/datasets/codeparrot-apps-test.jsonl",
)
)
def code_verify(generateds, query_ids, debug=False):
assert len(generateds) == len(query_ids), (
len(generateds),
len(query_ids),
)
def code_verify(id2info, generateds, query_ids, debug=True):
assert len(generateds) == len(query_ids)
problems = [id2info[qid] for qid in query_ids]
result = []
global global_problems
for idx, query_id in enumerate(query_ids):
if query_id not in global_problems:
continue
problem = global_problems[query_id]
test_code = generateds[idx]
for query_id, generated, problem in zip(query_ids, generateds, problems):
logger.debug(f"run_batch_code, query_id: {query_id}")
try:
curr_res, metadata = check_correctness(
problem=problem, generation=test_code, timeout=6000, debug=debug
problem=problem, generation=generated, timeout=6000, debug=debug
)
if any(x != True for x in curr_res):
logger.debug(
f'id:{problem["problem_id"]}, Results were not all True: {metadata}'
)
result.append(f"{query_id} failed")
logger.debug(f"id:{query_id}, Results were not all True: {metadata}")
result.append(0)
else:
# print(f"id:{problem["problem_id"]}, result : {curr_res}")
result.append(f"{query_id} success")
result.append(1)
except Exception as e:
logger.error(f"test framework exception = {repr(e)}{e}\n", exe_info=True)
result.append(f"{query_id} failed")
break
finally:
# print(f"id:{problem["problem_id"]}, result : {curr_res}")
# assert isinstance(curr_res, list)
pass
exc_info = sys.exc_info()
logger.error(
f"test framework exception = {repr(e)}{e}\n{traceback.format_exception(*exc_info)}"
)
result.append(0)
return result
if __name__ == "__main__":
path = "/storage/openpsi/data/code/apps/codeparrot-apps-test.jsonl"
data = []
with open(path, "r") as f:
code_data = [json.loads(l) for l in f.readlines()]
def create_test_params(index_list):
global global_problems
codes, query_ids = [], []
problem = code_data[0]
problem["problem_id"] = problem["id"]
id2info = {problem["problem_id"]: problem}
for index in index_list:
if str(index) not in global_problems:
continue
problem = global_problems[str(index)]
if "solutions" not in problem or not problem["solutions"]:
continue
codes.append(try_load_json(problem["solutions"])[0])
query_ids.append(str(index))
return codes, query_ids
codes, query_ids = create_test_params(list(range(805, 806)))
result = code_verify(codes, query_ids, True)
result = code_verify(
id2info,
[json.loads(problem["solutions"])[0]],
[problem["problem_id"]],
debug=False,
)
print(result)

View File

@ -1,5 +1,6 @@
import json
import os
import random
from collections import defaultdict
from functioncall.base import logging
@ -8,29 +9,14 @@ from functioncall.base.call import batch_function_call
logger = logging.getLogger("Functioncall")
def try_load_json(data):
"""
Attempts to load the given data as JSON. If successful, returns the parsed JSON.
Otherwise, returns None and logs an error.
"""
try:
loaded_data = json.loads(data)
return loaded_data
except json.JSONDecodeError as e:
# print(f"Failed to load JSON: {e}")
return data
def load_problems_with_testcase_batch(path, debug=False, test_case_batch_size=None):
def load_problems_with_testcase_batch(
id2info, query_ids, debug=False, test_case_batch_size=None
):
problem_map = defaultdict(list)
for idx, line in enumerate(open(path, "rb")):
if line is None:
continue
for idx, query_id in enumerate(query_ids):
problem = id2info[query_id]
# parse one problem
row = json.loads(line.strip().decode("utf-8"))
query_id = str(row.get("id", row.get("query_id")))
input_output = json.loads(row["input_output"]) if "input_output" in row else {}
input_output = json.loads(problem["input_output"])
inputs = input_output.get("inputs", [])
outputs = input_output.get("outputs", [])
assert len(inputs) == len(
@ -57,17 +43,14 @@ def load_problems_with_testcase_batch(path, debug=False, test_case_batch_size=No
"batche_index": batch_idx,
}
if debug:
sub_problem["solutions"] = row.get("solutions", [])
sub_problem["solutions"] = problem.get("solutions", [])
problem_map[query_id].append(sub_problem)
return problem_map
global_problems = None
def code_verify(
generateds, query_ids, debug=False, timeout=1000, timeout_for_testcase=6
id2info, generateds, query_ids, debug=False, timeout=1000, timeout_for_testcase=6
):
assert len(generateds) == len(query_ids), (
len(generateds),
@ -75,24 +58,13 @@ def code_verify(
)
payload_list = []
global global_problems
if global_problems is None:
global_problems = load_problems_with_testcase_batch(
os.getenv(
"REAL_CODE_METADATA_PATH",
"/storage/datasets/codeparrot-apps-test.jsonl",
),
id2info,
query_ids,
debug=True,
test_case_batch_size=20,
)
for idx, query_id in enumerate(query_ids):
if query_id not in global_problems:
payload_list.append(None)
logger.warning(
f"Invalid query id : {query_id}, type: {type(query_id)}, should be in problem dataset!"
)
continue
problems = global_problems[query_id]
for problem in problems:
payload_list.append(
@ -129,34 +101,28 @@ def code_verify(
if __name__ == "__main__":
path = "/storage/openpsi/data/code/apps/codeparrot-apps-test.jsonl"
data = []
with open(path, "r") as f:
code_data = [json.loads(l) for l in f.readlines()]
id2info = {}
def create_test_params(count=10):
global global_problems
if global_problems is None:
global_problems = load_problems_with_testcase_batch(
os.getenv(
"REAL_CODE_METADATA_PATH",
"/storage/datasets/codeparrot-apps-test.jsonl",
),
debug=True,
test_case_batch_size=20,
)
codes, query_ids = [], []
idx = 0
for query_id, problems in global_problems.items():
if idx >= count:
break
problem = problems[0]
if "solutions" not in problem or not problem["solutions"]:
global id2info
query_ids = []
generateds = []
cnt = 0
while cnt < count:
d = random.choice(code_data)
if not d["solutions"]:
continue
id2info[d["id"]] = d
query_ids.append(d["id"])
generateds.append(d["solutions"][0])
cnt += 1
return generateds, query_ids
codes.append(try_load_json(problem["solutions"])[0])
query_ids.append(query_id)
idx += 1
return codes, query_ids
codes, query_ids = create_test_params(100)
result = code_verify(codes, query_ids, True)
generateds, query_ids = create_test_params(100)
result = code_verify(id2info, generateds, query_ids, True)
print(result)

View File

@ -9,28 +9,9 @@ from functioncall.base.call import batch_function_call
logger = logging.getLogger("Functioncall")
def loadJson(dataDir):
with open(dataDir, "r") as f:
if dataDir.endswith(".jsonl"):
samples = [json.loads(line) for line in f.readlines()]
else:
samples = json.load(f)
return samples
id2info = None
def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=1000) -> List:
global id2info
if id2info is None:
id2info = loadJson(
os.getenv(
"REAL_MATH_MEATADATA_PATH",
"/storage/datasets/id2info.json",
)
)
def math_verify(
id2info, generateds: List, query_ids: List, batch_size=10, timeout=1000
) -> List:
assert len(generateds) == len(query_ids), (
len(generateds),
len(query_ids),
@ -95,27 +76,19 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=1000)
if __name__ == "__main__":
# sample = {
# "prompt": "",
# "query_id": "fe11b471-1aa9-4867-958f-a0a811c85f92",
# "answer": "\\boxed{-\\frac{1}{30}}",
# }
if id2info is None:
id2info = loadJson(
os.getenv(
"REAL_MATH_MEATADATA_PATH",
"/storage/datasets/id2info.json",
)
)
answers = []
query_ids = []
for id, value in id2info.items():
answers.append(value["solutions"][0])
query_ids.append(id)
sample = {
"answers": ["-\\frac{2}{3}"],
"solutions": [
"1. **Apply the operation $\\otimes$ to the innermost parentheses first:**\n \\[\n (1 \\otimes 2) \\otimes 3 = \\left(\\frac{1^2}{2}\\right) \\otimes 3 = \\frac{1}{2} \\otimes 3\n \\]\n \\[\n 1 \\otimes (2 \\otimes 3) = 1 \\otimes \\left(\\frac{2^2}{3}\\right) = 1 \\otimes \\frac{4}{3}\n \\]\n\n2. **Calculate each part using the definition of $\\otimes$:**\n \\[\n \\frac{1}{2} \\otimes 3 = \\frac{\\left(\\frac{1}{2}\\right)^2}{3} = \\frac{\\frac{1}{4}}{3} = \\frac{1}{12}\n \\]\n \\[\n 1 \\otimes \\frac{4}{3} = \\frac{1^2}{\\frac{4}{3}} = \\frac{1}{\\frac{4}{3}} = \\frac{3}{4}\n \\]\n\n3. **Subtract the two results:**\n \\[\n \\left(\\frac{1}{12}\\right) - \\left(\\frac{3}{4}\\right) = \\frac{1}{12} - \\frac{9}{12} = -\\frac{8}{12} = -\\frac{2}{3}\n \\]\n\n4. **Conclude with the final answer:**\n \\[\n \\boxed{A}\n \\]",
"\\boxed{-\\frac{2}{3}}",
],
}
id2info = {"fe11b471-1aa9-4867-958f-a0a811c85f92": sample}
start_time = time.time()
result = math_verify(answers[:200], query_ids[:200])
result = math_verify(
id2info,
sample["answers"] * 100,
["fe11b471-1aa9-4867-958f-a0a811c85f92" for _ in range(100)],
)
print(result)

View File

@ -160,8 +160,6 @@ def main_start(args, recover_count: int = 0):
CLUSTER_SPEC_PATH=cluster_spec_path,
REAL_RECOVER_RUN="1" if is_recover_run else "0",
REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0",
REAL_MATH_METADATA_PATH=os.getenv("REAL_MATH_METADATA_PATH", ""),
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
)
for k, v in BASE_ENVIRONS.items():

View File

@ -206,9 +206,9 @@ class LocalMultiProcessTest:
def init_global_constants(
num_dp,
num_mp,
num_pp,
num_dp=1,
num_mp=1,
num_pp=1,
topo=None,
model_name=None,
msid2mwid=None,
@ -217,7 +217,12 @@ def init_global_constants(
gradient_accumulation_fusion=False,
max_prompt_len=None,
is_train: bool = True,
expr_name=None,
trial_name=None,
):
expr_name = expr_name if expr_name is not None else _DEFAULT_EXPR_NAME
trial_name = trial_name if trial_name is not None else _DEFAULT_TRIAL_NAME
constants.set_experiment_trial_names(expr_name, trial_name)
model_name = model_name if model_name is not None else MODEL_NAME
if topo is None:

View File

@ -1,562 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import copy
import dataclasses
import math
import os
import pprint
from typing import *
import numpy as np
from omegaconf import DictConfig, OmegaConf
import realhf.base.logging as logging
from realhf.api.core.config import (
DatasetAbstraction,
ModelInterfaceAbstraction,
ModelInterfaceType,
)
from realhf.api.core.dfg import MFCDef, ParamReallocHook
from realhf.api.core.model_api import GenerationHyperparameters
from realhf.api.core.system_api import ExperimentConfig
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
from realhf.api.quickstart.device_mesh import DeviceMesh, MFCConfig, RPCAllocation
from realhf.api.quickstart.entrypoint import register_quickstart_exp
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
from realhf.experiments.common.common import CommonExperimentConfig
from realhf.experiments.common.utils import resolve_replica_ids, resolve_rpc_hooks
logger = logging.getLogger("PPO Code exp", "colored")
@dataclasses.dataclass
class PPOHyperparameters:
"""Configuration of PPO hyperparameters.
:param gen: Generation hyperparameters.
:type gen: GenerationHyperparameters
:param ppo_n_minibatches: Number of minibatches in each PPO update.
:type ppo_n_minibatches: int
:param kl_ctl: Coefficient of KL divergence rewards.
:type kl_ctl: float
:param discount: Discount factor.
:type discount: float
:param gae_lambda: Lambda factor in GAE.
:type gae_lambda: float
:param eps_clip: PPO actor probability ratio clipping factor.
:type eps_clip: float
:param value_eps_clip: PPO value clipping factor.
:type value_eps_clip: float
:param max_reward_clip: Maximum reward value.
:type max_reward_clip: float
:param reward_output_scaling: Scaling factor of the reward model output.
:type reward_output_scaling: float
:param reward_output_bias: Bias of the reward model output.
The number outputed by the reward model will be
CLIP((x - bias) * scaling, -max_reward_clip, max_reward_clip).
:type reward_output_bias: float
:param early_stop_imp_ratio: PPO update will be early stopped if importance ratio
exceeds this maximum value.
:type early_stop_imp_ratio: float
:param use_adaptive_kl_ctl: Whether to use adaptive KL divergence coefficient.
:type use_adaptive_kl_ctl: bool
:param adv_norm: Whether to use advantage normalization.
:type adv_norm: bool
:param value_norm: Whether to denormalize valued and normalize return predictions.
:type value_norm: bool
:param value_norm_type: Type of value normalization.
Either exponential moving average ("exp") or moving average ("ma").
:type value_norm_type: str
:param value_norm_beta: Exponential decay factor
in exponential moving average.
:type value_norm_beta: float
:param value_norm_eps: Epsilon factor in the
denominator of exponential moving average.
:type value_norm_eps: float
:param disable_value: A shortcut option to disable the critic model.
:type disable_value: bool
"""
gen: GenerationHyperparameters = dataclasses.field(
default_factory=GenerationHyperparameters
)
ppo_n_minibatches: int = 4
kl_ctl: float = 0.1
discount: float = 1.0
gae_lambda: float = 1.0
eps_clip: float = 0.2
value_eps_clip: float = 0.2
disable_value: bool = False
max_reward_clip: float = 20.0
reward_output_scaling: float = 1.0
reward_output_bias: float = 0.0
early_stop_imp_ratio: float = 5.0
use_adaptive_kl_ctl: bool = False
adv_norm: bool = True
value_norm: bool = True
value_norm_type: str = dataclasses.field(
metadata={"choices": ["exp", "ma"]}, default="exp"
)
value_norm_beta: float = 0.99995
value_norm_eps: float = 1e-5
@dataclasses.dataclass
class PPOCODEConfig(CommonExperimentConfig):
"""PPO experiment configuration.
It is a subclass of :class:`CommonExperimentConfig`,
so all CLI options in the base class are available.
We don't implement runtime evaluation for PPO.
We identify that the RLHF process is composed of four
distinct models with independent parameters and six
*model function calls* upon these models.
The four models are\:
- Actor\: The primary LLM that generates text.
- Critic\: The value function that estimates the value of a state.
- Ref\: The reference LLM that provides KL regularization.
- Rew\: The reward model that provides reward signals.
The four model function calls and their dependencies are\:
- Rollout\: Generate text from the actor model.
- InfReward\: Infer rewards from the reward model given generated text.
- InfRef\: Infer log probabilities from the reference model given generated text.
- InfValues\: Infer values from the critic model given generated text.
- TrainActor\: Train the actor model given generated text, rewards, values, and reference log probabilities.
- TrainCritic\: Train the critic model given generated text, rewards, values, and reference log probabilities.
This class resolves these dependencies under the hood.
What the users should specify are the runtime configurations
of models and allocations of *each model function call*.
:param actor: Runtime configuration of the primary LLM.
:type actor: ModelTrainEvalConfig
:param critic: Runtime configuration of the critic model of PPO.
:type critic: ModelTrainEvalConfig
:param ref: Runtime configuration of the reference LLM.
:type ref: ModelTrainEvalConfig
:param rew: Runtime configuration of the reward LLM.
:type rew: ModelTrainEvalConfig
:param actor_train: :class:`MFCConfig` for TrainActor.
:type actor_train: MFCConfig
:param critic_train: :class:`MFCConfig` for TrainCritic.
:type critic_train: MFCConfig
:param actor_gen: :class:`MFCConfig` for Rollout.
:type actor_gen: MFCConfig
:param critic_inf: :class:`MFCConfig` for InfValues.
:type critic_inf: MFCConfig
:param rew_inf: :class:`MFCConfig` for InfReward.
:type rew_inf: MFCConfig
:param ref_inf: :class:`MFCConfig` for InfRef.
:type ref_inf: MFCConfig
:param dataset: Dataset configuration.
:type dataset: PromptOnlyDatasetConfig
:param ppo: Configuration for the PPO algorithm.
:type ppo: PPOHyperparameters
:param group_size: The number of answers remained for each prompt.
:type group_size: int
:param generation_size: The number of answers sampled for each prompt.
Among them, only `group_size` samples are remained according to
the reward score, aka best-of-n sampling. If None, use `group_size`.
:type generation_size: Optional[int]
:param mask_no_eos_with_zero: Whether to mask out the reward if an
answer is truncated due to exceeding the length limit.
:type mask_no_eos_with_zero: bool
:param mask_too_long: Whether to mask out the PPO loss if an
answer is truncated due to exceeding the length limit.
:type mask_too_long: bool
:param check_verifier_status: If True, raise an error
when the reward is all-zero. This usually indicates a bug
of the verifier.
:type check_verifier_status: bool
:param group_adv_norm: Whther to use grouped advantage
normaliztion in GRPO.
:type group_adv_norm: bool
"""
actor: ModelTrainEvalConfig = dataclasses.field(
default_factory=ModelTrainEvalConfig
)
critic: ModelTrainEvalConfig = dataclasses.field(
default_factory=ModelTrainEvalConfig
)
ref: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
rew: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
# for manual allocation only
actor_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
critic_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
actor_gen: MFCConfig = dataclasses.field(default_factory=MFCConfig)
critic_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
rew_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
dataset: PromptOnlyDatasetConfig = dataclasses.field(
default_factory=PromptOnlyDatasetConfig
)
ppo: PPOHyperparameters = dataclasses.field(default_factory=PPOHyperparameters)
group_size: int = 1
generation_size: Optional[int] = None
mask_no_eos_with_zero: bool = False
rm_output_scaling: Optional[float] = None
ref_ema_eta: Optional[float] = None
group_adv_norm: bool = False
mask_too_long: bool = False
rw_type: Optional[str] = "sparse"
task: str = "code"
check_xml_format: bool = False
use_dense_reward: bool = False
reward_delta: bool = True
check_verifier_status: bool = False
dataset_filter_threshold: float = 100.0
dataset_max_filter_percentage: float = 0.0
def __post_init__(self):
self.ppo_kwargs = dict(
n_minibatches=self.ppo.ppo_n_minibatches,
kl_ctl=self.ppo.kl_ctl,
discount=self.ppo.discount,
gae_lambda=self.ppo.gae_lambda,
eps_clip=self.ppo.eps_clip,
value_eps_clip=self.ppo.value_eps_clip,
max_reward_clip=self.ppo.max_reward_clip,
adaptive_kl_ctl=self.ppo.use_adaptive_kl_ctl,
value_norm=self.ppo.value_norm,
value_norm_type=self.ppo.value_norm_type,
value_norm_beta=self.ppo.value_norm_beta,
value_norm_eps=self.ppo.value_norm_eps,
disable_value=self.ppo.disable_value,
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
)
if self.rm_output_scaling is None:
self.rm_output_scaling = self.ppo.reward_output_scaling
@property
def models(self) -> Dict[str, ModelTrainEvalConfig]:
# role to config
if self.ppo.disable_value:
return {
"actor": self.actor,
# "critic": self.critic,
"ref": self.ref,
"reward": self.rew,
}
else:
return {
"actor": self.actor,
"critic": self.critic,
"ref": self.ref,
"reward": self.rew,
}
@property
def rpcs(self):
if (
self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens
> self.actor.vllm.max_seq_len_to_capture
):
raise RuntimeError(
f"vllm max seq len to capture {self.actor.vllm.max_seq_len_to_capture} is "
f"smaller than the prompt length + generation length {self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
)
if not os.path.exists(os.getenv("REAL_CODE_METADATA_PATH")):
raise RuntimeError(
"Dataset json path REAL_CODE_METADATA_PATH does not exist."
)
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
if not (domain.startswith("http://") and ":" in domain):
raise RuntimeError(
"function call address FUNCTIONCALL_SERVICE_DOMAIN is invalid."
)
# interfaces
actor_interface = ModelInterfaceAbstraction(
"ppo_actor",
args={
**copy.deepcopy(self.ppo_kwargs),
# NOTE: to_container converts the object to a dict
# It is used for unifying the profiling API, which requires to
# pass external interface configurations in the launch command.
# Customized dataclass objects will not work in that case.
"generation_config": (
OmegaConf.to_container(self.ppo.gen, resolve=True)
if isinstance(self.ppo.gen, (OmegaConf, DictConfig))
else dataclasses.asdict(self.ppo.gen)
),
"early_stop_imp_ratio": self.ppo.early_stop_imp_ratio,
"adv_norm": self.ppo.adv_norm,
"group_size": self.group_size,
"generation_size": self.generation_size,
"group_adv_norm": self.group_adv_norm,
"mask_too_long": self.mask_too_long,
"use_dense_reward": self.use_dense_reward,
"reward_delta": self.reward_delta,
},
)
ref_interface = copy.deepcopy(actor_interface)
ref_interface.args["enable_save"] = False
critic_interface = ModelInterfaceAbstraction(
"ppo_critic",
args={
**copy.deepcopy(self.ppo_kwargs),
"group_size": self.group_size,
"mask_too_long": self.mask_too_long,
"use_dense_reward": self.use_dense_reward,
"reward_delta": self.reward_delta,
},
)
critic_interface.args.pop("eps_clip")
rw_interface = ModelInterfaceAbstraction(
"rw_code",
args=dict(
rw_type=self.rw_type,
task=self.task,
check_xml_format=self.check_xml_format,
tokenizer_path=self.actor.path,
enable_save=False,
output_scaling=self.ppo.reward_output_scaling,
rm_output_scaling=self.rm_output_scaling,
output_bias=self.ppo.reward_output_bias,
group_size=self.group_size,
check_verifier_status=self.check_verifier_status,
max_sync_length=self.ppo.gen.max_new_tokens
+ self.dataset.max_prompt_len
+ 128,
),
)
rollout = MFCDef(
name="actor_gen",
model_name="actor",
mb_spec=self.actor_gen.mb_spec,
interface_type=ModelInterfaceType.GENERATE,
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=["packed_prompts"],
output_keys=[
"seq_no_eos_mask",
"packed_input_ids",
"packed_logprobs",
"prompt_mask",
],
n_seqs=self.dataset.train_bs_n_seqs,
)
inf_reward = MFCDef(
name="rew_inf",
model_name="reward",
mb_spec=self.rew_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
interface_impl=rw_interface,
model_type=self.rew.type,
model_path=self.rew.path,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=["packed_input_ids", "packed_prompts"],
output_keys=["rewards", "dense_rewards"],
n_seqs=self.dataset.train_bs_n_seqs,
)
inf_ref_inputs = ["packed_input_ids"]
inf_ref_logits = MFCDef(
name="ref_inf",
model_name="ref",
mb_spec=self.ref_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
model_type=self.ref.type,
model_path=self.ref.path,
interface_impl=ref_interface,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=inf_ref_inputs,
output_keys=["packed_ref_logprobs"],
n_seqs=self.dataset.train_bs_n_seqs,
)
inf_values = MFCDef(
name="critic_inf",
model_name="critic",
mb_spec=self.critic_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
interface_impl=critic_interface,
model_type=self.critic.type,
model_path=self.critic.path,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=["packed_input_ids", "seq_no_eos_mask"],
output_keys=["values"],
n_seqs=self.dataset.train_bs_n_seqs,
)
train_actor_inputs = [
"packed_input_ids",
"packed_logprobs",
"packed_ref_logprobs",
"rewards",
"dense_rewards",
"values",
"prompt_mask",
"seq_no_eos_mask",
]
if self.ppo.disable_value:
train_actor_inputs.remove("values")
train_actor = MFCDef(
name="actor_train",
model_name="actor",
mb_spec=self.actor_train.mb_spec,
interface_type=ModelInterfaceType.TRAIN_STEP,
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=train_actor_inputs,
log_return_value=True,
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
n_seqs=self.dataset.train_bs_n_seqs,
)
train_critic = MFCDef(
name="critic_train",
model_name="critic",
mb_spec=self.critic_train.mb_spec,
interface_type=ModelInterfaceType.TRAIN_STEP,
interface_impl=critic_interface,
model_type=self.critic.type,
model_path=self.critic.path,
input_keys=[
"packed_input_ids",
"packed_logprobs",
"packed_ref_logprobs",
"rewards",
"dense_rewards",
"values",
"prompt_mask",
"seq_no_eos_mask",
],
log_return_value=True,
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
n_seqs=self.dataset.train_bs_n_seqs,
)
if self.ppo.disable_value:
return {
"actor_gen": rollout,
"actor_train": train_actor,
# "critic_inf": inf_values,
# "critic_train": train_critic,
"ref_inf": inf_ref_logits,
"rew_inf": inf_reward,
}
else:
return {
"actor_gen": rollout,
"actor_train": train_actor,
"critic_inf": inf_values,
"critic_train": train_critic,
"ref_inf": inf_ref_logits,
"rew_inf": inf_reward,
}
@property
def allocations(self):
if self.ppo.disable_value:
return {
"actor_gen": self.actor_gen,
"actor_train": self.actor_train,
# "critic_inf": self.critic_inf,
# "critic_train": self.critic_train,
"ref_inf": self.ref_inf,
"rew_inf": self.rew_inf,
}
else:
return {
"actor_gen": self.actor_gen,
"actor_train": self.actor_train,
"critic_inf": self.critic_inf,
"critic_train": self.critic_train,
"ref_inf": self.ref_inf,
"rew_inf": self.rew_inf,
}
@property
def datasets(self):
return [
DatasetAbstraction(
"code_prompt",
args=dict(
dataset_path=self.dataset.path,
max_length=self.dataset.max_prompt_len,
fill_to_max_length=self.dataset.fill_to_max_length,
filter_threshold=self.dataset_filter_threshold,
max_filter_percentage=self.dataset_max_filter_percentage,
),
)
]
@property
def tokenizer_name_or_path(self) -> str:
return self.actor.path
@property
def search_kwargs(self):
return {
"num_gen_tokens": self.ppo.gen.max_new_tokens,
"n_ppo_minibatches": self.ppo.ppo_n_minibatches,
"seq_len": self.dataset.max_prompt_len,
}
@property
def max_prompt_len(self):
return self.dataset.max_prompt_len
def initial_setup(self) -> ExperimentConfig:
rpc_allocs = self._get_rpc_allocations()
resolve_replica_ids(rpc_allocs, self.models)
resolve_rpc_hooks(
rpc_allocs, self.models
) # inplace modify MFCDefs in rpc allocations
pprint.pprint(rpc_allocs)
######### update ref model using ema, ref_ema_eta = 0 means fixed ref model #########
def _find_rpc(name):
return next(alloc.rpc for alloc in rpc_allocs if alloc.rpc.name == name)
# Remove the offload hook of ref_inf, because
# we need to receive parameters from peer GPUs and update it immediately.
if self.ref_ema_eta is not None:
ref_inf = _find_rpc("ref_inf")
ref_inf._post_hooks = []
# Add an unidirectional parameter reallocation hook.
actor_train = _find_rpc("actor_train")
actor_train.add_post_hook(
ParamReallocHook(
target=ref_inf.model_name,
eta=self.ref_ema_eta,
)
)
######### The main difference from normal PPO #########
model_worker = self._get_model_worker_configs(rpc_allocs)
return ExperimentConfig(
exp_ctrl=self.exp_ctrl,
wandb=self.wandb,
model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs],
model_worker=model_worker,
)
register_quickstart_exp("ppo-code", PPOCODEConfig)

View File

@ -7,8 +7,6 @@ import os
import pprint
from typing import *
from omegaconf import DictConfig, OmegaConf
import realhf.base.logging as logging
from realhf.api.core.config import (
DatasetAbstraction,
@ -21,7 +19,7 @@ from realhf.api.core.system_api import ExperimentConfig
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
from realhf.api.quickstart.device_mesh import MFCConfig
from realhf.api.quickstart.entrypoint import register_quickstart_exp
from realhf.api.quickstart.model import ModelTrainEvalConfig
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
from realhf.experiments.common.common import CommonExperimentConfig
from realhf.experiments.common.utils import (
asdict,
@ -89,8 +87,6 @@ class PPOHyperparameters:
gae_lambda: float = 1.0
eps_clip: float = 0.2
value_eps_clip: float = 0.2
disable_value: bool = False
recompute_logprob: bool = False
max_reward_clip: float = 20.0
reward_output_scaling: float = 1.0
reward_output_bias: float = 0.0
@ -104,6 +100,10 @@ class PPOHyperparameters:
value_norm_beta: float = 0.99995
value_norm_eps: float = 1e-5
disable_value: bool = False
recompute_logprob: bool = False
fuse_rew_ref: bool = True
@dataclasses.dataclass
class PPOMATHConfig(CommonExperimentConfig):
@ -197,6 +197,7 @@ class PPOMATHConfig(CommonExperimentConfig):
actor_gen: MFCConfig = dataclasses.field(default_factory=MFCConfig)
critic_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
rew_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
actor_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
dataset: PromptOnlyDatasetConfig = dataclasses.field(
@ -208,15 +209,11 @@ class PPOMATHConfig(CommonExperimentConfig):
group_size: int = 1
generation_size: Optional[int] = None
mask_no_eos_with_zero: bool = False
rm_output_scaling: Optional[float] = None
ref_ema_eta: Optional[float] = None
group_adv_norm: bool = False
mask_too_long: bool = False
rw_type: Optional[str] = "sparse"
task: str = "math"
check_xml_format: bool = False
use_dense_reward: bool = False
reward_delta: bool = True
check_verifier_status: bool = False
@ -242,9 +239,6 @@ class PPOMATHConfig(CommonExperimentConfig):
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
)
if self.rm_output_scaling is None:
self.rm_output_scaling = self.ppo.reward_output_scaling
@property
def models(self) -> Dict[str, ModelTrainEvalConfig]:
# role to config
@ -270,10 +264,6 @@ class PPOMATHConfig(CommonExperimentConfig):
f"smaller than the prompt length + generation length "
f"{self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
)
if not os.path.exists(os.getenv("REAL_MATH_METADATA_PATH")):
raise RuntimeError(
"Dataset json path REAL_MATH_METADATA_PATH does not exist."
)
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
if domain and (not (domain.startswith("http://") and ":" in domain)):
@ -297,8 +287,6 @@ class PPOMATHConfig(CommonExperimentConfig):
"generation_size": self.generation_size,
"group_adv_norm": self.group_adv_norm,
"mask_too_long": self.mask_too_long,
"use_dense_reward": self.use_dense_reward,
"reward_delta": self.reward_delta,
},
)
@ -308,29 +296,30 @@ class PPOMATHConfig(CommonExperimentConfig):
**copy.deepcopy(self.ppo_kwargs),
"group_size": self.group_size,
"mask_too_long": self.mask_too_long,
"use_dense_reward": self.use_dense_reward,
"reward_delta": self.reward_delta,
},
)
critic_interface.args.pop("eps_clip")
rw_interface = ModelInterfaceAbstraction(
"reward",
args=dict(
rw_type=self.rw_type,
task=self.task,
check_xml_format=self.check_xml_format,
tokenizer_path=self.actor.path,
enable_save=False,
output_scaling=self.ppo.reward_output_scaling,
rm_output_scaling=self.rm_output_scaling,
output_bias=self.ppo.reward_output_bias,
rw_type=self.rw_type,
check_xml_format=self.check_xml_format,
group_size=self.group_size,
check_verifier_status=self.check_verifier_status,
max_sync_length=self.ppo.gen.max_new_tokens
+ self.dataset.max_prompt_len
+ 128,
),
)
ref_interface = copy.deepcopy(actor_interface)
ref_interface.args["enable_save"] = False
if self.ppo.fuse_rew_ref:
ref_interface = ModelInterfaceAbstraction(
"fused-threading",
args=dict(interfaces=dict(rew=rw_interface, ref=ref_interface)),
)
rollout_output_keys = [
"seq_no_eos_mask",
"packed_input_ids",
@ -366,13 +355,23 @@ class PPOMATHConfig(CommonExperimentConfig):
n_seqs=self.dataset.train_bs_n_seqs,
)
inf_reward = MFCDef(
name="rew_inf",
model_name="reward",
interface_type=ModelInterfaceType.INFERENCE,
interface_impl=rw_interface,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=["packed_input_ids", "packed_prompts"],
output_keys=["rewards"],
n_seqs=self.dataset.train_bs_n_seqs,
)
# add rew param into ref MFC
inf_ref_inputs = ["packed_input_ids", "packed_prompts"]
inf_ref_outputs = ["logprobs", "rewards", "dense_rewards"]
ref_interface = copy.deepcopy(actor_interface)
ref_interface.type_ = "ref_rw"
ref_interface.args["enable_save"] = False
ref_interface.args["rew_inf_args"] = copy.deepcopy(rw_interface.args)
inf_ref_inputs = ["packed_input_ids"]
inf_ref_outputs = ["logprobs"]
if self.ppo.fuse_rew_ref:
inf_ref_inputs += ["packed_prompts"]
inf_ref_outputs += ["rewards"]
inf_ref_logits = MFCDef(
name="ref_rw",
@ -408,7 +407,6 @@ class PPOMATHConfig(CommonExperimentConfig):
"packed_logprobs",
"packed_ref_logprobs",
"rewards",
"dense_rewards",
"values",
"prompt_mask",
"seq_no_eos_mask",
@ -442,7 +440,6 @@ class PPOMATHConfig(CommonExperimentConfig):
"packed_logprobs",
"packed_ref_logprobs",
"rewards",
"dense_rewards",
"values",
"prompt_mask",
"seq_no_eos_mask",
@ -466,6 +463,8 @@ class PPOMATHConfig(CommonExperimentConfig):
rpcs.pop("critic_train")
if not self.ppo.recompute_logprob:
rpcs.pop("actor_inf")
if self.ppo.fuse_rew_ref:
rpcs.pop("rew_inf")
return rpcs
@property
@ -484,13 +483,15 @@ class PPOMATHConfig(CommonExperimentConfig):
allocs.pop("critic_train")
if not self.ppo.recompute_logprob:
allocs.pop("actor_inf")
if self.ppo.fuse_rew_ref:
allocs.pop("rew_inf")
return allocs
@property
def datasets(self):
return [
DatasetAbstraction(
"math_prompt",
"math_code_prompt",
args=dict(
dataset_path=self.dataset.path,
max_length=self.dataset.max_prompt_len,

View File

@ -0,0 +1,184 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import json
from typing import Callable, Dict, Hashable, List, Optional
import numpy as np
import torch.utils.data
from realhf.api.core import data_api
from realhf.base import logging
logger = logging.getLogger("Math Code Dataset")
id2info = {}
def check_code_metadata_entries(data):
pass
def check_math_metadata_entries(data):
pass
def load_metadata(path):
assert str(path).endswith(".jsonl"), path
with open(path, "r") as f:
data = [json.loads(l) for l in f.readlines()]
id2info = {}
for d in data:
assert d["query_id"] not in d, (d["task"], d["query_id"])
if d["task"] == "math":
check_math_metadata_entries(d)
elif d["task"] == "code":
check_code_metadata_entries(d)
id2info[d["query_id"]] = d
return id2info
class MATHCodePromptDataset(torch.utils.data.Dataset):
def __init__(
self,
util: data_api.DatasetUtility,
max_length: Optional[int] = None,
dataset_path: Optional[str] = None,
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
filter_threshold: float = 1e4,
max_filter_percentage: float = 0.0,
):
"""Required keys: prompt, query_id, task=math/code, solutions.
For code dataset, they additionally require an "input_output" key.
"""
self._util = util
self.max_length = max_length
global id2info
id2info = load_metadata(dataset_path)
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
prompts_str = [x["prompt"] for x in data]
self.ids = [x["query_id"] for x in data]
self.tasks_ids = [data_api.RL_TASKS.index(x["task"]) for x in data]
if "scores" in data[0]:
self.base_scores = [np.mean(x["scores"]) for x in data]
util.tokenizer.padding_side = "left"
prompt_encodings = util.tokenizer(
prompts_str,
truncation=True,
# max_length=max_length,
padding=False,
return_length=True,
return_attention_mask=False,
)
logger.info(f"{len(data)} samples, checking lengths (max_length={max_length})")
indices = [
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
]
logger.info(f"{len(indices)} samples remain")
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
self.ids = [
str(self.ids[idx]) + f"@idx:{idx}-{util.dp_rank}" for idx in indices
]
if "scores" in data[0]:
self.base_scores = [self.base_scores[idx] for idx in indices]
assert all(len(x) == l for x, l in zip(self.prompts, self.prompt_lengths))
logger.info(f"Number of prompts in the dataset: {len(self.prompts)}")
self.active_indices = list(range(len(self.prompts)))
self.filter_threshold = filter_threshold
self.max_filter_percentage = max_filter_percentage
@property
def util(self):
return self._util
def __len__(self):
return len(self.active_indices)
def __getitem__(self, idx):
# print(self.base_scores)
idx = self.active_indices[idx]
data = dict(
task_ids=torch.tensor([self.tasks_ids[idx]], dtype=torch.long),
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
)
if hasattr(self, "base_scores"):
data["base_scores"] = torch.tensor(
[self.base_scores[idx]], dtype=torch.float32
)
return data_api.SequenceSample.from_default(
ids=[self.ids[idx]],
seqlens=[self.prompt_lengths[idx]],
data=data,
)
def filter(self, eval_scores: Dict[Hashable, float]):
# Get all data indices that have a higher score than the threshold.
idx2scores_to_remove = {}
for pop_idx, idx in enumerate(self.active_indices):
data_id = self.ids[idx]
if data_id not in eval_scores:
continue
if eval_scores[data_id] > self.filter_threshold:
idx2scores_to_remove[pop_idx] = eval_scores[data_id]
# Control the number of samples to be removed according to max_filter_percentage.
n = int(len(self.active_indices) * self.max_filter_percentage)
indices_to_remove = sorted(
idx2scores_to_remove.keys(),
key=lambda x: idx2scores_to_remove[x],
reverse=True,
)[:n]
for pop_idx in sorted(indices_to_remove, reverse=True):
self.active_indices.pop(pop_idx)
logger.info(
f"Math prompt dataset DP rank {self.util.dp_rank} filtered"
f" {len(indices_to_remove)} samples, {len(self.active_indices)} samples remain. "
f"Original dataset size: {len(self.prompts)}. "
f"Filter threshold: {self.filter_threshold}. "
f"Max filter percentage: {self.max_filter_percentage}. "
f"Current number of eval scores: {len(eval_scores)}."
)
if not __name__ == "__main__":
data_api.register_dataset("math_code_prompt", MATHCodePromptDataset)
else:
from transformers import AutoTokenizer
dataset = MATHCodePromptDataset(
data_api.DatasetUtility(
seed=0,
dp_rank=0,
world_size=1,
tokenizer=AutoTokenizer.from_pretrained(
"/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
),
),
max_length=512,
dataset_path="/storage/openpsi/users/bowei.fw/data/code_math.jsonl",
)
dataloader = torch.utils.data.DataLoader(
dataset,
collate_fn=data_api.SequenceSample.gather,
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
batch_size=4,
shuffle=True,
)
print(f"size: {len(dataset)}")
for d in dataloader:
print(d.ids)

View File

@ -4,13 +4,10 @@ import json
import os
import signal
import subprocess
import time
import traceback
import uuid
from typing import *
from realhf.base import logging
from realhf.base.constants import parallelism_rank
logger = logging.getLogger("math parser")
@ -48,20 +45,7 @@ def loadJson(dataDir):
return samples
headers = {
"Content-Type": "application/json",
}
id2info = None
def parse_line(prompt_str, generated, query_id):
global id2info
if id2info is None:
try:
id2info = loadJson(os.environ["REAL_MATH_METADATA_PATH"])
except KeyError as e:
raise KeyError("The json file REAL_MATH_METADATA_PATH is not set") from e
def parse_line(id2info, prompt_str, generated, query_id):
info = id2info[query_id.split("@idx:")[0]]
tmp_id = str(uuid.uuid4())
@ -112,17 +96,12 @@ def parse_line(prompt_str, generated, query_id):
def parse_lines_in_parallel(
id2info,
generateds: List,
query_ids: List,
max_workers=22,
check_xml_format=False,
) -> List:
global id2info
if id2info is None:
try:
id2info = loadJson(os.environ["REAL_MATH_METADATA_PATH"])
except KeyError as e:
raise KeyError("The json file REAL_MATH_METADATA_PATH is not set") from e
assert len(generateds) == len(query_ids), (
len(generateds),
len(query_ids),
@ -204,17 +183,19 @@ def parse_lines_in_parallel(
if __name__ == "__main__":
sample = {
"prompt": "",
"query_id": "35ecd821a9e7e31da9ef0663a25347ce",
# "answer_in_box": ["\\boxed{\\frac{1}{2}}", "<think></think><answer>\\boxed{\\frac{1}{2}}</answer>"]
"answer": "<think>\n1. The problem requires us to determine the number of sequences of 144 hand movements such that every position appears exactly once and the hands return to the initial position at the end.\n2. We know that each movement involves one hand moving clockwise to the next number while the other hand stays in place.\n3. Considering the 12-hour clock, we can represent each positioning of the hands as a combination of the positions of both hands. Since both hands can be in any of the 12 positions, there are 12 x 12 = 144 different positionings.\n4. Given that at each position only one hand moves, every single movement is unique, leading to a total of 144 unique movements.\n5. These 144 movements must form a Hamiltonian Cycle, where each edge represents a valid movement between two positions.\n6. The problem thus reduces to finding a Hamiltonian cycle in a directed graph. Since the appearance of each movement is unique, it also determines the direction of the movement.\n7. We consider the edges that are rotations of each other as equivalent. Taking the rotational symmetry into account, we have 144/12 = 12 equivalence classes.\n8. The problem now is to determine the number of ways to arrange these 12 classes of rotations in a circle, which is 11 factorial.\n9. We must find the value of 11! and then compute the result modulo 1000.\n</think>\n<answer>\n320\n</answer>",
"answers": ["-\\frac{2}{3}"],
"solutions": [
"1. **Apply the operation $\\otimes$ to the innermost parentheses first:**\n \\[\n (1 \\otimes 2) \\otimes 3 = \\left(\\frac{1^2}{2}\\right) \\otimes 3 = \\frac{1}{2} \\otimes 3\n \\]\n \\[\n 1 \\otimes (2 \\otimes 3) = 1 \\otimes \\left(\\frac{2^2}{3}\\right) = 1 \\otimes \\frac{4}{3}\n \\]\n\n2. **Calculate each part using the definition of $\\otimes$:**\n \\[\n \\frac{1}{2} \\otimes 3 = \\frac{\\left(\\frac{1}{2}\\right)^2}{3} = \\frac{\\frac{1}{4}}{3} = \\frac{1}{12}\n \\]\n \\[\n 1 \\otimes \\frac{4}{3} = \\frac{1^2}{\\frac{4}{3}} = \\frac{1}{\\frac{4}{3}} = \\frac{3}{4}\n \\]\n\n3. **Subtract the two results:**\n \\[\n \\left(\\frac{1}{12}\\right) - \\left(\\frac{3}{4}\\right) = \\frac{1}{12} - \\frac{9}{12} = -\\frac{8}{12} = -\\frac{2}{3}\n \\]\n\n4. **Conclude with the final answer:**\n \\[\n \\boxed{A}\n \\]",
"\\boxed{-\\frac{2}{3}}",
],
}
id2info = {"fe11b471-1aa9-4867-958f-a0a811c85f92": sample}
print(
parse_lines_in_parallel(
# [sample["answer_in_box"][0] for _ in range(50)] + [sample["answer_in_box"][1] for _ in range(50)],
[sample["answer"]] * 100,
[sample["query_id"] for _ in range(100)],
id2info,
sample["answers"] * 100,
["fe11b471-1aa9-4867-958f-a0a811c85f92" for _ in range(100)],
max_workers=8,
check_xml_format=True,
)

View File

@ -2,7 +2,6 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import uuid
from typing import Callable, Dict, Hashable, List, Optional
import numpy as np
@ -85,140 +84,4 @@ class PromptDataset(torch.utils.data.Dataset):
)
class MATHCodePromptDataset(torch.utils.data.Dataset):
def __init__(
self,
util: data_api.DatasetUtility,
max_length: Optional[int] = None,
dataset_path: Optional[str] = None,
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
filter_threshold: float = 1e4,
max_filter_percentage: float = 0.0,
):
self._util = util
self.max_length = max_length
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
prompts_str = [x["prompt"] for x in data]
self.ids = [x["query_id"] for x in data]
self.tasks_ids = [data_api.RL_TASKS.index(x["task"]) for x in data]
if "scores" in data[0]:
self.base_scores = [np.mean(x["scores"]) for x in data]
util.tokenizer.padding_side = "left"
prompt_encodings = util.tokenizer(
prompts_str,
truncation=True,
# max_length=max_length,
padding=False,
return_length=True,
return_attention_mask=False,
)
logger.info(f"{len(data)} samples, checking lengths (max_length={max_length})")
indices = [
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
]
logger.info(f"{len(indices)} samples remain")
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
self.ids = [
str(self.ids[idx]) + f"@idx:{idx}-{util.dp_rank}" for idx in indices
]
if "scores" in data[0]:
self.base_scores = [self.base_scores[idx] for idx in indices]
assert all(len(x) == l for x, l in zip(self.prompts, self.prompt_lengths))
logger.info(f"Number of prompts in the dataset: {len(self.prompts)}")
self.active_indices = list(range(len(self.prompts)))
self.filter_threshold = filter_threshold
self.max_filter_percentage = max_filter_percentage
@property
def util(self):
return self._util
def __len__(self):
return len(self.active_indices)
def __getitem__(self, idx):
# print(self.base_scores)
idx = self.active_indices[idx]
data = dict(
task_ids=torch.tensor([self.tasks_ids[idx]], dtype=torch.long),
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
)
if hasattr(self, "base_scores"):
data["base_scores"] = torch.tensor(
[self.base_scores[idx]], dtype=torch.float32
)
return data_api.SequenceSample.from_default(
ids=[self.ids[idx]],
seqlens=[self.prompt_lengths[idx]],
data=data,
)
def filter(self, eval_scores: Dict[Hashable, float]):
# Get all data indices that have a higher score than the threshold.
idx2scores_to_remove = {}
for pop_idx, idx in enumerate(self.active_indices):
data_id = self.ids[idx]
if data_id not in eval_scores:
continue
if eval_scores[data_id] > self.filter_threshold:
idx2scores_to_remove[pop_idx] = eval_scores[data_id]
# Control the number of samples to be removed according to max_filter_percentage.
n = int(len(self.active_indices) * self.max_filter_percentage)
indices_to_remove = sorted(
idx2scores_to_remove.keys(),
key=lambda x: idx2scores_to_remove[x],
reverse=True,
)[:n]
for pop_idx in sorted(indices_to_remove, reverse=True):
self.active_indices.pop(pop_idx)
logger.info(
f"Math prompt dataset DP rank {self.util.dp_rank} filtered"
f" {len(indices_to_remove)} samples, {len(self.active_indices)} samples remain. "
f"Original dataset size: {len(self.prompts)}. "
f"Filter threshold: {self.filter_threshold}. "
f"Max filter percentage: {self.max_filter_percentage}. "
f"Current number of eval scores: {len(eval_scores)}."
)
if not __name__ == "__main__":
data_api.register_dataset("prompt", PromptDataset)
data_api.register_dataset("math_code_prompt", MATHCodePromptDataset)
else:
from transformers import AutoTokenizer
dataset = MATHCodePromptDataset(
data_api.DatasetUtility(
seed=0,
dp_rank=0,
world_size=1,
tokenizer=AutoTokenizer.from_pretrained(
"/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
),
),
max_length=512,
dataset_path="/storage/openpsi/users/bowei.fw/data/code_math.jsonl",
)
dataloader = torch.utils.data.DataLoader(
dataset,
collate_fn=data_api.SequenceSample.gather,
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
batch_size=4,
shuffle=True,
)
print(f"size: {len(dataset)}")
for d in dataloader:
print(d.ids)

View File

@ -67,35 +67,5 @@ class FusedThreadingForwardInterface(model_api.ModelInterface):
return final_result
# Mock methods for profiling only.
def _mock_inference(
self,
model: model_api.Model,
data: SequenceSample,
) -> SequenceSample:
prompt_lens = flat2d(data.seqlens["packed_prompts"])
seqlens = [x + 1024 for x in prompt_lens]
module = model.module
if not isinstance(module, ReaLModel):
module = module.module
mconfig = module.config
packed_input_ids = torch.randint(
0,
mconfig.vocab_size,
(sum(seqlens),),
dtype=torch.long,
device=model.device,
)
n_tasks = len(RL_TASKS)
task_ids = torch.randint(
0, n_tasks, (data.bs,), dtype=torch.long, device=model.device
)
return SequenceSample.from_default(
seqlens=seqlens,
ids=data.ids,
data=dict(packed_input_ids=packed_input_ids, task_ids=task_ids),
)
model_api.register_interface("fused-threading", FusedThreadingForwardInterface)

View File

@ -18,7 +18,7 @@ import realhf.base.logging as logging
import realhf.impl.model.utils.ppo_functional as ppo_functional
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
from realhf.base.datapack import flat2d
from realhf.impl.model.interface.math_parser import parse_lines_in_parallel
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.nn.real_llm_generate import concat_prompt_to_generation_output
from realhf.impl.model.utils.functional import (

View File

@ -18,6 +18,7 @@ import torch.distributed as dist
import realhf.api.core.model_api as model_api
import realhf.base.logging as logging
from functioncall.code.local_verify import code_verify as local_code_verify
from functioncall.code.verify import code_verify
from functioncall.math.verify import math_verify
from realhf.api.core.data_api import (
@ -28,14 +29,14 @@ from realhf.api.core.data_api import (
)
from realhf.base import constants
from realhf.base.datapack import flat2d
from realhf.impl.model.interface.math_parser import (
parse_lines_in_parallel as math_verify_local,
)
from realhf.impl.dataset.math_code_dataset import id2info
from realhf.impl.dataset.math_parser import parse_lines_in_parallel as math_verify_local
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else math_verify_local
code_verify_call = code_verify if ENABLE_FUNCTION_CALL else local_code_verify
class VerifierException(Exception):
@ -61,6 +62,7 @@ def extract_python_code(text, min_length=20, strict_syntax=True):
valid_blocks.append(clean_block)
if not valid_blocks:
logger.warning(f"failed to extract python code from {text}")
return None
# return the last code block
return valid_blocks[-1]
@ -120,14 +122,20 @@ def check_with_elementtree(text):
def dispatch_reward_calculation(task, answers, query_id_strs) -> List:
global id2info
assert len(answers) == len(query_id_strs)
format_rewards = []
if task == "math":
format_rewards = math_verify_call(answers, query_id_strs)
format_rewards = math_verify_call(id2info, answers, query_id_strs)
elif task == "code":
codes = [extract_python_code(_answer) for _answer in answers]
format_rewards = code_verify(codes, query_id_strs)
assert len(format_rewards) == len(answers), task
format_rewards = code_verify_call(id2info, codes, query_id_strs)
assert len(format_rewards) == len(answers), (
task,
len(format_rewards),
len(answers),
answers,
)
return format_rewards
@ -167,10 +175,9 @@ def retokenize_and_verify(
@dataclasses.dataclass
class MultiTaskCPURewardInterface(model_api.ModelInterface):
class MultiTaskRewardInterface(model_api.ModelInterface):
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
output_scaling: float = 1.0
output_scaling: float = 1.0
output_bias: float = 0.0
rw_type: str = "sparse"
check_xml_format: bool = False
@ -183,21 +190,15 @@ class MultiTaskCPURewardInterface(model_api.ModelInterface):
logger.info(f"output_scaling: {self.output_scaling}")
logger.info(f"output_bias: {self.output_bias}")
logger.info(f"rw_type: {self.rw_type}")
if (
constants.data_parallel_world_size()
< constants.parallelism_group_size()
):
logger.warning(
"There's no reason to use tensor and pipeline parallelism for CPU reward."
)
def _dispatch_tasks(self, data: SequenceSample) -> Tuple[Dict, Dict]:
xs = data.unpack()
dispatched = {}
dispatched_indices = {}
for task_idx, task_name in enumerate(RL_TASKS):
indices = (data.data["task_ids"] == task_idx).numpy().tolist()
if any(indices):
indices = (data.data["task_ids"] == task_idx).numpy().nonzero()[0].tolist()
print("============", task_idx, task_name, indices)
if len(indices) > 0:
dispatched[task_name] = SequenceSample.gather([xs[i] for i in indices])
dispatched_indices[task_name] = indices
@ -208,7 +209,9 @@ class MultiTaskCPURewardInterface(model_api.ModelInterface):
) -> SequenceSample:
xs = [None for _ in range(bs)]
for task_name, indices in dispatched_indices.items():
print(task_name, indices)
xxs = results[task_name].unpack()
assert len(indices) == len(xxs), (len(indices), len(xxs))
for i, xx in zip(indices, xxs):
xs[i] = xx
assert all(xs)
@ -260,7 +263,7 @@ class MultiTaskCPURewardInterface(model_api.ModelInterface):
self.log_rewards_to_file(task_type, model, prompt_strs, seq_strs, scores)
# NOTE: a place holder
dense_scores = packed_input_ids.new_zeros(dtype=torch.float32)
dense_scores = torch.zeros_like(packed_input_ids, dtype=torch.float32)
res = SequenceSample(
keys=["rewards", "dense_rewards"],
@ -368,6 +371,8 @@ class MultiTaskCPURewardInterface(model_api.ModelInterface):
mb_spec: MicroBatchSpec,
) -> SequenceSample | None:
task_data, dispatch_indices = self._dispatch_tasks(data)
for d in task_data.values():
print(d.bs)
assert self.rw_type == "sparse"
@ -377,11 +382,13 @@ class MultiTaskCPURewardInterface(model_api.ModelInterface):
try:
result = func(*args, **kwargs)
except Exception as e:
raise asyncio.CancelledError(f"{task_type} task failed: {e}") from e
raise asyncio.CancelledError(
f"[{task_type}] task failed: {e}"
) from e
finally:
duration = time.perf_counter() - start_time
logger.info(f"[{task_type}] time cost: {duration:.4f}s")
return result
return task_type, result
return _wrapped_func
@ -391,11 +398,12 @@ class MultiTaskCPURewardInterface(model_api.ModelInterface):
task_func = _task_func(self.calculate_task_reward, task_type)
task_args = (model, d, mb_spec, task_type)
task = asyncio.create_task(asyncio.to_thread(task_func, *task_args))
tasks.append((task_type, task))
tasks.append(task)
task_results = {}
results = await asyncio.gather(*tasks)
for task_type, result in results:
task_results = {}
for res in results:
task_type, result = res
task_results[task_type] = result
return task_results
@ -410,5 +418,46 @@ class MultiTaskCPURewardInterface(model_api.ModelInterface):
return final_result
def _mock_inference(
self,
model: model_api.Model,
data: SequenceSample,
) -> SequenceSample:
from realhf.base.testing import TESTING_MODEL_VOCAB_SIZE
model_api.register_interface("reward", MultiTaskCPURewardInterface)
prompt_lens = flat2d(data.seqlens["packed_prompts"])
task_ids = data.data["task_ids"].cpu().numpy().tolist()
seqlens = []
offset = 0
seq = []
for plen, task_id in zip(prompt_lens, task_ids):
seq += [data.data["packed_prompts"][offset : offset + plen]]
offset += plen
if task_id == RL_TASKS.index("math"):
answer_str = (
"something unimportant but the answer is \\boxed{-\\frac{2}{3}}"
)
elif task_id == RL_TASKS.index("code"):
answer_str = (
"```python\ninput()\nprint(1)\nprint(1)\nprint(1)\nprint(1)\n```"
)
else:
answer_str = "something unimportant"
encoding = model.tokenizer(
[answer_str], add_special_tokens=True, return_attention_mask=False
)
ans = torch.tensor(encoding["input_ids"], dtype=torch.long).flatten()
seq += [ans]
seqlens.append(plen + len(ans))
x = SequenceSample.from_default(
seqlens=seqlens,
ids=data.ids,
data=dict(packed_input_ids=torch.cat(seq)),
)
data.update_(x)
return data
model_api.register_interface("reward", MultiTaskRewardInterface)

View File

@ -618,8 +618,6 @@ class RayController:
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""),
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
REAL_MATH_METADATA_PATH=os.environ.get("REAL_MATH_METADATA_PATH", ""),
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),

View File

@ -27,9 +27,7 @@ def model_class(request):
@pytest.fixture(params=[testing.TESTING_DATASET_SIZE])
def math_dataset(request, save_path):
with open(os.getenv("REAL_MATH_METADATA_PATH"), "r") as f:
query_ids = list(json.load(f).keys())
def math_code_dataset(request, save_path):
size = request.param
max_prompt_len = 8
max_resp_len = 8
@ -42,7 +40,7 @@ def math_dataset(request, save_path):
prompt=generate_random_sentence(prompt_len),
)
dataset.append(d)
with open(str(save_path / "math_dataset.json"), "w") as f:
with open(str(save_path / "math_code_dataset.json"), "w") as f:
json.dump(dataset, f)
return dataset
@ -51,9 +49,9 @@ def math_dataset(request, save_path):
"dp,pp,mp",
[
(1, 1, 1),
(2, 1, 2),
(1, 2, 1),
(1, 1, 2),
# (2, 1, 2),
# (1, 2, 1),
# (1, 1, 2),
],
)
def test_ppo_symm(
@ -97,11 +95,6 @@ def test_ppo_symm(
init_critic_from_actor=True,
backend="mock_train",
),
rew=ModelTrainEvalConfig(
path=str(save_path),
init_critic_from_actor=True,
init_from_scratch=True,
),
dataset=PromptOnlyDatasetConfig(
path=str(save_path / "math_dataset.json"),
max_prompt_len=mconfig.n_positions // 2,
@ -121,6 +114,7 @@ def test_ppo_symm(
run_test_exp(exp_cfg)
@pytest.mark.skip("")
# The global resharding strategy, where all MFCs
# occupy the same device mesh but with different
# parallelization strategies.
@ -242,6 +236,7 @@ def test_ppo_global_reshard(
run_test_exp(exp_cfg)
@pytest.mark.skip("")
# Actor/critic train and ref_inf/rew_inf are on disjoint
# device meshes and executed concurrently.
@pytest.mark.parametrize("actor_gen", [(2, 2, 1)])
@ -358,6 +353,7 @@ def test_ppo_param_realloc_sub_device_mesh(
run_test_exp(exp_cfg)
@pytest.mark.skip("")
@pytest.mark.parametrize("freq_step", [3, 4, 7])
@pytest.mark.parametrize("freq_epoch", [1, 2, 3])
@pytest.mark.parametrize("bs", [30, 80, 100])

View File

@ -0,0 +1,111 @@
# Copyright 2025 Ant Group Inc. All Rights Reserved.
import os
import shutil
import uuid
from typing import *
import pytest
import torch.distributed as dist
from torch.utils.data.dataloader import DataLoader
from realhf.api.core.data_api import (
DatasetUtility,
MicroBatchSpec,
SequenceSample,
load_hf_tokenizer,
)
from realhf.api.core.model_api import FinetuneSpec, Model
from realhf.base import constants, network, testing
from tests.fixtures import *
@pytest.fixture(params=[testing.TESTING_DATASET_SIZE])
def math_code_dataset(request, save_path):
size = request.param
max_prompt_len = 8
dataset = []
for i in range(size):
prompt_len = random.randint(1, max_prompt_len)
n_pairs = random.randint(1, 5)
if random.random() < 0.5:
d = dict(
task="code",
query_id=str(uuid.uuid4()),
prompt=generate_random_sentence(prompt_len),
problem_id=str(uuid.uuid4()),
input_output=json.dumps(
{"inputs": [1] * 8, "outputs": ["1\n1\n1\n1\n"] * 8}
),
solutions=json.dumps(
["```python\ninput()\nprint(1)\nprint(1)\nprint(1)\nprint(1)\n```"]
* 3
),
difficulty=random.random() * 10,
)
else:
d = dict(
task="math",
query_id=str(uuid.uuid4()),
prompt=generate_random_sentence(prompt_len),
answers=["-\\frac{2}{3}"],
solutions=["-\\frac{2}{3}"],
)
dataset.append(d)
with open(str(save_path / "math_code_dataset.jsonl"), "w") as f:
f.write("\n".join([json.dumps(d) for d in dataset]))
return dataset
@pytest.mark.parametrize(
"tokenizer_path", ["/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"]
)
def test_multi_task_reward_interface(save_path, tokenizer_path, math_code_dataset):
from realhf.impl.dataset.math_code_dataset import MATHCodePromptDataset
dist.init_process_group(
rank=0, world_size=1, init_method=f"tcp://localhost:{network.find_free_port()}"
)
testing.init_global_constants()
dataset = MATHCodePromptDataset(
DatasetUtility(
seed=0,
dp_rank=0,
world_size=1,
tokenizer=load_hf_tokenizer(tokenizer_path),
),
max_length=512,
dataset_path=str(save_path / "math_code_dataset.jsonl"),
)
dataloader = DataLoader(
dataset,
collate_fn=SequenceSample.gather,
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
batch_size=4,
shuffle=True,
)
from realhf.impl.model.interface.rw_interface import MultiTaskRewardInterface
with constants.model_scope(testing.MODEL_NAME):
interface = MultiTaskRewardInterface(
tokenizer_path=tokenizer_path,
group_size=1,
check_verifier_status=False,
)
model = Model(
name="test",
module=None,
tokenizer=load_hf_tokenizer(tokenizer_path),
device=torch.device("cpu"),
ft_spec=FinetuneSpec(
total_train_epochs=1, dataset_size=100, train_batch_size=3
),
)
for d in dataloader:
d = interface.mock("inference", model, d)
rewards = interface.inference(model, d, mb_spec=MicroBatchSpec())
d.update_(rewards)
print("success")