mirror of https://github.com/inclusionAI/AReaL
create ref-rw inference aysnc mode
This commit is contained in:
parent
234e3dd3a0
commit
8732811c73
|
@ -19,7 +19,7 @@ def check_is_realhf_native_impl(_cls):
|
||||||
def check_is_realhf_native_model_interface(name):
|
def check_is_realhf_native_model_interface(name):
|
||||||
# NOTE: we should not import iterfaces here,
|
# NOTE: we should not import iterfaces here,
|
||||||
# such that we can avoid CUDA initialization.
|
# such that we can avoid CUDA initialization.
|
||||||
return name in ["ppo_actor", "ppo_critic", "sft"]
|
return name in ["ppo_actor", "ppo_critic", "sft", "ref_rw"]
|
||||||
|
|
||||||
|
|
||||||
def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]):
|
def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]):
|
||||||
|
|
|
@ -253,7 +253,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
"actor": self.actor,
|
"actor": self.actor,
|
||||||
# "critic": self.critic,
|
# "critic": self.critic,
|
||||||
"ref": self.ref,
|
"ref": self.ref,
|
||||||
"reward": self.rew,
|
# "reward": self.rew,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
|
@ -304,8 +304,6 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
"reward_delta": self.reward_delta,
|
"reward_delta": self.reward_delta,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ref_interface = copy.deepcopy(actor_interface)
|
|
||||||
ref_interface.args["enable_save"] = False
|
|
||||||
|
|
||||||
critic_interface = ModelInterfaceAbstraction(
|
critic_interface = ModelInterfaceAbstraction(
|
||||||
"ppo_critic",
|
"ppo_critic",
|
||||||
|
@ -319,7 +317,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
)
|
)
|
||||||
critic_interface.args.pop("eps_clip")
|
critic_interface.args.pop("eps_clip")
|
||||||
rw_interface = ModelInterfaceAbstraction(
|
rw_interface = ModelInterfaceAbstraction(
|
||||||
"rw_math",
|
"reward",
|
||||||
args=dict(
|
args=dict(
|
||||||
rw_type=self.rw_type,
|
rw_type=self.rw_type,
|
||||||
task=self.task,
|
task=self.task,
|
||||||
|
@ -368,9 +366,16 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
n_seqs=self.dataset.train_bs_n_seqs,
|
n_seqs=self.dataset.train_bs_n_seqs,
|
||||||
)
|
)
|
||||||
|
|
||||||
inf_ref_inputs = ["packed_input_ids"]
|
# add rew param into ref MFC
|
||||||
|
inf_ref_inputs = ["packed_input_ids", "packed_prompts"]
|
||||||
|
inf_ref_outputs = ["packed_ref_logprobs", "rewards", "dense_rewards"]
|
||||||
|
ref_interface = copy.deepcopy(actor_interface)
|
||||||
|
ref_interface.type_ = "ref_rw"
|
||||||
|
ref_interface.args["enable_save"] = False
|
||||||
|
ref_interface.args["rew_inf_args"] = copy.deepcopy(rw_interface.args)
|
||||||
|
|
||||||
inf_ref_logits = MFCDef(
|
inf_ref_logits = MFCDef(
|
||||||
name="ref_inf",
|
name="ref_rw",
|
||||||
model_name="ref",
|
model_name="ref",
|
||||||
mb_spec=self.ref_inf.mb_spec,
|
mb_spec=self.ref_inf.mb_spec,
|
||||||
interface_type=ModelInterfaceType.INFERENCE,
|
interface_type=ModelInterfaceType.INFERENCE,
|
||||||
|
@ -379,7 +384,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
interface_impl=ref_interface,
|
interface_impl=ref_interface,
|
||||||
min_n_seqs_per_pass=1 / self.group_size,
|
min_n_seqs_per_pass=1 / self.group_size,
|
||||||
input_keys=inf_ref_inputs,
|
input_keys=inf_ref_inputs,
|
||||||
output_keys=["packed_ref_logprobs"],
|
output_keys=inf_ref_outputs,
|
||||||
n_seqs=self.dataset.train_bs_n_seqs,
|
n_seqs=self.dataset.train_bs_n_seqs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -451,8 +456,9 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
"actor_train": train_actor,
|
"actor_train": train_actor,
|
||||||
# "critic_inf": inf_values,
|
# "critic_inf": inf_values,
|
||||||
# "critic_train": train_critic,
|
# "critic_train": train_critic,
|
||||||
"ref_inf": inf_ref_logits,
|
# "ref_inf": inf_ref_logits,
|
||||||
"rew_inf": inf_reward,
|
# "rew_inf": inf_reward,
|
||||||
|
"ref_rw": inf_ref_logits,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
|
@ -472,8 +478,9 @@ class PPOMATHConfig(CommonExperimentConfig):
|
||||||
"actor_train": self.actor_train,
|
"actor_train": self.actor_train,
|
||||||
# "critic_inf": self.critic_inf,
|
# "critic_inf": self.critic_inf,
|
||||||
# "critic_train": self.critic_train,
|
# "critic_train": self.critic_train,
|
||||||
"ref_inf": self.ref_inf,
|
# "ref_inf": self.ref_inf,
|
||||||
"rew_inf": self.rew_inf,
|
# "rew_inf": self.rew_inf,
|
||||||
|
"ref_rw": self.ref_inf,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -1,429 +0,0 @@
|
||||||
# Copyright 2025 Ant Group Inc.
|
|
||||||
import collections
|
|
||||||
import dataclasses
|
|
||||||
import html
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
from ast import parse
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
|
|
||||||
import colorama
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
import realhf.api.core.model_api as model_api
|
|
||||||
import realhf.base.logging as logging
|
|
||||||
from functioncall.code.verify import code_verify
|
|
||||||
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer
|
|
||||||
from realhf.base import constants
|
|
||||||
|
|
||||||
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
|
|
||||||
|
|
||||||
|
|
||||||
class CodeVerifierException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def extract_python_code(text, min_length=20, strict_syntax=True):
|
|
||||||
code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```"
|
|
||||||
code_blocks = re.findall(code_pattern, text, re.DOTALL)
|
|
||||||
valid_blocks = []
|
|
||||||
for block in code_blocks:
|
|
||||||
clean_block = block.strip()
|
|
||||||
if len(clean_block) < min_length:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# verify code syntax
|
|
||||||
if strict_syntax:
|
|
||||||
try:
|
|
||||||
parse(clean_block, mode="exec")
|
|
||||||
except (SyntaxError, IndentationError):
|
|
||||||
continue
|
|
||||||
|
|
||||||
valid_blocks.append(clean_block)
|
|
||||||
|
|
||||||
if not valid_blocks:
|
|
||||||
return None
|
|
||||||
# return the last code block
|
|
||||||
return valid_blocks[-1]
|
|
||||||
|
|
||||||
|
|
||||||
def check_with_elementtree(text):
|
|
||||||
def escape_between_tags(text, tags=["think", "answer"]):
|
|
||||||
"""转义标签之间的内容,但保留标签本身."""
|
|
||||||
# 构建标签模式
|
|
||||||
tag_pattern = "|".join(tags)
|
|
||||||
parts = []
|
|
||||||
current_pos = 0
|
|
||||||
|
|
||||||
# 匹配开始和结束标签
|
|
||||||
pattern = f"</?({tag_pattern})[^>]*>"
|
|
||||||
|
|
||||||
for match in re.finditer(pattern, text):
|
|
||||||
# 添加标签之前的内容(需要转义)
|
|
||||||
if current_pos < match.start():
|
|
||||||
parts.append(html.escape(text[current_pos : match.start()]))
|
|
||||||
|
|
||||||
# 添加标签本身(不转义)
|
|
||||||
parts.append(match.group())
|
|
||||||
current_pos = match.end()
|
|
||||||
|
|
||||||
# 添加最后剩余的内容
|
|
||||||
if current_pos < len(text):
|
|
||||||
parts.append(html.escape(text[current_pos:]))
|
|
||||||
|
|
||||||
return "".join(parts)
|
|
||||||
|
|
||||||
text = escape_between_tags(text)
|
|
||||||
if not text.strip().startswith("<think>"):
|
|
||||||
text = "<think>" + text
|
|
||||||
try:
|
|
||||||
xml_text = f"<root>{text}</root>"
|
|
||||||
x = ET.fromstring(xml_text)
|
|
||||||
if x.text is not None and x.text.strip() != "":
|
|
||||||
return False, f"Error: extra content before <think>. {x.text}"
|
|
||||||
if len(x) != 2:
|
|
||||||
return False, f"Error: there are {len(x)} tags."
|
|
||||||
if x[0].tag is None or x[0].tag != "think":
|
|
||||||
return False, f"Error: <think> tag is missing. {x[0].tag}"
|
|
||||||
if x[0].tail is not None and x[0].tail.strip() != "":
|
|
||||||
return (
|
|
||||||
False,
|
|
||||||
f"Error: extra content between <think> and <answer>. {x[0].tail}",
|
|
||||||
)
|
|
||||||
if x[1].tag is None or x[1].tag != "answer":
|
|
||||||
return False, f"Error: <answer> tag is missing. {x[1].tag}"
|
|
||||||
if x[1].tail is not None and x[1].tail.strip() != "":
|
|
||||||
return False, f"Error: extra content after <answer>, {x[1].tail}"
|
|
||||||
|
|
||||||
return True, x[1].text if x[1].text is not None else ""
|
|
||||||
except ET.ParseError as e:
|
|
||||||
return False, f"Error: XML格式错误, {str(e)}"
|
|
||||||
|
|
||||||
|
|
||||||
def retokenize(
|
|
||||||
task,
|
|
||||||
tokenizer,
|
|
||||||
packed_input_ids,
|
|
||||||
input_cu_seqlens,
|
|
||||||
prompts,
|
|
||||||
prompt_cu_seqlens,
|
|
||||||
query_ids,
|
|
||||||
check_xml_format=False,
|
|
||||||
do_eval=False,
|
|
||||||
):
|
|
||||||
input_ids = [
|
|
||||||
packed_input_ids[start:end]
|
|
||||||
for start, end in zip(input_cu_seqlens[:-1], input_cu_seqlens[1:])
|
|
||||||
]
|
|
||||||
prompt_ids = [
|
|
||||||
prompts[start:end]
|
|
||||||
for start, end in zip(prompt_cu_seqlens[:-1], prompt_cu_seqlens[1:])
|
|
||||||
]
|
|
||||||
seq_strs = tokenizer.batch_decode(
|
|
||||||
input_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
|
||||||
)
|
|
||||||
prompt_strs = tokenizer.batch_decode(
|
|
||||||
prompt_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
|
||||||
)
|
|
||||||
# query_id_strs = query_ids
|
|
||||||
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
|
|
||||||
|
|
||||||
format_rewards = []
|
|
||||||
|
|
||||||
queryid_to_results = collections.defaultdict(list)
|
|
||||||
# 8 processes on each node, with 10 subprocesses each
|
|
||||||
if do_eval == True:
|
|
||||||
_answers = [
|
|
||||||
seq_str.split(prompt_str)[1]
|
|
||||||
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
|
|
||||||
]
|
|
||||||
|
|
||||||
codes = [extract_python_code(_answer) for _answer in _answers]
|
|
||||||
logger.info(
|
|
||||||
f"code_rw_interface, size: {len(query_id_strs)}, valid code size: {len(codes)}, query_id_0: {query_id_strs[0]}"
|
|
||||||
)
|
|
||||||
format_rewards = code_verify(codes, query_id_strs)
|
|
||||||
|
|
||||||
if check_xml_format:
|
|
||||||
with ThreadPoolExecutor(max_workers=22) as executor:
|
|
||||||
futures = [
|
|
||||||
executor.submit(check_with_elementtree, answer_str)
|
|
||||||
for answer_str in _answers
|
|
||||||
]
|
|
||||||
# xml_rewards = []
|
|
||||||
for idx, future in enumerate(futures):
|
|
||||||
xml_reward, _ = future.result()
|
|
||||||
# xml_rewards.append(xml_reward)
|
|
||||||
if xml_reward == 1 and format_rewards[idx] == 0:
|
|
||||||
format_rewards[idx] = -0.8
|
|
||||||
elif xml_reward == 0 and format_rewards[idx] == 0:
|
|
||||||
format_rewards[idx] = -1
|
|
||||||
|
|
||||||
for query_id_str, format_reward in zip(query_id_strs, format_rewards):
|
|
||||||
if query_id_str not in queryid_to_results:
|
|
||||||
queryid_to_results[query_id_str] = []
|
|
||||||
queryid_to_results[query_id_str].append(format_reward)
|
|
||||||
else:
|
|
||||||
for query_id_str in query_id_strs:
|
|
||||||
if query_id_str not in queryid_to_results:
|
|
||||||
queryid_to_results[query_id_str] = []
|
|
||||||
queryid_to_results[query_id_str].append(0)
|
|
||||||
format_rewards.append(0)
|
|
||||||
|
|
||||||
return format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class PackedCodeRewardInterface(model_api.ModelInterface):
|
|
||||||
|
|
||||||
enable_save: bool = False
|
|
||||||
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
|
||||||
output_scaling: float = 1.0
|
|
||||||
rm_output_scaling: float = 1.0
|
|
||||||
rm_output_bias: float = 0.0
|
|
||||||
output_bias: float = 0.0
|
|
||||||
loss_fun = torch.nn.CrossEntropyLoss(reduction="none")
|
|
||||||
max_sync_length: int = 2048
|
|
||||||
rw_type: str = "sparse"
|
|
||||||
task: str = "code" # math or countdown or code
|
|
||||||
check_xml_format: bool = False
|
|
||||||
post_process: str = "sigmoid"
|
|
||||||
group_size: int = 1
|
|
||||||
check_verifier_status: bool = False
|
|
||||||
|
|
||||||
_call_count: int = 0
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
|
|
||||||
logger.info(f"rm_output_scaling: {self.rm_output_scaling}")
|
|
||||||
logger.info(f"rm_output_bias: {self.rm_output_bias}")
|
|
||||||
logger.info(f"output_scaling: {self.output_scaling}")
|
|
||||||
logger.info(f"output_bias: {self.output_bias}")
|
|
||||||
logger.info(f"max_sync_length: {self.max_sync_length}")
|
|
||||||
logger.info(f"rw_type: {self.rw_type}")
|
|
||||||
logger.info(f"post_process: {self.post_process}")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
gen_file_path = os.path.join(
|
|
||||||
constants.LOG_ROOT,
|
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"generated",
|
|
||||||
f"v{self._call_count}r{dist.get_rank()}.txt",
|
|
||||||
)
|
|
||||||
if os.path.exists(gen_file_path):
|
|
||||||
self._call_count += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
logger.info(f"call_count: {self._call_count}")
|
|
||||||
|
|
||||||
def inference(
|
|
||||||
self,
|
|
||||||
model: model_api.Model,
|
|
||||||
data_: SequenceSample,
|
|
||||||
mb_spec,
|
|
||||||
) -> SequenceSample:
|
|
||||||
|
|
||||||
packed_input_ids: torch.Tensor = data_.data["packed_input_ids"].squeeze()
|
|
||||||
|
|
||||||
input_seqlens = torch.tensor(data_.seqlens["packed_input_ids"]).view(-1)
|
|
||||||
input_cu_seqlens = torch.nn.functional.pad(
|
|
||||||
input_seqlens.cumsum(0), (1, 0)
|
|
||||||
).int()
|
|
||||||
|
|
||||||
packed_prompts = data_.data["packed_prompts"]
|
|
||||||
prompts = []
|
|
||||||
prompt_seqlens = []
|
|
||||||
offset = 0
|
|
||||||
for x in data_.seqlens["packed_prompts"]:
|
|
||||||
prompts += [packed_prompts[offset : offset + x[0]]] * self.group_size
|
|
||||||
offset += x[0]
|
|
||||||
prompt_seqlens.extend(x * self.group_size)
|
|
||||||
|
|
||||||
assert offset == sum(x[0] for x in data_.seqlens["packed_prompts"])
|
|
||||||
# non_packed_prompts = copy.deepcopy(prompts)
|
|
||||||
prompts = torch.cat(prompts)
|
|
||||||
prompt_seqlens = torch.tensor(prompt_seqlens).view(-1)
|
|
||||||
prompt_cu_seqlens = torch.nn.functional.pad(
|
|
||||||
prompt_seqlens.cumsum(0), (1, 0)
|
|
||||||
).int()
|
|
||||||
|
|
||||||
query_ids = [data_id for data_id in data_.ids for _ in range(self.group_size)]
|
|
||||||
|
|
||||||
format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results = (
|
|
||||||
retokenize(
|
|
||||||
self.task,
|
|
||||||
self.tokenizer,
|
|
||||||
packed_input_ids,
|
|
||||||
input_cu_seqlens,
|
|
||||||
prompts,
|
|
||||||
prompt_cu_seqlens,
|
|
||||||
query_ids,
|
|
||||||
check_xml_format=self.check_xml_format,
|
|
||||||
do_eval=constants.is_last_pipe_stage(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert self.rw_type == "sparse"
|
|
||||||
dense_scores = torch.zeros_like(packed_input_ids).float()
|
|
||||||
scores = torch.FloatTensor(format_rewards).to(packed_input_ids.device)
|
|
||||||
scores[scores == 0] = -1
|
|
||||||
|
|
||||||
if len(scores) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
assert dense_scores.shape == packed_input_ids.shape
|
|
||||||
scores = (
|
|
||||||
scores.to(packed_input_ids.device) - self.output_bias
|
|
||||||
) * self.output_scaling
|
|
||||||
|
|
||||||
logger.info(f"Code reward logging info @v{model.version.global_step}")
|
|
||||||
logger.info(
|
|
||||||
f"before: Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"number of samples: {len(scores)}, {scores.shape}, group_size: {self.group_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
gen_file_path = os.path.join(
|
|
||||||
constants.LOG_ROOT,
|
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"generated",
|
|
||||||
f"v{self._call_count}r{dist.get_rank()}.txt",
|
|
||||||
)
|
|
||||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
|
||||||
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
|
|
||||||
with open(gen_file_path, "w") as _f:
|
|
||||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
|
||||||
zip(scores, prompt_strs, seq_strs)
|
|
||||||
):
|
|
||||||
info = "\n".join(
|
|
||||||
[
|
|
||||||
f"idx: {idx} / {len(scores)}",
|
|
||||||
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
|
|
||||||
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
_f.write(info + "\n")
|
|
||||||
|
|
||||||
gen_file_path = os.path.join(
|
|
||||||
constants.LOG_ROOT,
|
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"generated_jsonl",
|
|
||||||
f"v{self._call_count}r{dist.get_rank()}.jsonl",
|
|
||||||
)
|
|
||||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
|
||||||
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
|
|
||||||
with open(gen_file_path, "w") as _f:
|
|
||||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
|
||||||
zip(scores, prompt_strs, seq_strs)
|
|
||||||
):
|
|
||||||
_f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"prompt": prompt_str,
|
|
||||||
"generated": seq_str.split(prompt_str)[1],
|
|
||||||
"reward": score.item(),
|
|
||||||
},
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
pass_at_k = np.mean(
|
|
||||||
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
|
|
||||||
)
|
|
||||||
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
|
|
||||||
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
|
|
||||||
|
|
||||||
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
|
|
||||||
|
|
||||||
logger.info(f"reward: {sum(scores) / len(scores)}")
|
|
||||||
|
|
||||||
train_pass_monitor_file_path = os.path.join(
|
|
||||||
constants.LOG_ROOT,
|
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"training_monitor",
|
|
||||||
f"v{self._call_count}r{dist.get_rank()}.jsonl",
|
|
||||||
)
|
|
||||||
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
|
|
||||||
logger.info(
|
|
||||||
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
|
|
||||||
)
|
|
||||||
with open(train_pass_monitor_file_path, "w") as monitor_file:
|
|
||||||
for key, value in queryid_to_results.items():
|
|
||||||
pass1 = sum(value) / len(value)
|
|
||||||
pass8 = int(sum(value) > 0)
|
|
||||||
monitor_file.write(
|
|
||||||
json.dumps(
|
|
||||||
{"query_id": key, "pass1": pass1, "pass8": pass8},
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._call_count += 1
|
|
||||||
|
|
||||||
if scores.dtype != torch.float32:
|
|
||||||
scores = scores.to(torch.float32)
|
|
||||||
if dense_scores.dtype != torch.float32:
|
|
||||||
dense_scores = dense_scores.to(torch.float32)
|
|
||||||
|
|
||||||
res = SequenceSample(
|
|
||||||
keys=["rewards", "dense_rewards"],
|
|
||||||
trailing_shapes=dict(rewards=(), dense_rewards=()),
|
|
||||||
dtypes=dict(rewards=torch.float32, dense_rewards=torch.float32),
|
|
||||||
ids=data_.ids,
|
|
||||||
seqlens=dict(
|
|
||||||
rewards=[
|
|
||||||
torch.tensor([1 for _ in range(len(x))], dtype=torch.int32)
|
|
||||||
for x in data_.seqlens["packed_input_ids"]
|
|
||||||
],
|
|
||||||
dense_rewards=data_.seqlens["packed_input_ids"],
|
|
||||||
),
|
|
||||||
data=dict(rewards=scores, dense_rewards=dense_scores),
|
|
||||||
)
|
|
||||||
|
|
||||||
# record rewards for each piece of data
|
|
||||||
avg_scores = []
|
|
||||||
offset = 0
|
|
||||||
for i in range(data_.bs):
|
|
||||||
score_lis = scores[
|
|
||||||
offset : offset + len(data_.seqlens["packed_input_ids"][i])
|
|
||||||
]
|
|
||||||
avg_scores.append(score_lis.mean().item())
|
|
||||||
offset += len(data_.seqlens["packed_input_ids"][i])
|
|
||||||
assert offset == sum(len(x) for x in data_.seqlens["packed_input_ids"])
|
|
||||||
|
|
||||||
res.metadata["scores"] = avg_scores
|
|
||||||
|
|
||||||
if self.check_verifier_status:
|
|
||||||
avg_score = torch.tensor(
|
|
||||||
np.mean(avg_scores), device=constants.current_device()
|
|
||||||
)
|
|
||||||
dist.all_reduce(
|
|
||||||
avg_score, op=dist.ReduceOp.SUM, group=constants.parallelism_group()
|
|
||||||
)
|
|
||||||
avg_score /= constants.parallelism_group_size()
|
|
||||||
avg_score = avg_score.item()
|
|
||||||
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
|
|
||||||
|
|
||||||
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
|
|
||||||
raise CodeVerifierException(
|
|
||||||
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
|
|
||||||
)
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
model_api.register_interface("rw_code", PackedCodeRewardInterface)
|
|
|
@ -37,10 +37,34 @@ ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else
|
||||||
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
|
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
|
||||||
|
|
||||||
|
|
||||||
class MathVerifierException(Exception):
|
class VerifierException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def extract_python_code(text, min_length=20, strict_syntax=True):
|
||||||
|
code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```"
|
||||||
|
code_blocks = re.findall(code_pattern, text, re.DOTALL)
|
||||||
|
valid_blocks = []
|
||||||
|
for block in code_blocks:
|
||||||
|
clean_block = block.strip()
|
||||||
|
if len(clean_block) < min_length:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# verify code syntax
|
||||||
|
if strict_syntax:
|
||||||
|
try:
|
||||||
|
parse(clean_block, mode="exec")
|
||||||
|
except (SyntaxError, IndentationError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
valid_blocks.append(clean_block)
|
||||||
|
|
||||||
|
if not valid_blocks:
|
||||||
|
return None
|
||||||
|
# return the last code block
|
||||||
|
return valid_blocks[-1]
|
||||||
|
|
||||||
|
|
||||||
def check_with_elementtree(text):
|
def check_with_elementtree(text):
|
||||||
def escape_between_tags(text, tags=["think", "answer"]):
|
def escape_between_tags(text, tags=["think", "answer"]):
|
||||||
"""转义标签之间的内容,但保留标签本身."""
|
"""转义标签之间的内容,但保留标签本身."""
|
||||||
|
@ -94,6 +118,19 @@ def check_with_elementtree(text):
|
||||||
return False, f"Error: XML格式错误, {str(e)}"
|
return False, f"Error: XML格式错误, {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def reward_caculate(task, _answers, query_id_strs):
|
||||||
|
format_rewards = []
|
||||||
|
if task == "math":
|
||||||
|
format_rewards = math_verify_call(_answers, query_id_strs)
|
||||||
|
else:
|
||||||
|
codes = [extract_python_code(_answer) for _answer in _answers]
|
||||||
|
format_rewards = code_verify(codes, query_id_strs)
|
||||||
|
logger.info(
|
||||||
|
f"reward_caculate, task: {task}, size: {len(query_id_strs)}, query_id_0: {query_id_strs[0]}"
|
||||||
|
)
|
||||||
|
return format_rewards
|
||||||
|
|
||||||
|
|
||||||
def retokenize(
|
def retokenize(
|
||||||
task,
|
task,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
@ -122,6 +159,10 @@ def retokenize(
|
||||||
# query_id_strs = query_ids
|
# query_id_strs = query_ids
|
||||||
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
|
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"retokenize, query_id_strs:{query_id_strs}, seq_strs:{seq_strs}, prompt_strs:{prompt_strs}"
|
||||||
|
)
|
||||||
|
|
||||||
format_rewards = []
|
format_rewards = []
|
||||||
queryid_to_results = collections.defaultdict(list)
|
queryid_to_results = collections.defaultdict(list)
|
||||||
# 8 processes on each node, with 10 subprocesses each
|
# 8 processes on each node, with 10 subprocesses each
|
||||||
|
@ -162,8 +203,7 @@ def retokenize(
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class PackedMathRewardInterface(model_api.ModelInterface):
|
class PackedRewardInterface(model_api.ModelInterface):
|
||||||
|
|
||||||
enable_save: bool = False
|
enable_save: bool = False
|
||||||
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
||||||
output_scaling: float = 1.0
|
output_scaling: float = 1.0
|
||||||
|
@ -392,7 +432,7 @@ class PackedMathRewardInterface(model_api.ModelInterface):
|
||||||
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
|
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
|
||||||
|
|
||||||
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
|
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
|
||||||
raise MathVerifierException(
|
raise VerifierException(
|
||||||
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
|
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -401,4 +441,4 @@ class PackedMathRewardInterface(model_api.ModelInterface):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
model_api.register_interface("rw_math", PackedMathRewardInterface)
|
model_api.register_interface("reward", PackedRewardInterface)
|
|
@ -0,0 +1,270 @@
|
||||||
|
# Copyright 2025 Ant Group Inc.
|
||||||
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
|
import functools
|
||||||
|
import itertools
|
||||||
|
import time
|
||||||
|
from typing import Dict, Literal, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
import realhf.api.core.model_api as model_api
|
||||||
|
|
||||||
|
import realhf.base.logging as logging
|
||||||
|
import realhf.impl.model.utils.ppo_functional as ppo_functional
|
||||||
|
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
|
||||||
|
from realhf.base.datapack import flat2d
|
||||||
|
|
||||||
|
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||||
|
|
||||||
|
from realhf.impl.model.utils.functional import (
|
||||||
|
gather_packed_shifted_log_probs,
|
||||||
|
masked_normalization,
|
||||||
|
)
|
||||||
|
from realhf.impl.model.interface.rw_interface import PackedRewardInterface
|
||||||
|
|
||||||
|
logger = logging.getLogger("RefRwInterface")
|
||||||
|
|
||||||
|
TASK_TYPE_REF: Literal["ref"] = "ref"
|
||||||
|
TASK_TYPE_RW_MATH: Literal["rw_math"] = "rw_math"
|
||||||
|
TASK_TYPE_RW_CODE: Literal["rw_code"] = "rw_code"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class RefRwInterface(model_api.ModelInterface):
|
||||||
|
n_minibatches: int = 4
|
||||||
|
|
||||||
|
# Use dict here to allow argument passing through commandline.
|
||||||
|
generation_config: Dict = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
kl_ctl: float = 0.1
|
||||||
|
|
||||||
|
adv_norm: bool = True
|
||||||
|
discount: float = 1.0
|
||||||
|
gae_lambda: float = 1.0
|
||||||
|
|
||||||
|
eps_clip: float = 0.2
|
||||||
|
value_eps_clip: float = 0.2
|
||||||
|
max_reward_clip: float = 5.0
|
||||||
|
|
||||||
|
disable_value: bool = False
|
||||||
|
|
||||||
|
early_stop_kl: Optional[float] = None # e.g. 0.1
|
||||||
|
early_stop_imp_ratio: Optional[float] = None # e.g., 10.0
|
||||||
|
|
||||||
|
adaptive_kl_ctl: bool = False
|
||||||
|
adaptive_kl_target: Optional[float] = 6
|
||||||
|
adaptive_kl_horizon: Optional[float] = 10000
|
||||||
|
|
||||||
|
enable_save: bool = True
|
||||||
|
|
||||||
|
value_norm: bool = False
|
||||||
|
value_norm_type: str = dataclasses.field(
|
||||||
|
metadata={"choices": ["exp", "ma"]}, default="exp"
|
||||||
|
)
|
||||||
|
value_norm_beta: float = 0.99995
|
||||||
|
value_norm_eps: float = 1e-5
|
||||||
|
|
||||||
|
group_size: int = 1
|
||||||
|
generation_size: Optional[int] = None
|
||||||
|
mask_no_eos_with_zero: bool = False
|
||||||
|
group_adv_norm: bool = False
|
||||||
|
mask_too_long: bool = False
|
||||||
|
use_dense_reward: bool = False
|
||||||
|
reward_delta: bool = True
|
||||||
|
token_normalize_scope: Literal["global", "dp"] = "global"
|
||||||
|
rew_inf_args: Dict = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.adaptive_kl_ctl:
|
||||||
|
assert self.adaptive_kl_target is not None
|
||||||
|
assert self.adaptive_kl_horizon is not None
|
||||||
|
self.kl_adapter = ppo_functional.AdaptiveKLController(
|
||||||
|
self.kl_ctl, self.adaptive_kl_target, self.adaptive_kl_horizon
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.kl_adapter = ppo_functional.FixedKLController(self.kl_ctl)
|
||||||
|
if self.value_norm:
|
||||||
|
from realhf.impl.model.modules import (
|
||||||
|
ExponentialRunningMeanStd,
|
||||||
|
MovingAverageRunningMeanStd,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.value_norm_type == "exp":
|
||||||
|
self.rms = ExponentialRunningMeanStd(
|
||||||
|
beta=self.value_norm_beta, epsilon=self.value_norm_eps
|
||||||
|
)
|
||||||
|
elif self.value_norm_type == "ma":
|
||||||
|
self.rms = MovingAverageRunningMeanStd()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown value_norm_type {self.value_norm_type}")
|
||||||
|
self.kl_ctl = None
|
||||||
|
|
||||||
|
self.gconfig = model_api.GenerationHyperparameters(**self.generation_config)
|
||||||
|
if self.generation_size is not None:
|
||||||
|
assert self.generation_size >= self.group_size
|
||||||
|
else:
|
||||||
|
self.generation_size = self.group_size
|
||||||
|
self.gconfig.n = self.generation_size
|
||||||
|
|
||||||
|
def save(self, model: model_api.Model, save_dir: str):
|
||||||
|
if not self.enable_save:
|
||||||
|
return
|
||||||
|
module = model.module
|
||||||
|
if not isinstance(module, ReaLModel):
|
||||||
|
module = module.module
|
||||||
|
module.save_to_hf(
|
||||||
|
tokenizer=model.tokenizer,
|
||||||
|
save_dir=save_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _dispatch_tasks(self, data):
|
||||||
|
math_data, code_data, rlhf_data, ref_data = data, data, data, data
|
||||||
|
return math_data, code_data, rlhf_data, ref_data
|
||||||
|
|
||||||
|
def _gather_tasks(self, data_map):
|
||||||
|
# merge SequenceSamples from math_data, code_data, rlhf_data, ref_data
|
||||||
|
return data_map.get(TASK_TYPE_REF, None)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ref_inference(
|
||||||
|
self, model: model_api.Model, input_: SequenceSample, mb_spec: MicroBatchSpec
|
||||||
|
):
|
||||||
|
module = model.module
|
||||||
|
module.eval()
|
||||||
|
|
||||||
|
# This post_hook will gather log probabilities in mini-batches,
|
||||||
|
# reducing peak memory usage.
|
||||||
|
def calc_logprobs(logits, input_):
|
||||||
|
logits /= self.gconfig.temperature
|
||||||
|
|
||||||
|
input_lens = torch.tensor(input_.seqlens["packed_input_ids"]).view(-1)
|
||||||
|
cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
|
||||||
|
|
||||||
|
logprobs = gather_packed_shifted_log_probs(
|
||||||
|
logits, cu_seqlens, input_.data["packed_input_ids"]
|
||||||
|
)
|
||||||
|
return logprobs
|
||||||
|
|
||||||
|
input_flattend = SequenceSample.from_default(
|
||||||
|
ids=list(range(input_.bs * self.group_size)),
|
||||||
|
seqlens=flat2d(input_.seqlens["packed_input_ids"]),
|
||||||
|
data=dict(packed_input_ids=input_.data["packed_input_ids"]),
|
||||||
|
)
|
||||||
|
# add posthook to avoid storing full logits
|
||||||
|
logprobs = module.forward(
|
||||||
|
input_=input_flattend,
|
||||||
|
post_hook=calc_logprobs,
|
||||||
|
output_seqlens=[
|
||||||
|
[x - 1 for x in slens]
|
||||||
|
for slens in input_flattend.seqlens["packed_input_ids"]
|
||||||
|
],
|
||||||
|
mb_spec=mb_spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
res = SequenceSample(
|
||||||
|
keys=["packed_ref_logprobs"],
|
||||||
|
ids=input_.ids,
|
||||||
|
dtypes=dict(packed_ref_logprobs=model.module.dtype),
|
||||||
|
trailing_shapes=dict(packed_ref_logprobs=()),
|
||||||
|
data=dict(packed_ref_logprobs=logprobs),
|
||||||
|
seqlens=dict(
|
||||||
|
packed_ref_logprobs=[
|
||||||
|
[x - 1 for x in slen] for slen in input_.seqlens["packed_input_ids"]
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
model: model_api.Model,
|
||||||
|
input_: SequenceSample,
|
||||||
|
mb_spec: MicroBatchSpec,
|
||||||
|
) -> SequenceSample:
|
||||||
|
math_data, code_data, rlhf_data, ref_data = self._dispatch_tasks(input_)
|
||||||
|
|
||||||
|
if not hasattr(self, "rew_inf_args") or not isinstance(self.rew_inf_args, dict):
|
||||||
|
raise ValueError("Invalid rew_inf_args. Expected a dictionary.")
|
||||||
|
rewardInterface = PackedRewardInterface(**self.rew_inf_args)
|
||||||
|
logger.info(f"self.rew_inf_args: {self.rew_inf_args}, input_: {input_}")
|
||||||
|
|
||||||
|
task_map = {
|
||||||
|
TASK_TYPE_REF: (self.ref_inference, ref_data),
|
||||||
|
TASK_TYPE_RW_MATH: (rewardInterface.inference, math_data),
|
||||||
|
TASK_TYPE_RW_CODE: (rewardInterface.inference, code_data),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _task_func(func, task_type: str):
|
||||||
|
def _wrapped_func(*args, **kwargs):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
logger.info(f"[{task_type}] ref_rw task start @ {start_time:.4f}")
|
||||||
|
try:
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{task_type} ref_rw task failed: {e}")
|
||||||
|
finally:
|
||||||
|
duration = time.perf_counter() - start_time
|
||||||
|
logger.info(
|
||||||
|
f"[{task_type}] ref_rw task cost: {duration:.4f}s, start @ {start_time:.4f}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return _wrapped_func
|
||||||
|
|
||||||
|
async def _run_tasks() -> dict:
|
||||||
|
tasks = []
|
||||||
|
for task_type, (func, data) in task_map.items():
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
task_func = _task_func(func, task_type)
|
||||||
|
task_args = (model, data, mb_spec)
|
||||||
|
task = asyncio.create_task(asyncio.to_thread(task_func, *task_args))
|
||||||
|
tasks.append((task_type, task))
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
for task_type, task in tasks:
|
||||||
|
try:
|
||||||
|
results[task_type] = await task
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{task_type} task failed: {e}")
|
||||||
|
results[task_type] = None
|
||||||
|
return results
|
||||||
|
|
||||||
|
task_results = asyncio.run(_run_tasks())
|
||||||
|
final_result = self._gather_tasks(task_results)
|
||||||
|
return final_result
|
||||||
|
|
||||||
|
# Mock methods for profiling only.
|
||||||
|
def _mock_inference(
|
||||||
|
self,
|
||||||
|
model: model_api.Model,
|
||||||
|
dataset_input: SequenceSample,
|
||||||
|
) -> SequenceSample:
|
||||||
|
prompt_lens = flat2d(dataset_input.seqlens["packed_prompts"])
|
||||||
|
seqlens = [x + self.gconfig.max_new_tokens for x in prompt_lens]
|
||||||
|
module = model.module
|
||||||
|
if not isinstance(module, ReaLModel):
|
||||||
|
module = module.module
|
||||||
|
mconfig = module.config
|
||||||
|
packed_input_ids = torch.randint(
|
||||||
|
0,
|
||||||
|
mconfig.vocab_size,
|
||||||
|
(sum(seqlens),),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=model.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SequenceSample.from_default(
|
||||||
|
seqlens=seqlens,
|
||||||
|
ids=dataset_input.ids,
|
||||||
|
data=dict(packed_input_ids=packed_input_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
model_api.register_interface("ref_rw", RefRwInterface)
|
|
@ -0,0 +1,317 @@
|
||||||
|
import functools
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import *
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pynvml
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import transformers
|
||||||
|
from torch.cuda import is_initialized
|
||||||
|
|
||||||
|
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
print(root_dir)
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(0, root_dir)
|
||||||
|
|
||||||
|
from realhf.api.core import data_api, dfg, model_api
|
||||||
|
from realhf.api.core.config import ModelName
|
||||||
|
from realhf.api.core.model_api import ReaLModelConfig
|
||||||
|
from realhf.base import constants, logging
|
||||||
|
from realhf.base.network import find_free_port
|
||||||
|
from realhf.base.testing import (
|
||||||
|
init_global_constants,
|
||||||
|
_DEFAULT_EXPR_NAME,
|
||||||
|
_DEFAULT_TRIAL_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("test async ref-rew")
|
||||||
|
os.environ["REAL_MATH_METADATA_PATH"] = "/storage/datasets/id2info.json"
|
||||||
|
|
||||||
|
|
||||||
|
def loadJson():
|
||||||
|
dataDir = os.environ["REAL_MATH_METADATA_PATH"]
|
||||||
|
with open(dataDir, "r") as f:
|
||||||
|
if dataDir.endswith(".jsonl"):
|
||||||
|
samples = [json.loads(line) for line in f.readlines()]
|
||||||
|
else:
|
||||||
|
samples = json.load(f)
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_input(batch_size: int, seq_len):
|
||||||
|
vocab_size = 100
|
||||||
|
torch.manual_seed(1)
|
||||||
|
seqs = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long)
|
||||||
|
|
||||||
|
samples = loadJson()
|
||||||
|
id_list = list(samples.keys())
|
||||||
|
# id_tensor = torch.tensor([id_list[i] for i in range(seqs.shape[0])], dtype=torch.long) # 使用哈希值编码
|
||||||
|
|
||||||
|
return data_api.SequenceSample.from_default(
|
||||||
|
seqlens=[seq_len for _ in range(seqs.shape[0])],
|
||||||
|
ids=[id_list[i] for i in range(seqs.shape[0])],
|
||||||
|
data=dict(
|
||||||
|
packed_input_ids=seqs.view(-1),
|
||||||
|
# prompt_mask=torch.zeros_like(seqs.view(-1), dtype=torch.bool),
|
||||||
|
packed_prompts=seqs[:, :seq_len].contiguous().view(-1),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def funcion_call(
|
||||||
|
rpc_name: str,
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
model_path: str,
|
||||||
|
model_family_name: str,
|
||||||
|
dp: int,
|
||||||
|
pp: int,
|
||||||
|
tp: int,
|
||||||
|
interface_type: dfg.ModelInterfaceType,
|
||||||
|
interface_impl: dfg.ModelInterfaceAbstraction,
|
||||||
|
batch_size: int,
|
||||||
|
prompt_len: int,
|
||||||
|
input_: data_api.SequenceSample | None,
|
||||||
|
port: int,
|
||||||
|
):
|
||||||
|
|
||||||
|
# assert not torch.cuda.is_initialized()
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
||||||
|
torch.cuda.set_device(0)
|
||||||
|
assert world_size == (
|
||||||
|
dp * pp * tp
|
||||||
|
), f"dp={dp}, pp={pp}, tp={tp}, world_size={world_size}"
|
||||||
|
assert batch_size % dp == 0, (batch_size, dp)
|
||||||
|
|
||||||
|
# Initialize distributed environment.
|
||||||
|
model_name = ModelName("default", 0)
|
||||||
|
if not dist.is_initialized():
|
||||||
|
logger.info("Setting up distributed environment...")
|
||||||
|
dist.init_process_group(
|
||||||
|
"nccl",
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
init_method=f"tcp://localhost:{port}",
|
||||||
|
)
|
||||||
|
logger.info("Initialized distributed environment.")
|
||||||
|
init_global_constants(
|
||||||
|
num_dp=dp,
|
||||||
|
num_mp=tp,
|
||||||
|
num_pp=pp,
|
||||||
|
sequence_parallel=interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
|
||||||
|
model_name=model_name,
|
||||||
|
max_prompt_len=prompt_len,
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(0)
|
||||||
|
|
||||||
|
# NOTE: import here to avoid CUDA re-initialization
|
||||||
|
|
||||||
|
from realhf.impl.model.nn.real_llm_api import ReaLModel, add_helper_functions
|
||||||
|
|
||||||
|
# Call a method like `config_from_llama` to get the config.
|
||||||
|
mconfig: ReaLModelConfig = getattr(ReaLModel, f"config_from_{model_family_name}")(
|
||||||
|
transformers.AutoConfig.from_pretrained(model_path)
|
||||||
|
)
|
||||||
|
is_critic = rpc_name in ["critic_inf", "critic_train", "rew_inf"]
|
||||||
|
mconfig.is_critic = is_critic
|
||||||
|
with constants.model_scope(model_name):
|
||||||
|
# Construct the model.
|
||||||
|
logger.info(f"Loading model from {model_path}...")
|
||||||
|
module = ReaLModel(mconfig, dtype=torch.bfloat16, device="cuda")
|
||||||
|
setattr(ReaLModel, "save_to_hf", getattr(ReaLModel, f"to_{model_family_name}"))
|
||||||
|
setattr(
|
||||||
|
ReaLModel, "load_from_hf", getattr(ReaLModel, f"from_{model_family_name}")
|
||||||
|
)
|
||||||
|
module._instantiation_hooks.append(
|
||||||
|
lambda: getattr(module, f"from_{model_family_name}")(
|
||||||
|
load_dir=model_path,
|
||||||
|
init_critic_from_actor=is_critic,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
add_helper_functions(module)
|
||||||
|
module.instantiate()
|
||||||
|
module.eval()
|
||||||
|
|
||||||
|
tokenizer = data_api.load_hf_tokenizer(model_path)
|
||||||
|
|
||||||
|
model = model_api.Model(
|
||||||
|
name=model_name,
|
||||||
|
module=module,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
device=module.device,
|
||||||
|
dtype=module.dtype,
|
||||||
|
)
|
||||||
|
if interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
|
||||||
|
from realhf.impl.model.backend.megatron import MegatronTrainBackend
|
||||||
|
|
||||||
|
backend = MegatronTrainBackend()
|
||||||
|
else:
|
||||||
|
from realhf.impl.model.backend.inference import PipelineInferenceBackend
|
||||||
|
|
||||||
|
backend = PipelineInferenceBackend()
|
||||||
|
|
||||||
|
logger.info("Running backend initialization...")
|
||||||
|
ft_spec = model_api.FinetuneSpec(
|
||||||
|
total_train_epochs=1,
|
||||||
|
dataset_size=128,
|
||||||
|
train_batch_size=128,
|
||||||
|
)
|
||||||
|
model = backend.initialize(model, ft_spec)
|
||||||
|
|
||||||
|
interface = model_api.make_interface(interface_impl)
|
||||||
|
|
||||||
|
if input_ is None:
|
||||||
|
input_ = _mock_input(batch_size, prompt_len)
|
||||||
|
|
||||||
|
input_ = input_.cuda()
|
||||||
|
|
||||||
|
mb_spec = model_api.MicroBatchSpec()
|
||||||
|
|
||||||
|
logger.info("Running interface computation...")
|
||||||
|
start = time.perf_counter_ns()
|
||||||
|
if interface_type == dfg.ModelInterfaceType.GENERATE:
|
||||||
|
res = interface.generate(model, input_, mb_spec)
|
||||||
|
elif interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
|
||||||
|
res = interface.train_step(model, input_)
|
||||||
|
else:
|
||||||
|
res = interface.inference(model, input_, mb_spec)
|
||||||
|
|
||||||
|
if constants.model_parallel_rank() == 0 and constants.is_last_pipe_stage():
|
||||||
|
if isinstance(res, data_api.SequenceSample):
|
||||||
|
res = res.cpu()
|
||||||
|
|
||||||
|
comsumed = time.perf_counter_ns() - start
|
||||||
|
logger.info(f"{rpc_name} Computation done. {comsumed} ns")
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def run_function_call(
|
||||||
|
rpc_name: str,
|
||||||
|
model_path: str,
|
||||||
|
model_family_name: str,
|
||||||
|
batch_size: int,
|
||||||
|
prompt_len: int,
|
||||||
|
gen_len: int,
|
||||||
|
input_: data_api.SequenceSample | None,
|
||||||
|
) -> data_api.SequenceSample | None:
|
||||||
|
assert rpc_name in [
|
||||||
|
"actor_gen",
|
||||||
|
"actor_train",
|
||||||
|
"critic_inf",
|
||||||
|
"rew_inf",
|
||||||
|
"critic_train",
|
||||||
|
"ref_inf",
|
||||||
|
"ref_rw",
|
||||||
|
]
|
||||||
|
|
||||||
|
ref_rw_interface = dfg.ModelInterfaceAbstraction(
|
||||||
|
"ref_rw",
|
||||||
|
args=dict(
|
||||||
|
generation_config=dict(
|
||||||
|
max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True
|
||||||
|
),
|
||||||
|
rew_inf_args=dict(
|
||||||
|
tokenizer_path=model_path,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
ppo_actor_interface = dfg.ModelInterfaceAbstraction(
|
||||||
|
"ppo_actor",
|
||||||
|
args=dict(
|
||||||
|
generation_config=dict(
|
||||||
|
max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True
|
||||||
|
),
|
||||||
|
rew_inf_args=dict(
|
||||||
|
tokenizer_path=model_path,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
ppo_critic_interface = dfg.ModelInterfaceAbstraction("ppo_critic")
|
||||||
|
rw_interface = dfg.ModelInterfaceAbstraction(
|
||||||
|
"paired_rw",
|
||||||
|
)
|
||||||
|
if rpc_name == "actor_gen":
|
||||||
|
interface_type = dfg.ModelInterfaceType.GENERATE
|
||||||
|
interface_impl = ppo_actor_interface
|
||||||
|
elif rpc_name == "actor_train":
|
||||||
|
interface_type = dfg.ModelInterfaceType.TRAIN_STEP
|
||||||
|
interface_impl = ppo_actor_interface
|
||||||
|
elif rpc_name == "critic_inf":
|
||||||
|
interface_type = dfg.ModelInterfaceType.INFERENCE
|
||||||
|
interface_impl = ppo_critic_interface
|
||||||
|
elif rpc_name == "ref_inf":
|
||||||
|
interface_type = dfg.ModelInterfaceType.INFERENCE
|
||||||
|
interface_impl = ppo_actor_interface
|
||||||
|
elif rpc_name == "ref_rw":
|
||||||
|
interface_type = dfg.ModelInterfaceType.INFERENCE
|
||||||
|
interface_impl = ref_rw_interface
|
||||||
|
elif rpc_name == "critic_train":
|
||||||
|
interface_type = dfg.ModelInterfaceType.TRAIN_STEP
|
||||||
|
interface_impl = ppo_critic_interface
|
||||||
|
else:
|
||||||
|
interface_type = dfg.ModelInterfaceType.INFERENCE
|
||||||
|
interface_impl = rw_interface
|
||||||
|
|
||||||
|
logger.info(f"Running RPC {rpc_name}...")
|
||||||
|
|
||||||
|
port = find_free_port()
|
||||||
|
res = funcion_call(
|
||||||
|
rank=0,
|
||||||
|
rpc_name=rpc_name,
|
||||||
|
world_size=1,
|
||||||
|
model_path=model_path,
|
||||||
|
model_family_name=model_family_name,
|
||||||
|
dp=1,
|
||||||
|
pp=1,
|
||||||
|
tp=1,
|
||||||
|
interface_type=interface_type,
|
||||||
|
interface_impl=interface_impl,
|
||||||
|
batch_size=batch_size,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
input_=input_,
|
||||||
|
port=port,
|
||||||
|
)
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
if isinstance(res, data_api.SequenceSample):
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
logger.info(f"RPC {rpc_name} stats: {res}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
model_family_name = "qwen2"
|
||||||
|
batch_size = 16
|
||||||
|
prompt_len = 128
|
||||||
|
gen_len = 4096
|
||||||
|
model_path = "/storage/models/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||||
|
|
||||||
|
constants.set_experiment_trial_names(_DEFAULT_EXPR_NAME, _DEFAULT_TRIAL_NAME)
|
||||||
|
|
||||||
|
for i in range(2):
|
||||||
|
ref_rw_res = run_function_call(
|
||||||
|
"ref_rw",
|
||||||
|
model_family_name=model_family_name,
|
||||||
|
model_path=model_path,
|
||||||
|
batch_size=batch_size,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
gen_len=gen_len,
|
||||||
|
input_=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue