mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
parent
eb1e8a7592
commit
de8243cc78
|
@ -0,0 +1,30 @@
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
|
||||||
|
data = []
|
||||||
|
with open("/storage/openpsi/data/code/apps/codeparrot-apps-test.jsonl", "r") as f:
|
||||||
|
code_data = [json.loads(l) for l in f.readlines()]
|
||||||
|
|
||||||
|
original_keys = list(code_data[0].keys())
|
||||||
|
print(original_keys)
|
||||||
|
for d in code_data:
|
||||||
|
print(d["starter_code"], type(d["starter_code"]))
|
||||||
|
# print(json.loads(d["solutions"])[0])
|
||||||
|
exit(0)
|
||||||
|
d["query_id"] = d["id"]
|
||||||
|
d["prompt"] = d["question"]
|
||||||
|
d["task"] = "code"
|
||||||
|
for k in original_keys:
|
||||||
|
d.pop(k)
|
||||||
|
data.append(d)
|
||||||
|
|
||||||
|
with open("/storage/openpsi/users/gjx/data/DeepScaleR/prompts.jsonl", "r") as f:
|
||||||
|
math_data = [json.loads(l) for l in f.readlines()]
|
||||||
|
|
||||||
|
for d in math_data:
|
||||||
|
data.append(dict(prompt=d["prompt"], task="math", query_id=d["query_id"]))
|
||||||
|
|
||||||
|
random.shuffle(data)
|
||||||
|
with open("/storage/openpsi/users/bowei.fw/data/code_math.jsonl", "w") as f:
|
||||||
|
for d in data:
|
||||||
|
f.write(json.dumps(d, ensure_ascii=False) + "\n")
|
|
@ -291,7 +291,6 @@ The descriptions of the important parameters are as follows:
|
||||||
+ `MODE`: It is always `ray`, and do not change it to other values when referring to this tutorial for training.
|
+ `MODE`: It is always `ray`, and do not change it to other values when referring to this tutorial for training.
|
||||||
+ `BASE_MODEL_PATH`: The path of the model.
|
+ `BASE_MODEL_PATH`: The path of the model.
|
||||||
+ `DATA_PATH`: The path of the dataset jsonl file
|
+ `DATA_PATH`: The path of the dataset jsonl file
|
||||||
+ `REAL_MATH_METADATA_PATH`: Set it to the path of the json file of the math metadata, refer to troubleshooting.
|
|
||||||
+ `CLUSTER_SPEC_PATH`: Set it to the path of cluster_config.json
|
+ `CLUSTER_SPEC_PATH`: Set it to the path of cluster_config.json
|
||||||
|
|
||||||
+ `n_nodes`: The number of nodes
|
+ `n_nodes`: The number of nodes
|
||||||
|
@ -513,10 +512,5 @@ This error typically occurs during the vLLM generation phase and is another symp
|
||||||
- Reduce the training batch size or the number of answers generated per prompt. Note that this may lower sample efficiency and extend training time.
|
- Reduce the training batch size or the number of answers generated per prompt. Note that this may lower sample efficiency and extend training time.
|
||||||
- [Switch vLLM's attention backend to xformers](https://github.com/vllm-project/vllm/issues/5376).
|
- [Switch vLLM's attention backend to xformers](https://github.com/vllm-project/vllm/issues/5376).
|
||||||
|
|
||||||
## Others
|
|
||||||
|
|
||||||
### How to train with other datasets
|
|
||||||
|
|
||||||
The dataset must be in `jsonl` format, with each entry containing two keys: `prompt` (a math problem) and `query_id` (a unique identifier for the problem). After preparing the dataset, update the `REAL_MATH_METADATA_PATH` with the new dataset's metadata. Metadata is a JSON file that records the solutions for each problem. The training code uses this metadata to determine if the model solved a problem correctly.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -315,7 +315,6 @@ python3 -m realhf.apps.quickstart ppo-math --show-args
|
||||||
+ MODE:总是为 ray,参考本教程进行训练时不要改成其他值。
|
+ MODE:总是为 ray,参考本教程进行训练时不要改成其他值。
|
||||||
+ BASE_MODEL_PATH:模型的路径
|
+ BASE_MODEL_PATH:模型的路径
|
||||||
+ DATA_PATH:数据集 jsonl 文件的路径
|
+ DATA_PATH:数据集 jsonl 文件的路径
|
||||||
+ REAL_MATH_METADATA_PATH:设置成数学 metadata 的 json 文件路径,参考troubleshooting。
|
|
||||||
+ CLUSTER_SPEC_PATH:设置成 cluster_config.json 的路径
|
+ CLUSTER_SPEC_PATH:设置成 cluster_config.json 的路径
|
||||||
|
|
||||||
+ n_nodes:节点数量
|
+ n_nodes:节点数量
|
||||||
|
@ -530,8 +529,3 @@ ALL_PARAMS=(
|
||||||
+ 减小训练batch size或者每个prompt生成的答案数量,但减小后会降低样本效率、延长训练时间
|
+ 减小训练batch size或者每个prompt生成的答案数量,但减小后会降低样本效率、延长训练时间
|
||||||
+ [将vLLM的attention backend换成xformers](https://github.com/vllm-project/vllm/issues/5376)
|
+ [将vLLM的attention backend换成xformers](https://github.com/vllm-project/vllm/issues/5376)
|
||||||
|
|
||||||
## 其他
|
|
||||||
|
|
||||||
### 如何用其他数据集进行训练
|
|
||||||
|
|
||||||
数据集需要是是 jsonl 格式的文件,其中每一条数据需要包含两个 key,分别是 prompt,即一道数学问题,和query_id,即这道数学问题的唯一标识符。在准备好数据集后,还需要根据数据集中的题目更新REAL_MATH_METADATA_PATH的内容。metadata 是一个 json 文件,记录了每道题目的答案、来源和解法。训练代码需要根据 metadata 来判断模型是否做对了一道题。
|
|
||||||
|
|
|
@ -18,7 +18,6 @@ BASE_MODEL_PATH="$1"
|
||||||
|
|
||||||
# original data
|
# original data
|
||||||
DATA_PATH="$2"
|
DATA_PATH="$2"
|
||||||
REAL_MATH_METADATA_PATH="$3"
|
|
||||||
|
|
||||||
# Option 1: The experiment runs locally with subprocesses.
|
# Option 1: The experiment runs locally with subprocesses.
|
||||||
# MODE=local
|
# MODE=local
|
||||||
|
@ -53,7 +52,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
|
||||||
# It's the user's responsibility to tune them appropriately.
|
# It's the user's responsibility to tune them appropriately.
|
||||||
unset CLUSTER_SPEC_PATH
|
unset CLUSTER_SPEC_PATH
|
||||||
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
||||||
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
|
|
||||||
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
||||||
python3 -m realhf.apps.quickstart ppo-math \
|
python3 -m realhf.apps.quickstart ppo-math \
|
||||||
mode=$MODE \
|
mode=$MODE \
|
||||||
|
|
|
@ -17,7 +17,6 @@ BASE_MODEL_PATH="$1"
|
||||||
|
|
||||||
# original data
|
# original data
|
||||||
DATA_PATH="$2"
|
DATA_PATH="$2"
|
||||||
REAL_MATH_METADATA_PATH="$3"
|
|
||||||
|
|
||||||
# Option 1: The experiment runs locally with subprocesses.
|
# Option 1: The experiment runs locally with subprocesses.
|
||||||
# MODE=local
|
# MODE=local
|
||||||
|
@ -52,7 +51,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
|
||||||
# It's the user's responsibility to tune them appropriately.
|
# It's the user's responsibility to tune them appropriately.
|
||||||
unset CLUSTER_SPEC_PATH
|
unset CLUSTER_SPEC_PATH
|
||||||
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
||||||
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
|
|
||||||
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
||||||
python3 -m realhf.apps.quickstart ppo-math \
|
python3 -m realhf.apps.quickstart ppo-math \
|
||||||
mode=$MODE \
|
mode=$MODE \
|
||||||
|
|
|
@ -20,7 +20,6 @@ BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
|
||||||
|
|
||||||
# original data
|
# original data
|
||||||
DATA_PATH="/storage/datasets/${DATASET_NAME}"
|
DATA_PATH="/storage/datasets/${DATASET_NAME}"
|
||||||
REAL_CODE_METADATA_PATH="/storage/datasets/codeparrot-apps-test.jsonl"
|
|
||||||
|
|
||||||
# Option 1: The experiment runs locally with subprocesses.
|
# Option 1: The experiment runs locally with subprocesses.
|
||||||
# MODE=local
|
# MODE=local
|
||||||
|
@ -55,7 +54,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
|
||||||
# It's the user's responsibility to tune them appropriately.
|
# It's the user's responsibility to tune them appropriately.
|
||||||
unset CLUSTER_SPEC_PATH
|
unset CLUSTER_SPEC_PATH
|
||||||
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
||||||
REAL_CODE_METADATA_PATH=${REAL_CODE_METADATA_PATH} \
|
|
||||||
FUNCTIONCALL_SERVICE_DOMAIN="" \
|
FUNCTIONCALL_SERVICE_DOMAIN="" \
|
||||||
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
||||||
python3 -m realhf.apps.quickstart ppo-code \
|
python3 -m realhf.apps.quickstart ppo-code \
|
||||||
|
|
|
@ -20,7 +20,6 @@ BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
|
||||||
|
|
||||||
# original data
|
# original data
|
||||||
DATA_PATH="/storage/datasets/${DATASET_NAME}"
|
DATA_PATH="/storage/datasets/${DATASET_NAME}"
|
||||||
REAL_MATH_METADATA_PATH="/storage/datasets/id2info.json"
|
|
||||||
|
|
||||||
# Option 1: The experiment runs locally with subprocesses.
|
# Option 1: The experiment runs locally with subprocesses.
|
||||||
# MODE=local
|
# MODE=local
|
||||||
|
@ -55,7 +54,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
|
||||||
# It's the user's responsibility to tune them appropriately.
|
# It's the user's responsibility to tune them appropriately.
|
||||||
unset CLUSTER_SPEC_PATH
|
unset CLUSTER_SPEC_PATH
|
||||||
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
||||||
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
|
|
||||||
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
||||||
python3 -m realhf.apps.quickstart ppo-math \
|
python3 -m realhf.apps.quickstart ppo-math \
|
||||||
mode=$MODE \
|
mode=$MODE \
|
||||||
|
|
|
@ -20,7 +20,6 @@ BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
|
||||||
|
|
||||||
# original data
|
# original data
|
||||||
DATA_PATH="/storage/datasets/${DATASET_NAME}"
|
DATA_PATH="/storage/datasets/${DATASET_NAME}"
|
||||||
REAL_MATH_METADATA_PATH="/storage/datasets/id2info.json"
|
|
||||||
|
|
||||||
# Option 1: The experiment runs locally with subprocesses.
|
# Option 1: The experiment runs locally with subprocesses.
|
||||||
# MODE=local
|
# MODE=local
|
||||||
|
@ -55,7 +54,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
|
||||||
# It's the user's responsibility to tune them appropriately.
|
# It's the user's responsibility to tune them appropriately.
|
||||||
unset CLUSTER_SPEC_PATH
|
unset CLUSTER_SPEC_PATH
|
||||||
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
||||||
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
|
|
||||||
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
||||||
python3 -m realhf.apps.quickstart ppo-math \
|
python3 -m realhf.apps.quickstart ppo-math \
|
||||||
mode=$MODE \
|
mode=$MODE \
|
||||||
|
|
|
@ -20,7 +20,6 @@ BASE_MODEL_PATH="/storage/models/${MODEL_NAME}"
|
||||||
|
|
||||||
# original data
|
# original data
|
||||||
DATA_PATH="/storage/datasets/${DATASET_NAME}"
|
DATA_PATH="/storage/datasets/${DATASET_NAME}"
|
||||||
REAL_MATH_METADATA_PATH="/storage/datasets/id2info.json"
|
|
||||||
|
|
||||||
# Option 1: The experiment runs locally with subprocesses.
|
# Option 1: The experiment runs locally with subprocesses.
|
||||||
# MODE=local
|
# MODE=local
|
||||||
|
@ -55,7 +54,6 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
|
||||||
# It's the user's responsibility to tune them appropriately.
|
# It's the user's responsibility to tune them appropriately.
|
||||||
unset CLUSTER_SPEC_PATH
|
unset CLUSTER_SPEC_PATH
|
||||||
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
||||||
REAL_MATH_METADATA_PATH=${REAL_MATH_METADATA_PATH} \
|
|
||||||
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
||||||
python3 -m realhf.apps.quickstart ppo-math \
|
python3 -m realhf.apps.quickstart ppo-math \
|
||||||
mode=$MODE \
|
mode=$MODE \
|
||||||
|
|
|
@ -5,27 +5,14 @@ import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from typing import Dict, List
|
||||||
from function.testing_util import run_test
|
|
||||||
|
|
||||||
from functioncall.base import logging
|
from functioncall.base import logging
|
||||||
|
from functioncall.code.function.testing_util import run_test
|
||||||
|
|
||||||
logger = logging.getLogger("Functioncall")
|
logger = logging.getLogger("Functioncall")
|
||||||
|
|
||||||
|
|
||||||
def try_load_json(data):
|
|
||||||
"""
|
|
||||||
Attempts to load the given data as JSON. If successful, returns the parsed JSON.
|
|
||||||
Otherwise, returns None and logs an error.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
loaded_data = json.loads(data)
|
|
||||||
return loaded_data
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.info(f"Failed to load JSON: {e}")
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def capture_stdout(code):
|
def capture_stdout(code):
|
||||||
original_stdout = sys.stdout
|
original_stdout = sys.stdout
|
||||||
fake_stdout = StringIO()
|
fake_stdout = StringIO()
|
||||||
|
@ -51,7 +38,7 @@ def _temp_run(problem, generation, debug, result):
|
||||||
logger.debug(f"Test completed with result: {result}")
|
logger.debug(f"Test completed with result: {result}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if debug:
|
if debug:
|
||||||
logger.error(f"Error in _temp_run: {e}, problem:{problem}", exe_info=True)
|
logger.error(f"Error in _temp_run: {e}, problem:{problem}")
|
||||||
|
|
||||||
execution_time = time.time() - start_time
|
execution_time = time.time() - start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -65,7 +52,7 @@ def check_correctness(problem, generation, timeout, debug=False):
|
||||||
inside `run_test`"""
|
inside `run_test`"""
|
||||||
if debug:
|
if debug:
|
||||||
result = capture_stdout(
|
result = capture_stdout(
|
||||||
"from function.testing_util import run_test\n"
|
"from functioncall.code.function.testing_util import run_test\n"
|
||||||
+ "run_test(sample=problem, test=generation, debug=debug)"
|
+ "run_test(sample=problem, test=generation, debug=debug)"
|
||||||
)
|
)
|
||||||
return result[0], result[1]
|
return result[0], result[1]
|
||||||
|
@ -86,7 +73,7 @@ def check_correctness(problem, generation, timeout, debug=False):
|
||||||
# Remark: ideally we would consider that all tests failed but we can't access number of tests here easily
|
# Remark: ideally we would consider that all tests failed but we can't access number of tests here easily
|
||||||
# so we use 21=the average number of tests for a smaple in the test split instead
|
# so we use 21=the average number of tests for a smaple in the test split instead
|
||||||
avg_number_tests = 21
|
avg_number_tests = 21
|
||||||
result = [[-1] * avg_number_tests]
|
result = [[-1, ""] for _ in range(avg_number_tests)]
|
||||||
if debug:
|
if debug:
|
||||||
logger.debug(f"Global timeout occurred, returning default result.")
|
logger.debug(f"Global timeout occurred, returning default result.")
|
||||||
if debug:
|
if debug:
|
||||||
|
@ -99,96 +86,50 @@ def check_correctness(problem, generation, timeout, debug=False):
|
||||||
return result[0]
|
return result[0]
|
||||||
|
|
||||||
|
|
||||||
def load_problems(path):
|
def code_verify(id2info, generateds, query_ids, debug=True):
|
||||||
problem_map = {}
|
assert len(generateds) == len(query_ids)
|
||||||
for idx, line in enumerate(open(path, "rb")):
|
problems = [id2info[qid] for qid in query_ids]
|
||||||
if line is not None:
|
|
||||||
line = line.strip().decode("utf-8")
|
|
||||||
row = json.loads(line)
|
|
||||||
|
|
||||||
query_id = str(row["id"])
|
|
||||||
problem_map[query_id] = {
|
|
||||||
"problem_id": query_id,
|
|
||||||
# "question": row["question"],
|
|
||||||
"solutions": row["solutions"],
|
|
||||||
"input_output": row["input_output"],
|
|
||||||
# "difficulty": level,
|
|
||||||
# "url": row["url"],
|
|
||||||
# "starter_code": row["starter_code"],
|
|
||||||
}
|
|
||||||
|
|
||||||
return problem_map
|
|
||||||
|
|
||||||
|
|
||||||
global_problems = load_problems(
|
|
||||||
os.getenv(
|
|
||||||
"REAL_CODE_METADATA_PATH",
|
|
||||||
"/storage/datasets/codeparrot-apps-test.jsonl",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def code_verify(generateds, query_ids, debug=False):
|
|
||||||
assert len(generateds) == len(query_ids), (
|
|
||||||
len(generateds),
|
|
||||||
len(query_ids),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
global global_problems
|
|
||||||
|
|
||||||
for idx, query_id in enumerate(query_ids):
|
for query_id, generated, problem in zip(query_ids, generateds, problems):
|
||||||
if query_id not in global_problems:
|
|
||||||
continue
|
|
||||||
|
|
||||||
problem = global_problems[query_id]
|
|
||||||
test_code = generateds[idx]
|
|
||||||
logger.debug(f"run_batch_code, query_id: {query_id}")
|
logger.debug(f"run_batch_code, query_id: {query_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
curr_res, metadata = check_correctness(
|
curr_res, metadata = check_correctness(
|
||||||
problem=problem, generation=test_code, timeout=6000, debug=debug
|
problem=problem, generation=generated, timeout=6000, debug=debug
|
||||||
)
|
)
|
||||||
|
|
||||||
if any(x != True for x in curr_res):
|
if any(x != True for x in curr_res):
|
||||||
logger.debug(
|
logger.debug(f"id:{query_id}, Results were not all True: {metadata}")
|
||||||
f'id:{problem["problem_id"]}, Results were not all True: {metadata}'
|
result.append(0)
|
||||||
)
|
|
||||||
result.append(f"{query_id} failed")
|
|
||||||
else:
|
else:
|
||||||
# print(f"id:{problem["problem_id"]}, result : {curr_res}")
|
# print(f"id:{problem["problem_id"]}, result : {curr_res}")
|
||||||
result.append(f"{query_id} success")
|
result.append(1)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"test framework exception = {repr(e)}{e}\n", exe_info=True)
|
exc_info = sys.exc_info()
|
||||||
result.append(f"{query_id} failed")
|
logger.error(
|
||||||
break
|
f"test framework exception = {repr(e)}{e}\n{traceback.format_exception(*exc_info)}"
|
||||||
finally:
|
)
|
||||||
# print(f"id:{problem["problem_id"]}, result : {curr_res}")
|
result.append(0)
|
||||||
# assert isinstance(curr_res, list)
|
|
||||||
pass
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
path = "/storage/openpsi/data/code/apps/codeparrot-apps-test.jsonl"
|
||||||
|
data = []
|
||||||
|
with open(path, "r") as f:
|
||||||
|
code_data = [json.loads(l) for l in f.readlines()]
|
||||||
|
|
||||||
def create_test_params(index_list):
|
problem = code_data[0]
|
||||||
global global_problems
|
problem["problem_id"] = problem["id"]
|
||||||
codes, query_ids = [], []
|
id2info = {problem["problem_id"]: problem}
|
||||||
|
|
||||||
for index in index_list:
|
result = code_verify(
|
||||||
if str(index) not in global_problems:
|
id2info,
|
||||||
continue
|
[json.loads(problem["solutions"])[0]],
|
||||||
|
[problem["problem_id"]],
|
||||||
problem = global_problems[str(index)]
|
debug=False,
|
||||||
if "solutions" not in problem or not problem["solutions"]:
|
)
|
||||||
continue
|
|
||||||
|
|
||||||
codes.append(try_load_json(problem["solutions"])[0])
|
|
||||||
query_ids.append(str(index))
|
|
||||||
return codes, query_ids
|
|
||||||
|
|
||||||
codes, query_ids = create_test_params(list(range(805, 806)))
|
|
||||||
result = code_verify(codes, query_ids, True)
|
|
||||||
print(result)
|
print(result)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from functioncall.base import logging
|
from functioncall.base import logging
|
||||||
|
@ -8,29 +9,14 @@ from functioncall.base.call import batch_function_call
|
||||||
logger = logging.getLogger("Functioncall")
|
logger = logging.getLogger("Functioncall")
|
||||||
|
|
||||||
|
|
||||||
def try_load_json(data):
|
def load_problems_with_testcase_batch(
|
||||||
"""
|
id2info, query_ids, debug=False, test_case_batch_size=None
|
||||||
Attempts to load the given data as JSON. If successful, returns the parsed JSON.
|
):
|
||||||
Otherwise, returns None and logs an error.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
loaded_data = json.loads(data)
|
|
||||||
return loaded_data
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
# print(f"Failed to load JSON: {e}")
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def load_problems_with_testcase_batch(path, debug=False, test_case_batch_size=None):
|
|
||||||
problem_map = defaultdict(list)
|
problem_map = defaultdict(list)
|
||||||
for idx, line in enumerate(open(path, "rb")):
|
for idx, query_id in enumerate(query_ids):
|
||||||
if line is None:
|
problem = id2info[query_id]
|
||||||
continue
|
|
||||||
|
|
||||||
# parse one problem
|
# parse one problem
|
||||||
row = json.loads(line.strip().decode("utf-8"))
|
input_output = json.loads(problem["input_output"])
|
||||||
query_id = str(row.get("id", row.get("query_id")))
|
|
||||||
input_output = json.loads(row["input_output"]) if "input_output" in row else {}
|
|
||||||
inputs = input_output.get("inputs", [])
|
inputs = input_output.get("inputs", [])
|
||||||
outputs = input_output.get("outputs", [])
|
outputs = input_output.get("outputs", [])
|
||||||
assert len(inputs) == len(
|
assert len(inputs) == len(
|
||||||
|
@ -57,17 +43,14 @@ def load_problems_with_testcase_batch(path, debug=False, test_case_batch_size=No
|
||||||
"batche_index": batch_idx,
|
"batche_index": batch_idx,
|
||||||
}
|
}
|
||||||
if debug:
|
if debug:
|
||||||
sub_problem["solutions"] = row.get("solutions", [])
|
sub_problem["solutions"] = problem.get("solutions", [])
|
||||||
problem_map[query_id].append(sub_problem)
|
problem_map[query_id].append(sub_problem)
|
||||||
|
|
||||||
return problem_map
|
return problem_map
|
||||||
|
|
||||||
|
|
||||||
global_problems = None
|
|
||||||
|
|
||||||
|
|
||||||
def code_verify(
|
def code_verify(
|
||||||
generateds, query_ids, debug=False, timeout=1000, timeout_for_testcase=6
|
id2info, generateds, query_ids, debug=False, timeout=1000, timeout_for_testcase=6
|
||||||
):
|
):
|
||||||
assert len(generateds) == len(query_ids), (
|
assert len(generateds) == len(query_ids), (
|
||||||
len(generateds),
|
len(generateds),
|
||||||
|
@ -75,24 +58,13 @@ def code_verify(
|
||||||
)
|
)
|
||||||
payload_list = []
|
payload_list = []
|
||||||
|
|
||||||
global global_problems
|
global_problems = load_problems_with_testcase_batch(
|
||||||
if global_problems is None:
|
id2info,
|
||||||
global_problems = load_problems_with_testcase_batch(
|
query_ids,
|
||||||
os.getenv(
|
debug=True,
|
||||||
"REAL_CODE_METADATA_PATH",
|
test_case_batch_size=20,
|
||||||
"/storage/datasets/codeparrot-apps-test.jsonl",
|
)
|
||||||
),
|
|
||||||
debug=True,
|
|
||||||
test_case_batch_size=20,
|
|
||||||
)
|
|
||||||
for idx, query_id in enumerate(query_ids):
|
for idx, query_id in enumerate(query_ids):
|
||||||
if query_id not in global_problems:
|
|
||||||
payload_list.append(None)
|
|
||||||
logger.warning(
|
|
||||||
f"Invalid query id : {query_id}, type: {type(query_id)}, should be in problem dataset!"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
problems = global_problems[query_id]
|
problems = global_problems[query_id]
|
||||||
for problem in problems:
|
for problem in problems:
|
||||||
payload_list.append(
|
payload_list.append(
|
||||||
|
@ -129,34 +101,28 @@ def code_verify(
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
path = "/storage/openpsi/data/code/apps/codeparrot-apps-test.jsonl"
|
||||||
|
data = []
|
||||||
|
with open(path, "r") as f:
|
||||||
|
code_data = [json.loads(l) for l in f.readlines()]
|
||||||
|
|
||||||
|
id2info = {}
|
||||||
|
|
||||||
def create_test_params(count=10):
|
def create_test_params(count=10):
|
||||||
global global_problems
|
global id2info
|
||||||
if global_problems is None:
|
query_ids = []
|
||||||
global_problems = load_problems_with_testcase_batch(
|
generateds = []
|
||||||
os.getenv(
|
cnt = 0
|
||||||
"REAL_CODE_METADATA_PATH",
|
while cnt < count:
|
||||||
"/storage/datasets/codeparrot-apps-test.jsonl",
|
d = random.choice(code_data)
|
||||||
),
|
if not d["solutions"]:
|
||||||
debug=True,
|
|
||||||
test_case_batch_size=20,
|
|
||||||
)
|
|
||||||
codes, query_ids = [], []
|
|
||||||
idx = 0
|
|
||||||
for query_id, problems in global_problems.items():
|
|
||||||
if idx >= count:
|
|
||||||
break
|
|
||||||
|
|
||||||
problem = problems[0]
|
|
||||||
if "solutions" not in problem or not problem["solutions"]:
|
|
||||||
continue
|
continue
|
||||||
|
id2info[d["id"]] = d
|
||||||
|
query_ids.append(d["id"])
|
||||||
|
generateds.append(d["solutions"][0])
|
||||||
|
cnt += 1
|
||||||
|
return generateds, query_ids
|
||||||
|
|
||||||
codes.append(try_load_json(problem["solutions"])[0])
|
generateds, query_ids = create_test_params(100)
|
||||||
query_ids.append(query_id)
|
result = code_verify(id2info, generateds, query_ids, True)
|
||||||
idx += 1
|
|
||||||
|
|
||||||
return codes, query_ids
|
|
||||||
|
|
||||||
codes, query_ids = create_test_params(100)
|
|
||||||
result = code_verify(codes, query_ids, True)
|
|
||||||
print(result)
|
print(result)
|
||||||
|
|
|
@ -9,28 +9,9 @@ from functioncall.base.call import batch_function_call
|
||||||
logger = logging.getLogger("Functioncall")
|
logger = logging.getLogger("Functioncall")
|
||||||
|
|
||||||
|
|
||||||
def loadJson(dataDir):
|
def math_verify(
|
||||||
with open(dataDir, "r") as f:
|
id2info, generateds: List, query_ids: List, batch_size=10, timeout=1000
|
||||||
if dataDir.endswith(".jsonl"):
|
) -> List:
|
||||||
samples = [json.loads(line) for line in f.readlines()]
|
|
||||||
else:
|
|
||||||
samples = json.load(f)
|
|
||||||
|
|
||||||
return samples
|
|
||||||
|
|
||||||
|
|
||||||
id2info = None
|
|
||||||
|
|
||||||
|
|
||||||
def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=1000) -> List:
|
|
||||||
global id2info
|
|
||||||
if id2info is None:
|
|
||||||
id2info = loadJson(
|
|
||||||
os.getenv(
|
|
||||||
"REAL_MATH_MEATADATA_PATH",
|
|
||||||
"/storage/datasets/id2info.json",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert len(generateds) == len(query_ids), (
|
assert len(generateds) == len(query_ids), (
|
||||||
len(generateds),
|
len(generateds),
|
||||||
len(query_ids),
|
len(query_ids),
|
||||||
|
@ -95,27 +76,19 @@ def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=1000)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# sample = {
|
sample = {
|
||||||
# "prompt": "",
|
"answers": ["-\\frac{2}{3}"],
|
||||||
# "query_id": "fe11b471-1aa9-4867-958f-a0a811c85f92",
|
"solutions": [
|
||||||
# "answer": "\\boxed{-\\frac{1}{30}}",
|
"1. **Apply the operation $\\otimes$ to the innermost parentheses first:**\n \\[\n (1 \\otimes 2) \\otimes 3 = \\left(\\frac{1^2}{2}\\right) \\otimes 3 = \\frac{1}{2} \\otimes 3\n \\]\n \\[\n 1 \\otimes (2 \\otimes 3) = 1 \\otimes \\left(\\frac{2^2}{3}\\right) = 1 \\otimes \\frac{4}{3}\n \\]\n\n2. **Calculate each part using the definition of $\\otimes$:**\n \\[\n \\frac{1}{2} \\otimes 3 = \\frac{\\left(\\frac{1}{2}\\right)^2}{3} = \\frac{\\frac{1}{4}}{3} = \\frac{1}{12}\n \\]\n \\[\n 1 \\otimes \\frac{4}{3} = \\frac{1^2}{\\frac{4}{3}} = \\frac{1}{\\frac{4}{3}} = \\frac{3}{4}\n \\]\n\n3. **Subtract the two results:**\n \\[\n \\left(\\frac{1}{12}\\right) - \\left(\\frac{3}{4}\\right) = \\frac{1}{12} - \\frac{9}{12} = -\\frac{8}{12} = -\\frac{2}{3}\n \\]\n\n4. **Conclude with the final answer:**\n \\[\n \\boxed{A}\n \\]",
|
||||||
# }
|
"\\boxed{-\\frac{2}{3}}",
|
||||||
|
],
|
||||||
if id2info is None:
|
}
|
||||||
id2info = loadJson(
|
id2info = {"fe11b471-1aa9-4867-958f-a0a811c85f92": sample}
|
||||||
os.getenv(
|
|
||||||
"REAL_MATH_MEATADATA_PATH",
|
|
||||||
"/storage/datasets/id2info.json",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
answers = []
|
|
||||||
query_ids = []
|
|
||||||
|
|
||||||
for id, value in id2info.items():
|
|
||||||
answers.append(value["solutions"][0])
|
|
||||||
query_ids.append(id)
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
result = math_verify(answers[:200], query_ids[:200])
|
result = math_verify(
|
||||||
|
id2info,
|
||||||
|
sample["answers"] * 100,
|
||||||
|
["fe11b471-1aa9-4867-958f-a0a811c85f92" for _ in range(100)],
|
||||||
|
)
|
||||||
print(result)
|
print(result)
|
||||||
|
|
|
@ -160,8 +160,6 @@ def main_start(args, recover_count: int = 0):
|
||||||
CLUSTER_SPEC_PATH=cluster_spec_path,
|
CLUSTER_SPEC_PATH=cluster_spec_path,
|
||||||
REAL_RECOVER_RUN="1" if is_recover_run else "0",
|
REAL_RECOVER_RUN="1" if is_recover_run else "0",
|
||||||
REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0",
|
REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0",
|
||||||
REAL_MATH_METADATA_PATH=os.getenv("REAL_MATH_METADATA_PATH", ""),
|
|
||||||
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
|
|
||||||
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
||||||
)
|
)
|
||||||
for k, v in BASE_ENVIRONS.items():
|
for k, v in BASE_ENVIRONS.items():
|
||||||
|
|
|
@ -206,9 +206,9 @@ class LocalMultiProcessTest:
|
||||||
|
|
||||||
|
|
||||||
def init_global_constants(
|
def init_global_constants(
|
||||||
num_dp,
|
num_dp=1,
|
||||||
num_mp,
|
num_mp=1,
|
||||||
num_pp,
|
num_pp=1,
|
||||||
topo=None,
|
topo=None,
|
||||||
model_name=None,
|
model_name=None,
|
||||||
msid2mwid=None,
|
msid2mwid=None,
|
||||||
|
@ -217,7 +217,12 @@ def init_global_constants(
|
||||||
gradient_accumulation_fusion=False,
|
gradient_accumulation_fusion=False,
|
||||||
max_prompt_len=None,
|
max_prompt_len=None,
|
||||||
is_train: bool = True,
|
is_train: bool = True,
|
||||||
|
expr_name=None,
|
||||||
|
trial_name=None,
|
||||||
):
|
):
|
||||||
|
expr_name = expr_name if expr_name is not None else _DEFAULT_EXPR_NAME
|
||||||
|
trial_name = trial_name if trial_name is not None else _DEFAULT_TRIAL_NAME
|
||||||
|
constants.set_experiment_trial_names(expr_name, trial_name)
|
||||||
model_name = model_name if model_name is not None else MODEL_NAME
|
model_name = model_name if model_name is not None else MODEL_NAME
|
||||||
|
|
||||||
if topo is None:
|
if topo is None:
|
||||||
|
|
|
@ -1,562 +0,0 @@
|
||||||
# Copyright 2025 Ant Group Inc.
|
|
||||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
||||||
import copy
|
|
||||||
import dataclasses
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import pprint
|
|
||||||
from typing import *
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
|
||||||
|
|
||||||
import realhf.base.logging as logging
|
|
||||||
from realhf.api.core.config import (
|
|
||||||
DatasetAbstraction,
|
|
||||||
ModelInterfaceAbstraction,
|
|
||||||
ModelInterfaceType,
|
|
||||||
)
|
|
||||||
from realhf.api.core.dfg import MFCDef, ParamReallocHook
|
|
||||||
from realhf.api.core.model_api import GenerationHyperparameters
|
|
||||||
from realhf.api.core.system_api import ExperimentConfig
|
|
||||||
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
|
|
||||||
from realhf.api.quickstart.device_mesh import DeviceMesh, MFCConfig, RPCAllocation
|
|
||||||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
|
||||||
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
|
|
||||||
from realhf.experiments.common.common import CommonExperimentConfig
|
|
||||||
from realhf.experiments.common.utils import resolve_replica_ids, resolve_rpc_hooks
|
|
||||||
|
|
||||||
logger = logging.getLogger("PPO Code exp", "colored")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class PPOHyperparameters:
|
|
||||||
"""Configuration of PPO hyperparameters.
|
|
||||||
|
|
||||||
:param gen: Generation hyperparameters.
|
|
||||||
:type gen: GenerationHyperparameters
|
|
||||||
:param ppo_n_minibatches: Number of minibatches in each PPO update.
|
|
||||||
:type ppo_n_minibatches: int
|
|
||||||
:param kl_ctl: Coefficient of KL divergence rewards.
|
|
||||||
:type kl_ctl: float
|
|
||||||
:param discount: Discount factor.
|
|
||||||
:type discount: float
|
|
||||||
:param gae_lambda: Lambda factor in GAE.
|
|
||||||
:type gae_lambda: float
|
|
||||||
:param eps_clip: PPO actor probability ratio clipping factor.
|
|
||||||
:type eps_clip: float
|
|
||||||
:param value_eps_clip: PPO value clipping factor.
|
|
||||||
:type value_eps_clip: float
|
|
||||||
:param max_reward_clip: Maximum reward value.
|
|
||||||
:type max_reward_clip: float
|
|
||||||
:param reward_output_scaling: Scaling factor of the reward model output.
|
|
||||||
:type reward_output_scaling: float
|
|
||||||
:param reward_output_bias: Bias of the reward model output.
|
|
||||||
The number outputed by the reward model will be
|
|
||||||
CLIP((x - bias) * scaling, -max_reward_clip, max_reward_clip).
|
|
||||||
:type reward_output_bias: float
|
|
||||||
:param early_stop_imp_ratio: PPO update will be early stopped if importance ratio
|
|
||||||
exceeds this maximum value.
|
|
||||||
:type early_stop_imp_ratio: float
|
|
||||||
:param use_adaptive_kl_ctl: Whether to use adaptive KL divergence coefficient.
|
|
||||||
:type use_adaptive_kl_ctl: bool
|
|
||||||
:param adv_norm: Whether to use advantage normalization.
|
|
||||||
:type adv_norm: bool
|
|
||||||
:param value_norm: Whether to denormalize valued and normalize return predictions.
|
|
||||||
:type value_norm: bool
|
|
||||||
:param value_norm_type: Type of value normalization.
|
|
||||||
Either exponential moving average ("exp") or moving average ("ma").
|
|
||||||
:type value_norm_type: str
|
|
||||||
:param value_norm_beta: Exponential decay factor
|
|
||||||
in exponential moving average.
|
|
||||||
:type value_norm_beta: float
|
|
||||||
:param value_norm_eps: Epsilon factor in the
|
|
||||||
denominator of exponential moving average.
|
|
||||||
:type value_norm_eps: float
|
|
||||||
:param disable_value: A shortcut option to disable the critic model.
|
|
||||||
:type disable_value: bool
|
|
||||||
"""
|
|
||||||
|
|
||||||
gen: GenerationHyperparameters = dataclasses.field(
|
|
||||||
default_factory=GenerationHyperparameters
|
|
||||||
)
|
|
||||||
ppo_n_minibatches: int = 4
|
|
||||||
kl_ctl: float = 0.1
|
|
||||||
discount: float = 1.0
|
|
||||||
gae_lambda: float = 1.0
|
|
||||||
eps_clip: float = 0.2
|
|
||||||
value_eps_clip: float = 0.2
|
|
||||||
disable_value: bool = False
|
|
||||||
max_reward_clip: float = 20.0
|
|
||||||
reward_output_scaling: float = 1.0
|
|
||||||
reward_output_bias: float = 0.0
|
|
||||||
early_stop_imp_ratio: float = 5.0
|
|
||||||
use_adaptive_kl_ctl: bool = False
|
|
||||||
adv_norm: bool = True
|
|
||||||
value_norm: bool = True
|
|
||||||
value_norm_type: str = dataclasses.field(
|
|
||||||
metadata={"choices": ["exp", "ma"]}, default="exp"
|
|
||||||
)
|
|
||||||
value_norm_beta: float = 0.99995
|
|
||||||
value_norm_eps: float = 1e-5
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class PPOCODEConfig(CommonExperimentConfig):
|
|
||||||
"""PPO experiment configuration.
|
|
||||||
|
|
||||||
It is a subclass of :class:`CommonExperimentConfig`,
|
|
||||||
so all CLI options in the base class are available.
|
|
||||||
|
|
||||||
We don't implement runtime evaluation for PPO.
|
|
||||||
|
|
||||||
We identify that the RLHF process is composed of four
|
|
||||||
distinct models with independent parameters and six
|
|
||||||
*model function calls* upon these models.
|
|
||||||
|
|
||||||
The four models are\:
|
|
||||||
|
|
||||||
- Actor\: The primary LLM that generates text.
|
|
||||||
- Critic\: The value function that estimates the value of a state.
|
|
||||||
- Ref\: The reference LLM that provides KL regularization.
|
|
||||||
- Rew\: The reward model that provides reward signals.
|
|
||||||
|
|
||||||
The four model function calls and their dependencies are\:
|
|
||||||
|
|
||||||
- Rollout\: Generate text from the actor model.
|
|
||||||
- InfReward\: Infer rewards from the reward model given generated text.
|
|
||||||
- InfRef\: Infer log probabilities from the reference model given generated text.
|
|
||||||
- InfValues\: Infer values from the critic model given generated text.
|
|
||||||
- TrainActor\: Train the actor model given generated text, rewards, values, and reference log probabilities.
|
|
||||||
- TrainCritic\: Train the critic model given generated text, rewards, values, and reference log probabilities.
|
|
||||||
|
|
||||||
This class resolves these dependencies under the hood.
|
|
||||||
What the users should specify are the runtime configurations
|
|
||||||
of models and allocations of *each model function call*.
|
|
||||||
|
|
||||||
:param actor: Runtime configuration of the primary LLM.
|
|
||||||
:type actor: ModelTrainEvalConfig
|
|
||||||
:param critic: Runtime configuration of the critic model of PPO.
|
|
||||||
:type critic: ModelTrainEvalConfig
|
|
||||||
:param ref: Runtime configuration of the reference LLM.
|
|
||||||
:type ref: ModelTrainEvalConfig
|
|
||||||
:param rew: Runtime configuration of the reward LLM.
|
|
||||||
:type rew: ModelTrainEvalConfig
|
|
||||||
:param actor_train: :class:`MFCConfig` for TrainActor.
|
|
||||||
:type actor_train: MFCConfig
|
|
||||||
:param critic_train: :class:`MFCConfig` for TrainCritic.
|
|
||||||
:type critic_train: MFCConfig
|
|
||||||
:param actor_gen: :class:`MFCConfig` for Rollout.
|
|
||||||
:type actor_gen: MFCConfig
|
|
||||||
:param critic_inf: :class:`MFCConfig` for InfValues.
|
|
||||||
:type critic_inf: MFCConfig
|
|
||||||
:param rew_inf: :class:`MFCConfig` for InfReward.
|
|
||||||
:type rew_inf: MFCConfig
|
|
||||||
:param ref_inf: :class:`MFCConfig` for InfRef.
|
|
||||||
:type ref_inf: MFCConfig
|
|
||||||
:param dataset: Dataset configuration.
|
|
||||||
:type dataset: PromptOnlyDatasetConfig
|
|
||||||
:param ppo: Configuration for the PPO algorithm.
|
|
||||||
:type ppo: PPOHyperparameters
|
|
||||||
:param group_size: The number of answers remained for each prompt.
|
|
||||||
:type group_size: int
|
|
||||||
:param generation_size: The number of answers sampled for each prompt.
|
|
||||||
Among them, only `group_size` samples are remained according to
|
|
||||||
the reward score, aka best-of-n sampling. If None, use `group_size`.
|
|
||||||
:type generation_size: Optional[int]
|
|
||||||
:param mask_no_eos_with_zero: Whether to mask out the reward if an
|
|
||||||
answer is truncated due to exceeding the length limit.
|
|
||||||
:type mask_no_eos_with_zero: bool
|
|
||||||
:param mask_too_long: Whether to mask out the PPO loss if an
|
|
||||||
answer is truncated due to exceeding the length limit.
|
|
||||||
:type mask_too_long: bool
|
|
||||||
:param check_verifier_status: If True, raise an error
|
|
||||||
when the reward is all-zero. This usually indicates a bug
|
|
||||||
of the verifier.
|
|
||||||
:type check_verifier_status: bool
|
|
||||||
:param group_adv_norm: Whther to use grouped advantage
|
|
||||||
normaliztion in GRPO.
|
|
||||||
:type group_adv_norm: bool
|
|
||||||
"""
|
|
||||||
|
|
||||||
actor: ModelTrainEvalConfig = dataclasses.field(
|
|
||||||
default_factory=ModelTrainEvalConfig
|
|
||||||
)
|
|
||||||
critic: ModelTrainEvalConfig = dataclasses.field(
|
|
||||||
default_factory=ModelTrainEvalConfig
|
|
||||||
)
|
|
||||||
ref: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
|
|
||||||
rew: ModelTrainEvalConfig = dataclasses.field(default_factory=ModelTrainEvalConfig)
|
|
||||||
|
|
||||||
# for manual allocation only
|
|
||||||
actor_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
|
||||||
critic_train: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
|
||||||
actor_gen: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
|
||||||
critic_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
|
||||||
rew_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
|
||||||
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
|
||||||
|
|
||||||
dataset: PromptOnlyDatasetConfig = dataclasses.field(
|
|
||||||
default_factory=PromptOnlyDatasetConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
ppo: PPOHyperparameters = dataclasses.field(default_factory=PPOHyperparameters)
|
|
||||||
|
|
||||||
group_size: int = 1
|
|
||||||
generation_size: Optional[int] = None
|
|
||||||
mask_no_eos_with_zero: bool = False
|
|
||||||
rm_output_scaling: Optional[float] = None
|
|
||||||
ref_ema_eta: Optional[float] = None
|
|
||||||
group_adv_norm: bool = False
|
|
||||||
mask_too_long: bool = False
|
|
||||||
rw_type: Optional[str] = "sparse"
|
|
||||||
task: str = "code"
|
|
||||||
check_xml_format: bool = False
|
|
||||||
use_dense_reward: bool = False
|
|
||||||
reward_delta: bool = True
|
|
||||||
|
|
||||||
check_verifier_status: bool = False
|
|
||||||
|
|
||||||
dataset_filter_threshold: float = 100.0
|
|
||||||
dataset_max_filter_percentage: float = 0.0
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
|
|
||||||
self.ppo_kwargs = dict(
|
|
||||||
n_minibatches=self.ppo.ppo_n_minibatches,
|
|
||||||
kl_ctl=self.ppo.kl_ctl,
|
|
||||||
discount=self.ppo.discount,
|
|
||||||
gae_lambda=self.ppo.gae_lambda,
|
|
||||||
eps_clip=self.ppo.eps_clip,
|
|
||||||
value_eps_clip=self.ppo.value_eps_clip,
|
|
||||||
max_reward_clip=self.ppo.max_reward_clip,
|
|
||||||
adaptive_kl_ctl=self.ppo.use_adaptive_kl_ctl,
|
|
||||||
value_norm=self.ppo.value_norm,
|
|
||||||
value_norm_type=self.ppo.value_norm_type,
|
|
||||||
value_norm_beta=self.ppo.value_norm_beta,
|
|
||||||
value_norm_eps=self.ppo.value_norm_eps,
|
|
||||||
disable_value=self.ppo.disable_value,
|
|
||||||
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.rm_output_scaling is None:
|
|
||||||
self.rm_output_scaling = self.ppo.reward_output_scaling
|
|
||||||
|
|
||||||
@property
|
|
||||||
def models(self) -> Dict[str, ModelTrainEvalConfig]:
|
|
||||||
# role to config
|
|
||||||
if self.ppo.disable_value:
|
|
||||||
return {
|
|
||||||
"actor": self.actor,
|
|
||||||
# "critic": self.critic,
|
|
||||||
"ref": self.ref,
|
|
||||||
"reward": self.rew,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"actor": self.actor,
|
|
||||||
"critic": self.critic,
|
|
||||||
"ref": self.ref,
|
|
||||||
"reward": self.rew,
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def rpcs(self):
|
|
||||||
if (
|
|
||||||
self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens
|
|
||||||
> self.actor.vllm.max_seq_len_to_capture
|
|
||||||
):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"vllm max seq len to capture {self.actor.vllm.max_seq_len_to_capture} is "
|
|
||||||
f"smaller than the prompt length + generation length {self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
|
|
||||||
)
|
|
||||||
if not os.path.exists(os.getenv("REAL_CODE_METADATA_PATH")):
|
|
||||||
raise RuntimeError(
|
|
||||||
"Dataset json path REAL_CODE_METADATA_PATH does not exist."
|
|
||||||
)
|
|
||||||
|
|
||||||
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
|
|
||||||
if not (domain.startswith("http://") and ":" in domain):
|
|
||||||
raise RuntimeError(
|
|
||||||
"function call address FUNCTIONCALL_SERVICE_DOMAIN is invalid."
|
|
||||||
)
|
|
||||||
|
|
||||||
# interfaces
|
|
||||||
actor_interface = ModelInterfaceAbstraction(
|
|
||||||
"ppo_actor",
|
|
||||||
args={
|
|
||||||
**copy.deepcopy(self.ppo_kwargs),
|
|
||||||
# NOTE: to_container converts the object to a dict
|
|
||||||
# It is used for unifying the profiling API, which requires to
|
|
||||||
# pass external interface configurations in the launch command.
|
|
||||||
# Customized dataclass objects will not work in that case.
|
|
||||||
"generation_config": (
|
|
||||||
OmegaConf.to_container(self.ppo.gen, resolve=True)
|
|
||||||
if isinstance(self.ppo.gen, (OmegaConf, DictConfig))
|
|
||||||
else dataclasses.asdict(self.ppo.gen)
|
|
||||||
),
|
|
||||||
"early_stop_imp_ratio": self.ppo.early_stop_imp_ratio,
|
|
||||||
"adv_norm": self.ppo.adv_norm,
|
|
||||||
"group_size": self.group_size,
|
|
||||||
"generation_size": self.generation_size,
|
|
||||||
"group_adv_norm": self.group_adv_norm,
|
|
||||||
"mask_too_long": self.mask_too_long,
|
|
||||||
"use_dense_reward": self.use_dense_reward,
|
|
||||||
"reward_delta": self.reward_delta,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ref_interface = copy.deepcopy(actor_interface)
|
|
||||||
ref_interface.args["enable_save"] = False
|
|
||||||
|
|
||||||
critic_interface = ModelInterfaceAbstraction(
|
|
||||||
"ppo_critic",
|
|
||||||
args={
|
|
||||||
**copy.deepcopy(self.ppo_kwargs),
|
|
||||||
"group_size": self.group_size,
|
|
||||||
"mask_too_long": self.mask_too_long,
|
|
||||||
"use_dense_reward": self.use_dense_reward,
|
|
||||||
"reward_delta": self.reward_delta,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
critic_interface.args.pop("eps_clip")
|
|
||||||
rw_interface = ModelInterfaceAbstraction(
|
|
||||||
"rw_code",
|
|
||||||
args=dict(
|
|
||||||
rw_type=self.rw_type,
|
|
||||||
task=self.task,
|
|
||||||
check_xml_format=self.check_xml_format,
|
|
||||||
tokenizer_path=self.actor.path,
|
|
||||||
enable_save=False,
|
|
||||||
output_scaling=self.ppo.reward_output_scaling,
|
|
||||||
rm_output_scaling=self.rm_output_scaling,
|
|
||||||
output_bias=self.ppo.reward_output_bias,
|
|
||||||
group_size=self.group_size,
|
|
||||||
check_verifier_status=self.check_verifier_status,
|
|
||||||
max_sync_length=self.ppo.gen.max_new_tokens
|
|
||||||
+ self.dataset.max_prompt_len
|
|
||||||
+ 128,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
rollout = MFCDef(
|
|
||||||
name="actor_gen",
|
|
||||||
model_name="actor",
|
|
||||||
mb_spec=self.actor_gen.mb_spec,
|
|
||||||
interface_type=ModelInterfaceType.GENERATE,
|
|
||||||
model_type=self.actor.type,
|
|
||||||
model_path=self.actor.path,
|
|
||||||
interface_impl=actor_interface,
|
|
||||||
input_keys=["packed_prompts"],
|
|
||||||
output_keys=[
|
|
||||||
"seq_no_eos_mask",
|
|
||||||
"packed_input_ids",
|
|
||||||
"packed_logprobs",
|
|
||||||
"prompt_mask",
|
|
||||||
],
|
|
||||||
n_seqs=self.dataset.train_bs_n_seqs,
|
|
||||||
)
|
|
||||||
|
|
||||||
inf_reward = MFCDef(
|
|
||||||
name="rew_inf",
|
|
||||||
model_name="reward",
|
|
||||||
mb_spec=self.rew_inf.mb_spec,
|
|
||||||
interface_type=ModelInterfaceType.INFERENCE,
|
|
||||||
interface_impl=rw_interface,
|
|
||||||
model_type=self.rew.type,
|
|
||||||
model_path=self.rew.path,
|
|
||||||
min_n_seqs_per_pass=1 / self.group_size,
|
|
||||||
input_keys=["packed_input_ids", "packed_prompts"],
|
|
||||||
output_keys=["rewards", "dense_rewards"],
|
|
||||||
n_seqs=self.dataset.train_bs_n_seqs,
|
|
||||||
)
|
|
||||||
|
|
||||||
inf_ref_inputs = ["packed_input_ids"]
|
|
||||||
inf_ref_logits = MFCDef(
|
|
||||||
name="ref_inf",
|
|
||||||
model_name="ref",
|
|
||||||
mb_spec=self.ref_inf.mb_spec,
|
|
||||||
interface_type=ModelInterfaceType.INFERENCE,
|
|
||||||
model_type=self.ref.type,
|
|
||||||
model_path=self.ref.path,
|
|
||||||
interface_impl=ref_interface,
|
|
||||||
min_n_seqs_per_pass=1 / self.group_size,
|
|
||||||
input_keys=inf_ref_inputs,
|
|
||||||
output_keys=["packed_ref_logprobs"],
|
|
||||||
n_seqs=self.dataset.train_bs_n_seqs,
|
|
||||||
)
|
|
||||||
|
|
||||||
inf_values = MFCDef(
|
|
||||||
name="critic_inf",
|
|
||||||
model_name="critic",
|
|
||||||
mb_spec=self.critic_inf.mb_spec,
|
|
||||||
interface_type=ModelInterfaceType.INFERENCE,
|
|
||||||
interface_impl=critic_interface,
|
|
||||||
model_type=self.critic.type,
|
|
||||||
model_path=self.critic.path,
|
|
||||||
min_n_seqs_per_pass=1 / self.group_size,
|
|
||||||
input_keys=["packed_input_ids", "seq_no_eos_mask"],
|
|
||||||
output_keys=["values"],
|
|
||||||
n_seqs=self.dataset.train_bs_n_seqs,
|
|
||||||
)
|
|
||||||
|
|
||||||
train_actor_inputs = [
|
|
||||||
"packed_input_ids",
|
|
||||||
"packed_logprobs",
|
|
||||||
"packed_ref_logprobs",
|
|
||||||
"rewards",
|
|
||||||
"dense_rewards",
|
|
||||||
"values",
|
|
||||||
"prompt_mask",
|
|
||||||
"seq_no_eos_mask",
|
|
||||||
]
|
|
||||||
if self.ppo.disable_value:
|
|
||||||
train_actor_inputs.remove("values")
|
|
||||||
train_actor = MFCDef(
|
|
||||||
name="actor_train",
|
|
||||||
model_name="actor",
|
|
||||||
mb_spec=self.actor_train.mb_spec,
|
|
||||||
interface_type=ModelInterfaceType.TRAIN_STEP,
|
|
||||||
model_type=self.actor.type,
|
|
||||||
model_path=self.actor.path,
|
|
||||||
interface_impl=actor_interface,
|
|
||||||
input_keys=train_actor_inputs,
|
|
||||||
log_return_value=True,
|
|
||||||
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
|
|
||||||
n_seqs=self.dataset.train_bs_n_seqs,
|
|
||||||
)
|
|
||||||
|
|
||||||
train_critic = MFCDef(
|
|
||||||
name="critic_train",
|
|
||||||
model_name="critic",
|
|
||||||
mb_spec=self.critic_train.mb_spec,
|
|
||||||
interface_type=ModelInterfaceType.TRAIN_STEP,
|
|
||||||
interface_impl=critic_interface,
|
|
||||||
model_type=self.critic.type,
|
|
||||||
model_path=self.critic.path,
|
|
||||||
input_keys=[
|
|
||||||
"packed_input_ids",
|
|
||||||
"packed_logprobs",
|
|
||||||
"packed_ref_logprobs",
|
|
||||||
"rewards",
|
|
||||||
"dense_rewards",
|
|
||||||
"values",
|
|
||||||
"prompt_mask",
|
|
||||||
"seq_no_eos_mask",
|
|
||||||
],
|
|
||||||
log_return_value=True,
|
|
||||||
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
|
|
||||||
n_seqs=self.dataset.train_bs_n_seqs,
|
|
||||||
)
|
|
||||||
if self.ppo.disable_value:
|
|
||||||
return {
|
|
||||||
"actor_gen": rollout,
|
|
||||||
"actor_train": train_actor,
|
|
||||||
# "critic_inf": inf_values,
|
|
||||||
# "critic_train": train_critic,
|
|
||||||
"ref_inf": inf_ref_logits,
|
|
||||||
"rew_inf": inf_reward,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"actor_gen": rollout,
|
|
||||||
"actor_train": train_actor,
|
|
||||||
"critic_inf": inf_values,
|
|
||||||
"critic_train": train_critic,
|
|
||||||
"ref_inf": inf_ref_logits,
|
|
||||||
"rew_inf": inf_reward,
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def allocations(self):
|
|
||||||
if self.ppo.disable_value:
|
|
||||||
return {
|
|
||||||
"actor_gen": self.actor_gen,
|
|
||||||
"actor_train": self.actor_train,
|
|
||||||
# "critic_inf": self.critic_inf,
|
|
||||||
# "critic_train": self.critic_train,
|
|
||||||
"ref_inf": self.ref_inf,
|
|
||||||
"rew_inf": self.rew_inf,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"actor_gen": self.actor_gen,
|
|
||||||
"actor_train": self.actor_train,
|
|
||||||
"critic_inf": self.critic_inf,
|
|
||||||
"critic_train": self.critic_train,
|
|
||||||
"ref_inf": self.ref_inf,
|
|
||||||
"rew_inf": self.rew_inf,
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def datasets(self):
|
|
||||||
return [
|
|
||||||
DatasetAbstraction(
|
|
||||||
"code_prompt",
|
|
||||||
args=dict(
|
|
||||||
dataset_path=self.dataset.path,
|
|
||||||
max_length=self.dataset.max_prompt_len,
|
|
||||||
fill_to_max_length=self.dataset.fill_to_max_length,
|
|
||||||
filter_threshold=self.dataset_filter_threshold,
|
|
||||||
max_filter_percentage=self.dataset_max_filter_percentage,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tokenizer_name_or_path(self) -> str:
|
|
||||||
return self.actor.path
|
|
||||||
|
|
||||||
@property
|
|
||||||
def search_kwargs(self):
|
|
||||||
return {
|
|
||||||
"num_gen_tokens": self.ppo.gen.max_new_tokens,
|
|
||||||
"n_ppo_minibatches": self.ppo.ppo_n_minibatches,
|
|
||||||
"seq_len": self.dataset.max_prompt_len,
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_prompt_len(self):
|
|
||||||
return self.dataset.max_prompt_len
|
|
||||||
|
|
||||||
def initial_setup(self) -> ExperimentConfig:
|
|
||||||
rpc_allocs = self._get_rpc_allocations()
|
|
||||||
|
|
||||||
resolve_replica_ids(rpc_allocs, self.models)
|
|
||||||
resolve_rpc_hooks(
|
|
||||||
rpc_allocs, self.models
|
|
||||||
) # inplace modify MFCDefs in rpc allocations
|
|
||||||
|
|
||||||
pprint.pprint(rpc_allocs)
|
|
||||||
|
|
||||||
######### update ref model using ema, ref_ema_eta = 0 means fixed ref model #########
|
|
||||||
def _find_rpc(name):
|
|
||||||
return next(alloc.rpc for alloc in rpc_allocs if alloc.rpc.name == name)
|
|
||||||
|
|
||||||
# Remove the offload hook of ref_inf, because
|
|
||||||
# we need to receive parameters from peer GPUs and update it immediately.
|
|
||||||
if self.ref_ema_eta is not None:
|
|
||||||
|
|
||||||
ref_inf = _find_rpc("ref_inf")
|
|
||||||
ref_inf._post_hooks = []
|
|
||||||
|
|
||||||
# Add an unidirectional parameter reallocation hook.
|
|
||||||
actor_train = _find_rpc("actor_train")
|
|
||||||
actor_train.add_post_hook(
|
|
||||||
ParamReallocHook(
|
|
||||||
target=ref_inf.model_name,
|
|
||||||
eta=self.ref_ema_eta,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
######### The main difference from normal PPO #########
|
|
||||||
|
|
||||||
model_worker = self._get_model_worker_configs(rpc_allocs)
|
|
||||||
|
|
||||||
return ExperimentConfig(
|
|
||||||
exp_ctrl=self.exp_ctrl,
|
|
||||||
wandb=self.wandb,
|
|
||||||
model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs],
|
|
||||||
model_worker=model_worker,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
register_quickstart_exp("ppo-code", PPOCODEConfig)
|
|
|
@ -7,8 +7,6 @@ import os
|
||||||
import pprint
|
import pprint
|
||||||
from typing import *
|
from typing import *
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
|
||||||
|
|
||||||
import realhf.base.logging as logging
|
import realhf.base.logging as logging
|
||||||
from realhf.api.core.config import (
|
from realhf.api.core.config import (
|
||||||
DatasetAbstraction,
|
DatasetAbstraction,
|
||||||
|
@ -21,7 +19,7 @@ from realhf.api.core.system_api import ExperimentConfig
|
||||||
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
|
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
|
||||||
from realhf.api.quickstart.device_mesh import MFCConfig
|
from realhf.api.quickstart.device_mesh import MFCConfig
|
||||||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
||||||
from realhf.api.quickstart.model import ModelTrainEvalConfig
|
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
|
||||||
from realhf.experiments.common.common import CommonExperimentConfig
|
from realhf.experiments.common.common import CommonExperimentConfig
|
||||||
from realhf.experiments.common.utils import (
|
from realhf.experiments.common.utils import (
|
||||||
asdict,
|
asdict,
|
||||||
|
@ -89,8 +87,6 @@ class PPOHyperparameters:
|
||||||
gae_lambda: float = 1.0
|
gae_lambda: float = 1.0
|
||||||
eps_clip: float = 0.2
|
eps_clip: float = 0.2
|
||||||
value_eps_clip: float = 0.2
|
value_eps_clip: float = 0.2
|
||||||
disable_value: bool = False
|
|
||||||
recompute_logprob: bool = False
|
|
||||||
max_reward_clip: float = 20.0
|
max_reward_clip: float = 20.0
|
||||||
reward_output_scaling: float = 1.0
|
reward_output_scaling: float = 1.0
|
||||||
reward_output_bias: float = 0.0
|
reward_output_bias: float = 0.0
|
||||||
|
@ -104,6 +100,10 @@ class PPOHyperparameters:
|
||||||
value_norm_beta: float = 0.99995
|
value_norm_beta: float = 0.99995
|
||||||
value_norm_eps: float = 1e-5
|
value_norm_eps: float = 1e-5
|
||||||
|
|
||||||
|
disable_value: bool = False
|
||||||
|
recompute_logprob: bool = False
|
||||||
|
fuse_rew_ref: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class PPOMATHConfig(CommonExperimentConfig):
|
class PPOMATHConfig(CommonExperimentConfig):
|
||||||
|
@ -197,6 +197,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
actor_gen: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
actor_gen: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||||
critic_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
critic_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||||
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
ref_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||||
|
rew_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||||
actor_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
actor_inf: MFCConfig = dataclasses.field(default_factory=MFCConfig)
|
||||||
|
|
||||||
dataset: PromptOnlyDatasetConfig = dataclasses.field(
|
dataset: PromptOnlyDatasetConfig = dataclasses.field(
|
||||||
|
@ -208,15 +209,11 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
group_size: int = 1
|
group_size: int = 1
|
||||||
generation_size: Optional[int] = None
|
generation_size: Optional[int] = None
|
||||||
mask_no_eos_with_zero: bool = False
|
mask_no_eos_with_zero: bool = False
|
||||||
rm_output_scaling: Optional[float] = None
|
|
||||||
ref_ema_eta: Optional[float] = None
|
ref_ema_eta: Optional[float] = None
|
||||||
group_adv_norm: bool = False
|
group_adv_norm: bool = False
|
||||||
mask_too_long: bool = False
|
mask_too_long: bool = False
|
||||||
rw_type: Optional[str] = "sparse"
|
rw_type: Optional[str] = "sparse"
|
||||||
task: str = "math"
|
|
||||||
check_xml_format: bool = False
|
check_xml_format: bool = False
|
||||||
use_dense_reward: bool = False
|
|
||||||
reward_delta: bool = True
|
|
||||||
|
|
||||||
check_verifier_status: bool = False
|
check_verifier_status: bool = False
|
||||||
|
|
||||||
|
@ -242,9 +239,6 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
|
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.rm_output_scaling is None:
|
|
||||||
self.rm_output_scaling = self.ppo.reward_output_scaling
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def models(self) -> Dict[str, ModelTrainEvalConfig]:
|
def models(self) -> Dict[str, ModelTrainEvalConfig]:
|
||||||
# role to config
|
# role to config
|
||||||
|
@ -270,10 +264,6 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
f"smaller than the prompt length + generation length "
|
f"smaller than the prompt length + generation length "
|
||||||
f"{self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
|
f"{self.dataset.max_prompt_len + self.ppo.gen.max_new_tokens}"
|
||||||
)
|
)
|
||||||
if not os.path.exists(os.getenv("REAL_MATH_METADATA_PATH")):
|
|
||||||
raise RuntimeError(
|
|
||||||
"Dataset json path REAL_MATH_METADATA_PATH does not exist."
|
|
||||||
)
|
|
||||||
|
|
||||||
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
|
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
|
||||||
if domain and (not (domain.startswith("http://") and ":" in domain)):
|
if domain and (not (domain.startswith("http://") and ":" in domain)):
|
||||||
|
@ -297,8 +287,6 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
"generation_size": self.generation_size,
|
"generation_size": self.generation_size,
|
||||||
"group_adv_norm": self.group_adv_norm,
|
"group_adv_norm": self.group_adv_norm,
|
||||||
"mask_too_long": self.mask_too_long,
|
"mask_too_long": self.mask_too_long,
|
||||||
"use_dense_reward": self.use_dense_reward,
|
|
||||||
"reward_delta": self.reward_delta,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -308,29 +296,30 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
**copy.deepcopy(self.ppo_kwargs),
|
**copy.deepcopy(self.ppo_kwargs),
|
||||||
"group_size": self.group_size,
|
"group_size": self.group_size,
|
||||||
"mask_too_long": self.mask_too_long,
|
"mask_too_long": self.mask_too_long,
|
||||||
"use_dense_reward": self.use_dense_reward,
|
|
||||||
"reward_delta": self.reward_delta,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
critic_interface.args.pop("eps_clip")
|
critic_interface.args.pop("eps_clip")
|
||||||
rw_interface = ModelInterfaceAbstraction(
|
rw_interface = ModelInterfaceAbstraction(
|
||||||
"reward",
|
"reward",
|
||||||
args=dict(
|
args=dict(
|
||||||
rw_type=self.rw_type,
|
|
||||||
task=self.task,
|
|
||||||
check_xml_format=self.check_xml_format,
|
|
||||||
tokenizer_path=self.actor.path,
|
tokenizer_path=self.actor.path,
|
||||||
enable_save=False,
|
|
||||||
output_scaling=self.ppo.reward_output_scaling,
|
output_scaling=self.ppo.reward_output_scaling,
|
||||||
rm_output_scaling=self.rm_output_scaling,
|
|
||||||
output_bias=self.ppo.reward_output_bias,
|
output_bias=self.ppo.reward_output_bias,
|
||||||
|
rw_type=self.rw_type,
|
||||||
|
check_xml_format=self.check_xml_format,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
check_verifier_status=self.check_verifier_status,
|
check_verifier_status=self.check_verifier_status,
|
||||||
max_sync_length=self.ppo.gen.max_new_tokens
|
|
||||||
+ self.dataset.max_prompt_len
|
|
||||||
+ 128,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ref_interface = copy.deepcopy(actor_interface)
|
||||||
|
ref_interface.args["enable_save"] = False
|
||||||
|
if self.ppo.fuse_rew_ref:
|
||||||
|
ref_interface = ModelInterfaceAbstraction(
|
||||||
|
"fused-threading",
|
||||||
|
args=dict(interfaces=dict(rew=rw_interface, ref=ref_interface)),
|
||||||
|
)
|
||||||
|
|
||||||
rollout_output_keys = [
|
rollout_output_keys = [
|
||||||
"seq_no_eos_mask",
|
"seq_no_eos_mask",
|
||||||
"packed_input_ids",
|
"packed_input_ids",
|
||||||
|
@ -366,13 +355,23 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
n_seqs=self.dataset.train_bs_n_seqs,
|
n_seqs=self.dataset.train_bs_n_seqs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
inf_reward = MFCDef(
|
||||||
|
name="rew_inf",
|
||||||
|
model_name="reward",
|
||||||
|
interface_type=ModelInterfaceType.INFERENCE,
|
||||||
|
interface_impl=rw_interface,
|
||||||
|
min_n_seqs_per_pass=1 / self.group_size,
|
||||||
|
input_keys=["packed_input_ids", "packed_prompts"],
|
||||||
|
output_keys=["rewards"],
|
||||||
|
n_seqs=self.dataset.train_bs_n_seqs,
|
||||||
|
)
|
||||||
|
|
||||||
# add rew param into ref MFC
|
# add rew param into ref MFC
|
||||||
inf_ref_inputs = ["packed_input_ids", "packed_prompts"]
|
inf_ref_inputs = ["packed_input_ids"]
|
||||||
inf_ref_outputs = ["logprobs", "rewards", "dense_rewards"]
|
inf_ref_outputs = ["logprobs"]
|
||||||
ref_interface = copy.deepcopy(actor_interface)
|
if self.ppo.fuse_rew_ref:
|
||||||
ref_interface.type_ = "ref_rw"
|
inf_ref_inputs += ["packed_prompts"]
|
||||||
ref_interface.args["enable_save"] = False
|
inf_ref_outputs += ["rewards"]
|
||||||
ref_interface.args["rew_inf_args"] = copy.deepcopy(rw_interface.args)
|
|
||||||
|
|
||||||
inf_ref_logits = MFCDef(
|
inf_ref_logits = MFCDef(
|
||||||
name="ref_rw",
|
name="ref_rw",
|
||||||
|
@ -408,7 +407,6 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
"packed_logprobs",
|
"packed_logprobs",
|
||||||
"packed_ref_logprobs",
|
"packed_ref_logprobs",
|
||||||
"rewards",
|
"rewards",
|
||||||
"dense_rewards",
|
|
||||||
"values",
|
"values",
|
||||||
"prompt_mask",
|
"prompt_mask",
|
||||||
"seq_no_eos_mask",
|
"seq_no_eos_mask",
|
||||||
|
@ -442,7 +440,6 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
"packed_logprobs",
|
"packed_logprobs",
|
||||||
"packed_ref_logprobs",
|
"packed_ref_logprobs",
|
||||||
"rewards",
|
"rewards",
|
||||||
"dense_rewards",
|
|
||||||
"values",
|
"values",
|
||||||
"prompt_mask",
|
"prompt_mask",
|
||||||
"seq_no_eos_mask",
|
"seq_no_eos_mask",
|
||||||
|
@ -466,6 +463,8 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
rpcs.pop("critic_train")
|
rpcs.pop("critic_train")
|
||||||
if not self.ppo.recompute_logprob:
|
if not self.ppo.recompute_logprob:
|
||||||
rpcs.pop("actor_inf")
|
rpcs.pop("actor_inf")
|
||||||
|
if self.ppo.fuse_rew_ref:
|
||||||
|
rpcs.pop("rew_inf")
|
||||||
return rpcs
|
return rpcs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -484,13 +483,15 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
allocs.pop("critic_train")
|
allocs.pop("critic_train")
|
||||||
if not self.ppo.recompute_logprob:
|
if not self.ppo.recompute_logprob:
|
||||||
allocs.pop("actor_inf")
|
allocs.pop("actor_inf")
|
||||||
|
if self.ppo.fuse_rew_ref:
|
||||||
|
allocs.pop("rew_inf")
|
||||||
return allocs
|
return allocs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def datasets(self):
|
def datasets(self):
|
||||||
return [
|
return [
|
||||||
DatasetAbstraction(
|
DatasetAbstraction(
|
||||||
"math_prompt",
|
"math_code_prompt",
|
||||||
args=dict(
|
args=dict(
|
||||||
dataset_path=self.dataset.path,
|
dataset_path=self.dataset.path,
|
||||||
max_length=self.dataset.max_prompt_len,
|
max_length=self.dataset.max_prompt_len,
|
||||||
|
|
|
@ -0,0 +1,184 @@
|
||||||
|
# Copyright 2025 Ant Group Inc.
|
||||||
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Callable, Dict, Hashable, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
from realhf.api.core import data_api
|
||||||
|
from realhf.base import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("Math Code Dataset")
|
||||||
|
|
||||||
|
id2info = {}
|
||||||
|
|
||||||
|
|
||||||
|
def check_code_metadata_entries(data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def check_math_metadata_entries(data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def load_metadata(path):
|
||||||
|
assert str(path).endswith(".jsonl"), path
|
||||||
|
with open(path, "r") as f:
|
||||||
|
data = [json.loads(l) for l in f.readlines()]
|
||||||
|
id2info = {}
|
||||||
|
for d in data:
|
||||||
|
assert d["query_id"] not in d, (d["task"], d["query_id"])
|
||||||
|
if d["task"] == "math":
|
||||||
|
check_math_metadata_entries(d)
|
||||||
|
elif d["task"] == "code":
|
||||||
|
check_code_metadata_entries(d)
|
||||||
|
id2info[d["query_id"]] = d
|
||||||
|
return id2info
|
||||||
|
|
||||||
|
|
||||||
|
class MATHCodePromptDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
util: data_api.DatasetUtility,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
dataset_path: Optional[str] = None,
|
||||||
|
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
|
||||||
|
filter_threshold: float = 1e4,
|
||||||
|
max_filter_percentage: float = 0.0,
|
||||||
|
):
|
||||||
|
"""Required keys: prompt, query_id, task=math/code, solutions.
|
||||||
|
|
||||||
|
For code dataset, they additionally require an "input_output" key.
|
||||||
|
"""
|
||||||
|
self._util = util
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
global id2info
|
||||||
|
id2info = load_metadata(dataset_path)
|
||||||
|
|
||||||
|
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
|
||||||
|
|
||||||
|
prompts_str = [x["prompt"] for x in data]
|
||||||
|
self.ids = [x["query_id"] for x in data]
|
||||||
|
self.tasks_ids = [data_api.RL_TASKS.index(x["task"]) for x in data]
|
||||||
|
if "scores" in data[0]:
|
||||||
|
self.base_scores = [np.mean(x["scores"]) for x in data]
|
||||||
|
util.tokenizer.padding_side = "left"
|
||||||
|
prompt_encodings = util.tokenizer(
|
||||||
|
prompts_str,
|
||||||
|
truncation=True,
|
||||||
|
# max_length=max_length,
|
||||||
|
padding=False,
|
||||||
|
return_length=True,
|
||||||
|
return_attention_mask=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"{len(data)} samples, checking lengths (max_length={max_length})")
|
||||||
|
indices = [
|
||||||
|
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
|
||||||
|
]
|
||||||
|
logger.info(f"{len(indices)} samples remain")
|
||||||
|
|
||||||
|
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
|
||||||
|
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
|
||||||
|
self.ids = [
|
||||||
|
str(self.ids[idx]) + f"@idx:{idx}-{util.dp_rank}" for idx in indices
|
||||||
|
]
|
||||||
|
if "scores" in data[0]:
|
||||||
|
self.base_scores = [self.base_scores[idx] for idx in indices]
|
||||||
|
|
||||||
|
assert all(len(x) == l for x, l in zip(self.prompts, self.prompt_lengths))
|
||||||
|
|
||||||
|
logger.info(f"Number of prompts in the dataset: {len(self.prompts)}")
|
||||||
|
|
||||||
|
self.active_indices = list(range(len(self.prompts)))
|
||||||
|
self.filter_threshold = filter_threshold
|
||||||
|
self.max_filter_percentage = max_filter_percentage
|
||||||
|
|
||||||
|
@property
|
||||||
|
def util(self):
|
||||||
|
return self._util
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.active_indices)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# print(self.base_scores)
|
||||||
|
idx = self.active_indices[idx]
|
||||||
|
data = dict(
|
||||||
|
task_ids=torch.tensor([self.tasks_ids[idx]], dtype=torch.long),
|
||||||
|
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
|
||||||
|
)
|
||||||
|
if hasattr(self, "base_scores"):
|
||||||
|
data["base_scores"] = torch.tensor(
|
||||||
|
[self.base_scores[idx]], dtype=torch.float32
|
||||||
|
)
|
||||||
|
return data_api.SequenceSample.from_default(
|
||||||
|
ids=[self.ids[idx]],
|
||||||
|
seqlens=[self.prompt_lengths[idx]],
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter(self, eval_scores: Dict[Hashable, float]):
|
||||||
|
# Get all data indices that have a higher score than the threshold.
|
||||||
|
idx2scores_to_remove = {}
|
||||||
|
for pop_idx, idx in enumerate(self.active_indices):
|
||||||
|
data_id = self.ids[idx]
|
||||||
|
if data_id not in eval_scores:
|
||||||
|
continue
|
||||||
|
if eval_scores[data_id] > self.filter_threshold:
|
||||||
|
idx2scores_to_remove[pop_idx] = eval_scores[data_id]
|
||||||
|
|
||||||
|
# Control the number of samples to be removed according to max_filter_percentage.
|
||||||
|
n = int(len(self.active_indices) * self.max_filter_percentage)
|
||||||
|
indices_to_remove = sorted(
|
||||||
|
idx2scores_to_remove.keys(),
|
||||||
|
key=lambda x: idx2scores_to_remove[x],
|
||||||
|
reverse=True,
|
||||||
|
)[:n]
|
||||||
|
|
||||||
|
for pop_idx in sorted(indices_to_remove, reverse=True):
|
||||||
|
self.active_indices.pop(pop_idx)
|
||||||
|
logger.info(
|
||||||
|
f"Math prompt dataset DP rank {self.util.dp_rank} filtered"
|
||||||
|
f" {len(indices_to_remove)} samples, {len(self.active_indices)} samples remain. "
|
||||||
|
f"Original dataset size: {len(self.prompts)}. "
|
||||||
|
f"Filter threshold: {self.filter_threshold}. "
|
||||||
|
f"Max filter percentage: {self.max_filter_percentage}. "
|
||||||
|
f"Current number of eval scores: {len(eval_scores)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if not __name__ == "__main__":
|
||||||
|
data_api.register_dataset("math_code_prompt", MATHCodePromptDataset)
|
||||||
|
else:
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
dataset = MATHCodePromptDataset(
|
||||||
|
data_api.DatasetUtility(
|
||||||
|
seed=0,
|
||||||
|
dp_rank=0,
|
||||||
|
world_size=1,
|
||||||
|
tokenizer=AutoTokenizer.from_pretrained(
|
||||||
|
"/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
max_length=512,
|
||||||
|
dataset_path="/storage/openpsi/users/bowei.fw/data/code_math.jsonl",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
collate_fn=data_api.SequenceSample.gather,
|
||||||
|
# NOTE: This is *NOT* the actual batch size for training.
|
||||||
|
# It is just a proper size to load data to workers.
|
||||||
|
batch_size=4,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
print(f"size: {len(dataset)}")
|
||||||
|
for d in dataloader:
|
||||||
|
print(d.ids)
|
|
@ -4,13 +4,10 @@ import json
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import *
|
from typing import *
|
||||||
|
|
||||||
from realhf.base import logging
|
from realhf.base import logging
|
||||||
from realhf.base.constants import parallelism_rank
|
|
||||||
|
|
||||||
logger = logging.getLogger("math parser")
|
logger = logging.getLogger("math parser")
|
||||||
|
|
||||||
|
@ -48,20 +45,7 @@ def loadJson(dataDir):
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
headers = {
|
def parse_line(id2info, prompt_str, generated, query_id):
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
id2info = None
|
|
||||||
|
|
||||||
|
|
||||||
def parse_line(prompt_str, generated, query_id):
|
|
||||||
global id2info
|
|
||||||
if id2info is None:
|
|
||||||
try:
|
|
||||||
id2info = loadJson(os.environ["REAL_MATH_METADATA_PATH"])
|
|
||||||
except KeyError as e:
|
|
||||||
raise KeyError("The json file REAL_MATH_METADATA_PATH is not set") from e
|
|
||||||
info = id2info[query_id.split("@idx:")[0]]
|
info = id2info[query_id.split("@idx:")[0]]
|
||||||
|
|
||||||
tmp_id = str(uuid.uuid4())
|
tmp_id = str(uuid.uuid4())
|
||||||
|
@ -112,17 +96,12 @@ def parse_line(prompt_str, generated, query_id):
|
||||||
|
|
||||||
|
|
||||||
def parse_lines_in_parallel(
|
def parse_lines_in_parallel(
|
||||||
|
id2info,
|
||||||
generateds: List,
|
generateds: List,
|
||||||
query_ids: List,
|
query_ids: List,
|
||||||
max_workers=22,
|
max_workers=22,
|
||||||
check_xml_format=False,
|
check_xml_format=False,
|
||||||
) -> List:
|
) -> List:
|
||||||
global id2info
|
|
||||||
if id2info is None:
|
|
||||||
try:
|
|
||||||
id2info = loadJson(os.environ["REAL_MATH_METADATA_PATH"])
|
|
||||||
except KeyError as e:
|
|
||||||
raise KeyError("The json file REAL_MATH_METADATA_PATH is not set") from e
|
|
||||||
assert len(generateds) == len(query_ids), (
|
assert len(generateds) == len(query_ids), (
|
||||||
len(generateds),
|
len(generateds),
|
||||||
len(query_ids),
|
len(query_ids),
|
||||||
|
@ -204,17 +183,19 @@ def parse_lines_in_parallel(
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sample = {
|
sample = {
|
||||||
"prompt": "",
|
"answers": ["-\\frac{2}{3}"],
|
||||||
"query_id": "35ecd821a9e7e31da9ef0663a25347ce",
|
"solutions": [
|
||||||
# "answer_in_box": ["\\boxed{\\frac{1}{2}}", "<think></think><answer>\\boxed{\\frac{1}{2}}</answer>"]
|
"1. **Apply the operation $\\otimes$ to the innermost parentheses first:**\n \\[\n (1 \\otimes 2) \\otimes 3 = \\left(\\frac{1^2}{2}\\right) \\otimes 3 = \\frac{1}{2} \\otimes 3\n \\]\n \\[\n 1 \\otimes (2 \\otimes 3) = 1 \\otimes \\left(\\frac{2^2}{3}\\right) = 1 \\otimes \\frac{4}{3}\n \\]\n\n2. **Calculate each part using the definition of $\\otimes$:**\n \\[\n \\frac{1}{2} \\otimes 3 = \\frac{\\left(\\frac{1}{2}\\right)^2}{3} = \\frac{\\frac{1}{4}}{3} = \\frac{1}{12}\n \\]\n \\[\n 1 \\otimes \\frac{4}{3} = \\frac{1^2}{\\frac{4}{3}} = \\frac{1}{\\frac{4}{3}} = \\frac{3}{4}\n \\]\n\n3. **Subtract the two results:**\n \\[\n \\left(\\frac{1}{12}\\right) - \\left(\\frac{3}{4}\\right) = \\frac{1}{12} - \\frac{9}{12} = -\\frac{8}{12} = -\\frac{2}{3}\n \\]\n\n4. **Conclude with the final answer:**\n \\[\n \\boxed{A}\n \\]",
|
||||||
"answer": "<think>\n1. The problem requires us to determine the number of sequences of 144 hand movements such that every position appears exactly once and the hands return to the initial position at the end.\n2. We know that each movement involves one hand moving clockwise to the next number while the other hand stays in place.\n3. Considering the 12-hour clock, we can represent each positioning of the hands as a combination of the positions of both hands. Since both hands can be in any of the 12 positions, there are 12 x 12 = 144 different positionings.\n4. Given that at each position only one hand moves, every single movement is unique, leading to a total of 144 unique movements.\n5. These 144 movements must form a Hamiltonian Cycle, where each edge represents a valid movement between two positions.\n6. The problem thus reduces to finding a Hamiltonian cycle in a directed graph. Since the appearance of each movement is unique, it also determines the direction of the movement.\n7. We consider the edges that are rotations of each other as equivalent. Taking the rotational symmetry into account, we have 144/12 = 12 equivalence classes.\n8. The problem now is to determine the number of ways to arrange these 12 classes of rotations in a circle, which is 11 factorial.\n9. We must find the value of 11! and then compute the result modulo 1000.\n</think>\n<answer>\n320\n</answer>",
|
"\\boxed{-\\frac{2}{3}}",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
id2info = {"fe11b471-1aa9-4867-958f-a0a811c85f92": sample}
|
||||||
|
|
||||||
print(
|
print(
|
||||||
parse_lines_in_parallel(
|
parse_lines_in_parallel(
|
||||||
# [sample["answer_in_box"][0] for _ in range(50)] + [sample["answer_in_box"][1] for _ in range(50)],
|
id2info,
|
||||||
[sample["answer"]] * 100,
|
sample["answers"] * 100,
|
||||||
[sample["query_id"] for _ in range(100)],
|
["fe11b471-1aa9-4867-958f-a0a811c85f92" for _ in range(100)],
|
||||||
max_workers=8,
|
max_workers=8,
|
||||||
check_xml_format=True,
|
check_xml_format=True,
|
||||||
)
|
)
|
|
@ -2,7 +2,6 @@
|
||||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||||
|
|
||||||
import uuid
|
|
||||||
from typing import Callable, Dict, Hashable, List, Optional
|
from typing import Callable, Dict, Hashable, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -85,140 +84,4 @@ class PromptDataset(torch.utils.data.Dataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MATHCodePromptDataset(torch.utils.data.Dataset):
|
data_api.register_dataset("prompt", PromptDataset)
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
util: data_api.DatasetUtility,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
dataset_path: Optional[str] = None,
|
|
||||||
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
|
|
||||||
filter_threshold: float = 1e4,
|
|
||||||
max_filter_percentage: float = 0.0,
|
|
||||||
):
|
|
||||||
self._util = util
|
|
||||||
self.max_length = max_length
|
|
||||||
|
|
||||||
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
|
|
||||||
|
|
||||||
prompts_str = [x["prompt"] for x in data]
|
|
||||||
self.ids = [x["query_id"] for x in data]
|
|
||||||
self.tasks_ids = [data_api.RL_TASKS.index(x["task"]) for x in data]
|
|
||||||
if "scores" in data[0]:
|
|
||||||
self.base_scores = [np.mean(x["scores"]) for x in data]
|
|
||||||
util.tokenizer.padding_side = "left"
|
|
||||||
prompt_encodings = util.tokenizer(
|
|
||||||
prompts_str,
|
|
||||||
truncation=True,
|
|
||||||
# max_length=max_length,
|
|
||||||
padding=False,
|
|
||||||
return_length=True,
|
|
||||||
return_attention_mask=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"{len(data)} samples, checking lengths (max_length={max_length})")
|
|
||||||
indices = [
|
|
||||||
i for i, x in enumerate(prompt_encodings["length"]) if x <= max_length
|
|
||||||
]
|
|
||||||
logger.info(f"{len(indices)} samples remain")
|
|
||||||
|
|
||||||
self.prompt_lengths = [int(prompt_encodings["length"][idx]) for idx in indices]
|
|
||||||
self.prompts = [prompt_encodings["input_ids"][idx] for idx in indices]
|
|
||||||
self.ids = [
|
|
||||||
str(self.ids[idx]) + f"@idx:{idx}-{util.dp_rank}" for idx in indices
|
|
||||||
]
|
|
||||||
if "scores" in data[0]:
|
|
||||||
self.base_scores = [self.base_scores[idx] for idx in indices]
|
|
||||||
|
|
||||||
assert all(len(x) == l for x, l in zip(self.prompts, self.prompt_lengths))
|
|
||||||
|
|
||||||
logger.info(f"Number of prompts in the dataset: {len(self.prompts)}")
|
|
||||||
|
|
||||||
self.active_indices = list(range(len(self.prompts)))
|
|
||||||
self.filter_threshold = filter_threshold
|
|
||||||
self.max_filter_percentage = max_filter_percentage
|
|
||||||
|
|
||||||
@property
|
|
||||||
def util(self):
|
|
||||||
return self._util
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.active_indices)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
# print(self.base_scores)
|
|
||||||
idx = self.active_indices[idx]
|
|
||||||
data = dict(
|
|
||||||
task_ids=torch.tensor([self.tasks_ids[idx]], dtype=torch.long),
|
|
||||||
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long),
|
|
||||||
)
|
|
||||||
if hasattr(self, "base_scores"):
|
|
||||||
data["base_scores"] = torch.tensor(
|
|
||||||
[self.base_scores[idx]], dtype=torch.float32
|
|
||||||
)
|
|
||||||
return data_api.SequenceSample.from_default(
|
|
||||||
ids=[self.ids[idx]],
|
|
||||||
seqlens=[self.prompt_lengths[idx]],
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
|
|
||||||
def filter(self, eval_scores: Dict[Hashable, float]):
|
|
||||||
# Get all data indices that have a higher score than the threshold.
|
|
||||||
idx2scores_to_remove = {}
|
|
||||||
for pop_idx, idx in enumerate(self.active_indices):
|
|
||||||
data_id = self.ids[idx]
|
|
||||||
if data_id not in eval_scores:
|
|
||||||
continue
|
|
||||||
if eval_scores[data_id] > self.filter_threshold:
|
|
||||||
idx2scores_to_remove[pop_idx] = eval_scores[data_id]
|
|
||||||
|
|
||||||
# Control the number of samples to be removed according to max_filter_percentage.
|
|
||||||
n = int(len(self.active_indices) * self.max_filter_percentage)
|
|
||||||
indices_to_remove = sorted(
|
|
||||||
idx2scores_to_remove.keys(),
|
|
||||||
key=lambda x: idx2scores_to_remove[x],
|
|
||||||
reverse=True,
|
|
||||||
)[:n]
|
|
||||||
|
|
||||||
for pop_idx in sorted(indices_to_remove, reverse=True):
|
|
||||||
self.active_indices.pop(pop_idx)
|
|
||||||
logger.info(
|
|
||||||
f"Math prompt dataset DP rank {self.util.dp_rank} filtered"
|
|
||||||
f" {len(indices_to_remove)} samples, {len(self.active_indices)} samples remain. "
|
|
||||||
f"Original dataset size: {len(self.prompts)}. "
|
|
||||||
f"Filter threshold: {self.filter_threshold}. "
|
|
||||||
f"Max filter percentage: {self.max_filter_percentage}. "
|
|
||||||
f"Current number of eval scores: {len(eval_scores)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if not __name__ == "__main__":
|
|
||||||
data_api.register_dataset("prompt", PromptDataset)
|
|
||||||
data_api.register_dataset("math_code_prompt", MATHCodePromptDataset)
|
|
||||||
else:
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
dataset = MATHCodePromptDataset(
|
|
||||||
data_api.DatasetUtility(
|
|
||||||
seed=0,
|
|
||||||
dp_rank=0,
|
|
||||||
world_size=1,
|
|
||||||
tokenizer=AutoTokenizer.from_pretrained(
|
|
||||||
"/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
|
|
||||||
),
|
|
||||||
),
|
|
||||||
max_length=512,
|
|
||||||
dataset_path="/storage/openpsi/users/bowei.fw/data/code_math.jsonl",
|
|
||||||
)
|
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
collate_fn=data_api.SequenceSample.gather,
|
|
||||||
# NOTE: This is *NOT* the actual batch size for training.
|
|
||||||
# It is just a proper size to load data to workers.
|
|
||||||
batch_size=4,
|
|
||||||
shuffle=True,
|
|
||||||
)
|
|
||||||
print(f"size: {len(dataset)}")
|
|
||||||
for d in dataloader:
|
|
||||||
print(d.ids)
|
|
||||||
|
|
|
@ -67,35 +67,5 @@ class FusedThreadingForwardInterface(model_api.ModelInterface):
|
||||||
|
|
||||||
return final_result
|
return final_result
|
||||||
|
|
||||||
# Mock methods for profiling only.
|
|
||||||
def _mock_inference(
|
|
||||||
self,
|
|
||||||
model: model_api.Model,
|
|
||||||
data: SequenceSample,
|
|
||||||
) -> SequenceSample:
|
|
||||||
prompt_lens = flat2d(data.seqlens["packed_prompts"])
|
|
||||||
seqlens = [x + 1024 for x in prompt_lens]
|
|
||||||
module = model.module
|
|
||||||
if not isinstance(module, ReaLModel):
|
|
||||||
module = module.module
|
|
||||||
mconfig = module.config
|
|
||||||
packed_input_ids = torch.randint(
|
|
||||||
0,
|
|
||||||
mconfig.vocab_size,
|
|
||||||
(sum(seqlens),),
|
|
||||||
dtype=torch.long,
|
|
||||||
device=model.device,
|
|
||||||
)
|
|
||||||
n_tasks = len(RL_TASKS)
|
|
||||||
task_ids = torch.randint(
|
|
||||||
0, n_tasks, (data.bs,), dtype=torch.long, device=model.device
|
|
||||||
)
|
|
||||||
|
|
||||||
return SequenceSample.from_default(
|
|
||||||
seqlens=seqlens,
|
|
||||||
ids=data.ids,
|
|
||||||
data=dict(packed_input_ids=packed_input_ids, task_ids=task_ids),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
model_api.register_interface("fused-threading", FusedThreadingForwardInterface)
|
model_api.register_interface("fused-threading", FusedThreadingForwardInterface)
|
||||||
|
|
|
@ -18,7 +18,7 @@ import realhf.base.logging as logging
|
||||||
import realhf.impl.model.utils.ppo_functional as ppo_functional
|
import realhf.impl.model.utils.ppo_functional as ppo_functional
|
||||||
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
|
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
|
||||||
from realhf.base.datapack import flat2d
|
from realhf.base.datapack import flat2d
|
||||||
from realhf.impl.model.interface.math_parser import parse_lines_in_parallel
|
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
|
||||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||||
from realhf.impl.model.nn.real_llm_generate import concat_prompt_to_generation_output
|
from realhf.impl.model.nn.real_llm_generate import concat_prompt_to_generation_output
|
||||||
from realhf.impl.model.utils.functional import (
|
from realhf.impl.model.utils.functional import (
|
||||||
|
|
|
@ -18,6 +18,7 @@ import torch.distributed as dist
|
||||||
|
|
||||||
import realhf.api.core.model_api as model_api
|
import realhf.api.core.model_api as model_api
|
||||||
import realhf.base.logging as logging
|
import realhf.base.logging as logging
|
||||||
|
from functioncall.code.local_verify import code_verify as local_code_verify
|
||||||
from functioncall.code.verify import code_verify
|
from functioncall.code.verify import code_verify
|
||||||
from functioncall.math.verify import math_verify
|
from functioncall.math.verify import math_verify
|
||||||
from realhf.api.core.data_api import (
|
from realhf.api.core.data_api import (
|
||||||
|
@ -28,14 +29,14 @@ from realhf.api.core.data_api import (
|
||||||
)
|
)
|
||||||
from realhf.base import constants
|
from realhf.base import constants
|
||||||
from realhf.base.datapack import flat2d
|
from realhf.base.datapack import flat2d
|
||||||
from realhf.impl.model.interface.math_parser import (
|
from realhf.impl.dataset.math_code_dataset import id2info
|
||||||
parse_lines_in_parallel as math_verify_local,
|
from realhf.impl.dataset.math_parser import parse_lines_in_parallel as math_verify_local
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
|
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
|
||||||
|
|
||||||
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
|
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
|
||||||
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else math_verify_local
|
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else math_verify_local
|
||||||
|
code_verify_call = code_verify if ENABLE_FUNCTION_CALL else local_code_verify
|
||||||
|
|
||||||
|
|
||||||
class VerifierException(Exception):
|
class VerifierException(Exception):
|
||||||
|
@ -61,6 +62,7 @@ def extract_python_code(text, min_length=20, strict_syntax=True):
|
||||||
valid_blocks.append(clean_block)
|
valid_blocks.append(clean_block)
|
||||||
|
|
||||||
if not valid_blocks:
|
if not valid_blocks:
|
||||||
|
logger.warning(f"failed to extract python code from {text}")
|
||||||
return None
|
return None
|
||||||
# return the last code block
|
# return the last code block
|
||||||
return valid_blocks[-1]
|
return valid_blocks[-1]
|
||||||
|
@ -120,14 +122,20 @@ def check_with_elementtree(text):
|
||||||
|
|
||||||
|
|
||||||
def dispatch_reward_calculation(task, answers, query_id_strs) -> List:
|
def dispatch_reward_calculation(task, answers, query_id_strs) -> List:
|
||||||
|
global id2info
|
||||||
assert len(answers) == len(query_id_strs)
|
assert len(answers) == len(query_id_strs)
|
||||||
format_rewards = []
|
format_rewards = []
|
||||||
if task == "math":
|
if task == "math":
|
||||||
format_rewards = math_verify_call(answers, query_id_strs)
|
format_rewards = math_verify_call(id2info, answers, query_id_strs)
|
||||||
elif task == "code":
|
elif task == "code":
|
||||||
codes = [extract_python_code(_answer) for _answer in answers]
|
codes = [extract_python_code(_answer) for _answer in answers]
|
||||||
format_rewards = code_verify(codes, query_id_strs)
|
format_rewards = code_verify_call(id2info, codes, query_id_strs)
|
||||||
assert len(format_rewards) == len(answers), task
|
assert len(format_rewards) == len(answers), (
|
||||||
|
task,
|
||||||
|
len(format_rewards),
|
||||||
|
len(answers),
|
||||||
|
answers,
|
||||||
|
)
|
||||||
return format_rewards
|
return format_rewards
|
||||||
|
|
||||||
|
|
||||||
|
@ -167,10 +175,9 @@ def retokenize_and_verify(
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class MultiTaskCPURewardInterface(model_api.ModelInterface):
|
class MultiTaskRewardInterface(model_api.ModelInterface):
|
||||||
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
||||||
output_scaling: float = 1.0
|
output_scaling: float = 1.0
|
||||||
output_scaling: float = 1.0
|
|
||||||
output_bias: float = 0.0
|
output_bias: float = 0.0
|
||||||
rw_type: str = "sparse"
|
rw_type: str = "sparse"
|
||||||
check_xml_format: bool = False
|
check_xml_format: bool = False
|
||||||
|
@ -183,21 +190,15 @@ class MultiTaskCPURewardInterface(model_api.ModelInterface):
|
||||||
logger.info(f"output_scaling: {self.output_scaling}")
|
logger.info(f"output_scaling: {self.output_scaling}")
|
||||||
logger.info(f"output_bias: {self.output_bias}")
|
logger.info(f"output_bias: {self.output_bias}")
|
||||||
logger.info(f"rw_type: {self.rw_type}")
|
logger.info(f"rw_type: {self.rw_type}")
|
||||||
if (
|
|
||||||
constants.data_parallel_world_size()
|
|
||||||
< constants.parallelism_group_size()
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"There's no reason to use tensor and pipeline parallelism for CPU reward."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _dispatch_tasks(self, data: SequenceSample) -> Tuple[Dict, Dict]:
|
def _dispatch_tasks(self, data: SequenceSample) -> Tuple[Dict, Dict]:
|
||||||
xs = data.unpack()
|
xs = data.unpack()
|
||||||
dispatched = {}
|
dispatched = {}
|
||||||
dispatched_indices = {}
|
dispatched_indices = {}
|
||||||
for task_idx, task_name in enumerate(RL_TASKS):
|
for task_idx, task_name in enumerate(RL_TASKS):
|
||||||
indices = (data.data["task_ids"] == task_idx).numpy().tolist()
|
indices = (data.data["task_ids"] == task_idx).numpy().nonzero()[0].tolist()
|
||||||
if any(indices):
|
print("============", task_idx, task_name, indices)
|
||||||
|
if len(indices) > 0:
|
||||||
dispatched[task_name] = SequenceSample.gather([xs[i] for i in indices])
|
dispatched[task_name] = SequenceSample.gather([xs[i] for i in indices])
|
||||||
dispatched_indices[task_name] = indices
|
dispatched_indices[task_name] = indices
|
||||||
|
|
||||||
|
@ -208,7 +209,9 @@ class MultiTaskCPURewardInterface(model_api.ModelInterface):
|
||||||
) -> SequenceSample:
|
) -> SequenceSample:
|
||||||
xs = [None for _ in range(bs)]
|
xs = [None for _ in range(bs)]
|
||||||
for task_name, indices in dispatched_indices.items():
|
for task_name, indices in dispatched_indices.items():
|
||||||
|
print(task_name, indices)
|
||||||
xxs = results[task_name].unpack()
|
xxs = results[task_name].unpack()
|
||||||
|
assert len(indices) == len(xxs), (len(indices), len(xxs))
|
||||||
for i, xx in zip(indices, xxs):
|
for i, xx in zip(indices, xxs):
|
||||||
xs[i] = xx
|
xs[i] = xx
|
||||||
assert all(xs)
|
assert all(xs)
|
||||||
|
@ -260,7 +263,7 @@ class MultiTaskCPURewardInterface(model_api.ModelInterface):
|
||||||
self.log_rewards_to_file(task_type, model, prompt_strs, seq_strs, scores)
|
self.log_rewards_to_file(task_type, model, prompt_strs, seq_strs, scores)
|
||||||
|
|
||||||
# NOTE: a place holder
|
# NOTE: a place holder
|
||||||
dense_scores = packed_input_ids.new_zeros(dtype=torch.float32)
|
dense_scores = torch.zeros_like(packed_input_ids, dtype=torch.float32)
|
||||||
|
|
||||||
res = SequenceSample(
|
res = SequenceSample(
|
||||||
keys=["rewards", "dense_rewards"],
|
keys=["rewards", "dense_rewards"],
|
||||||
|
@ -368,6 +371,8 @@ class MultiTaskCPURewardInterface(model_api.ModelInterface):
|
||||||
mb_spec: MicroBatchSpec,
|
mb_spec: MicroBatchSpec,
|
||||||
) -> SequenceSample | None:
|
) -> SequenceSample | None:
|
||||||
task_data, dispatch_indices = self._dispatch_tasks(data)
|
task_data, dispatch_indices = self._dispatch_tasks(data)
|
||||||
|
for d in task_data.values():
|
||||||
|
print(d.bs)
|
||||||
|
|
||||||
assert self.rw_type == "sparse"
|
assert self.rw_type == "sparse"
|
||||||
|
|
||||||
|
@ -377,11 +382,13 @@ class MultiTaskCPURewardInterface(model_api.ModelInterface):
|
||||||
try:
|
try:
|
||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise asyncio.CancelledError(f"{task_type} task failed: {e}") from e
|
raise asyncio.CancelledError(
|
||||||
|
f"[{task_type}] task failed: {e}"
|
||||||
|
) from e
|
||||||
finally:
|
finally:
|
||||||
duration = time.perf_counter() - start_time
|
duration = time.perf_counter() - start_time
|
||||||
logger.info(f"[{task_type}] time cost: {duration:.4f}s")
|
logger.info(f"[{task_type}] time cost: {duration:.4f}s")
|
||||||
return result
|
return task_type, result
|
||||||
|
|
||||||
return _wrapped_func
|
return _wrapped_func
|
||||||
|
|
||||||
|
@ -391,11 +398,12 @@ class MultiTaskCPURewardInterface(model_api.ModelInterface):
|
||||||
task_func = _task_func(self.calculate_task_reward, task_type)
|
task_func = _task_func(self.calculate_task_reward, task_type)
|
||||||
task_args = (model, d, mb_spec, task_type)
|
task_args = (model, d, mb_spec, task_type)
|
||||||
task = asyncio.create_task(asyncio.to_thread(task_func, *task_args))
|
task = asyncio.create_task(asyncio.to_thread(task_func, *task_args))
|
||||||
tasks.append((task_type, task))
|
tasks.append(task)
|
||||||
|
|
||||||
task_results = {}
|
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
for task_type, result in results:
|
task_results = {}
|
||||||
|
for res in results:
|
||||||
|
task_type, result = res
|
||||||
task_results[task_type] = result
|
task_results[task_type] = result
|
||||||
|
|
||||||
return task_results
|
return task_results
|
||||||
|
@ -410,5 +418,46 @@ class MultiTaskCPURewardInterface(model_api.ModelInterface):
|
||||||
|
|
||||||
return final_result
|
return final_result
|
||||||
|
|
||||||
|
def _mock_inference(
|
||||||
|
self,
|
||||||
|
model: model_api.Model,
|
||||||
|
data: SequenceSample,
|
||||||
|
) -> SequenceSample:
|
||||||
|
from realhf.base.testing import TESTING_MODEL_VOCAB_SIZE
|
||||||
|
|
||||||
model_api.register_interface("reward", MultiTaskCPURewardInterface)
|
prompt_lens = flat2d(data.seqlens["packed_prompts"])
|
||||||
|
task_ids = data.data["task_ids"].cpu().numpy().tolist()
|
||||||
|
seqlens = []
|
||||||
|
offset = 0
|
||||||
|
seq = []
|
||||||
|
for plen, task_id in zip(prompt_lens, task_ids):
|
||||||
|
seq += [data.data["packed_prompts"][offset : offset + plen]]
|
||||||
|
offset += plen
|
||||||
|
if task_id == RL_TASKS.index("math"):
|
||||||
|
answer_str = (
|
||||||
|
"something unimportant but the answer is \\boxed{-\\frac{2}{3}}"
|
||||||
|
)
|
||||||
|
elif task_id == RL_TASKS.index("code"):
|
||||||
|
answer_str = (
|
||||||
|
"```python\ninput()\nprint(1)\nprint(1)\nprint(1)\nprint(1)\n```"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
answer_str = "something unimportant"
|
||||||
|
encoding = model.tokenizer(
|
||||||
|
[answer_str], add_special_tokens=True, return_attention_mask=False
|
||||||
|
)
|
||||||
|
|
||||||
|
ans = torch.tensor(encoding["input_ids"], dtype=torch.long).flatten()
|
||||||
|
seq += [ans]
|
||||||
|
seqlens.append(plen + len(ans))
|
||||||
|
|
||||||
|
x = SequenceSample.from_default(
|
||||||
|
seqlens=seqlens,
|
||||||
|
ids=data.ids,
|
||||||
|
data=dict(packed_input_ids=torch.cat(seq)),
|
||||||
|
)
|
||||||
|
data.update_(x)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
model_api.register_interface("reward", MultiTaskRewardInterface)
|
||||||
|
|
|
@ -618,8 +618,6 @@ class RayController:
|
||||||
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
|
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
|
||||||
REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""),
|
REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""),
|
||||||
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
|
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
|
||||||
REAL_MATH_METADATA_PATH=os.environ.get("REAL_MATH_METADATA_PATH", ""),
|
|
||||||
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
|
|
||||||
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
||||||
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
|
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
|
||||||
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),
|
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),
|
||||||
|
|
|
@ -27,9 +27,7 @@ def model_class(request):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=[testing.TESTING_DATASET_SIZE])
|
@pytest.fixture(params=[testing.TESTING_DATASET_SIZE])
|
||||||
def math_dataset(request, save_path):
|
def math_code_dataset(request, save_path):
|
||||||
with open(os.getenv("REAL_MATH_METADATA_PATH"), "r") as f:
|
|
||||||
query_ids = list(json.load(f).keys())
|
|
||||||
size = request.param
|
size = request.param
|
||||||
max_prompt_len = 8
|
max_prompt_len = 8
|
||||||
max_resp_len = 8
|
max_resp_len = 8
|
||||||
|
@ -42,7 +40,7 @@ def math_dataset(request, save_path):
|
||||||
prompt=generate_random_sentence(prompt_len),
|
prompt=generate_random_sentence(prompt_len),
|
||||||
)
|
)
|
||||||
dataset.append(d)
|
dataset.append(d)
|
||||||
with open(str(save_path / "math_dataset.json"), "w") as f:
|
with open(str(save_path / "math_code_dataset.json"), "w") as f:
|
||||||
json.dump(dataset, f)
|
json.dump(dataset, f)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
@ -51,9 +49,9 @@ def math_dataset(request, save_path):
|
||||||
"dp,pp,mp",
|
"dp,pp,mp",
|
||||||
[
|
[
|
||||||
(1, 1, 1),
|
(1, 1, 1),
|
||||||
(2, 1, 2),
|
# (2, 1, 2),
|
||||||
(1, 2, 1),
|
# (1, 2, 1),
|
||||||
(1, 1, 2),
|
# (1, 1, 2),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_ppo_symm(
|
def test_ppo_symm(
|
||||||
|
@ -97,11 +95,6 @@ def test_ppo_symm(
|
||||||
init_critic_from_actor=True,
|
init_critic_from_actor=True,
|
||||||
backend="mock_train",
|
backend="mock_train",
|
||||||
),
|
),
|
||||||
rew=ModelTrainEvalConfig(
|
|
||||||
path=str(save_path),
|
|
||||||
init_critic_from_actor=True,
|
|
||||||
init_from_scratch=True,
|
|
||||||
),
|
|
||||||
dataset=PromptOnlyDatasetConfig(
|
dataset=PromptOnlyDatasetConfig(
|
||||||
path=str(save_path / "math_dataset.json"),
|
path=str(save_path / "math_dataset.json"),
|
||||||
max_prompt_len=mconfig.n_positions // 2,
|
max_prompt_len=mconfig.n_positions // 2,
|
||||||
|
@ -121,6 +114,7 @@ def test_ppo_symm(
|
||||||
run_test_exp(exp_cfg)
|
run_test_exp(exp_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("")
|
||||||
# The global resharding strategy, where all MFCs
|
# The global resharding strategy, where all MFCs
|
||||||
# occupy the same device mesh but with different
|
# occupy the same device mesh but with different
|
||||||
# parallelization strategies.
|
# parallelization strategies.
|
||||||
|
@ -242,6 +236,7 @@ def test_ppo_global_reshard(
|
||||||
run_test_exp(exp_cfg)
|
run_test_exp(exp_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("")
|
||||||
# Actor/critic train and ref_inf/rew_inf are on disjoint
|
# Actor/critic train and ref_inf/rew_inf are on disjoint
|
||||||
# device meshes and executed concurrently.
|
# device meshes and executed concurrently.
|
||||||
@pytest.mark.parametrize("actor_gen", [(2, 2, 1)])
|
@pytest.mark.parametrize("actor_gen", [(2, 2, 1)])
|
||||||
|
@ -358,6 +353,7 @@ def test_ppo_param_realloc_sub_device_mesh(
|
||||||
run_test_exp(exp_cfg)
|
run_test_exp(exp_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("")
|
||||||
@pytest.mark.parametrize("freq_step", [3, 4, 7])
|
@pytest.mark.parametrize("freq_step", [3, 4, 7])
|
||||||
@pytest.mark.parametrize("freq_epoch", [1, 2, 3])
|
@pytest.mark.parametrize("freq_epoch", [1, 2, 3])
|
||||||
@pytest.mark.parametrize("bs", [30, 80, 100])
|
@pytest.mark.parametrize("bs", [30, 80, 100])
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
# Copyright 2025 Ant Group Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import uuid
|
||||||
|
from typing import *
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.utils.data.dataloader import DataLoader
|
||||||
|
|
||||||
|
from realhf.api.core.data_api import (
|
||||||
|
DatasetUtility,
|
||||||
|
MicroBatchSpec,
|
||||||
|
SequenceSample,
|
||||||
|
load_hf_tokenizer,
|
||||||
|
)
|
||||||
|
from realhf.api.core.model_api import FinetuneSpec, Model
|
||||||
|
from realhf.base import constants, network, testing
|
||||||
|
from tests.fixtures import *
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[testing.TESTING_DATASET_SIZE])
|
||||||
|
def math_code_dataset(request, save_path):
|
||||||
|
size = request.param
|
||||||
|
max_prompt_len = 8
|
||||||
|
dataset = []
|
||||||
|
for i in range(size):
|
||||||
|
prompt_len = random.randint(1, max_prompt_len)
|
||||||
|
n_pairs = random.randint(1, 5)
|
||||||
|
if random.random() < 0.5:
|
||||||
|
d = dict(
|
||||||
|
task="code",
|
||||||
|
query_id=str(uuid.uuid4()),
|
||||||
|
prompt=generate_random_sentence(prompt_len),
|
||||||
|
problem_id=str(uuid.uuid4()),
|
||||||
|
input_output=json.dumps(
|
||||||
|
{"inputs": [1] * 8, "outputs": ["1\n1\n1\n1\n"] * 8}
|
||||||
|
),
|
||||||
|
solutions=json.dumps(
|
||||||
|
["```python\ninput()\nprint(1)\nprint(1)\nprint(1)\nprint(1)\n```"]
|
||||||
|
* 3
|
||||||
|
),
|
||||||
|
difficulty=random.random() * 10,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
d = dict(
|
||||||
|
task="math",
|
||||||
|
query_id=str(uuid.uuid4()),
|
||||||
|
prompt=generate_random_sentence(prompt_len),
|
||||||
|
answers=["-\\frac{2}{3}"],
|
||||||
|
solutions=["-\\frac{2}{3}"],
|
||||||
|
)
|
||||||
|
dataset.append(d)
|
||||||
|
with open(str(save_path / "math_code_dataset.jsonl"), "w") as f:
|
||||||
|
f.write("\n".join([json.dumps(d) for d in dataset]))
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tokenizer_path", ["/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"]
|
||||||
|
)
|
||||||
|
def test_multi_task_reward_interface(save_path, tokenizer_path, math_code_dataset):
|
||||||
|
from realhf.impl.dataset.math_code_dataset import MATHCodePromptDataset
|
||||||
|
|
||||||
|
dist.init_process_group(
|
||||||
|
rank=0, world_size=1, init_method=f"tcp://localhost:{network.find_free_port()}"
|
||||||
|
)
|
||||||
|
testing.init_global_constants()
|
||||||
|
|
||||||
|
dataset = MATHCodePromptDataset(
|
||||||
|
DatasetUtility(
|
||||||
|
seed=0,
|
||||||
|
dp_rank=0,
|
||||||
|
world_size=1,
|
||||||
|
tokenizer=load_hf_tokenizer(tokenizer_path),
|
||||||
|
),
|
||||||
|
max_length=512,
|
||||||
|
dataset_path=str(save_path / "math_code_dataset.jsonl"),
|
||||||
|
)
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
collate_fn=SequenceSample.gather,
|
||||||
|
# NOTE: This is *NOT* the actual batch size for training.
|
||||||
|
# It is just a proper size to load data to workers.
|
||||||
|
batch_size=4,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
from realhf.impl.model.interface.rw_interface import MultiTaskRewardInterface
|
||||||
|
|
||||||
|
with constants.model_scope(testing.MODEL_NAME):
|
||||||
|
interface = MultiTaskRewardInterface(
|
||||||
|
tokenizer_path=tokenizer_path,
|
||||||
|
group_size=1,
|
||||||
|
check_verifier_status=False,
|
||||||
|
)
|
||||||
|
model = Model(
|
||||||
|
name="test",
|
||||||
|
module=None,
|
||||||
|
tokenizer=load_hf_tokenizer(tokenizer_path),
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
ft_spec=FinetuneSpec(
|
||||||
|
total_train_epochs=1, dataset_size=100, train_batch_size=3
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
for d in dataloader:
|
||||||
|
d = interface.mock("inference", model, d)
|
||||||
|
rewards = interface.inference(model, d, mb_spec=MicroBatchSpec())
|
||||||
|
d.update_(rewards)
|
||||||
|
print("success")
|
Loading…
Reference in New Issue