mirror of https://github.com/inclusionAI/AReaL
create ref-rw inference aysnc mode
This commit is contained in:
parent
234e3dd3a0
commit
8732811c73
|
@ -19,7 +19,7 @@ def check_is_realhf_native_impl(_cls):
|
|||
def check_is_realhf_native_model_interface(name):
|
||||
# NOTE: we should not import iterfaces here,
|
||||
# such that we can avoid CUDA initialization.
|
||||
return name in ["ppo_actor", "ppo_critic", "sft"]
|
||||
return name in ["ppo_actor", "ppo_critic", "sft", "ref_rw"]
|
||||
|
||||
|
||||
def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]):
|
||||
|
|
|
@ -253,7 +253,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
"actor": self.actor,
|
||||
# "critic": self.critic,
|
||||
"ref": self.ref,
|
||||
"reward": self.rew,
|
||||
# "reward": self.rew,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
|
@ -304,8 +304,6 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
"reward_delta": self.reward_delta,
|
||||
},
|
||||
)
|
||||
ref_interface = copy.deepcopy(actor_interface)
|
||||
ref_interface.args["enable_save"] = False
|
||||
|
||||
critic_interface = ModelInterfaceAbstraction(
|
||||
"ppo_critic",
|
||||
|
@ -319,7 +317,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
)
|
||||
critic_interface.args.pop("eps_clip")
|
||||
rw_interface = ModelInterfaceAbstraction(
|
||||
"rw_math",
|
||||
"reward",
|
||||
args=dict(
|
||||
rw_type=self.rw_type,
|
||||
task=self.task,
|
||||
|
@ -368,9 +366,16 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
inf_ref_inputs = ["packed_input_ids"]
|
||||
# add rew param into ref MFC
|
||||
inf_ref_inputs = ["packed_input_ids", "packed_prompts"]
|
||||
inf_ref_outputs = ["packed_ref_logprobs", "rewards", "dense_rewards"]
|
||||
ref_interface = copy.deepcopy(actor_interface)
|
||||
ref_interface.type_ = "ref_rw"
|
||||
ref_interface.args["enable_save"] = False
|
||||
ref_interface.args["rew_inf_args"] = copy.deepcopy(rw_interface.args)
|
||||
|
||||
inf_ref_logits = MFCDef(
|
||||
name="ref_inf",
|
||||
name="ref_rw",
|
||||
model_name="ref",
|
||||
mb_spec=self.ref_inf.mb_spec,
|
||||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
|
@ -379,7 +384,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
interface_impl=ref_interface,
|
||||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=inf_ref_inputs,
|
||||
output_keys=["packed_ref_logprobs"],
|
||||
output_keys=inf_ref_outputs,
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
|
@ -451,8 +456,9 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
"actor_train": train_actor,
|
||||
# "critic_inf": inf_values,
|
||||
# "critic_train": train_critic,
|
||||
"ref_inf": inf_ref_logits,
|
||||
"rew_inf": inf_reward,
|
||||
# "ref_inf": inf_ref_logits,
|
||||
# "rew_inf": inf_reward,
|
||||
"ref_rw": inf_ref_logits,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
|
@ -472,8 +478,9 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
"actor_train": self.actor_train,
|
||||
# "critic_inf": self.critic_inf,
|
||||
# "critic_train": self.critic_train,
|
||||
"ref_inf": self.ref_inf,
|
||||
"rew_inf": self.rew_inf,
|
||||
# "ref_inf": self.ref_inf,
|
||||
# "rew_inf": self.rew_inf,
|
||||
"ref_rw": self.ref_inf,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
|
|
|
@ -1,429 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
import collections
|
||||
import dataclasses
|
||||
import html
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from ast import parse
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import colorama
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import realhf.api.core.model_api as model_api
|
||||
import realhf.base.logging as logging
|
||||
from functioncall.code.verify import code_verify
|
||||
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer
|
||||
from realhf.base import constants
|
||||
|
||||
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
|
||||
|
||||
|
||||
class CodeVerifierException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def extract_python_code(text, min_length=20, strict_syntax=True):
|
||||
code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```"
|
||||
code_blocks = re.findall(code_pattern, text, re.DOTALL)
|
||||
valid_blocks = []
|
||||
for block in code_blocks:
|
||||
clean_block = block.strip()
|
||||
if len(clean_block) < min_length:
|
||||
continue
|
||||
|
||||
# verify code syntax
|
||||
if strict_syntax:
|
||||
try:
|
||||
parse(clean_block, mode="exec")
|
||||
except (SyntaxError, IndentationError):
|
||||
continue
|
||||
|
||||
valid_blocks.append(clean_block)
|
||||
|
||||
if not valid_blocks:
|
||||
return None
|
||||
# return the last code block
|
||||
return valid_blocks[-1]
|
||||
|
||||
|
||||
def check_with_elementtree(text):
|
||||
def escape_between_tags(text, tags=["think", "answer"]):
|
||||
"""转义标签之间的内容,但保留标签本身."""
|
||||
# 构建标签模式
|
||||
tag_pattern = "|".join(tags)
|
||||
parts = []
|
||||
current_pos = 0
|
||||
|
||||
# 匹配开始和结束标签
|
||||
pattern = f"</?({tag_pattern})[^>]*>"
|
||||
|
||||
for match in re.finditer(pattern, text):
|
||||
# 添加标签之前的内容(需要转义)
|
||||
if current_pos < match.start():
|
||||
parts.append(html.escape(text[current_pos : match.start()]))
|
||||
|
||||
# 添加标签本身(不转义)
|
||||
parts.append(match.group())
|
||||
current_pos = match.end()
|
||||
|
||||
# 添加最后剩余的内容
|
||||
if current_pos < len(text):
|
||||
parts.append(html.escape(text[current_pos:]))
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
text = escape_between_tags(text)
|
||||
if not text.strip().startswith("<think>"):
|
||||
text = "<think>" + text
|
||||
try:
|
||||
xml_text = f"<root>{text}</root>"
|
||||
x = ET.fromstring(xml_text)
|
||||
if x.text is not None and x.text.strip() != "":
|
||||
return False, f"Error: extra content before <think>. {x.text}"
|
||||
if len(x) != 2:
|
||||
return False, f"Error: there are {len(x)} tags."
|
||||
if x[0].tag is None or x[0].tag != "think":
|
||||
return False, f"Error: <think> tag is missing. {x[0].tag}"
|
||||
if x[0].tail is not None and x[0].tail.strip() != "":
|
||||
return (
|
||||
False,
|
||||
f"Error: extra content between <think> and <answer>. {x[0].tail}",
|
||||
)
|
||||
if x[1].tag is None or x[1].tag != "answer":
|
||||
return False, f"Error: <answer> tag is missing. {x[1].tag}"
|
||||
if x[1].tail is not None and x[1].tail.strip() != "":
|
||||
return False, f"Error: extra content after <answer>, {x[1].tail}"
|
||||
|
||||
return True, x[1].text if x[1].text is not None else ""
|
||||
except ET.ParseError as e:
|
||||
return False, f"Error: XML格式错误, {str(e)}"
|
||||
|
||||
|
||||
def retokenize(
|
||||
task,
|
||||
tokenizer,
|
||||
packed_input_ids,
|
||||
input_cu_seqlens,
|
||||
prompts,
|
||||
prompt_cu_seqlens,
|
||||
query_ids,
|
||||
check_xml_format=False,
|
||||
do_eval=False,
|
||||
):
|
||||
input_ids = [
|
||||
packed_input_ids[start:end]
|
||||
for start, end in zip(input_cu_seqlens[:-1], input_cu_seqlens[1:])
|
||||
]
|
||||
prompt_ids = [
|
||||
prompts[start:end]
|
||||
for start, end in zip(prompt_cu_seqlens[:-1], prompt_cu_seqlens[1:])
|
||||
]
|
||||
seq_strs = tokenizer.batch_decode(
|
||||
input_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
||||
)
|
||||
prompt_strs = tokenizer.batch_decode(
|
||||
prompt_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
||||
)
|
||||
# query_id_strs = query_ids
|
||||
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
|
||||
|
||||
format_rewards = []
|
||||
|
||||
queryid_to_results = collections.defaultdict(list)
|
||||
# 8 processes on each node, with 10 subprocesses each
|
||||
if do_eval == True:
|
||||
_answers = [
|
||||
seq_str.split(prompt_str)[1]
|
||||
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
|
||||
]
|
||||
|
||||
codes = [extract_python_code(_answer) for _answer in _answers]
|
||||
logger.info(
|
||||
f"code_rw_interface, size: {len(query_id_strs)}, valid code size: {len(codes)}, query_id_0: {query_id_strs[0]}"
|
||||
)
|
||||
format_rewards = code_verify(codes, query_id_strs)
|
||||
|
||||
if check_xml_format:
|
||||
with ThreadPoolExecutor(max_workers=22) as executor:
|
||||
futures = [
|
||||
executor.submit(check_with_elementtree, answer_str)
|
||||
for answer_str in _answers
|
||||
]
|
||||
# xml_rewards = []
|
||||
for idx, future in enumerate(futures):
|
||||
xml_reward, _ = future.result()
|
||||
# xml_rewards.append(xml_reward)
|
||||
if xml_reward == 1 and format_rewards[idx] == 0:
|
||||
format_rewards[idx] = -0.8
|
||||
elif xml_reward == 0 and format_rewards[idx] == 0:
|
||||
format_rewards[idx] = -1
|
||||
|
||||
for query_id_str, format_reward in zip(query_id_strs, format_rewards):
|
||||
if query_id_str not in queryid_to_results:
|
||||
queryid_to_results[query_id_str] = []
|
||||
queryid_to_results[query_id_str].append(format_reward)
|
||||
else:
|
||||
for query_id_str in query_id_strs:
|
||||
if query_id_str not in queryid_to_results:
|
||||
queryid_to_results[query_id_str] = []
|
||||
queryid_to_results[query_id_str].append(0)
|
||||
format_rewards.append(0)
|
||||
|
||||
return format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PackedCodeRewardInterface(model_api.ModelInterface):
|
||||
|
||||
enable_save: bool = False
|
||||
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
||||
output_scaling: float = 1.0
|
||||
rm_output_scaling: float = 1.0
|
||||
rm_output_bias: float = 0.0
|
||||
output_bias: float = 0.0
|
||||
loss_fun = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
max_sync_length: int = 2048
|
||||
rw_type: str = "sparse"
|
||||
task: str = "code" # math or countdown or code
|
||||
check_xml_format: bool = False
|
||||
post_process: str = "sigmoid"
|
||||
group_size: int = 1
|
||||
check_verifier_status: bool = False
|
||||
|
||||
_call_count: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
|
||||
logger.info(f"rm_output_scaling: {self.rm_output_scaling}")
|
||||
logger.info(f"rm_output_bias: {self.rm_output_bias}")
|
||||
logger.info(f"output_scaling: {self.output_scaling}")
|
||||
logger.info(f"output_bias: {self.output_bias}")
|
||||
logger.info(f"max_sync_length: {self.max_sync_length}")
|
||||
logger.info(f"rw_type: {self.rw_type}")
|
||||
logger.info(f"post_process: {self.post_process}")
|
||||
|
||||
while True:
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated",
|
||||
f"v{self._call_count}r{dist.get_rank()}.txt",
|
||||
)
|
||||
if os.path.exists(gen_file_path):
|
||||
self._call_count += 1
|
||||
else:
|
||||
break
|
||||
logger.info(f"call_count: {self._call_count}")
|
||||
|
||||
def inference(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
data_: SequenceSample,
|
||||
mb_spec,
|
||||
) -> SequenceSample:
|
||||
|
||||
packed_input_ids: torch.Tensor = data_.data["packed_input_ids"].squeeze()
|
||||
|
||||
input_seqlens = torch.tensor(data_.seqlens["packed_input_ids"]).view(-1)
|
||||
input_cu_seqlens = torch.nn.functional.pad(
|
||||
input_seqlens.cumsum(0), (1, 0)
|
||||
).int()
|
||||
|
||||
packed_prompts = data_.data["packed_prompts"]
|
||||
prompts = []
|
||||
prompt_seqlens = []
|
||||
offset = 0
|
||||
for x in data_.seqlens["packed_prompts"]:
|
||||
prompts += [packed_prompts[offset : offset + x[0]]] * self.group_size
|
||||
offset += x[0]
|
||||
prompt_seqlens.extend(x * self.group_size)
|
||||
|
||||
assert offset == sum(x[0] for x in data_.seqlens["packed_prompts"])
|
||||
# non_packed_prompts = copy.deepcopy(prompts)
|
||||
prompts = torch.cat(prompts)
|
||||
prompt_seqlens = torch.tensor(prompt_seqlens).view(-1)
|
||||
prompt_cu_seqlens = torch.nn.functional.pad(
|
||||
prompt_seqlens.cumsum(0), (1, 0)
|
||||
).int()
|
||||
|
||||
query_ids = [data_id for data_id in data_.ids for _ in range(self.group_size)]
|
||||
|
||||
format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results = (
|
||||
retokenize(
|
||||
self.task,
|
||||
self.tokenizer,
|
||||
packed_input_ids,
|
||||
input_cu_seqlens,
|
||||
prompts,
|
||||
prompt_cu_seqlens,
|
||||
query_ids,
|
||||
check_xml_format=self.check_xml_format,
|
||||
do_eval=constants.is_last_pipe_stage(),
|
||||
)
|
||||
)
|
||||
|
||||
assert self.rw_type == "sparse"
|
||||
dense_scores = torch.zeros_like(packed_input_ids).float()
|
||||
scores = torch.FloatTensor(format_rewards).to(packed_input_ids.device)
|
||||
scores[scores == 0] = -1
|
||||
|
||||
if len(scores) == 0:
|
||||
return None
|
||||
|
||||
assert dense_scores.shape == packed_input_ids.shape
|
||||
scores = (
|
||||
scores.to(packed_input_ids.device) - self.output_bias
|
||||
) * self.output_scaling
|
||||
|
||||
logger.info(f"Code reward logging info @v{model.version.global_step}")
|
||||
logger.info(
|
||||
f"before: Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
||||
)
|
||||
logger.info(
|
||||
f"number of samples: {len(scores)}, {scores.shape}, group_size: {self.group_size}"
|
||||
)
|
||||
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated",
|
||||
f"v{self._call_count}r{dist.get_rank()}.txt",
|
||||
)
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
info = "\n".join(
|
||||
[
|
||||
f"idx: {idx} / {len(scores)}",
|
||||
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
|
||||
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
|
||||
]
|
||||
)
|
||||
_f.write(info + "\n")
|
||||
|
||||
gen_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"generated_jsonl",
|
||||
f"v{self._call_count}r{dist.get_rank()}.jsonl",
|
||||
)
|
||||
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
||||
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
|
||||
with open(gen_file_path, "w") as _f:
|
||||
for idx, (score, prompt_str, seq_str) in enumerate(
|
||||
zip(scores, prompt_strs, seq_strs)
|
||||
):
|
||||
_f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt": prompt_str,
|
||||
"generated": seq_str.split(prompt_str)[1],
|
||||
"reward": score.item(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
||||
)
|
||||
|
||||
pass_at_k = np.mean(
|
||||
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
|
||||
)
|
||||
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
|
||||
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
|
||||
|
||||
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
|
||||
|
||||
logger.info(f"reward: {sum(scores) / len(scores)}")
|
||||
|
||||
train_pass_monitor_file_path = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"training_monitor",
|
||||
f"v{self._call_count}r{dist.get_rank()}.jsonl",
|
||||
)
|
||||
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
|
||||
logger.info(
|
||||
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
|
||||
)
|
||||
with open(train_pass_monitor_file_path, "w") as monitor_file:
|
||||
for key, value in queryid_to_results.items():
|
||||
pass1 = sum(value) / len(value)
|
||||
pass8 = int(sum(value) > 0)
|
||||
monitor_file.write(
|
||||
json.dumps(
|
||||
{"query_id": key, "pass1": pass1, "pass8": pass8},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
self._call_count += 1
|
||||
|
||||
if scores.dtype != torch.float32:
|
||||
scores = scores.to(torch.float32)
|
||||
if dense_scores.dtype != torch.float32:
|
||||
dense_scores = dense_scores.to(torch.float32)
|
||||
|
||||
res = SequenceSample(
|
||||
keys=["rewards", "dense_rewards"],
|
||||
trailing_shapes=dict(rewards=(), dense_rewards=()),
|
||||
dtypes=dict(rewards=torch.float32, dense_rewards=torch.float32),
|
||||
ids=data_.ids,
|
||||
seqlens=dict(
|
||||
rewards=[
|
||||
torch.tensor([1 for _ in range(len(x))], dtype=torch.int32)
|
||||
for x in data_.seqlens["packed_input_ids"]
|
||||
],
|
||||
dense_rewards=data_.seqlens["packed_input_ids"],
|
||||
),
|
||||
data=dict(rewards=scores, dense_rewards=dense_scores),
|
||||
)
|
||||
|
||||
# record rewards for each piece of data
|
||||
avg_scores = []
|
||||
offset = 0
|
||||
for i in range(data_.bs):
|
||||
score_lis = scores[
|
||||
offset : offset + len(data_.seqlens["packed_input_ids"][i])
|
||||
]
|
||||
avg_scores.append(score_lis.mean().item())
|
||||
offset += len(data_.seqlens["packed_input_ids"][i])
|
||||
assert offset == sum(len(x) for x in data_.seqlens["packed_input_ids"])
|
||||
|
||||
res.metadata["scores"] = avg_scores
|
||||
|
||||
if self.check_verifier_status:
|
||||
avg_score = torch.tensor(
|
||||
np.mean(avg_scores), device=constants.current_device()
|
||||
)
|
||||
dist.all_reduce(
|
||||
avg_score, op=dist.ReduceOp.SUM, group=constants.parallelism_group()
|
||||
)
|
||||
avg_score /= constants.parallelism_group_size()
|
||||
avg_score = avg_score.item()
|
||||
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
|
||||
|
||||
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
|
||||
raise CodeVerifierException(
|
||||
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
model_api.register_interface("rw_code", PackedCodeRewardInterface)
|
|
@ -37,10 +37,34 @@ ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else
|
|||
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
|
||||
|
||||
|
||||
class MathVerifierException(Exception):
|
||||
class VerifierException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def extract_python_code(text, min_length=20, strict_syntax=True):
|
||||
code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```"
|
||||
code_blocks = re.findall(code_pattern, text, re.DOTALL)
|
||||
valid_blocks = []
|
||||
for block in code_blocks:
|
||||
clean_block = block.strip()
|
||||
if len(clean_block) < min_length:
|
||||
continue
|
||||
|
||||
# verify code syntax
|
||||
if strict_syntax:
|
||||
try:
|
||||
parse(clean_block, mode="exec")
|
||||
except (SyntaxError, IndentationError):
|
||||
continue
|
||||
|
||||
valid_blocks.append(clean_block)
|
||||
|
||||
if not valid_blocks:
|
||||
return None
|
||||
# return the last code block
|
||||
return valid_blocks[-1]
|
||||
|
||||
|
||||
def check_with_elementtree(text):
|
||||
def escape_between_tags(text, tags=["think", "answer"]):
|
||||
"""转义标签之间的内容,但保留标签本身."""
|
||||
|
@ -94,6 +118,19 @@ def check_with_elementtree(text):
|
|||
return False, f"Error: XML格式错误, {str(e)}"
|
||||
|
||||
|
||||
def reward_caculate(task, _answers, query_id_strs):
|
||||
format_rewards = []
|
||||
if task == "math":
|
||||
format_rewards = math_verify_call(_answers, query_id_strs)
|
||||
else:
|
||||
codes = [extract_python_code(_answer) for _answer in _answers]
|
||||
format_rewards = code_verify(codes, query_id_strs)
|
||||
logger.info(
|
||||
f"reward_caculate, task: {task}, size: {len(query_id_strs)}, query_id_0: {query_id_strs[0]}"
|
||||
)
|
||||
return format_rewards
|
||||
|
||||
|
||||
def retokenize(
|
||||
task,
|
||||
tokenizer,
|
||||
|
@ -122,6 +159,10 @@ def retokenize(
|
|||
# query_id_strs = query_ids
|
||||
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
|
||||
|
||||
logger.info(
|
||||
f"retokenize, query_id_strs:{query_id_strs}, seq_strs:{seq_strs}, prompt_strs:{prompt_strs}"
|
||||
)
|
||||
|
||||
format_rewards = []
|
||||
queryid_to_results = collections.defaultdict(list)
|
||||
# 8 processes on each node, with 10 subprocesses each
|
||||
|
@ -162,8 +203,7 @@ def retokenize(
|
|||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PackedMathRewardInterface(model_api.ModelInterface):
|
||||
|
||||
class PackedRewardInterface(model_api.ModelInterface):
|
||||
enable_save: bool = False
|
||||
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
||||
output_scaling: float = 1.0
|
||||
|
@ -392,7 +432,7 @@ class PackedMathRewardInterface(model_api.ModelInterface):
|
|||
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
|
||||
|
||||
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
|
||||
raise MathVerifierException(
|
||||
raise VerifierException(
|
||||
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
|
||||
)
|
||||
|
||||
|
@ -401,4 +441,4 @@ class PackedMathRewardInterface(model_api.ModelInterface):
|
|||
return res
|
||||
|
||||
|
||||
model_api.register_interface("rw_math", PackedMathRewardInterface)
|
||||
model_api.register_interface("reward", PackedRewardInterface)
|
|
@ -0,0 +1,270 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import time
|
||||
from typing import Dict, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
import realhf.api.core.model_api as model_api
|
||||
|
||||
import realhf.base.logging as logging
|
||||
import realhf.impl.model.utils.ppo_functional as ppo_functional
|
||||
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
|
||||
from realhf.base.datapack import flat2d
|
||||
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
|
||||
from realhf.impl.model.utils.functional import (
|
||||
gather_packed_shifted_log_probs,
|
||||
masked_normalization,
|
||||
)
|
||||
from realhf.impl.model.interface.rw_interface import PackedRewardInterface
|
||||
|
||||
logger = logging.getLogger("RefRwInterface")
|
||||
|
||||
TASK_TYPE_REF: Literal["ref"] = "ref"
|
||||
TASK_TYPE_RW_MATH: Literal["rw_math"] = "rw_math"
|
||||
TASK_TYPE_RW_CODE: Literal["rw_code"] = "rw_code"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RefRwInterface(model_api.ModelInterface):
|
||||
n_minibatches: int = 4
|
||||
|
||||
# Use dict here to allow argument passing through commandline.
|
||||
generation_config: Dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
kl_ctl: float = 0.1
|
||||
|
||||
adv_norm: bool = True
|
||||
discount: float = 1.0
|
||||
gae_lambda: float = 1.0
|
||||
|
||||
eps_clip: float = 0.2
|
||||
value_eps_clip: float = 0.2
|
||||
max_reward_clip: float = 5.0
|
||||
|
||||
disable_value: bool = False
|
||||
|
||||
early_stop_kl: Optional[float] = None # e.g. 0.1
|
||||
early_stop_imp_ratio: Optional[float] = None # e.g., 10.0
|
||||
|
||||
adaptive_kl_ctl: bool = False
|
||||
adaptive_kl_target: Optional[float] = 6
|
||||
adaptive_kl_horizon: Optional[float] = 10000
|
||||
|
||||
enable_save: bool = True
|
||||
|
||||
value_norm: bool = False
|
||||
value_norm_type: str = dataclasses.field(
|
||||
metadata={"choices": ["exp", "ma"]}, default="exp"
|
||||
)
|
||||
value_norm_beta: float = 0.99995
|
||||
value_norm_eps: float = 1e-5
|
||||
|
||||
group_size: int = 1
|
||||
generation_size: Optional[int] = None
|
||||
mask_no_eos_with_zero: bool = False
|
||||
group_adv_norm: bool = False
|
||||
mask_too_long: bool = False
|
||||
use_dense_reward: bool = False
|
||||
reward_delta: bool = True
|
||||
token_normalize_scope: Literal["global", "dp"] = "global"
|
||||
rew_inf_args: Dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.adaptive_kl_ctl:
|
||||
assert self.adaptive_kl_target is not None
|
||||
assert self.adaptive_kl_horizon is not None
|
||||
self.kl_adapter = ppo_functional.AdaptiveKLController(
|
||||
self.kl_ctl, self.adaptive_kl_target, self.adaptive_kl_horizon
|
||||
)
|
||||
else:
|
||||
self.kl_adapter = ppo_functional.FixedKLController(self.kl_ctl)
|
||||
if self.value_norm:
|
||||
from realhf.impl.model.modules import (
|
||||
ExponentialRunningMeanStd,
|
||||
MovingAverageRunningMeanStd,
|
||||
)
|
||||
|
||||
if self.value_norm_type == "exp":
|
||||
self.rms = ExponentialRunningMeanStd(
|
||||
beta=self.value_norm_beta, epsilon=self.value_norm_eps
|
||||
)
|
||||
elif self.value_norm_type == "ma":
|
||||
self.rms = MovingAverageRunningMeanStd()
|
||||
else:
|
||||
raise ValueError(f"Unknown value_norm_type {self.value_norm_type}")
|
||||
self.kl_ctl = None
|
||||
|
||||
self.gconfig = model_api.GenerationHyperparameters(**self.generation_config)
|
||||
if self.generation_size is not None:
|
||||
assert self.generation_size >= self.group_size
|
||||
else:
|
||||
self.generation_size = self.group_size
|
||||
self.gconfig.n = self.generation_size
|
||||
|
||||
def save(self, model: model_api.Model, save_dir: str):
|
||||
if not self.enable_save:
|
||||
return
|
||||
module = model.module
|
||||
if not isinstance(module, ReaLModel):
|
||||
module = module.module
|
||||
module.save_to_hf(
|
||||
tokenizer=model.tokenizer,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
|
||||
def _dispatch_tasks(self, data):
|
||||
math_data, code_data, rlhf_data, ref_data = data, data, data, data
|
||||
return math_data, code_data, rlhf_data, ref_data
|
||||
|
||||
def _gather_tasks(self, data_map):
|
||||
# merge SequenceSamples from math_data, code_data, rlhf_data, ref_data
|
||||
return data_map.get(TASK_TYPE_REF, None)
|
||||
|
||||
@torch.no_grad()
|
||||
def ref_inference(
|
||||
self, model: model_api.Model, input_: SequenceSample, mb_spec: MicroBatchSpec
|
||||
):
|
||||
module = model.module
|
||||
module.eval()
|
||||
|
||||
# This post_hook will gather log probabilities in mini-batches,
|
||||
# reducing peak memory usage.
|
||||
def calc_logprobs(logits, input_):
|
||||
logits /= self.gconfig.temperature
|
||||
|
||||
input_lens = torch.tensor(input_.seqlens["packed_input_ids"]).view(-1)
|
||||
cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
|
||||
|
||||
logprobs = gather_packed_shifted_log_probs(
|
||||
logits, cu_seqlens, input_.data["packed_input_ids"]
|
||||
)
|
||||
return logprobs
|
||||
|
||||
input_flattend = SequenceSample.from_default(
|
||||
ids=list(range(input_.bs * self.group_size)),
|
||||
seqlens=flat2d(input_.seqlens["packed_input_ids"]),
|
||||
data=dict(packed_input_ids=input_.data["packed_input_ids"]),
|
||||
)
|
||||
# add posthook to avoid storing full logits
|
||||
logprobs = module.forward(
|
||||
input_=input_flattend,
|
||||
post_hook=calc_logprobs,
|
||||
output_seqlens=[
|
||||
[x - 1 for x in slens]
|
||||
for slens in input_flattend.seqlens["packed_input_ids"]
|
||||
],
|
||||
mb_spec=mb_spec,
|
||||
)
|
||||
|
||||
res = SequenceSample(
|
||||
keys=["packed_ref_logprobs"],
|
||||
ids=input_.ids,
|
||||
dtypes=dict(packed_ref_logprobs=model.module.dtype),
|
||||
trailing_shapes=dict(packed_ref_logprobs=()),
|
||||
data=dict(packed_ref_logprobs=logprobs),
|
||||
seqlens=dict(
|
||||
packed_ref_logprobs=[
|
||||
[x - 1 for x in slen] for slen in input_.seqlens["packed_input_ids"]
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def inference(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
input_: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample:
|
||||
math_data, code_data, rlhf_data, ref_data = self._dispatch_tasks(input_)
|
||||
|
||||
if not hasattr(self, "rew_inf_args") or not isinstance(self.rew_inf_args, dict):
|
||||
raise ValueError("Invalid rew_inf_args. Expected a dictionary.")
|
||||
rewardInterface = PackedRewardInterface(**self.rew_inf_args)
|
||||
logger.info(f"self.rew_inf_args: {self.rew_inf_args}, input_: {input_}")
|
||||
|
||||
task_map = {
|
||||
TASK_TYPE_REF: (self.ref_inference, ref_data),
|
||||
TASK_TYPE_RW_MATH: (rewardInterface.inference, math_data),
|
||||
TASK_TYPE_RW_CODE: (rewardInterface.inference, code_data),
|
||||
}
|
||||
|
||||
def _task_func(func, task_type: str):
|
||||
def _wrapped_func(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
logger.info(f"[{task_type}] ref_rw task start @ {start_time:.4f}")
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"{task_type} ref_rw task failed: {e}")
|
||||
finally:
|
||||
duration = time.perf_counter() - start_time
|
||||
logger.info(
|
||||
f"[{task_type}] ref_rw task cost: {duration:.4f}s, start @ {start_time:.4f}"
|
||||
)
|
||||
return result
|
||||
|
||||
return _wrapped_func
|
||||
|
||||
async def _run_tasks() -> dict:
|
||||
tasks = []
|
||||
for task_type, (func, data) in task_map.items():
|
||||
if not data:
|
||||
continue
|
||||
task_func = _task_func(func, task_type)
|
||||
task_args = (model, data, mb_spec)
|
||||
task = asyncio.create_task(asyncio.to_thread(task_func, *task_args))
|
||||
tasks.append((task_type, task))
|
||||
|
||||
results = {}
|
||||
for task_type, task in tasks:
|
||||
try:
|
||||
results[task_type] = await task
|
||||
except Exception as e:
|
||||
logger.error(f"{task_type} task failed: {e}")
|
||||
results[task_type] = None
|
||||
return results
|
||||
|
||||
task_results = asyncio.run(_run_tasks())
|
||||
final_result = self._gather_tasks(task_results)
|
||||
return final_result
|
||||
|
||||
# Mock methods for profiling only.
|
||||
def _mock_inference(
|
||||
self,
|
||||
model: model_api.Model,
|
||||
dataset_input: SequenceSample,
|
||||
) -> SequenceSample:
|
||||
prompt_lens = flat2d(dataset_input.seqlens["packed_prompts"])
|
||||
seqlens = [x + self.gconfig.max_new_tokens for x in prompt_lens]
|
||||
module = model.module
|
||||
if not isinstance(module, ReaLModel):
|
||||
module = module.module
|
||||
mconfig = module.config
|
||||
packed_input_ids = torch.randint(
|
||||
0,
|
||||
mconfig.vocab_size,
|
||||
(sum(seqlens),),
|
||||
dtype=torch.long,
|
||||
device=model.device,
|
||||
)
|
||||
|
||||
return SequenceSample.from_default(
|
||||
seqlens=seqlens,
|
||||
ids=dataset_input.ids,
|
||||
data=dict(packed_input_ids=packed_input_ids),
|
||||
)
|
||||
|
||||
|
||||
model_api.register_interface("ref_rw", RefRwInterface)
|
|
@ -0,0 +1,317 @@
|
|||
import functools
|
||||
import gc
|
||||
import os
|
||||
import pickle
|
||||
import json
|
||||
import time
|
||||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
import pynvml
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import transformers
|
||||
from torch.cuda import is_initialized
|
||||
|
||||
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
print(root_dir)
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, root_dir)
|
||||
|
||||
from realhf.api.core import data_api, dfg, model_api
|
||||
from realhf.api.core.config import ModelName
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base import constants, logging
|
||||
from realhf.base.network import find_free_port
|
||||
from realhf.base.testing import (
|
||||
init_global_constants,
|
||||
_DEFAULT_EXPR_NAME,
|
||||
_DEFAULT_TRIAL_NAME,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("test async ref-rew")
|
||||
os.environ["REAL_MATH_METADATA_PATH"] = "/storage/datasets/id2info.json"
|
||||
|
||||
|
||||
def loadJson():
|
||||
dataDir = os.environ["REAL_MATH_METADATA_PATH"]
|
||||
with open(dataDir, "r") as f:
|
||||
if dataDir.endswith(".jsonl"):
|
||||
samples = [json.loads(line) for line in f.readlines()]
|
||||
else:
|
||||
samples = json.load(f)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def _mock_input(batch_size: int, seq_len):
|
||||
vocab_size = 100
|
||||
torch.manual_seed(1)
|
||||
seqs = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long)
|
||||
|
||||
samples = loadJson()
|
||||
id_list = list(samples.keys())
|
||||
# id_tensor = torch.tensor([id_list[i] for i in range(seqs.shape[0])], dtype=torch.long) # 使用哈希值编码
|
||||
|
||||
return data_api.SequenceSample.from_default(
|
||||
seqlens=[seq_len for _ in range(seqs.shape[0])],
|
||||
ids=[id_list[i] for i in range(seqs.shape[0])],
|
||||
data=dict(
|
||||
packed_input_ids=seqs.view(-1),
|
||||
# prompt_mask=torch.zeros_like(seqs.view(-1), dtype=torch.bool),
|
||||
packed_prompts=seqs[:, :seq_len].contiguous().view(-1),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def funcion_call(
|
||||
rpc_name: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
model_path: str,
|
||||
model_family_name: str,
|
||||
dp: int,
|
||||
pp: int,
|
||||
tp: int,
|
||||
interface_type: dfg.ModelInterfaceType,
|
||||
interface_impl: dfg.ModelInterfaceAbstraction,
|
||||
batch_size: int,
|
||||
prompt_len: int,
|
||||
input_: data_api.SequenceSample | None,
|
||||
port: int,
|
||||
):
|
||||
|
||||
# assert not torch.cuda.is_initialized()
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
||||
torch.cuda.set_device(0)
|
||||
assert world_size == (
|
||||
dp * pp * tp
|
||||
), f"dp={dp}, pp={pp}, tp={tp}, world_size={world_size}"
|
||||
assert batch_size % dp == 0, (batch_size, dp)
|
||||
|
||||
# Initialize distributed environment.
|
||||
model_name = ModelName("default", 0)
|
||||
if not dist.is_initialized():
|
||||
logger.info("Setting up distributed environment...")
|
||||
dist.init_process_group(
|
||||
"nccl",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
init_method=f"tcp://localhost:{port}",
|
||||
)
|
||||
logger.info("Initialized distributed environment.")
|
||||
init_global_constants(
|
||||
num_dp=dp,
|
||||
num_mp=tp,
|
||||
num_pp=pp,
|
||||
sequence_parallel=interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
|
||||
model_name=model_name,
|
||||
max_prompt_len=prompt_len,
|
||||
)
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# NOTE: import here to avoid CUDA re-initialization
|
||||
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel, add_helper_functions
|
||||
|
||||
# Call a method like `config_from_llama` to get the config.
|
||||
mconfig: ReaLModelConfig = getattr(ReaLModel, f"config_from_{model_family_name}")(
|
||||
transformers.AutoConfig.from_pretrained(model_path)
|
||||
)
|
||||
is_critic = rpc_name in ["critic_inf", "critic_train", "rew_inf"]
|
||||
mconfig.is_critic = is_critic
|
||||
with constants.model_scope(model_name):
|
||||
# Construct the model.
|
||||
logger.info(f"Loading model from {model_path}...")
|
||||
module = ReaLModel(mconfig, dtype=torch.bfloat16, device="cuda")
|
||||
setattr(ReaLModel, "save_to_hf", getattr(ReaLModel, f"to_{model_family_name}"))
|
||||
setattr(
|
||||
ReaLModel, "load_from_hf", getattr(ReaLModel, f"from_{model_family_name}")
|
||||
)
|
||||
module._instantiation_hooks.append(
|
||||
lambda: getattr(module, f"from_{model_family_name}")(
|
||||
load_dir=model_path,
|
||||
init_critic_from_actor=is_critic,
|
||||
)
|
||||
)
|
||||
add_helper_functions(module)
|
||||
module.instantiate()
|
||||
module.eval()
|
||||
|
||||
tokenizer = data_api.load_hf_tokenizer(model_path)
|
||||
|
||||
model = model_api.Model(
|
||||
name=model_name,
|
||||
module=module,
|
||||
tokenizer=tokenizer,
|
||||
device=module.device,
|
||||
dtype=module.dtype,
|
||||
)
|
||||
if interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
|
||||
from realhf.impl.model.backend.megatron import MegatronTrainBackend
|
||||
|
||||
backend = MegatronTrainBackend()
|
||||
else:
|
||||
from realhf.impl.model.backend.inference import PipelineInferenceBackend
|
||||
|
||||
backend = PipelineInferenceBackend()
|
||||
|
||||
logger.info("Running backend initialization...")
|
||||
ft_spec = model_api.FinetuneSpec(
|
||||
total_train_epochs=1,
|
||||
dataset_size=128,
|
||||
train_batch_size=128,
|
||||
)
|
||||
model = backend.initialize(model, ft_spec)
|
||||
|
||||
interface = model_api.make_interface(interface_impl)
|
||||
|
||||
if input_ is None:
|
||||
input_ = _mock_input(batch_size, prompt_len)
|
||||
|
||||
input_ = input_.cuda()
|
||||
|
||||
mb_spec = model_api.MicroBatchSpec()
|
||||
|
||||
logger.info("Running interface computation...")
|
||||
start = time.perf_counter_ns()
|
||||
if interface_type == dfg.ModelInterfaceType.GENERATE:
|
||||
res = interface.generate(model, input_, mb_spec)
|
||||
elif interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
|
||||
res = interface.train_step(model, input_)
|
||||
else:
|
||||
res = interface.inference(model, input_, mb_spec)
|
||||
|
||||
if constants.model_parallel_rank() == 0 and constants.is_last_pipe_stage():
|
||||
if isinstance(res, data_api.SequenceSample):
|
||||
res = res.cpu()
|
||||
|
||||
comsumed = time.perf_counter_ns() - start
|
||||
logger.info(f"{rpc_name} Computation done. {comsumed} ns")
|
||||
return res
|
||||
|
||||
|
||||
def run_function_call(
|
||||
rpc_name: str,
|
||||
model_path: str,
|
||||
model_family_name: str,
|
||||
batch_size: int,
|
||||
prompt_len: int,
|
||||
gen_len: int,
|
||||
input_: data_api.SequenceSample | None,
|
||||
) -> data_api.SequenceSample | None:
|
||||
assert rpc_name in [
|
||||
"actor_gen",
|
||||
"actor_train",
|
||||
"critic_inf",
|
||||
"rew_inf",
|
||||
"critic_train",
|
||||
"ref_inf",
|
||||
"ref_rw",
|
||||
]
|
||||
|
||||
ref_rw_interface = dfg.ModelInterfaceAbstraction(
|
||||
"ref_rw",
|
||||
args=dict(
|
||||
generation_config=dict(
|
||||
max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True
|
||||
),
|
||||
rew_inf_args=dict(
|
||||
tokenizer_path=model_path,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
ppo_actor_interface = dfg.ModelInterfaceAbstraction(
|
||||
"ppo_actor",
|
||||
args=dict(
|
||||
generation_config=dict(
|
||||
max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True
|
||||
),
|
||||
rew_inf_args=dict(
|
||||
tokenizer_path=model_path,
|
||||
),
|
||||
),
|
||||
)
|
||||
ppo_critic_interface = dfg.ModelInterfaceAbstraction("ppo_critic")
|
||||
rw_interface = dfg.ModelInterfaceAbstraction(
|
||||
"paired_rw",
|
||||
)
|
||||
if rpc_name == "actor_gen":
|
||||
interface_type = dfg.ModelInterfaceType.GENERATE
|
||||
interface_impl = ppo_actor_interface
|
||||
elif rpc_name == "actor_train":
|
||||
interface_type = dfg.ModelInterfaceType.TRAIN_STEP
|
||||
interface_impl = ppo_actor_interface
|
||||
elif rpc_name == "critic_inf":
|
||||
interface_type = dfg.ModelInterfaceType.INFERENCE
|
||||
interface_impl = ppo_critic_interface
|
||||
elif rpc_name == "ref_inf":
|
||||
interface_type = dfg.ModelInterfaceType.INFERENCE
|
||||
interface_impl = ppo_actor_interface
|
||||
elif rpc_name == "ref_rw":
|
||||
interface_type = dfg.ModelInterfaceType.INFERENCE
|
||||
interface_impl = ref_rw_interface
|
||||
elif rpc_name == "critic_train":
|
||||
interface_type = dfg.ModelInterfaceType.TRAIN_STEP
|
||||
interface_impl = ppo_critic_interface
|
||||
else:
|
||||
interface_type = dfg.ModelInterfaceType.INFERENCE
|
||||
interface_impl = rw_interface
|
||||
|
||||
logger.info(f"Running RPC {rpc_name}...")
|
||||
|
||||
port = find_free_port()
|
||||
res = funcion_call(
|
||||
rank=0,
|
||||
rpc_name=rpc_name,
|
||||
world_size=1,
|
||||
model_path=model_path,
|
||||
model_family_name=model_family_name,
|
||||
dp=1,
|
||||
pp=1,
|
||||
tp=1,
|
||||
interface_type=interface_type,
|
||||
interface_impl=interface_impl,
|
||||
batch_size=batch_size,
|
||||
prompt_len=prompt_len,
|
||||
input_=input_,
|
||||
port=port,
|
||||
)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
if isinstance(res, data_api.SequenceSample):
|
||||
return res
|
||||
else:
|
||||
logger.info(f"RPC {rpc_name} stats: {res}")
|
||||
|
||||
|
||||
def main():
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
model_family_name = "qwen2"
|
||||
batch_size = 16
|
||||
prompt_len = 128
|
||||
gen_len = 4096
|
||||
model_path = "/storage/models/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
|
||||
constants.set_experiment_trial_names(_DEFAULT_EXPR_NAME, _DEFAULT_TRIAL_NAME)
|
||||
|
||||
for i in range(2):
|
||||
ref_rw_res = run_function_call(
|
||||
"ref_rw",
|
||||
model_family_name=model_family_name,
|
||||
model_path=model_path,
|
||||
batch_size=batch_size,
|
||||
prompt_len=prompt_len,
|
||||
gen_len=gen_len,
|
||||
input_=None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue