[doc] [lite] Add customization docs for AReaLite. (#191)

* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine

Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/353

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* add gradient checkpointing

* PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP  engine

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/354

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .

* PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* fix
* .

* PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities

Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* fix destroy process group

* PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset

Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/358

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* fix loss mask
* fix
* .

* PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub

Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/368

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .

* PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation

Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/371

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher

Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/370

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* .
* .
* .
* fix

* PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results.

Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/392

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* support fsdp engine and sglang remote engine
* minor fix
* .
* refactor trainer
* add close
* rm mb_spec
* .
* fix
* .
* qwen2 grpo works
* fix
* fix
* async works
* fix
* slurm launcher not tested
* fix arg parse
* .
* sglang server wrapper
* .
* .
* slurm run
* ready for boba
* debug
* 32k run
* .
* .
* fix
* .
* .
* .
* .
* .
* fix
* .
* fix
* .
* .
* .
* .
* fix
* .
* .
* .
* .
* .
* .
* .
* refactor train engine
* refactor train engine
* .
* fix update weight error
* .
* .
* match train
* format
* .
* fix
* seems to work
* .
* .
* .
* .

* .

* .

* .

* .

* .

* .

* .

---------

Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
Wei Fu 2025-07-22 15:43:31 +08:00 committed by GitHub
parent ba16d4ef44
commit 6239633213
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1197 additions and 498 deletions

View File

@ -0,0 +1,87 @@
import uuid
import torch
from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters
from arealite.api.engine_api import InferenceEngine
from arealite.api.io_struct import LLMRequest
from arealite.api.workflow_api import RolloutWorkflow
from arealite.utils.data import concat_padded_tensors
class MultiTurnWorkflow(RolloutWorkflow):
def __init__(
self,
reward_fn,
gconfig: GenerationHyperparameters,
tokenizer: PreTrainedTokenizerFast,
max_turns: int,
turn_discount: float,
):
self.reward_fn = reward_fn
self.gconfig = gconfig
self.tokenizer = tokenizer
self.max_turns = max_turns
self.turn_discount = turn_discount
async def arun_episode(self, engine: InferenceEngine, data):
# Placeholders for the results
seq, logprobs, loss_mask, versions = [], [], [], []
messages = data["messages"]
# Run multi-turn rollout until correct
t = reward = 0
discount = 0
rid = uuid.uuid4().hex
while reward == 0 and t < self.max_turns:
# Amend a prompt if the previous answer is incorrect
if t > 0:
messages += [
{"role": "asistant", "content": completions_str},
{
"role": "user",
"content": "Your answer is not correct. Please try to answer it again.",
},
]
# Convert the prompt into input_ids
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
)
# Send generate request to get the response.
req = LLMRequest(
rid=rid,
input_ids=input_ids,
gconfig=self.gconfig.new(n_samples=1),
)
resp = await engine.agenerate(req)
# compute reward: 1 for correct and 0 otherwise
prompt_str = self.tokenizer.decode(input_ids)
completions_str = self.tokenizer.decode(resp.output_tokens)
reward = self.reward_fn(
prompt=prompt_str,
completions=completions_str,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
# Amend results
input_len = len(resp.input_tokens) - len(seq)
seq += resp.input_tokens[-input_len:] + resp.output_tokens
logprobs += [0.0] * input_len + resp.output_logprobs
loss_mask += [0] * input_len + [1] * resp.output_len
versions += [-1] * input_len + resp.output_versions
# Increase counter
t += 1
discount *= self.turn_discount
res = dict(
seq=torch.tensor(seq),
logprobs=torch.tensor(logprobs),
loss_mask=torch.tensor(loss_mask),
versions=torch.tensor(versions),
rewards=torch.tensor([float(reward * discount)]),
attetion_mask=torch.ones(len(seq), dtype=torch.bool),
)
res = {k: v.unsqueeze(0) for k, v in res.items()}
return concat_padded_tensors([res])

View File

@ -39,3 +39,8 @@ parts:
- caption: Contributing
chapters:
- file: contrib
- caption: Customization (Legacy)
chapters:
- file: legacy/customization/dataset
- file: legacy/customization/agent
- file: legacy/customization/algorithm

View File

@ -1,135 +1,267 @@
# Rollout and Agentic RL
This guide provides an example of modifying the rollout behavior for PPO training.
This guide shows you how to create custom rollout behaviors for RL training by building
a multi-turn math agent with **AReaLite**. This agent keeps trying to solve math
problems until it finds the correct answer.
In particular, we implement a multi-turn math agent using end-to-end RL. The math agent will continuously attempt to think through and solve math problems until it reaches the correct answer.
You can find the complete implementation in `arealite/workflow/multi_turn.py`.
## Define Your Agent
## Step 1: Define Your Workflow
Create a new file under `realhf/impl/agent/`, for example, `math_multi_turn_agent.py`. Your `Agent` must implement the interface defined in `realhf/api/core/agent.py`, which requires implementing a single method: `collect_trajectory`.
AReaLite gives you flexibility in how you design your agents. Instead of rigid `Agent`
classes that might constrain your agent's capabilities, AReaLite captures all rollout
behavior in a `RolloutWorkflow` class. This approach lets you customize your agent's
behavior however you need.
```python
class MathMultiTurnAgent(Agent):
async def collect_trajectory(
# arealite/api/workflow_api.py
class RolloutWorkflow:
async def arun_episode(
self, engine: InferenceEngine, data: Dict[str, Any]
) -> TensorDict:
"""Run a single episode of the workflow.
See concrete example implementations under the `arealite/workflow` directory.
"""
raise NotImplementedError()
```
The workflow exposes an `arun_episode` method that runs and collects data from a single
episode. This method takes two key arguments:
1. **InferenceEngine**: Provides the `agenerate` method for generating responses to user
inputs
1. **data**: The prompt data loaded from your RL dataset
Within this method, you have complete control over how your agent and environment
interact.
> **Note**: Each `arun_episode` call takes a single prompt and outputs the trajectories
> generated from that prompt—it's not batched. However, you can generate multiple
> trajectories from a single prompt (for example, with GRPO or tree search).
### Setting Up the Multi-Turn Math Workflow
Let's build a multi-turn rollout workflow for solving math problems. First, we'll define
the `__init__` method to set up what we need during rollout:
> **Note**: You have complete flexibility in defining the `__init__` method. Pass
> whatever arguments you need to construct your workflow. If you want to use tools, pass
> the corresponding environment here so your agent can call it in the `arun_episode`
> method.
```python
class MultiTurnWorkflow(RolloutWorkflow):
def __init__(
self,
prompt: SequenceSample,
env: EnvironmentService,
obs_queue: asyncio.Queue,
act_queue: asyncio.Queue,
reward_fn,
gconfig: GenerationHyperparameters, # aka sampling_params
tokenizer: PreTrainedTokenizerFast,
max_turns: int,
turn_discount: float,
):
...
self.reward_fn = reward_fn
self.gconfig = gconfig
self.tokenizer = tokenizer
self.max_turns = max_turns
# Discount rewards if the agent takes longer to find the correct answer
self.turn_discount = turn_discount
```
## Implement the `collect_trajectory` Logic
### Implementing the Episode Logic
The `collect_trajectory` function takes a task prompt, an environment, and two queues as input, then produces several trajectories for the RL trainer. Within this function, you can create arbitrary data processing logic to produce the input for the inference engine (i.e., via `obs_queue`) and extract the action (i.e., via `act_queue`) from the generated tokens.
In this example, the initial observation is the math problem itself. We put the token IDs and generation config into `obs_queue` and wait for the action produced by the inference engine from `act_queue`. After the inference engine returns, we extract the generated answers and send them to the environment.
Now let's implement the `arun_episode` method. We'll start by tokenizing the prompt data
and converting it into an `LLMRequest` object for the inference engine:
```python
for turn in range(self.num_turns):
await obs_queue.put((qid, token_ids, self.gconfig))
act: BundledGenerationOutputs = await act_queue.get()
_, success, *_ = await env.step((qid, answers))
...
class MultiTurnWorkflow(RolloutWorkflow):
# ... __init__ method above ...
async def arun_episode(self, engine: InferenceEngine, data):
# Initialize result containers
seq, logprobs, loss_mask, versions = [], [], [], []
messages = data["messages"]
# Run multi-turn rollout until we get the correct answer
t = reward = 0
discount = 1.0
rid = uuid.uuid4().hex
while reward == 0 and t < self.max_turns:
# Convert the conversation into input tokens
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
)
# Generate response from the model
req = LLMRequest(
rid=rid,
input_ids=input_ids,
gconfig=self.gconfig.new(n_samples=1),
)
resp = await engine.agenerate(req)
# ... continue processing ...
```
The environment is similar to a [gym environment](https://github.com/Farama-Foundation/Gymnasium), which defines two methods: `reset` and `step`. However, to maintain efficiency, we use an asynchronous implementation to avoid mutual blocking across different environment instances.
> **Note**: This example uses the "messages" key from the prompt data to get
> OpenAI-compatible messages. This isn't required—the key and prompt format depend
> entirely on your implementation. For instance, if your dataset stores prompt strings
> in a "prompt" column, you could get input token IDs with
> `self.tokenizer.encode(data["prompt"])`.
The math environment is stateless and essentially serves as a wrapper around the reward function:
> **Note**: The `rid` field in `LLMRequest` is the request ID. Requests with the same ID
> will reuse the LLM inference server's KV caches for better efficiency.
### Handling Multi-Turn Conversations
Next, we'll check if the current answer is correct using our `reward_fn`. This function
should return 1 for correct answers and 0 otherwise. When the answer is wrong, we'll
apply a discount, add feedback to the conversation, and let the model try again:
```python
class MathCodeSingleStepEnv(EnvironmentService):
class MultiTurnWorkflow(RolloutWorkflow):
# ... previous methods ...
async def step(self, action: Tuple[str, List[str]]):
qid, answers = action
...
# Make `math_verify_call` async
format_rewards = await asyncio.to_thread(
math_verify_call,
answers,
...
async def arun_episode(self, engine: InferenceEngine, data):
# ... initialization code ...
while reward == 0 and t < self.max_turns:
# Add feedback if the previous answer was incorrect
if t > 0:
messages += [
{"role": "assistant", "content": completions_str},
{
"role": "user",
"content": "Your answer is not correct. Please try to answer it again."
},
]
# Generate response (code from above)
# ...
# Evaluate the response
prompt_str = self.tokenizer.decode(input_ids)
completions_str = self.tokenizer.decode(resp.output_tokens)
reward = self.reward_fn(
prompt=prompt_str,
completions=completions_str,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
# Update counters
t += 1
discount *= self.turn_discount
```
### Reward Function Signature
To make it easier to switch between different reward functions, we recommend following
this signature:
```python
def reward_fn(
prompt: str,
completions: str,
prompt_ids: List[int],
completion_ids: List[int],
**kwargs,
):
"""Reward function for evaluating agent performance.
This signature is recommended for compatibility with predefined workflows,
but you can modify it freely in custom implementations.
Args:
prompt: The task description string
completions: The agent's response string
prompt_ids: Tokenized prompt
completion_ids: Tokenized response
**kwargs: Additional dataset attributes (solutions, input_outputs, etc.)
Returns:
float: Reward value (typically 1.0 for correct, 0.0 for incorrect)
"""
```
While this signature is convenient, you're not restricted to it in custom
workflows—modify as needed for your specific use case.
### Collecting Training Data
Finally, let's complete the implementation by collecting trajectories in the
`TensorDict` format:
```python
class MultiTurnWorkflow(RolloutWorkflow):
# ... previous methods ...
async def arun_episode(self, engine: InferenceEngine, data):
# ... episode logic above ...
while reward == 0 and t < self.max_turns:
# ... generation and evaluation ...
# Collect trajectory data
input_len = len(resp.input_tokens) - len(seq)
seq += resp.input_tokens[-input_len:] + resp.output_tokens
logprobs += [0.0] * input_len + resp.output_logprobs
loss_mask += [0] * input_len + [1] * resp.output_len
versions += [-1] * input_len + resp.output_versions
# Package results
res = dict(
input_ids=torch.tensor(seq),
logprobs=torch.tensor(logprobs),
loss_mask=torch.tensor(loss_mask),
versions=torch.tensor(versions),
rewards=torch.tensor(float(reward * discount)),
attention_mask=torch.ones(len(seq), dtype=torch.bool),
)
return None, format_rewards, True, False, {}
res = {k: v.unsqueeze(0) for k, v in res.items()}
return concat_padded_tensors([res])
```
After `env.step` returns the reward for the current step, we can check whether the answer is correct. If not, we can append a user prompt and send it to `obs_queue` again to enter the next round.
> **Important**: The returned `TensorDict` must follow HuggingFace's padded data format,
> where each tensor has shape `[batch_size, sequence_length, *]`. This allows AReaLite
> to automatically batch multiple trajectories for training. Since this example returns
> a single trajectory, we use `unsqueeze(0)` to create a batch of size 1.
> **Note**: You're not restricted to specific keys in your `TensorDict`—different
> algorithms need different keys. This example targets the GRPO algorithm, so we include
> `input_ids`, `loss_mask`, `attention_mask`, and `logprobs` (needed for computing
> importance ratios).
## Step 2: Training with Your Custom Workflow
Using your custom workflow is straightforward—just create it in your training script and
pass it to the `rollout_batch` or `prepare_batch` method:
```python
for turn in range(self.num_turns):
...
feedback = None
if success[0]:
feedback = "Congratulations! You are correct!"
else:
feedback = "Unfortunately your answer is wrong. Let's try again."
feedback = "\n" + self.tokenizer.apply_chat_template(
[dict(content=feedback, role="user")],
add_generation_prompt=True,
tokenize=False,
def main(args):
# ... setup code ...
# Create your custom workflow
workflow = MultiTurnWorkflow(
reward_fn=gsm8k_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
turn_discount=0.9,
max_turns=5,
)
feedback = self.tokenizer(feedback)["input_ids"]
token_ids.extend(feedback)
# Run training—no other changes needed!
data_generator = iter(train_dataloader)
for global_step in range(max_steps):
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
# ... continue with training loop ...
```
## Modify the Configuration
You're now close to running the end-to-end RL loop. The final step is to register and import your implementation, then modify the experiment configuration.
```python
# in realhf/impl/agent/math_multi_turn_agent.py
register_agent("math-multi-turn", MathMultiTurnAgent)
```
```python
# in realhf/impl/agent/__init__.py
import realhf.impl.agent.math_multi_turn_agent
```
In `realhf/experiments/async_exp/async_math_ppo.py`:
```diff
@dataclasses.dataclass
class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
+ # New CLI arguments are defined here
+ my_param: float = 1.0
# in realhf/experiments/async_exp/async_ppo_math_exp.py
@property
def agent(self) -> AgentAbstraction:
return AgentAbstraction(
- "math-single-step",
+ "math-multi-turn", # Your registered name
args=dict(
- ...
+ # Any configurations for your __init__ method
+ my_param=my_param,
),
)
@property
def env(self) -> EnvServiceAbstraction:
- return EnvServiceAbstraction(
- "math-code-single-step", args=dict(dataset_path=self.dataset.path)
- )
+ # Change to your customized environment if necessary
+ return EnvServiceAbstraction(
+ "my-env", args=dict(...)
+ )
```
## Run Training
Please follow the guide in [quickstart](../tutorial/quickstart.md). Generally, start your experiments by running:
```bash
python3 training/main_async_ppo.py my_param=5.0 # and any additional CLI arguments
```
The training reward of our trial is shown below:
![](multiturn_reward.png)
Happy coding!
That's it! Your custom multi-turn math agent is now ready for reinforcement learning
training. The workflow will automatically handle the multi-turn conversations, reward
computation, and data collection needed for effective RL training.

View File

@ -1,305 +1,203 @@
# Training Algorithm
An algorithm is encapsulated in a `ModelInterface`, which primarily defines three methods:
> **Note**: We recommend the user to first read the
> [agent customization guide](agent.md).
**AReaLite** structures RL algorithms around two core components:
- **RolloutWorkflow**: Defines what data to generate during rollouts
- **TrainEngine**: Defines how to process the generated data for training
We'll demonstrate this by implementing an RL algorithm similar to ReMax.
## Step 1: Implementing the RolloutWorkflow
The rollout workflow generates both greedy and sampled completions, then uses the reward
difference as the final training signal:
```python
# in realhf/api/core/model_api.py
class ModelInterface(abc.ABC):
"""An interface for model training, inference, and generation.
class ReMaxRLVRWorkflow(RolloutWorkflow):
async def arun_episode(self, engine: InferenceEngine, data):
# Prepare input tokens from chat messages
input_ids = self.tokenizer.apply_chat_template(
data["messages"],
tokenize=True,
add_generation_prompt=True,
enable_thinking=self.enable_thinking,
)
This interface is designed to follow the dependency injection pattern.
We pass the model to the interface and call its methods, ensuring that model APIs
and algorithms are fully decoupled. For example, REINFORCE and PPO can exhibit
different behaviors during training. Separate interfaces can be written for these
algorithms while using the same model that provides basic forward-backward-update
functionality (i.e., :class:`PipelinableEngine`).
n_samples = self.gconfig.n_samples
rid = uuid.uuid4().hex
During runtime, the master worker requests model workers to execute a specific
interface type (e.g., generate) on a specific model. The model worker locates
the corresponding model, passes it into the requested interface, performs the
computation, and returns the result.
"""
# Create requests for both sampled and greedy generation
sample_req = LLMRequest(
rid=rid,
input_ids=input_ids,
gconfig=self.gconfig,
)
greedy_req = LLMRequest(
rid=rid,
input_ids=input_ids,
gconfig=self.gconfig.new(greedy=True),
)
def inference(
self,
model: Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> SequenceSample | None:
raise NotImplementedError()
# Generate both responses concurrently
resp, greedy_resp = await asyncio.gather(
engine.agenerate(sample_req),
engine.agenerate(greedy_req),
)
def generate(
self,
model: Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> SequenceSample | None:
raise NotImplementedError()
# Calculate rewards for both completions
prompt_str = self.tokenizer.decode(input_ids)
completions_str = self.tokenizer.decode(resp.output_tokens)
def train_step(
self,
model: Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> Dict | List[Dict]:
raise NotImplementedError()
sample_reward = self.reward_fn(
prompt=prompt_str,
completions=completions_str,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
greedy_completions = self.tokenizer.decode(greedy_resp.output_tokens)
greedy_reward = self.reward_fn(
prompt=prompt_str,
completions=greedy_completions,
prompt_ids=greedy_resp.input_tokens,
completion_ids=greedy_resp.output_tokens,
**data,
)
# Package results for training
res = dict(
# Add batch dimension
input_ids=torch.tensor(resp.input_tokens + resp.output_tokens).unsqueeze(0),
loss_mask=torch.tensor([0] * resp.input_len + [1] * resp.output_len).unsqueeze(0),
versions=torch.tensor([-1] * resp.input_len + resp.output_versions).unsqueeze(0),
attention_mask=torch.ones(resp.input_len + resp.output_len, dtype=torch.bool).unsqueeze(0),
# Use reward difference across all tokens
rewards=torch.tensor([float(sample_reward - greedy_reward)] * (resp.input_len + resp.output_len)),
)
return TensorDict(res, batch_size=[1])
```
When the dataflow is fixed, it's usually sufficient to modify or add the file that defines the algorithm interface.
> **Note**: For detailed guidance on customizing rollout workflows, see the
> [agent customization guide](agent.md).
We provide two examples: (1) changing PPO's global advantage normalization to grouped normalization in GRPO, and (2) changing the original PPO loss to the decoupled PPO loss in AReaL's paper.
## Step 2: Implementing the REINFORCE Training Algorithm
Training algorithms are implemented by subclassing `TrainEngine` and using its atomic
operations like `forward`, `train_batch`, and `eval_batch`.
First, let's define the REINFORCE loss function:
```python
def reinforce_loss_fn(logits, data):
input_ids = data["input_ids"]
loss_mask = data["loss_mask"].bool()
rewards = data["rewards"]
logprobs = gather_logprobs(
logits, torch.roll(input_ids, shifts=-1, dims=-1)
)
loss = -logprobs * rewards
loss = torch.where(loss_mask, loss, 0.0)
return loss.sum() / loss_mask.count_nonzero()
```
```{note}
We recommend using asynchronous RL, so that you can customize the generation behavior by [modifying your RL agent](agent.md) and don't need to modify the `generate` method of model interfaces.
To decrease memory usage, AReaLite automatically packs multiple sequences in an 1D tensor before forward passes. Hence, the loss function should assume handling 1D *packed* tensors instead of *padded* tensors.
```
## Grouped Advantage Normalization
The PPO algorithm is written in a single file `ppo_interface.py`. The method we are going to modify is the `train_step` method in `PPOActorInterface`. PPO's global advantage normalization looks like:
Next, we implement the training engine. We use a two-class design to maintain backend
compatibility:
```python
@dataclass
class PPOActorInterface(ModelInterface):
def train_step(
self,
model: Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> Dict | List[Dict]:
...
if self.adv_norm:
advantages = masked_normalization(advantages, loss_mask)
...
```
class ReinforceActor:
def __init__(self, engine: TrainEngine):
self.engine = engine
### An Additional Note on Data Management
We need to explain how data in each batch is organized.
Usually, each data batch (i.e., the `data` variable) includes multiple prompts. The number of prompts is called "batch size". Additionally, each prompt may have multiple corresponding answers. The number of answers is called "group_size". Therefore, there are batch_size × group_size sequences in each batch.
These sequences have different lengths, but they are concatenated (or packed) together as a 1D tensor. The inner dimension is the "group" with the same prompt, and the outer dimension consists of answers from different prompts. Similar to flash-attention, we use `cu_seqlens` to mark the boundary of each sequence. `cu_seqlens` is the cumulative sum of sequence lengths across the batch.
Each token in the sequence has a corresponding reward and advantage, so `advantages` is also a packed 1D tensor just like the tokens (i.e., `packed_input_ids`). However, the "sequences" of advantages are all one step shorter than tokens due to the auto-regressive nature of LLMs. We can only compute the loss on tokens except for the first one in each sequence.
### Implementation
For grouped advantage normalization, we need to partition the advantages into groups and run normalization within the tensor chunk of each group:
```diff
@dataclass
class PPOActorInterface(ModelInterface):
+ group_adv_norm: bool = False
def train_step(
self,
model: Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> Dict | List[Dict]:
...
if self.adv_norm:
- advantages = masked_normalization(advantages, loss_mask)
+ if not self.group_adv_norm:
+ advantages = masked_normalization(advantages, loss_mask)
+ else:
+ n_samples = data.bs
+ adv_list = []
+ for i in range(0, n_samples, self.group_size):
+ # Start and end of the chunk
+ s = short1cu_seqlens[i]
+ e = short1cu_seqlens[i + self.group_size]
+ # Get advantages within each group of the same prompt
+ adv = advantages[s: e]
+ mask = loss_mask[s: e]
+ # Run normalization
+ advn = masked_normalization(adv, mask, all_reduce=False)
+ adv_list.append(advn)
+ advantages = torch.cat(adv_list, 0)
...
```
### Modify Your Experiment Configuration
To make our new argument `group_adv_norm` effective in CLI args, we should make the following changes to the `PPOMathConfig` under `realhf/experiments/common/ppo_math_exp.py`:
```diff
@dataclasses.dataclass
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
+ group_adv_norm: bool = False
@property
def rpcs(self):
...
# interfaces
actor_interface = ModelInterfaceAbstraction(
"ppo_actor",
args={
**copy.deepcopy(self.ppo_kwargs),
+ "group_adv_norm": self.group_adv_norm,
...
},
def train_reinforce(self, data: TensorDict):
# Enable gradient checkpointing
self.engine.train()
return self.engine.train_batch(
data,
loss_fn=reinforce_loss_fn,
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
)
class FSDPReinforceActor(FSDPEngine):
def __init__(self):
self.actor = ReinforceActor(self)
def train_reinforce(self, *args, **kwargs):
return self.actor.train_reinforce(*args, **kwargs)
```
## The Decoupled PPO Loss
**Why two classes?** This design separates concerns:
![decoupled loss](decoupled_loss.png)
1. **Backend Agnostic Logic**: `ReinforceActor` contains the core REINFORCE algorithm
that works with any backend (FSDP, DeepSpeed, Megatron) since they share the same
`train_batch` API.
As mentioned in AReaL's paper, we implement this loss by recomputing the probabilities before mini-batched updates, and use this value as π_prox to compute the above loss.
1. **Backend-Specific Features**: `FSDPReinforceActor` inherits from `FSDPEngine` to
provide backend-specific utilities like `save`, `load`, and `upload_weights`. For
other backends, you'd create `MegatronReinforceActor`, etc.
### Probability Recomputation
> **Note**: This pattern is similar to interfaces in Go or traits in Rust, adapted for
> Python's object model.
Recomputation involves a single forward pass, which has already been implemented by `PPOActorInterface.inference`. We need to call this method in the `train_step` method:
## Step 3: Composing the Complete Training Loop
```diff
@dataclass
class PPOActorInterface(ModelInterface):
+ use_decoupled_loss: bool = False
def train_step(
self,
model: Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> Dict | List[Dict]:
+ if self.use_decoupled_loss:
+ s: SequenceSample = self.inference(model, data, mb_spec)
+ prox_logp = s.data["logprobs"]
...
```
Next, we need to pass `prox_logp` to loss computation:
```diff
@dataclass
class PPOActorInterface(ModelInterface):
...
def train_step(
self,
model: Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> Dict | List[Dict]:
# Prepare data to be split into mini-batches.
flat_data = dict(
advantages=advantages,
old_logp=old_logp,
ppo_loss_mask=loss_mask,
packed_input_ids=input_.data["packed_input_ids"],
kl_rewards=kl_rewards,
)
+ if self.use_decoupled_loss:
+ flat_data["prox_logp"] = prox_logp.float()
flat_input = SequenceSample.from_default(
ids=list(range(input_.bs * self.group_size)),
data=flat_data,
seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()],
)
...
datas = flat_input.split_with_spec(spec)
...
for mb_i, data in enumerate(datas):
train_stat = module.train_batch(
input_=data,
mb_spec=mb_spec,
version_steps=model.version.global_step,
loss_fn=_loss_fn,
loss_weight_fn=lambda x: x.data[
"ppo_loss_mask"
].count_nonzero(),
token_normalize_scope=self.token_normalize_scope,
)
```
The `flat_input` variable will be divided into mini-batches. Each mini-batch of data will be passed into the `train_batch` method to run distributed training. The data included in this `SequenceSample` object will all be passed into the `_loss_fn`. In this case, `_loss_fn` is a wrapper over `_ppo_actor_loss_from_model_outputs`:
The main training loop brings everything together:
```python
def _ppo_actor_loss_from_model_outputs(
logits: torch.FloatTensor, # [tot_seqlen, vocab_size]
input_: SequenceSample,
...
) -> torch.Tensor:
...
```
def main(args):
# Initialize inference engine for rollouts
rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, ft_spec)
`logits` is the output of model forward, and `input_` is exactly the `input_` we passed into `train_batch`. So now we can retrieve the `prox_logp` via:
# Initialize training engine
actor = FSDPReinforceActor(config=config.actor)
actor.initialize(None, ft_spec)
```diff
def _ppo_actor_loss_from_model_outputs(
logits: torch.FloatTensor, # [tot_seqlen, vocab_size]
input_: SequenceSample,
...
) -> torch.Tensor:
...
+ prox_logp = input_.data["prox_logp"]
loss, ppo_stat = ppo_functional.actor_loss_fn(
logprobs=logprobs,
old_logprobs=old_logp,
advantages=advantages,
eps_clip=eps_clip,
loss_mask=ppo_loss_mask,
c_clip=c_clip,
+ proximal_logprobs=prox_logp,
behav_imp_weight_cap=behav_imp_weight_cap,
# Create rollout workflow
workflow = ReMaxRLVRWorkflow(
reward_fn=gsm8k_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
)
# Main training loop
for global_step in range(max_steps):
# Generate training data
with stats_tracker.record_timing("rollout"):
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = batch.to(actor.device)
# Synchronize all processes
dist.barrier()
torch.cuda.synchronize()
# Training step
with (
stats_tracker.record_timing("train_step"),
stats_tracker.scope("actor"),
):
stats = actor.train_reinforce(batch)
actor.step_lr_scheduler()
# Update model weights
with stats_tracker.record_timing("update_weights"):
# Weight update logic here
...
```
We have successfully recomputed the probability and passed it into the loss function. Next we should revise the loss computation code.
### Modifying the PPO Loss
```diff
def actor_loss_fn(
logprobs: torch.FloatTensor,
old_logprobs: torch.FloatTensor,
advantages: torch.FloatTensor,
eps_clip: float,
loss_mask: Optional[torch.BoolTensor] = None,
c_clip: Optional[float] = None,
+ proximal_logprobs: Optional[torch.FloatTensor] = None,
behav_imp_weight_cap: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor, Dict]:
...
+ if proximal_logprobs is not None:
+ denorm_logprobs = proximal_logprobs
+ else:
+ denorm_logprobs = old_logprobs
...
loss_mask_count = loss_mask.count_nonzero() or 1
# For numerical stability.
- ratio = torch.where(loss_mask, torch.exp(logprobs - old_logprobs), 0)
+ ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
...
+ if proximal_logprobs is not None:
+ behav_kl = proximal_logprobs - old_logprobs
+ behav_imp_weight = behav_kl.exp()
+ behav_kl = torch.where(loss_mask, behav_kl, 0.0)
+ behav_imp_weight = torch.where(loss_mask, behav_imp_weight, 0.0)
+ pg_loss = pg_loss * behav_imp_weight
...
return pg_loss, stat
```
### Modify the Experiment Configuration
```diff
@dataclasses.dataclass
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
+ use_decoupled_loss: bool = False
@property
def rpcs(self):
...
# interfaces
actor_interface = ModelInterfaceAbstraction(
"ppo_actor",
args={
**copy.deepcopy(self.ppo_kwargs),
+ "use_decoupled_loss": self.use_decoupled_loss,
...
},
)
```

View File

@ -1,134 +1,144 @@
# Dataset
This guide provides detailed examples of how to create custom datasets in AReaL for model training.
**AReaLite** directly integrates with the `Dataset` class from the HuggingFace
`datasets` package. This gives you full flexibility to load, process, and filter your
data before training.
## Define Your Dataset
The required columns in your dataset depend on the specific implementation of the
`RolloutWorkflow` (for online reinforcement learning) or the training engines (for
offline training, such as `LMEngine` for Supervised Fine-Tuning (SFT)).
Create a new file under `realhf/impl/dataset/`, for example, `my_custom_dataset.py`. Your `Dataset` must implement the `torch.utils.data.Dataset` interface and follow the framework's conventions.
Here are two concrete examples from the existing implementation:
## SFT (Offline Training)
In the SFT example, we see that the loaded data is directly passed to the `train_lm`
method:
```python
class MyCustomDataset(torch.utils.data.Dataset):
def __init__(
self,
util: data_api.DatasetUtility,
max_length: Optional[int] = None,
dataset_path: Optional[str] = None,
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
# Your custom parameters
custom_param: float = 1.0,
):
"""Custom dataset initialization
Args:
util: Dataset utility class containing tokenizer, seed, distributed info, etc.
max_length: Maximum sequence length
dataset_path: Path to dataset file (optional)
dataset_builder: Data construction function (optional, alternative to dataset_path)
custom_param: Your custom parameter
"""
self._util = util
self.max_length = max_length
# Load and split dataset
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
# Your custom data processing logic
# examples/arealite/gsm8k_sft.py
def main(args):
...
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
get_gsm8k_dataset("train", tokenizer, rank, world_size),
collate_fn=pad_sequences_to_tensors,
)
...
# Run training loop
for epoch in range(total_epochs):
for step, data in enumerate(train_dataloader):
stats = engine.train_lm(data)
```
In this case, the `train_lm` method requires the keys "input_ids", "attention_mask", and
"loss_mask" to function. We first tokenize the dataset to extract the "input_ids" and
"loss_mask". Then, the `pad_sequences_to_tensors` method is used to batch multiple
sequences and append the "attention_mask":
```python
def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
def process(sample):
seq_token = tokenizer.encode(
sample["question"] + sample["answer"] + tokenizer.eos_token
)
prompt_token = tokenizer.encode(sample["question"])
loss_mask = [0] * len(prompt_token) + [1] * (len(seq_token) - len(prompt_token))
return {"input_ids": seq_token, "loss_mask": loss_mask}
# Remove unnecessary columns to avoid errors during collation
dataset = dataset.map(process).remove_columns(["question", "answer"])
return dataset
def get_gsm8k_dataset(split, tokenizer, rank, world_size):
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
return process_gsm8k_sft_dataset(dataset, tokenizer)
```
## GRPO (Online Training)
In the GRPO example, the loaded data is passed to the `InferenceEngine`, rather than the
`TrainEngine`:
```python
# examples/arealite/gsm8k_ppo.py
def main(args):
...
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
get_gsm8k_dataset("train", rank, world_size),
collate_fn=lambda x: x,
)
# Initialize inference engine
rollout = RemoteSGLangEngine(config.rollout)
workflow = RLVRWorkflow(
reward_fn=gsm8k_reward_fn,
...
)
# Run training loop
...
for global_step in range(max_steps):
batch = rollout.rollout_batch(data, workflow=workflow)
...
```
## Implement Core Methods
Note that the `collate_fn` here is an identity function, meaning it simply returns the
list of individual data items as a batch. In `rollout_batch`, the data is then
dispatched to multiple concurrent executions of `workflow.arun_episode`, where each
dispatched data corresponds to a single episode.
Every dataset class must implement the following two core methods:
### 1. `__len__` Method
Returns the size of the dataset:
The `RLVRWorkflow` implementation extracts the "messages" field from the data dictionary
as the prompt for generating a response. Additionally, this data is passed to the
`reward_fn` as keyword arguments, which allows the reward function to make use of other
dataset fields, like "answers". Heres an example:
```python
def __len__(self):
return len(self.data_samples)
```
class RLVRWorkflow(RolloutWorkflow):
### 2. `__getitem__` Method
Returns the sample at the specified index, must return a `SequenceSample` object:
```python
def __getitem__(self, idx):
# Get raw data
sample = self.data_samples[idx]
# Process data
...
# Return SequenceSample object
return data_api.SequenceSample.from_default(
ids=[sample["id"]],
seqlens=[len(processed_data["input_ids"])],
data=dict(
packed_prompts=torch.tensor(processed_data["input_ids"], dtype=torch.long),
# Other necessary data fields
),
)
```
### Dataset Examples
We provide some examples of dataset under `realhf/impl/dataset/`:
- For SFT, please refer `prompt_answer_dataset.py`.
- For Reward model training, please refer `rw_paired_dataset.py`
- For RL training, please refer `math_code_dataset.py`
## Data Format Requirements
### JSONL File Format
Your data file should be in JSONL format, with one JSON object per line.
If you are using our PromptDataset implementation, your data should be like:
- Math Data
```json
{"qid": "sample_1", "prompt": "Solve this math problem: 2+2=", "solutions": ["\\boxed{4}"]}
```
- Code Data
```json
{"qid": "sample_2", "prompt": "Code problem", "input_output": "{\"inputs\": [\"5\\n2 3 5 10 12\\n\"], \"outputs\": [\"17\\n\"]}"}
```
- `qid`: Unique identifier for the sample
- `prompt`: Input prompt text
- `task`: Task type, used to distinguish how to calculate the reward. ("math" and "code" are supported now.)
Note: There is no format restriction for a customized dataset as long as it can be loaded by your custom code.
## Registration and Configuration
### Register Dataset
Register your dataset at the end of your dataset file:
```python
# in realhf/impl/dataset/my_custom_dataset.py
data_api.register_dataset("my-custom", MyCustomDataset)
```
### Modify Experiment Configuration
Use your new dataset in the experiment configuration (refer to `realhf/experiments/common/*_exp.py`):
```python
# in your experiment config file
@property
def datasets(self) -> List[DatasetAbstraction]:
return [
DatasetAbstraction(
"my-custom", # Your registered name
args=dict(
dataset_path=self.dataset_path,
max_length=self.max_length,
custom_param=self.custom_param,
# Other initialization parameters
),
async def arun_episode(self, engine: InferenceEngine, data):
input_ids = self.tokenizer.apply_chat_template(
data["messages"],
tokenize=True,
add_generation_prompt=True,
enable_thinking=self.enable_thinking,
)
req = LLMRequest(
input_ids=input_ids,
...
)
...
reward = self.reward_fn(
prompt=prompt_str,
completions=completions_str,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
]
```
Thus, the "messages" field must be constructed when loading the dataset, and the reward
function should be defined to handle the dataset's specific fields. Heres how you can
process the dataset for this example:
```python
def process_gsm8k_rl_dataset(dataset: Dataset):
def process(sample):
messages = [{"role": "user", "content": sample["question"]}]
return {"messages": messages}
# The dataset has two fields "messages" and "answer"
dataset = dataset.map(process).remove_columns(["question"])
return dataset
def get_gsm8k_dataset(split, rank, world_size):
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
return process_gsm8k_rl_dataset(dataset)
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
# "answer" is passed in through "**data"
from realhf.impl.dataset.math_parser import process_results
return int(process_results(completions, answer)[0])
```

View File

@ -0,0 +1,162 @@
# Rollout and Agentic RL (Legacy)
> **Note**: While this legacy approach works, we strongly recommend using the AReaLite
> for new projects. It provides better flexibility, cleaner abstractions, and easier
> maintenance.
## Step 1: Define Your Agent Class
Create a new file under `realhf/impl/agent/`, such as `math_multi_turn_agent.py`. Your
`Agent` must implement the interface defined in `realhf/api/core/agent.py`, which
requires a single method: `collect_trajectory`.
```python
class MathMultiTurnAgent(Agent):
async def collect_trajectory(
self,
prompt: SequenceSample,
env: EnvironmentService,
obs_queue: asyncio.Queue,
act_queue: asyncio.Queue,
):
# Implementation goes here
...
```
## Step 2: Implement the Trajectory Collection Logic
The `collect_trajectory` method takes a task prompt, an environment, and two
communication queues. Within this method, you control the data flow between your agent
and the inference engine using these queues:
- **obs_queue**: Send observations (token IDs and generation config) to the inference
engine
- **act_queue**: Receive actions (generated responses) from the inference engine
Here's how the multi-turn conversation works:
```python
for turn in range(self.num_turns):
# Send the current state to the inference engine
await obs_queue.put((qid, token_ids, self.gconfig))
# Get the generated response
act: BundledGenerationOutputs = await act_queue.get()
# Evaluate the response through the environment
success, rewards = await env.step((qid, answers))
# ... process results ...
```
### Environment Integration
The environment follows a
[Gym-like interface](https://github.com/Farama-Foundation/Gymnasium) with `reset` and
`step` methods, but uses asynchronous implementations to prevent blocking across
different environment instances.
For math problems, the environment is typically stateless and acts as a wrapper around
your reward function:
```python
class MathCodeSingleStepEnv(EnvironmentService):
async def step(self, action: Tuple[str, List[str]]):
qid, answers = action
# ... setup code ...
# Run reward computation asynchronously
format_rewards = await asyncio.to_thread(
math_verify_call,
answers,
# ... other parameters ...
)
return None, format_rewards, True, False, {}
```
### Handling Multi-Turn Feedback
After receiving the reward from `env.step`, check if the answer is correct. If not,
provide feedback and continue to the next turn:
```python
for turn in range(self.num_turns):
# ... generation and evaluation code ...
# Provide feedback based on the result
if success[0]:
feedback = "Congratulations! You are correct!"
else:
feedback = "Unfortunately your answer is wrong. Let's try again."
# Format feedback as a user message
feedback = "\n" + self.tokenizer.apply_chat_template(
[{"content": feedback, "role": "user"}],
add_generation_prompt=True,
tokenize=False,
)
# Add feedback tokens to the conversation
feedback_tokens = self.tokenizer(feedback)["input_ids"]
token_ids.extend(feedback_tokens)
```
## Step 3: Register and Configure Your Agent
First, register your agent implementation:
```python
# in realhf/impl/agent/math_multi_turn_agent.py
register_agent("math-multi-turn", MathMultiTurnAgent)
```
```python
# in realhf/impl/agent/__init__.py
import realhf.impl.agent.math_multi_turn_agent
```
Then update your experiment configuration in
`realhf/experiments/async_exp/async_math_ppo.py`:
```python
@dataclasses.dataclass
class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
# Add any new CLI arguments your agent needs
my_param: float = 1.0
@property
def agent(self) -> AgentAbstraction:
return AgentAbstraction(
"math-multi-turn", # Your registered agent name
args=dict(
# Pass any arguments needed for your __init__ method
my_param=self.my_param,
# ... other configuration ...
),
)
@property
def env(self) -> EnvServiceAbstraction:
# Update to use your custom environment if needed
return EnvServiceAbstraction(
"math-code-single-step",
args=dict(dataset_path=self.dataset.path)
)
```
## Step 4: Run Training
Follow the standard training procedure outlined in the
[quickstart guide](../../tutorial/quickstart.md). Launch your experiment with:
```bash
python3 training/main_async_ppo.py my_param=5.0 # plus any additional CLI arguments
```
## Training Results
Here's an example of the training reward curve from our multi-turn math agent:
![Multi-turn Training Rewards](multiturn_reward.png)
The agent successfully learns to solve math problems with improved accuracy over time,
demonstrating the effectiveness of the multi-turn approach.

View File

@ -0,0 +1,259 @@
# Training Algorithm (Legacy)
> **Note**: The AReaLite approach is more recommended for new implementations due to its
> cleaner separation of concerns and better maintainability.
The legacy approach encapsulates algorithms in a `ModelInterface` with three core
methods:
```python
# From realhf/api/core/model_api.py
class ModelInterface(abc.ABC):
"""Interface for model training, inference, and generation.
This interface follows the dependency injection pattern, allowing
algorithms like REINFORCE and PPO to use the same underlying model
while exhibiting different training behaviors.
"""
def inference(
self,
model: Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> SequenceSample | None:
raise NotImplementedError()
def generate(
self,
model: Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> SequenceSample | None:
raise NotImplementedError()
def train_step(
self,
model: Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> Dict | List[Dict]:
raise NotImplementedError()
```
When the dataflow is fixed, you typically only need to modify the algorithm interface
file.
> **Note**: We recommend using asynchronous RL so you can customize generation behavior
> by [modifying your RL agent](agent.md) instead of the `generate` method.
## Example 1: Grouped Advantage Normalization
Let's modify PPO's global advantage normalization to use grouped normalization (GRPO
approach).
### Understanding Data Organization
Each batch contains multiple prompts (batch size) and each prompt may have multiple
responses (group size). So total sequences = batch_size × group_size.
Sequences have different lengths but are packed into a 1D tensor. We use `cu_seqlens`
(cumulative sequence lengths) to mark boundaries, similar to flash-attention.
### Implementation
The standard PPO normalization looks like:
```python
@dataclass
class PPOActorInterface(ModelInterface):
def train_step(self, model: Model, data: SequenceSample, mb_spec: MicroBatchSpec) -> Dict | List[Dict]:
# ...
if self.adv_norm:
advantages = masked_normalization(advantages, loss_mask)
# ...
```
For grouped normalization, we partition advantages by group:
```python
@dataclass
class PPOActorInterface(ModelInterface):
group_adv_norm: bool = False
def train_step(self, model: Model, data: SequenceSample, mb_spec: MicroBatchSpec) -> Dict | List[Dict]:
# ...
if self.adv_norm:
if not self.group_adv_norm:
advantages = masked_normalization(advantages, loss_mask)
else:
n_samples = data.bs
adv_list = []
for i in range(0, n_samples, self.group_size):
# Define chunk boundaries
s = short1cu_seqlens[i]
e = short1cu_seqlens[i + self.group_size]
# Extract advantages for this group
adv = advantages[s:e]
mask = loss_mask[s:e]
# Normalize within group
advn = masked_normalization(adv, mask, all_reduce=False)
adv_list.append(advn)
advantages = torch.cat(adv_list, 0)
# ...
```
### Configuration Changes
Update the experiment configuration to expose the new parameter:
```python
@dataclasses.dataclass
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
group_adv_norm: bool = False
@property
def rpcs(self):
# ...
actor_interface = ModelInterfaceAbstraction(
"ppo_actor",
args={
**copy.deepcopy(self.ppo_kwargs),
"group_adv_norm": self.group_adv_norm,
# ...
},
)
```
## Example 2: Decoupled PPO Loss
The decoupled PPO loss (from AReaL's paper) recomputes probabilities before mini-batch
updates and uses this as π_prox:
![decoupled loss](decoupled_loss.png)
### Probability Recomputation
We recompute probabilities using the existing `inference` method:
```python
@dataclass
class PPOActorInterface(ModelInterface):
use_decoupled_loss: bool = False
def train_step(self, model: Model, data: SequenceSample, mb_spec: MicroBatchSpec) -> Dict | List[Dict]:
if self.use_decoupled_loss:
s: SequenceSample = self.inference(model, data, mb_spec)
prox_logp = s.data["logprobs"]
# Prepare mini-batch data
flat_data = dict(
advantages=advantages,
old_logp=old_logp,
ppo_loss_mask=loss_mask,
packed_input_ids=input_.data["packed_input_ids"],
kl_rewards=kl_rewards,
)
if self.use_decoupled_loss:
flat_data["prox_logp"] = prox_logp.float()
flat_input = SequenceSample.from_default(
ids=list(range(input_.bs * self.group_size)),
data=flat_data,
seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()],
)
# Split into mini-batches and train
datas = flat_input.split_with_spec(spec)
for mb_i, data in enumerate(datas):
train_stat = module.train_batch(
input_=data,
mb_spec=mb_spec,
version_steps=model.version.global_step,
loss_fn=_loss_fn,
loss_weight_fn=lambda x: x.data["ppo_loss_mask"].count_nonzero(),
token_normalize_scope=self.token_normalize_scope,
)
```
### Modifying the Loss Function
Update the loss computation to use the recomputed probabilities:
```python
def _ppo_actor_loss_from_model_outputs(
logits: torch.FloatTensor, # [tot_seqlen, vocab_size]
input_: SequenceSample,
...
) -> torch.Tensor:
# ...
prox_logp = input_.data.get("prox_logp")
loss, ppo_stat = ppo_functional.actor_loss_fn(
logprobs=logprobs,
old_logprobs=old_logp,
advantages=advantages,
eps_clip=eps_clip,
loss_mask=ppo_loss_mask,
c_clip=c_clip,
proximal_logprobs=prox_logp,
behav_imp_weight_cap=behav_imp_weight_cap,
)
```
And in the core loss function:
```python
def actor_loss_fn(
logprobs: torch.FloatTensor,
old_logprobs: torch.FloatTensor,
advantages: torch.FloatTensor,
eps_clip: float,
loss_mask: Optional[torch.BoolTensor] = None,
c_clip: Optional[float] = None,
proximal_logprobs: Optional[torch.FloatTensor] = None,
behav_imp_weight_cap: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor, Dict]:
# Use proximal probabilities if available, otherwise use old probabilities
denorm_logprobs = proximal_logprobs if proximal_logprobs is not None else old_logprobs
loss_mask_count = loss_mask.count_nonzero() or 1
# Compute importance weights
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
# Apply behavioral importance weighting for decoupled loss
if proximal_logprobs is not None:
behav_kl = proximal_logprobs - old_logprobs
behav_imp_weight = behav_kl.exp()
behav_kl = torch.where(loss_mask, behav_kl, 0.0)
behav_imp_weight = torch.where(loss_mask, behav_imp_weight, 0.0)
pg_loss = pg_loss * behav_imp_weight
# ...
return pg_loss, stat
```
### Configuration Update
```python
@dataclasses.dataclass
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
use_decoupled_loss: bool = False
@property
def rpcs(self):
# ...
actor_interface = ModelInterfaceAbstraction(
"ppo_actor",
args={
**copy.deepcopy(self.ppo_kwargs),
"use_decoupled_loss": self.use_decoupled_loss,
# ...
},
)
```

View File

@ -0,0 +1,146 @@
# Dataset (Legacy)
> **Note**: While this legacy approach works, we strongly recommend using the AReaLite
> for new projects. It provides better flexibility, cleaner abstractions, and easier
> maintenance.
## Define Your Dataset
Create a new file under `realhf/impl/dataset/`, for example, `my_custom_dataset.py`.
Your `Dataset` must implement the `torch.utils.data.Dataset` interface and follow the
framework's conventions.
```python
class MyCustomDataset(torch.utils.data.Dataset):
def __init__(
self,
util: data_api.DatasetUtility,
max_length: Optional[int] = None,
dataset_path: Optional[str] = None,
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
# Your custom parameters
custom_param: float = 1.0,
):
"""Custom dataset initialization
Args:
util: Dataset utility class containing tokenizer, seed, distributed info, etc.
max_length: Maximum sequence length
dataset_path: Path to dataset file (optional)
dataset_builder: Data construction function (optional, alternative to dataset_path)
custom_param: Your custom parameter
"""
self._util = util
self.max_length = max_length
# Load and split dataset
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
# Your custom data processing logic
...
```
## Implement Core Methods
Every dataset class must implement the following two core methods:
### 1. `__len__` Method
Returns the size of the dataset:
```python
def __len__(self):
return len(self.data_samples)
```
### 2. `__getitem__` Method
Returns the sample at the specified index, must return a `SequenceSample` object:
```python
def __getitem__(self, idx):
# Get raw data
sample = self.data_samples[idx]
# Process data
...
# Return SequenceSample object
return data_api.SequenceSample.from_default(
ids=[sample["id"]],
seqlens=[len(processed_data["input_ids"])],
data=dict(
packed_prompts=torch.tensor(processed_data["input_ids"], dtype=torch.long),
# Other necessary data fields
),
)
```
### Dataset Examples
We provide some examples of dataset under `realhf/impl/dataset/`:
- For SFT, please refer `prompt_answer_dataset.py`.
- For Reward model training, please refer `rw_paired_dataset.py`
- For RL training, please refer `math_code_dataset.py`
## Data Format Requirements
### JSONL File Format
Your data file should be in JSONL format, with one JSON object per line. If you are
using our PromptDataset implementation, your data should be like:
- Math Data
```json
{"qid": "sample_1", "prompt": "Solve this math problem: 2+2=", "solutions": ["\\boxed{4}"]}
```
- Code Data
```json
{"qid": "sample_2", "prompt": "Code problem", "input_output": "{\"inputs\": [\"5\\n2 3 5 10 12\\n\"], \"outputs\": [\"17\\n\"]}"}
```
- `qid`: Unique identifier for the sample
- `prompt`: Input prompt text
- `task`: Task type, used to distinguish how to calculate the reward. ("math" and "code"
are supported now.)
Note: There is no format restriction for a customized dataset as long as it can be
loaded by your custom code.
## Registration and Configuration
### Register Dataset
Register your dataset at the end of your dataset file:
```python
# in realhf/impl/dataset/my_custom_dataset.py
data_api.register_dataset("my-custom", MyCustomDataset)
```
### Modify Experiment Configuration
Use your new dataset in the experiment configuration (refer to
`realhf/experiments/common/*_exp.py`):
```python
# in your experiment config file
@property
def datasets(self) -> List[DatasetAbstraction]:
return [
DatasetAbstraction(
"my-custom", # Your registered name
args=dict(
dataset_path=self.dataset_path,
max_length=self.max_length,
custom_param=self.custom_param,
# Other initialization parameters
),
)
]
```

View File

Before

Width:  |  Height:  |  Size: 47 KiB

After

Width:  |  Height:  |  Size: 47 KiB

View File

Before

Width:  |  Height:  |  Size: 32 KiB

After

Width:  |  Height:  |  Size: 32 KiB