mirror of https://github.com/inclusionAI/AReaL
445 lines
16 KiB
Python
445 lines
16 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
|
|
import collections
|
|
import copy
|
|
import dataclasses
|
|
import html
|
|
import itertools
|
|
import json
|
|
import os
|
|
import random
|
|
import re
|
|
import xml.etree.ElementTree as ET
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from typing import Dict, List, Optional
|
|
|
|
import colorama
|
|
import numpy as np
|
|
import requests
|
|
import torch
|
|
import torch.distributed as dist
|
|
import tqdm
|
|
import transformers
|
|
|
|
import realhf.api.core.model_api as model_api
|
|
import realhf.base.logging as logging
|
|
from functioncall.math.verify import math_verify
|
|
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer
|
|
from realhf.base import constants
|
|
from realhf.base.constants import data_parallel_group, data_parallel_world_size
|
|
from realhf.base.datapack import flat2d
|
|
from realhf.impl.model.interface.math_parser import parse_lines_in_parallel
|
|
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
|
|
|
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
|
|
|
|
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
|
|
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
|
|
|
|
|
|
class VerifierException(Exception):
|
|
pass
|
|
|
|
|
|
def extract_python_code(text, min_length=20, strict_syntax=True):
|
|
code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```"
|
|
code_blocks = re.findall(code_pattern, text, re.DOTALL)
|
|
valid_blocks = []
|
|
for block in code_blocks:
|
|
clean_block = block.strip()
|
|
if len(clean_block) < min_length:
|
|
continue
|
|
|
|
# verify code syntax
|
|
if strict_syntax:
|
|
try:
|
|
parse(clean_block, mode="exec")
|
|
except (SyntaxError, IndentationError):
|
|
continue
|
|
|
|
valid_blocks.append(clean_block)
|
|
|
|
if not valid_blocks:
|
|
return None
|
|
# return the last code block
|
|
return valid_blocks[-1]
|
|
|
|
|
|
def check_with_elementtree(text):
|
|
def escape_between_tags(text, tags=["think", "answer"]):
|
|
"""转义标签之间的内容,但保留标签本身."""
|
|
# 构建标签模式
|
|
tag_pattern = "|".join(tags)
|
|
parts = []
|
|
current_pos = 0
|
|
|
|
# 匹配开始和结束标签
|
|
pattern = f"</?({tag_pattern})[^>]*>"
|
|
|
|
for match in re.finditer(pattern, text):
|
|
# 添加标签之前的内容(需要转义)
|
|
if current_pos < match.start():
|
|
parts.append(html.escape(text[current_pos : match.start()]))
|
|
|
|
# 添加标签本身(不转义)
|
|
parts.append(match.group())
|
|
current_pos = match.end()
|
|
|
|
# 添加最后剩余的内容
|
|
if current_pos < len(text):
|
|
parts.append(html.escape(text[current_pos:]))
|
|
|
|
return "".join(parts)
|
|
|
|
text = escape_between_tags(text)
|
|
if not text.strip().startswith("<think>"):
|
|
text = "<think>" + text
|
|
try:
|
|
xml_text = f"<root>{text}</root>"
|
|
x = ET.fromstring(xml_text)
|
|
if x.text is not None and x.text.strip() != "":
|
|
return False, f"Error: extra content before <think>. {x.text}"
|
|
if len(x) != 2:
|
|
return False, f"Error: there are {len(x)} tags."
|
|
if x[0].tag is None or x[0].tag != "think":
|
|
return False, f"Error: <think> tag is missing. {x[0].tag}"
|
|
if x[0].tail is not None and x[0].tail.strip() != "":
|
|
return (
|
|
False,
|
|
f"Error: extra content between <think> and <answer>. {x[0].tail}",
|
|
)
|
|
if x[1].tag is None or x[1].tag != "answer":
|
|
return False, f"Error: <answer> tag is missing. {x[1].tag}"
|
|
if x[1].tail is not None and x[1].tail.strip() != "":
|
|
return False, f"Error: extra content after <answer>, {x[1].tail}"
|
|
|
|
return True, x[1].text if x[1].text is not None else ""
|
|
except ET.ParseError as e:
|
|
return False, f"Error: XML格式错误, {str(e)}"
|
|
|
|
|
|
def reward_caculate(task, _answers, query_id_strs):
|
|
format_rewards = []
|
|
if task == "math":
|
|
format_rewards = math_verify_call(_answers, query_id_strs)
|
|
else:
|
|
codes = [extract_python_code(_answer) for _answer in _answers]
|
|
format_rewards = code_verify(codes, query_id_strs)
|
|
logger.info(
|
|
f"reward_caculate, task: {task}, size: {len(query_id_strs)}, query_id_0: {query_id_strs[0]}"
|
|
)
|
|
return format_rewards
|
|
|
|
|
|
def retokenize(
|
|
task,
|
|
tokenizer,
|
|
packed_input_ids,
|
|
input_cu_seqlens,
|
|
prompts,
|
|
prompt_cu_seqlens,
|
|
query_ids,
|
|
check_xml_format=False,
|
|
do_eval=False,
|
|
):
|
|
input_ids = [
|
|
packed_input_ids[start:end]
|
|
for start, end in zip(input_cu_seqlens[:-1], input_cu_seqlens[1:])
|
|
]
|
|
prompt_ids = [
|
|
prompts[start:end]
|
|
for start, end in zip(prompt_cu_seqlens[:-1], prompt_cu_seqlens[1:])
|
|
]
|
|
seq_strs = tokenizer.batch_decode(
|
|
input_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
|
)
|
|
prompt_strs = tokenizer.batch_decode(
|
|
prompt_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
|
)
|
|
# query_id_strs = query_ids
|
|
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
|
|
|
|
logger.info(
|
|
f"retokenize, query_id_strs:{query_id_strs}, seq_strs:{seq_strs}, prompt_strs:{prompt_strs}"
|
|
)
|
|
|
|
format_rewards = []
|
|
queryid_to_results = collections.defaultdict(list)
|
|
# 8 processes on each node, with 10 subprocesses each
|
|
if do_eval == True:
|
|
_answers = [
|
|
seq_str.split(prompt_str)[1]
|
|
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
|
|
]
|
|
|
|
format_rewards = math_verify_call(_answers, query_id_strs)
|
|
if check_xml_format:
|
|
with ThreadPoolExecutor(max_workers=22) as executor:
|
|
futures = [
|
|
executor.submit(check_with_elementtree, answer_str)
|
|
for answer_str in _answers
|
|
]
|
|
# xml_rewards = []
|
|
for idx, future in enumerate(futures):
|
|
xml_reward, _ = future.result()
|
|
# xml_rewards.append(xml_reward)
|
|
if xml_reward == 1 and format_rewards[idx] == 0:
|
|
format_rewards[idx] = -0.8
|
|
elif xml_reward == 0 and format_rewards[idx] == 0:
|
|
format_rewards[idx] = -1
|
|
|
|
for query_id_str, format_reward in zip(query_id_strs, format_rewards):
|
|
if query_id_str not in queryid_to_results:
|
|
queryid_to_results[query_id_str] = []
|
|
queryid_to_results[query_id_str].append(format_reward)
|
|
else:
|
|
for query_id_str in query_id_strs:
|
|
if query_id_str not in queryid_to_results:
|
|
queryid_to_results[query_id_str] = []
|
|
queryid_to_results[query_id_str].append(0)
|
|
format_rewards.append(0)
|
|
|
|
return format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class PackedRewardInterface(model_api.ModelInterface):
|
|
enable_save: bool = False
|
|
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
|
output_scaling: float = 1.0
|
|
rm_output_scaling: float = 1.0
|
|
rm_output_bias: float = 0.0
|
|
output_bias: float = 0.0
|
|
loss_fun = torch.nn.CrossEntropyLoss(reduction="none")
|
|
max_sync_length: int = 2048
|
|
rw_type: str = "sparse"
|
|
task: str = "math" # math or countdown
|
|
check_xml_format: bool = False
|
|
post_process: str = "sigmoid"
|
|
group_size: int = 1
|
|
check_verifier_status: bool = False
|
|
|
|
def __post_init__(self):
|
|
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
|
|
logger.info(f"rm_output_scaling: {self.rm_output_scaling}")
|
|
logger.info(f"rm_output_bias: {self.rm_output_bias}")
|
|
logger.info(f"output_scaling: {self.output_scaling}")
|
|
logger.info(f"output_bias: {self.output_bias}")
|
|
logger.info(f"max_sync_length: {self.max_sync_length}")
|
|
logger.info(f"rw_type: {self.rw_type}")
|
|
logger.info(f"post_process: {self.post_process}")
|
|
|
|
def inference(
|
|
self,
|
|
model: model_api.Model,
|
|
data_: SequenceSample,
|
|
mb_spec,
|
|
) -> SequenceSample:
|
|
|
|
packed_input_ids: torch.Tensor = data_.data["packed_input_ids"].squeeze()
|
|
|
|
input_seqlens = torch.tensor(data_.seqlens["packed_input_ids"]).view(-1)
|
|
input_cu_seqlens = torch.nn.functional.pad(
|
|
input_seqlens.cumsum(0), (1, 0)
|
|
).int()
|
|
|
|
packed_prompts = data_.data["packed_prompts"]
|
|
prompts = []
|
|
prompt_seqlens = []
|
|
offset = 0
|
|
for x in data_.seqlens["packed_prompts"]:
|
|
prompts += [packed_prompts[offset : offset + x[0]]] * self.group_size
|
|
offset += x[0]
|
|
prompt_seqlens.extend(x * self.group_size)
|
|
|
|
assert offset == sum(x[0] for x in data_.seqlens["packed_prompts"])
|
|
# non_packed_prompts = copy.deepcopy(prompts)
|
|
prompts = torch.cat(prompts)
|
|
prompt_seqlens = torch.tensor(prompt_seqlens).view(-1)
|
|
prompt_cu_seqlens = torch.nn.functional.pad(
|
|
prompt_seqlens.cumsum(0), (1, 0)
|
|
).int()
|
|
|
|
query_ids = [data_id for data_id in data_.ids for _ in range(self.group_size)]
|
|
|
|
format_rewards, prompt_strs, prompt_ids, seq_strs, queryid_to_results = (
|
|
retokenize(
|
|
self.task,
|
|
self.tokenizer,
|
|
packed_input_ids,
|
|
input_cu_seqlens,
|
|
prompts,
|
|
prompt_cu_seqlens,
|
|
query_ids,
|
|
check_xml_format=self.check_xml_format,
|
|
do_eval=constants.is_last_pipe_stage(),
|
|
)
|
|
)
|
|
|
|
assert self.rw_type == "sparse"
|
|
dense_scores = torch.zeros_like(packed_input_ids).float()
|
|
scores = torch.FloatTensor(format_rewards).to(packed_input_ids.device)
|
|
scores[scores == 0] = -1
|
|
|
|
if len(scores) == 0:
|
|
return None
|
|
|
|
assert dense_scores.shape == packed_input_ids.shape
|
|
scores = (
|
|
scores.to(packed_input_ids.device) - self.output_bias
|
|
) * self.output_scaling
|
|
|
|
logger.info(f"Math reward logging info @v{model.version.global_step}")
|
|
logger.info(
|
|
f"before: Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
|
)
|
|
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
|
|
if constants.is_last_pipe_stage():
|
|
|
|
gen_file_path = os.path.join(
|
|
constants.LOG_ROOT,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
"generated",
|
|
f"v{model.version.global_step}r{dist.get_rank()}.txt",
|
|
)
|
|
|
|
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
|
logger.info(
|
|
f"Generated samples and rewards will be dumped to: {gen_file_path}"
|
|
)
|
|
with open(gen_file_path, "w") as _f:
|
|
for idx, (score, prompt_str, seq_str) in enumerate(
|
|
zip(scores, prompt_strs, seq_strs)
|
|
):
|
|
info = "\n".join(
|
|
[
|
|
f"idx: {idx} / {len(scores)}",
|
|
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
|
|
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
|
|
]
|
|
)
|
|
_f.write(info + "\n")
|
|
|
|
gen_file_path = os.path.join(
|
|
constants.LOG_ROOT,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
"generated_jsonl",
|
|
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
|
|
)
|
|
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
|
logger.info(
|
|
f"Generated samples and rewards will be dumped to: {gen_file_path}"
|
|
)
|
|
with open(gen_file_path, "w") as _f:
|
|
for idx, (score, prompt_str, seq_str) in enumerate(
|
|
zip(scores, prompt_strs, seq_strs)
|
|
):
|
|
_f.write(
|
|
json.dumps(
|
|
{
|
|
"prompt": prompt_str,
|
|
"generated": seq_str.split(prompt_str)[1],
|
|
"reward": score.item(),
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
+ "\n"
|
|
)
|
|
|
|
logger.info(
|
|
f"Format success rate: {torch.FloatTensor(format_rewards).mean().item()}"
|
|
)
|
|
|
|
pass_at_k = np.mean(
|
|
[sum([xx == 1 for xx in x]) > 0 for x in queryid_to_results.values()]
|
|
)
|
|
avg_num_samples = np.mean([len(x) for x in queryid_to_results.values()])
|
|
logger.info(f"pass@k: {pass_at_k}, num_samples: {avg_num_samples}")
|
|
|
|
logger.info(f"number of samples: {len(scores)}, {scores.shape}")
|
|
|
|
logger.info(f"reward: {sum(scores) / len(scores)}")
|
|
|
|
train_pass_monitor_file_path = os.path.join(
|
|
constants.LOG_ROOT,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
"training_monitor",
|
|
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
|
|
)
|
|
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
|
|
logger.info(
|
|
f"pass monitor result will be dumped to: {train_pass_monitor_file_path}"
|
|
)
|
|
with open(train_pass_monitor_file_path, "w") as monitor_file:
|
|
for key, value in queryid_to_results.items():
|
|
pass1 = sum(value) / len(value)
|
|
pass8 = int(sum(value) > 0)
|
|
monitor_file.write(
|
|
json.dumps(
|
|
{"query_id": key, "pass1": pass1, "pass8": pass8},
|
|
ensure_ascii=False,
|
|
)
|
|
+ "\n"
|
|
)
|
|
|
|
model.inc_version()
|
|
|
|
if scores.dtype != torch.float32:
|
|
scores = scores.to(torch.float32)
|
|
if dense_scores.dtype != torch.float32:
|
|
dense_scores = dense_scores.to(torch.float32)
|
|
|
|
res = SequenceSample(
|
|
keys=["rewards", "dense_rewards"],
|
|
trailing_shapes=dict(rewards=(), dense_rewards=()),
|
|
dtypes=dict(rewards=torch.float32, dense_rewards=torch.float32),
|
|
ids=data_.ids,
|
|
seqlens=dict(
|
|
rewards=[
|
|
torch.tensor([1 for _ in range(len(x))], dtype=torch.int32)
|
|
for x in data_.seqlens["packed_input_ids"]
|
|
],
|
|
dense_rewards=data_.seqlens["packed_input_ids"],
|
|
),
|
|
data=dict(rewards=scores, dense_rewards=dense_scores),
|
|
)
|
|
|
|
# record rewards for each piece of data
|
|
avg_scores = []
|
|
offset = 0
|
|
for i in range(data_.bs):
|
|
score_lis = scores[
|
|
offset : offset + len(data_.seqlens["packed_input_ids"][i])
|
|
]
|
|
avg_scores.append(score_lis.mean().item())
|
|
offset += len(data_.seqlens["packed_input_ids"][i])
|
|
assert offset == sum(len(x) for x in data_.seqlens["packed_input_ids"])
|
|
|
|
res.metadata["scores"] = avg_scores
|
|
|
|
if self.check_verifier_status:
|
|
avg_score = torch.tensor(
|
|
np.mean(avg_scores), device=constants.current_device()
|
|
)
|
|
dist.all_reduce(
|
|
avg_score, op=dist.ReduceOp.SUM, group=constants.parallelism_group()
|
|
)
|
|
avg_score /= constants.parallelism_group_size()
|
|
avg_score = avg_score.item()
|
|
minimal_score = (-1 - self.output_bias) * self.rm_output_scaling
|
|
|
|
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
|
|
raise VerifierException(
|
|
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
|
|
)
|
|
|
|
if not constants.is_last_pipe_stage():
|
|
return None
|
|
return res
|
|
|
|
|
|
model_api.register_interface("reward", PackedRewardInterface)
|