mirror of https://github.com/inclusionAI/AReaL
[doc] [lite] Add customization docs for AReaLite. (#191)
* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * add gradient checkpointing * PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * fix * . * PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * fix destroy process group * PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/358 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * fix loss mask * fix * . * PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/368 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/371 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/370 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * . * . * . * fix * PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results. Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/392 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * support fsdp engine and sglang remote engine * minor fix * . * refactor trainer * add close * rm mb_spec * . * fix * . * qwen2 grpo works * fix * fix * async works * fix * slurm launcher not tested * fix arg parse * . * sglang server wrapper * . * . * slurm run * ready for boba * debug * 32k run * . * . * fix * . * . * . * . * . * fix * . * fix * . * . * . * . * fix * . * . * . * . * . * . * . * refactor train engine * refactor train engine * . * fix update weight error * . * . * match train * format * . * fix * seems to work * . * . * . * . * . * . * . * . * . * . * . --------- Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
parent
ba16d4ef44
commit
6239633213
|
@ -0,0 +1,87 @@
|
|||
import uuid
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
from arealite.api.engine_api import InferenceEngine
|
||||
from arealite.api.io_struct import LLMRequest
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
|
||||
|
||||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
def __init__(
|
||||
self,
|
||||
reward_fn,
|
||||
gconfig: GenerationHyperparameters,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
max_turns: int,
|
||||
turn_discount: float,
|
||||
):
|
||||
self.reward_fn = reward_fn
|
||||
self.gconfig = gconfig
|
||||
self.tokenizer = tokenizer
|
||||
self.max_turns = max_turns
|
||||
self.turn_discount = turn_discount
|
||||
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
# Placeholders for the results
|
||||
seq, logprobs, loss_mask, versions = [], [], [], []
|
||||
messages = data["messages"]
|
||||
# Run multi-turn rollout until correct
|
||||
t = reward = 0
|
||||
discount = 0
|
||||
rid = uuid.uuid4().hex
|
||||
while reward == 0 and t < self.max_turns:
|
||||
# Amend a prompt if the previous answer is incorrect
|
||||
if t > 0:
|
||||
messages += [
|
||||
{"role": "asistant", "content": completions_str},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Your answer is not correct. Please try to answer it again.",
|
||||
},
|
||||
]
|
||||
# Convert the prompt into input_ids
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
# Send generate request to get the response.
|
||||
req = LLMRequest(
|
||||
rid=rid,
|
||||
input_ids=input_ids,
|
||||
gconfig=self.gconfig.new(n_samples=1),
|
||||
)
|
||||
resp = await engine.agenerate(req)
|
||||
# compute reward: 1 for correct and 0 otherwise
|
||||
prompt_str = self.tokenizer.decode(input_ids)
|
||||
completions_str = self.tokenizer.decode(resp.output_tokens)
|
||||
reward = self.reward_fn(
|
||||
prompt=prompt_str,
|
||||
completions=completions_str,
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
# Amend results
|
||||
input_len = len(resp.input_tokens) - len(seq)
|
||||
seq += resp.input_tokens[-input_len:] + resp.output_tokens
|
||||
logprobs += [0.0] * input_len + resp.output_logprobs
|
||||
loss_mask += [0] * input_len + [1] * resp.output_len
|
||||
versions += [-1] * input_len + resp.output_versions
|
||||
# Increase counter
|
||||
t += 1
|
||||
discount *= self.turn_discount
|
||||
res = dict(
|
||||
seq=torch.tensor(seq),
|
||||
logprobs=torch.tensor(logprobs),
|
||||
loss_mask=torch.tensor(loss_mask),
|
||||
versions=torch.tensor(versions),
|
||||
rewards=torch.tensor([float(reward * discount)]),
|
||||
attetion_mask=torch.ones(len(seq), dtype=torch.bool),
|
||||
)
|
||||
res = {k: v.unsqueeze(0) for k, v in res.items()}
|
||||
return concat_padded_tensors([res])
|
|
@ -39,3 +39,8 @@ parts:
|
|||
- caption: Contributing
|
||||
chapters:
|
||||
- file: contrib
|
||||
- caption: Customization (Legacy)
|
||||
chapters:
|
||||
- file: legacy/customization/dataset
|
||||
- file: legacy/customization/agent
|
||||
- file: legacy/customization/algorithm
|
||||
|
|
|
@ -1,135 +1,267 @@
|
|||
# Rollout and Agentic RL
|
||||
|
||||
This guide provides an example of modifying the rollout behavior for PPO training.
|
||||
This guide shows you how to create custom rollout behaviors for RL training by building
|
||||
a multi-turn math agent with **AReaLite**. This agent keeps trying to solve math
|
||||
problems until it finds the correct answer.
|
||||
|
||||
In particular, we implement a multi-turn math agent using end-to-end RL. The math agent will continuously attempt to think through and solve math problems until it reaches the correct answer.
|
||||
You can find the complete implementation in `arealite/workflow/multi_turn.py`.
|
||||
|
||||
## Define Your Agent
|
||||
## Step 1: Define Your Workflow
|
||||
|
||||
Create a new file under `realhf/impl/agent/`, for example, `math_multi_turn_agent.py`. Your `Agent` must implement the interface defined in `realhf/api/core/agent.py`, which requires implementing a single method: `collect_trajectory`.
|
||||
AReaLite gives you flexibility in how you design your agents. Instead of rigid `Agent`
|
||||
classes that might constrain your agent's capabilities, AReaLite captures all rollout
|
||||
behavior in a `RolloutWorkflow` class. This approach lets you customize your agent's
|
||||
behavior however you need.
|
||||
|
||||
```python
|
||||
class MathMultiTurnAgent(Agent):
|
||||
|
||||
async def collect_trajectory(
|
||||
# arealite/api/workflow_api.py
|
||||
class RolloutWorkflow:
|
||||
async def arun_episode(
|
||||
self, engine: InferenceEngine, data: Dict[str, Any]
|
||||
) -> TensorDict:
|
||||
"""Run a single episode of the workflow.
|
||||
|
||||
See concrete example implementations under the `arealite/workflow` directory.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
```
|
||||
|
||||
The workflow exposes an `arun_episode` method that runs and collects data from a single
|
||||
episode. This method takes two key arguments:
|
||||
|
||||
1. **InferenceEngine**: Provides the `agenerate` method for generating responses to user
|
||||
inputs
|
||||
1. **data**: The prompt data loaded from your RL dataset
|
||||
|
||||
Within this method, you have complete control over how your agent and environment
|
||||
interact.
|
||||
|
||||
> **Note**: Each `arun_episode` call takes a single prompt and outputs the trajectories
|
||||
> generated from that prompt—it's not batched. However, you can generate multiple
|
||||
> trajectories from a single prompt (for example, with GRPO or tree search).
|
||||
|
||||
### Setting Up the Multi-Turn Math Workflow
|
||||
|
||||
Let's build a multi-turn rollout workflow for solving math problems. First, we'll define
|
||||
the `__init__` method to set up what we need during rollout:
|
||||
|
||||
> **Note**: You have complete flexibility in defining the `__init__` method. Pass
|
||||
> whatever arguments you need to construct your workflow. If you want to use tools, pass
|
||||
> the corresponding environment here so your agent can call it in the `arun_episode`
|
||||
> method.
|
||||
|
||||
```python
|
||||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
def __init__(
|
||||
self,
|
||||
prompt: SequenceSample,
|
||||
env: EnvironmentService,
|
||||
obs_queue: asyncio.Queue,
|
||||
act_queue: asyncio.Queue,
|
||||
reward_fn,
|
||||
gconfig: GenerationHyperparameters, # aka sampling_params
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
max_turns: int,
|
||||
turn_discount: float,
|
||||
):
|
||||
...
|
||||
self.reward_fn = reward_fn
|
||||
self.gconfig = gconfig
|
||||
self.tokenizer = tokenizer
|
||||
self.max_turns = max_turns
|
||||
# Discount rewards if the agent takes longer to find the correct answer
|
||||
self.turn_discount = turn_discount
|
||||
```
|
||||
|
||||
## Implement the `collect_trajectory` Logic
|
||||
### Implementing the Episode Logic
|
||||
|
||||
The `collect_trajectory` function takes a task prompt, an environment, and two queues as input, then produces several trajectories for the RL trainer. Within this function, you can create arbitrary data processing logic to produce the input for the inference engine (i.e., via `obs_queue`) and extract the action (i.e., via `act_queue`) from the generated tokens.
|
||||
|
||||
In this example, the initial observation is the math problem itself. We put the token IDs and generation config into `obs_queue` and wait for the action produced by the inference engine from `act_queue`. After the inference engine returns, we extract the generated answers and send them to the environment.
|
||||
Now let's implement the `arun_episode` method. We'll start by tokenizing the prompt data
|
||||
and converting it into an `LLMRequest` object for the inference engine:
|
||||
|
||||
```python
|
||||
for turn in range(self.num_turns):
|
||||
await obs_queue.put((qid, token_ids, self.gconfig))
|
||||
act: BundledGenerationOutputs = await act_queue.get()
|
||||
_, success, *_ = await env.step((qid, answers))
|
||||
...
|
||||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
# ... __init__ method above ...
|
||||
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
# Initialize result containers
|
||||
seq, logprobs, loss_mask, versions = [], [], [], []
|
||||
messages = data["messages"]
|
||||
# Run multi-turn rollout until we get the correct answer
|
||||
t = reward = 0
|
||||
discount = 1.0
|
||||
rid = uuid.uuid4().hex
|
||||
while reward == 0 and t < self.max_turns:
|
||||
# Convert the conversation into input tokens
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
# Generate response from the model
|
||||
req = LLMRequest(
|
||||
rid=rid,
|
||||
input_ids=input_ids,
|
||||
gconfig=self.gconfig.new(n_samples=1),
|
||||
)
|
||||
resp = await engine.agenerate(req)
|
||||
# ... continue processing ...
|
||||
```
|
||||
|
||||
The environment is similar to a [gym environment](https://github.com/Farama-Foundation/Gymnasium), which defines two methods: `reset` and `step`. However, to maintain efficiency, we use an asynchronous implementation to avoid mutual blocking across different environment instances.
|
||||
> **Note**: This example uses the "messages" key from the prompt data to get
|
||||
> OpenAI-compatible messages. This isn't required—the key and prompt format depend
|
||||
> entirely on your implementation. For instance, if your dataset stores prompt strings
|
||||
> in a "prompt" column, you could get input token IDs with
|
||||
> `self.tokenizer.encode(data["prompt"])`.
|
||||
|
||||
The math environment is stateless and essentially serves as a wrapper around the reward function:
|
||||
> **Note**: The `rid` field in `LLMRequest` is the request ID. Requests with the same ID
|
||||
> will reuse the LLM inference server's KV caches for better efficiency.
|
||||
|
||||
### Handling Multi-Turn Conversations
|
||||
|
||||
Next, we'll check if the current answer is correct using our `reward_fn`. This function
|
||||
should return 1 for correct answers and 0 otherwise. When the answer is wrong, we'll
|
||||
apply a discount, add feedback to the conversation, and let the model try again:
|
||||
|
||||
```python
|
||||
class MathCodeSingleStepEnv(EnvironmentService):
|
||||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
# ... previous methods ...
|
||||
|
||||
async def step(self, action: Tuple[str, List[str]]):
|
||||
qid, answers = action
|
||||
...
|
||||
# Make `math_verify_call` async
|
||||
format_rewards = await asyncio.to_thread(
|
||||
math_verify_call,
|
||||
answers,
|
||||
...
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
# ... initialization code ...
|
||||
while reward == 0 and t < self.max_turns:
|
||||
# Add feedback if the previous answer was incorrect
|
||||
if t > 0:
|
||||
messages += [
|
||||
{"role": "assistant", "content": completions_str},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Your answer is not correct. Please try to answer it again."
|
||||
},
|
||||
]
|
||||
# Generate response (code from above)
|
||||
# ...
|
||||
# Evaluate the response
|
||||
prompt_str = self.tokenizer.decode(input_ids)
|
||||
completions_str = self.tokenizer.decode(resp.output_tokens)
|
||||
reward = self.reward_fn(
|
||||
prompt=prompt_str,
|
||||
completions=completions_str,
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
# Update counters
|
||||
t += 1
|
||||
discount *= self.turn_discount
|
||||
```
|
||||
|
||||
### Reward Function Signature
|
||||
|
||||
To make it easier to switch between different reward functions, we recommend following
|
||||
this signature:
|
||||
|
||||
```python
|
||||
def reward_fn(
|
||||
prompt: str,
|
||||
completions: str,
|
||||
prompt_ids: List[int],
|
||||
completion_ids: List[int],
|
||||
**kwargs,
|
||||
):
|
||||
"""Reward function for evaluating agent performance.
|
||||
|
||||
This signature is recommended for compatibility with predefined workflows,
|
||||
but you can modify it freely in custom implementations.
|
||||
|
||||
Args:
|
||||
prompt: The task description string
|
||||
completions: The agent's response string
|
||||
prompt_ids: Tokenized prompt
|
||||
completion_ids: Tokenized response
|
||||
**kwargs: Additional dataset attributes (solutions, input_outputs, etc.)
|
||||
|
||||
Returns:
|
||||
float: Reward value (typically 1.0 for correct, 0.0 for incorrect)
|
||||
"""
|
||||
```
|
||||
|
||||
While this signature is convenient, you're not restricted to it in custom
|
||||
workflows—modify as needed for your specific use case.
|
||||
|
||||
### Collecting Training Data
|
||||
|
||||
Finally, let's complete the implementation by collecting trajectories in the
|
||||
`TensorDict` format:
|
||||
|
||||
```python
|
||||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
# ... previous methods ...
|
||||
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
# ... episode logic above ...
|
||||
|
||||
while reward == 0 and t < self.max_turns:
|
||||
# ... generation and evaluation ...
|
||||
|
||||
# Collect trajectory data
|
||||
input_len = len(resp.input_tokens) - len(seq)
|
||||
seq += resp.input_tokens[-input_len:] + resp.output_tokens
|
||||
logprobs += [0.0] * input_len + resp.output_logprobs
|
||||
loss_mask += [0] * input_len + [1] * resp.output_len
|
||||
versions += [-1] * input_len + resp.output_versions
|
||||
|
||||
# Package results
|
||||
res = dict(
|
||||
input_ids=torch.tensor(seq),
|
||||
logprobs=torch.tensor(logprobs),
|
||||
loss_mask=torch.tensor(loss_mask),
|
||||
versions=torch.tensor(versions),
|
||||
rewards=torch.tensor(float(reward * discount)),
|
||||
attention_mask=torch.ones(len(seq), dtype=torch.bool),
|
||||
)
|
||||
return None, format_rewards, True, False, {}
|
||||
res = {k: v.unsqueeze(0) for k, v in res.items()}
|
||||
return concat_padded_tensors([res])
|
||||
```
|
||||
|
||||
After `env.step` returns the reward for the current step, we can check whether the answer is correct. If not, we can append a user prompt and send it to `obs_queue` again to enter the next round.
|
||||
> **Important**: The returned `TensorDict` must follow HuggingFace's padded data format,
|
||||
> where each tensor has shape `[batch_size, sequence_length, *]`. This allows AReaLite
|
||||
> to automatically batch multiple trajectories for training. Since this example returns
|
||||
> a single trajectory, we use `unsqueeze(0)` to create a batch of size 1.
|
||||
|
||||
> **Note**: You're not restricted to specific keys in your `TensorDict`—different
|
||||
> algorithms need different keys. This example targets the GRPO algorithm, so we include
|
||||
> `input_ids`, `loss_mask`, `attention_mask`, and `logprobs` (needed for computing
|
||||
> importance ratios).
|
||||
|
||||
## Step 2: Training with Your Custom Workflow
|
||||
|
||||
Using your custom workflow is straightforward—just create it in your training script and
|
||||
pass it to the `rollout_batch` or `prepare_batch` method:
|
||||
|
||||
```python
|
||||
for turn in range(self.num_turns):
|
||||
...
|
||||
feedback = None
|
||||
if success[0]:
|
||||
feedback = "Congratulations! You are correct!"
|
||||
else:
|
||||
feedback = "Unfortunately your answer is wrong. Let's try again."
|
||||
|
||||
feedback = "\n" + self.tokenizer.apply_chat_template(
|
||||
[dict(content=feedback, role="user")],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
def main(args):
|
||||
# ... setup code ...
|
||||
|
||||
# Create your custom workflow
|
||||
workflow = MultiTurnWorkflow(
|
||||
reward_fn=gsm8k_reward_fn,
|
||||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
turn_discount=0.9,
|
||||
max_turns=5,
|
||||
)
|
||||
feedback = self.tokenizer(feedback)["input_ids"]
|
||||
token_ids.extend(feedback)
|
||||
|
||||
# Run training—no other changes needed!
|
||||
data_generator = iter(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
# ... continue with training loop ...
|
||||
```
|
||||
|
||||
## Modify the Configuration
|
||||
|
||||
You're now close to running the end-to-end RL loop. The final step is to register and import your implementation, then modify the experiment configuration.
|
||||
|
||||
```python
|
||||
# in realhf/impl/agent/math_multi_turn_agent.py
|
||||
register_agent("math-multi-turn", MathMultiTurnAgent)
|
||||
```
|
||||
|
||||
```python
|
||||
# in realhf/impl/agent/__init__.py
|
||||
import realhf.impl.agent.math_multi_turn_agent
|
||||
```
|
||||
|
||||
In `realhf/experiments/async_exp/async_math_ppo.py`:
|
||||
|
||||
```diff
|
||||
@dataclasses.dataclass
|
||||
class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
|
||||
+ # New CLI arguments are defined here
|
||||
+ my_param: float = 1.0
|
||||
|
||||
# in realhf/experiments/async_exp/async_ppo_math_exp.py
|
||||
@property
|
||||
def agent(self) -> AgentAbstraction:
|
||||
return AgentAbstraction(
|
||||
- "math-single-step",
|
||||
+ "math-multi-turn", # Your registered name
|
||||
args=dict(
|
||||
- ...
|
||||
+ # Any configurations for your __init__ method
|
||||
+ my_param=my_param,
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def env(self) -> EnvServiceAbstraction:
|
||||
- return EnvServiceAbstraction(
|
||||
- "math-code-single-step", args=dict(dataset_path=self.dataset.path)
|
||||
- )
|
||||
+ # Change to your customized environment if necessary
|
||||
+ return EnvServiceAbstraction(
|
||||
+ "my-env", args=dict(...)
|
||||
+ )
|
||||
```
|
||||
|
||||
## Run Training
|
||||
|
||||
Please follow the guide in [quickstart](../tutorial/quickstart.md). Generally, start your experiments by running:
|
||||
|
||||
```bash
|
||||
python3 training/main_async_ppo.py my_param=5.0 # and any additional CLI arguments
|
||||
```
|
||||
|
||||
The training reward of our trial is shown below:
|
||||
|
||||

|
||||
|
||||
Happy coding!
|
||||
That's it! Your custom multi-turn math agent is now ready for reinforcement learning
|
||||
training. The workflow will automatically handle the multi-turn conversations, reward
|
||||
computation, and data collection needed for effective RL training.
|
||||
|
|
|
@ -1,305 +1,203 @@
|
|||
# Training Algorithm
|
||||
|
||||
An algorithm is encapsulated in a `ModelInterface`, which primarily defines three methods:
|
||||
> **Note**: We recommend the user to first read the
|
||||
> [agent customization guide](agent.md).
|
||||
|
||||
**AReaLite** structures RL algorithms around two core components:
|
||||
|
||||
- **RolloutWorkflow**: Defines what data to generate during rollouts
|
||||
- **TrainEngine**: Defines how to process the generated data for training
|
||||
|
||||
We'll demonstrate this by implementing an RL algorithm similar to ReMax.
|
||||
|
||||
## Step 1: Implementing the RolloutWorkflow
|
||||
|
||||
The rollout workflow generates both greedy and sampled completions, then uses the reward
|
||||
difference as the final training signal:
|
||||
|
||||
```python
|
||||
# in realhf/api/core/model_api.py
|
||||
class ModelInterface(abc.ABC):
|
||||
"""An interface for model training, inference, and generation.
|
||||
class ReMaxRLVRWorkflow(RolloutWorkflow):
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
# Prepare input tokens from chat messages
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
data["messages"],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=self.enable_thinking,
|
||||
)
|
||||
|
||||
This interface is designed to follow the dependency injection pattern.
|
||||
We pass the model to the interface and call its methods, ensuring that model APIs
|
||||
and algorithms are fully decoupled. For example, REINFORCE and PPO can exhibit
|
||||
different behaviors during training. Separate interfaces can be written for these
|
||||
algorithms while using the same model that provides basic forward-backward-update
|
||||
functionality (i.e., :class:`PipelinableEngine`).
|
||||
n_samples = self.gconfig.n_samples
|
||||
rid = uuid.uuid4().hex
|
||||
|
||||
During runtime, the master worker requests model workers to execute a specific
|
||||
interface type (e.g., generate) on a specific model. The model worker locates
|
||||
the corresponding model, passes it into the requested interface, performs the
|
||||
computation, and returns the result.
|
||||
"""
|
||||
# Create requests for both sampled and greedy generation
|
||||
sample_req = LLMRequest(
|
||||
rid=rid,
|
||||
input_ids=input_ids,
|
||||
gconfig=self.gconfig,
|
||||
)
|
||||
greedy_req = LLMRequest(
|
||||
rid=rid,
|
||||
input_ids=input_ids,
|
||||
gconfig=self.gconfig.new(greedy=True),
|
||||
)
|
||||
|
||||
def inference(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample | None:
|
||||
raise NotImplementedError()
|
||||
# Generate both responses concurrently
|
||||
resp, greedy_resp = await asyncio.gather(
|
||||
engine.agenerate(sample_req),
|
||||
engine.agenerate(greedy_req),
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample | None:
|
||||
raise NotImplementedError()
|
||||
# Calculate rewards for both completions
|
||||
prompt_str = self.tokenizer.decode(input_ids)
|
||||
completions_str = self.tokenizer.decode(resp.output_tokens)
|
||||
|
||||
def train_step(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> Dict | List[Dict]:
|
||||
raise NotImplementedError()
|
||||
sample_reward = self.reward_fn(
|
||||
prompt=prompt_str,
|
||||
completions=completions_str,
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
|
||||
greedy_completions = self.tokenizer.decode(greedy_resp.output_tokens)
|
||||
greedy_reward = self.reward_fn(
|
||||
prompt=prompt_str,
|
||||
completions=greedy_completions,
|
||||
prompt_ids=greedy_resp.input_tokens,
|
||||
completion_ids=greedy_resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
|
||||
# Package results for training
|
||||
res = dict(
|
||||
# Add batch dimension
|
||||
input_ids=torch.tensor(resp.input_tokens + resp.output_tokens).unsqueeze(0),
|
||||
loss_mask=torch.tensor([0] * resp.input_len + [1] * resp.output_len).unsqueeze(0),
|
||||
versions=torch.tensor([-1] * resp.input_len + resp.output_versions).unsqueeze(0),
|
||||
attention_mask=torch.ones(resp.input_len + resp.output_len, dtype=torch.bool).unsqueeze(0),
|
||||
# Use reward difference across all tokens
|
||||
rewards=torch.tensor([float(sample_reward - greedy_reward)] * (resp.input_len + resp.output_len)),
|
||||
)
|
||||
|
||||
return TensorDict(res, batch_size=[1])
|
||||
```
|
||||
|
||||
When the dataflow is fixed, it's usually sufficient to modify or add the file that defines the algorithm interface.
|
||||
> **Note**: For detailed guidance on customizing rollout workflows, see the
|
||||
> [agent customization guide](agent.md).
|
||||
|
||||
We provide two examples: (1) changing PPO's global advantage normalization to grouped normalization in GRPO, and (2) changing the original PPO loss to the decoupled PPO loss in AReaL's paper.
|
||||
## Step 2: Implementing the REINFORCE Training Algorithm
|
||||
|
||||
Training algorithms are implemented by subclassing `TrainEngine` and using its atomic
|
||||
operations like `forward`, `train_batch`, and `eval_batch`.
|
||||
|
||||
First, let's define the REINFORCE loss function:
|
||||
|
||||
```python
|
||||
def reinforce_loss_fn(logits, data):
|
||||
input_ids = data["input_ids"]
|
||||
loss_mask = data["loss_mask"].bool()
|
||||
rewards = data["rewards"]
|
||||
|
||||
logprobs = gather_logprobs(
|
||||
logits, torch.roll(input_ids, shifts=-1, dims=-1)
|
||||
)
|
||||
loss = -logprobs * rewards
|
||||
loss = torch.where(loss_mask, loss, 0.0)
|
||||
|
||||
return loss.sum() / loss_mask.count_nonzero()
|
||||
```
|
||||
|
||||
```{note}
|
||||
We recommend using asynchronous RL, so that you can customize the generation behavior by [modifying your RL agent](agent.md) and don't need to modify the `generate` method of model interfaces.
|
||||
To decrease memory usage, AReaLite automatically packs multiple sequences in an 1D tensor before forward passes. Hence, the loss function should assume handling 1D *packed* tensors instead of *padded* tensors.
|
||||
```
|
||||
|
||||
## Grouped Advantage Normalization
|
||||
|
||||
The PPO algorithm is written in a single file `ppo_interface.py`. The method we are going to modify is the `train_step` method in `PPOActorInterface`. PPO's global advantage normalization looks like:
|
||||
Next, we implement the training engine. We use a two-class design to maintain backend
|
||||
compatibility:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class PPOActorInterface(ModelInterface):
|
||||
def train_step(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> Dict | List[Dict]:
|
||||
...
|
||||
if self.adv_norm:
|
||||
advantages = masked_normalization(advantages, loss_mask)
|
||||
...
|
||||
```
|
||||
class ReinforceActor:
|
||||
def __init__(self, engine: TrainEngine):
|
||||
self.engine = engine
|
||||
|
||||
### An Additional Note on Data Management
|
||||
|
||||
We need to explain how data in each batch is organized.
|
||||
|
||||
Usually, each data batch (i.e., the `data` variable) includes multiple prompts. The number of prompts is called "batch size". Additionally, each prompt may have multiple corresponding answers. The number of answers is called "group_size". Therefore, there are batch_size × group_size sequences in each batch.
|
||||
|
||||
These sequences have different lengths, but they are concatenated (or packed) together as a 1D tensor. The inner dimension is the "group" with the same prompt, and the outer dimension consists of answers from different prompts. Similar to flash-attention, we use `cu_seqlens` to mark the boundary of each sequence. `cu_seqlens` is the cumulative sum of sequence lengths across the batch.
|
||||
|
||||
Each token in the sequence has a corresponding reward and advantage, so `advantages` is also a packed 1D tensor just like the tokens (i.e., `packed_input_ids`). However, the "sequences" of advantages are all one step shorter than tokens due to the auto-regressive nature of LLMs. We can only compute the loss on tokens except for the first one in each sequence.
|
||||
|
||||
### Implementation
|
||||
|
||||
For grouped advantage normalization, we need to partition the advantages into groups and run normalization within the tensor chunk of each group:
|
||||
|
||||
```diff
|
||||
@dataclass
|
||||
class PPOActorInterface(ModelInterface):
|
||||
+ group_adv_norm: bool = False
|
||||
|
||||
def train_step(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> Dict | List[Dict]:
|
||||
...
|
||||
if self.adv_norm:
|
||||
- advantages = masked_normalization(advantages, loss_mask)
|
||||
+ if not self.group_adv_norm:
|
||||
+ advantages = masked_normalization(advantages, loss_mask)
|
||||
+ else:
|
||||
+ n_samples = data.bs
|
||||
+ adv_list = []
|
||||
+ for i in range(0, n_samples, self.group_size):
|
||||
+ # Start and end of the chunk
|
||||
+ s = short1cu_seqlens[i]
|
||||
+ e = short1cu_seqlens[i + self.group_size]
|
||||
+ # Get advantages within each group of the same prompt
|
||||
+ adv = advantages[s: e]
|
||||
+ mask = loss_mask[s: e]
|
||||
+ # Run normalization
|
||||
+ advn = masked_normalization(adv, mask, all_reduce=False)
|
||||
+ adv_list.append(advn)
|
||||
+ advantages = torch.cat(adv_list, 0)
|
||||
...
|
||||
```
|
||||
|
||||
### Modify Your Experiment Configuration
|
||||
|
||||
To make our new argument `group_adv_norm` effective in CLI args, we should make the following changes to the `PPOMathConfig` under `realhf/experiments/common/ppo_math_exp.py`:
|
||||
|
||||
```diff
|
||||
@dataclasses.dataclass
|
||||
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
||||
+ group_adv_norm: bool = False
|
||||
|
||||
@property
|
||||
def rpcs(self):
|
||||
...
|
||||
# interfaces
|
||||
actor_interface = ModelInterfaceAbstraction(
|
||||
"ppo_actor",
|
||||
args={
|
||||
**copy.deepcopy(self.ppo_kwargs),
|
||||
+ "group_adv_norm": self.group_adv_norm,
|
||||
...
|
||||
},
|
||||
def train_reinforce(self, data: TensorDict):
|
||||
# Enable gradient checkpointing
|
||||
self.engine.train()
|
||||
return self.engine.train_batch(
|
||||
data,
|
||||
loss_fn=reinforce_loss_fn,
|
||||
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
|
||||
)
|
||||
|
||||
class FSDPReinforceActor(FSDPEngine):
|
||||
def __init__(self):
|
||||
self.actor = ReinforceActor(self)
|
||||
|
||||
def train_reinforce(self, *args, **kwargs):
|
||||
return self.actor.train_reinforce(*args, **kwargs)
|
||||
```
|
||||
|
||||
## The Decoupled PPO Loss
|
||||
**Why two classes?** This design separates concerns:
|
||||
|
||||

|
||||
1. **Backend Agnostic Logic**: `ReinforceActor` contains the core REINFORCE algorithm
|
||||
that works with any backend (FSDP, DeepSpeed, Megatron) since they share the same
|
||||
`train_batch` API.
|
||||
|
||||
As mentioned in AReaL's paper, we implement this loss by recomputing the probabilities before mini-batched updates, and use this value as π_prox to compute the above loss.
|
||||
1. **Backend-Specific Features**: `FSDPReinforceActor` inherits from `FSDPEngine` to
|
||||
provide backend-specific utilities like `save`, `load`, and `upload_weights`. For
|
||||
other backends, you'd create `MegatronReinforceActor`, etc.
|
||||
|
||||
### Probability Recomputation
|
||||
> **Note**: This pattern is similar to interfaces in Go or traits in Rust, adapted for
|
||||
> Python's object model.
|
||||
|
||||
Recomputation involves a single forward pass, which has already been implemented by `PPOActorInterface.inference`. We need to call this method in the `train_step` method:
|
||||
## Step 3: Composing the Complete Training Loop
|
||||
|
||||
```diff
|
||||
@dataclass
|
||||
class PPOActorInterface(ModelInterface):
|
||||
+ use_decoupled_loss: bool = False
|
||||
|
||||
def train_step(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> Dict | List[Dict]:
|
||||
+ if self.use_decoupled_loss:
|
||||
+ s: SequenceSample = self.inference(model, data, mb_spec)
|
||||
+ prox_logp = s.data["logprobs"]
|
||||
...
|
||||
```
|
||||
|
||||
Next, we need to pass `prox_logp` to loss computation:
|
||||
|
||||
```diff
|
||||
@dataclass
|
||||
class PPOActorInterface(ModelInterface):
|
||||
...
|
||||
|
||||
def train_step(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> Dict | List[Dict]:
|
||||
# Prepare data to be split into mini-batches.
|
||||
flat_data = dict(
|
||||
advantages=advantages,
|
||||
old_logp=old_logp,
|
||||
ppo_loss_mask=loss_mask,
|
||||
packed_input_ids=input_.data["packed_input_ids"],
|
||||
kl_rewards=kl_rewards,
|
||||
)
|
||||
+ if self.use_decoupled_loss:
|
||||
+ flat_data["prox_logp"] = prox_logp.float()
|
||||
|
||||
flat_input = SequenceSample.from_default(
|
||||
ids=list(range(input_.bs * self.group_size)),
|
||||
data=flat_data,
|
||||
seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()],
|
||||
)
|
||||
...
|
||||
datas = flat_input.split_with_spec(spec)
|
||||
...
|
||||
for mb_i, data in enumerate(datas):
|
||||
train_stat = module.train_batch(
|
||||
input_=data,
|
||||
mb_spec=mb_spec,
|
||||
version_steps=model.version.global_step,
|
||||
loss_fn=_loss_fn,
|
||||
loss_weight_fn=lambda x: x.data[
|
||||
"ppo_loss_mask"
|
||||
].count_nonzero(),
|
||||
token_normalize_scope=self.token_normalize_scope,
|
||||
)
|
||||
```
|
||||
|
||||
The `flat_input` variable will be divided into mini-batches. Each mini-batch of data will be passed into the `train_batch` method to run distributed training. The data included in this `SequenceSample` object will all be passed into the `_loss_fn`. In this case, `_loss_fn` is a wrapper over `_ppo_actor_loss_from_model_outputs`:
|
||||
The main training loop brings everything together:
|
||||
|
||||
```python
|
||||
def _ppo_actor_loss_from_model_outputs(
|
||||
logits: torch.FloatTensor, # [tot_seqlen, vocab_size]
|
||||
input_: SequenceSample,
|
||||
...
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
```
|
||||
def main(args):
|
||||
# Initialize inference engine for rollouts
|
||||
rollout = RemoteSGLangEngine(config.rollout)
|
||||
rollout.initialize(None, ft_spec)
|
||||
|
||||
`logits` is the output of model forward, and `input_` is exactly the `input_` we passed into `train_batch`. So now we can retrieve the `prox_logp` via:
|
||||
# Initialize training engine
|
||||
actor = FSDPReinforceActor(config=config.actor)
|
||||
actor.initialize(None, ft_spec)
|
||||
|
||||
```diff
|
||||
def _ppo_actor_loss_from_model_outputs(
|
||||
logits: torch.FloatTensor, # [tot_seqlen, vocab_size]
|
||||
input_: SequenceSample,
|
||||
...
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
+ prox_logp = input_.data["prox_logp"]
|
||||
loss, ppo_stat = ppo_functional.actor_loss_fn(
|
||||
logprobs=logprobs,
|
||||
old_logprobs=old_logp,
|
||||
advantages=advantages,
|
||||
eps_clip=eps_clip,
|
||||
loss_mask=ppo_loss_mask,
|
||||
c_clip=c_clip,
|
||||
+ proximal_logprobs=prox_logp,
|
||||
behav_imp_weight_cap=behav_imp_weight_cap,
|
||||
# Create rollout workflow
|
||||
workflow = ReMaxRLVRWorkflow(
|
||||
reward_fn=gsm8k_reward_fn,
|
||||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# Main training loop
|
||||
for global_step in range(max_steps):
|
||||
# Generate training data
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
|
||||
batch = batch.to(actor.device)
|
||||
|
||||
# Synchronize all processes
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Training step
|
||||
with (
|
||||
stats_tracker.record_timing("train_step"),
|
||||
stats_tracker.scope("actor"),
|
||||
):
|
||||
stats = actor.train_reinforce(batch)
|
||||
actor.step_lr_scheduler()
|
||||
|
||||
# Update model weights
|
||||
with stats_tracker.record_timing("update_weights"):
|
||||
# Weight update logic here
|
||||
...
|
||||
```
|
||||
|
||||
We have successfully recomputed the probability and passed it into the loss function. Next we should revise the loss computation code.
|
||||
|
||||
### Modifying the PPO Loss
|
||||
|
||||
```diff
|
||||
def actor_loss_fn(
|
||||
logprobs: torch.FloatTensor,
|
||||
old_logprobs: torch.FloatTensor,
|
||||
advantages: torch.FloatTensor,
|
||||
eps_clip: float,
|
||||
loss_mask: Optional[torch.BoolTensor] = None,
|
||||
c_clip: Optional[float] = None,
|
||||
+ proximal_logprobs: Optional[torch.FloatTensor] = None,
|
||||
behav_imp_weight_cap: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
...
|
||||
+ if proximal_logprobs is not None:
|
||||
+ denorm_logprobs = proximal_logprobs
|
||||
+ else:
|
||||
+ denorm_logprobs = old_logprobs
|
||||
...
|
||||
loss_mask_count = loss_mask.count_nonzero() or 1
|
||||
# For numerical stability.
|
||||
- ratio = torch.where(loss_mask, torch.exp(logprobs - old_logprobs), 0)
|
||||
+ ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
|
||||
...
|
||||
+ if proximal_logprobs is not None:
|
||||
+ behav_kl = proximal_logprobs - old_logprobs
|
||||
+ behav_imp_weight = behav_kl.exp()
|
||||
+ behav_kl = torch.where(loss_mask, behav_kl, 0.0)
|
||||
+ behav_imp_weight = torch.where(loss_mask, behav_imp_weight, 0.0)
|
||||
+ pg_loss = pg_loss * behav_imp_weight
|
||||
...
|
||||
return pg_loss, stat
|
||||
```
|
||||
|
||||
### Modify the Experiment Configuration
|
||||
|
||||
```diff
|
||||
@dataclasses.dataclass
|
||||
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
||||
+ use_decoupled_loss: bool = False
|
||||
|
||||
@property
|
||||
def rpcs(self):
|
||||
...
|
||||
# interfaces
|
||||
actor_interface = ModelInterfaceAbstraction(
|
||||
"ppo_actor",
|
||||
args={
|
||||
**copy.deepcopy(self.ppo_kwargs),
|
||||
+ "use_decoupled_loss": self.use_decoupled_loss,
|
||||
...
|
||||
},
|
||||
)
|
||||
```
|
|
@ -1,134 +1,144 @@
|
|||
# Dataset
|
||||
|
||||
This guide provides detailed examples of how to create custom datasets in AReaL for model training.
|
||||
**AReaLite** directly integrates with the `Dataset` class from the HuggingFace
|
||||
`datasets` package. This gives you full flexibility to load, process, and filter your
|
||||
data before training.
|
||||
|
||||
## Define Your Dataset
|
||||
The required columns in your dataset depend on the specific implementation of the
|
||||
`RolloutWorkflow` (for online reinforcement learning) or the training engines (for
|
||||
offline training, such as `LMEngine` for Supervised Fine-Tuning (SFT)).
|
||||
|
||||
Create a new file under `realhf/impl/dataset/`, for example, `my_custom_dataset.py`. Your `Dataset` must implement the `torch.utils.data.Dataset` interface and follow the framework's conventions.
|
||||
Here are two concrete examples from the existing implementation:
|
||||
|
||||
## SFT (Offline Training)
|
||||
|
||||
In the SFT example, we see that the loaded data is directly passed to the `train_lm`
|
||||
method:
|
||||
|
||||
```python
|
||||
class MyCustomDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
util: data_api.DatasetUtility,
|
||||
max_length: Optional[int] = None,
|
||||
dataset_path: Optional[str] = None,
|
||||
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
|
||||
# Your custom parameters
|
||||
custom_param: float = 1.0,
|
||||
):
|
||||
"""Custom dataset initialization
|
||||
|
||||
Args:
|
||||
util: Dataset utility class containing tokenizer, seed, distributed info, etc.
|
||||
max_length: Maximum sequence length
|
||||
dataset_path: Path to dataset file (optional)
|
||||
dataset_builder: Data construction function (optional, alternative to dataset_path)
|
||||
custom_param: Your custom parameter
|
||||
"""
|
||||
self._util = util
|
||||
self.max_length = max_length
|
||||
|
||||
# Load and split dataset
|
||||
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
|
||||
|
||||
# Your custom data processing logic
|
||||
# examples/arealite/gsm8k_sft.py
|
||||
def main(args):
|
||||
...
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("train", tokenizer, rank, world_size),
|
||||
collate_fn=pad_sequences_to_tensors,
|
||||
)
|
||||
...
|
||||
# Run training loop
|
||||
for epoch in range(total_epochs):
|
||||
for step, data in enumerate(train_dataloader):
|
||||
stats = engine.train_lm(data)
|
||||
```
|
||||
|
||||
In this case, the `train_lm` method requires the keys "input_ids", "attention_mask", and
|
||||
"loss_mask" to function. We first tokenize the dataset to extract the "input_ids" and
|
||||
"loss_mask". Then, the `pad_sequences_to_tensors` method is used to batch multiple
|
||||
sequences and append the "attention_mask":
|
||||
|
||||
```python
|
||||
def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
|
||||
def process(sample):
|
||||
seq_token = tokenizer.encode(
|
||||
sample["question"] + sample["answer"] + tokenizer.eos_token
|
||||
)
|
||||
prompt_token = tokenizer.encode(sample["question"])
|
||||
loss_mask = [0] * len(prompt_token) + [1] * (len(seq_token) - len(prompt_token))
|
||||
return {"input_ids": seq_token, "loss_mask": loss_mask}
|
||||
|
||||
# Remove unnecessary columns to avoid errors during collation
|
||||
dataset = dataset.map(process).remove_columns(["question", "answer"])
|
||||
return dataset
|
||||
|
||||
def get_gsm8k_dataset(split, tokenizer, rank, world_size):
|
||||
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
return process_gsm8k_sft_dataset(dataset, tokenizer)
|
||||
```
|
||||
|
||||
## GRPO (Online Training)
|
||||
|
||||
In the GRPO example, the loaded data is passed to the `InferenceEngine`, rather than the
|
||||
`TrainEngine`:
|
||||
|
||||
```python
|
||||
# examples/arealite/gsm8k_ppo.py
|
||||
def main(args):
|
||||
...
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("train", rank, world_size),
|
||||
collate_fn=lambda x: x,
|
||||
)
|
||||
# Initialize inference engine
|
||||
rollout = RemoteSGLangEngine(config.rollout)
|
||||
workflow = RLVRWorkflow(
|
||||
reward_fn=gsm8k_reward_fn,
|
||||
...
|
||||
)
|
||||
# Run training loop
|
||||
...
|
||||
for global_step in range(max_steps):
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
...
|
||||
```
|
||||
|
||||
## Implement Core Methods
|
||||
Note that the `collate_fn` here is an identity function, meaning it simply returns the
|
||||
list of individual data items as a batch. In `rollout_batch`, the data is then
|
||||
dispatched to multiple concurrent executions of `workflow.arun_episode`, where each
|
||||
dispatched data corresponds to a single episode.
|
||||
|
||||
Every dataset class must implement the following two core methods:
|
||||
|
||||
### 1. `__len__` Method
|
||||
|
||||
Returns the size of the dataset:
|
||||
The `RLVRWorkflow` implementation extracts the "messages" field from the data dictionary
|
||||
as the prompt for generating a response. Additionally, this data is passed to the
|
||||
`reward_fn` as keyword arguments, which allows the reward function to make use of other
|
||||
dataset fields, like "answers". Here’s an example:
|
||||
|
||||
```python
|
||||
def __len__(self):
|
||||
return len(self.data_samples)
|
||||
```
|
||||
class RLVRWorkflow(RolloutWorkflow):
|
||||
|
||||
### 2. `__getitem__` Method
|
||||
|
||||
Returns the sample at the specified index, must return a `SequenceSample` object:
|
||||
|
||||
```python
|
||||
def __getitem__(self, idx):
|
||||
# Get raw data
|
||||
sample = self.data_samples[idx]
|
||||
|
||||
# Process data
|
||||
...
|
||||
|
||||
# Return SequenceSample object
|
||||
return data_api.SequenceSample.from_default(
|
||||
ids=[sample["id"]],
|
||||
seqlens=[len(processed_data["input_ids"])],
|
||||
data=dict(
|
||||
packed_prompts=torch.tensor(processed_data["input_ids"], dtype=torch.long),
|
||||
# Other necessary data fields
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
### Dataset Examples
|
||||
|
||||
We provide some examples of dataset under `realhf/impl/dataset/`:
|
||||
- For SFT, please refer `prompt_answer_dataset.py`.
|
||||
- For Reward model training, please refer `rw_paired_dataset.py`
|
||||
- For RL training, please refer `math_code_dataset.py`
|
||||
|
||||
## Data Format Requirements
|
||||
|
||||
### JSONL File Format
|
||||
|
||||
Your data file should be in JSONL format, with one JSON object per line.
|
||||
If you are using our PromptDataset implementation, your data should be like:
|
||||
- Math Data
|
||||
```json
|
||||
{"qid": "sample_1", "prompt": "Solve this math problem: 2+2=", "solutions": ["\\boxed{4}"]}
|
||||
```
|
||||
- Code Data
|
||||
```json
|
||||
{"qid": "sample_2", "prompt": "Code problem", "input_output": "{\"inputs\": [\"5\\n2 3 5 10 12\\n\"], \"outputs\": [\"17\\n\"]}"}
|
||||
```
|
||||
|
||||
- `qid`: Unique identifier for the sample
|
||||
- `prompt`: Input prompt text
|
||||
- `task`: Task type, used to distinguish how to calculate the reward. ("math" and "code" are supported now.)
|
||||
|
||||
Note: There is no format restriction for a customized dataset as long as it can be loaded by your custom code.
|
||||
|
||||
## Registration and Configuration
|
||||
|
||||
### Register Dataset
|
||||
|
||||
Register your dataset at the end of your dataset file:
|
||||
|
||||
```python
|
||||
# in realhf/impl/dataset/my_custom_dataset.py
|
||||
data_api.register_dataset("my-custom", MyCustomDataset)
|
||||
```
|
||||
|
||||
### Modify Experiment Configuration
|
||||
|
||||
Use your new dataset in the experiment configuration (refer to `realhf/experiments/common/*_exp.py`):
|
||||
|
||||
```python
|
||||
# in your experiment config file
|
||||
@property
|
||||
def datasets(self) -> List[DatasetAbstraction]:
|
||||
return [
|
||||
DatasetAbstraction(
|
||||
"my-custom", # Your registered name
|
||||
args=dict(
|
||||
dataset_path=self.dataset_path,
|
||||
max_length=self.max_length,
|
||||
custom_param=self.custom_param,
|
||||
# Other initialization parameters
|
||||
),
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
data["messages"],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=self.enable_thinking,
|
||||
)
|
||||
req = LLMRequest(
|
||||
input_ids=input_ids,
|
||||
...
|
||||
)
|
||||
...
|
||||
reward = self.reward_fn(
|
||||
prompt=prompt_str,
|
||||
completions=completions_str,
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
]
|
||||
```
|
||||
|
||||
Thus, the "messages" field must be constructed when loading the dataset, and the reward
|
||||
function should be defined to handle the dataset's specific fields. Here’s how you can
|
||||
process the dataset for this example:
|
||||
|
||||
```python
|
||||
def process_gsm8k_rl_dataset(dataset: Dataset):
|
||||
def process(sample):
|
||||
messages = [{"role": "user", "content": sample["question"]}]
|
||||
return {"messages": messages}
|
||||
|
||||
# The dataset has two fields "messages" and "answer"
|
||||
dataset = dataset.map(process).remove_columns(["question"])
|
||||
return dataset
|
||||
|
||||
def get_gsm8k_dataset(split, rank, world_size):
|
||||
dataset = load_dataset(path="openai/gsm8k", name="main", split=split)
|
||||
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
return process_gsm8k_rl_dataset(dataset)
|
||||
|
||||
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
|
||||
# "answer" is passed in through "**data"
|
||||
from realhf.impl.dataset.math_parser import process_results
|
||||
|
||||
return int(process_results(completions, answer)[0])
|
||||
```
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
# Rollout and Agentic RL (Legacy)
|
||||
|
||||
> **Note**: While this legacy approach works, we strongly recommend using the AReaLite
|
||||
> for new projects. It provides better flexibility, cleaner abstractions, and easier
|
||||
> maintenance.
|
||||
|
||||
## Step 1: Define Your Agent Class
|
||||
|
||||
Create a new file under `realhf/impl/agent/`, such as `math_multi_turn_agent.py`. Your
|
||||
`Agent` must implement the interface defined in `realhf/api/core/agent.py`, which
|
||||
requires a single method: `collect_trajectory`.
|
||||
|
||||
```python
|
||||
class MathMultiTurnAgent(Agent):
|
||||
async def collect_trajectory(
|
||||
self,
|
||||
prompt: SequenceSample,
|
||||
env: EnvironmentService,
|
||||
obs_queue: asyncio.Queue,
|
||||
act_queue: asyncio.Queue,
|
||||
):
|
||||
# Implementation goes here
|
||||
...
|
||||
```
|
||||
|
||||
## Step 2: Implement the Trajectory Collection Logic
|
||||
|
||||
The `collect_trajectory` method takes a task prompt, an environment, and two
|
||||
communication queues. Within this method, you control the data flow between your agent
|
||||
and the inference engine using these queues:
|
||||
|
||||
- **obs_queue**: Send observations (token IDs and generation config) to the inference
|
||||
engine
|
||||
- **act_queue**: Receive actions (generated responses) from the inference engine
|
||||
|
||||
Here's how the multi-turn conversation works:
|
||||
|
||||
```python
|
||||
for turn in range(self.num_turns):
|
||||
# Send the current state to the inference engine
|
||||
await obs_queue.put((qid, token_ids, self.gconfig))
|
||||
|
||||
# Get the generated response
|
||||
act: BundledGenerationOutputs = await act_queue.get()
|
||||
|
||||
# Evaluate the response through the environment
|
||||
success, rewards = await env.step((qid, answers))
|
||||
# ... process results ...
|
||||
```
|
||||
|
||||
### Environment Integration
|
||||
|
||||
The environment follows a
|
||||
[Gym-like interface](https://github.com/Farama-Foundation/Gymnasium) with `reset` and
|
||||
`step` methods, but uses asynchronous implementations to prevent blocking across
|
||||
different environment instances.
|
||||
|
||||
For math problems, the environment is typically stateless and acts as a wrapper around
|
||||
your reward function:
|
||||
|
||||
```python
|
||||
class MathCodeSingleStepEnv(EnvironmentService):
|
||||
async def step(self, action: Tuple[str, List[str]]):
|
||||
qid, answers = action
|
||||
# ... setup code ...
|
||||
|
||||
# Run reward computation asynchronously
|
||||
format_rewards = await asyncio.to_thread(
|
||||
math_verify_call,
|
||||
answers,
|
||||
# ... other parameters ...
|
||||
)
|
||||
return None, format_rewards, True, False, {}
|
||||
```
|
||||
|
||||
### Handling Multi-Turn Feedback
|
||||
|
||||
After receiving the reward from `env.step`, check if the answer is correct. If not,
|
||||
provide feedback and continue to the next turn:
|
||||
|
||||
```python
|
||||
for turn in range(self.num_turns):
|
||||
# ... generation and evaluation code ...
|
||||
|
||||
# Provide feedback based on the result
|
||||
if success[0]:
|
||||
feedback = "Congratulations! You are correct!"
|
||||
else:
|
||||
feedback = "Unfortunately your answer is wrong. Let's try again."
|
||||
|
||||
# Format feedback as a user message
|
||||
feedback = "\n" + self.tokenizer.apply_chat_template(
|
||||
[{"content": feedback, "role": "user"}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
# Add feedback tokens to the conversation
|
||||
feedback_tokens = self.tokenizer(feedback)["input_ids"]
|
||||
token_ids.extend(feedback_tokens)
|
||||
```
|
||||
|
||||
## Step 3: Register and Configure Your Agent
|
||||
|
||||
First, register your agent implementation:
|
||||
|
||||
```python
|
||||
# in realhf/impl/agent/math_multi_turn_agent.py
|
||||
register_agent("math-multi-turn", MathMultiTurnAgent)
|
||||
```
|
||||
|
||||
```python
|
||||
# in realhf/impl/agent/__init__.py
|
||||
import realhf.impl.agent.math_multi_turn_agent
|
||||
```
|
||||
|
||||
Then update your experiment configuration in
|
||||
`realhf/experiments/async_exp/async_math_ppo.py`:
|
||||
|
||||
```python
|
||||
@dataclasses.dataclass
|
||||
class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
|
||||
# Add any new CLI arguments your agent needs
|
||||
my_param: float = 1.0
|
||||
|
||||
@property
|
||||
def agent(self) -> AgentAbstraction:
|
||||
return AgentAbstraction(
|
||||
"math-multi-turn", # Your registered agent name
|
||||
args=dict(
|
||||
# Pass any arguments needed for your __init__ method
|
||||
my_param=self.my_param,
|
||||
# ... other configuration ...
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def env(self) -> EnvServiceAbstraction:
|
||||
# Update to use your custom environment if needed
|
||||
return EnvServiceAbstraction(
|
||||
"math-code-single-step",
|
||||
args=dict(dataset_path=self.dataset.path)
|
||||
)
|
||||
```
|
||||
|
||||
## Step 4: Run Training
|
||||
|
||||
Follow the standard training procedure outlined in the
|
||||
[quickstart guide](../../tutorial/quickstart.md). Launch your experiment with:
|
||||
|
||||
```bash
|
||||
python3 training/main_async_ppo.py my_param=5.0 # plus any additional CLI arguments
|
||||
```
|
||||
|
||||
## Training Results
|
||||
|
||||
Here's an example of the training reward curve from our multi-turn math agent:
|
||||
|
||||

|
||||
|
||||
The agent successfully learns to solve math problems with improved accuracy over time,
|
||||
demonstrating the effectiveness of the multi-turn approach.
|
|
@ -0,0 +1,259 @@
|
|||
# Training Algorithm (Legacy)
|
||||
|
||||
> **Note**: The AReaLite approach is more recommended for new implementations due to its
|
||||
> cleaner separation of concerns and better maintainability.
|
||||
|
||||
The legacy approach encapsulates algorithms in a `ModelInterface` with three core
|
||||
methods:
|
||||
|
||||
```python
|
||||
# From realhf/api/core/model_api.py
|
||||
class ModelInterface(abc.ABC):
|
||||
"""Interface for model training, inference, and generation.
|
||||
|
||||
This interface follows the dependency injection pattern, allowing
|
||||
algorithms like REINFORCE and PPO to use the same underlying model
|
||||
while exhibiting different training behaviors.
|
||||
"""
|
||||
|
||||
def inference(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> SequenceSample | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def train_step(
|
||||
self,
|
||||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> Dict | List[Dict]:
|
||||
raise NotImplementedError()
|
||||
```
|
||||
|
||||
When the dataflow is fixed, you typically only need to modify the algorithm interface
|
||||
file.
|
||||
|
||||
> **Note**: We recommend using asynchronous RL so you can customize generation behavior
|
||||
> by [modifying your RL agent](agent.md) instead of the `generate` method.
|
||||
|
||||
## Example 1: Grouped Advantage Normalization
|
||||
|
||||
Let's modify PPO's global advantage normalization to use grouped normalization (GRPO
|
||||
approach).
|
||||
|
||||
### Understanding Data Organization
|
||||
|
||||
Each batch contains multiple prompts (batch size) and each prompt may have multiple
|
||||
responses (group size). So total sequences = batch_size × group_size.
|
||||
|
||||
Sequences have different lengths but are packed into a 1D tensor. We use `cu_seqlens`
|
||||
(cumulative sequence lengths) to mark boundaries, similar to flash-attention.
|
||||
|
||||
### Implementation
|
||||
|
||||
The standard PPO normalization looks like:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class PPOActorInterface(ModelInterface):
|
||||
def train_step(self, model: Model, data: SequenceSample, mb_spec: MicroBatchSpec) -> Dict | List[Dict]:
|
||||
# ...
|
||||
if self.adv_norm:
|
||||
advantages = masked_normalization(advantages, loss_mask)
|
||||
# ...
|
||||
```
|
||||
|
||||
For grouped normalization, we partition advantages by group:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class PPOActorInterface(ModelInterface):
|
||||
group_adv_norm: bool = False
|
||||
|
||||
def train_step(self, model: Model, data: SequenceSample, mb_spec: MicroBatchSpec) -> Dict | List[Dict]:
|
||||
# ...
|
||||
if self.adv_norm:
|
||||
if not self.group_adv_norm:
|
||||
advantages = masked_normalization(advantages, loss_mask)
|
||||
else:
|
||||
n_samples = data.bs
|
||||
adv_list = []
|
||||
for i in range(0, n_samples, self.group_size):
|
||||
# Define chunk boundaries
|
||||
s = short1cu_seqlens[i]
|
||||
e = short1cu_seqlens[i + self.group_size]
|
||||
|
||||
# Extract advantages for this group
|
||||
adv = advantages[s:e]
|
||||
mask = loss_mask[s:e]
|
||||
|
||||
# Normalize within group
|
||||
advn = masked_normalization(adv, mask, all_reduce=False)
|
||||
adv_list.append(advn)
|
||||
|
||||
advantages = torch.cat(adv_list, 0)
|
||||
# ...
|
||||
```
|
||||
|
||||
### Configuration Changes
|
||||
|
||||
Update the experiment configuration to expose the new parameter:
|
||||
|
||||
```python
|
||||
@dataclasses.dataclass
|
||||
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
||||
group_adv_norm: bool = False
|
||||
|
||||
@property
|
||||
def rpcs(self):
|
||||
# ...
|
||||
actor_interface = ModelInterfaceAbstraction(
|
||||
"ppo_actor",
|
||||
args={
|
||||
**copy.deepcopy(self.ppo_kwargs),
|
||||
"group_adv_norm": self.group_adv_norm,
|
||||
# ...
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
## Example 2: Decoupled PPO Loss
|
||||
|
||||
The decoupled PPO loss (from AReaL's paper) recomputes probabilities before mini-batch
|
||||
updates and uses this as π_prox:
|
||||
|
||||

|
||||
|
||||
### Probability Recomputation
|
||||
|
||||
We recompute probabilities using the existing `inference` method:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class PPOActorInterface(ModelInterface):
|
||||
use_decoupled_loss: bool = False
|
||||
|
||||
def train_step(self, model: Model, data: SequenceSample, mb_spec: MicroBatchSpec) -> Dict | List[Dict]:
|
||||
if self.use_decoupled_loss:
|
||||
s: SequenceSample = self.inference(model, data, mb_spec)
|
||||
prox_logp = s.data["logprobs"]
|
||||
|
||||
# Prepare mini-batch data
|
||||
flat_data = dict(
|
||||
advantages=advantages,
|
||||
old_logp=old_logp,
|
||||
ppo_loss_mask=loss_mask,
|
||||
packed_input_ids=input_.data["packed_input_ids"],
|
||||
kl_rewards=kl_rewards,
|
||||
)
|
||||
|
||||
if self.use_decoupled_loss:
|
||||
flat_data["prox_logp"] = prox_logp.float()
|
||||
|
||||
flat_input = SequenceSample.from_default(
|
||||
ids=list(range(input_.bs * self.group_size)),
|
||||
data=flat_data,
|
||||
seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()],
|
||||
)
|
||||
|
||||
# Split into mini-batches and train
|
||||
datas = flat_input.split_with_spec(spec)
|
||||
for mb_i, data in enumerate(datas):
|
||||
train_stat = module.train_batch(
|
||||
input_=data,
|
||||
mb_spec=mb_spec,
|
||||
version_steps=model.version.global_step,
|
||||
loss_fn=_loss_fn,
|
||||
loss_weight_fn=lambda x: x.data["ppo_loss_mask"].count_nonzero(),
|
||||
token_normalize_scope=self.token_normalize_scope,
|
||||
)
|
||||
```
|
||||
|
||||
### Modifying the Loss Function
|
||||
|
||||
Update the loss computation to use the recomputed probabilities:
|
||||
|
||||
```python
|
||||
def _ppo_actor_loss_from_model_outputs(
|
||||
logits: torch.FloatTensor, # [tot_seqlen, vocab_size]
|
||||
input_: SequenceSample,
|
||||
...
|
||||
) -> torch.Tensor:
|
||||
# ...
|
||||
prox_logp = input_.data.get("prox_logp")
|
||||
|
||||
loss, ppo_stat = ppo_functional.actor_loss_fn(
|
||||
logprobs=logprobs,
|
||||
old_logprobs=old_logp,
|
||||
advantages=advantages,
|
||||
eps_clip=eps_clip,
|
||||
loss_mask=ppo_loss_mask,
|
||||
c_clip=c_clip,
|
||||
proximal_logprobs=prox_logp,
|
||||
behav_imp_weight_cap=behav_imp_weight_cap,
|
||||
)
|
||||
```
|
||||
|
||||
And in the core loss function:
|
||||
|
||||
```python
|
||||
def actor_loss_fn(
|
||||
logprobs: torch.FloatTensor,
|
||||
old_logprobs: torch.FloatTensor,
|
||||
advantages: torch.FloatTensor,
|
||||
eps_clip: float,
|
||||
loss_mask: Optional[torch.BoolTensor] = None,
|
||||
c_clip: Optional[float] = None,
|
||||
proximal_logprobs: Optional[torch.FloatTensor] = None,
|
||||
behav_imp_weight_cap: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
# Use proximal probabilities if available, otherwise use old probabilities
|
||||
denorm_logprobs = proximal_logprobs if proximal_logprobs is not None else old_logprobs
|
||||
|
||||
loss_mask_count = loss_mask.count_nonzero() or 1
|
||||
|
||||
# Compute importance weights
|
||||
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
|
||||
|
||||
# Apply behavioral importance weighting for decoupled loss
|
||||
if proximal_logprobs is not None:
|
||||
behav_kl = proximal_logprobs - old_logprobs
|
||||
behav_imp_weight = behav_kl.exp()
|
||||
behav_kl = torch.where(loss_mask, behav_kl, 0.0)
|
||||
behav_imp_weight = torch.where(loss_mask, behav_imp_weight, 0.0)
|
||||
pg_loss = pg_loss * behav_imp_weight
|
||||
|
||||
# ...
|
||||
return pg_loss, stat
|
||||
```
|
||||
|
||||
### Configuration Update
|
||||
|
||||
```python
|
||||
@dataclasses.dataclass
|
||||
class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
||||
use_decoupled_loss: bool = False
|
||||
|
||||
@property
|
||||
def rpcs(self):
|
||||
# ...
|
||||
actor_interface = ModelInterfaceAbstraction(
|
||||
"ppo_actor",
|
||||
args={
|
||||
**copy.deepcopy(self.ppo_kwargs),
|
||||
"use_decoupled_loss": self.use_decoupled_loss,
|
||||
# ...
|
||||
},
|
||||
)
|
||||
```
|
|
@ -0,0 +1,146 @@
|
|||
# Dataset (Legacy)
|
||||
|
||||
> **Note**: While this legacy approach works, we strongly recommend using the AReaLite
|
||||
> for new projects. It provides better flexibility, cleaner abstractions, and easier
|
||||
> maintenance.
|
||||
|
||||
## Define Your Dataset
|
||||
|
||||
Create a new file under `realhf/impl/dataset/`, for example, `my_custom_dataset.py`.
|
||||
Your `Dataset` must implement the `torch.utils.data.Dataset` interface and follow the
|
||||
framework's conventions.
|
||||
|
||||
```python
|
||||
class MyCustomDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
util: data_api.DatasetUtility,
|
||||
max_length: Optional[int] = None,
|
||||
dataset_path: Optional[str] = None,
|
||||
dataset_builder: Optional[Callable[[], List[Dict]]] = None,
|
||||
# Your custom parameters
|
||||
custom_param: float = 1.0,
|
||||
):
|
||||
"""Custom dataset initialization
|
||||
|
||||
Args:
|
||||
util: Dataset utility class containing tokenizer, seed, distributed info, etc.
|
||||
max_length: Maximum sequence length
|
||||
dataset_path: Path to dataset file (optional)
|
||||
dataset_builder: Data construction function (optional, alternative to dataset_path)
|
||||
custom_param: Your custom parameter
|
||||
"""
|
||||
self._util = util
|
||||
self.max_length = max_length
|
||||
|
||||
# Load and split dataset
|
||||
data = data_api.load_shuffle_split_dataset(util, dataset_path, dataset_builder)
|
||||
|
||||
# Your custom data processing logic
|
||||
...
|
||||
```
|
||||
|
||||
## Implement Core Methods
|
||||
|
||||
Every dataset class must implement the following two core methods:
|
||||
|
||||
### 1. `__len__` Method
|
||||
|
||||
Returns the size of the dataset:
|
||||
|
||||
```python
|
||||
def __len__(self):
|
||||
return len(self.data_samples)
|
||||
```
|
||||
|
||||
### 2. `__getitem__` Method
|
||||
|
||||
Returns the sample at the specified index, must return a `SequenceSample` object:
|
||||
|
||||
```python
|
||||
def __getitem__(self, idx):
|
||||
# Get raw data
|
||||
sample = self.data_samples[idx]
|
||||
|
||||
# Process data
|
||||
...
|
||||
|
||||
# Return SequenceSample object
|
||||
return data_api.SequenceSample.from_default(
|
||||
ids=[sample["id"]],
|
||||
seqlens=[len(processed_data["input_ids"])],
|
||||
data=dict(
|
||||
packed_prompts=torch.tensor(processed_data["input_ids"], dtype=torch.long),
|
||||
# Other necessary data fields
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
### Dataset Examples
|
||||
|
||||
We provide some examples of dataset under `realhf/impl/dataset/`:
|
||||
|
||||
- For SFT, please refer `prompt_answer_dataset.py`.
|
||||
- For Reward model training, please refer `rw_paired_dataset.py`
|
||||
- For RL training, please refer `math_code_dataset.py`
|
||||
|
||||
## Data Format Requirements
|
||||
|
||||
### JSONL File Format
|
||||
|
||||
Your data file should be in JSONL format, with one JSON object per line. If you are
|
||||
using our PromptDataset implementation, your data should be like:
|
||||
|
||||
- Math Data
|
||||
|
||||
```json
|
||||
{"qid": "sample_1", "prompt": "Solve this math problem: 2+2=", "solutions": ["\\boxed{4}"]}
|
||||
```
|
||||
|
||||
- Code Data
|
||||
|
||||
```json
|
||||
{"qid": "sample_2", "prompt": "Code problem", "input_output": "{\"inputs\": [\"5\\n2 3 5 10 12\\n\"], \"outputs\": [\"17\\n\"]}"}
|
||||
```
|
||||
|
||||
- `qid`: Unique identifier for the sample
|
||||
- `prompt`: Input prompt text
|
||||
- `task`: Task type, used to distinguish how to calculate the reward. ("math" and "code"
|
||||
are supported now.)
|
||||
|
||||
Note: There is no format restriction for a customized dataset as long as it can be
|
||||
loaded by your custom code.
|
||||
|
||||
## Registration and Configuration
|
||||
|
||||
### Register Dataset
|
||||
|
||||
Register your dataset at the end of your dataset file:
|
||||
|
||||
```python
|
||||
# in realhf/impl/dataset/my_custom_dataset.py
|
||||
data_api.register_dataset("my-custom", MyCustomDataset)
|
||||
```
|
||||
|
||||
### Modify Experiment Configuration
|
||||
|
||||
Use your new dataset in the experiment configuration (refer to
|
||||
`realhf/experiments/common/*_exp.py`):
|
||||
|
||||
```python
|
||||
# in your experiment config file
|
||||
@property
|
||||
def datasets(self) -> List[DatasetAbstraction]:
|
||||
return [
|
||||
DatasetAbstraction(
|
||||
"my-custom", # Your registered name
|
||||
args=dict(
|
||||
dataset_path=self.dataset_path,
|
||||
max_length=self.max_length,
|
||||
custom_param=self.custom_param,
|
||||
# Other initialization parameters
|
||||
),
|
||||
)
|
||||
]
|
||||
```
|
Before Width: | Height: | Size: 47 KiB After Width: | Height: | Size: 47 KiB |
Before Width: | Height: | Size: 32 KiB After Width: | Height: | Size: 32 KiB |
Loading…
Reference in New Issue