create ref-rw inference aysnc mode

This commit is contained in:
Jun Mo 2025-03-13 18:06:38 +08:00
parent 234e3dd3a0
commit 8732811c73
6 changed files with 651 additions and 446 deletions

View File

@ -19,7 +19,7 @@ def check_is_realhf_native_impl(_cls):
def check_is_realhf_native_model_interface(name):
# NOTE: we should not import iterfaces here,
# such that we can avoid CUDA initialization.
return name in ["ppo_actor", "ppo_critic", "sft"]
return name in ["ppo_actor", "ppo_critic", "sft", "ref_rw"]
def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]):

View File

@ -253,7 +253,7 @@ class PPOMATHConfig(CommonExperimentConfig):
"actor": self.actor,
# "critic": self.critic,
"ref": self.ref,
"reward": self.rew,
# "reward": self.rew,
}
else:
return {
@ -304,8 +304,6 @@ class PPOMATHConfig(CommonExperimentConfig):
"reward_delta": self.reward_delta,
},
)
ref_interface = copy.deepcopy(actor_interface)
ref_interface.args["enable_save"] = False
critic_interface = ModelInterfaceAbstraction(
"ppo_critic",
@ -319,7 +317,7 @@ class PPOMATHConfig(CommonExperimentConfig):
)
critic_interface.args.pop("eps_clip")
rw_interface = ModelInterfaceAbstraction(
"rw_math",
"reward",
args=dict(
rw_type=self.rw_type,
task=self.task,
@ -368,9 +366,16 @@ class PPOMATHConfig(CommonExperimentConfig):
n_seqs=self.dataset.train_bs_n_seqs,
)
inf_ref_inputs = ["packed_input_ids"]
# add rew param into ref MFC
inf_ref_inputs = ["packed_input_ids", "packed_prompts"]
inf_ref_outputs = ["packed_ref_logprobs", "rewards", "dense_rewards"]
ref_interface = copy.deepcopy(actor_interface)
ref_interface.type_ = "ref_rw"
ref_interface.args["enable_save"] = False
ref_interface.args["rew_inf_args"] = copy.deepcopy(rw_interface.args)
inf_ref_logits = MFCDef(
name="ref_inf",
name="ref_rw",
model_name="ref",
mb_spec=self.ref_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
@ -379,7 +384,7 @@ class PPOMATHConfig(CommonExperimentConfig):
interface_impl=ref_interface,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=inf_ref_inputs,
output_keys=["packed_ref_logprobs"],
output_keys=inf_ref_outputs,
n_seqs=self.dataset.train_bs_n_seqs,
)
@ -451,8 +456,9 @@ class PPOMATHConfig(CommonExperimentConfig):
"actor_train": train_actor,
# "critic_inf": inf_values,
# "critic_train": train_critic,
"ref_inf": inf_ref_logits,
"rew_inf": inf_reward,
# "ref_inf": inf_ref_logits,
# "rew_inf": inf_reward,
"ref_rw": inf_ref_logits,
}
else:
return {
@ -472,8 +478,9 @@ class PPOMATHConfig(CommonExperimentConfig):
"actor_train": self.actor_train,
# "critic_inf": self.critic_inf,
# "critic_train": self.critic_train,
"ref_inf": self.ref_inf,
"rew_inf": self.rew_inf,
# "ref_inf": self.ref_inf,
# "rew_inf": self.rew_inf,
"ref_rw": self.ref_inf,
}
else:
return {

View File

@ -1,429 +0,0 @@
# Copyright 2025 Ant Group Inc.
import collections
import dataclasses
import html
import json
import os
import re
import xml.etree.ElementTree as ET
from ast import parse
from concurrent.futures import ThreadPoolExecutor, as_completed
import colorama
import numpy as np
import torch
import torch.distributed as dist
import realhf.api.core.model_api as model_api
import realhf.base.logging as logging
from functioncall.code.verify import code_verify
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer
from realhf.base import constants
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
class CodeVerifierException(Exception):
pass
def extract_python_code(text, min_length=20, strict_syntax=True):
code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```"
code_blocks = re.findall(code_pattern, text, re.DOTALL)
valid_blocks = []
for block in code_blocks:
clean_block = block.strip()
if len(clean_block) < min_length:
continue
# verify code syntax
if strict_syntax:
try:
parse(clean_block, mode="exec")
except (SyntaxError, IndentationError):
continue
valid_blocks.append(clean_block)
if not valid_blocks:
return None
# return the last code block
return valid_blocks[-1]
def check_with_elementtree(text):
def escape_between_tags(text, tags=["think", "answer"]):
"""转义标签之间的内容,但保留标签本身."""
# 构建标签模式
tag_pattern = "|".join(tags)
parts = []
current_pos = 0
# 匹配开始和结束标签
pattern = f"</?({tag_pattern})[^>]*>"
for match in re.finditer(pattern, text):
# 添加标签之前的内容(需要转义)
if current_pos < match.start():
parts.append(html.escape(text[current_pos : match.start()]))
# 添加标签本身(不转义)
parts.append(match.group())
current_pos = match.end()
# 添加最后剩余的内容
if current_pos < len(text):
parts.append(html.escape(text[current_pos:]))
return "".join(parts)
text = escape_between_tags(text)
if not text.strip().startswith("<think>"):
text = "<think>" + text
try:
xml_text = f"<root>{text}</root>"
x = ET.fromstring(xml_text)
if x.text is not None and x.text.strip() != "":
return False, f"Error: extra content before <think>. {x.text}"
if len(x) != 2:
return False, f"Error: there are {len(x)} tags."
if x[0].tag is None or x[0].tag != "think":
return False, f"Error: <think> tag is missing. {x[0].tag}"
if x[0].tail is not None and x[0].tail.strip() != "":
return (
False,
f"Error: extra content between <think> and <answer>. {x[0].tail}",
)
if x[1].tag is None or x[1].tag != "answer":
return False, f"Error: <answer> tag is missing. {x[1].tag}"
if x[1].tail is not None and x[1].tail.strip() != "":
return False, f"Error: extra content after <answer>, {x[1].tail}"
return True, x[1].text if x[1].text is not None else ""
except ET.ParseError as e:
return False, f"Error: XML格式错误, {str(e)}"
def retokenize(
task,
tokenizer,
packed_input_ids,
input_cu_seqlens,
prompts,
prompt_cu_seqlens,
query_ids,
check_xml_format=False,
do_eval=False,
):
input_ids = [
packed_input_ids[start:end]
for start, end in zip(input_cu_seqlens[:-1], input_cu_seqlens[1:])
]
prompt_ids = [
prompts[start:end]
for start, end in zip(prompt_cu_seqlens[:-1], prompt_cu_seqlens[1:])
]
seq_strs = tokenizer.batch_decode(
input_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
)
prompt_strs = tokenizer.batch_decode(
prompt_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
)
# query_id_strs = query_ids
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
format_rewards = []
queryid_to_results = collections.defaultdict(list)
# 8 processes on each node, with 10 subprocesses each
if do_eval == True:
_answers = [
seq_str.split(prompt_str)[1]
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
]
codes = [extract_python_code(_answer) for _answer in _answers]
logger.info(
f"code_rw_interface, size: {len(query_id_strs)}, valid code size: {len(codes)}, query_id_0: {query_id_strs[0]}"
)
format_rewards = code_verify(codes, query_id_strs)
if check_xml_format:
with ThreadPoolExecutor(max_workers=22) as executor:
futures = [
executor.submit(check_with_elementtree, answer_str)
for answer_str in _answers
]
# xml_rewards = []
for idx, future in enumerate(futures):
xml_reward, _ = future.result()
# xml_rewards.append(xml_reward)
if xml_reward == 1 and format_rewards[idx] == 0:
format_rewards[idx] = -0.8
elif xml_reward == 0 and format_rewards[idx] == 0:
format_rewards[idx] = -1
for query_id_str, format_reward in zip(query_id_strs, format_rewards):
if query_id_str not in queryid_to_results:
queryid_to_results[query_id_str] = []
queryid_to_results[query_id_str].append(format_reward)
else:
for query_id_str in query_id_strs:
if query_id_str not in queryid_to_results:
queryid_to_results[query_id_str] = []
queryid_to_results[query_id_str].append(0)
format_rewards.append(0)
return format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results
@dataclasses.dataclass
class PackedCodeRewardInterface(model_api.ModelInterface):
enable_save: bool = False
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
output_scaling: float = 1.0
rm_output_scaling: float = 1.0
rm_output_bias: float = 0.0
output_bias: float = 0.0
loss_fun = torch.nn.CrossEntropyLoss(reduction="none")
max_sync_length: int = 2048
rw_type: str = "sparse"
task: str = "code" # math or countdown or code
check_xml_format: bool = False
post_process: str = "sigmoid"
group_size: int = 1
check_verifier_status: bool = False
_call_count: int = 0
def __post_init__(self):
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
logger.info(f"rm_output_scaling: {self.rm_output_scaling}")
logger.info(f"rm_output_bias: {self.rm_output_bias}")
logger.info(f"output_scaling: {self.output_scaling}")
logger.info(f"output_bias: {self.output_bias}")
logger.info(f"max_sync_length: {self.max_sync_length}")
logger.info(f"rw_type: {self.rw_type}")
logger.info(f"post_process: {self.post_process}")
while True:
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
f"v{self._call_count}r{dist.get_rank()}.txt",
)
if os.path.exists(gen_file_path):
self._call_count += 1
else:
break
logger.info(f"call_count: {self._call_count}")
def inference(
self,
model: model_api.Model,
data_: SequenceSample,
mb_spec,
) -> SequenceSample:
packed_input_ids: torch.Tensor = data_.data["packed_input_ids"].squeeze()
input_seqlens = torch.tensor(data_.seqlens["packed_input_ids"]).view(-1)
input_cu_seqlens = torch.nn.functional.pad(
input_seqlens.cumsum(0), (1, 0)
).int()
packed_prompts = data_.data["packed_prompts"]
prompts = []
prompt_seqlens = []
offset = 0
for x in data_.seqlens["packed_prompts"]:
prompts += [packed_prompts[offset : offset + x[0]]] * self.group_size
offset += x[0]
prompt_seqlens.extend(x * self.group_size)
assert offset == sum(x[0] for x in data_.seqlens["packed_prompts"])
# non_packed_prompts = copy.deepcopy(prompts)
prompts = torch.cat(prompts)
prompt_seqlens = torch.tensor(prompt_seqlens).view(-1)
prompt_cu_seqlens = torch.nn.functional.pad(
prompt_seqlens.cumsum(0), (1, 0)
).int()
query_ids = [data_id for data_id in data_.ids for _ in range(self.group_size)]
format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results = (
retokenize(
self.task,
self.tokenizer,
packed_input_ids,
input_cu_seqlens,
prompts,
prompt_cu_seqlens,
query_ids,
check_xml_format=self.check_xml_format,
do_eval=constants.is_last_pipe_stage(),
)
)
assert self.rw_type == "sparse"
dense_scores = torch.zeros_like(packed_input_ids).float()
scores = torch.FloatTensor(format_rewards).to(packed_input_ids.device)
scores[scores == 0] = -1
if len(scores) == 0:
return None
assert dense_scores.shape == packed_input_ids.shape
scores = (
scores.to(packed_input_ids.device) - self.output_bias
) * self.output_scaling
logger.info(f"Code reward logging info @v{model.version.global_step}")
logger.info(
f"before: Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
)
logger.info(
f"number of samples: {len(scores)}, {scores.shape}, group_size: {self.group_size}"
)
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
f"v{self._call_count}r{dist.get_rank()}.txt",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
info = "\n".join(
[
f"idx: {idx} / {len(scores)}",
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
]
)
_f.write(info + "\n")
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated_jsonl",
f"v{self._call_count}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
_f.write(
json.dumps(
{
"prompt": prompt_str,
"generated": seq_str.split(prompt_str)[1],
"reward": score.item(),
},
ensure_ascii=False,
)
+ "\n"
)
logger.info(
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
)
pass_at_k = np.mean(
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
)
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
logger.info(f"reward: {sum(scores) / len(scores)}")
train_pass_monitor_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"training_monitor",
f"v{self._call_count}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
logger.info(
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
)
with open(train_pass_monitor_file_path, "w") as monitor_file:
for key, value in queryid_to_results.items():
pass1 = sum(value) / len(value)
pass8 = int(sum(value) > 0)
monitor_file.write(
json.dumps(
{"query_id": key, "pass1": pass1, "pass8": pass8},
ensure_ascii=False,
)
+ "\n"
)
self._call_count += 1
if scores.dtype != torch.float32:
scores = scores.to(torch.float32)
if dense_scores.dtype != torch.float32:
dense_scores = dense_scores.to(torch.float32)
res = SequenceSample(
keys=["rewards", "dense_rewards"],
trailing_shapes=dict(rewards=(), dense_rewards=()),
dtypes=dict(rewards=torch.float32, dense_rewards=torch.float32),
ids=data_.ids,
seqlens=dict(
rewards=[
torch.tensor([1 for _ in range(len(x))], dtype=torch.int32)
for x in data_.seqlens["packed_input_ids"]
],
dense_rewards=data_.seqlens["packed_input_ids"],
),
data=dict(rewards=scores, dense_rewards=dense_scores),
)
# record rewards for each piece of data
avg_scores = []
offset = 0
for i in range(data_.bs):
score_lis = scores[
offset : offset + len(data_.seqlens["packed_input_ids"][i])
]
avg_scores.append(score_lis.mean().item())
offset += len(data_.seqlens["packed_input_ids"][i])
assert offset == sum(len(x) for x in data_.seqlens["packed_input_ids"])
res.metadata["scores"] = avg_scores
if self.check_verifier_status:
avg_score = torch.tensor(
np.mean(avg_scores), device=constants.current_device()
)
dist.all_reduce(
avg_score, op=dist.ReduceOp.SUM, group=constants.parallelism_group()
)
avg_score /= constants.parallelism_group_size()
avg_score = avg_score.item()
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
raise CodeVerifierException(
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
)
return res
model_api.register_interface("rw_code", PackedCodeRewardInterface)

View File

@ -37,10 +37,34 @@ ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
class MathVerifierException(Exception):
class VerifierException(Exception):
pass
def extract_python_code(text, min_length=20, strict_syntax=True):
code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```"
code_blocks = re.findall(code_pattern, text, re.DOTALL)
valid_blocks = []
for block in code_blocks:
clean_block = block.strip()
if len(clean_block) < min_length:
continue
# verify code syntax
if strict_syntax:
try:
parse(clean_block, mode="exec")
except (SyntaxError, IndentationError):
continue
valid_blocks.append(clean_block)
if not valid_blocks:
return None
# return the last code block
return valid_blocks[-1]
def check_with_elementtree(text):
def escape_between_tags(text, tags=["think", "answer"]):
"""转义标签之间的内容,但保留标签本身."""
@ -94,6 +118,19 @@ def check_with_elementtree(text):
return False, f"Error: XML格式错误, {str(e)}"
def reward_caculate(task, _answers, query_id_strs):
format_rewards = []
if task == "math":
format_rewards = math_verify_call(_answers, query_id_strs)
else:
codes = [extract_python_code(_answer) for _answer in _answers]
format_rewards = code_verify(codes, query_id_strs)
logger.info(
f"reward_caculate, task: {task}, size: {len(query_id_strs)}, query_id_0: {query_id_strs[0]}"
)
return format_rewards
def retokenize(
task,
tokenizer,
@ -122,6 +159,10 @@ def retokenize(
# query_id_strs = query_ids
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
logger.info(
f"retokenize, query_id_strs:{query_id_strs}, seq_strs:{seq_strs}, prompt_strs:{prompt_strs}"
)
format_rewards = []
queryid_to_results = collections.defaultdict(list)
# 8 processes on each node, with 10 subprocesses each
@ -162,8 +203,7 @@ def retokenize(
@dataclasses.dataclass
class PackedMathRewardInterface(model_api.ModelInterface):
class PackedRewardInterface(model_api.ModelInterface):
enable_save: bool = False
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
output_scaling: float = 1.0
@ -392,7 +432,7 @@ class PackedMathRewardInterface(model_api.ModelInterface):
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
raise MathVerifierException(
raise VerifierException(
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
)
@ -401,4 +441,4 @@ class PackedMathRewardInterface(model_api.ModelInterface):
return res
model_api.register_interface("rw_math", PackedMathRewardInterface)
model_api.register_interface("reward", PackedRewardInterface)

View File

@ -0,0 +1,270 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import dataclasses
import functools
import itertools
import time
from typing import Dict, Literal, Optional, Tuple
import torch
import realhf.api.core.model_api as model_api
import realhf.base.logging as logging
import realhf.impl.model.utils.ppo_functional as ppo_functional
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
from realhf.base.datapack import flat2d
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.utils.functional import (
gather_packed_shifted_log_probs,
masked_normalization,
)
from realhf.impl.model.interface.rw_interface import PackedRewardInterface
logger = logging.getLogger("RefRwInterface")
TASK_TYPE_REF: Literal["ref"] = "ref"
TASK_TYPE_RW_MATH: Literal["rw_math"] = "rw_math"
TASK_TYPE_RW_CODE: Literal["rw_code"] = "rw_code"
@dataclasses.dataclass
class RefRwInterface(model_api.ModelInterface):
n_minibatches: int = 4
# Use dict here to allow argument passing through commandline.
generation_config: Dict = dataclasses.field(default_factory=dict)
kl_ctl: float = 0.1
adv_norm: bool = True
discount: float = 1.0
gae_lambda: float = 1.0
eps_clip: float = 0.2
value_eps_clip: float = 0.2
max_reward_clip: float = 5.0
disable_value: bool = False
early_stop_kl: Optional[float] = None # e.g. 0.1
early_stop_imp_ratio: Optional[float] = None # e.g., 10.0
adaptive_kl_ctl: bool = False
adaptive_kl_target: Optional[float] = 6
adaptive_kl_horizon: Optional[float] = 10000
enable_save: bool = True
value_norm: bool = False
value_norm_type: str = dataclasses.field(
metadata={"choices": ["exp", "ma"]}, default="exp"
)
value_norm_beta: float = 0.99995
value_norm_eps: float = 1e-5
group_size: int = 1
generation_size: Optional[int] = None
mask_no_eos_with_zero: bool = False
group_adv_norm: bool = False
mask_too_long: bool = False
use_dense_reward: bool = False
reward_delta: bool = True
token_normalize_scope: Literal["global", "dp"] = "global"
rew_inf_args: Dict = dataclasses.field(default_factory=dict)
def __post_init__(self):
if self.adaptive_kl_ctl:
assert self.adaptive_kl_target is not None
assert self.adaptive_kl_horizon is not None
self.kl_adapter = ppo_functional.AdaptiveKLController(
self.kl_ctl, self.adaptive_kl_target, self.adaptive_kl_horizon
)
else:
self.kl_adapter = ppo_functional.FixedKLController(self.kl_ctl)
if self.value_norm:
from realhf.impl.model.modules import (
ExponentialRunningMeanStd,
MovingAverageRunningMeanStd,
)
if self.value_norm_type == "exp":
self.rms = ExponentialRunningMeanStd(
beta=self.value_norm_beta, epsilon=self.value_norm_eps
)
elif self.value_norm_type == "ma":
self.rms = MovingAverageRunningMeanStd()
else:
raise ValueError(f"Unknown value_norm_type {self.value_norm_type}")
self.kl_ctl = None
self.gconfig = model_api.GenerationHyperparameters(**self.generation_config)
if self.generation_size is not None:
assert self.generation_size >= self.group_size
else:
self.generation_size = self.group_size
self.gconfig.n = self.generation_size
def save(self, model: model_api.Model, save_dir: str):
if not self.enable_save:
return
module = model.module
if not isinstance(module, ReaLModel):
module = module.module
module.save_to_hf(
tokenizer=model.tokenizer,
save_dir=save_dir,
)
def _dispatch_tasks(self, data):
math_data, code_data, rlhf_data, ref_data = data, data, data, data
return math_data, code_data, rlhf_data, ref_data
def _gather_tasks(self, data_map):
# merge SequenceSamples from math_data, code_data, rlhf_data, ref_data
return data_map.get(TASK_TYPE_REF, None)
@torch.no_grad()
def ref_inference(
self, model: model_api.Model, input_: SequenceSample, mb_spec: MicroBatchSpec
):
module = model.module
module.eval()
# This post_hook will gather log probabilities in mini-batches,
# reducing peak memory usage.
def calc_logprobs(logits, input_):
logits /= self.gconfig.temperature
input_lens = torch.tensor(input_.seqlens["packed_input_ids"]).view(-1)
cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
logprobs = gather_packed_shifted_log_probs(
logits, cu_seqlens, input_.data["packed_input_ids"]
)
return logprobs
input_flattend = SequenceSample.from_default(
ids=list(range(input_.bs * self.group_size)),
seqlens=flat2d(input_.seqlens["packed_input_ids"]),
data=dict(packed_input_ids=input_.data["packed_input_ids"]),
)
# add posthook to avoid storing full logits
logprobs = module.forward(
input_=input_flattend,
post_hook=calc_logprobs,
output_seqlens=[
[x - 1 for x in slens]
for slens in input_flattend.seqlens["packed_input_ids"]
],
mb_spec=mb_spec,
)
res = SequenceSample(
keys=["packed_ref_logprobs"],
ids=input_.ids,
dtypes=dict(packed_ref_logprobs=model.module.dtype),
trailing_shapes=dict(packed_ref_logprobs=()),
data=dict(packed_ref_logprobs=logprobs),
seqlens=dict(
packed_ref_logprobs=[
[x - 1 for x in slen] for slen in input_.seqlens["packed_input_ids"]
]
),
)
return res
def inference(
self,
model: model_api.Model,
input_: SequenceSample,
mb_spec: MicroBatchSpec,
) -> SequenceSample:
math_data, code_data, rlhf_data, ref_data = self._dispatch_tasks(input_)
if not hasattr(self, "rew_inf_args") or not isinstance(self.rew_inf_args, dict):
raise ValueError("Invalid rew_inf_args. Expected a dictionary.")
rewardInterface = PackedRewardInterface(**self.rew_inf_args)
logger.info(f"self.rew_inf_args: {self.rew_inf_args}, input_: {input_}")
task_map = {
TASK_TYPE_REF: (self.ref_inference, ref_data),
TASK_TYPE_RW_MATH: (rewardInterface.inference, math_data),
TASK_TYPE_RW_CODE: (rewardInterface.inference, code_data),
}
def _task_func(func, task_type: str):
def _wrapped_func(*args, **kwargs):
start_time = time.perf_counter()
logger.info(f"[{task_type}] ref_rw task start @ {start_time:.4f}")
try:
result = func(*args, **kwargs)
except Exception as e:
logger.error(f"{task_type} ref_rw task failed: {e}")
finally:
duration = time.perf_counter() - start_time
logger.info(
f"[{task_type}] ref_rw task cost: {duration:.4f}s, start @ {start_time:.4f}"
)
return result
return _wrapped_func
async def _run_tasks() -> dict:
tasks = []
for task_type, (func, data) in task_map.items():
if not data:
continue
task_func = _task_func(func, task_type)
task_args = (model, data, mb_spec)
task = asyncio.create_task(asyncio.to_thread(task_func, *task_args))
tasks.append((task_type, task))
results = {}
for task_type, task in tasks:
try:
results[task_type] = await task
except Exception as e:
logger.error(f"{task_type} task failed: {e}")
results[task_type] = None
return results
task_results = asyncio.run(_run_tasks())
final_result = self._gather_tasks(task_results)
return final_result
# Mock methods for profiling only.
def _mock_inference(
self,
model: model_api.Model,
dataset_input: SequenceSample,
) -> SequenceSample:
prompt_lens = flat2d(dataset_input.seqlens["packed_prompts"])
seqlens = [x + self.gconfig.max_new_tokens for x in prompt_lens]
module = model.module
if not isinstance(module, ReaLModel):
module = module.module
mconfig = module.config
packed_input_ids = torch.randint(
0,
mconfig.vocab_size,
(sum(seqlens),),
dtype=torch.long,
device=model.device,
)
return SequenceSample.from_default(
seqlens=seqlens,
ids=dataset_input.ids,
data=dict(packed_input_ids=packed_input_ids),
)
model_api.register_interface("ref_rw", RefRwInterface)

View File

@ -0,0 +1,317 @@
import functools
import gc
import os
import pickle
import json
import time
from typing import *
import numpy as np
import pynvml
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import transformers
from torch.cuda import is_initialized
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
print(root_dir)
import sys
sys.path.insert(0, root_dir)
from realhf.api.core import data_api, dfg, model_api
from realhf.api.core.config import ModelName
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import constants, logging
from realhf.base.network import find_free_port
from realhf.base.testing import (
init_global_constants,
_DEFAULT_EXPR_NAME,
_DEFAULT_TRIAL_NAME,
)
logger = logging.getLogger("test async ref-rew")
os.environ["REAL_MATH_METADATA_PATH"] = "/storage/datasets/id2info.json"
def loadJson():
dataDir = os.environ["REAL_MATH_METADATA_PATH"]
with open(dataDir, "r") as f:
if dataDir.endswith(".jsonl"):
samples = [json.loads(line) for line in f.readlines()]
else:
samples = json.load(f)
return samples
def _mock_input(batch_size: int, seq_len):
vocab_size = 100
torch.manual_seed(1)
seqs = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long)
samples = loadJson()
id_list = list(samples.keys())
# id_tensor = torch.tensor([id_list[i] for i in range(seqs.shape[0])], dtype=torch.long) # 使用哈希值编码
return data_api.SequenceSample.from_default(
seqlens=[seq_len for _ in range(seqs.shape[0])],
ids=[id_list[i] for i in range(seqs.shape[0])],
data=dict(
packed_input_ids=seqs.view(-1),
# prompt_mask=torch.zeros_like(seqs.view(-1), dtype=torch.bool),
packed_prompts=seqs[:, :seq_len].contiguous().view(-1),
),
)
def funcion_call(
rpc_name: str,
rank: int,
world_size: int,
model_path: str,
model_family_name: str,
dp: int,
pp: int,
tp: int,
interface_type: dfg.ModelInterfaceType,
interface_impl: dfg.ModelInterfaceAbstraction,
batch_size: int,
prompt_len: int,
input_: data_api.SequenceSample | None,
port: int,
):
# assert not torch.cuda.is_initialized()
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
torch.cuda.set_device(0)
assert world_size == (
dp * pp * tp
), f"dp={dp}, pp={pp}, tp={tp}, world_size={world_size}"
assert batch_size % dp == 0, (batch_size, dp)
# Initialize distributed environment.
model_name = ModelName("default", 0)
if not dist.is_initialized():
logger.info("Setting up distributed environment...")
dist.init_process_group(
"nccl",
rank=rank,
world_size=world_size,
init_method=f"tcp://localhost:{port}",
)
logger.info("Initialized distributed environment.")
init_global_constants(
num_dp=dp,
num_mp=tp,
num_pp=pp,
sequence_parallel=interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
model_name=model_name,
max_prompt_len=prompt_len,
)
torch.cuda.set_device(0)
# NOTE: import here to avoid CUDA re-initialization
from realhf.impl.model.nn.real_llm_api import ReaLModel, add_helper_functions
# Call a method like `config_from_llama` to get the config.
mconfig: ReaLModelConfig = getattr(ReaLModel, f"config_from_{model_family_name}")(
transformers.AutoConfig.from_pretrained(model_path)
)
is_critic = rpc_name in ["critic_inf", "critic_train", "rew_inf"]
mconfig.is_critic = is_critic
with constants.model_scope(model_name):
# Construct the model.
logger.info(f"Loading model from {model_path}...")
module = ReaLModel(mconfig, dtype=torch.bfloat16, device="cuda")
setattr(ReaLModel, "save_to_hf", getattr(ReaLModel, f"to_{model_family_name}"))
setattr(
ReaLModel, "load_from_hf", getattr(ReaLModel, f"from_{model_family_name}")
)
module._instantiation_hooks.append(
lambda: getattr(module, f"from_{model_family_name}")(
load_dir=model_path,
init_critic_from_actor=is_critic,
)
)
add_helper_functions(module)
module.instantiate()
module.eval()
tokenizer = data_api.load_hf_tokenizer(model_path)
model = model_api.Model(
name=model_name,
module=module,
tokenizer=tokenizer,
device=module.device,
dtype=module.dtype,
)
if interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
from realhf.impl.model.backend.megatron import MegatronTrainBackend
backend = MegatronTrainBackend()
else:
from realhf.impl.model.backend.inference import PipelineInferenceBackend
backend = PipelineInferenceBackend()
logger.info("Running backend initialization...")
ft_spec = model_api.FinetuneSpec(
total_train_epochs=1,
dataset_size=128,
train_batch_size=128,
)
model = backend.initialize(model, ft_spec)
interface = model_api.make_interface(interface_impl)
if input_ is None:
input_ = _mock_input(batch_size, prompt_len)
input_ = input_.cuda()
mb_spec = model_api.MicroBatchSpec()
logger.info("Running interface computation...")
start = time.perf_counter_ns()
if interface_type == dfg.ModelInterfaceType.GENERATE:
res = interface.generate(model, input_, mb_spec)
elif interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
res = interface.train_step(model, input_)
else:
res = interface.inference(model, input_, mb_spec)
if constants.model_parallel_rank() == 0 and constants.is_last_pipe_stage():
if isinstance(res, data_api.SequenceSample):
res = res.cpu()
comsumed = time.perf_counter_ns() - start
logger.info(f"{rpc_name} Computation done. {comsumed} ns")
return res
def run_function_call(
rpc_name: str,
model_path: str,
model_family_name: str,
batch_size: int,
prompt_len: int,
gen_len: int,
input_: data_api.SequenceSample | None,
) -> data_api.SequenceSample | None:
assert rpc_name in [
"actor_gen",
"actor_train",
"critic_inf",
"rew_inf",
"critic_train",
"ref_inf",
"ref_rw",
]
ref_rw_interface = dfg.ModelInterfaceAbstraction(
"ref_rw",
args=dict(
generation_config=dict(
max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True
),
rew_inf_args=dict(
tokenizer_path=model_path,
),
),
)
ppo_actor_interface = dfg.ModelInterfaceAbstraction(
"ppo_actor",
args=dict(
generation_config=dict(
max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True
),
rew_inf_args=dict(
tokenizer_path=model_path,
),
),
)
ppo_critic_interface = dfg.ModelInterfaceAbstraction("ppo_critic")
rw_interface = dfg.ModelInterfaceAbstraction(
"paired_rw",
)
if rpc_name == "actor_gen":
interface_type = dfg.ModelInterfaceType.GENERATE
interface_impl = ppo_actor_interface
elif rpc_name == "actor_train":
interface_type = dfg.ModelInterfaceType.TRAIN_STEP
interface_impl = ppo_actor_interface
elif rpc_name == "critic_inf":
interface_type = dfg.ModelInterfaceType.INFERENCE
interface_impl = ppo_critic_interface
elif rpc_name == "ref_inf":
interface_type = dfg.ModelInterfaceType.INFERENCE
interface_impl = ppo_actor_interface
elif rpc_name == "ref_rw":
interface_type = dfg.ModelInterfaceType.INFERENCE
interface_impl = ref_rw_interface
elif rpc_name == "critic_train":
interface_type = dfg.ModelInterfaceType.TRAIN_STEP
interface_impl = ppo_critic_interface
else:
interface_type = dfg.ModelInterfaceType.INFERENCE
interface_impl = rw_interface
logger.info(f"Running RPC {rpc_name}...")
port = find_free_port()
res = funcion_call(
rank=0,
rpc_name=rpc_name,
world_size=1,
model_path=model_path,
model_family_name=model_family_name,
dp=1,
pp=1,
tp=1,
interface_type=interface_type,
interface_impl=interface_impl,
batch_size=batch_size,
prompt_len=prompt_len,
input_=input_,
port=port,
)
gc.collect()
torch.cuda.empty_cache()
gc.collect()
if isinstance(res, data_api.SequenceSample):
return res
else:
logger.info(f"RPC {rpc_name} stats: {res}")
def main():
mp.set_start_method("spawn", force=True)
model_family_name = "qwen2"
batch_size = 16
prompt_len = 128
gen_len = 4096
model_path = "/storage/models/DeepSeek-R1-Distill-Qwen-1.5B"
constants.set_experiment_trial_names(_DEFAULT_EXPR_NAME, _DEFAULT_TRIAL_NAME)
for i in range(2):
ref_rw_res = run_function_call(
"ref_rw",
model_family_name=model_family_name,
model_path=model_path,
batch_size=batch_size,
prompt_len=prompt_len,
gen_len=gen_len,
input_=None,
)
if __name__ == "__main__":
main()