mirror of https://github.com/inclusionAI/AReaL
add doc (#68)
This commit is contained in:
parent
ab20f940f9
commit
fabe59aad1
13
README.md
13
README.md
|
@ -3,7 +3,7 @@
|
|||
</h1>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://inclusionai.github.io/AReaL/"><b>Documentation</b></a> | <a href="https://deepwiki.com/inclusionAI/AReaL"><b>Ask DeepWiki</b></a> |
|
||||
| <a href="https://arxiv.org/pdf/2505.24298"><b>Paper</b></a> | <a href="https://inclusionai.github.io/AReaL/"><b>Documentation</b></a> | <a href="https://deepwiki.com/inclusionAI/AReaL"><b>Ask DeepWiki</b></a> |
|
||||
</p>
|
||||
|
||||
<img align="right" alt="ReaL" src="/assets/logo.png" width="20%">
|
||||
|
@ -115,6 +115,11 @@ python eval_and_aggregate.py \
|
|||
|
||||
+ [Installation](https://inclusionai.github.io/AReaL/tutorial/installation.html)
|
||||
+ [Quickstart](https://inclusionai.github.io/AReaL/tutorial/quickstart.html)
|
||||
+ [Code Walkthrough](https://inclusionai.github.io/AReaL/developer/overview.html)
|
||||
+ **Customization Guide**
|
||||
+ [Dataset](https://inclusionai.github.io/AReaL/customization/dataset.html)
|
||||
+ [Rollout Behavior (Agentic RL)](https://inclusionai.github.io/AReaL/customization/agent.html)
|
||||
+ [Training Algorithm](https://inclusionai.github.io/AReaL/customization/algorithm.html)
|
||||
+ [Contributing](https://inclusionai.github.io/AReaL/contrib.html)
|
||||
|
||||
## Future Plan
|
||||
|
@ -122,15 +127,15 @@ AReaL is under active development. We will have major releases in a weekly manne
|
|||
|
||||
### System Development
|
||||
- [x] Support for SGLang.
|
||||
- [ ] RL training with coding problems.
|
||||
- [ ] Asynchronous generation and RL training.
|
||||
- [x] RL training with coding problems.
|
||||
- [x] Asynchronous generation and RL training.
|
||||
- [ ] Optimizations for distributed training: expert parallel and zero-bubble pipelining.
|
||||
- [ ] RL for vision-language models (VLM).
|
||||
- [ ] Function calling and agent capabilities.
|
||||
|
||||
### Algorithm Development
|
||||
- [x] RL training receipes for 1.5B and 7B models.
|
||||
- [ ] A complete RL training receipe for 32B models.
|
||||
- [x] A complete RL training receipe for 32B models.
|
||||
- [ ] Sample-efficient multi-task RL algorithms.
|
||||
- [ ] Agentic capabilities with end-to-end RL.
|
||||
- [ ] Stable RL training for larger MOE models.
|
||||
|
|
|
@ -14,6 +14,7 @@ parts:
|
|||
chapters:
|
||||
- file: customization/dataset
|
||||
- file: customization/agent
|
||||
- file: customization/algorithm
|
||||
- caption: Code Walkthrough
|
||||
chapters:
|
||||
- file: developer/overview
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Agentic RL
|
||||
# Rollout and Agentic RL
|
||||
|
||||
This guide provides an example of training a multi-turn math agent using end-to-end RL. The math agent will continuously attempt to think through and solve math problems until it reaches the correct answer.
|
||||
This guide provides an example of modifying the rollout behavior for PPO training.
|
||||
|
||||
In particular, we implement a multi-turn math agent using end-to-end RL. The math agent will continuously attempt to think through and solve math problems until it reaches the correct answer.
|
||||
|
||||
## Define Your Agent
|
||||
|
||||
|
|
|
@ -0,0 +1,305 @@
|
|||
# Training Algorithm
|
||||
|
||||
An algorithm is encapsulated in a `ModelInterface`, which primarily defines three methods:
|
||||
|
||||
```python
|
||||
# in realhf/api/core/model_api.py
|
||||
class ModelInterface(abc.ABC):
|
||||
"""An interface for model training, inference, and generation.
|
||||
|
||||
This interface is designed to follow the dependency injection pattern.
|
||||
We pass the model to the interface and call its methods, ensuring that model APIs
|
||||
and algorithms are fully decoupled. For example, REINFORCE and PPO can exhibit
|
||||
different behaviors during training. Separate interfaces can be written for these
|
||||
algorithms while using the same model that provides basic forward-backward-update
|
||||
functionality (i.e., :class:`PipelinableEngine`).
|
||||
|
||||
During runtime, the master worker requests model workers to execute a specific
|
||||
interface type (e.g., generate) on a specific model. The model worker locates
|
||||
the corresponding model, passes it into the requested interface, performs the
|
||||
computation, and returns the result.
|
||||
"""
|
||||
|
||||
def inference(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def train_step(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> Dict | List[Dict]:
|
||||
raise NotImplementedError()
|
||||
```
|
||||
|
||||
When the dataflow is fixed, it's usually sufficient to modify or add the file that defines the algorithm interface.
|
||||
|
||||
We provide two examples: (1) changing PPO's global advantage normalization to grouped normalization in GRPO, and (2) changing the original PPO loss to the decoupled PPO loss in AReaL's paper.
|
||||
|
||||
```{note}
|
||||
We recommend using asynchronous RL, so that you can customize the generation behavior by [modifying your RL agent](agent.md) and don't need to modify the `generate` method of model interfaces.
|
||||
```
|
||||
|
||||
## Grouped Advantage Normalization
|
||||
|
||||
The PPO algorithm is written in a single file `ppo_interface.py`. The method we are going to modify is the `train_step` method in `PPOActorInterface`. PPO's global advantage normalization looks like:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class PPOActorInterface(ModelInterface):
|
||||
def train_step(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> Dict | List[Dict]:
|
||||
...
|
||||
if self.adv_norm:
|
||||
advantages = masked_normalization(advantages, loss_mask)
|
||||
...
|
||||
```
|
||||
|
||||
### An Additional Note on Data Management
|
||||
|
||||
We need to explain how data in each batch is organized.
|
||||
|
||||
Usually, each data batch (i.e., the `data` variable) includes multiple prompts. The number of prompts is called "batch size". Additionally, each prompt may have multiple corresponding answers. The number of answers is called "group_size". Therefore, there are batch_size × group_size sequences in each batch.
|
||||
|
||||
These sequences have different lengths, but they are concatenated (or packed) together as a 1D tensor. The inner dimension is the "group" with the same prompt, and the outer dimension consists of answers from different prompts. Similar to flash-attention, we use `cu_seqlens` to mark the boundary of each sequence. `cu_seqlens` is the cumulative sum of sequence lengths across the batch.
|
||||
|
||||
Each token in the sequence has a corresponding reward and advantage, so `advantages` is also a packed 1D tensor just like the tokens (i.e., `packed_input_ids`). However, the "sequences" of advantages are all one step shorter than tokens due to the auto-regressive nature of LLMs. We can only compute the loss on tokens except for the first one in each sequence.
|
||||
|
||||
### Implementation
|
||||
|
||||
For grouped advantage normalization, we need to partition the advantages into groups and run normalization within the tensor chunk of each group:
|
||||
|
||||
```diff
|
||||
@dataclass
|
||||
class PPOActorInterface(ModelInterface):
|
||||
+ group_adv_norm: bool = False
|
||||
|
||||
def train_step(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> Dict | List[Dict]:
|
||||
...
|
||||
if self.adv_norm:
|
||||
- advantages = masked_normalization(advantages, loss_mask)
|
||||
+ if not self.group_adv_norm:
|
||||
+ advantages = masked_normalization(advantages, loss_mask)
|
||||
+ else:
|
||||
+ n_samples = data.bs
|
||||
+ adv_list = []
|
||||
+ for i in range(0, n_samples, self.group_size):
|
||||
+ # Start and end of the chunk
|
||||
+ s = short1cu_seqlens[i]
|
||||
+ e = short1cu_seqlens[i + self.group_size]
|
||||
+ # Get advantages within each group of the same prompt
|
||||
+ adv = advantages[s: e]
|
||||
+ mask = loss_mask[s: e]
|
||||
+ # Run normalization
|
||||
+ advn = masked_normalization(adv, mask, all_reduce=False)
|
||||
+ adv_list.append(advn)
|
||||
+ advantages = torch.cat(adv_list, 0)
|
||||
...
|
||||
```
|
||||
|
||||
### Modify Your Experiment Configuration
|
||||
|
||||
To make our new argument `group_adv_norm` effective in CLI args, we should make the following changes to the `PPOMathConfig` under `realhf/experiments/common/ppo_math_exp.py`:
|
||||
|
||||
```diff
|
||||
@dataclasses.dataclass
|
||||
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
||||
+ group_adv_norm: bool = False
|
||||
|
||||
@property
|
||||
def rpcs(self):
|
||||
...
|
||||
# interfaces
|
||||
actor_interface = ModelInterfaceAbstraction(
|
||||
"ppo_actor",
|
||||
args={
|
||||
**copy.deepcopy(self.ppo_kwargs),
|
||||
+ "group_adv_norm": self.group_adv_norm,
|
||||
...
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
## The Decoupled PPO Loss
|
||||
|
||||

|
||||
|
||||
As mentioned in AReaL's paper, we implement this loss by recomputing the probabilities before mini-batched updates, and use this value as π_prox to compute the above loss.
|
||||
|
||||
### Probability Recomputation
|
||||
|
||||
Recomputation involves a single forward pass, which has already been implemented by `PPOActorInterface.inference`. We need to call this method in the `train_step` method:
|
||||
|
||||
```diff
|
||||
@dataclass
|
||||
class PPOActorInterface(ModelInterface):
|
||||
+ use_decoupled_loss: bool = False
|
||||
|
||||
def train_step(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> Dict | List[Dict]:
|
||||
+ if self.use_decoupled_loss:
|
||||
+ s: SequenceSample = self.inference(model, data, mb_spec)
|
||||
+ prox_logp = s.data["logprobs"]
|
||||
...
|
||||
```
|
||||
|
||||
Next, we need to pass `prox_logp` to loss computation:
|
||||
|
||||
```diff
|
||||
@dataclass
|
||||
class PPOActorInterface(ModelInterface):
|
||||
...
|
||||
|
||||
def train_step(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> Dict | List[Dict]:
|
||||
# Prepare data to be split into mini-batches.
|
||||
flat_data = dict(
|
||||
advantages=advantages,
|
||||
old_logp=old_logp,
|
||||
ppo_loss_mask=loss_mask,
|
||||
packed_input_ids=input_.data["packed_input_ids"],
|
||||
kl_rewards=kl_rewards,
|
||||
)
|
||||
+ if self.use_decoupled_loss:
|
||||
+ flat_data["prox_logp"] = prox_logp.float()
|
||||
|
||||
flat_input = SequenceSample.from_default(
|
||||
ids=list(range(input_.bs * self.group_size)),
|
||||
data=flat_data,
|
||||
seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()],
|
||||
)
|
||||
...
|
||||
datas = flat_input.split_with_spec(spec)
|
||||
...
|
||||
for mb_i, data in enumerate(datas):
|
||||
train_stat = module.train_batch(
|
||||
input_=data,
|
||||
mb_spec=mb_spec,
|
||||
version_steps=model.version.global_step,
|
||||
loss_fn=_loss_fn,
|
||||
loss_weight_fn=lambda x: x.data[
|
||||
"ppo_loss_mask"
|
||||
].count_nonzero(),
|
||||
token_normalize_scope=self.token_normalize_scope,
|
||||
)
|
||||
```
|
||||
|
||||
The `flat_input` variable will be divided into mini-batches. Each mini-batch of data will be passed into the `train_batch` method to run distributed training. The data included in this `SequenceSample` object will all be passed into the `_loss_fn`. In this case, `_loss_fn` is a wrapper over `_ppo_actor_loss_from_model_outputs`:
|
||||
|
||||
```python
|
||||
def _ppo_actor_loss_from_model_outputs(
|
||||
logits: torch.FloatTensor, # [tot_seqlen, vocab_size]
|
||||
input_: SequenceSample,
|
||||
...
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
```
|
||||
|
||||
`logits` is the output of model forward, and `input_` is exactly the `input_` we passed into `train_batch`. So now we can retrieve the `prox_logp` via:
|
||||
|
||||
```diff
|
||||
def _ppo_actor_loss_from_model_outputs(
|
||||
logits: torch.FloatTensor, # [tot_seqlen, vocab_size]
|
||||
input_: SequenceSample,
|
||||
...
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
+ prox_logp = input_.data["prox_logp"]
|
||||
loss, ppo_stat = ppo_functional.actor_loss_fn(
|
||||
logprobs=logprobs,
|
||||
old_logprobs=old_logp,
|
||||
advantages=advantages,
|
||||
eps_clip=eps_clip,
|
||||
loss_mask=ppo_loss_mask,
|
||||
c_clip=c_clip,
|
||||
+ proximal_logprobs=prox_logp,
|
||||
behav_imp_weight_cap=behav_imp_weight_cap,
|
||||
)
|
||||
```
|
||||
|
||||
We have successfully recomputed the probability and passed it into the loss function. Next we should revise the loss computation code.
|
||||
|
||||
### Modifying the PPO Loss
|
||||
|
||||
```diff
|
||||
def actor_loss_fn(
|
||||
logprobs: torch.FloatTensor,
|
||||
old_logprobs: torch.FloatTensor,
|
||||
advantages: torch.FloatTensor,
|
||||
eps_clip: float,
|
||||
loss_mask: Optional[torch.BoolTensor] = None,
|
||||
c_clip: Optional[float] = None,
|
||||
+ proximal_logprobs: Optional[torch.FloatTensor] = None,
|
||||
behav_imp_weight_cap: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
...
|
||||
+ if proximal_logprobs is not None:
|
||||
+ denorm_logprobs = proximal_logprobs
|
||||
+ else:
|
||||
+ denorm_logprobs = old_logprobs
|
||||
...
|
||||
loss_mask_count = loss_mask.count_nonzero() or 1
|
||||
# For numerical stability.
|
||||
- ratio = torch.where(loss_mask, torch.exp(logprobs - old_logprobs), 0)
|
||||
+ ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
|
||||
...
|
||||
+ if proximal_logprobs is not None:
|
||||
+ behav_kl = proximal_logprobs - old_logprobs
|
||||
+ behav_imp_weight = behav_kl.exp()
|
||||
+ behav_kl = torch.where(loss_mask, behav_kl, 0.0)
|
||||
+ behav_imp_weight = torch.where(loss_mask, behav_imp_weight, 0.0)
|
||||
+ pg_loss = pg_loss * behav_imp_weight
|
||||
...
|
||||
return pg_loss, stat
|
||||
```
|
||||
|
||||
### Modify the Experiment Configuration
|
||||
|
||||
```diff
|
||||
@dataclasses.dataclass
|
||||
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
||||
+ use_decoupled_loss: bool = False
|
||||
|
||||
@property
|
||||
def rpcs(self):
|
||||
...
|
||||
# interfaces
|
||||
actor_interface = ModelInterfaceAbstraction(
|
||||
"ppo_actor",
|
||||
args={
|
||||
**copy.deepcopy(self.ppo_kwargs),
|
||||
+ "use_decoupled_loss": self.use_decoupled_loss,
|
||||
...
|
||||
},
|
||||
)
|
||||
```
|
|
@ -1,4 +1,4 @@
|
|||
# Dataset Customization
|
||||
# Dataset
|
||||
|
||||
This guide provides detailed examples of how to create custom datasets in AReaL for model training.
|
||||
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 47 KiB |
|
@ -3,6 +3,7 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
|
@ -36,9 +37,10 @@ def check_code_metadata_entries(data):
|
|||
data["problem_id"] = data["query_id"]
|
||||
assert isinstance(data["prompt"], str)
|
||||
case_size = sys.getsizeof(data["input_output"])
|
||||
assert (
|
||||
case_size < 500 * 1024
|
||||
), f"'input_output' exceeds 500KB ({case_size} bytes). Use remote testcase instead."
|
||||
if not os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""):
|
||||
assert (
|
||||
case_size < 500 * 1024
|
||||
), f"'input_output' exceeds 500KB ({case_size} bytes). Use remote testcase instead."
|
||||
input_output = json.loads(data["input_output"])
|
||||
assert len(input_output["inputs"]) == len(input_output["outputs"])
|
||||
for inp, out in zip(input_output["inputs"], input_output["outputs"]):
|
||||
|
|
Loading…
Reference in New Issue