mirror of https://github.com/inclusionAI/AReaL
511 lines
18 KiB
Python
511 lines
18 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
|
|
import ast
|
|
import asyncio
|
|
import dataclasses
|
|
import html
|
|
import json
|
|
import os
|
|
import re
|
|
import time
|
|
import xml.etree.ElementTree as ET
|
|
from typing import Dict, List, Tuple
|
|
|
|
import colorama
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
import realhf.api.core.model_api as model_api
|
|
import realhf.base.logging as logging
|
|
from functioncall.code.local_verify import code_verify as local_code_verify
|
|
from functioncall.code.verify import code_verify
|
|
from functioncall.math.verify import math_verify
|
|
from realhf.api.core.data_api import (
|
|
RL_TASKS,
|
|
MicroBatchSpec,
|
|
SequenceSample,
|
|
load_hf_tokenizer,
|
|
)
|
|
from realhf.base import constants
|
|
from realhf.base.datapack import flat2d
|
|
from realhf.impl.dataset.math_code_dataset import load_metadata
|
|
from realhf.impl.dataset.math_parser import parse_lines_in_parallel as math_verify_local
|
|
|
|
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
|
|
|
|
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
|
|
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else math_verify_local
|
|
code_verify_call = code_verify if ENABLE_FUNCTION_CALL else local_code_verify
|
|
|
|
|
|
class VerifierException(Exception):
|
|
pass
|
|
|
|
|
|
def extract_python_code(text, min_length=20, strict_syntax=True):
|
|
code_pattern = r"(?i)```(?:python|py)?\s*\n?(.*?)\n?```"
|
|
code_blocks = re.findall(code_pattern, text, re.DOTALL)
|
|
valid_blocks = []
|
|
for block in code_blocks:
|
|
clean_block = block.strip()
|
|
if len(clean_block) < min_length:
|
|
continue
|
|
|
|
# verify code syntax
|
|
if strict_syntax:
|
|
try:
|
|
ast.parse(clean_block, mode="exec")
|
|
except (SyntaxError, IndentationError):
|
|
continue
|
|
|
|
valid_blocks.append(clean_block)
|
|
|
|
if not valid_blocks:
|
|
# logger.warning(f"failed to extract python code from {text}")
|
|
return None
|
|
# return the last code block
|
|
return valid_blocks[-1]
|
|
|
|
|
|
def check_with_elementtree(text):
|
|
def escape_between_tags(text, tags=["think", "answer"]):
|
|
"""转义标签之间的内容,但保留标签本身."""
|
|
# 构建标签模式
|
|
tag_pattern = "|".join(tags)
|
|
parts = []
|
|
current_pos = 0
|
|
|
|
# 匹配开始和结束标签
|
|
pattern = f"</?({tag_pattern})[^>]*>"
|
|
|
|
for match in re.finditer(pattern, text):
|
|
# 添加标签之前的内容(需要转义)
|
|
if current_pos < match.start():
|
|
parts.append(html.escape(text[current_pos : match.start()]))
|
|
|
|
# 添加标签本身(不转义)
|
|
parts.append(match.group())
|
|
current_pos = match.end()
|
|
|
|
# 添加最后剩余的内容
|
|
if current_pos < len(text):
|
|
parts.append(html.escape(text[current_pos:]))
|
|
|
|
return "".join(parts)
|
|
|
|
text = escape_between_tags(text)
|
|
if not text.strip().startswith("<think>"):
|
|
text = "<think>" + text
|
|
try:
|
|
xml_text = f"<root>{text}</root>"
|
|
x = ET.fromstring(xml_text)
|
|
if x.text is not None and x.text.strip() != "":
|
|
return False, f"Error: extra content before <think>. {x.text}"
|
|
if len(x) != 2:
|
|
return False, f"Error: there are {len(x)} tags."
|
|
if x[0].tag is None or x[0].tag != "think":
|
|
return False, f"Error: <think> tag is missing. {x[0].tag}"
|
|
if x[0].tail is not None and x[0].tail.strip() != "":
|
|
return (
|
|
False,
|
|
f"Error: extra content between <think> and <answer>. {x[0].tail}",
|
|
)
|
|
if x[1].tag is None or x[1].tag != "answer":
|
|
return False, f"Error: <answer> tag is missing. {x[1].tag}"
|
|
if x[1].tail is not None and x[1].tail.strip() != "":
|
|
return False, f"Error: extra content after <answer>, {x[1].tail}"
|
|
|
|
return True, x[1].text if x[1].text is not None else ""
|
|
except ET.ParseError as e:
|
|
return False, f"Error: XML格式错误, {str(e)}"
|
|
|
|
|
|
id2info = {}
|
|
|
|
|
|
def dispatch_reward_calculation(task, answers, query_id_strs) -> List:
|
|
global id2info
|
|
assert len(answers) == len(query_id_strs)
|
|
format_rewards = []
|
|
if task == "math" or task == "stem":
|
|
format_rewards = math_verify_call(id2info, answers, query_id_strs)
|
|
elif task == "code":
|
|
codes = [extract_python_code(_answer) for _answer in answers]
|
|
format_rewards = code_verify_call(id2info, codes, query_id_strs)
|
|
assert len(format_rewards) == len(answers), (
|
|
task,
|
|
len(format_rewards),
|
|
len(answers),
|
|
answers,
|
|
)
|
|
return format_rewards
|
|
|
|
|
|
def retokenize_and_verify(
|
|
task,
|
|
tokenizer,
|
|
prompt_ids: List[List[int]],
|
|
seq_ids: List[List[int]],
|
|
query_ids: List[str],
|
|
check_xml_format=False,
|
|
):
|
|
seq_strs = tokenizer.batch_decode(
|
|
seq_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
|
)
|
|
prompt_strs = tokenizer.batch_decode(
|
|
prompt_ids, clean_up_tokenization_spaces=False, skip_special_tokens=True
|
|
)
|
|
# query_id_strs = query_ids
|
|
query_id_strs = [query_id.split("@")[0] for query_id in query_ids]
|
|
|
|
answers = [
|
|
seq_str.split(prompt_str)[1]
|
|
for seq_str, prompt_str in zip(seq_strs, prompt_strs)
|
|
]
|
|
|
|
format_rewards = dispatch_reward_calculation(task, answers, query_id_strs)
|
|
|
|
if check_xml_format:
|
|
for idx, answer in enumerate(answers):
|
|
xml_reward, _ = check_with_elementtree(answer)
|
|
if xml_reward == 1 and format_rewards[idx] == 0:
|
|
format_rewards[idx] = -0.8
|
|
elif xml_reward == 0 and format_rewards[idx] == 0:
|
|
format_rewards[idx] = -1
|
|
|
|
return format_rewards, prompt_strs, seq_strs
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MultiTaskRewardInterface(model_api.ModelInterface):
|
|
dataset_path: str = ""
|
|
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
|
output_scaling: float = 1.0
|
|
output_bias: float = 0.0
|
|
rw_type: str = "sparse"
|
|
check_xml_format: bool = False
|
|
group_size: int = 1
|
|
check_verifier_status: bool = False
|
|
|
|
def __post_init__(self):
|
|
global id2info
|
|
id2info, _ = load_metadata(self.dataset_path)
|
|
self.tokenizer = load_hf_tokenizer(self.tokenizer_path)
|
|
if constants.parallelism_rank() == 0:
|
|
logger.info(f"output_scaling: {self.output_scaling}")
|
|
logger.info(f"output_bias: {self.output_bias}")
|
|
logger.info(f"rw_type: {self.rw_type}")
|
|
|
|
def _dispatch_tasks(self, data: SequenceSample) -> Tuple[Dict, Dict]:
|
|
xs = data.unpack()
|
|
dispatched = {}
|
|
dispatched_indices = {}
|
|
for task_idx, task_name in enumerate(RL_TASKS):
|
|
indices = (
|
|
(data.data["task_ids"] == task_idx).cpu().numpy().nonzero()[0].tolist()
|
|
)
|
|
if len(indices) > 0:
|
|
dispatched[task_name] = SequenceSample.gather([xs[i] for i in indices])
|
|
dispatched_indices[task_name] = indices
|
|
|
|
return dispatched, dispatched_indices
|
|
|
|
def _gather_tasks(
|
|
self, results: Dict, dispatched_indices: Dict, bs: int
|
|
) -> SequenceSample:
|
|
xs = [None for _ in range(bs)]
|
|
for task_name, indices in dispatched_indices.items():
|
|
xxs = results[task_name].unpack()
|
|
assert len(indices) == len(xxs), (len(indices), len(xxs))
|
|
for i, xx in zip(indices, xxs):
|
|
xs[i] = xx
|
|
assert all(xs)
|
|
return SequenceSample.gather(xs)
|
|
|
|
def _dispatch_tp_and_pp(self, data: SequenceSample):
|
|
tp_pp_size = constants.tp_and_pp_world_size()
|
|
if tp_pp_size == 1:
|
|
return data, None
|
|
splitted, _, backward_indices = data.split(
|
|
mb_spec=MicroBatchSpec(n_mbs=tp_pp_size)
|
|
)
|
|
tp_pp_rank = constants.tp_and_pp_rank()
|
|
print("dispatched batch size", [s.bs for s in splitted], flush=True)
|
|
return splitted[tp_pp_rank], backward_indices
|
|
|
|
def _gather_tp_and_pp(self, input_, data: SequenceSample, backward_indices):
|
|
tp_pp_size = constants.tp_and_pp_world_size()
|
|
if tp_pp_size == 1:
|
|
return data
|
|
local_rank = constants.grid().topo.get_rank(
|
|
data=constants.data_parallel_rank(),
|
|
tensor=0,
|
|
pipe=constants.pipe_parallel_world_size() - 1,
|
|
)
|
|
dst = constants.to_global_pg_rank(local_rank)
|
|
gather_list = None
|
|
if dist.get_rank() == dst:
|
|
gather_list = [None for _ in range(tp_pp_size)]
|
|
x = data.data["rewards"].cpu().numpy().tolist()
|
|
print(x, flush=True)
|
|
dist.gather_object(
|
|
x, gather_list, dst=dst, group=constants.tp_and_pp_cpu_group()
|
|
)
|
|
if dist.get_rank() != dst:
|
|
return None
|
|
gathered = np.array(gather_list).reshape(-1, self.group_size)
|
|
assert len(gathered) == len(backward_indices)
|
|
rewards = (
|
|
np.concatenate([gathered[i] for i in backward_indices]).flatten().tolist()
|
|
)
|
|
return SequenceSample(
|
|
keys=["rewards"],
|
|
trailing_shapes=dict(rewards=()),
|
|
dtypes=dict(rewards=torch.float32),
|
|
ids=input_.ids,
|
|
seqlens=dict(
|
|
rewards=[[1 for _ in range(self.group_size)] for _ in range(input_.bs)],
|
|
),
|
|
data=dict(rewards=torch.tensor(rewards, dtype=torch.float32)),
|
|
)
|
|
|
|
def calculate_task_reward(
|
|
self,
|
|
model: model_api.Model,
|
|
data: SequenceSample,
|
|
mb_spec: MicroBatchSpec,
|
|
task_type: str,
|
|
):
|
|
# mb_spec is disrespected here
|
|
packed_input_ids: torch.Tensor = data.data["packed_input_ids"]
|
|
input_seqlens = flat2d(data.seqlens["packed_input_ids"])
|
|
seq_ids = []
|
|
offset = 0
|
|
for slen in input_seqlens:
|
|
seq_ids.append(
|
|
packed_input_ids[offset : offset + slen].cpu().numpy().tolist()
|
|
)
|
|
offset += slen
|
|
assert offset == packed_input_ids.shape[0], (offset, packed_input_ids.shape)
|
|
prompt_input_ids = data.data["packed_prompts"]
|
|
prompt_len = flat2d(data.seqlens["packed_prompts"])
|
|
prompt_ids = []
|
|
offset = 0
|
|
for slen in prompt_len:
|
|
p = prompt_input_ids[offset : offset + slen].cpu().numpy().tolist()
|
|
prompt_ids += [p] * self.group_size
|
|
offset += slen
|
|
format_rewards, prompt_strs, seq_strs = retokenize_and_verify(
|
|
task_type,
|
|
self.tokenizer,
|
|
prompt_ids=prompt_ids,
|
|
seq_ids=seq_ids,
|
|
query_ids=[
|
|
str(data_id) for data_id in data.ids for _ in range(self.group_size)
|
|
],
|
|
check_xml_format=self.check_xml_format,
|
|
)
|
|
scores = torch.FloatTensor(format_rewards).to(packed_input_ids.device)
|
|
scores[scores == 0] = -1
|
|
|
|
scores = (
|
|
scores.to(packed_input_ids.device) - self.output_bias
|
|
) * self.output_scaling
|
|
|
|
self.log_rewards_to_file(task_type, model, prompt_strs, seq_strs, scores)
|
|
|
|
res = SequenceSample(
|
|
keys=["rewards"],
|
|
trailing_shapes=dict(rewards=()),
|
|
dtypes=dict(rewards=torch.float32),
|
|
ids=data.ids,
|
|
seqlens=dict(
|
|
rewards=[
|
|
[1 for _ in range(len(x))] for x in data.seqlens["packed_input_ids"]
|
|
],
|
|
),
|
|
data=dict(rewards=scores),
|
|
)
|
|
|
|
# record rewards for each piece of data
|
|
avg_scores = []
|
|
offset = 0
|
|
for i in range(data.bs):
|
|
score_lis = scores[
|
|
offset : offset + len(data.seqlens["packed_input_ids"][i])
|
|
]
|
|
avg_scores.append(score_lis.mean().item())
|
|
offset += len(data.seqlens["packed_input_ids"][i])
|
|
assert offset == sum(len(x) for x in data.seqlens["packed_input_ids"])
|
|
|
|
res.metadata["scores"] = avg_scores
|
|
|
|
if self.check_verifier_status:
|
|
avg_score = torch.tensor(
|
|
np.mean(avg_scores), device=constants.current_device()
|
|
)
|
|
dist.all_reduce(
|
|
avg_score, op=dist.ReduceOp.SUM, group=constants.data_parallel_group()
|
|
)
|
|
avg_score /= constants.data_parallel_group()
|
|
avg_score = avg_score.item()
|
|
minimal_score = (-1 - self.output_bias) * self.output_scaling
|
|
|
|
if avg_score <= minimal_score or np.isclose(avg_score, minimal_score):
|
|
raise VerifierException(
|
|
"All rewards are at minimal value. Probably there are something wrong with the verifier!"
|
|
)
|
|
return res
|
|
|
|
def log_rewards_to_file(
|
|
self, task_type: str, model: model_api.Model, prompt_strs, seq_strs, scores
|
|
):
|
|
tik = time.perf_counter()
|
|
gen_file_path = os.path.join(
|
|
constants.LOG_ROOT,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
"generated",
|
|
task_type,
|
|
f"v{model.version.global_step}r{dist.get_rank()}.txt",
|
|
)
|
|
|
|
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
|
with open(gen_file_path, "w") as _f:
|
|
for idx, (score, prompt_str, seq_str) in enumerate(
|
|
zip(scores, prompt_strs, seq_strs)
|
|
):
|
|
info = "\n".join(
|
|
[
|
|
f"idx: {idx} / {len(scores)}",
|
|
f"reward is {score.item()}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt_str}{colorama.Style.RESET_ALL}",
|
|
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{seq_str.split(prompt_str)[1]}{colorama.Style.RESET_ALL}.",
|
|
]
|
|
)
|
|
_f.write(info + "\n")
|
|
|
|
gen_file_path = os.path.join(
|
|
constants.LOG_ROOT,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
"generated_jsonl",
|
|
task_type,
|
|
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
|
|
)
|
|
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
|
|
with open(gen_file_path, "w") as _f:
|
|
for idx, (score, prompt_str, seq_str) in enumerate(
|
|
zip(scores, prompt_strs, seq_strs)
|
|
):
|
|
_f.write(
|
|
json.dumps(
|
|
{
|
|
"prompt": prompt_str,
|
|
"generated": seq_str.split(prompt_str)[1],
|
|
"reward": score.item(),
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
+ "\n"
|
|
)
|
|
|
|
logger.info(f"[{task_type}] number of samples: {len(scores)}, {scores.shape}")
|
|
logger.info(f"[{task_type}] avg reward: {sum(scores) / len(scores)}")
|
|
logger.info(f"[{task_type}] log to file time: {time.perf_counter()- tik:.2f}s")
|
|
|
|
def inference(
|
|
self,
|
|
model: model_api.Model,
|
|
data: SequenceSample,
|
|
mb_spec: MicroBatchSpec,
|
|
) -> SequenceSample | None:
|
|
input_ = data
|
|
data, backward_indices = self._dispatch_tp_and_pp(data)
|
|
task_data, dispatch_indices = self._dispatch_tasks(data)
|
|
|
|
assert self.rw_type == "sparse"
|
|
|
|
def _task_func(func, task_type: str):
|
|
def _wrapped_func(*args, **kwargs):
|
|
start_time = time.perf_counter()
|
|
try:
|
|
result = func(*args, **kwargs)
|
|
except Exception as e:
|
|
raise asyncio.CancelledError(
|
|
f"[{task_type}] task failed: {e}"
|
|
) from e
|
|
finally:
|
|
duration = time.perf_counter() - start_time
|
|
logger.info(f"[{task_type}] time cost: {duration:.4f}s")
|
|
return task_type, result
|
|
|
|
return _wrapped_func
|
|
|
|
async def _run_tasks():
|
|
tasks = []
|
|
for task_type, d in task_data.items():
|
|
task_func = _task_func(self.calculate_task_reward, task_type)
|
|
task_args = (model, d, mb_spec, task_type)
|
|
task = asyncio.create_task(asyncio.to_thread(task_func, *task_args))
|
|
tasks.append(task)
|
|
|
|
results = await asyncio.gather(*tasks)
|
|
task_results = {}
|
|
for res in results:
|
|
task_type, result = res
|
|
task_results[task_type] = result
|
|
|
|
return task_results
|
|
|
|
task_results = asyncio.run(_run_tasks())
|
|
final_result = self._gather_tasks(task_results, dispatch_indices, data.bs)
|
|
final_result = self._gather_tp_and_pp(input_, final_result, backward_indices)
|
|
|
|
model.inc_version()
|
|
|
|
return final_result
|
|
|
|
def _mock_inference(
|
|
self,
|
|
model: model_api.Model,
|
|
data: SequenceSample,
|
|
) -> SequenceSample:
|
|
|
|
prompt_lens = flat2d(data.seqlens["packed_prompts"])
|
|
task_ids = data.data["task_ids"].cpu().numpy().tolist()
|
|
seqlens = []
|
|
offset = 0
|
|
seq = []
|
|
for plen, task_id in zip(prompt_lens, task_ids):
|
|
seq += [data.data["packed_prompts"][offset : offset + plen]]
|
|
offset += plen
|
|
if task_id == RL_TASKS.index("math"):
|
|
answer_str = (
|
|
"something unimportant but the answer is \\boxed{-\\frac{2}{3}}."
|
|
)
|
|
elif task_id == RL_TASKS.index("code"):
|
|
answer_str = (
|
|
"```python\ninput()\nimport time\ntime.sleep(1e-3)\nprint(1)\n```"
|
|
)
|
|
else:
|
|
answer_str = "something unimportant"
|
|
encoding = model.tokenizer(
|
|
[answer_str], add_special_tokens=True, return_attention_mask=False
|
|
)
|
|
|
|
ans = torch.tensor(encoding["input_ids"], dtype=torch.long).flatten()
|
|
seq += [ans]
|
|
seqlens.append(plen + len(ans))
|
|
|
|
x = SequenceSample.from_default(
|
|
seqlens=seqlens,
|
|
ids=data.ids,
|
|
data=dict(packed_input_ids=torch.cat(seq)),
|
|
)
|
|
data.update_(x)
|
|
return data
|
|
|
|
|
|
model_api.register_interface("rw-math-code", MultiTaskRewardInterface)
|