0724_merge8

This commit is contained in:
朱晗 2025-07-24 15:57:36 +08:00
commit e705db12f4
1 changed files with 276 additions and 65 deletions

View File

@ -1,17 +1,20 @@
# Rollout and Agentic RL
This guide shows you how to create custom rollout behaviors for RL training by building
a multi-turn math agent with **AReaLite**. This agent keeps trying to solve math
problems until it finds the correct answer.
This guide demonstrates how to customize rollout behavior for PPO training by
implementing a multi-turn math agent that uses end-to-end reinforcement learning. Our
example agent will continuously try to solve a math problem until it reaches the correct
answer.
You can find the complete implementation in `arealite/workflow/multi_turn.py`.
## Approach: Using AReaLite (Recommended)
## Step 1: Define Your Workflow
The complete implementation is placed at `arealite/workflow/multi_turn.py`.
AReaLite gives you flexibility in how you design your agents. Instead of rigid `Agent`
classes that might constrain your agent's capabilities, AReaLite captures all rollout
behavior in a `RolloutWorkflow` class. This approach lets you customize your agent's
behavior however you need.
### Step 1: Define Your Workflow
AReaLite takes a flexible approach to agent definition. Rather than using rigid `Agent`
classes that might limit your agentic capabilities, AReaLite captures all rollout
behavior in a `RolloutWorkflow` class. This design gives you complete freedom to
customize your agent's behavior.
```python
# arealite/api/workflow_api.py
@ -26,8 +29,8 @@ class RolloutWorkflow:
raise NotImplementedError()
```
The workflow exposes an `arun_episode` method that runs and collects data from a single
episode. This method takes two key arguments:
The workflow exposes a single `arun_episode` method that runs and collects data from a
single episode. This method takes two key arguments:
1. **InferenceEngine**: Provides the `agenerate` method for generating responses to user
inputs
@ -36,19 +39,14 @@ episode. This method takes two key arguments:
Within this method, you have complete control over how your agent and environment
interact.
> **Note**: Each `arun_episode` call takes a single prompt and outputs the trajectories
> generated from that prompt—it's not batched. However, you can generate multiple
> trajectories from a single prompt (for example, with GRPO or tree search).
### Setting Up the Multi-Turn Math Workflow
#### Setting Up the Multi-Turn Math Workflow
Let's build a multi-turn rollout workflow for solving math problems. First, we'll define
the `__init__` method to set up what we need during rollout:
the `__init__` method to capture the utilities we need during rollout:
> **Note**: You have complete flexibility in defining the `__init__` method. Pass
> whatever arguments you need to construct your workflow. If you want to use tools, pass
> the corresponding environment here so your agent can call it in the `arun_episode`
> method.
> **Note**: You have complete flexibility in defining the `__init__` method. Pass any
> arguments needed to construct your workflow. If you want to use tools, pass the
> corresponding environment here so your agent can call it in the `arun_episode` method.
```python
class MultiTurnWorkflow(RolloutWorkflow):
@ -68,7 +66,7 @@ class MultiTurnWorkflow(RolloutWorkflow):
self.turn_discount = turn_discount
```
### Implementing the Episode Logic
#### Implementing the Episode Logic
Now let's implement the `arun_episode` method. We'll start by tokenizing the prompt data
and converting it into an `LLMRequest` object for the inference engine:
@ -102,20 +100,20 @@ class MultiTurnWorkflow(RolloutWorkflow):
# ... continue processing ...
```
> **Note**: This example uses the "messages" key from the prompt data to get
> OpenAI-compatible messages. This isn't required—the key and prompt format depend
> **Note**: This example accesses the "messages" key from the prompt data to get
> OpenAI-compatible messages. This isn't mandatory—the key and prompt format depend
> entirely on your implementation. For instance, if your dataset stores prompt strings
> in a "prompt" column, you could get input token IDs with
> in a "prompt" column, you could get input token IDs via
> `self.tokenizer.encode(data["prompt"])`.
> **Note**: The `rid` field in `LLMRequest` is the request ID. Requests with the same ID
> will reuse the LLM inference server's KV caches for better efficiency.
> will reuse the LLM inference server's KV caches for efficiency.
### Handling Multi-Turn Conversations
#### Handling Multi-Turn Conversations
Next, we'll check if the current answer is correct using our `reward_fn`. This function
should return 1 for correct answers and 0 otherwise. When the answer is wrong, we'll
apply a discount, add feedback to the conversation, and let the model try again:
Next, we'll evaluate whether the current answer is correct using our `reward_fn`. This
function should return 1 for correct answers and 0 otherwise. When the answer is wrong,
we'll apply a discount, add feedback to the conversation, and let the model try again:
```python
class MultiTurnWorkflow(RolloutWorkflow):
@ -150,10 +148,10 @@ class MultiTurnWorkflow(RolloutWorkflow):
discount *= self.turn_discount
```
### Reward Function Signature
#### Reward Function Signature
To make it easier to switch between different reward functions, we recommend following
this signature:
For convenience when switching between different reward functions, we recommend
following this pre-defined signature:
```python
def reward_fn(
@ -180,10 +178,10 @@ def reward_fn(
"""
```
While this signature is convenient, you're not restricted to it in custom
workflows—modify as needed for your specific use case.
While this signature is convenient, there are no strict restrictions on reward functions
in custom workflows—modify them as needed for your specific use case.
### Collecting Training Data
#### Collecting Training Data
Finally, let's complete the implementation by collecting trajectories in the
`TensorDict` format:
@ -218,6 +216,191 @@ class MultiTurnWorkflow(RolloutWorkflow):
return concat_padded_tensors([res])
```
> **Important**: The returned `TensorDict` must follow HuggingFace's padded data format,
> where each tensor has shape `[batch_size, sequence_length, *]`. This allows AReaLite
> to automatically batch multiple trajectories for the training engine. Since this
> example returns a single trajectory, we use `unsqueeze(0)` to create a size-1 batch.
> **Note**: There are no restrictions on the keys in your `TensorDict`—different
> algorithms require different keys. This example targets the GRPO algorithm, so we
> include `input_ids`, `loss_mask`, `attention_mask`, and `logprobs` (needed for
> computing importance ratios).
### Step 2: Training with Your Custom Workflow
Using your custom workflow is straightforward—just construct it in your training script
and pass it to the `rollout_batch` or `prepare_batch` method:
```python
def main(args):
# ... setup code ...
# Create your custom workflow
workflow = MultiTurnWorkflow(
reward_fn=gsm8k_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
turn_discount=0.9,
max_turns=5,
)
# Run training—no other changes needed!
data_generator = iter(train_dataloader)
for global_step in range(max_steps):
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
# ... continue with training loop ...
```
That's it! Your custom multi-turn math agent is now ready to train with reinforcement
learning. The workflow will automatically handle the multi-turn conversations, reward
computation, and data collection needed for effective RL training.
## Alternative Approach: Using the Legacy Version (Not Recommended)
While we strongly recommend using AReaLite for new projects, you might encounter legacy
code that uses the older Agent-based approach. Here's how it works for reference, though
we suggest migrating to the workflow-based system when possible.
### Step 1: Define Your Agent Class
Create a new file under `realhf/impl/agent/`, such as `math_multi_turn_agent.py`. Your
`Agent` must implement the interface defined in `realhf/api/core/agent.py`, which
requires a single method: `collect_trajectory`.
```python
class MathMultiTurnAgent(Agent):
async def collect_trajectory(
self,
reward_fn,
gconfig: GenerationHyperparameters, # aka sampling_params
tokenizer: PreTrainedTokenizerFast,
max_turns: int,
turn_discount: float,
):
# Implementation goes here
...
```
### Step 2: Implement the Trajectory Collection Logic
The `collect_trajectory` method takes a task prompt, an environment, and two
communication queues. Within this method, you control the data flow between your agent
and the inference engine using these queues:
- **obs_queue**: Send observations (token IDs and generation config) to the inference
engine
- **act_queue**: Receive actions (generated responses) from the inference engine
Here's how the multi-turn conversation works:
```python
for turn in range(self.num_turns):
# Send the current state to the inference engine
await obs_queue.put((qid, token_ids, self.gconfig))
# Get the generated response
act: BundledGenerationOutputs = await act_queue.get()
# Evaluate the response through the environment
success, rewards = await env.step((qid, answers))
# ... process results ...
```
#### Environment Integration
The environment follows a
[Gym-like interface](https://github.com/Farama-Foundation/Gymnasium) with `reset` and
`step` methods, but uses asynchronous implementations to prevent blocking across
different environment instances.
For math problems, the environment is typically stateless and acts as a wrapper around
your reward function:
```python
class MathCodeSingleStepEnv(EnvironmentService):
async def step(self, action: Tuple[str, List[str]]):
qid, answers = action
# ... setup code ...
# Run reward computation asynchronously
format_rewards = await asyncio.to_thread(
math_verify_call,
answers,
# ... other parameters ...
)
return None, format_rewards, True, False, {}
```
#### Handling Multi-Turn Feedback
After receiving the reward from `env.step`, check if the answer is correct. If not,
provide feedback and continue to the next turn:
```python
for turn in range(self.num_turns):
# ... generation and evaluation code ...
# Provide feedback based on the result
if success[0]:
feedback = "Congratulations! You are correct!"
else:
feedback = "Unfortunately your answer is wrong. Let's try again."
# Format feedback as a user message
feedback = "\n" + self.tokenizer.apply_chat_template(
[{"content": feedback, "role": "user"}],
add_generation_prompt=True,
tokenize=False,
)
# Add feedback tokens to the conversation
feedback_tokens = self.tokenizer(feedback)["input_ids"]
token_ids.extend(feedback_tokens)
```
### Step 3: Register and Configure Your Agent
First, register your agent implementation:
```python
class MultiTurnWorkflow(RolloutWorkflow):
# ... previous methods ...
async def arun_episode(self, engine: InferenceEngine, data):
# ... episode logic above ...
while reward == 0 and t < self.max_turns:
# ... generation and evaluation ...
# Collect trajectory data
input_len = len(resp.input_tokens) - len(seq)
seq += resp.input_tokens[-input_len:] + resp.output_tokens
logprobs += [0.0] * input_len + resp.output_logprobs
loss_mask += [0] * input_len + [1] * resp.output_len
versions += [-1] * input_len + resp.output_versions
# Package results
res = dict(
input_ids=torch.tensor(seq),
logprobs=torch.tensor(logprobs),
loss_mask=torch.tensor(loss_mask),
versions=torch.tensor(versions),
rewards=torch.tensor(float(reward * discount)),
attention_mask=torch.ones(len(seq), dtype=torch.bool),
)
res = {k: v.unsqueeze(0) for k, v in res.items()}
return concat_padded_tensors([res])
```
> **Important**: The returned `TensorDict` must follow HuggingFace's padded data format,
> where each tensor has shape `[batch_size, sequence_length, *]`. This allows AReaLite
> to automatically batch multiple trajectories for training. Since this example returns
@ -234,34 +417,62 @@ Using your custom workflow is straightforward—just create it in your training
pass it to the `rollout_batch` or `prepare_batch` method:
```python
def main(args):
# ... setup code ...
# Create your custom workflow
workflow = MultiTurnWorkflow(
reward_fn=gsm8k_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
turn_discount=0.9,
max_turns=5,
)
# Run training—no other changes needed!
data_generator = iter(train_dataloader)
for global_step in range(max_steps):
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
# ... continue with training loop ...
# in realhf/impl/agent/__init__.py
import realhf.impl.agent.math_multi_turn_agent
```
That's it! Your custom multi-turn math agent is now ready for reinforcement learning
training. The workflow will automatically handle the multi-turn conversations, reward
computation, and data collection needed for effective RL training.
Then update your experiment configuration in
`realhf/experiments/async_exp/async_math_ppo.py`:
```python
@dataclasses.dataclass
class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
# Add any new CLI arguments your agent needs
my_param: float = 1.0
@property
def agent(self) -> AgentAbstraction:
return AgentAbstraction(
"math-multi-turn", # Your registered agent name
args=dict(
# Pass any arguments needed for your __init__ method
my_param=self.my_param,
# ... other configuration ...
),
)
@property
def env(self) -> EnvServiceAbstraction:
# Update to use your custom environment if needed
return EnvServiceAbstraction(
"math-code-single-step",
args=dict(dataset_path=self.dataset.path)
)
```
### Step 4: Run Training
Follow the standard training procedure outlined in the
[quickstart guide](../tutorial/quickstart.md). Launch your experiment with:
```bash
python3 training/main_async_ppo.py my_param=5.0 # plus any additional CLI arguments
```
### Training Results
Here's an example of the training reward curve from our multi-turn math agent:
![Multi-turn Training Rewards](multiturn_reward.png)
The agent successfully learns to solve math problems with improved accuracy over time,
demonstrating the effectiveness of the multi-turn approach.
______________________________________________________________________
**Note**: While this legacy approach works, we strongly recommend using the AReaLite
workflow system for new projects. It provides better flexibility, cleaner abstractions,
and easier maintenance. Consider migrating existing legacy agents to the workflow-based
approach when possible.
Happy coding!