[Feature] Add multi-turn agent example (#72)

* update benchmark script

* .

* add benchmark docs

* PullRequest: 178 multi turn math agent training

Merge branch gjx/multi-turn-math of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/178?tab=diff

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* multi turn math agent training
* training data logging and clean math multi-turn exp
* fix
* .

* format

* .

---------

Co-authored-by: 步偶 <sam.gjx@antgroup.com>
This commit is contained in:
Wei Fu 2025-06-08 20:36:52 +08:00 committed by GitHub
parent 448bb05a3d
commit 54daceb6dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 300 additions and 0 deletions

View File

@ -1,2 +1,3 @@
import realhf.impl.agent.math_multi_turn_agent
import realhf.impl.agent.math_single_step_agent
import realhf.impl.agent.null_agent

View File

@ -0,0 +1,299 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import json
import os
from datetime import datetime
from typing import List
import colorama
import numpy as np
import torch
from realhf.api.core.agent_api import Agent, register_agent
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer
from realhf.api.core.env_api import EnvironmentService
from realhf.api.core.model_api import BundledGenerationOutputs
from realhf.base import constants, logging
logger = logging.getLogger("Math Code Agent")
class MathMultiTurnAgent(Agent):
"""A multi-turn reasoning agent for mathematical tasks.
In each turn the agent produces an answer and receives evaluation results from the environment.
By default, we use 4 turns with a token budget=1K at each round.
"""
def __init__(
self,
gconfig,
tokenizer_path,
reward_scaling=1.0,
reward_bias=0.0,
turn_level_discount: float = 1.0,
num_turns: int = 5,
):
self.gconfig = gconfig.new(n=1)
self.tokenizer = load_hf_tokenizer(tokenizer_path)
self.reward_scaling = reward_scaling
self.reward_bias = reward_bias
self.turn_level_discount = turn_level_discount
self.num_turns = num_turns
async def collect_trajectory(
self,
prompt: SequenceSample,
env: EnvironmentService,
obs_queue: asyncio.Queue,
act_queue: asyncio.Queue,
) -> List[SequenceSample]:
# reset does nothing, just to make it like multi-step environments
await env.reset()
assert prompt.bs == 1
assert self.gconfig.n == 1
prompt_token_ids = prompt.data["packed_prompts"].cpu().numpy().tolist()
qid = prompt.ids[0]
birth_time = int(datetime.now().timestamp() * 1000)
prompt_str = self.tokenizer.batch_decode(
[prompt_token_ids],
clean_up_tokenization_spaces=False,
skip_special_tokens=True,
)[0]
token_ids = prompt_token_ids
all_rewards = []
all_answers = []
all_success = []
x = dict(
keys=[
"packed_input_ids",
"prompt_mask",
"packed_logprobs",
"seq_no_eos_mask",
"packed_prompts",
"version_start",
"version_end",
"rewards",
"birth_time",
],
ids=[qid],
dtypes=dict(
packed_prompts=torch.long,
packed_input_ids=torch.long,
prompt_mask=torch.bool,
seq_no_eos_mask=torch.bool,
version_start=torch.int,
version_end=torch.int,
packed_logprobs=torch.float32,
rewards=torch.float32,
birth_time=torch.long,
),
trailing_shapes=dict(
packed_input_ids=(),
prompt_mask=(),
seq_no_eos_mask=(),
packed_prompts=(),
version_end=(),
version_start=(),
packed_logprobs=(),
rewards=(),
birth_time=(),
),
seqlens=dict(
packed_input_ids=[[]],
packed_logprobs=[[]],
packed_prompts=[[len(prompt_token_ids)]],
prompt_mask=[[]],
seq_no_eos_mask=[[1 for _ in range(self.num_turns)]],
rewards=[[1 for _ in range(self.num_turns)]],
version_start=[[1 for _ in range(self.num_turns)]],
version_end=[[1 for _ in range(self.num_turns)]],
birth_time=[[1]],
),
data=dict(
packed_prompts=list(prompt_token_ids),
packed_logprobs=[],
packed_input_ids=[],
seq_no_eos_mask=[],
rewards=[],
version_start=[],
version_end=[],
birth_time=torch.tensor([birth_time], dtype=torch.long),
prompt_mask=[],
),
)
for turn in range(self.num_turns):
await obs_queue.put((qid, token_ids, self.gconfig))
act: BundledGenerationOutputs = await act_queue.get()
seq_strs = self.tokenizer.batch_decode(
act.seqs,
clean_up_tokenization_spaces=False,
skip_special_tokens=True,
)
prompt_str = self.tokenizer.batch_decode(
[act.prompt_ids],
clean_up_tokenization_spaces=False,
skip_special_tokens=True,
)[0]
answers = [seq_str.split(prompt_str)[1] for seq_str in seq_strs]
# single-step env for evaluating generated solutions
_, success, *_ = await env.step((qid, answers))
rewards = [
((float(r) - 0.5) * 2 - self.reward_bias) * self.reward_scaling
for r in success
]
all_success.extend(success)
all_answers.extend(answers)
x["data"]["packed_input_ids"].extend(list(act.seqs[0]))
x["data"]["packed_logprobs"].extend(list(act.logprobs[0]))
x["data"]["seq_no_eos_mask"].append(act.no_eos[0])
all_rewards.append(rewards[0])
x["data"]["prompt_mask"].extend(
[1] * act.prompt_len + [0] * (act.seqlens[0] - act.prompt_len)
)
x["data"]["version_start"].extend(list(act.version_start))
x["data"]["version_end"].extend(list(act.version_end))
x["seqlens"]["packed_input_ids"][0].append(act.seqlens[0])
x["seqlens"]["packed_logprobs"][0].append(act.seqlens[0] - 1)
x["seqlens"]["prompt_mask"][0].append(act.seqlens[0])
token_ids = list(act.seqs[0])
feedback = None
if success[0]:
feedback = "Congratulations! You are correct!"
else:
feedback = "Unfortunately your answer is wrong. Let's try again."
feedback = "\n" + self.tokenizer.apply_chat_template(
[dict(content=feedback, role="user")],
add_generation_prompt=True,
tokenize=False,
)
feedback = self.tokenizer(feedback)["input_ids"]
token_ids.extend(feedback)
self.log_rewards_to_file(
str(qid),
prompt_str,
seqlens=x["seqlens"]["packed_input_ids"][0],
answers=all_answers,
prompt_len=len(prompt_token_ids),
rewards=all_rewards,
success=all_success,
version_starts=x["data"]["version_start"],
version_ends=x["data"]["version_end"],
)
for i in reversed(range(len(all_rewards) - 1)):
all_rewards[i] = (
all_rewards[i] + all_rewards[i + 1] * self.turn_level_discount
)
x["data"]["rewards"] = all_rewards
for k in x["keys"]:
if not isinstance(x["data"][k], torch.Tensor):
x["data"][k] = torch.tensor(x["data"][k], dtype=x["dtypes"][k])
x = SequenceSample(**x)
if "task_ids" in prompt.keys:
y = SequenceSample(
keys=["task_ids"],
ids=[qid],
dtypes=dict(task_ids=torch.long),
trailing_shapes=dict(task_ids=()),
seqlens=dict(task_ids=[[1]]),
data=dict(task_ids=prompt.data["task_ids"]),
)
x.update_(y)
return [x]
def log_rewards_to_file(
self,
qid: str,
prompt: str,
prompt_len: int,
answers: List[str],
seqlens: List[int],
rewards: List[float],
success: List[bool],
version_starts: List[int],
version_ends: List[int],
):
group_size = len(answers)
for group_idx in range(group_size):
# NOTE: we can ensure that only one process is logging this query id
gen_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"generated",
str(version_starts[group_idx]),
f"{qid}.txt",
)
os.makedirs(os.path.dirname(gen_file_path), exist_ok=True)
version_start = version_starts[group_idx]
version_end = version_ends[group_idx]
reward = rewards[group_idx]
answer = answers[group_idx]
seqlen = seqlens[group_idx]
with open(gen_file_path, "a") as _f:
info = "\n".join(
[
f"idx: {group_idx + 1} / {group_size}, seqlen: {seqlen}, "
f"head version: {version_start}, tail version: {version_end}.",
f"reward is {reward}, prompt is {colorama.Fore.YELLOW + colorama.Style.DIM}{prompt}{colorama.Style.RESET_ALL}",
f"sequence is: {colorama.Fore.YELLOW + colorama.Style.DIM}{answer}{colorama.Style.RESET_ALL}.",
]
)
_f.write(info + "\n")
train_pass_monitor_file_path = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"training_monitor",
str(version_starts[group_idx]),
f"{qid}.jsonl",
)
os.makedirs(os.path.dirname(train_pass_monitor_file_path), exist_ok=True)
with open(train_pass_monitor_file_path, "a") as monitor_file:
monitor_file.write(
json.dumps(
{
"version_start": int(version_start),
"version_end": int(version_end),
"success": bool(success),
"prompt_len": prompt_len,
"answer_len": seqlen - prompt_len,
},
ensure_ascii=False,
)
+ "\n"
)
register_agent("math-multi-turn", MathMultiTurnAgent)