create ref-rw inference aysnc mode

This commit is contained in:
Jun Mo 2025-03-13 18:06:38 +08:00
parent 234e3dd3a0
commit 8732811c73
6 changed files with 651 additions and 446 deletions

View File

@ -19,7 +19,7 @@ def check_is_realhf_native_impl(_cls):
def check_is_realhf_native_model_interface(name): def check_is_realhf_native_model_interface(name):
# NOTE: we should not import iterfaces here, # NOTE: we should not import iterfaces here,
# such that we can avoid CUDA initialization. # such that we can avoid CUDA initialization.
return name in ["ppo_actor", "ppo_critic", "sft"] return name in ["ppo_actor", "ppo_critic", "sft", "ref_rw"]
def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]): def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]):

View File

@ -253,7 +253,7 @@ class PPOMATHConfig(CommonExperimentConfig):
"actor": self.actor, "actor": self.actor,
# "critic": self.critic, # "critic": self.critic,
"ref": self.ref, "ref": self.ref,
"reward": self.rew, # "reward": self.rew,
} }
else: else:
return { return {
@ -304,8 +304,6 @@ class PPOMATHConfig(CommonExperimentConfig):
"reward_delta": self.reward_delta, "reward_delta": self.reward_delta,
}, },
) )
ref_interface = copy.deepcopy(actor_interface)
ref_interface.args["enable_save"] = False
critic_interface = ModelInterfaceAbstraction( critic_interface = ModelInterfaceAbstraction(
"ppo_critic", "ppo_critic",
@ -319,7 +317,7 @@ class PPOMATHConfig(CommonExperimentConfig):
) )
critic_interface.args.pop("eps_clip") critic_interface.args.pop("eps_clip")
rw_interface = ModelInterfaceAbstraction( rw_interface = ModelInterfaceAbstraction(
"rw_math", "reward",
args=dict( args=dict(
rw_type=self.rw_type, rw_type=self.rw_type,
task=self.task, task=self.task,
@ -368,9 +366,16 @@ class PPOMATHConfig(CommonExperimentConfig):
n_seqs=self.dataset.train_bs_n_seqs, n_seqs=self.dataset.train_bs_n_seqs,
) )
inf_ref_inputs = ["packed_input_ids"] # add rew param into ref MFC
inf_ref_inputs = ["packed_input_ids", "packed_prompts"]
inf_ref_outputs = ["packed_ref_logprobs", "rewards", "dense_rewards"]
ref_interface = copy.deepcopy(actor_interface)
ref_interface.type_ = "ref_rw"
ref_interface.args["enable_save"] = False
ref_interface.args["rew_inf_args"] = copy.deepcopy(rw_interface.args)
inf_ref_logits = MFCDef( inf_ref_logits = MFCDef(
name="ref_inf", name="ref_rw",
model_name="ref", model_name="ref",
mb_spec=self.ref_inf.mb_spec, mb_spec=self.ref_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE, interface_type=ModelInterfaceType.INFERENCE,
@ -379,7 +384,7 @@ class PPOMATHConfig(CommonExperimentConfig):
interface_impl=ref_interface, interface_impl=ref_interface,
min_n_seqs_per_pass=1 / self.group_size, min_n_seqs_per_pass=1 / self.group_size,
input_keys=inf_ref_inputs, input_keys=inf_ref_inputs,
output_keys=["packed_ref_logprobs"], output_keys=inf_ref_outputs,
n_seqs=self.dataset.train_bs_n_seqs, n_seqs=self.dataset.train_bs_n_seqs,
) )
@ -451,8 +456,9 @@ class PPOMATHConfig(CommonExperimentConfig):
"actor_train": train_actor, "actor_train": train_actor,
# "critic_inf": inf_values, # "critic_inf": inf_values,
# "critic_train": train_critic, # "critic_train": train_critic,
"ref_inf": inf_ref_logits, # "ref_inf": inf_ref_logits,
"rew_inf": inf_reward, # "rew_inf": inf_reward,
"ref_rw": inf_ref_logits,
} }
else: else:
return { return {
@ -472,8 +478,9 @@ class PPOMATHConfig(CommonExperimentConfig):
"actor_train": self.actor_train, "actor_train": self.actor_train,
# "critic_inf": self.critic_inf, # "critic_inf": self.critic_inf,
# "critic_train": self.critic_train, # "critic_train": self.critic_train,
"ref_inf": self.ref_inf, # "ref_inf": self.ref_inf,
"rew_inf": self.rew_inf, # "rew_inf": self.rew_inf,
"ref_rw": self.ref_inf,
} }
else: else:
return { return {

View File

@ -1,429 +0,0 @@
# Copyright 2025 Ant Group Inc.
import collections
import dataclasses
import html
import json
import os
import re
import xml.etree.ElementTree as ET
from ast import parse
from concurrent.futures import ThreadPoolExecutor, as_completed
import colorama
import numpy as np
import torch
import torch.distributed as dist
import realhf.api.core.model_api as model_api
import realhf.base.logging as logging
from functioncall.code.verify import code_verify
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer
from realhf.base import constants
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
class CodeVerifierException(Exception):
pass
def extract_python_code(text, min_length=20, strict_syntax=True):
code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```"
code_blocks = re.findall(code_pattern, text, re.DOTALL)
valid_blocks = []
for block in code_blocks:
clean_block = block.strip()
if len(clean_block) < min_length:
continue
# verify code syntax
if strict_syntax:
try:
parse(clean_block, mode="exec")
except (SyntaxError, IndentationError):
continue
valid_blocks.append(clean_block)
if not valid_blocks:
return None
# return the last code block
return valid_blocks[-1]
def check_with_elementtree(text):
def escape_between_tags(text, tags=["think", "answer"]):
"""转义标签之间的内容,但保留标签本身."""
# 构建标签模式
tag_pattern = "|".join(tags)
parts = []
current_pos = 0
# 匹配开始和结束标签
pattern = f"</?({tag_pattern})[^>]*>"
for match in re.finditer(pattern, text):
# 添加标签之前的内容(需要转义)
if current_pos < match.start():
parts.append(html.escape(text[current_pos : match.start()]))
# 添加标签本身(不转义)
parts.append(match.group())
current_pos = match.end()
# 添加最后剩余的内容
if current_pos < len(text):
parts.append(html.escape(text[current_pos:]))
return "".join(parts)
text = escape_between_tags(text)
if not text.strip().startswith("<think>"):
text = "<think>" + text
try:
xml_text = f"<root>{text}</root>"
x = ET.fromstring(xml_text)
if x.text is not None and x.text.strip() != "":
return False, f"Error: extra content before <think>. {x.text}"
if len(x) != 2:
return False, f"Error: there are {len(x)} tags."
if x[0].tag is None or x[0].tag != "think":
return False, f"Error: <think> tag is missing. {x[0].tag}"
if x[0].tail is not None and x[0].tail.strip() != "":
return (
False,
f"Error: extra content between <think> and <answer>. {x[0].tail}",
)
if x[1].tag is None or x[1].tag != "answer":
return False, f"Error: <answer> tag is missing. {x[1].tag}"
if x[1].tail is not None and x[1].tail.strip() != "":
return False, f"Error: extra content after <answer>, {x[1].tail}"
return True, x[1].text if x[1].text is not None else ""
except ET.ParseError as e:
return False, f"Error: XML格式错误, {str(e)}"
def retokenize(
task,
tokenizer,
packed_input_ids,
input_cu_seqlens,
prompts,
prompt_cu_seqlens,
query_ids,
check_xml_format=False,
do_eval=False,
):
input_ids = [
packed_input_ids[start:end]
for start, end in zip(input_cu_seqlens[:-1], input_cu_seqlens[1:])
]
prompt_ids = [
prompts[start:end]
for start, end in zip(prompt_cu_seqlens[:-1], prompt_cu_seqlens[1:])
]
seq_strs = tokenizer.batch_decode(
input_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
)
prompt_strs = tokenizer.batch_decode(
prompt_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
)
# query_id_strs = query_ids
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
format_rewards = []
queryid_to_results = collections.defaultdict(list)
# 8 processes on each node, with 10 subprocesses each
if do_eval == True:
_answers = [
seq_str.split(prompt_str)[1]
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
]
codes = [extract_python_code(_answer) for _answer in _answers]
logger.info(
f"code_rw_interface, size: {len(query_id_strs)}, valid code size: {len(codes)}, query_id_0: {query_id_strs[0]}"
)
format_rewards = code_verify(codes, query_id_strs)
if check_xml_format:
with ThreadPoolExecutor(max_workers=22) as executor:
futures = [
executor.submit(check_with_elementtree, answer_str)
for answer_str in _answers
]
# xml_rewards = []
for idx, future in enumerate(futures):
xml_reward, _ = future.result()
# xml_rewards.append(xml_reward)
if xml_reward == 1 and format_rewards[idx] == 0:
format_rewards[idx] = -0.8
elif xml_reward == 0 and format_rewards[idx] == 0:
format_rewards[idx] = -1
for query_id_str, format_reward in zip(query_id_strs, format_rewards):
if query_id_str not in queryid_to_results:
queryid_to_results[query_id_str] = []
queryid_to_results[query_id_str].append(format_reward)
else:
for query_id_str in query_id_strs:
if query_id_str not in queryid_to_results:
queryid_to_results[query_id_str] = []
queryid_to_results[query_id_str].append(0)
format_rewards.append(0)
return format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results
@dataclasses.dataclass
class PackedCodeRewardInterface(model_api.ModelInterface):
enable_save: bool = False
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
output_scaling: float = 1.0
rm_output_scaling: float = 1.0
rm_output_bias: float = 0.0
output_bias: float = 0.0
loss_fun = torch.nn.CrossEntropyLoss(reduction="none")
max_sync_length: int = 2048
rw_type: str = "sparse"
task: str = "code" # math or countdown or code
check_xml_format: bool = False
post_process: str = "sigmoid"
group_size: int = 1
check_verifier_status: bool = False
_call_count: int = 0
def __post_init__(self):
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
logger.info(f"rm_output_scaling: {self.rm_output_scaling}")
logger.info(f"rm_output_bias: {self.rm_output_bias}")
logger.info(f"output_scaling: {self.output_scaling}")
logger.info(f"output_bias: {self.output_bias}")
logger.info(f"max_sync_length: {self.max_sync_length}")
logger.info(f"rw_type: {self.rw_type}")
logger.info(f"post_process: {self.post_process}")
while True:
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
f"v{self._call_count}r{dist.get_rank()}.txt",
)
if os.path.exists(gen_file_path):
self._call_count += 1
else:
break
logger.info(f"call_count: {self._call_count}")
def inference(
self,
model: model_api.Model,
data_: SequenceSample,
mb_spec,
) -> SequenceSample:
packed_input_ids: torch.Tensor = data_.data["packed_input_ids"].squeeze()
input_seqlens = torch.tensor(data_.seqlens["packed_input_ids"]).view(-1)
input_cu_seqlens = torch.nn.functional.pad(
input_seqlens.cumsum(0), (1, 0)
).int()
packed_prompts = data_.data["packed_prompts"]
prompts = []
prompt_seqlens = []
offset = 0
for x in data_.seqlens["packed_prompts"]:
prompts += [packed_prompts[offset : offset + x[0]]] * self.group_size
offset += x[0]
prompt_seqlens.extend(x * self.group_size)
assert offset == sum(x[0] for x in data_.seqlens["packed_prompts"])
# non_packed_prompts = copy.deepcopy(prompts)
prompts = torch.cat(prompts)
prompt_seqlens = torch.tensor(prompt_seqlens).view(-1)
prompt_cu_seqlens = torch.nn.functional.pad(
prompt_seqlens.cumsum(0), (1, 0)
).int()
query_ids = [data_id for data_id in data_.ids for _ in range(self.group_size)]
format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results = (
retokenize(
self.task,
self.tokenizer,
packed_input_ids,
input_cu_seqlens,
prompts,
prompt_cu_seqlens,
query_ids,
check_xml_format=self.check_xml_format,
do_eval=constants.is_last_pipe_stage(),
)
)
assert self.rw_type == "sparse"
dense_scores = torch.zeros_like(packed_input_ids).float()
scores = torch.FloatTensor(format_rewards).to(packed_input_ids.device)
scores[scores == 0] = -1
if len(scores) == 0:
return None
assert dense_scores.shape == packed_input_ids.shape
scores = (
scores.to(packed_input_ids.device) - self.output_bias
) * self.output_scaling
logger.info(f"Code reward logging info @v{model.version.global_step}")
logger.info(
f"before: Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
)
logger.info(
f"number of samples: {len(scores)}, {scores.shape}, group_size: {self.group_size}"
)
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
f"v{self._call_count}r{dist.get_rank()}.txt",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
info = "\n".join(
[
f"idx: {idx} / {len(scores)}",
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
]
)
_f.write(info + "\n")
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated_jsonl",
f"v{self._call_count}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}")
with open(gen_file_path, "w") as _f:
for idx, (score, prompt_str, seq_str) in enumerate(
zip(scores, prompt_strs, seq_strs)
):
_f.write(
json.dumps(
{
"prompt": prompt_str,
"generated": seq_str.split(prompt_str)[1],
"reward": score.item(),
},
ensure_ascii=False,
)
+ "\n"
)
logger.info(
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
)
pass_at_k = np.mean(
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
)
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
logger.info(f"reward: {sum(scores) / len(scores)}")
train_pass_monitor_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"training_monitor",
f"v{self._call_count}r{dist.get_rank()}.jsonl",
)
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
logger.info(
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
)
with open(train_pass_monitor_file_path, "w") as monitor_file:
for key, value in queryid_to_results.items():
pass1 = sum(value) / len(value)
pass8 = int(sum(value) > 0)
monitor_file.write(
json.dumps(
{"query_id": key, "pass1": pass1, "pass8": pass8},
ensure_ascii=False,
)
+ "\n"
)
self._call_count += 1
if scores.dtype != torch.float32:
scores = scores.to(torch.float32)
if dense_scores.dtype != torch.float32:
dense_scores = dense_scores.to(torch.float32)
res = SequenceSample(
keys=["rewards", "dense_rewards"],
trailing_shapes=dict(rewards=(), dense_rewards=()),
dtypes=dict(rewards=torch.float32, dense_rewards=torch.float32),
ids=data_.ids,
seqlens=dict(
rewards=[
torch.tensor([1 for _ in range(len(x))], dtype=torch.int32)
for x in data_.seqlens["packed_input_ids"]
],
dense_rewards=data_.seqlens["packed_input_ids"],
),
data=dict(rewards=scores, dense_rewards=dense_scores),
)
# record rewards for each piece of data
avg_scores = []
offset = 0
for i in range(data_.bs):
score_lis = scores[
offset : offset + len(data_.seqlens["packed_input_ids"][i])
]
avg_scores.append(score_lis.mean().item())
offset += len(data_.seqlens["packed_input_ids"][i])
assert offset == sum(len(x) for x in data_.seqlens["packed_input_ids"])
res.metadata["scores"] = avg_scores
if self.check_verifier_status:
avg_score = torch.tensor(
np.mean(avg_scores), device=constants.current_device()
)
dist.all_reduce(
avg_score, op=dist.ReduceOp.SUM, group=constants.parallelism_group()
)
avg_score /= constants.parallelism_group_size()
avg_score = avg_score.item()
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
raise CodeVerifierException(
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
)
return res
model_api.register_interface("rw_code", PackedCodeRewardInterface)

View File

@ -37,10 +37,34 @@ ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
class MathVerifierException(Exception): class VerifierException(Exception):
pass pass
def extract_python_code(text, min_length=20, strict_syntax=True):
code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```"
code_blocks = re.findall(code_pattern, text, re.DOTALL)
valid_blocks = []
for block in code_blocks:
clean_block = block.strip()
if len(clean_block) < min_length:
continue
# verify code syntax
if strict_syntax:
try:
parse(clean_block, mode="exec")
except (SyntaxError, IndentationError):
continue
valid_blocks.append(clean_block)
if not valid_blocks:
return None
# return the last code block
return valid_blocks[-1]
def check_with_elementtree(text): def check_with_elementtree(text):
def escape_between_tags(text, tags=["think", "answer"]): def escape_between_tags(text, tags=["think", "answer"]):
"""转义标签之间的内容,但保留标签本身.""" """转义标签之间的内容,但保留标签本身."""
@ -94,6 +118,19 @@ def check_with_elementtree(text):
return False, f"Error: XML格式错误, {str(e)}" return False, f"Error: XML格式错误, {str(e)}"
def reward_caculate(task, _answers, query_id_strs):
format_rewards = []
if task == "math":
format_rewards = math_verify_call(_answers, query_id_strs)
else:
codes = [extract_python_code(_answer) for _answer in _answers]
format_rewards = code_verify(codes, query_id_strs)
logger.info(
f"reward_caculate, task: {task}, size: {len(query_id_strs)}, query_id_0: {query_id_strs[0]}"
)
return format_rewards
def retokenize( def retokenize(
task, task,
tokenizer, tokenizer,
@ -122,6 +159,10 @@ def retokenize(
# query_id_strs = query_ids # query_id_strs = query_ids
query_id_strs = [query_id.split("@")[0] for query_id in query_ids] query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
logger.info(
f"retokenize, query_id_strs:{query_id_strs}, seq_strs:{seq_strs}, prompt_strs:{prompt_strs}"
)
format_rewards = [] format_rewards = []
queryid_to_results = collections.defaultdict(list) queryid_to_results = collections.defaultdict(list)
# 8 processes on each node, with 10 subprocesses each # 8 processes on each node, with 10 subprocesses each
@ -162,8 +203,7 @@ def retokenize(
@dataclasses.dataclass @dataclasses.dataclass
class PackedMathRewardInterface(model_api.ModelInterface): class PackedRewardInterface(model_api.ModelInterface):
enable_save: bool = False enable_save: bool = False
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B" tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
output_scaling: float = 1.0 output_scaling: float = 1.0
@ -392,7 +432,7 @@ class PackedMathRewardInterface(model_api.ModelInterface):
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score): if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
raise MathVerifierException( raise VerifierException(
"All rewards are at minimal value. Probably there are something wrong with the verifier!" "All rewards are at minimal value. Probably there are something wrong with the verifier!"
) )
@ -401,4 +441,4 @@ class PackedMathRewardInterface(model_api.ModelInterface):
return res return res
model_api.register_interface("rw_math", PackedMathRewardInterface) model_api.register_interface("reward", PackedRewardInterface)

View File

@ -0,0 +1,270 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import dataclasses
import functools
import itertools
import time
from typing import Dict, Literal, Optional, Tuple
import torch
import realhf.api.core.model_api as model_api
import realhf.base.logging as logging
import realhf.impl.model.utils.ppo_functional as ppo_functional
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample
from realhf.base.datapack import flat2d
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.utils.functional import (
gather_packed_shifted_log_probs,
masked_normalization,
)
from realhf.impl.model.interface.rw_interface import PackedRewardInterface
logger = logging.getLogger("RefRwInterface")
TASK_TYPE_REF: Literal["ref"] = "ref"
TASK_TYPE_RW_MATH: Literal["rw_math"] = "rw_math"
TASK_TYPE_RW_CODE: Literal["rw_code"] = "rw_code"
@dataclasses.dataclass
class RefRwInterface(model_api.ModelInterface):
n_minibatches: int = 4
# Use dict here to allow argument passing through commandline.
generation_config: Dict = dataclasses.field(default_factory=dict)
kl_ctl: float = 0.1
adv_norm: bool = True
discount: float = 1.0
gae_lambda: float = 1.0
eps_clip: float = 0.2
value_eps_clip: float = 0.2
max_reward_clip: float = 5.0
disable_value: bool = False
early_stop_kl: Optional[float] = None # e.g. 0.1
early_stop_imp_ratio: Optional[float] = None # e.g., 10.0
adaptive_kl_ctl: bool = False
adaptive_kl_target: Optional[float] = 6
adaptive_kl_horizon: Optional[float] = 10000
enable_save: bool = True
value_norm: bool = False
value_norm_type: str = dataclasses.field(
metadata={"choices": ["exp", "ma"]}, default="exp"
)
value_norm_beta: float = 0.99995
value_norm_eps: float = 1e-5
group_size: int = 1
generation_size: Optional[int] = None
mask_no_eos_with_zero: bool = False
group_adv_norm: bool = False
mask_too_long: bool = False
use_dense_reward: bool = False
reward_delta: bool = True
token_normalize_scope: Literal["global", "dp"] = "global"
rew_inf_args: Dict = dataclasses.field(default_factory=dict)
def __post_init__(self):
if self.adaptive_kl_ctl:
assert self.adaptive_kl_target is not None
assert self.adaptive_kl_horizon is not None
self.kl_adapter = ppo_functional.AdaptiveKLController(
self.kl_ctl, self.adaptive_kl_target, self.adaptive_kl_horizon
)
else:
self.kl_adapter = ppo_functional.FixedKLController(self.kl_ctl)
if self.value_norm:
from realhf.impl.model.modules import (
ExponentialRunningMeanStd,
MovingAverageRunningMeanStd,
)
if self.value_norm_type == "exp":
self.rms = ExponentialRunningMeanStd(
beta=self.value_norm_beta, epsilon=self.value_norm_eps
)
elif self.value_norm_type == "ma":
self.rms = MovingAverageRunningMeanStd()
else:
raise ValueError(f"Unknown value_norm_type {self.value_norm_type}")
self.kl_ctl = None
self.gconfig = model_api.GenerationHyperparameters(**self.generation_config)
if self.generation_size is not None:
assert self.generation_size >= self.group_size
else:
self.generation_size = self.group_size
self.gconfig.n = self.generation_size
def save(self, model: model_api.Model, save_dir: str):
if not self.enable_save:
return
module = model.module
if not isinstance(module, ReaLModel):
module = module.module
module.save_to_hf(
tokenizer=model.tokenizer,
save_dir=save_dir,
)
def _dispatch_tasks(self, data):
math_data, code_data, rlhf_data, ref_data = data, data, data, data
return math_data, code_data, rlhf_data, ref_data
def _gather_tasks(self, data_map):
# merge SequenceSamples from math_data, code_data, rlhf_data, ref_data
return data_map.get(TASK_TYPE_REF, None)
@torch.no_grad()
def ref_inference(
self, model: model_api.Model, input_: SequenceSample, mb_spec: MicroBatchSpec
):
module = model.module
module.eval()
# This post_hook will gather log probabilities in mini-batches,
# reducing peak memory usage.
def calc_logprobs(logits, input_):
logits /= self.gconfig.temperature
input_lens = torch.tensor(input_.seqlens["packed_input_ids"]).view(-1)
cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
logprobs = gather_packed_shifted_log_probs(
logits, cu_seqlens, input_.data["packed_input_ids"]
)
return logprobs
input_flattend = SequenceSample.from_default(
ids=list(range(input_.bs * self.group_size)),
seqlens=flat2d(input_.seqlens["packed_input_ids"]),
data=dict(packed_input_ids=input_.data["packed_input_ids"]),
)
# add posthook to avoid storing full logits
logprobs = module.forward(
input_=input_flattend,
post_hook=calc_logprobs,
output_seqlens=[
[x - 1 for x in slens]
for slens in input_flattend.seqlens["packed_input_ids"]
],
mb_spec=mb_spec,
)
res = SequenceSample(
keys=["packed_ref_logprobs"],
ids=input_.ids,
dtypes=dict(packed_ref_logprobs=model.module.dtype),
trailing_shapes=dict(packed_ref_logprobs=()),
data=dict(packed_ref_logprobs=logprobs),
seqlens=dict(
packed_ref_logprobs=[
[x - 1 for x in slen] for slen in input_.seqlens["packed_input_ids"]
]
),
)
return res
def inference(
self,
model: model_api.Model,
input_: SequenceSample,
mb_spec: MicroBatchSpec,
) -> SequenceSample:
math_data, code_data, rlhf_data, ref_data = self._dispatch_tasks(input_)
if not hasattr(self, "rew_inf_args") or not isinstance(self.rew_inf_args, dict):
raise ValueError("Invalid rew_inf_args. Expected a dictionary.")
rewardInterface = PackedRewardInterface(**self.rew_inf_args)
logger.info(f"self.rew_inf_args: {self.rew_inf_args}, input_: {input_}")
task_map = {
TASK_TYPE_REF: (self.ref_inference, ref_data),
TASK_TYPE_RW_MATH: (rewardInterface.inference, math_data),
TASK_TYPE_RW_CODE: (rewardInterface.inference, code_data),
}
def _task_func(func, task_type: str):
def _wrapped_func(*args, **kwargs):
start_time = time.perf_counter()
logger.info(f"[{task_type}] ref_rw task start @ {start_time:.4f}")
try:
result = func(*args, **kwargs)
except Exception as e:
logger.error(f"{task_type} ref_rw task failed: {e}")
finally:
duration = time.perf_counter() - start_time
logger.info(
f"[{task_type}] ref_rw task cost: {duration:.4f}s, start @ {start_time:.4f}"
)
return result
return _wrapped_func
async def _run_tasks() -> dict:
tasks = []
for task_type, (func, data) in task_map.items():
if not data:
continue
task_func = _task_func(func, task_type)
task_args = (model, data, mb_spec)
task = asyncio.create_task(asyncio.to_thread(task_func, *task_args))
tasks.append((task_type, task))
results = {}
for task_type, task in tasks:
try:
results[task_type] = await task
except Exception as e:
logger.error(f"{task_type} task failed: {e}")
results[task_type] = None
return results
task_results = asyncio.run(_run_tasks())
final_result = self._gather_tasks(task_results)
return final_result
# Mock methods for profiling only.
def _mock_inference(
self,
model: model_api.Model,
dataset_input: SequenceSample,
) -> SequenceSample:
prompt_lens = flat2d(dataset_input.seqlens["packed_prompts"])
seqlens = [x + self.gconfig.max_new_tokens for x in prompt_lens]
module = model.module
if not isinstance(module, ReaLModel):
module = module.module
mconfig = module.config
packed_input_ids = torch.randint(
0,
mconfig.vocab_size,
(sum(seqlens),),
dtype=torch.long,
device=model.device,
)
return SequenceSample.from_default(
seqlens=seqlens,
ids=dataset_input.ids,
data=dict(packed_input_ids=packed_input_ids),
)
model_api.register_interface("ref_rw", RefRwInterface)

View File

@ -0,0 +1,317 @@
import functools
import gc
import os
import pickle
import json
import time
from typing import *
import numpy as np
import pynvml
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import transformers
from torch.cuda import is_initialized
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
print(root_dir)
import sys
sys.path.insert(0, root_dir)
from realhf.api.core import data_api, dfg, model_api
from realhf.api.core.config import ModelName
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import constants, logging
from realhf.base.network import find_free_port
from realhf.base.testing import (
init_global_constants,
_DEFAULT_EXPR_NAME,
_DEFAULT_TRIAL_NAME,
)
logger = logging.getLogger("test async ref-rew")
os.environ["REAL_MATH_METADATA_PATH"] = "/storage/datasets/id2info.json"
def loadJson():
dataDir = os.environ["REAL_MATH_METADATA_PATH"]
with open(dataDir, "r") as f:
if dataDir.endswith(".jsonl"):
samples = [json.loads(line) for line in f.readlines()]
else:
samples = json.load(f)
return samples
def _mock_input(batch_size: int, seq_len):
vocab_size = 100
torch.manual_seed(1)
seqs = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long)
samples = loadJson()
id_list = list(samples.keys())
# id_tensor = torch.tensor([id_list[i] for i in range(seqs.shape[0])], dtype=torch.long) # 使用哈希值编码
return data_api.SequenceSample.from_default(
seqlens=[seq_len for _ in range(seqs.shape[0])],
ids=[id_list[i] for i in range(seqs.shape[0])],
data=dict(
packed_input_ids=seqs.view(-1),
# prompt_mask=torch.zeros_like(seqs.view(-1), dtype=torch.bool),
packed_prompts=seqs[:, :seq_len].contiguous().view(-1),
),
)
def funcion_call(
rpc_name: str,
rank: int,
world_size: int,
model_path: str,
model_family_name: str,
dp: int,
pp: int,
tp: int,
interface_type: dfg.ModelInterfaceType,
interface_impl: dfg.ModelInterfaceAbstraction,
batch_size: int,
prompt_len: int,
input_: data_api.SequenceSample | None,
port: int,
):
# assert not torch.cuda.is_initialized()
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
torch.cuda.set_device(0)
assert world_size == (
dp * pp * tp
), f"dp={dp}, pp={pp}, tp={tp}, world_size={world_size}"
assert batch_size % dp == 0, (batch_size, dp)
# Initialize distributed environment.
model_name = ModelName("default", 0)
if not dist.is_initialized():
logger.info("Setting up distributed environment...")
dist.init_process_group(
"nccl",
rank=rank,
world_size=world_size,
init_method=f"tcp://localhost:{port}",
)
logger.info("Initialized distributed environment.")
init_global_constants(
num_dp=dp,
num_mp=tp,
num_pp=pp,
sequence_parallel=interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
model_name=model_name,
max_prompt_len=prompt_len,
)
torch.cuda.set_device(0)
# NOTE: import here to avoid CUDA re-initialization
from realhf.impl.model.nn.real_llm_api import ReaLModel, add_helper_functions
# Call a method like `config_from_llama` to get the config.
mconfig: ReaLModelConfig = getattr(ReaLModel, f"config_from_{model_family_name}")(
transformers.AutoConfig.from_pretrained(model_path)
)
is_critic = rpc_name in ["critic_inf", "critic_train", "rew_inf"]
mconfig.is_critic = is_critic
with constants.model_scope(model_name):
# Construct the model.
logger.info(f"Loading model from {model_path}...")
module = ReaLModel(mconfig, dtype=torch.bfloat16, device="cuda")
setattr(ReaLModel, "save_to_hf", getattr(ReaLModel, f"to_{model_family_name}"))
setattr(
ReaLModel, "load_from_hf", getattr(ReaLModel, f"from_{model_family_name}")
)
module._instantiation_hooks.append(
lambda: getattr(module, f"from_{model_family_name}")(
load_dir=model_path,
init_critic_from_actor=is_critic,
)
)
add_helper_functions(module)
module.instantiate()
module.eval()
tokenizer = data_api.load_hf_tokenizer(model_path)
model = model_api.Model(
name=model_name,
module=module,
tokenizer=tokenizer,
device=module.device,
dtype=module.dtype,
)
if interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
from realhf.impl.model.backend.megatron import MegatronTrainBackend
backend = MegatronTrainBackend()
else:
from realhf.impl.model.backend.inference import PipelineInferenceBackend
backend = PipelineInferenceBackend()
logger.info("Running backend initialization...")
ft_spec = model_api.FinetuneSpec(
total_train_epochs=1,
dataset_size=128,
train_batch_size=128,
)
model = backend.initialize(model, ft_spec)
interface = model_api.make_interface(interface_impl)
if input_ is None:
input_ = _mock_input(batch_size, prompt_len)
input_ = input_.cuda()
mb_spec = model_api.MicroBatchSpec()
logger.info("Running interface computation...")
start = time.perf_counter_ns()
if interface_type == dfg.ModelInterfaceType.GENERATE:
res = interface.generate(model, input_, mb_spec)
elif interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
res = interface.train_step(model, input_)
else:
res = interface.inference(model, input_, mb_spec)
if constants.model_parallel_rank() == 0 and constants.is_last_pipe_stage():
if isinstance(res, data_api.SequenceSample):
res = res.cpu()
comsumed = time.perf_counter_ns() - start
logger.info(f"{rpc_name} Computation done. {comsumed} ns")
return res
def run_function_call(
rpc_name: str,
model_path: str,
model_family_name: str,
batch_size: int,
prompt_len: int,
gen_len: int,
input_: data_api.SequenceSample | None,
) -> data_api.SequenceSample | None:
assert rpc_name in [
"actor_gen",
"actor_train",
"critic_inf",
"rew_inf",
"critic_train",
"ref_inf",
"ref_rw",
]
ref_rw_interface = dfg.ModelInterfaceAbstraction(
"ref_rw",
args=dict(
generation_config=dict(
max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True
),
rew_inf_args=dict(
tokenizer_path=model_path,
),
),
)
ppo_actor_interface = dfg.ModelInterfaceAbstraction(
"ppo_actor",
args=dict(
generation_config=dict(
max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True
),
rew_inf_args=dict(
tokenizer_path=model_path,
),
),
)
ppo_critic_interface = dfg.ModelInterfaceAbstraction("ppo_critic")
rw_interface = dfg.ModelInterfaceAbstraction(
"paired_rw",
)
if rpc_name == "actor_gen":
interface_type = dfg.ModelInterfaceType.GENERATE
interface_impl = ppo_actor_interface
elif rpc_name == "actor_train":
interface_type = dfg.ModelInterfaceType.TRAIN_STEP
interface_impl = ppo_actor_interface
elif rpc_name == "critic_inf":
interface_type = dfg.ModelInterfaceType.INFERENCE
interface_impl = ppo_critic_interface
elif rpc_name == "ref_inf":
interface_type = dfg.ModelInterfaceType.INFERENCE
interface_impl = ppo_actor_interface
elif rpc_name == "ref_rw":
interface_type = dfg.ModelInterfaceType.INFERENCE
interface_impl = ref_rw_interface
elif rpc_name == "critic_train":
interface_type = dfg.ModelInterfaceType.TRAIN_STEP
interface_impl = ppo_critic_interface
else:
interface_type = dfg.ModelInterfaceType.INFERENCE
interface_impl = rw_interface
logger.info(f"Running RPC {rpc_name}...")
port = find_free_port()
res = funcion_call(
rank=0,
rpc_name=rpc_name,
world_size=1,
model_path=model_path,
model_family_name=model_family_name,
dp=1,
pp=1,
tp=1,
interface_type=interface_type,
interface_impl=interface_impl,
batch_size=batch_size,
prompt_len=prompt_len,
input_=input_,
port=port,
)
gc.collect()
torch.cuda.empty_cache()
gc.collect()
if isinstance(res, data_api.SequenceSample):
return res
else:
logger.info(f"RPC {rpc_name} stats: {res}")
def main():
mp.set_start_method("spawn", force=True)
model_family_name = "qwen2"
batch_size = 16
prompt_len = 128
gen_len = 4096
model_path = "/storage/models/DeepSeek-R1-Distill-Qwen-1.5B"
constants.set_experiment_trial_names(_DEFAULT_EXPR_NAME, _DEFAULT_TRIAL_NAME)
for i in range(2):
ref_rw_res = run_function_call(
"ref_rw",
model_family_name=model_family_name,
model_path=model_path,
batch_size=batch_size,
prompt_len=prompt_len,
gen_len=gen_len,
input_=None,
)
if __name__ == "__main__":
main()