diff --git a/realhf/experiments/common/check.py b/realhf/experiments/common/check.py index 33266cf..86524ba 100644 --- a/realhf/experiments/common/check.py +++ b/realhf/experiments/common/check.py @@ -19,7 +19,7 @@ def check_is_realhf_native_impl(_cls): def check_is_realhf_native_model_interface(name): # NOTE: we should not import iterfaces here, # such that we can avoid CUDA initialization. - return name in ["ppo_actor", "ppo_critic", "sft"] + return name in ["ppo_actor", "ppo_critic", "sft", "ref_rw"] def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation]): diff --git a/realhf/experiments/common/ppo_math_exp.py b/realhf/experiments/common/ppo_math_exp.py index 727b35c..3a5e66a 100644 --- a/realhf/experiments/common/ppo_math_exp.py +++ b/realhf/experiments/common/ppo_math_exp.py @@ -253,7 +253,7 @@ class PPOMATHConfig(CommonExperimentConfig): "actor": self.actor, # "critic": self.critic, "ref": self.ref, - "reward": self.rew, + # "reward": self.rew, } else: return { @@ -304,8 +304,6 @@ class PPOMATHConfig(CommonExperimentConfig): "reward_delta": self.reward_delta, }, ) - ref_interface = copy.deepcopy(actor_interface) - ref_interface.args["enable_save"] = False critic_interface = ModelInterfaceAbstraction( "ppo_critic", @@ -319,7 +317,7 @@ class PPOMATHConfig(CommonExperimentConfig): ) critic_interface.args.pop("eps_clip") rw_interface = ModelInterfaceAbstraction( - "rw_math", + "reward", args=dict( rw_type=self.rw_type, task=self.task, @@ -368,9 +366,16 @@ class PPOMATHConfig(CommonExperimentConfig): n_seqs=self.dataset.train_bs_n_seqs, ) - inf_ref_inputs = ["packed_input_ids"] + # add rew param into ref MFC + inf_ref_inputs = ["packed_input_ids", "packed_prompts"] + inf_ref_outputs = ["packed_ref_logprobs", "rewards", "dense_rewards"] + ref_interface = copy.deepcopy(actor_interface) + ref_interface.type_ = "ref_rw" + ref_interface.args["enable_save"] = False + ref_interface.args["rew_inf_args"] = copy.deepcopy(rw_interface.args) + inf_ref_logits = MFCDef( - name="ref_inf", + name="ref_rw", model_name="ref", mb_spec=self.ref_inf.mb_spec, interface_type=ModelInterfaceType.INFERENCE, @@ -379,7 +384,7 @@ class PPOMATHConfig(CommonExperimentConfig): interface_impl=ref_interface, min_n_seqs_per_pass=1 / self.group_size, input_keys=inf_ref_inputs, - output_keys=["packed_ref_logprobs"], + output_keys=inf_ref_outputs, n_seqs=self.dataset.train_bs_n_seqs, ) @@ -451,8 +456,9 @@ class PPOMATHConfig(CommonExperimentConfig): "actor_train": train_actor, # "critic_inf": inf_values, # "critic_train": train_critic, - "ref_inf": inf_ref_logits, - "rew_inf": inf_reward, + # "ref_inf": inf_ref_logits, + # "rew_inf": inf_reward, + "ref_rw": inf_ref_logits, } else: return { @@ -472,8 +478,9 @@ class PPOMATHConfig(CommonExperimentConfig): "actor_train": self.actor_train, # "critic_inf": self.critic_inf, # "critic_train": self.critic_train, - "ref_inf": self.ref_inf, - "rew_inf": self.rew_inf, + # "ref_inf": self.ref_inf, + # "rew_inf": self.rew_inf, + "ref_rw": self.ref_inf, } else: return { diff --git a/realhf/impl/model/interface/code_rw_interface.py b/realhf/impl/model/interface/code_rw_interface.py deleted file mode 100644 index 2f83aca..0000000 --- a/realhf/impl/model/interface/code_rw_interface.py +++ /dev/null @@ -1,429 +0,0 @@ -# Copyright 2025 Ant Group Inc. -import collections -import dataclasses -import html -import json -import os -import re -import xml.etree.ElementTree as ET -from ast import parse -from concurrent.futures import ThreadPoolExecutor, as_completed - -import colorama -import numpy as np -import torch -import torch.distributed as dist - -import realhf.api.core.model_api as model_api -import realhf.base.logging as logging -from functioncall.code.verify import code_verify -from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer -from realhf.base import constants - -logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark") - - -class CodeVerifierException(Exception): - pass - - -def extract_python_code(text, min_length=20, strict_syntax=True): - code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```" - code_blocks = re.findall(code_pattern, text, re.DOTALL) - valid_blocks = [] - for block in code_blocks: - clean_block = block.strip() - if len(clean_block) < min_length: - continue - - # verify code syntax - if strict_syntax: - try: - parse(clean_block, mode="exec") - except (SyntaxError, IndentationError): - continue - - valid_blocks.append(clean_block) - - if not valid_blocks: - return None - # return the last code block - return valid_blocks[-1] - - -def check_with_elementtree(text): - def escape_between_tags(text, tags=["think", "answer"]): - """转义标签之间的内容,但保留标签本身.""" - # 构建标签模式 - tag_pattern = "|".join(tags) - parts = [] - current_pos = 0 - - # 匹配开始和结束标签 - pattern = f"]*>" - - for match in re.finditer(pattern, text): - # 添加标签之前的内容(需要转义) - if current_pos < match.start(): - parts.append(html.escape(text[current_pos : match.start()])) - - # 添加标签本身(不转义) - parts.append(match.group()) - current_pos = match.end() - - # 添加最后剩余的内容 - if current_pos < len(text): - parts.append(html.escape(text[current_pos:])) - - return "".join(parts) - - text = escape_between_tags(text) - if not text.strip().startswith(""): - text = "" + text - try: - xml_text = f"{text}" - x = ET.fromstring(xml_text) - if x.text is not None and x.text.strip() != "": - return False, f"Error: extra content before . {x.text}" - if len(x) != 2: - return False, f"Error: there are {len(x)} tags." - if x[0].tag is None or x[0].tag != "think": - return False, f"Error: tag is missing. {x[0].tag}" - if x[0].tail is not None and x[0].tail.strip() != "": - return ( - False, - f"Error: extra content between and . {x[0].tail}", - ) - if x[1].tag is None or x[1].tag != "answer": - return False, f"Error: tag is missing. {x[1].tag}" - if x[1].tail is not None and x[1].tail.strip() != "": - return False, f"Error: extra content after , {x[1].tail}" - - return True, x[1].text if x[1].text is not None else "" - except ET.ParseError as e: - return False, f"Error: XML格式错误, {str(e)}" - - -def retokenize( - task, - tokenizer, - packed_input_ids, - input_cu_seqlens, - prompts, - prompt_cu_seqlens, - query_ids, - check_xml_format=False, - do_eval=False, -): - input_ids = [ - packed_input_ids[start:end] - for start, end in zip(input_cu_seqlens[:-1], input_cu_seqlens[1:]) - ] - prompt_ids = [ - prompts[start:end] - for start, end in zip(prompt_cu_seqlens[:-1], prompt_cu_seqlens[1:]) - ] - seq_strs = tokenizer.batch_decode( - input_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True - ) - prompt_strs = tokenizer.batch_decode( - prompt_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True - ) - # query_id_strs = query_ids - query_id_strs = [query_id.split("@")[0] for query_id in query_ids] - - format_rewards = [] - - queryid_to_results = collections.defaultdict(list) - # 8 processes on each node, with 10 subprocesses each - if do_eval == True: - _answers = [ - seq_str.split(prompt_str)[1] - for seq_str, prompt_str in zip(seq_strs, prompt_strs) - ] - - codes = [extract_python_code(_answer) for _answer in _answers] - logger.info( - f"code_rw_interface, size: {len(query_id_strs)}, valid code size: {len(codes)}, query_id_0: {query_id_strs[0]}" - ) - format_rewards = code_verify(codes, query_id_strs) - - if check_xml_format: - with ThreadPoolExecutor(max_workers=22) as executor: - futures = [ - executor.submit(check_with_elementtree, answer_str) - for answer_str in _answers - ] - # xml_rewards = [] - for idx, future in enumerate(futures): - xml_reward, _ = future.result() - # xml_rewards.append(xml_reward) - if xml_reward == 1 and format_rewards[idx] == 0: - format_rewards[idx] = -0.8 - elif xml_reward == 0 and format_rewards[idx] == 0: - format_rewards[idx] = -1 - - for query_id_str, format_reward in zip(query_id_strs, format_rewards): - if query_id_str not in queryid_to_results: - queryid_to_results[query_id_str] = [] - queryid_to_results[query_id_str].append(format_reward) - else: - for query_id_str in query_id_strs: - if query_id_str not in queryid_to_results: - queryid_to_results[query_id_str] = [] - queryid_to_results[query_id_str].append(0) - format_rewards.append(0) - - return format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results - - -@dataclasses.dataclass -class PackedCodeRewardInterface(model_api.ModelInterface): - - enable_save: bool = False - tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B" - output_scaling: float = 1.0 - rm_output_scaling: float = 1.0 - rm_output_bias: float = 0.0 - output_bias: float = 0.0 - loss_fun = torch.nn.CrossEntropyLoss(reduction="none") - max_sync_length: int = 2048 - rw_type: str = "sparse" - task: str = "code" # math or countdown or code - check_xml_format: bool = False - post_process: str = "sigmoid" - group_size: int = 1 - check_verifier_status: bool = False - - _call_count: int = 0 - - def __post_init__(self): - self.tokenizer = load_hf_tokenizer(self.tokenizer_path) - logger.info(f"rm_output_scaling: {self.rm_output_scaling}") - logger.info(f"rm_output_bias: {self.rm_output_bias}") - logger.info(f"output_scaling: {self.output_scaling}") - logger.info(f"output_bias: {self.output_bias}") - logger.info(f"max_sync_length: {self.max_sync_length}") - logger.info(f"rw_type: {self.rw_type}") - logger.info(f"post_process: {self.post_process}") - - while True: - gen_file_path = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "generated", - f"v{self._call_count}r{dist.get_rank()}.txt", - ) - if os.path.exists(gen_file_path): - self._call_count += 1 - else: - break - logger.info(f"call_count: {self._call_count}") - - def inference( - self, - model: model_api.Model, - data_: SequenceSample, - mb_spec, - ) -> SequenceSample: - - packed_input_ids: torch.Tensor = data_.data["packed_input_ids"].squeeze() - - input_seqlens = torch.tensor(data_.seqlens["packed_input_ids"]).view(-1) - input_cu_seqlens = torch.nn.functional.pad( - input_seqlens.cumsum(0), (1, 0) - ).int() - - packed_prompts = data_.data["packed_prompts"] - prompts = [] - prompt_seqlens = [] - offset = 0 - for x in data_.seqlens["packed_prompts"]: - prompts += [packed_prompts[offset : offset + x[0]]] * self.group_size - offset += x[0] - prompt_seqlens.extend(x * self.group_size) - - assert offset == sum(x[0] for x in data_.seqlens["packed_prompts"]) - # non_packed_prompts = copy.deepcopy(prompts) - prompts = torch.cat(prompts) - prompt_seqlens = torch.tensor(prompt_seqlens).view(-1) - prompt_cu_seqlens = torch.nn.functional.pad( - prompt_seqlens.cumsum(0), (1, 0) - ).int() - - query_ids = [data_id for data_id in data_.ids for _ in range(self.group_size)] - - format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results = ( - retokenize( - self.task, - self.tokenizer, - packed_input_ids, - input_cu_seqlens, - prompts, - prompt_cu_seqlens, - query_ids, - check_xml_format=self.check_xml_format, - do_eval=constants.is_last_pipe_stage(), - ) - ) - - assert self.rw_type == "sparse" - dense_scores = torch.zeros_like(packed_input_ids).float() - scores = torch.FloatTensor(format_rewards).to(packed_input_ids.device) - scores[scores == 0] = -1 - - if len(scores) == 0: - return None - - assert dense_scores.shape == packed_input_ids.shape - scores = ( - scores.to(packed_input_ids.device) - self.output_bias - ) * self.output_scaling - - logger.info(f"Code reward logging info @v{model.version.global_step}") - logger.info( - f"before: Format success rate: {torch.FloatTensor(format_rewards).mean().item()}" - ) - logger.info( - f"number of samples: {len(scores)}, {scores.shape}, group_size: {self.group_size}" - ) - - gen_file_path = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "generated", - f"v{self._call_count}r{dist.get_rank()}.txt", - ) - os.makedirs(os.path.dirname(gen_file_path), exist_ok=True) - logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}") - with open(gen_file_path, "w") as _f: - for idx, (score, prompt_str, seq_str) in enumerate( - zip(scores, prompt_strs, seq_strs) - ): - info = "\n".join( - [ - f"idx: {idx} / {len(scores)}", - f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}", - f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.", - ] - ) - _f.write(info + "\n") - - gen_file_path = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "generated_jsonl", - f"v{self._call_count}r{dist.get_rank()}.jsonl", - ) - os.makedirs(os.path.dirname(gen_file_path), exist_ok=True) - logger.info(f"Generated samples and rewards will be dumped to: {gen_file_path}") - with open(gen_file_path, "w") as _f: - for idx, (score, prompt_str, seq_str) in enumerate( - zip(scores, prompt_strs, seq_strs) - ): - _f.write( - json.dumps( - { - "prompt": prompt_str, - "generated": seq_str.split(prompt_str)[1], - "reward": score.item(), - }, - ensure_ascii=False, - ) - + "\n" - ) - - logger.info( - f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}" - ) - - pass_at_k = np.mean( - [sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()] - ) - avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()]) - logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}") - - logger.info(f"number of samples: {len(scores)}, {scores.shape}") - - logger.info(f"reward: {sum(scores) / len(scores)}") - - train_pass_monitor_file_path = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "training_monitor", - f"v{self._call_count}r{dist.get_rank()}.jsonl", - ) - os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True) - logger.info( - f"pass monitor result will be dumped to: {train_pass_monitor_file_path}" - ) - with open(train_pass_monitor_file_path, "w") as monitor_file: - for key, value in queryid_to_results.items(): - pass1 = sum(value) / len(value) - pass8 = int(sum(value) > 0) - monitor_file.write( - json.dumps( - {"query_id": key, "pass1": pass1, "pass8": pass8}, - ensure_ascii=False, - ) - + "\n" - ) - - self._call_count += 1 - - if scores.dtype != torch.float32: - scores = scores.to(torch.float32) - if dense_scores.dtype != torch.float32: - dense_scores = dense_scores.to(torch.float32) - - res = SequenceSample( - keys=["rewards", "dense_rewards"], - trailing_shapes=dict(rewards=(), dense_rewards=()), - dtypes=dict(rewards=torch.float32, dense_rewards=torch.float32), - ids=data_.ids, - seqlens=dict( - rewards=[ - torch.tensor([1 for _ in range(len(x))], dtype=torch.int32) - for x in data_.seqlens["packed_input_ids"] - ], - dense_rewards=data_.seqlens["packed_input_ids"], - ), - data=dict(rewards=scores, dense_rewards=dense_scores), - ) - - # record rewards for each piece of data - avg_scores = [] - offset = 0 - for i in range(data_.bs): - score_lis = scores[ - offset : offset + len(data_.seqlens["packed_input_ids"][i]) - ] - avg_scores.append(score_lis.mean().item()) - offset += len(data_.seqlens["packed_input_ids"][i]) - assert offset == sum(len(x) for x in data_.seqlens["packed_input_ids"]) - - res.metadata["scores"] = avg_scores - - if self.check_verifier_status: - avg_score = torch.tensor( - np.mean(avg_scores), device=constants.current_device() - ) - dist.all_reduce( - avg_score, op=dist.ReduceOp.SUM, group=constants.parallelism_group() - ) - avg_score /= constants.parallelism_group_size() - avg_score = avg_score.item() - minimal_score = (-1 - self.output_bias) * self.rm_output_scaling - - if avg_score <= minimal_score or np.isclose(avg_score, minimal_score): - raise CodeVerifierException( - "All rewards are at minimal value. Probably there are something wrong with the verifier!" - ) - return res - - -model_api.register_interface("rw_code", PackedCodeRewardInterface) diff --git a/realhf/impl/model/interface/math_rw_interface.py b/realhf/impl/model/interface/rw_interface.py similarity index 91% rename from realhf/impl/model/interface/math_rw_interface.py rename to realhf/impl/model/interface/rw_interface.py index d8425a0..6fd8bbc 100644 --- a/realhf/impl/model/interface/math_rw_interface.py +++ b/realhf/impl/model/interface/rw_interface.py @@ -37,10 +37,34 @@ ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel -class MathVerifierException(Exception): +class VerifierException(Exception): pass +def extract_python_code(text, min_length=20, strict_syntax=True): + code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```" + code_blocks = re.findall(code_pattern, text, re.DOTALL) + valid_blocks = [] + for block in code_blocks: + clean_block = block.strip() + if len(clean_block) < min_length: + continue + + # verify code syntax + if strict_syntax: + try: + parse(clean_block, mode="exec") + except (SyntaxError, IndentationError): + continue + + valid_blocks.append(clean_block) + + if not valid_blocks: + return None + # return the last code block + return valid_blocks[-1] + + def check_with_elementtree(text): def escape_between_tags(text, tags=["think", "answer"]): """转义标签之间的内容,但保留标签本身.""" @@ -94,6 +118,19 @@ def check_with_elementtree(text): return False, f"Error: XML格式错误, {str(e)}" +def reward_caculate(task, _answers, query_id_strs): + format_rewards = [] + if task == "math": + format_rewards = math_verify_call(_answers, query_id_strs) + else: + codes = [extract_python_code(_answer) for _answer in _answers] + format_rewards = code_verify(codes, query_id_strs) + logger.info( + f"reward_caculate, task: {task}, size: {len(query_id_strs)}, query_id_0: {query_id_strs[0]}" + ) + return format_rewards + + def retokenize( task, tokenizer, @@ -122,6 +159,10 @@ def retokenize( # query_id_strs = query_ids query_id_strs = [query_id.split("@")[0] for query_id in query_ids] + logger.info( + f"retokenize, query_id_strs:{query_id_strs}, seq_strs:{seq_strs}, prompt_strs:{prompt_strs}" + ) + format_rewards = [] queryid_to_results = collections.defaultdict(list) # 8 processes on each node, with 10 subprocesses each @@ -162,8 +203,7 @@ def retokenize( @dataclasses.dataclass -class PackedMathRewardInterface(model_api.ModelInterface): - +class PackedRewardInterface(model_api.ModelInterface): enable_save: bool = False tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B" output_scaling: float = 1.0 @@ -392,7 +432,7 @@ class PackedMathRewardInterface(model_api.ModelInterface): minimal_score = (-1 - self.output_bias) * self.rm_output_scaling if avg_score <= minimal_score or np.isclose(avg_score, minimal_score): - raise MathVerifierException( + raise VerifierException( "All rewards are at minimal value. Probably there are something wrong with the verifier!" ) @@ -401,4 +441,4 @@ class PackedMathRewardInterface(model_api.ModelInterface): return res -model_api.register_interface("rw_math", PackedMathRewardInterface) +model_api.register_interface("reward", PackedRewardInterface) diff --git a/realhf/impl/model/interface/rw_ref_interface.py b/realhf/impl/model/interface/rw_ref_interface.py new file mode 100644 index 0000000..4fd31b1 --- /dev/null +++ b/realhf/impl/model/interface/rw_ref_interface.py @@ -0,0 +1,270 @@ +# Copyright 2025 Ant Group Inc. +# Copyright 2024 Wei Fu & Zhiyu Mei +# Licensed under the Apache License, Version 2.0 (the "License"). + +import asyncio +import dataclasses +import functools +import itertools +import time +from typing import Dict, Literal, Optional, Tuple + +import torch + + +import realhf.api.core.model_api as model_api + +import realhf.base.logging as logging +import realhf.impl.model.utils.ppo_functional as ppo_functional +from realhf.api.core.data_api import MicroBatchSpec, SequenceSample +from realhf.base.datapack import flat2d + +from realhf.impl.model.nn.real_llm_api import ReaLModel + +from realhf.impl.model.utils.functional import ( + gather_packed_shifted_log_probs, + masked_normalization, +) +from realhf.impl.model.interface.rw_interface import PackedRewardInterface + +logger = logging.getLogger("RefRwInterface") + +TASK_TYPE_REF: Literal["ref"] = "ref" +TASK_TYPE_RW_MATH: Literal["rw_math"] = "rw_math" +TASK_TYPE_RW_CODE: Literal["rw_code"] = "rw_code" + + +@dataclasses.dataclass +class RefRwInterface(model_api.ModelInterface): + n_minibatches: int = 4 + + # Use dict here to allow argument passing through commandline. + generation_config: Dict = dataclasses.field(default_factory=dict) + + kl_ctl: float = 0.1 + + adv_norm: bool = True + discount: float = 1.0 + gae_lambda: float = 1.0 + + eps_clip: float = 0.2 + value_eps_clip: float = 0.2 + max_reward_clip: float = 5.0 + + disable_value: bool = False + + early_stop_kl: Optional[float] = None # e.g. 0.1 + early_stop_imp_ratio: Optional[float] = None # e.g., 10.0 + + adaptive_kl_ctl: bool = False + adaptive_kl_target: Optional[float] = 6 + adaptive_kl_horizon: Optional[float] = 10000 + + enable_save: bool = True + + value_norm: bool = False + value_norm_type: str = dataclasses.field( + metadata={"choices": ["exp", "ma"]}, default="exp" + ) + value_norm_beta: float = 0.99995 + value_norm_eps: float = 1e-5 + + group_size: int = 1 + generation_size: Optional[int] = None + mask_no_eos_with_zero: bool = False + group_adv_norm: bool = False + mask_too_long: bool = False + use_dense_reward: bool = False + reward_delta: bool = True + token_normalize_scope: Literal["global", "dp"] = "global" + rew_inf_args: Dict = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.adaptive_kl_ctl: + assert self.adaptive_kl_target is not None + assert self.adaptive_kl_horizon is not None + self.kl_adapter = ppo_functional.AdaptiveKLController( + self.kl_ctl, self.adaptive_kl_target, self.adaptive_kl_horizon + ) + else: + self.kl_adapter = ppo_functional.FixedKLController(self.kl_ctl) + if self.value_norm: + from realhf.impl.model.modules import ( + ExponentialRunningMeanStd, + MovingAverageRunningMeanStd, + ) + + if self.value_norm_type == "exp": + self.rms = ExponentialRunningMeanStd( + beta=self.value_norm_beta, epsilon=self.value_norm_eps + ) + elif self.value_norm_type == "ma": + self.rms = MovingAverageRunningMeanStd() + else: + raise ValueError(f"Unknown value_norm_type {self.value_norm_type}") + self.kl_ctl = None + + self.gconfig = model_api.GenerationHyperparameters(**self.generation_config) + if self.generation_size is not None: + assert self.generation_size >= self.group_size + else: + self.generation_size = self.group_size + self.gconfig.n = self.generation_size + + def save(self, model: model_api.Model, save_dir: str): + if not self.enable_save: + return + module = model.module + if not isinstance(module, ReaLModel): + module = module.module + module.save_to_hf( + tokenizer=model.tokenizer, + save_dir=save_dir, + ) + + def _dispatch_tasks(self, data): + math_data, code_data, rlhf_data, ref_data = data, data, data, data + return math_data, code_data, rlhf_data, ref_data + + def _gather_tasks(self, data_map): + # merge SequenceSamples from math_data, code_data, rlhf_data, ref_data + return data_map.get(TASK_TYPE_REF, None) + + @torch.no_grad() + def ref_inference( + self, model: model_api.Model, input_: SequenceSample, mb_spec: MicroBatchSpec + ): + module = model.module + module.eval() + + # This post_hook will gather log probabilities in mini-batches, + # reducing peak memory usage. + def calc_logprobs(logits, input_): + logits /= self.gconfig.temperature + + input_lens = torch.tensor(input_.seqlens["packed_input_ids"]).view(-1) + cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int() + + logprobs = gather_packed_shifted_log_probs( + logits, cu_seqlens, input_.data["packed_input_ids"] + ) + return logprobs + + input_flattend = SequenceSample.from_default( + ids=list(range(input_.bs * self.group_size)), + seqlens=flat2d(input_.seqlens["packed_input_ids"]), + data=dict(packed_input_ids=input_.data["packed_input_ids"]), + ) + # add posthook to avoid storing full logits + logprobs = module.forward( + input_=input_flattend, + post_hook=calc_logprobs, + output_seqlens=[ + [x - 1 for x in slens] + for slens in input_flattend.seqlens["packed_input_ids"] + ], + mb_spec=mb_spec, + ) + + res = SequenceSample( + keys=["packed_ref_logprobs"], + ids=input_.ids, + dtypes=dict(packed_ref_logprobs=model.module.dtype), + trailing_shapes=dict(packed_ref_logprobs=()), + data=dict(packed_ref_logprobs=logprobs), + seqlens=dict( + packed_ref_logprobs=[ + [x - 1 for x in slen] for slen in input_.seqlens["packed_input_ids"] + ] + ), + ) + + return res + + def inference( + self, + model: model_api.Model, + input_: SequenceSample, + mb_spec: MicroBatchSpec, + ) -> SequenceSample: + math_data, code_data, rlhf_data, ref_data = self._dispatch_tasks(input_) + + if not hasattr(self, "rew_inf_args") or not isinstance(self.rew_inf_args, dict): + raise ValueError("Invalid rew_inf_args. Expected a dictionary.") + rewardInterface = PackedRewardInterface(**self.rew_inf_args) + logger.info(f"self.rew_inf_args: {self.rew_inf_args}, input_: {input_}") + + task_map = { + TASK_TYPE_REF: (self.ref_inference, ref_data), + TASK_TYPE_RW_MATH: (rewardInterface.inference, math_data), + TASK_TYPE_RW_CODE: (rewardInterface.inference, code_data), + } + + def _task_func(func, task_type: str): + def _wrapped_func(*args, **kwargs): + start_time = time.perf_counter() + logger.info(f"[{task_type}] ref_rw task start @ {start_time:.4f}") + try: + result = func(*args, **kwargs) + except Exception as e: + logger.error(f"{task_type} ref_rw task failed: {e}") + finally: + duration = time.perf_counter() - start_time + logger.info( + f"[{task_type}] ref_rw task cost: {duration:.4f}s, start @ {start_time:.4f}" + ) + return result + + return _wrapped_func + + async def _run_tasks() -> dict: + tasks = [] + for task_type, (func, data) in task_map.items(): + if not data: + continue + task_func = _task_func(func, task_type) + task_args = (model, data, mb_spec) + task = asyncio.create_task(asyncio.to_thread(task_func, *task_args)) + tasks.append((task_type, task)) + + results = {} + for task_type, task in tasks: + try: + results[task_type] = await task + except Exception as e: + logger.error(f"{task_type} task failed: {e}") + results[task_type] = None + return results + + task_results = asyncio.run(_run_tasks()) + final_result = self._gather_tasks(task_results) + return final_result + + # Mock methods for profiling only. + def _mock_inference( + self, + model: model_api.Model, + dataset_input: SequenceSample, + ) -> SequenceSample: + prompt_lens = flat2d(dataset_input.seqlens["packed_prompts"]) + seqlens = [x + self.gconfig.max_new_tokens for x in prompt_lens] + module = model.module + if not isinstance(module, ReaLModel): + module = module.module + mconfig = module.config + packed_input_ids = torch.randint( + 0, + mconfig.vocab_size, + (sum(seqlens),), + dtype=torch.long, + device=model.device, + ) + + return SequenceSample.from_default( + seqlens=seqlens, + ids=dataset_input.ids, + data=dict(packed_input_ids=packed_input_ids), + ) + + +model_api.register_interface("ref_rw", RefRwInterface) diff --git a/tests/experiments/test_async_refrw.py b/tests/experiments/test_async_refrw.py new file mode 100644 index 0000000..116abee --- /dev/null +++ b/tests/experiments/test_async_refrw.py @@ -0,0 +1,317 @@ +import functools +import gc +import os +import pickle +import json +import time +from typing import * + +import numpy as np +import pynvml +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import transformers +from torch.cuda import is_initialized + +root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +print(root_dir) +import sys + +sys.path.insert(0, root_dir) + +from realhf.api.core import data_api, dfg, model_api +from realhf.api.core.config import ModelName +from realhf.api.core.model_api import ReaLModelConfig +from realhf.base import constants, logging +from realhf.base.network import find_free_port +from realhf.base.testing import ( + init_global_constants, + _DEFAULT_EXPR_NAME, + _DEFAULT_TRIAL_NAME, +) + +logger = logging.getLogger("test async ref-rew") +os.environ["REAL_MATH_METADATA_PATH"] = "/storage/datasets/id2info.json" + + +def loadJson(): + dataDir = os.environ["REAL_MATH_METADATA_PATH"] + with open(dataDir, "r") as f: + if dataDir.endswith(".jsonl"): + samples = [json.loads(line) for line in f.readlines()] + else: + samples = json.load(f) + + return samples + + +def _mock_input(batch_size: int, seq_len): + vocab_size = 100 + torch.manual_seed(1) + seqs = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long) + + samples = loadJson() + id_list = list(samples.keys()) + # id_tensor = torch.tensor([id_list[i] for i in range(seqs.shape[0])], dtype=torch.long) # 使用哈希值编码 + + return data_api.SequenceSample.from_default( + seqlens=[seq_len for _ in range(seqs.shape[0])], + ids=[id_list[i] for i in range(seqs.shape[0])], + data=dict( + packed_input_ids=seqs.view(-1), + # prompt_mask=torch.zeros_like(seqs.view(-1), dtype=torch.bool), + packed_prompts=seqs[:, :seq_len].contiguous().view(-1), + ), + ) + + +def funcion_call( + rpc_name: str, + rank: int, + world_size: int, + model_path: str, + model_family_name: str, + dp: int, + pp: int, + tp: int, + interface_type: dfg.ModelInterfaceType, + interface_impl: dfg.ModelInterfaceAbstraction, + batch_size: int, + prompt_len: int, + input_: data_api.SequenceSample | None, + port: int, +): + + # assert not torch.cuda.is_initialized() + os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) + torch.cuda.set_device(0) + assert world_size == ( + dp * pp * tp + ), f"dp={dp}, pp={pp}, tp={tp}, world_size={world_size}" + assert batch_size % dp == 0, (batch_size, dp) + + # Initialize distributed environment. + model_name = ModelName("default", 0) + if not dist.is_initialized(): + logger.info("Setting up distributed environment...") + dist.init_process_group( + "nccl", + rank=rank, + world_size=world_size, + init_method=f"tcp://localhost:{port}", + ) + logger.info("Initialized distributed environment.") + init_global_constants( + num_dp=dp, + num_mp=tp, + num_pp=pp, + sequence_parallel=interface_type == dfg.ModelInterfaceType.TRAIN_STEP, + model_name=model_name, + max_prompt_len=prompt_len, + ) + torch.cuda.set_device(0) + + # NOTE: import here to avoid CUDA re-initialization + + from realhf.impl.model.nn.real_llm_api import ReaLModel, add_helper_functions + + # Call a method like `config_from_llama` to get the config. + mconfig: ReaLModelConfig = getattr(ReaLModel, f"config_from_{model_family_name}")( + transformers.AutoConfig.from_pretrained(model_path) + ) + is_critic = rpc_name in ["critic_inf", "critic_train", "rew_inf"] + mconfig.is_critic = is_critic + with constants.model_scope(model_name): + # Construct the model. + logger.info(f"Loading model from {model_path}...") + module = ReaLModel(mconfig, dtype=torch.bfloat16, device="cuda") + setattr(ReaLModel, "save_to_hf", getattr(ReaLModel, f"to_{model_family_name}")) + setattr( + ReaLModel, "load_from_hf", getattr(ReaLModel, f"from_{model_family_name}") + ) + module._instantiation_hooks.append( + lambda: getattr(module, f"from_{model_family_name}")( + load_dir=model_path, + init_critic_from_actor=is_critic, + ) + ) + add_helper_functions(module) + module.instantiate() + module.eval() + + tokenizer = data_api.load_hf_tokenizer(model_path) + + model = model_api.Model( + name=model_name, + module=module, + tokenizer=tokenizer, + device=module.device, + dtype=module.dtype, + ) + if interface_type == dfg.ModelInterfaceType.TRAIN_STEP: + from realhf.impl.model.backend.megatron import MegatronTrainBackend + + backend = MegatronTrainBackend() + else: + from realhf.impl.model.backend.inference import PipelineInferenceBackend + + backend = PipelineInferenceBackend() + + logger.info("Running backend initialization...") + ft_spec = model_api.FinetuneSpec( + total_train_epochs=1, + dataset_size=128, + train_batch_size=128, + ) + model = backend.initialize(model, ft_spec) + + interface = model_api.make_interface(interface_impl) + + if input_ is None: + input_ = _mock_input(batch_size, prompt_len) + + input_ = input_.cuda() + + mb_spec = model_api.MicroBatchSpec() + + logger.info("Running interface computation...") + start = time.perf_counter_ns() + if interface_type == dfg.ModelInterfaceType.GENERATE: + res = interface.generate(model, input_, mb_spec) + elif interface_type == dfg.ModelInterfaceType.TRAIN_STEP: + res = interface.train_step(model, input_) + else: + res = interface.inference(model, input_, mb_spec) + + if constants.model_parallel_rank() == 0 and constants.is_last_pipe_stage(): + if isinstance(res, data_api.SequenceSample): + res = res.cpu() + + comsumed = time.perf_counter_ns() - start + logger.info(f"{rpc_name} Computation done. {comsumed} ns") + return res + + +def run_function_call( + rpc_name: str, + model_path: str, + model_family_name: str, + batch_size: int, + prompt_len: int, + gen_len: int, + input_: data_api.SequenceSample | None, +) -> data_api.SequenceSample | None: + assert rpc_name in [ + "actor_gen", + "actor_train", + "critic_inf", + "rew_inf", + "critic_train", + "ref_inf", + "ref_rw", + ] + + ref_rw_interface = dfg.ModelInterfaceAbstraction( + "ref_rw", + args=dict( + generation_config=dict( + max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True + ), + rew_inf_args=dict( + tokenizer_path=model_path, + ), + ), + ) + + ppo_actor_interface = dfg.ModelInterfaceAbstraction( + "ppo_actor", + args=dict( + generation_config=dict( + max_new_tokens=gen_len, min_new_tokens=gen_len, greedy=True + ), + rew_inf_args=dict( + tokenizer_path=model_path, + ), + ), + ) + ppo_critic_interface = dfg.ModelInterfaceAbstraction("ppo_critic") + rw_interface = dfg.ModelInterfaceAbstraction( + "paired_rw", + ) + if rpc_name == "actor_gen": + interface_type = dfg.ModelInterfaceType.GENERATE + interface_impl = ppo_actor_interface + elif rpc_name == "actor_train": + interface_type = dfg.ModelInterfaceType.TRAIN_STEP + interface_impl = ppo_actor_interface + elif rpc_name == "critic_inf": + interface_type = dfg.ModelInterfaceType.INFERENCE + interface_impl = ppo_critic_interface + elif rpc_name == "ref_inf": + interface_type = dfg.ModelInterfaceType.INFERENCE + interface_impl = ppo_actor_interface + elif rpc_name == "ref_rw": + interface_type = dfg.ModelInterfaceType.INFERENCE + interface_impl = ref_rw_interface + elif rpc_name == "critic_train": + interface_type = dfg.ModelInterfaceType.TRAIN_STEP + interface_impl = ppo_critic_interface + else: + interface_type = dfg.ModelInterfaceType.INFERENCE + interface_impl = rw_interface + + logger.info(f"Running RPC {rpc_name}...") + + port = find_free_port() + res = funcion_call( + rank=0, + rpc_name=rpc_name, + world_size=1, + model_path=model_path, + model_family_name=model_family_name, + dp=1, + pp=1, + tp=1, + interface_type=interface_type, + interface_impl=interface_impl, + batch_size=batch_size, + prompt_len=prompt_len, + input_=input_, + port=port, + ) + + gc.collect() + torch.cuda.empty_cache() + gc.collect() + if isinstance(res, data_api.SequenceSample): + return res + else: + logger.info(f"RPC {rpc_name} stats: {res}") + + +def main(): + mp.set_start_method("spawn", force=True) + + model_family_name = "qwen2" + batch_size = 16 + prompt_len = 128 + gen_len = 4096 + model_path = "/storage/models/DeepSeek-R1-Distill-Qwen-1.5B" + + constants.set_experiment_trial_names(_DEFAULT_EXPR_NAME, _DEFAULT_TRIAL_NAME) + + for i in range(2): + ref_rw_res = run_function_call( + "ref_rw", + model_family_name=model_family_name, + model_path=model_path, + batch_size=batch_size, + prompt_len=prompt_len, + gen_len=gen_len, + input_=None, + ) + + +if __name__ == "__main__": + main()