mirror of https://github.com/inclusionAI/AReaL
0724_merge8
This commit is contained in:
commit
e705db12f4
|
@ -1,17 +1,20 @@
|
|||
# Rollout and Agentic RL
|
||||
|
||||
This guide shows you how to create custom rollout behaviors for RL training by building
|
||||
a multi-turn math agent with **AReaLite**. This agent keeps trying to solve math
|
||||
problems until it finds the correct answer.
|
||||
This guide demonstrates how to customize rollout behavior for PPO training by
|
||||
implementing a multi-turn math agent that uses end-to-end reinforcement learning. Our
|
||||
example agent will continuously try to solve a math problem until it reaches the correct
|
||||
answer.
|
||||
|
||||
You can find the complete implementation in `arealite/workflow/multi_turn.py`.
|
||||
## Approach: Using AReaLite (Recommended)
|
||||
|
||||
## Step 1: Define Your Workflow
|
||||
The complete implementation is placed at `arealite/workflow/multi_turn.py`.
|
||||
|
||||
AReaLite gives you flexibility in how you design your agents. Instead of rigid `Agent`
|
||||
classes that might constrain your agent's capabilities, AReaLite captures all rollout
|
||||
behavior in a `RolloutWorkflow` class. This approach lets you customize your agent's
|
||||
behavior however you need.
|
||||
### Step 1: Define Your Workflow
|
||||
|
||||
AReaLite takes a flexible approach to agent definition. Rather than using rigid `Agent`
|
||||
classes that might limit your agentic capabilities, AReaLite captures all rollout
|
||||
behavior in a `RolloutWorkflow` class. This design gives you complete freedom to
|
||||
customize your agent's behavior.
|
||||
|
||||
```python
|
||||
# arealite/api/workflow_api.py
|
||||
|
@ -26,8 +29,8 @@ class RolloutWorkflow:
|
|||
raise NotImplementedError()
|
||||
```
|
||||
|
||||
The workflow exposes an `arun_episode` method that runs and collects data from a single
|
||||
episode. This method takes two key arguments:
|
||||
The workflow exposes a single `arun_episode` method that runs and collects data from a
|
||||
single episode. This method takes two key arguments:
|
||||
|
||||
1. **InferenceEngine**: Provides the `agenerate` method for generating responses to user
|
||||
inputs
|
||||
|
@ -36,19 +39,14 @@ episode. This method takes two key arguments:
|
|||
Within this method, you have complete control over how your agent and environment
|
||||
interact.
|
||||
|
||||
> **Note**: Each `arun_episode` call takes a single prompt and outputs the trajectories
|
||||
> generated from that prompt—it's not batched. However, you can generate multiple
|
||||
> trajectories from a single prompt (for example, with GRPO or tree search).
|
||||
|
||||
### Setting Up the Multi-Turn Math Workflow
|
||||
#### Setting Up the Multi-Turn Math Workflow
|
||||
|
||||
Let's build a multi-turn rollout workflow for solving math problems. First, we'll define
|
||||
the `__init__` method to set up what we need during rollout:
|
||||
the `__init__` method to capture the utilities we need during rollout:
|
||||
|
||||
> **Note**: You have complete flexibility in defining the `__init__` method. Pass
|
||||
> whatever arguments you need to construct your workflow. If you want to use tools, pass
|
||||
> the corresponding environment here so your agent can call it in the `arun_episode`
|
||||
> method.
|
||||
> **Note**: You have complete flexibility in defining the `__init__` method. Pass any
|
||||
> arguments needed to construct your workflow. If you want to use tools, pass the
|
||||
> corresponding environment here so your agent can call it in the `arun_episode` method.
|
||||
|
||||
```python
|
||||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
|
@ -68,7 +66,7 @@ class MultiTurnWorkflow(RolloutWorkflow):
|
|||
self.turn_discount = turn_discount
|
||||
```
|
||||
|
||||
### Implementing the Episode Logic
|
||||
#### Implementing the Episode Logic
|
||||
|
||||
Now let's implement the `arun_episode` method. We'll start by tokenizing the prompt data
|
||||
and converting it into an `LLMRequest` object for the inference engine:
|
||||
|
@ -102,20 +100,20 @@ class MultiTurnWorkflow(RolloutWorkflow):
|
|||
# ... continue processing ...
|
||||
```
|
||||
|
||||
> **Note**: This example uses the "messages" key from the prompt data to get
|
||||
> OpenAI-compatible messages. This isn't required—the key and prompt format depend
|
||||
> **Note**: This example accesses the "messages" key from the prompt data to get
|
||||
> OpenAI-compatible messages. This isn't mandatory—the key and prompt format depend
|
||||
> entirely on your implementation. For instance, if your dataset stores prompt strings
|
||||
> in a "prompt" column, you could get input token IDs with
|
||||
> in a "prompt" column, you could get input token IDs via
|
||||
> `self.tokenizer.encode(data["prompt"])`.
|
||||
|
||||
> **Note**: The `rid` field in `LLMRequest` is the request ID. Requests with the same ID
|
||||
> will reuse the LLM inference server's KV caches for better efficiency.
|
||||
> will reuse the LLM inference server's KV caches for efficiency.
|
||||
|
||||
### Handling Multi-Turn Conversations
|
||||
#### Handling Multi-Turn Conversations
|
||||
|
||||
Next, we'll check if the current answer is correct using our `reward_fn`. This function
|
||||
should return 1 for correct answers and 0 otherwise. When the answer is wrong, we'll
|
||||
apply a discount, add feedback to the conversation, and let the model try again:
|
||||
Next, we'll evaluate whether the current answer is correct using our `reward_fn`. This
|
||||
function should return 1 for correct answers and 0 otherwise. When the answer is wrong,
|
||||
we'll apply a discount, add feedback to the conversation, and let the model try again:
|
||||
|
||||
```python
|
||||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
|
@ -150,10 +148,10 @@ class MultiTurnWorkflow(RolloutWorkflow):
|
|||
discount *= self.turn_discount
|
||||
```
|
||||
|
||||
### Reward Function Signature
|
||||
#### Reward Function Signature
|
||||
|
||||
To make it easier to switch between different reward functions, we recommend following
|
||||
this signature:
|
||||
For convenience when switching between different reward functions, we recommend
|
||||
following this pre-defined signature:
|
||||
|
||||
```python
|
||||
def reward_fn(
|
||||
|
@ -180,10 +178,10 @@ def reward_fn(
|
|||
"""
|
||||
```
|
||||
|
||||
While this signature is convenient, you're not restricted to it in custom
|
||||
workflows—modify as needed for your specific use case.
|
||||
While this signature is convenient, there are no strict restrictions on reward functions
|
||||
in custom workflows—modify them as needed for your specific use case.
|
||||
|
||||
### Collecting Training Data
|
||||
#### Collecting Training Data
|
||||
|
||||
Finally, let's complete the implementation by collecting trajectories in the
|
||||
`TensorDict` format:
|
||||
|
@ -218,6 +216,191 @@ class MultiTurnWorkflow(RolloutWorkflow):
|
|||
return concat_padded_tensors([res])
|
||||
```
|
||||
|
||||
> **Important**: The returned `TensorDict` must follow HuggingFace's padded data format,
|
||||
> where each tensor has shape `[batch_size, sequence_length, *]`. This allows AReaLite
|
||||
> to automatically batch multiple trajectories for the training engine. Since this
|
||||
> example returns a single trajectory, we use `unsqueeze(0)` to create a size-1 batch.
|
||||
|
||||
> **Note**: There are no restrictions on the keys in your `TensorDict`—different
|
||||
> algorithms require different keys. This example targets the GRPO algorithm, so we
|
||||
> include `input_ids`, `loss_mask`, `attention_mask`, and `logprobs` (needed for
|
||||
> computing importance ratios).
|
||||
|
||||
### Step 2: Training with Your Custom Workflow
|
||||
|
||||
Using your custom workflow is straightforward—just construct it in your training script
|
||||
and pass it to the `rollout_batch` or `prepare_batch` method:
|
||||
|
||||
```python
|
||||
def main(args):
|
||||
# ... setup code ...
|
||||
|
||||
# Create your custom workflow
|
||||
workflow = MultiTurnWorkflow(
|
||||
reward_fn=gsm8k_reward_fn,
|
||||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
turn_discount=0.9,
|
||||
max_turns=5,
|
||||
)
|
||||
|
||||
# Run training—no other changes needed!
|
||||
data_generator = iter(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
# ... continue with training loop ...
|
||||
```
|
||||
|
||||
That's it! Your custom multi-turn math agent is now ready to train with reinforcement
|
||||
learning. The workflow will automatically handle the multi-turn conversations, reward
|
||||
computation, and data collection needed for effective RL training.
|
||||
|
||||
## Alternative Approach: Using the Legacy Version (Not Recommended)
|
||||
|
||||
While we strongly recommend using AReaLite for new projects, you might encounter legacy
|
||||
code that uses the older Agent-based approach. Here's how it works for reference, though
|
||||
we suggest migrating to the workflow-based system when possible.
|
||||
|
||||
### Step 1: Define Your Agent Class
|
||||
|
||||
Create a new file under `realhf/impl/agent/`, such as `math_multi_turn_agent.py`. Your
|
||||
`Agent` must implement the interface defined in `realhf/api/core/agent.py`, which
|
||||
requires a single method: `collect_trajectory`.
|
||||
|
||||
```python
|
||||
class MathMultiTurnAgent(Agent):
|
||||
async def collect_trajectory(
|
||||
self,
|
||||
reward_fn,
|
||||
gconfig: GenerationHyperparameters, # aka sampling_params
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
max_turns: int,
|
||||
turn_discount: float,
|
||||
):
|
||||
# Implementation goes here
|
||||
...
|
||||
```
|
||||
|
||||
### Step 2: Implement the Trajectory Collection Logic
|
||||
|
||||
The `collect_trajectory` method takes a task prompt, an environment, and two
|
||||
communication queues. Within this method, you control the data flow between your agent
|
||||
and the inference engine using these queues:
|
||||
|
||||
- **obs_queue**: Send observations (token IDs and generation config) to the inference
|
||||
engine
|
||||
- **act_queue**: Receive actions (generated responses) from the inference engine
|
||||
|
||||
Here's how the multi-turn conversation works:
|
||||
|
||||
```python
|
||||
for turn in range(self.num_turns):
|
||||
# Send the current state to the inference engine
|
||||
await obs_queue.put((qid, token_ids, self.gconfig))
|
||||
|
||||
# Get the generated response
|
||||
act: BundledGenerationOutputs = await act_queue.get()
|
||||
|
||||
# Evaluate the response through the environment
|
||||
success, rewards = await env.step((qid, answers))
|
||||
# ... process results ...
|
||||
```
|
||||
|
||||
#### Environment Integration
|
||||
|
||||
The environment follows a
|
||||
[Gym-like interface](https://github.com/Farama-Foundation/Gymnasium) with `reset` and
|
||||
`step` methods, but uses asynchronous implementations to prevent blocking across
|
||||
different environment instances.
|
||||
|
||||
For math problems, the environment is typically stateless and acts as a wrapper around
|
||||
your reward function:
|
||||
|
||||
```python
|
||||
class MathCodeSingleStepEnv(EnvironmentService):
|
||||
async def step(self, action: Tuple[str, List[str]]):
|
||||
qid, answers = action
|
||||
# ... setup code ...
|
||||
|
||||
# Run reward computation asynchronously
|
||||
format_rewards = await asyncio.to_thread(
|
||||
math_verify_call,
|
||||
answers,
|
||||
# ... other parameters ...
|
||||
)
|
||||
return None, format_rewards, True, False, {}
|
||||
```
|
||||
|
||||
#### Handling Multi-Turn Feedback
|
||||
|
||||
After receiving the reward from `env.step`, check if the answer is correct. If not,
|
||||
provide feedback and continue to the next turn:
|
||||
|
||||
```python
|
||||
for turn in range(self.num_turns):
|
||||
# ... generation and evaluation code ...
|
||||
|
||||
# Provide feedback based on the result
|
||||
if success[0]:
|
||||
feedback = "Congratulations! You are correct!"
|
||||
else:
|
||||
feedback = "Unfortunately your answer is wrong. Let's try again."
|
||||
|
||||
# Format feedback as a user message
|
||||
feedback = "\n" + self.tokenizer.apply_chat_template(
|
||||
[{"content": feedback, "role": "user"}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
# Add feedback tokens to the conversation
|
||||
feedback_tokens = self.tokenizer(feedback)["input_ids"]
|
||||
token_ids.extend(feedback_tokens)
|
||||
```
|
||||
|
||||
### Step 3: Register and Configure Your Agent
|
||||
|
||||
First, register your agent implementation:
|
||||
|
||||
```python
|
||||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
# ... previous methods ...
|
||||
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
# ... episode logic above ...
|
||||
|
||||
while reward == 0 and t < self.max_turns:
|
||||
# ... generation and evaluation ...
|
||||
|
||||
# Collect trajectory data
|
||||
input_len = len(resp.input_tokens) - len(seq)
|
||||
seq += resp.input_tokens[-input_len:] + resp.output_tokens
|
||||
logprobs += [0.0] * input_len + resp.output_logprobs
|
||||
loss_mask += [0] * input_len + [1] * resp.output_len
|
||||
versions += [-1] * input_len + resp.output_versions
|
||||
|
||||
# Package results
|
||||
res = dict(
|
||||
input_ids=torch.tensor(seq),
|
||||
logprobs=torch.tensor(logprobs),
|
||||
loss_mask=torch.tensor(loss_mask),
|
||||
versions=torch.tensor(versions),
|
||||
rewards=torch.tensor(float(reward * discount)),
|
||||
attention_mask=torch.ones(len(seq), dtype=torch.bool),
|
||||
)
|
||||
res = {k: v.unsqueeze(0) for k, v in res.items()}
|
||||
return concat_padded_tensors([res])
|
||||
```
|
||||
|
||||
> **Important**: The returned `TensorDict` must follow HuggingFace's padded data format,
|
||||
> where each tensor has shape `[batch_size, sequence_length, *]`. This allows AReaLite
|
||||
> to automatically batch multiple trajectories for training. Since this example returns
|
||||
|
@ -234,34 +417,62 @@ Using your custom workflow is straightforward—just create it in your training
|
|||
pass it to the `rollout_batch` or `prepare_batch` method:
|
||||
|
||||
```python
|
||||
def main(args):
|
||||
# ... setup code ...
|
||||
|
||||
# Create your custom workflow
|
||||
workflow = MultiTurnWorkflow(
|
||||
reward_fn=gsm8k_reward_fn,
|
||||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
turn_discount=0.9,
|
||||
max_turns=5,
|
||||
)
|
||||
|
||||
# Run training—no other changes needed!
|
||||
data_generator = iter(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
# ... continue with training loop ...
|
||||
# in realhf/impl/agent/__init__.py
|
||||
import realhf.impl.agent.math_multi_turn_agent
|
||||
```
|
||||
|
||||
That's it! Your custom multi-turn math agent is now ready for reinforcement learning
|
||||
training. The workflow will automatically handle the multi-turn conversations, reward
|
||||
computation, and data collection needed for effective RL training.
|
||||
Then update your experiment configuration in
|
||||
`realhf/experiments/async_exp/async_math_ppo.py`:
|
||||
|
||||
```python
|
||||
@dataclasses.dataclass
|
||||
class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
|
||||
# Add any new CLI arguments your agent needs
|
||||
my_param: float = 1.0
|
||||
|
||||
@property
|
||||
def agent(self) -> AgentAbstraction:
|
||||
return AgentAbstraction(
|
||||
"math-multi-turn", # Your registered agent name
|
||||
args=dict(
|
||||
# Pass any arguments needed for your __init__ method
|
||||
my_param=self.my_param,
|
||||
# ... other configuration ...
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def env(self) -> EnvServiceAbstraction:
|
||||
# Update to use your custom environment if needed
|
||||
return EnvServiceAbstraction(
|
||||
"math-code-single-step",
|
||||
args=dict(dataset_path=self.dataset.path)
|
||||
)
|
||||
```
|
||||
|
||||
### Step 4: Run Training
|
||||
|
||||
Follow the standard training procedure outlined in the
|
||||
[quickstart guide](../tutorial/quickstart.md). Launch your experiment with:
|
||||
|
||||
```bash
|
||||
python3 training/main_async_ppo.py my_param=5.0 # plus any additional CLI arguments
|
||||
```
|
||||
|
||||
### Training Results
|
||||
|
||||
Here's an example of the training reward curve from our multi-turn math agent:
|
||||
|
||||

|
||||
|
||||
The agent successfully learns to solve math problems with improved accuracy over time,
|
||||
demonstrating the effectiveness of the multi-turn approach.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
**Note**: While this legacy approach works, we strongly recommend using the AReaLite
|
||||
workflow system for new projects. It provides better flexibility, cleaner abstractions,
|
||||
and easier maintenance. Consider migrating existing legacy agents to the workflow-based
|
||||
approach when possible.
|
||||
|
||||
Happy coding!
|
||||
|
|
Loading…
Reference in New Issue