mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
parent
5ffce54c68
commit
6ebcbc5c4b
|
@ -1,158 +1,213 @@
|
|||
# AReaLite (TBD)
|
||||
# AReaLite Design Doc
|
||||
|
||||
A simplified and easy-to-read version of AReaL with a minimal set of APIs.
|
||||
## Motivation
|
||||
|
||||
AReaL is too heavy for AI researchers to use, understand, and develop with for several reasons. The most important issue is that its code architecture is *system-centric* rather than *AI-centric* — the RL algorithm workflow consists of multiple *workers* that run consecutive *model function calls*, neither of which are well-known concepts for AI researchers. As a result, users must first understand these concepts before they can develop workflows and algorithms for their own use cases.
|
||||
|
||||
Additionally, due to historical reasons, AReaL's code is not clean. There are large pieces of code inherited from previous projects that are not useful but significantly increase the burden on users and developers. Sometimes debugging is difficult even for core developers like myself.
|
||||
|
||||
Since the tools for building RL workflows are becoming increasingly mature, implementing a framework that achieves comparable efficiency requires much fewer lines of code. Now is the proper time to revisit the API design and distill the giant codebase into a neat and clean one. The distilled codebase does not need to be ultra-efficient. Instead, we want to deliver 90% functionality of the original AReaL while minimizing the lines of code and the burden on potential users. Our aim is to build an RL training framework that is fast to use, fast to read, and fast to execute. Here comes the lite version of AReaL — AReaLite.
|
||||
|
||||
AReaLite is the first step in AReaL's refactoring process. It is not only a standalone training library with shallow interfaces, but will also provide the core API definitions to be used by AReaL in the future. AReaL will essentially transform its current worker-based architecture into an AI-centric architecture like AReaLite. AReaL will **extend** AReaLite's APIs and implementations to support more backends for efficient large-scale training.
|
||||
|
||||
## Expectations of AReaLite
|
||||
|
||||
### Highlights
|
||||
|
||||
+ Fully asynchronous training with decoupled inference and training.
|
||||
+ Elastic inference device scaling — users can shut down or launch more inference processes independently during training.
|
||||
+ Full SFT/RL algorithmic functionality matching AReaL.
|
||||
+ Arbitrary agentic rollout workflow customization in a single file.
|
||||
+ Easy navigation to implementation references via Ctrl+click in VSCode.
|
||||
+ Support for distributed launching with Ray/SLURM/torchrun.
|
||||
|
||||
### AReaLite's Scope
|
||||
|
||||
+ Not bound to Ray.
|
||||
+ Only supports SGLang and PyTorch FSDP2 with SPMD launching.
|
||||
+ No customized data structures like `SequenceSample`. All data are PyTorch tensors.
|
||||
+ Uses HuggingFace (models, datasets) and PyTorch (FSDP, data structures) as much as possible.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
```
|
||||
arealite/
|
||||
├── api/ # Abstract interfaces and data structures
|
||||
├── impl/ # Concrete implementations
|
||||
├── cli/ # Command-line interfaces
|
||||
├── config/ # Configuration templates
|
||||
└── tests/ # Standalone test scripts
|
||||
```
|
||||
|
||||
#### 1. API Layer (`api/`)
|
||||
|
||||
The API layer defines abstract interfaces and data structures that provide a clean contract between different components:
|
||||
|
||||
- **`engine_api.py`**: Defines `SPMDWrapper` for SPMD-based training backends (FSDP) and `EngineFactory`
|
||||
- **`trainer_api.py`**: Defines `Trainer` base class for different training algorithms and `TrainerFactory`
|
||||
- **`rollout_api.py`**: Defines `RolloutWorkflow`, `Agent`, `Environment` for RL data collection and `RolloutWorkflowFactory`
|
||||
- **`cli_args.py`**: Defines configuration dataclasses for all components
|
||||
|
||||
#### 2. Implementation Layer (`impl/`)
|
||||
|
||||
The implementation layer contains concrete implementations of the API interfaces:
|
||||
|
||||
- **`fsdp_wrapper.py`**: FSDP-based training engine using PyTorch FSDP2
|
||||
- **`trainer/ppo.py`**: PPO trainer implementation for reinforcement learning
|
||||
- **`rollout_controller.py`**: Coordinates rollout data collection across workers
|
||||
- **`rlvr/`**: RLVR (RL via Verification and Refinement) workflow implementations
|
||||
- **`agentic/`**: Agentic workflow implementations (math, code tasks)
|
||||
|
||||
#### 3. CLI Layer (`cli/`)
|
||||
|
||||
The CLI layer provides user-facing commands:
|
||||
|
||||
- **`main.py`**: Main entry point for launching complete training pipelines
|
||||
- **`launch_server.py`**: Utility for launching standalone LLM servers
|
||||
|
||||
### Data Flow Architecture
|
||||
|
||||
AReaLite uses an **async producer-consumer pattern**:
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
|
||||
│ LLM Servers │◄──►│ Rollout Workers │───►│ Data Buffer │
|
||||
│ (SGLang) │ │ (Async Batch) │ │ │
|
||||
└─────────────────┘ └──────────────────┘ └─────────────────┘
|
||||
▲ │
|
||||
│ ▼
|
||||
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
|
||||
│ Checkpoints │◄───│ FSDP Trainer │◄───│ Training Loop │
|
||||
│ │ │ (Sync Batch) │ │ │
|
||||
└─────────────────┘ └──────────────────┘ └─────────────────┘
|
||||
```
|
||||
|
||||
### Key Design Principles
|
||||
|
||||
#### 1. **AI-Centric API Design**
|
||||
Unlike the original AReaL's system-centric approach with workers and model functions, AReaLite uses familiar ML concepts:
|
||||
- `Agent` and `Environment` (from RL literature)
|
||||
- `RolloutWorkflow` (combines multiple agents and the environment to generate rollout data)
|
||||
- `Trainer` (from HuggingFace/PyTorch, fetches rollout data and updates model parameters)
|
||||
|
||||
#### 2. **Factory Pattern for Extensibility**
|
||||
Each major component uses a factory pattern for easy customization:
|
||||
- `EngineFactory` creates training backends
|
||||
- `TrainerFactory` creates training algorithms
|
||||
- `RolloutWorkflowFactory` creates rollout workflows
|
||||
|
||||
#### 3. **Configuration-Driven Architecture**
|
||||
All components are configured through dataclasses defined in `cli_args.py`, enabling:
|
||||
- Type-safe configuration
|
||||
- Easy CLI argument generation
|
||||
- Clear documentation of available options
|
||||
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Training Pipeline
|
||||
|
||||
1. **Initialization**: Factory classes create configured instances of engines, trainers, and rollout workflows
|
||||
2. **Rollout Phase**: `RolloutController` coordinates async data collection across multiple `RolloutWorker` instances
|
||||
3. **Training Phase**: `Trainer` performs synchronous gradient updates using collected data
|
||||
4. **Weight Updates**: Updated model weights are pushed to LLM servers via `update_weights_to()`
|
||||
|
||||
### Rollout System
|
||||
|
||||
The rollout system supports arbitrary agentic rollout paradigms, implemented as `RolloutWorkflow` instances. `RolloutWorkflow` exposes a `run_episode` method for users to implement the logic of collecting a complete agentic trajectory. Users can implement gymnasium-compatible `Agent` and `Environment` interfaces first and combine them as a workflow as in normal RL literature (in `arealite/impl/agentic/`), or users can implement the workflow directly if the agent-environment interfaces are not compatible with the desired use cases (in `arealite/impl/rlvr/`).
|
||||
|
||||
## Expected Usage
|
||||
|
||||
Launch LLM servers and trainers at the same time according to the allocation mode:
|
||||
|
||||
### Basic RL Training
|
||||
```bash
|
||||
python3 arealite.cli.main \
|
||||
experiment_name=my-exp trial_name=my-trial \
|
||||
mode=${torchrun|ray|slurm} \
|
||||
allocation_mode=sglang.d16p1m1+d32p2m1 \
|
||||
shutdown_server_on_exit=False \
|
||||
trainer.type=ppo trainer.ppo.async_training=True
|
||||
trainer.type=ppo \
|
||||
trainer.ppo.actor.path=Qwen/Qwen2-0.5B
|
||||
```
|
||||
|
||||
Add new servers elastically:
|
||||
|
||||
```bash
|
||||
python3 arealite.cli.launch_server \
|
||||
experiment_name=my-exp trial_name=my-trial
|
||||
```
|
||||
|
||||
If the trainer dies, restart the experiment without re-launching the server:
|
||||
|
||||
### Rollout-Only Evaluation
|
||||
```bash
|
||||
python3 arealite.cli.main \
|
||||
experiment_name=my-exp trial_name=my-trial \
|
||||
min_required_servers=17
|
||||
```
|
||||
|
||||
Run rollout or evaluation only:
|
||||
|
||||
```bash
|
||||
python3 arealite.cli.main trainer.type=null \
|
||||
trainer.type=null \
|
||||
valid_dataset.path=huggingface/dataset
|
||||
```
|
||||
|
||||
## How to use in its current form (Dev WIP)
|
||||
|
||||
1. Launch a wrapped SGLang server:
|
||||
|
||||
### Distributed Training
|
||||
```bash
|
||||
python3 arealite/cli/launch_server.py experiment_name=test_rollout trial_name=test_rollout
|
||||
python3 arealite.cli.main \
|
||||
mode=ray \
|
||||
allocation_mode=sglang.d16p1m1+d32p2m1 \
|
||||
trainer.type=ppo
|
||||
```
|
||||
|
||||
This command reads the configuration from `config/llm_server.yaml` and launches a local SGLang server. It registers the server address for later use by the trainer.
|
||||
## Customization Guide
|
||||
|
||||
2. Run simple tests:
|
||||
|
||||
```bash
|
||||
python3 arealite/tests/test_rollout.py
|
||||
```
|
||||
|
||||
This command uses the same experiment and trial names as those used to launch LLM servers, so the test automatically finds the server address. You can then run the rollout loop with a pre-defined workflow to collect training data.
|
||||
|
||||
## Customization
|
||||
|
||||
All customizations follow similar procedures:
|
||||
|
||||
1. Inherit the base class (e.g., `Trainer`) and write your customized object under the `impl/` folder (e.g., `impl/trainer/ppo.py`).
|
||||
|
||||
2. Modify the factory class in the API file (e.g., `api/trainer_api.py`) to allow initialization of your customized class, e.g.:
|
||||
|
||||
```diff
|
||||
@dataclass
|
||||
class TrainerFactory:
|
||||
args: TrainingArgs
|
||||
|
||||
def make_trainer(
|
||||
self,
|
||||
config: TrainerConfig,
|
||||
train_dataset: Dataset,
|
||||
valid_dataset: Optional[Dataset] = None,
|
||||
rollout_controller: Optional["RolloutController"] = None,
|
||||
extra_args: Optional[Dict] = None,
|
||||
) -> Trainer:
|
||||
if config.type == "ppo":
|
||||
from arealite.impl.trainer.ppo import SpmdPPOTrainer
|
||||
|
||||
return SpmdPPOTrainer(
|
||||
self.args,
|
||||
config,
|
||||
train_dataset=train_dataset,
|
||||
valid_dataset=valid_dataset,
|
||||
rollout_controller=rollout_controller,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
+ if config.type == "sft":
|
||||
+ from arealite.impl.trainer.sft import SpmdSFTTrainer
|
||||
+ return SpmdSFTTrainer(
|
||||
+ ...
|
||||
+ )
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown agent type: {config.type}")
|
||||
```
|
||||
|
||||
3. Modify the CLI args so you can customize your agent through the command line:
|
||||
### Adding New Trainers
|
||||
|
||||
1. **Implement trainer class** in `impl/trainer/`:
|
||||
```python
|
||||
@dataclass
|
||||
class SFTTrainerConfig:
|
||||
model: EngineConfig = field(
|
||||
default_factory=EngineConfig,
|
||||
metadata={"help": "Model configuration for SFT training"},
|
||||
)
|
||||
mb_spec: MicroBatchSpec = field(
|
||||
default_factory=MicroBatchSpec,
|
||||
metadata={"help": "Micro-batch specification for SFT training"},
|
||||
)
|
||||
from arealite.api.trainer_api import Trainer
|
||||
|
||||
class MyTrainer(Trainer):
|
||||
def train(self, resume_from_checkpoint=None):
|
||||
# Implementation here
|
||||
pass
|
||||
```
|
||||
|
||||
```diff
|
||||
@dataclass
|
||||
class TrainerConfig:
|
||||
type: str = field(
|
||||
default="ppo",
|
||||
metadata={"help": "Trainer type", "choices": ["ppo", "sft", "null"]},
|
||||
)
|
||||
ppo: Optional[PPOTrainerConfig] = field(
|
||||
default=None, metadata={"help": "PPO trainer configuration (if using PPO)"}
|
||||
)
|
||||
+ sft: Optional[SFTTrainerConfig] = field(
|
||||
+ default=None, metadata={"help": "SFT trainer configuration (if using SFT)"}
|
||||
+ )
|
||||
|
||||
2. **Add configuration** in `cli_args.py`:
|
||||
```python
|
||||
@dataclass
|
||||
class MyTrainerConfig:
|
||||
learning_rate: float = 1e-4
|
||||
```
|
||||
|
||||
Then you can use your own trainer with a command like:
|
||||
|
||||
```bash
|
||||
python3 train.py trainer.sft.model.path=Qwen/Qwen2-0.5B
|
||||
3. **Register in factory** in `trainer_api.py`:
|
||||
```python
|
||||
def make_trainer(self, config: TrainerConfig) -> Trainer:
|
||||
if config.type == "my_trainer":
|
||||
return MyTrainer(...)
|
||||
```
|
||||
|
||||
## Unit Tests
|
||||
### Adding New Rollout Workflows
|
||||
|
||||
Unit tests are placed under the `tests` folder, but currently they are essentially not *unit* tests but rather standalone scripts to ensure that individual components are runnable.
|
||||
1. **Implement workflow** in `impl/`:
|
||||
```python
|
||||
from arealite.api.rollout_api import RolloutWorkflow
|
||||
|
||||
They need further refactoring to use `pytest`.
|
||||
class MyWorkflow(RolloutWorkflow):
|
||||
async def arun_episode(self, gconfig, env_option=None, seed=None):
|
||||
# Implementation here
|
||||
pass
|
||||
```
|
||||
|
||||
## TODOs
|
||||
2. **Register in factory** in `rollout_api.py`:
|
||||
```python
|
||||
def make_workflow(self, config: RolloutWorkflowConfig):
|
||||
if config.type == "my_workflow":
|
||||
return MyWorkflow(...)
|
||||
```
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [ ] Finalize API design. (In-progress)
|
||||
- [x] Implement standalone SGLang server (`impl/sglang_server.py`).
|
||||
- [x] Implement SGLang client generation (`impl/sglang_client.py`).
|
||||
- [x] Rollout pipeline (`tests/test_rollout.py`).
|
||||
- [ ] FSDP2 engine with transformers models. (In-progress)
|
||||
- [ ] SGLang update weights. (In-progress)
|
||||
- [ ] Synchronous PPO training pipeline (`impl/trainer/ppo.py`). (In-progress)
|
||||
- [x] SGLang rollout interruption.
|
||||
- [x] Asynchronous RL system-wide utilities (e.g., `RolloutController`).
|
||||
- [ ] CI and unittests. (FSDP trainer and PPO trainer in-progress.)
|
||||
- [ ] Benchmark performance versus the original AReaL code.
|
||||
- [ ] Various launching scripts: ray, torchrun, slurm.
|
||||
- [ ] Design doc and user guide for transitting from v0.3.0.
|
||||
- [ ] Examples of training GSM8K, TLDR, and a seach agent.
|
||||
- [ ] Allow external persistent SGLang servers for debugging purposes.
|
||||
- [ ] FSDP2 engine with transformers models. (In-progress)
|
||||
- [ ] SFT trainer. (In-progress)
|
||||
- [ ] SGLang update weights. (In-progress)
|
||||
- [ ] PPO trainer. (In-progress)
|
||||
- [ ] Add benchmarking against original AReaL
|
||||
- [ ] CI and unittests.
|
||||
- [ ] Other RL algorithms (DPO, REINFORCE, etc.)
|
||||
- [ ] Support for multi-modal models
|
||||
- [ ] User guide for transitioning from v0.3.0.
|
||||
- [ ] Advanced agentic workflows (tool use, planning)
|
||||
- [ ] Examples of training GSM8K, TLDR, and a search agent.
|
||||
- [ ] Allow external persistent SGLang servers for debugging purposes.
|
||||
|
|
|
@ -482,11 +482,11 @@ class SpmdPPOTrainer(Trainer):
|
|||
global_step = 0
|
||||
for epoch in range(total_epochs):
|
||||
for step, data in enumerate(self.train_dataloader):
|
||||
self.rollout_controller.submit(data)
|
||||
if self.config.async_training:
|
||||
# Submitted data will not actually be sent for rollout.
|
||||
# The rollout controller over-subscribe the data to
|
||||
# ensure that there are enough data being generated.
|
||||
self.rollout_controller.submit(data)
|
||||
if global_step < self.args.rollout.max_head_offpolicyness + 1:
|
||||
continue
|
||||
|
||||
|
|
Loading…
Reference in New Issue