diff --git a/arealite/README.md b/arealite/README.md index 4a23ce6..d0cae4e 100644 --- a/arealite/README.md +++ b/arealite/README.md @@ -1,52 +1,174 @@ -# AReaL v1.0.0 Design Doc +# AReaLite Design Doc -______________________________________________________________________ +## TL;DR -SFT example: +Follow our [step-by-step code walk-through](../docs/arealite/gsm8k_grpo.md) to +immediately get started with AReaLite! -```bash -torchrun --nnodes 1 --nproc-per-node 8 examples/arealite/gsm8k_sft.py --config examples/arealite/configs/gsm8k_sft.yaml +## Motivation + +AReaL presents several challenges that make it difficult for AI researchers to adopt, +understand, and develop with effectively. The primary issue stems from its +*system-centric* rather than *AI-centric* architecture. The reinforcement learning +algorithm workflow is built around multiple *workers* executing consecutive *model +function calls* — concepts that are unfamiliar to most AI researchers. This forces users +to first master these system-level abstractions before they can implement workflows and +algorithms for their specific research needs. + +Beyond architectural concerns, AReaL suffers from accumulated technical debt. The +codebase contains substantial legacy code inherited from previous projects that no +longer serves a purpose but significantly increases complexity for both users and +developers. Even experienced core developers sometimes struggle with debugging due to +this accumulated complexity. + +The landscape of RL workflow development tools has matured considerably, making it +possible to achieve comparable efficiency with significantly fewer lines of code. This +presents an ideal opportunity to redesign the API and distill the massive codebase into +something clean and maintainable. Rather than pursuing maximum efficiency, our goal is +to deliver 90% of AReaL's functionality while dramatically reducing code complexity and +user burden. This philosophy drives AReaLite — the lightweight version of AReaL. + +AReaLite serves as the first phase in AReaL's broader refactoring initiative. It +functions both as a standalone training library with intuitive interfaces and as the +foundation for AReaL's future core API definitions. The plan is to transform AReaL's +current worker-based architecture into an AI-centric architecture similar to AReaLite, +where AReaL will **extend** AReaLite's APIs and implementations to support additional +backends for efficient large-scale training. + +## Design Principles + +Our design is guided by seven core principles: + +1. **Native asynchronous RL training support** — Built from the ground up for + disentangled generation and training +1. **AI-centric design** — Minimize exposure to system concepts like "PlacementGroup" +1. **PyTorch-centric approach** — Use raw PyTorch types without unnecessary abstractions +1. **Transparent algorithm orchestration** — Make the flow of operations clear and + understandable +1. **Developer-friendly navigation** — Enable easy access to implementation details via + Ctrl+click in IDEs +1. **Ecosystem compatibility** — Integrate smoothly with existing ML/RL tools +1. **Single-file customization** — Allow RL pipeline modifications within a single file + +## Architecture + +### Core Directory Structure + +``` +arealite/ +├── api/ # Abstract interfaces and dataclasses +├── engine/ # Training and inference engines +├── launcher/ # Launcher for different backends +├── tests/ # Standalone test scripts +└── workflow/ # Custom RL rollout workflows ``` -GRPO example: +### Component Overview + +#### 1. API Layer (`api/`) + +The API layer establishes clean contracts between components through abstract interfaces +and data classes: + +- **`engine_api.py`**: Defines `TrainEngine` for SPMD-based distributed training + backends and `InferenceEngine` for streaming LLM inference +- **`workflow.py`**: Defines `RolloutWorkflow` for RL data collection within a unified + method interface +- **`cli_args.py`**: Contains configuration dataclasses for all system components + +The workflow object invokes `InferenceEngine` to complete data collection following +customized patterns, providing flexibility while maintaining consistency. + +#### 2. Backend Layer (`engine/`) + +The backend layer provides adapters for third-party libraries, ensuring they conform to +the APIs defined in `engine_api.py`. These components deliver core inference and +training capabilities: + +- **`fsdp_engine.py`**: FSDP-based training engine utilizing PyTorch FSDP2 +- **`sglang_remote.py`**: Client interface for generation with remote SGLang servers + +#### 3. Customization Layer (`engine/ppo/`, `workflow/`) + +This layer applies backend capabilities to implement specific reinforcement learning +pipelines: + +- **`engine/ppo/actor.py`**: PPO algorithm implementation that leverages a `TrainEngine` +- **`workflow/rlvr.py`**: RLVR workflow that utilizes an `InferenceEngine` to sample + multiple responses for each prompt + +#### 4. Entry Point Layer (`examples/arealite/`) + +The entry point layer composes customization layer implementations to create complete RL +training pipelines. While we provide several reference examples, users have complete +freedom to adapt these to their specific use cases. + +Entry points can be launched using launchers from `arealite/launcher/`, similar to other +distributed launch tools like `torchrun`: ```bash -python3 -m arealite.launcher.local examples/arealite/gsm8k_grpo.py --config examples/arealite/configs/gsm8k_grpo.yaml +python3 -m arealite.launcher.ray entrypoint.py --config my-config.yaml ``` -______________________________________________________________________ +## Usage Examples -We will provide both single-controller and SPMD user interfaces. The SPMD interface will -be delivered with AReaLite, which is the paradigm most users are familiar with, just -like using `torchrun` or `deepspeed`. However, this paradigm may lack some flexibility -over global scheduling and control. To unlock the full potential with customized -distributed execution, we will also provide a single-controller mode just like using Ray ---- but our scheduler backend will not be restricted to Ray. Our code will be able to -run with any scheduler in the cluster, such as native SLURM and K8S. +### Basic RL Training -However, we want the user code to stay the same for both modes. The following is a -simple usage example: +Users must provide a YAML configuration file, though they can override configuration +parameters for hyperparameter searches or other experimental needs: + +```bash +# Launch with Ray launcher: 4 nodes (4 GPUs each), 3 nodes for generation, 1 node for training +python3 -m arealite.launcher.ray examples/arealite/gsm8k_grpo.py \ + --config examples/arealite/configs/gsm8k_grpo.yaml \ + experiment_name= \ + trial_name= \ + allocation_mode=sglang.d12p1t1+d4p1t1 \ + cluster.n_nodes=4 \ + cluster.n_gpus_per_node=4 + +# Launch with Slurm launcher: 16 nodes (8 GPUs each), 12 nodes for generation, 4 nodes for training +python3 -m arealite.launcher.slurm examples/arealite/gsm8k_grpo.py \ + --config examples/arealite/configs/gsm8k_grpo.yaml \ + experiment_name= \ + trial_name= \ + allocation_mode=sglang.d96p1t1+d32p1t1 \ + cluster.n_nodes=16 \ + cluster.n_gpus_per_node=8 +``` + +## Customization Guide + +For detailed customization instructions, please refer to our documentation: + +- [Adding new agents](../docs/customization/agent.md) +- [Adding new datasets](../docs/customization/dataset.md) +- [Adding new algorithms](../docs/customization/algorithm.md) + +## Implementation Details + +### Entry Point Design Philosophy + +We considered two primary design patterns for entry points, each with distinct +trade-offs: + +#### Single-Controller Pattern + +The most modular approach uses a single-controller pattern where only one process in the +cluster executes the main coordination logic: ```python -def get_current_time(): - ... - def my_reward_fn(prompt, completion, prompt_ids, completion_ids, **kwargs): return len(completion_ids) class MyRolloutWorkflow: - def __init__(self, config: Any): - self.config = config - self.env = LocalToolingEnv() - self.env.register_tool(get_current_time) - async def arun_episode(self, engine: InferenceEngine, data: Dict[str, Any]) -> Dict[str, Tensor]: - ... message = [ {"role": "system", "message": ...}, {"role": "user", "message": ...}, ] + for _ in range(self.config.num_turns): text = tokenizer.apply_chat_template(message, tools=self.env.list_tools()) req = LLMRequest(text=text, ...) @@ -54,463 +176,335 @@ class MyRolloutWorkflow: tool_name, tool_args = parse_tool(resp) cur_time = await self.env.aexecute(tool_name, tool_args) message += [{"role": "user", "message": f"The current time is {cur_time}"}] + reward = my_reward_fn(None, None, None, req.input_ids, **data) - ... return output def main_grpo(): - dataset = load_dataset("openai/gsm8k", split="train") + config, _ = load_expr_config(args, GRPOConfig) - rollout_config, training_config = load_expr_config(sys.argv[1:]) + # Create rollout workflow + workflow = MyRolloutWorkflow() # Single-controller mode initialization scheduler = SlurmScheduler() - rollout = RolloutController( - SGLangEngine(rollout_config.engine), - rollout_config.controller, - scheduler, - ) - actor = TrainController( - MegatronGRPOActor(training_config.actor), - config.training_controller_config, - scheduler, - ) - ref = TrainController( - MegatronGRPOActor(training_config.ref), - config.training_controller_config, - scheduler, - ) - # SPMD mode initialization - # rollout = RemoteSGLangEngine(rollout_config.engine) - # actor = MegatronGRPOActor(training_config.actor) - # ref = MegatronGRPOActor(training_config.ref) + rollout = RolloutController(SGLangEngine(config.rollout), scheduler) + actor = TrainController(MegatronGRPOActor(config.actor), scheduler) - rollout.initialize() - actor.initialize() - ref.initialize() - - # Synchronous RL + # Training loop dataloader = StatefulDataloader(dataset) - for epoch in range(config.epoches): - data_generator = iter(dataloader) - for prompt in range(steps_per_epoch): - prompt = next(data_generator) + for _ in range(max_steps): + # Collect trajectories using rollout workflow + batch = rollout.rollout_batch(next(dataloader), workflow=workflow) + batch: DistributedBatch - # Update inference engine weights - future = rollout.update_weights(wcfg) - actor.upload_weights(wcfg) - future.result() + # Prepare training inputs + adv_batch = actor.compute_advantages_and_returns(batch) + batch['advantages'] = adv_batch['advantages'] - # synchronous rollout - rollout_batch = rollout.rollout_batch(batch, workflow=MyRolloutWorkflow(rollout_config.workflow)) - # or asynchronous rollout with filtering and off-policyness control - # rollout_batch = rollout.prepare_batch(batch, - # workflow=MyRolloutWorkflow(rollout_config.workflow), - # should_accept=lambda x: x['rewards'].mean() > 0) + # Execute PPO update + stats = actor.ppo_update(batch) - # In the single-controller mode - rollout_batch: DistributedBatch - x: TensorDict = rollout_batch.load_data() - # In the SPMD mode - # rollout_batch: TensorDict - - batch['input_ids'] = rollout_batch['input_ids'] - batch['rewards'] = rollout_batch['rewards'] - - # prepare train inputs - batch['ref_logp'] = ref.compute_logp(batch) - adv_batch = actor.compute_advantages_and_returns(batch) - batch['advantages'] = adv_batch['advantages'] - - # PPO update - stats = actor.ppo_update(batch) - print(stats) - -if __name__ == "__main__": - main_grpo() + # Update inference engine weights (non-blocking to prevent NCCL blocking) + future = rollout.update_weights(wcfg) + actor.upload_weights(wcfg) + future.result() ``` -The launch commands will be: +**Advantages:** -```bash -# Single-controller mode -python3 main_grpo.py --config my_config.yaml rollout.workflow.x=1 -# SPMD mode -CUDA_VISIBLE_DEVICES=0,1 nohup python3 -m sglang.launch_server \ - --seed 1 --host x.x.x.x --port 7777 --dp_size 2 > server.out 2>&1 & -CUDA_VISIBLE_DEVICES=3,4 \ - torchrun --nnodes 1 --nproc-per-node 2 \ - main_grpo.py --config my_config.yaml \ - rollout.workflow.x=1 \ - rollout.engine.addresses="\[x.x.x.x, y.y.y.y\]" -``` +- Provides maximum flexibility for device allocation, scheduling, and data arrangement -## Core API +**Disadvantages:** -- A specific algorithm must use these core components. -- Concrete implementations must follow the API definition. +- Introduces multiple abstractions (`TrainController`, `Scheduler`, `DistributedBatch`) + that complicate the script -### TrainEngine +#### SPMD Pattern -TrainEngine is a thin wrapper around existing training frameworks (FSDP, Megatron), -providing a unified interface for RL algorithms for computation, parameter saving and -loading, and providing a unified weight update interface for inference engines. +Given that AI researchers are more familiar with the SPMD (Single Program, Multiple +Data) pattern used in standard model training, we also provide an entry point following +this approach. With N GPUs dedicated to training, N processes execute the following +code: ```python -############################################# -@dataclass -class WeightUpdateMeta: - type: str - path: str | None - alloc_mode: AllocationMode | None +def main_grpo(): + config, _ = load_expr_config(args, GRPOConfig) -@dataclass -class SaveLoadMeta: - path: str - weight_format: str - with_optim: bool - tokenizer: PreTrainedTokenizerFast | None - base_model_path: str | None + # Create rollout workflow + workflow = MyRolloutWorkflow() + # SPMD mode initialization + rollout = RemoteSGLangEngine(config.rollout) + actor = MegatronGRPOActor(config.actor) + + # Training loop + dataloader = StatefulDataloader(dataset) + for _ in range(max_steps): + # Data collection (only on data parallel head) + if is_dp_head: + batch = rollout.rollout_batch(next(dataloader), workflow=workflow) + batch_list = [batch] + else: + batch_list = [None] + + # Broadcast data to all processes + batch = dist.broadcast(batch_list, src=0, group=model_parallel_group)[0] + batch: TensorDict + + # Prepare training inputs + adv_batch = actor.compute_advantages_and_returns(batch) + batch['advantages'] = adv_batch['advantages'] + + # Execute PPO update + stats = actor.ppo_update(batch) + + # Update weights (coordinated across processes) + if rank == 0: + future = rollout.update_weights(wcfg) + actor.upload_weights(wcfg) + if rank == 0: + future.result() +``` + +The SPMD pattern uses only concepts familiar to AI researchers, though it requires some +control flow branching based on parallelism strategy. + +### Training Engine Architecture + +The training engine operates at two abstraction levels to balance flexibility with ease +of use. + +#### Basic Level: Backend Adapters + +The foundational level provides unified interfaces for RL algorithms, handling +computation, parameter management, and weight updates for inference engines. Each RL +training experiment must use one of the implemented backends: + +```python class TrainEngine(abc.ABC): - # control api - def __init__(self, para_config) - self.para_config = para_config - def initialize(self, addr: str|None, ft_spec|None): - # Initialize distributed environment, initialize and load model - # addr is the corresponding service address when deploying sglang or fsdp/megatron remotely - # The controller passes the master addr when calling - pass - - def get_scheduling_config(self): - # Get the resource configuration information required by the scheduler to schedule the engine, - # such as the engine's image, cpu/gpu/memory size - pass + def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None): + """Initialize distributed training environment and load models.""" + raise NotImplementedError() def destroy(self): - """Destroy the engine and release GPU memory.""" + """Clean up engine resources and release GPU memory.""" pass - async def upload_weights(self, meta: WeightUpdateMeta): - pass + def upload_weights(self, meta: WeightUpdateMeta): + """Upload weights to inference engine (blocking operation).""" + raise NotImplementedError() def save(self, meta: SaveLoadMeta): - pass + """Save model weights and optimizer states for checkpointing.""" + raise NotImplementedError() def load(self, meta: SaveLoadMeta): - pass - - # data api - def step_lr_scheduler(self): - """Step learning rate scheduler.""" - # Due to PPO minibatch updates, multiple train batches may need to be called - # before calling step_lr_scheduler once, so this api needs to be separated + """Load model weights and optimizer states from checkpoint.""" raise NotImplementedError() def train_batch( self, - input_: Dict, - loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], - loss_weight_fn: Callable[[Dict], float], + input_: TensorDict, + loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor], + loss_weight_fn: Callable[[TensorDict], float], ) -> Dict[str, float]: - pass + """Update model parameters using provided batch and loss function.""" + raise NotImplementedError() @torch.no_grad() - def eval_batch( - self, - input_: Dict, - loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], - loss_weight_fn: Callable[[Dict], float], - ) -> torch.Tensor | None: - pass - - @torch.no_grad() def forward( self, - input_: Dict, - output_seqlens: List[List[int]] | None = None, - post_hook: Callable[[torch.Tensor, Dict], Any] | None = None, + input_: TensorDict, + output_seqlens: List[int] | None = None, + post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None, aggregate_fn: Callable[[List[Any]], Any] = torch.cat, ) -> Any | None: - pass -############################################# + """Execute gradient-free forward pass for inference.""" + raise NotImplementedError() +``` -# Implementation example -class FSDPEngine(TrainEngine): - def __init__(self, config: EngineConfig): +#### Algorithm Level: Extended Engines + +Extended engines like PPO Actor provide algorithm-specific organization and +computational interfaces. They leverage backend core methods (such as `forward`) to +generate algorithm-required tensors and execute specialized model updates. The produced +objects (e.g., `FSDPPPOActor`) are also instances of `TrainEngine`s, but affiliated with +methods that specially designed for the algorithm (e.g., `ppo_update`). + +```python +class PPOActor: + + def __init__(self, config: PPOActorConfig, engine: TrainEngine): self.config = config - self.weight_update_group_initialized = False + self.engine = engine - def initialize(self, addr: str|None, ft_spec: FinetuneSpec): - self.model_config = AutoConfig.from_pretrained(self.config.path) - with torch.device("meta"): - model = AutoModelForCausalLM.from_config(self.model_config) - self.model = FSDP(model) - self.optimizer = Adam(self.model.parameters()) + @torch.no_grad() + def compute_logp( + self, + data: TensorDict, + temperature: Optional[float] = None, + ) -> torch.Tensor | None: + + def calc_logprobs(logits, input_data): + labels = torch.roll(input_data["input_ids"], shifts=-1, dims=-1) + logprobs = gather_logprobs(logits, labels, temperature or 1.0) + return logprobs + + self.engine.eval() + return self.engine.forward( + input_=data, + post_hook=calc_logprobs, + aggregate_fn=lambda xs: torch.cat(xs, dim=-1), + ) + + def compute_advantages(self, data: TensorDict) -> None: + """Compute advantages for PPO training.""" + # Implementation details... + pass + + def ppo_update(self, data: TensorDict) -> List[Dict[str, float]]: + """Execute PPO policy update.""" + # Implementation details... + pass + +class FSDPPPOActor(FSDPEngine): + """FSDP-backed PPO Actor implementation.""" + + def __init__(self, config: PPOActorConfig): + super().__init__(config) + self.actor = PPOActor(config, self) + + @torch.no_grad() + def compute_logp(self, *args, **kwargs) -> torch.Tensor | None: + return self.actor.compute_logp(*args, **kwargs) + + @torch.no_grad() + def compute_advantages(self, *args, **kwargs) -> None: + self.actor.compute_advantages(*args, **kwargs) + + def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]: + return self.actor.ppo_update(*args, **kwargs) +``` + +### Inference Engine Design + +The inference engine's core functionality revolves around `generate` and +`update_weights` methods. These methods can interface with HTTP server APIs or invoke +local LLM engines: + +```python +class InferenceEngine(abc.ABC): + + def initialize(self, addr: str | None, ft_spec): + """Initialize distributed inference environment and load models.""" + raise NotImplementedError() def destroy(self): - del self.optimizer - del self.model - gc.collect() - torch.cuda.empty_cache() - gc.collect() - - async def upload_weights(self, meta: WeightUpdateMeta): - if meta.type == 'nccl': - if not self.weight_update_group_initialized: - await self.init_distributed_weight_update(meta) - return await self.aupdate_weights_from_distributed() - if meta.type == 'disk': - self.save_to_hf(meta.path) - return - - def save(self, meta): - if meta.weight_format == 'hf': - self.save_to_hf(meta.path, meta.tokenizer, meta.base_model_path) - elif meta.weight_format == 'dcp': - self.save_model_to_dcp(meta.path) - - if meta.with_optim: - self.save_optimizer_state(meta.path) - - def load(self, meta): - ... - - - ############# Helper methods start ############# - - def load_from_hf(self, path): - sd = load_hf_state_dict(path) - load_fsdp_state_dict(self.model, full_sd=sd) - - def save_to_hf(self, path, - tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None, - base_model_path: Optional[str] = None, - ): - if dist.rank() == 0: - sd = {} - for n, p in self.model.named_parameters(): - sd[n] = p.data.gather() - torch.save(sd, path) - if tokenizer is not None: - tokenizer.save_pretrained(path) - if base_model_path is not None: - copy_hf_configs(base_model_path, path) - dist.barrier() - - def save_for_recover(self, path: str): - self.save_model_to_dcp(path) - self.save_optimizer_state(path) - - def load_from_recover(self, path): - self.load_model_from_dcp(path) - self.load_optimizer_state(path) - - async def init_distributed_weight_update(self, meta: WeightUpdateMeta): - # Initialize NCCL communication group for weight updates - ... - - async def aupdate_weights_from_distributed(self): - # Update inference weights through NCCL - # Different engines (FSDP, Megatron) may have different weight aggregation, - # splitting and communication methods - # Keep this high-level interface instead of defining subdivided interfaces - # to facilitate different engines implementing the most efficient weight communication path - ... - - def load_model_from_dcp(self, path: str): - # Load pytorch distributed checkpoint model from disk during recovery - ... - - def save_model_to_dcp(self, path: str): - # Save model in dcp format for recovery - ... - - def save_optimizer_state(self, path: str): - # Save optimizer state for recovery - raise NotImplementedError() - - def load_optimizer_state(self, path: str): - # Load optimizer state during recovery - raise NotImplementedError() -``` - -### Algorithm-Specific TrainEngine API - -Extended engines (such as Actor in PPO) provide convenient organization and calling of -computational interfaces specific to algorithms. These computational interfaces maintain -single-process computational logic, but can be called by controllers in top-level -training scripts to complete distributed semantic computational orchestration. - -```python -class Actor(Engine): - - @torch.no_grad() - def compute_logps(self, input_: Dict[str, Tensor]) -> torch.Tensor: - ... # unpad - logps = self.forward(xxx) - ... # pad back - return logps - - def compute_advantages_and_returns(self, input_: Dict) -> Dict: - pass - - def ppo_update(self, input_: Dict) -> List[Dict[str, float]]: - ... - all_stats = [] - for _ in range(self.ppo_n_minibatches): - stats = self.train_batch(xxx, loss_fn=actor_loss_fn) - all_stats.append(stats) - return all_stats - -class Critic(Engine): - - @torch.no_grad() - def compute_values(self, input_: Dict) -> torch.Tensor: - pass - - def ppo_update(self, input_: Dict) -> List[Dict[str, float]]: - ... - all_stats = [] - for _ in range(self.ppo_n_minibatches): - stats = self.engine.train_batch(xxx, loss_fn=critic_loss_fn) - all_stats.append(stats) - return all_stats - -class FSDPActor(FSDPEngine, Actor): - pass - -class MegatronActor(FSDPEngine, Actor): - pass - -class FSDPCritic(MegatronEngine, Critic): - pass - -class MegatronCritic(MegatronEngine, Critic): - pass -``` - -### Inference Engine API - -Define the InferenceEngine API in a local-like mode, rather than a client-server -separated form, mainly for user-centered convenience in using the InferenceEngine as a -tool. - -InferenceEngine can internally start a SGLang subprocess (SGLangEngine), or call a -remotely deployed service (RemoteSGLangEngine). - -```python -class InferenceEngine(ABC): - def __init__(self, config) - self.config = config - - @abstractmethod - def initialize(self, addr: str|None, ft_spec): - """Start SGLang Engine, starting the engine will call model loading by default - """ - self.tasks = [] - - async def update_weights(self, meta) -> None: - # Update model weights based on meta information + """Clean up engine resources and release GPU memory.""" pass async def agenerate(self, req: LLMRequest) -> LLMResponse: - # Given a prompt, generate a response with LLM - pass + """Generate response asynchronously for the given request.""" + raise NotImplementedError() - def submit(self, data: Dict[str, Any], workflow): - """ - Asynchronously submit rollout request - """ - task = asyncio.create_task(workflow.arun_episode(self, data)) - self.tasks.append(task) + def update_weights(self, meta: WeightUpdateMeta) -> Future: + """Update inference engine weights asynchronously.""" + raise NotImplementedError() +``` + +#### Workflow Integration + +User-defined rollout workflows utilize the `InferenceEngine` to generate trajectories. +The workflow's `arun_episode` method produces one or more trajectories from a single +prompt. The generation process is streaming rather than batched, with each dataset item +processed independently. Here's a simplified RLVR example: + +```python +class RLVRWorkflow(RolloutWorkflow): + async def arun_episode(self, engine: InferenceEngine, data): + input_ids = self.tokenizer.apply_chat_template( + data["messages"], + tokenize=True, + add_generation_prompt=True, + enable_thinking=self.enable_thinking, + ) + + n_samples = self.gconfig.n_samples + req = LLMRequest( + rid=uuid.uuid4().hex, + input_ids=input_ids, + gconfig=self.gconfig.new(n_samples=1), + ) + + # Generate multiple responses concurrently + resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)]) - async def _wait_async(self, count: int, timeout: int) -> DistributedBatch: - tik = time.time() - n = 0 results = [] - while time.time() - tik < timeout and n < count: - done, _ = await asyncio.wait(self.tasks, return_when=FIRST_COMPLETED) - for task in done: - results.append(await task) - if n < count: - raise TimeoutError() - return DistributedBatch(results) + for resp in resps: + reward = self.reward_fn( + prompt=prompt_str, + completions=completions_str, + prompt_ids=resp.input_tokens, + completion_ids=resp.output_tokens, + **data, + ) - def wait(self, count: int, timeout: int) -> DistributedBatch: - """ - Synchronous wait interface, until the request returns count records - """ - return asyncio.run(self._wait_async(count, timeout)) + results.append(TensorDict(res, batch_size=[1])) - @abstractmethod - def rollout(self, data: List[Dict[str, Any]], workflow) -> DistributedBatch: - """ - Synchronously submit rollout request, until all inference requests are completed and returned - """ - pass - -################################### -# Implementation example -class SGLangEngine(InferenceEngine): - - def __init__(self, config: InfEngineConfig): - self.config = config - self.weight_update_group_initialized = False - - async def update_weights(self, meta) -> None: - if meta.type == 'nccl': - if not self.weight_update_group_initialized: - await self.init_distributed_weight_update(meta) - return await self.aupdate_weights_from_distributed() - if meta.type == 'disk': - self.update_weights_from_disk(meta.path) - return - - async def agenerate(self, req: LLMRequest) -> LLMResponse: - # Given a prompt, generate a response with LLM - return await self.llm.generate_async(xxx) - - # Weight update - @abstractmethod - def update_weights_from_disk(self, path) -> None: - """Update model weights from disk""" - ... - - @abstractmethod - async def ainit_distributed_weights_update(self, meta_info: WeightUpdateMeta): - # Initialize **all** needed weight synchronization communication groups and communication plans - # (which communication type and which parameters to communicate at each step) - # Depending on the engine's partitioning method, multiple communication groups may be initialized - # for weight synchronization - # Since both inference and training engines need to enter this function, - # it needs to be defined as an async function - ... - - @abstractmethod - async def aupdate_weights_from_distributed(self) -> None: - """Use the initialized weight communication plan and communication groups to update model weights with NCCL - - Since both inference and training engines need to enter this function, - it needs to be defined as an async function - """ - ... - - @abstractmethod - def check_health(self) -> bool: - """Check server health status - - Returns: - bool: Whether the server is healthy - """ - pass + return concat_padded_tensors(results) ``` -### RolloutWorkflow +#### Batch Processing and Asynchronous Operations -RolloutWorkflow is a class that provides the arun_episode method. This method has a -fixed signature, used to collect one agent trajectory. +While individual trajectory collection is straightforward, batching and asynchronous +execution require additional infrastructure. The `InferenceEngine` provides high-level +methods: `submit`, `wait`, `rollout_batch`, and `prepare_batch`. + +The `rollout_batch` method submits multiple `workflow.arun_episode` jobs to an +asynchronous thread pool using `submit`, then waits for completion using `wait`. The +`prepare_batch` method separates submission and waiting to enable asynchronous rollout: + +```python +def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None: + try: + self.input_queue.put_nowait((data, workflow)) + except Full: + raise RuntimeError("Input queue full. Consider increasing queue_size.") + +def wait( + self, + count: int, + timeout: float | None = None, + should_accept: Callable | None = None, +) -> TensorDict: + """Wait for specified number of results with optional filtering.""" + # Implementation details... + pass + +def rollout_batch( + self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow" +) -> TensorDict: + """Submit batch requests and wait for all results.""" + for item in data: + self.submit(item, workflow) + return self.wait(count=len(data)) + +def prepare_batch( + self, + dataloader: StatefulDataLoader, + workflow: "RolloutWorkflow", +): + """Prepare batch for asynchronous processing.""" + # Implementation details... + pass +``` + +### RolloutWorkflow Interface + +The `RolloutWorkflow` class provides the `arun_episode` method with a standardized +signature for collecting agent trajectories: ```python class MyRolloutWorkflow: @@ -521,263 +515,21 @@ class MyRolloutWorkflow: async def arun_episode(self, engine: InferenceEngine, data: Dict[str, Any]) -> Dict[str, Tensor]: - ... req = LLMRequest(input_ids=data['input_ids'], ...) + for _ in range(self.config.num_turns): resp = await engine.agenerate(req) res = await self.tool_executor.aexecute_tool(resp.completion) req.input_ids += res + reward = my_reward_fn(None, None, None, req.input_ids, **data) - ... return output ``` -### RolloutController & TrainController +### Controller Architecture -They have the same API as `InferenceEngine` and `TrainEngine`, respectively. - -## Other Components - -1. Algorithm workflows don't necessarily use these components; other replaceable - components such as implementations in rllm or verl-agent can also be used -1. Mainly for internal implementation and division of labor, as a thin wrapper to - facilitate adaptation of external packages - -### Environment - -1. Support multi-tool management and unified execution interface -1. Support local calling and mcp service calling -1. "Tools" belong to an instance rather than a class, register_tool is defined as a - method rather than a static method, this is (1) to prevent tools from subclasses - being registered to the base class, causing potential naming or calling conflicts; - (2) to support multiple tasks for the same service (e.g., browser), with each task - having a different toolset - -```python -from abc import ABC, abstractmethod -from typing import Any, Dict, List - -class ToolingEnv: - def __init__(self): - self.is_initialized = False - - self.tool_registry: Dict[str, Callable] = {} - self.tool_schemas: List[Dict[str, Any]] = [] - - async def ainitialize(self): - """ - Performs the initialization logic for the environment. - For stateful environments, this is where resources are created and - prepared (e.g., launching a browser). - """ - pass - - def list_tools(self) -> List[Dict[str, Any]]: - pass - - async def aexecute(self, tool_name: str, tool_args: Dict[str, Any]) -> Any: - pass - - async def aclose(self): - """ - Destroys the environment, releasing all held resources. - This method is critical for stateful environments (e.g., a browser session). - """ - pass - -class MCPToolingEnv(ToolingEnv): - def __init__(self, config: MCPToolingEnvConfig): - self.config = config - # init a mcp client - self.mcp_client = mcp.session_client(config.url) - self.tool_registry: Dict[str, Callable] = mcp.session_client.list_tools() - self.tool_schemas: List[Dict[str, Any]] = [] - - def list_tools(self) -> List[Dict[str, Any]]: - return self.tool_schemas - - def aexecute(self, tool_name: str, tool_args: Dict[str, Any]): - pass - - -class LocalToolingEnv(ToolingEnv): - - @staticmethod - def generate_schema(func: Callable) -> Dict[str, Any]: - """ - Generates a JSON schema for a function using introspection. - """ - # Use the function's docstring as the tool's description. - description = inspect.getdoc(func) or "No description provided." - sig = inspect.signature(func) - parameters = { - "type": "object", - "properties": {}, - "required": [], - } - - # Mapping from Python types to JSON Schema types - type_mapping = { - str: "string", - int: "integer", - float: "number", - bool: "boolean", - } - - for name, param in sig.parameters.items(): - # Default to string type if type hint is missing or complex - param_type = type_mapping.get(param.annotation, "string") - parameters["properties"][name] = {"type": param_type} - - # If a parameter has no default value, it is considered required. - if param.default is inspect.Parameter.empty: - parameters["required"].append(name) - - return { - "type": "function", - "function": { - "name": func.__name__, - "description": description, - "parameters": parameters, - } - } - - def register_tool(self, func: Callable) -> Callable: - """ - A decorator that registers a Python function as a tool in this environment. - """ - if not callable(func): - raise TypeError("The provided object must be a callable function.") - - tool_name = func.__name__ - if tool_name in self.tool_registry: - raise ValueError(f"Tool with name '{tool_name}' is already registered.") - - # Add the function to the registry and its schema to the schema list. - self.tool_registry[tool_name] = func - self.tool_schemas.append(self.generate_schema(func)) - - def list_tools(self) -> List[Dict[str, Any]]: - """ - Lists all available tools provided by this environment and their descriptions. - - Returns: - A list of dictionaries, where each dictionary describes a tool. - """ - return self.tool_schemas - - async def aexecute(self, tool_name: str, tool_args: Dict[str, Any]) -> Any: - """ - Executes a specified tool. - - Args: - tool_name (str): The name of the tool to execute. - tool_args (Dict[str, Any]): The arguments required to call the tool. - - Returns: - Any: The result of the tool's execution, typically a string or - structured JSON. - """ - if tool_name not in self._tool_registry: - return f"Error: Tool '{tool_name}' is not registered." - - tool_func = self._tool_registry[tool_name] - - try: - result = tool_func(**tool_args) - return result - except TypeError as e: - # This exception is often triggered by missing or incorrect argument types. - return f"Error executing '{tool_name}': Invalid arguments. Details: {e}" - except Exception as e: - return f"Error executing '{tool_name}': An unexpected error occurred. Details: {e}" - -``` - -### Reward - -1. Workflows for computing rewards using models and computing rewards based on rules - should be separated -1. Rule-based reward computation is defined as a function with a predefined signature, - which can be local or remote, and is generally wrapped in the rollout workflow; - user-written reward functions can also not use this signature as long as they provide - a workflow -1. Computing rewards using models requires users to initialize a controller/engine - themselves and explicitly call it in the algorithm workflow - -```python - -############ Rule-based reward ############ - -# The signature is just a reference, can be defined arbitrarily -def reward_fn(prompt: str, completion: str, prompt_ids: List[int], - completion_ids: List[int], **kwargs): - # prompt: prompt string (the task this data needs to complete) - # completion: trajectory string generated by the model based on the task - # prompt_ids: token ids of the prompt - # completion_ids: token ids of the trajectory generated by the model - # kwargs: all other attributes of this data in the dataset, - # for example, solutions, input_outputs, etc. - pass - -############ Model-based reward ############ - -reward = TrainController(Critic()) -rollout_controller = RolloutController(...) -for _ in range(epochs): - for _ in range(steps_per_epoch): - data = rollout_controller.rollout_batch(prompt) - data['reward'] = reward.compute_values(data) - ... -``` - -### Dataset - -Use huggingface datasets and pytorch torchdata. In Single-Controller mode, only one -process per experiment is responsible for data loading. - -```python -from datasets import Dataset, load_dataset -from datasets.distributed import split_dataset_by_node - -dataset = load_dataset( - path, - name=name, - split=split, - data_files=data_files, -) -dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) - -# Access the first data item -data: Dict = dataset[0] - -# Access the "prompt" column -data: List = data['prompt'] - -# Data processing -def process_example(example, idx): - # Add query_id column - example["query_id"] = str(idx) - example["prompt"] = example["question"] - - # used by the reward function - example["method"] = reward_mode - return example - -dataset = dataset.map( - lambda example, idx: process_example(example, idx), - with_indices=True, -) - -# Data loading and shuffle -from torchdata.stateful_dataloader import StatefulDataLoader -dataloader = StatefulDataLoader( - dataset=dataset, - batch_size=batch_size, - shuffle=True, - drop_last=True, - collate_fn=lambda x: x, # Can change the data batch packing method by changing this parameter -) -for data in dataloader: - assert isinstance(data, list) -``` +`RolloutController` and `TrainController` mirror the APIs of `InferenceEngine` and +`TrainEngine` respectively. Controllers handle engine deployment across the cluster and +manage data distribution, invoking engine methods through remote procedure calls (RPCs). +This architecture enables distributed operation while maintaining familiar interfaces +for users. diff --git a/blog/AReaL_v0_2.md b/blog/AReaL_v0_2.md index aa03561..f6d2cea 100644 --- a/blog/AReaL_v0_2.md +++ b/blog/AReaL_v0_2.md @@ -1,5 +1,5 @@

-AReaL v0.2: Training a SOTA 7B LRM with 1.5x Throughput Improvment +AReaL v0.2: Training a SOTA 7B LRM with 1.5x Throughput Improvement

diff --git a/blog/AReaL_v0_3.md b/blog/AReaL_v0_3.md index aa9e662..0222343 100644 --- a/blog/AReaL_v0_3.md +++ b/blog/AReaL_v0_3.md @@ -1,4 +1,6 @@ -# AReaL v0.3 +

+AReaL v0.3: SOTA Coding Models with 2.77x Faster Asynchronous RL Training +

## Introduction diff --git a/docs/_toc.yml b/docs/_toc.yml index 30c9774..8ecdd01 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -7,8 +7,8 @@ parts: - caption: Tutorial chapters: - file: tutorial/installation - - file: tutorial/quickstart_arealite - file: tutorial/quickstart + - file: tutorial/quickstart_legacy - file: tutorial/eval - file: tutorial/troubleshooting - caption: Getting Started with AReaLite diff --git a/docs/arealite/gsm8k_grpo.md b/docs/arealite/gsm8k_grpo.md index 9fdfc2d..26125d3 100644 --- a/docs/arealite/gsm8k_grpo.md +++ b/docs/arealite/gsm8k_grpo.md @@ -1 +1,497 @@ -# Running GRPO on GSM8K Dataset \ No newline at end of file +# Running GRPO on GSM8K Dataset + +This guide introduces how AReaLite runs the GRPO algorithm on the GSM8K dataset, using +the training script +[examples/arealite/gsm8k_grpo.py](../../examples/arealite/gsm8k_grpo.py) and +configuration file +[examples/arealite/configs/gsm8k_grpo.yaml](../../examples/arealite/configs/gsm8k_grpo.yaml). + +## How AReaLite Works + +The following figure illustrates the launching and one asynchronous training step of the +GRPO algorithm on the GSM8K dataset on AReaLite. Compared with the old AReaL +implementation, AReaLite runs inference servers and a SPMD training script instead of a +bunch of various workers. In a training step, AReaLite: + +1. Submits prompts from the dataset to `RemoteSGLangEngine`, who runs `RLVRWorkflow` in + a streaming manner. +1. Completes `RLVRWorkflow` by interacting with remote `SGLangServer` instances to + generate sequences, and computing rewards with the reward function. +1. Once there are enough outputs from `RLVRWorkflow`, aggregates them into a data batch + for algorithm-specific training engine `FSDPPPOActor`. +1. Computes losses and update weights in `FSDPPPOActor`. +1. Transfers the updated weights to remote `SGLangServer` instances. + +![arealite-gsm8k-example](gsm8k_grpo.png) + +In the following sections, we will walk you through the code to explain concepts and +show you how these steps are done in details. + +## Launching the Experiment + +As shown in [Quickstart Guide](../tutorial/quickstart.md), experiments in AReaLite are +launched using standalone launchers with the following commands: + +``` +# Local Launcher +python -m arealite.launcher.local --config +# Ray Launcher +python -m arealite.launcher.ray --config +# Slurm Launcher +python -m arealite.launcher.slurm --config +``` + +In AReaLite: + +- The **training script** is an SPMD python script that serves as the experiment entry + point. +- The launcher runs the training script with its distributed backend (`subprocess` for + `LocalLauncher`, `ray.remote` for `RayLauncher`, `srun` for `SlurmLauncher`). +- The launcher also manages inference servers (currently only supporting + `SGLangServer`). The number and parallelization strategies (e.g. tensor parallel) are + determined by the option [allocation_mode](../../arealite/api/cli_args.py#L797). +- For distributed launchers (`RayLauncher` and `SlurmLauncher`), inference servers run + with a wrapper + [arealite/launcher/sglang_server.py](../../arealite/launcher/sglang_server.py) to + handle addresses and ports in distributed settings. +- After `SGLangServer` instances are started, launchers collect their addresses and + ports to set the `AREAL_LLM_SERVER_ADDRS` environment variable for training scripts to + access these inference servers. + +The **configuration file** is a YAML file that sets the options provided in +[arealite/api/cli_args.py](../../arealite/api/cli_args.py). It could be modified via CLI +arguments such as `actor.path=Qwen/Qwen3-1.7B` and `+sglang.attention_backend=triton`. +The training scripts parse the config with CLI arguments into the config class defined +in [arealite/api/cli_args.py](../../arealite/api/cli_args.py). + +``` +config, _ = load_expr_config(args, GRPOConfig) +config: GRPOConfig +``` + +## Loading and Preprocessing Dataset + +We use the `datasets` and `torchdata` packages to load and preprocess the dataset into +our dataloader. First, we download `openai/gsm8k` from Huggingface and split it by data +parallel ranks, then map it to our desired format: + +```python +def process_gsm8k_rl_dataset(dataset: Dataset): + def process(sample): + messages = [{"role": "user", "content": sample["question"]}] + return {"messages": messages} + dataset = dataset.map(process).remove_columns(["question"]) + return dataset + +def get_gsm8k_dataset(split, rank, world_size): + dataset = load_dataset(path="openai/gsm8k", name="main", split=split) + dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) + return process_gsm8k_rl_dataset(dataset) +``` + +We then prepare training and evaluation dataloaders with `torchdata.StatefulDataLoader`: + +```python +train_dataloader = torchdata.StatefulDataLoader( + get_gsm8k_dataset("train", rank, world_size), + batch_size=config.train_dataset.batch_size // world_size, + shuffle=config.train_dataset.shuffle, + num_workers=config.train_dataset.num_workers, + collate_fn=lambda x: x, + drop_last=config.train_dataset.drop_last, +) +valid_dataloader = ... +``` + +If you wish to use your own huggingface datasets or datasets on your local storage, +please refers to [Customization: Dataset](../customization/dataset.md) for further +details. + +## Rollout + +The data lifecycle is controlled by an `RLVRWorkflow`, which defines how data progresses +from prompts to complete rollout data containing all fields required for training. Our +example shows a single-turn RLVR workflow with a math reward function. The core logic of +the workflow is implemented in an async method `arun_episode`, which takes a prompt, +generate answers with `RemoteSGLangEngine`, computes rewards, and populates additional +fields to produce finalized training data. + +```python +class RLVRWorkflow(RolloutWorkflow): + def __init__( + self, reward_fn, gconfig, tokenizer, ... + ): + self.reward_fn = reward_fn + self.gconfig = gconfig + self.tokenizer = tokenizer + + async def arun_episode(self, engine, data): + # rollout data with inference engine + input_ids = self.tokenizer.apply_chat_template(data["message"], ...) + req = LLMRequest(rid=..., input_ids=input_ids, gconfig=self.gconfig.new(n_samples=1)) + resps = await asyncio.gather( + *[engine.agenerate(req) for _ in range(self.gconfig.n_samples)] + ) + # post process rollout responses + results = [] + for resp in resps: + reward = self.reward_fn(...) + ... # other required fields for training + res = dict( + input_ids=..., + rewards=..., + ... # other required fields for training + ) + results.append(res) + # return padded `self.gconfig.n_samples` samples with prompt `data["message"]` + return concat_padded_tensors(results) + +def gsm8k_reward_fn(completions, answer): + ... + +tokenizer = load_hf_tokenizer(config.tokenizer_path) +workflow = RLVRWorkflow( + reward_fn=gsm8k_reward_fn, + gconfig=config.gconfig, + tokenizer=tokenizer, + ... +) +``` + +In AReaLite, generation tasks are offloaded to remote inference servers, which operate +on separate GPUs from those used for training. The `RemoteSGLangEngine` acts as a client +that interacts with the servers. `RemoteSGLangEngine` runs in a SPMD manner on every +training process, without occupying any GPUs. + +`RemoteSGLangEngine` is responsible for managing the data streaming through rollout +workflows, and collates completed rollout data into batched training samples. When +initializing, it launches a rollout thread that runs rollout workflows as `asyncio` +tasks. The following code shows the simplified version of rollout thread implementation, +which iteratively: + +- Checks available capacity. The capacity controls current number of rollout workflows + to limit concurrency and data off-policyness. +- If there is capacity left and rollout is not paused for weight update, continuously + obtains data from `input_queue` and creates `asyncio` tasks to run the workflows. +- Waits for rollout workflows to finish. +- Gathers data from finished workflows and puts them into `output_queue` + +```python +class RemoteSGLangEngine(InferenceEngine): + ... + async def _rollout_thread_async(self): + rid = 0 + try: + while not self.exiting.is_set(): + # Check capacity + capacity = self.get_capacity() + # Create rollout tasks with data obtained from input_queue + while ( + capacity > 0 + and not self.paused.is_set() + and self.input_queue.qsize() > 0 + ): + data, workflow = self.input_queue.get_nowait() + task = asyncio.create_task( + workflow.arun_episode(self, data), name=str(rid) + ) + rollout_tasks[str(rid)] = task + self.rollout_stat.submitted += 1 + self.rollout_stat.running += 1 + capacity -= 1 + rid += 1 + # Wait for rollout completion + tasks = list(rollout_tasks.values()) + done = [] + if tasks: + done, _ = await asyncio.wait( + tasks, + timeout=ROLLOUT_POLL_WAIT_TIME, + return_when=asyncio.FIRST_COMPLETED, + ) + # Collect done results, put the results into output queue + for task in done: + traj = await task + task_rid = task.get_name() + rollout_tasks.pop(task_rid) + self.rollout_stat.accepted += 1 + self.output_queue.put_nowait(traj) + self.rollout_stat.running -= 1 + await asyncio.sleep(1) + ... +``` + +With this rollout thread running, the training script (the main thread) submits prompts +into `input_queue` and collates rollout data from `output_queue` into training batches +with `prepare_batch` (for asynchronous RL) or `rollout_batch` (for synchronous RL). The +following code shows the implementation of `prepare_batch`: + +```python +def prepare_batch( + self, + dataloader: StatefulDataLoader, + workflow: "RolloutWorkflow", +): + if not hasattr(self, "data_generator"): + self.data_generator = iter(dataloader) + assert dataloader.batch_size is not None + while True: + # Submit at least two batches to allow maximum overlap + if ( + self.get_capacity() + dataloader.batch_size > 0 + and self.input_queue.qsize() + dataloader.batch_size + < self.input_queue.maxsize + ): + try: + data = next(self.data_generator) + except StopIteration: + self.data_generator = iter(dataloader) + data = next(self.data_generator) + for item in data: + # submit data into input_queue + self.submit(item, workflow=workflow) + try: + # wait for dataloader.batch_size data from output_queue + return self.wait(dataloader.batch_size, timeout=1) + except TimeoutError: + pass +``` + +The usage of `RemoteSGLangEngine` in the training script is simple: + +```python +rollout = RemoteSGLangEngine(config.rollout) +rollout.initialize() +eval_rollout = ... + +data_generator = iter(train_dataloader) +for global_step in range(max_steps): + # rollout batched training data for current step + if config.async_training: + batch = rollout.prepare_batch(train_dataloader, workflow=workflow) + else: + try: + data = next(data_generator) + except StopIteration: + data_generator = iter(train_dataloader) + data = next(data_generator) + batch = rollout.rollout_batch(data, workflow=workflow) +``` + +If you want to use rollout workflows with custom reward functions or agentic tool +calling, see [Customization: Rollout Workflows](../customization/agent.md) for more +details. + +## Training + +After obtaining the training batch, we use `FSDPPPOActor` to calculate losses and update +weights. Each train engine corresponds to one model, therefore we need an additional +engine for the reference model. Note that `torch.distributed` process groups will be +lazily initialized using `init_process_group` when the first train engine is +initialized. The initialization of train engine will also load model weights from paths +specified by the configuration. + +```python +actor = FSDPPPOActor(config=config.actor) +actor.initialize(None, ft_spec) +ref = None +if config.actor.kl_ctl > 0 and config.ref is not None: + ref = FSDPPPOActor(config=config.ref) + ref.initialize(None, ft_spec) +``` + +`FSDPPPOActor` is a high-level engine with algorithm-specific APIs, such as +`compute_logp`,`compute_advantages` and `ppo_update`. `FSDPPPOActor` is powered by the +lower-level train engine `FSDPEngine`, which use **pytorch FSDP2** to provide basic APIs +for the model such as `train_batch` and `forward`. The following code shows a GRPO +training step: + +```python +logp = actor.compute_logp(batch) +batch["prox_logp"] = logp +if ref is not None: + batch["ref_logp"] = ref.compute_logp(batch) + log_gpu_stats("ref logp") +actor.compute_advantages(batch) +stats = actor.ppo_update(batch) +actor.step_lr_scheduler() +``` + +If you want to customize your own training algorithm, see +[Customize algorithms](../customization/algorithm.md) for more details. + +## Transferring Weights to Inference Servers + +After training, we transfer updated model weights to remote inference servers through +cooperation between `FSDPPPOActor` and `RemoteSGLangEngine`. We provide options to +transfer model weights from shared storage or NCCL. In our example training script, we +first prepare `WeightUpdateMeta` for NCCL backend on all training processes. + +```python +# NOTE: Weight update meta only requires address and free port of rank 0, +# but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks +# due to `engine.get_param_specs()`. +# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0. +weight_update_meta = [ + WeightUpdateMeta.from_fsdp_nccl( + AllocationMode.from_str(config.allocation_mode), actor + ) +] +dist.broadcast_object_list(weight_update_meta, src=0) +weight_update_meta = weight_update_meta[0] +``` + +If you wish to transfer model weights from shared storage, you can use: + +```python +weight_update_meta = WeightUpdateMeta.from_disk(config.saver) +``` + +After a training step is finished, we transfer new weights from actor engine to remote +inference servers with steps shown in the following code: + +```python +# 1. Pause rollout on remote inference servers +rollout.pause() +# 2. Send requests to remote servers, tell them to update weights +if dist.get_rank() == 0: + future = rollout.update_weights(weight_update_meta) +# 3. Actor begins to transfer weights +actor.upload_weights(weight_update_meta) +# 4. Wait for remote servers to return after finishing updates +if dist.get_rank() == 0: + future.result() +# 5. Synchronize rollout processes for model version update +dist.barrier(device_ids=[actor.device.index]) +torch.cuda.synchronize() +# 6. Resume rollout on remote inference servers +rollout.resume() +# 7. Set version, ensures versions on actor and rollout engine are identical +actor.set_version(global_step + 1) +rollout.set_version(global_step + 1) +``` + +Now a complete GRPO training step in AReaLite is done! The core logic of our example +training script can be summarized as: + +```python +data_generator = iter(train_dataloader) +for global_step in range(max_steps): + if config.async_training: + batch = rollout.prepare_batch(train_dataloader, workflow=workflow) + else: + try: + data = next(data_generator) + except StopIteration: + data_generator = iter(train_dataloader) + data = next(data_generator) + batch = rollout.rollout_batch(data, workflow=workflow) + + logp = actor.compute_logp(batch) + batch["prox_logp"] = logp + if ref is not None: + batch["ref_logp"] = ref.compute_logp(batch) + log_gpu_stats("ref logp") + actor.compute_advantages(batch) + stats = actor.ppo_update(batch) + actor.step_lr_scheduler() + + rollout.pause() + if dist.get_rank() == 0: + future = rollout.update_weights(weight_update_meta) + actor.upload_weights(weight_update_meta) + if dist.get_rank() == 0: + future.result() + rollout.resume() + actor.set_version(global_step + 1) + rollout.set_version(global_step + 1) +``` + +## Utilities + +In AReaLite, we provide a wide range of utilities for basic functionalities required for +observing and tuning your experiments. + +### `Saver` and `Evaluator` + +`Saver` ([arealite/utils/saver.py](../../arealite/utils/saver.py)) and `Evaluator` +([arealite/utils/evaluator.py](../../arealite/utils/evaluator.py)) manage the frequency +to save and evaluate the model with the train engine. + +In our example, we call `saver.save` and `evaluator.evaluate` after every training step. +these two methods will automatically check if it is time to save or evaluate the model, +according to the experiment configuration. + +### `stats_tracker` + +`stats_tracker` ([realhf/base/stats_tracker.py](../../realhf/base/stats_tracker.py)) +gathers training statistics across parallel ranks and reduce them. + +1. **Scalar-type statistics** are recorded by `stats_tracker.scalar(key=value)` and will + be averaged by the number of scalars with the same key when reduced. +1. **Tensor-type statistics** require `denominator` and `reduce_type` to decide how to + reduce statistics under the same key. + +- `denominator` is a bool tensor that masks the elements in the tensor that we do not + want to record. +- `reduce_type` includes average, sum, min and max. By default, the average, min and max + are all calculated. + +For example, if we want to record the length of sequences with correct and incorrect +answers in a training batch: + +```python +seqlens = ... # tensor of shape [#seqs,] +reward_score = ... # tensor of shape [#seqs,] + +result_denominators = { + "correct_n_seqs": (reward_score > 0).bool(), + "incorrect_n_seqs": (reward_score <= 0).bool(), +} +# register the denominator +stats_tracker.denominator(**result_denominators) +# record the correct and incorrect sequence length +stats_tracker.stat( + correct_seq_len=seqlens.float(), denominator="correct_n_seqs" +) +stats_tracker.stat( + incorrect_seq_len=seqlens.float(), denominator="incorrect_n_seqs" +) +``` + +`stats_tracker` offers timer context to record time cost of a code block as a scalar. +And there is also a scope context to manage keys of statistics. + +```python +with stats_tracker.record_timing("train_step"): + # training step + ... + +with stats_tracker.scope("A"): + stats_tracker.scalar(c=123) # key="A/c", value=123 + with stats_tracker.scope("B"): + stats_tracker.scalar(c=234) # key="A/B/c", value=234 +``` + +After recording sufficient data, e.g. after a `train_batch` is finished, +`stats_tracker.export` is called to aggregate all statistics and dump them into a +dictionary. + +```python +stats = stats_tracker.export() +``` + +### `StatsLogger` + +`StatsLogger` ([arealite/utils/stats_logger.py](../../arealite/utils/stats_logger.py)) +logs gathered training data to recorders like `wandb` and `tensorboard` on rank 0. In +our example script, after finishing a training step, +`logger.commit(epoch, step, global_step, stats)` is called to record all statistics from +`stats_tracker` to print them as well as log them into the recorders set by the +configuration. + +## Next Steps + +- [Customize dataset](../customization/dataset.md) +- [Customize Agentic/RVLR rollout workflows](../customization/agent.md) +- [Customize algorithms](../customization/algorithm.md) diff --git a/docs/arealite/gsm8k_grpo.png b/docs/arealite/gsm8k_grpo.png new file mode 100644 index 0000000..2a89b47 Binary files /dev/null and b/docs/arealite/gsm8k_grpo.png differ diff --git a/docs/legacy/customization/agent.md b/docs/legacy/customization/agent.md index c926a6d..cec3ad3 100644 --- a/docs/legacy/customization/agent.md +++ b/docs/legacy/customization/agent.md @@ -146,7 +146,7 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig): ## Step 4: Run Training Follow the standard training procedure outlined in the -[quickstart guide](../../tutorial/quickstart.md). Launch your experiment with: +[quickstart guide](../../tutorial/quickstart_legacy.md). Launch your experiment with: ```bash python3 training/main_async_ppo.py my_param=5.0 # plus any additional CLI arguments diff --git a/docs/tutorial/quickstart.md b/docs/tutorial/quickstart.md index 5d22665..15bfd24 100644 --- a/docs/tutorial/quickstart.md +++ b/docs/tutorial/quickstart.md @@ -1,123 +1,121 @@ -# Quickstart (Legacy) +# Quickstart -> **Note**: This is a quickstart guide for launching AReaL experiment with legacy code in `realhf/`. We strongly recommend users to try AReaLite for better experiences. [Click here](quickstart_arealite.md) for AReaLite quickstart guide! +Welcome to the **AReaLite** Quickstart Guide! This guide demonstrates how to run an +AReaLite experiment training an LLM on the GSM8K dataset using the GRPO algorithm with +function-based rewards. Ensure you've completed +[the installation and environment setup](installation.md) before proceeding. -This guide walks you through a simple example of training an LLM to solve math problems. Please ensure you have properly [installed dependencies and set up the runtime environment](installation.md) before proceeding. +## Running the Experiment (on a single node) -## Dataset +To run the experiment, you will need: -Use `huggingface-cli` to download our open-source dataset: +- Training script: + [examples/arealite/gsm8k_grpo.py](../../examples/arealite/gsm8k_grpo.py) +- Config YAML: + [examples/arealite/configs/gsm8k_grpo.yaml](../../examples/arealite/configs/gsm8k_grpo.yaml) -```bash -huggingface-cli download --repo-type=dataset inclusionAI/AReaL-RL-Data +Our training scripts will automatically download the dataset (openai/gsm8k) and model +(Qwen/Qwen2-1.5B-Instruct). To run the example with default configuration, execute from +the repository directory: + +``` +python3 -m arealite.launcher.local examples/arealite/gsm8k_grpo.py --config examples/arealite/configs/gsm8k_grpo.yaml experiment_name= trial_name= ``` -> **Note**: The command above will display the path of the downloaded dataset. You'll need to pass this path to the training command. +> **Note**: The command above uses `LocalLauncher`, which only works for a single node +> (`cluster.n_nodes == 1`). For distributed experiments, see +> [Distributed Experiments with Ray or Slurm](quickstart.md#distributed-experiments-with-ray-or-slurm). -## Model +## Modifying configuration -We train using open-source models available on Hugging Face Hub. You can either download the model in advance or use the model identifier when running the experiment. +All available configuration options are listed in +[arealite/api/cli_args.py](https://github.com/inclusionAI/AReaL/blob/main/arealite/api/cli_args.py). +To customize the experiment (models, resources, algorithm options), you can: -```bash -# If you want to download it in advance -huggingface-cli download Qwen/Qwen3-1.7B +1. Edit the YAML file directly at + [examples/arealite/configs/gsm8k_grpo.yaml](../../examples/arealite/configs/gsm8k_grpo.yaml). +1. Add command-line options: + - For existing options in the YAML file, directly add the option: + `actor.path=Qwen/Qwen3-1.7B`. + - For other options in `cli_args.py`, but not in the YAML file, add with a prefix + "+": `+sglang.attention_backend=triton`. + +For example, here is the command to launch a customized configuration, based on our +GSM8K GRPO example: + +``` +python3 -m arealite.launcher.local examples/arealite/gsm8k_grpo.py \ + --config examples/arealite/configs/gsm8k_grpo.yaml \ + experiment_name= \ + trial_name= \ + allocation_mode=sglang.d2p1t1+d2p1t1 \ + cluster.n_nodes=1 \ + cluster.n_gpus_per_node=4 \ + gconfig.max_new_tokens=2048 \ + train_dataset.batch_size=1024 \ + +sglang.attention_backend=triton ``` -Refer to the [official documentation](https://huggingface.co/docs/huggingface_hub/guides/cli) for more information on using `huggingface-cli`. +::::{important} We're currently refactoring from legacy AReaL to AReaLite, which +introduces some configuration differences. We provide a **config converter** to transfer +old AReaL config into AReaLite YAML file for users' convenience. [Click here](xxx) to +learn how to use the **config converter**. :::: -## Training +## Distributed Experiments with Ray or Slurm -From the repository directory, run: +AReaLite provides standalone launchers for distributed experiments. After setting up +your Ray or Slurm cluster, launch experiments similarly to `LocalLauncher`: -```bash -# examples/run_async_ppo.sh -python3 training/main_async_ppo.py \ - n_nodes=1 n_gpus_per_node=8 \ - allocation_mode=sglang.d4p1m1+d2p2m1 \ - cluster.fileroot=/path/to/save/logs/checkpoints/ \ - actor.type._class=qwen3 \ - actor.path=Qwen/Qwen3-1.7B \ - ref.type._class=qwen3 \ - ref.path=Qwen/Qwen3-1.7B \ - dataset.path=/path/to/boba_106k_0319.jsonl \ - dataset.train_bs_n_seqs=32 \ - group_size=8 \ - ppo.gen.max_new_tokens=4096 \ - ppo.ppo_n_minibatches=4 \ - actor_train.mb_spec.max_tokens_per_mb=32768 \ - actor_inf.mb_spec.max_tokens_per_mb=32768 \ - max_concurrent_rollouts=16 \ - max_head_offpolicyness=4 +``` +# Launch with Ray launcher. 4 nodes (4 GPUs each), 3 nodes for generation, 1 node for training. +python3 -m arealite.launcher.ray examples/arealite/gsm8k_grpo.py \ + --config examples/arealite/configs/gsm8k_grpo.yaml \ + experiment_name= \ + trial_name= \ + allocation_mode=sglang.d12p1t1+d4p1t1 \ + cluster.n_nodes=4 \ + cluster.n_gpus_per_node=4 \ + ... + +# Launch with Slurm launcher. 16 nodes (8 GPUs each), 12 nodes for generation, 4 nodes for training +python3 -m arealite.launcher.slurm examples/arealite/gsm8k_grpo.py \ + --config examples/arealite/configs/gsm8k_grpo.yaml \ + experiment_name= \ + trial_name= \ + allocation_mode=sglang.d96p1t1+d32p1t1 \ + cluster.n_nodes=16 \ + cluster.n_gpus_per_node=8 \ + ... ``` -::::{important} -Running `main_async_ppo.py` with `ppo.recompute_logprob=False`, `ppo.use_decoupled_loss=False`, and `max_head_offpolicyness=0` will essentially replicate the behavior of synchronous PPO. Therefore, it's usually not recommended to run synchronous PPO directly (i.e., `main_sync_ppo.py`). The workflow of asynchronous RL is more stable and easier to customize. -:::: +Additional references: -## Command Line Options +- For more options for launchers, check `LauncherConfig` in + [arealite/api/cli_args.py](https://github.com/inclusionAI/AReaL/blob/main/arealite/api/cli_args.py). +- [Ray cluster setup guide](installation.md#optional-launch-ray-cluster-for-distributed-training) + for a guide on how to set up a ray cluster. -To view all available options: +> **Important Notes**: +> +> 1. Ensure `allocation_mode` matches your cluster configuration +> (`#GPUs == cluster.n_nodes * cluster.n_gpus_per_node`) +> 1. Ray/Slurm launchers only works for more than 1 node (`cluster.n_nodes > 1`). For +> single node scenario, please use `LocalLauncher`. +> 1. In Ray/Slurm launchers, GPUs are allocated at node granularity, which means #GPUs +> for generation or training must be integer multiples of `cluster.n_gpus_per_node`. -```bash -python3 training/main_sync_ppo.py --help -``` - -### Configuration Parameters - -- **`experiment_name`**: The name of your project. -- **`trial_name`**: The name of this trial in your project. -- **`{actor|ref}.path`**: The path to the model files. -- **`dataset.path`**: The path to the dataset JSONL file. -- **`cluster.fileroot`**: The root path for saving training outputs (logs and checkpoints). -- **`n_nodes`**: The number of nodes in the cluster. -- **`n_gpus_per_node`**: The number of GPUs per node. -- **`allocation_mode`**: The GPU allocation strategy and 3D parallelism configuration for the experiment. Format: - - `sglang.d${DP1}m${TP1}p${PP1}+d${DP2}m${TP2}p${PP2}`: Configures parallel strategies for SGLang generation and training respectively. Generation and training use separate GPU sets, and the total GPU count must equal: DP1×TP1×PP1 + DP2×TP2×PP2 = #GPUs. - -### Training Control - -- **`exp_ctrl.total_train_epochs`**: Number of training epochs (complete dataset iterations). -- **`exp_ctrl.save_freq_{epochs|steps|secs}`**: Frequency for saving model parameters to persistent storage. Set to null to disable saving. -- **`exp_ctrl.ckpt_freq_{epochs|steps|secs}`**: Frequency for saving temporary parameters for restart capability. -- **`dataset.train_bs_n_seqs`**: Training batch size (number of prompts sampled per training iteration). -- **`group_size`**: Number of responses sampled per prompt. - -### Memory and Performance - -- **`{actor_train|ref_inf|actor_inf}.mb_spec.max_tokens_per_mb`**: Maximum tokens per mini-batch for forward/backward passes during reference model inference and actor model training. Reduce this value to avoid OOM errors. -- **`max_concurrent_rollouts`**: The maximum number of concurrent rollouts. SGLang will run out of memory if this value is too large. Defaults to `dataset.train_bs_n_seqs`. - -### Algorithm Configuration - -- **`max_head_offpolicyness`**: The allowed maximum data staleness. 0 recovers synchronous training. A large value will increase generation throughput but degrade final performance. We recommend keeping this value at 8 or below. -- **`ppo.recompute_logprob`**: Whether to compute proximal log probabilities for training. Defaults to True for asynchronous experiments and False for synchronous baselines. -- **`ppo.use_decoupled_loss`**: Use decoupled loss to stabilize asynchronous training. Defaults to True. -- **`ppo.gen.max_new_tokens`**: Maximum tokens to generate per prompt. -- **`ppo.ppo_n_minibatches`**: Number of mini-batches for dividing data during each PPO update. -- **`success_rate_ub`**: Upper bound of success rate. Prompts with a higher success rate will be filtered out. -- **`success_rate_lb`**: Lower bound of success rate. Prompts with a lower success rate will be filtered out. - -## Monitoring the Training Process - -+ We recommend using [Weights & Biases (wandb)](https://github.com/wandb/wandb) or [SwanLab](https://github.com/SwanHubX/SwanLab) for monitoring—run `wandb login` or `swanlab login`, or set the corresponding environment variable API key (`WANDB_API_KEY` or `SWANLAB_API_KEY`). Set `wandb.mode="online"` or `swanlab.mode="cloud"` in your configuration to upload training statistics. If you cannot connect to the server, you can also use `wandb.mode="offline"` or `swanlab.mode="local"` to save data locally without uploading. - - -You can also use TensorBoard by setting the `tensorboard.path` parameter. - -The main log will be saved to `${fileroot}/logs/${USER}/${experiment_name}/${trial_name}/main.log` and contains the statistics uploaded to wandb. - -If SwanLab is enabled, logs will be saved to the directory specified by `swanlab.logdir`. - -### Key Training Statistics - -- **`Epoch 1/5`**: Indicates the total epochs required and the current epoch being trained. -- **`step 6/19`**: Shows that the current epoch has 19 steps, with the 6th step just completed. -- **`global step 6`**: Step count across all epochs. -- **`ppo_actor/task_reward/avg`**: Average reward value of all sampled responses in this step. This should steadily increase during training and eventually stabilize. -- **`ppo_actor/importance_weight/avg`**: Average importance sampling ratio across all tokens in the PPO loss. This is typically close to 1.0. -- **`ppo_actor/actor_clip_ratio/avg`**: Ratio of clipped tokens in PPO loss to total tokens. This is usually less than 0.1. -- **`ppo_actor/actor_loss/avg`**: PPO loss value. **This does not show clear trends during training** and should not be used as a performance indicator. + ## Next Steps -[Evaluate your model](eval.md) or check the [troubleshooting section](troubleshooting.md) if you encounter any issues. \ No newline at end of file +Check [Getting Started with AReaLite](../arealite/gsm8k_grpo.md) for a complete code +walkthrough on the GRPO GSM8K Example. + +Customization guides: + +- [Custom dataset](../customization/dataset.md) +- [Custom agentic/RVLR rollout workflows](../customization/agent.md) +- [Custom algorithms](../customization/algorithm.md) diff --git a/docs/tutorial/quickstart_arealite.md b/docs/tutorial/quickstart_arealite.md deleted file mode 100644 index 086179e..0000000 --- a/docs/tutorial/quickstart_arealite.md +++ /dev/null @@ -1,101 +0,0 @@ -# Quickstart - -Welcome to the **AReaLite** Quickstart Guide! -This guide demonstrates how to run an AReaLite experiment training an LLM on the GSM8K dataset using the GRPO algorithm with function-based rewards. -Ensure you've completed [the installation and environment setup](installation.md) before proceeding. - -## Running the Experiment (on a single node) - -To run the experiment, you will need: -- Training script: [examples/arealite/gsm8k_grpo.py](../../examples/arealite/gsm8k_grpo.py) -- Config YAML: [examples/arealite/configs/gsm8k_grpo.yaml](../../examples/arealite/configs/gsm8k_grpo.yaml) - -Our training scripts will automatically download the dataset (openai/gsm8k) and model (Qwen/Qwen2-1.5B-Instruct). -To run the example with default configuration, execute from the repository directory: -``` -python3 -m arealite.launcher.local examples/arealite/gsm8k_grpo.py --config examples/arealite/configs/gsm8k_grpo.yaml experiment_name= trial_name= -``` - -> **Note**: The command above uses `LocalLauncher`, which only works for a single node (`cluster.n_nodes == 1`). For distributed experiments, see [Distributed Experiments with Ray or Slurm](quickstart_arealite.md#distributed-experiments-with-ray-or-slurm). - -## Modifying configuration - -All available configuration options are listed in [arealite/api/cli_args.py](https://github.com/inclusionAI/AReaL/blob/main/arealite/api/cli_args.py). -To customize the experiment (models, resources, algorithm options), you can: -1. Edit the YAML file directly at [examples/arealite/configs/gsm8k_grpo.yaml](../../examples/arealite/configs/gsm8k_grpo.yaml). -2. Add command-line options: - - For existing options in the YAML file, directly add the option: `actor.path=Qwen/Qwen3-1.7B`. - - For other options in `cli_args.py`, but not in the YAML file, add with a prefix "+": `+sglang.attention_backend=triton`. - - - -For example, here is the command to launch a customized configuration, based on our GSM8K GRPO example: -``` -python3 -m arealite.launcher.local examples/arealite/gsm8k_grpo.py \ - --config examples/arealite/configs/gsm8k_grpo.yaml \ - experiment_name= \ - trial_name= \ - allocation_mode=sglang.d2p1t1+d2p1t1 \ - cluster.n_nodes=1 \ - cluster.n_gpus_per_node=4 \ - gconfig.max_new_tokens=2048 \ - train_dataset.batch_size=1024 \ - +sglang.attention_backend=triton -``` - -::::{important} -We're currently refactoring from legacy AReaL to AReaLite, which introduces some configuration differences. We provide a **config converter** to transfer old AReaL config into AReaLite YAML file for users' convenience. [Click here](xxx) to learn how to use the **config converter**. -:::: - -## Distributed Experiments with Ray or Slurm - -AReaLite provides standalone launchers for distributed experiments. After setting up your Ray or Slurm cluster, launch experiments similarly to `LocalLauncher`: - -``` -# Launch with Ray launcher. 4 nodes (4 GPUs each), 3 nodes for generation, 1 node for training. -python3 -m arealite.launcher.ray examples/arealite/gsm8k_grpo.py \ - --config examples/arealite/configs/gsm8k_grpo.yaml \ - experiment_name= \ - trial_name= \ - allocation_mode=sglang.d12p1t1+d4p1t1 \ - cluster.n_nodes=4 \ - cluster.n_gpus_per_node=4 \ - ... - -# Launch with Slurm launcher. 16 nodes (8 GPUs each), 12 nodes for generation, 4 nodes for training -python3 -m arealite.launcher.slurm examples/arealite/gsm8k_grpo.py \ - --config examples/arealite/configs/gsm8k_grpo.yaml \ - experiment_name= \ - trial_name= \ - allocation_mode=sglang.d96p1t1+d32p1t1 \ - cluster.n_nodes=16 \ - cluster.n_gpus_per_node=8 \ - ... -``` - -Additional references: -- For more options for launchers, check `LauncherConfig` in [arealite/api/cli_args.py](https://github.com/inclusionAI/AReaL/blob/main/arealite/api/cli_args.py). -- [Ray cluster setup guide](installation.md#optional-launch-ray-cluster-for-distributed-training) for a guide on how to set up a ray cluster. - -> **Important Notes**: -> 1. Ensure `allocation_mode` matches your cluster configuration (`#GPUs == cluster.n_nodes * cluster.n_gpus_per_node`) -> 2. Ray/Slurm launchers only works for more than 1 node (`cluster.n_nodes > 1`). For single node scenario, please use `LocalLauncher`. -> 3. In Ray/Slurm launchers, GPUs are allocated at node granularity, which means #GPUs for generation or training must be integer multiples of `cluster.n_gpus_per_node`. - - - -## Next Steps - - - -Customization guides: -- [Custom dataset](../customization/dataset.md) -- [Custom agentic/RVLR rollout workflows](../customization/agent.md) -- [Custom algorithms](../customization/algorithm.md) \ No newline at end of file diff --git a/docs/tutorial/quickstart_legacy.md b/docs/tutorial/quickstart_legacy.md new file mode 100644 index 0000000..5444456 --- /dev/null +++ b/docs/tutorial/quickstart_legacy.md @@ -0,0 +1,169 @@ +# Quickstart (Legacy) + +> **Note**: This is a quickstart guide for launching AReaL experiment with legacy code +> in `realhf/`. We strongly recommend users to try AReaLite for better experiences. +> [Click here](quickstart.md) for AReaLite quickstart guide! + +This guide walks you through a simple example of training an LLM to solve math problems. +Please ensure you have properly +[installed dependencies and set up the runtime environment](installation.md) before +proceeding. + +## Dataset + +Use `huggingface-cli` to download our open-source dataset: + +```bash +huggingface-cli download --repo-type=dataset inclusionAI/AReaL-RL-Data +``` + +> **Note**: The command above will display the path of the downloaded dataset. You'll +> need to pass this path to the training command. + +## Model + +We train using open-source models available on Hugging Face Hub. You can either download +the model in advance or use the model identifier when running the experiment. + +```bash +# If you want to download it in advance +huggingface-cli download Qwen/Qwen3-1.7B +``` + +Refer to the +[official documentation](https://huggingface.co/docs/huggingface_hub/guides/cli) for +more information on using `huggingface-cli`. + +## Training + +From the repository directory, run: + +```bash +# examples/run_async_ppo.sh +python3 training/main_async_ppo.py \ + n_nodes=1 n_gpus_per_node=8 \ + allocation_mode=sglang.d4p1m1+d2p2m1 \ + cluster.fileroot=/path/to/save/logs/checkpoints/ \ + actor.type._class=qwen3 \ + actor.path=Qwen/Qwen3-1.7B \ + ref.type._class=qwen3 \ + ref.path=Qwen/Qwen3-1.7B \ + dataset.path=/path/to/boba_106k_0319.jsonl \ + dataset.train_bs_n_seqs=32 \ + group_size=8 \ + ppo.gen.max_new_tokens=4096 \ + ppo.ppo_n_minibatches=4 \ + actor_train.mb_spec.max_tokens_per_mb=32768 \ + actor_inf.mb_spec.max_tokens_per_mb=32768 \ + max_concurrent_rollouts=16 \ + max_head_offpolicyness=4 +``` + +::::{important} Running `main_async_ppo.py` with `ppo.recompute_logprob=False`, +`ppo.use_decoupled_loss=False`, and `max_head_offpolicyness=0` will essentially +replicate the behavior of synchronous PPO. Therefore, it's usually not recommended to +run synchronous PPO directly (i.e., `main_sync_ppo.py`). The workflow of asynchronous RL +is more stable and easier to customize. :::: + +## Command Line Options + +To view all available options: + +```bash +python3 training/main_sync_ppo.py --help +``` + +### Configuration Parameters + +- **`experiment_name`**: The name of your project. +- **`trial_name`**: The name of this trial in your project. +- **`{actor|ref}.path`**: The path to the model files. +- **`dataset.path`**: The path to the dataset JSONL file. +- **`cluster.fileroot`**: The root path for saving training outputs (logs and + checkpoints). +- **`n_nodes`**: The number of nodes in the cluster. +- **`n_gpus_per_node`**: The number of GPUs per node. +- **`allocation_mode`**: The GPU allocation strategy and 3D parallelism configuration + for the experiment. Format: + - `sglang.d${DP1}m${TP1}p${PP1}+d${DP2}m${TP2}p${PP2}`: Configures parallel strategies + for SGLang generation and training respectively. Generation and training use + separate GPU sets, and the total GPU count must equal: DP1×TP1×PP1 + DP2×TP2×PP2 = + #GPUs. + +### Training Control + +- **`exp_ctrl.total_train_epochs`**: Number of training epochs (complete dataset + iterations). +- **`exp_ctrl.save_freq_{epochs|steps|secs}`**: Frequency for saving model parameters to + persistent storage. Set to null to disable saving. +- **`exp_ctrl.ckpt_freq_{epochs|steps|secs}`**: Frequency for saving temporary + parameters for restart capability. +- **`dataset.train_bs_n_seqs`**: Training batch size (number of prompts sampled per + training iteration). +- **`group_size`**: Number of responses sampled per prompt. + +### Memory and Performance + +- **`{actor_train|ref_inf|actor_inf}.mb_spec.max_tokens_per_mb`**: Maximum tokens per + mini-batch for forward/backward passes during reference model inference and actor + model training. Reduce this value to avoid OOM errors. +- **`max_concurrent_rollouts`**: The maximum number of concurrent rollouts. SGLang will + run out of memory if this value is too large. Defaults to `dataset.train_bs_n_seqs`. + +### Algorithm Configuration + +- **`max_head_offpolicyness`**: The allowed maximum data staleness. 0 recovers + synchronous training. A large value will increase generation throughput but degrade + final performance. We recommend keeping this value at 8 or below. +- **`ppo.recompute_logprob`**: Whether to compute proximal log probabilities for + training. Defaults to True for asynchronous experiments and False for synchronous + baselines. +- **`ppo.use_decoupled_loss`**: Use decoupled loss to stabilize asynchronous training. + Defaults to True. +- **`ppo.gen.max_new_tokens`**: Maximum tokens to generate per prompt. +- **`ppo.ppo_n_minibatches`**: Number of mini-batches for dividing data during each PPO + update. +- **`success_rate_ub`**: Upper bound of success rate. Prompts with a higher success rate + will be filtered out. +- **`success_rate_lb`**: Lower bound of success rate. Prompts with a lower success rate + will be filtered out. + +## Monitoring the Training Process + +- We recommend using [Weights & Biases (wandb)](https://github.com/wandb/wandb) or + [SwanLab](https://github.com/SwanHubX/SwanLab) for monitoring—run `wandb login` or + `swanlab login`, or set the corresponding environment variable API key + (`WANDB_API_KEY` or `SWANLAB_API_KEY`). Set `wandb.mode="online"` or + `swanlab.mode="cloud"` in your configuration to upload training statistics. If you + cannot connect to the server, you can also use `wandb.mode="offline"` or + `swanlab.mode="local"` to save data locally without uploading. + +You can also use TensorBoard by setting the `tensorboard.path` parameter. + +The main log will be saved to +`${fileroot}/logs/${USER}/${experiment_name}/${trial_name}/main.log` and contains the +statistics uploaded to wandb. + +If SwanLab is enabled, logs will be saved to the directory specified by +`swanlab.logdir`. + +### Key Training Statistics + +- **`Epoch 1/5`**: Indicates the total epochs required and the current epoch being + trained. +- **`step 6/19`**: Shows that the current epoch has 19 steps, with the 6th step just + completed. +- **`global step 6`**: Step count across all epochs. +- **`ppo_actor/task_reward/avg`**: Average reward value of all sampled responses in this + step. This should steadily increase during training and eventually stabilize. +- **`ppo_actor/importance_weight/avg`**: Average importance sampling ratio across all + tokens in the PPO loss. This is typically close to 1.0. +- **`ppo_actor/actor_clip_ratio/avg`**: Ratio of clipped tokens in PPO loss to total + tokens. This is usually less than 0.1. +- **`ppo_actor/actor_loss/avg`**: PPO loss value. **This does not show clear trends + during training** and should not be used as a performance indicator. + +## Next Steps + +[Evaluate your model](eval.md) or check the +[troubleshooting section](troubleshooting.md) if you encounter any issues. diff --git a/training/utils.py b/training/utils.py index e5d409b..e2cd4f4 100644 --- a/training/utils.py +++ b/training/utils.py @@ -11,6 +11,7 @@ from typing import Any, List import psutil import ray +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from realhf.api.cli_args import NameResolveConfig from realhf.api.core.system_api import Experiment, ExperimentScheduling, TasksGroup @@ -167,6 +168,19 @@ def _run_experiment(exp_cfg, expr_name, trial_name): scheduling: ExperimentScheduling = exp_cfg.scheduling_setup() + # We assume all nodes have the same resources (CPU, GPU, memory). + all_available_resources = ray.available_resources() + all_available_nodes = [ + k + for k in all_available_resources + if re.match(r"node:(\b(?:\d{1,3}\.){3}\d{1,3}\b)", k) + ] + n_nodes = len(all_available_nodes) + n_gpus_per_node = int(all_available_resources["GPU"] // n_nodes) + assert ( + all_available_resources["GPU"] % n_nodes == 0 + ), "AReaL assumes all nodes has the same number of GPUs." + for worker_type in WORKER_TYPES: sch = getattr(scheduling, worker_type) if sch is None: @@ -187,34 +201,73 @@ def _run_experiment(exp_cfg, expr_name, trial_name): f"Please launch more Ray nodes otherwise the experiment will get stuck." ) - # Use a customized packed scheduling method - # that sequentially allocates nodes. - available_nodes = [ - k - for k in available_resources - if re.match(r"node:(\b(?:\d{1,3}\.){3}\d{1,3}\b)", k) - ] - total_gpus = available_resources["GPU"] - n_gpus_per_node = int(total_gpus // len(available_nodes)) - - count = sch.count - all_schedules: List[TasksGroup] = [] - for _ in range(sch.count): - s_ = copy.deepcopy(sch) - s_.count = 1 - all_schedules.append(s_) - workers = [] + if sch.scheduling.gpu > 0 and n_nodes > 1: + # When # nodes > 1, for GPU workers, schedule them in granularity of nodes. + assert ( + n_gpus_per_node % sch.scheduling.gpu == 0 + ), f"Each node should be allocated with identical numbers of {worker_type}." + n_worker_per_node = int(n_gpus_per_node / sch.scheduling.gpu) + assert sch.count % n_worker_per_node == 0, ( + f"Total {worker_type} count ({sch.count}) should be divisible by " + f"the number of workers per node ({n_worker_per_node})." + ) + n_nodes = int(sch.count / n_worker_per_node) + placement_group = ray.util.placement_group( + bundles=[ + { + "CPU": sch.scheduling.cpu * n_worker_per_node, + "GPU": sch.scheduling.gpu * n_worker_per_node, + "memory": sch.scheduling.mem + * 1024**2 + * n_worker_per_node, # in bytes + } + ] + * n_nodes, + ) + try: + ray.get(placement_group.ready(), timeout=30) + except ray.exceptions.GetTimeoutError: + logger.critical( + f"Failed to create placement group for {worker_type}s. " + f"Please make sure at least {n_nodes} node " + f"has resources for {n_worker_per_node} {worker_type}s." + ) - for node_idx, i in enumerate(range(0, count, n_gpus_per_node)): - _schedules = all_schedules[i : i + n_gpus_per_node] - for _idx, sch in enumerate(_schedules): - # Schedule jobs one-by-one to maintain the order on remote nodes. + for node_id in range(n_nodes): + # Use a customized packed scheduling method + # that sequentially allocates nodes. + for i in range(n_worker_per_node): + _idx = node_id * n_worker_per_node + i + worker = RayWorker.options( + name=f"{worker_type}/{_idx}", + num_cpus=sch.scheduling.cpu, + num_gpus=sch.scheduling.gpu, + memory=sch.scheduling.mem * 1024**2, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_bundle_index=node_id, + placement_group_capture_child_tasks=True, + ), + ).remote( + args=exp_cfg, + worker_type=worker_type, + worker_cls=load_worker(worker_type), + kv_store_name=ray_kv_store_name, + ) + workers.append(worker) + else: + # Schedule them with SPREAD strategy when + # 1. CPU workers when n_nodes > 1, + # to save as much resource as possible on nodes for GPU workers. + # 2. all workers when n_nodes = 1 + for _idx in range(sch.count): worker = RayWorker.options( - name=f"{worker_type}/{_idx + i}", + name=f"{worker_type}/{_idx}", num_cpus=sch.scheduling.cpu, num_gpus=sch.scheduling.gpu, memory=sch.scheduling.mem * 1024**2, + scheduling_strategy="SPREAD", ).remote( args=exp_cfg, worker_type=worker_type, @@ -222,7 +275,6 @@ def _run_experiment(exp_cfg, expr_name, trial_name): kv_store_name=ray_kv_store_name, ) workers.append(worker) - all_workers[worker_type] = workers try: