mirror of https://github.com/inclusionAI/AReaL
change doc (#180)
This commit is contained in:
parent
e13db01f67
commit
0283cfa124
|
@ -1,8 +1,6 @@
|
|||
# AReaL v1.0.0 Design Doc
|
||||
|
||||
---
|
||||
|
||||
Update 20250710
|
||||
______________________________________________________________________
|
||||
|
||||
SFT example:
|
||||
|
||||
|
@ -10,12 +8,24 @@ SFT example:
|
|||
torchrun --nnodes 1 --nproc-per-node 8 examples/arealite/gsm8k_sft.py --config examples/arealite/configs/gsm8k_sft.yaml
|
||||
```
|
||||
|
||||
---
|
||||
GRPO example:
|
||||
|
||||
```bash
|
||||
python3 -m arealite.launcher.local examples/arealite/gsm8k_grpo.py --config examples/arealite/configs/gsm8k_grpo.yaml
|
||||
```
|
||||
|
||||
We will provide both single-controller and SPMD user interfaces. The SPMD interface will be delivered with AReaLite, which is the paradigm most users are familiar with, just like using `torchrun` or `deepspeed`. However, this paradigm may lack some flexibility over global scheduling and control. To unlock the full potential with customized distributed execution, we will also provide a single-controller mode just like using Ray --- but our scheduler backend will not be restricted to Ray. Our code will be able to run with any scheduler in the cluster, such as native SLURM and K8S.
|
||||
______________________________________________________________________
|
||||
|
||||
However, we want the user code to stay the same for both modes. The following is a simple usage example:
|
||||
We will provide both single-controller and SPMD user interfaces. The SPMD interface will
|
||||
be delivered with AReaLite, which is the paradigm most users are familiar with, just
|
||||
like using `torchrun` or `deepspeed`. However, this paradigm may lack some flexibility
|
||||
over global scheduling and control. To unlock the full potential with customized
|
||||
distributed execution, we will also provide a single-controller mode just like using Ray
|
||||
--- but our scheduler backend will not be restricted to Ray. Our code will be able to
|
||||
run with any scheduler in the cluster, such as native SLURM and K8S.
|
||||
|
||||
However, we want the user code to stay the same for both modes. The following is a
|
||||
simple usage example:
|
||||
|
||||
```python
|
||||
def get_current_time():
|
||||
|
@ -30,7 +40,7 @@ class MyRolloutWorkflow:
|
|||
self.env = LocalToolingEnv()
|
||||
self.env.register_tool(get_current_time)
|
||||
|
||||
async def arun_episode(self, engine: InferenceEngine,
|
||||
async def arun_episode(self, engine: InferenceEngine,
|
||||
data: Dict[str, Any]) -> Dict[str, Tensor]:
|
||||
...
|
||||
message = [
|
||||
|
@ -94,7 +104,7 @@ def main_grpo():
|
|||
# synchronous rollout
|
||||
rollout_batch = rollout.rollout_batch(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
|
||||
# or asynchronous rollout with filtering and off-policyness control
|
||||
# rollout_batch = rollout.prepare_batch(batch,
|
||||
# rollout_batch = rollout.prepare_batch(batch,
|
||||
# workflow=MyRolloutWorkflow(rollout_config.workflow),
|
||||
# should_accept=lambda x: x['rewards'].mean() > 0)
|
||||
|
||||
|
@ -137,12 +147,14 @@ CUDA_VISIBLE_DEVICES=3,4 \
|
|||
|
||||
## Core API
|
||||
|
||||
+ A specific algorithm must use these core components.
|
||||
+ Concrete implementations must follow the API definition.
|
||||
- A specific algorithm must use these core components.
|
||||
- Concrete implementations must follow the API definition.
|
||||
|
||||
### TrainEngine
|
||||
|
||||
TrainEngine is a thin wrapper around existing training frameworks (FSDP, Megatron), providing a unified interface for RL algorithms for computation, parameter saving and loading, and providing a unified weight update interface for inference engines.
|
||||
TrainEngine is a thin wrapper around existing training frameworks (FSDP, Megatron),
|
||||
providing a unified interface for RL algorithms for computation, parameter saving and
|
||||
loading, and providing a unified weight update interface for inference engines.
|
||||
|
||||
```python
|
||||
#############################################
|
||||
|
@ -165,7 +177,7 @@ class TrainEngine(abc.ABC):
|
|||
# control api
|
||||
def __init__(self, para_config)
|
||||
self.para_config = para_config
|
||||
|
||||
|
||||
def initialize(self, addr: str|None, ft_spec|None):
|
||||
# Initialize distributed environment, initialize and load model
|
||||
# addr is the corresponding service address when deploying sglang or fsdp/megatron remotely
|
||||
|
@ -196,7 +208,7 @@ class TrainEngine(abc.ABC):
|
|||
# Due to PPO minibatch updates, multiple train batches may need to be called
|
||||
# before calling step_lr_scheduler once, so this api needs to be separated
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
|
@ -240,7 +252,7 @@ class FSDPEngine(TrainEngine):
|
|||
|
||||
def destroy(self):
|
||||
del self.optimizer
|
||||
del self.model
|
||||
del self.model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
@ -265,8 +277,8 @@ class FSDPEngine(TrainEngine):
|
|||
|
||||
def load(self, meta):
|
||||
...
|
||||
|
||||
|
||||
|
||||
|
||||
############# Helper methods start #############
|
||||
|
||||
def load_from_hf(self, path):
|
||||
|
@ -287,7 +299,7 @@ class FSDPEngine(TrainEngine):
|
|||
if base_model_path is not None:
|
||||
copy_hf_configs(base_model_path, path)
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def save_for_recover(self, path: str):
|
||||
self.save_model_to_dcp(path)
|
||||
self.save_optimizer_state(path)
|
||||
|
@ -295,7 +307,7 @@ class FSDPEngine(TrainEngine):
|
|||
def load_from_recover(self, path):
|
||||
self.load_model_from_dcp(path)
|
||||
self.load_optimizer_state(path)
|
||||
|
||||
|
||||
async def init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
||||
# Initialize NCCL communication group for weight updates
|
||||
...
|
||||
|
@ -327,11 +339,14 @@ class FSDPEngine(TrainEngine):
|
|||
|
||||
### Algorithm-Specific TrainEngine API
|
||||
|
||||
Extended engines (such as Actor in PPO) provide convenient organization and calling of computational interfaces specific to algorithms. These computational interfaces maintain single-process computational logic, but can be called by controllers in top-level training scripts to complete distributed semantic computational orchestration.
|
||||
Extended engines (such as Actor in PPO) provide convenient organization and calling of
|
||||
computational interfaces specific to algorithms. These computational interfaces maintain
|
||||
single-process computational logic, but can be called by controllers in top-level
|
||||
training scripts to complete distributed semantic computational orchestration.
|
||||
|
||||
```python
|
||||
class Actor(Engine):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logps(self, input_: Dict[str, Tensor]) -> torch.Tensor:
|
||||
... # unpad
|
||||
|
@ -351,7 +366,7 @@ class Actor(Engine):
|
|||
return all_stats
|
||||
|
||||
class Critic(Engine):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_values(self, input_: Dict) -> torch.Tensor:
|
||||
pass
|
||||
|
@ -379,15 +394,18 @@ class MegatronCritic(MegatronEngine, Critic):
|
|||
|
||||
### Inference Engine API
|
||||
|
||||
Define the InferenceEngine API in a local-like mode, rather than a client-server separated form, mainly for user-centered convenience in using the InferenceEngine as a tool.
|
||||
Define the InferenceEngine API in a local-like mode, rather than a client-server
|
||||
separated form, mainly for user-centered convenience in using the InferenceEngine as a
|
||||
tool.
|
||||
|
||||
InferenceEngine can internally start a SGLang subprocess (SGLangEngine), or call a remotely deployed service (RemoteSGLangEngine).
|
||||
InferenceEngine can internally start a SGLang subprocess (SGLangEngine), or call a
|
||||
remotely deployed service (RemoteSGLangEngine).
|
||||
|
||||
```python
|
||||
class InferenceEngine(ABC):
|
||||
def __init__(self, config)
|
||||
self.config = config
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def initialize(self, addr: str|None, ft_spec):
|
||||
"""Start SGLang Engine, starting the engine will call model loading by default
|
||||
|
@ -439,7 +457,7 @@ class InferenceEngine(ABC):
|
|||
class SGLangEngine(InferenceEngine):
|
||||
|
||||
def __init__(self, config: InfEngineConfig):
|
||||
self.config = config
|
||||
self.config = config
|
||||
self.weight_update_group_initialized = False
|
||||
|
||||
async def update_weights(self, meta) -> None:
|
||||
|
@ -454,7 +472,7 @@ class SGLangEngine(InferenceEngine):
|
|||
async def agenerate(self, req: LLMRequest) -> LLMResponse:
|
||||
# Given a prompt, generate a response with LLM
|
||||
return await self.llm.generate_async(xxx)
|
||||
|
||||
|
||||
# Weight update
|
||||
@abstractmethod
|
||||
def update_weights_from_disk(self, path) -> None:
|
||||
|
@ -483,7 +501,7 @@ class SGLangEngine(InferenceEngine):
|
|||
@abstractmethod
|
||||
def check_health(self) -> bool:
|
||||
"""Check server health status
|
||||
|
||||
|
||||
Returns:
|
||||
bool: Whether the server is healthy
|
||||
"""
|
||||
|
@ -492,7 +510,8 @@ class SGLangEngine(InferenceEngine):
|
|||
|
||||
### RolloutWorkflow
|
||||
|
||||
RolloutWorkflow is a class that provides the arun_episode method. This method has a fixed signature, used to collect one agent trajectory.
|
||||
RolloutWorkflow is a class that provides the arun_episode method. This method has a
|
||||
fixed signature, used to collect one agent trajectory.
|
||||
|
||||
```python
|
||||
class MyRolloutWorkflow:
|
||||
|
@ -501,7 +520,7 @@ class MyRolloutWorkflow:
|
|||
self.tool_executor = ToolExecutor()
|
||||
self.tool_executor.register_tool(get_current_time)
|
||||
|
||||
async def arun_episode(self, engine: InferenceEngine,
|
||||
async def arun_episode(self, engine: InferenceEngine,
|
||||
data: Dict[str, Any]) -> Dict[str, Tensor]:
|
||||
...
|
||||
req = LLMRequest(input_ids=data['input_ids'], ...)
|
||||
|
@ -520,14 +539,20 @@ They have the same API as `InferenceEngine` and `TrainEngine`, respectively.
|
|||
|
||||
## Other Components
|
||||
|
||||
1. Algorithm workflows don't necessarily use these components; other replaceable components such as implementations in rllm or verl-agent can also be used
|
||||
2. Mainly for internal implementation and division of labor, as a thin wrapper to facilitate adaptation of external packages
|
||||
1. Algorithm workflows don't necessarily use these components; other replaceable
|
||||
components such as implementations in rllm or verl-agent can also be used
|
||||
1. Mainly for internal implementation and division of labor, as a thin wrapper to
|
||||
facilitate adaptation of external packages
|
||||
|
||||
### Environment
|
||||
|
||||
1. Support multi-tool management and unified execution interface
|
||||
2. Support local calling and mcp service calling
|
||||
3. "Tools" belong to an instance rather than a class, register_tool is defined as a method rather than a static method, this is (1) to prevent tools from subclasses being registered to the base class, causing potential naming or calling conflicts; (2) to support multiple tasks for the same service (e.g., browser), with each task having a different toolset
|
||||
1. Support local calling and mcp service calling
|
||||
1. "Tools" belong to an instance rather than a class, register_tool is defined as a
|
||||
method rather than a static method, this is (1) to prevent tools from subclasses
|
||||
being registered to the base class, causing potential naming or calling conflicts;
|
||||
(2) to support multiple tasks for the same service (e.g., browser), with each task
|
||||
having a different toolset
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
|
@ -536,14 +561,14 @@ from typing import Any, Dict, List
|
|||
class ToolingEnv:
|
||||
def __init__(self):
|
||||
self.is_initialized = False
|
||||
|
||||
|
||||
self.tool_registry: Dict[str, Callable] = {}
|
||||
self.tool_schemas: List[Dict[str, Any]] = []
|
||||
|
||||
|
||||
async def ainitialize(self):
|
||||
"""
|
||||
Performs the initialization logic for the environment.
|
||||
For stateful environments, this is where resources are created and
|
||||
For stateful environments, this is where resources are created and
|
||||
prepared (e.g., launching a browser).
|
||||
"""
|
||||
pass
|
||||
|
@ -575,7 +600,7 @@ class MCPToolingEnv(ToolingEnv):
|
|||
def aexecute(self, tool_name: str, tool_args: Dict[str, Any]):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class LocalToolingEnv(ToolingEnv):
|
||||
|
||||
@staticmethod
|
||||
|
@ -591,7 +616,7 @@ class LocalToolingEnv(ToolingEnv):
|
|||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
|
||||
# Mapping from Python types to JSON Schema types
|
||||
type_mapping = {
|
||||
str: "string",
|
||||
|
@ -602,13 +627,13 @@ class LocalToolingEnv(ToolingEnv):
|
|||
|
||||
for name, param in sig.parameters.items():
|
||||
# Default to string type if type hint is missing or complex
|
||||
param_type = type_mapping.get(param.annotation, "string")
|
||||
param_type = type_mapping.get(param.annotation, "string")
|
||||
parameters["properties"][name] = {"type": param_type}
|
||||
|
||||
|
||||
# If a parameter has no default value, it is considered required.
|
||||
if param.default is inspect.Parameter.empty:
|
||||
parameters["required"].append(name)
|
||||
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
|
@ -624,11 +649,11 @@ class LocalToolingEnv(ToolingEnv):
|
|||
"""
|
||||
if not callable(func):
|
||||
raise TypeError("The provided object must be a callable function.")
|
||||
|
||||
|
||||
tool_name = func.__name__
|
||||
if tool_name in self.tool_registry:
|
||||
raise ValueError(f"Tool with name '{tool_name}' is already registered.")
|
||||
|
||||
|
||||
# Add the function to the registry and its schema to the schema list.
|
||||
self.tool_registry[tool_name] = func
|
||||
self.tool_schemas.append(self.generate_schema(func))
|
||||
|
@ -636,7 +661,7 @@ class LocalToolingEnv(ToolingEnv):
|
|||
def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Lists all available tools provided by this environment and their descriptions.
|
||||
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, where each dictionary describes a tool.
|
||||
"""
|
||||
|
@ -651,14 +676,14 @@ class LocalToolingEnv(ToolingEnv):
|
|||
tool_args (Dict[str, Any]): The arguments required to call the tool.
|
||||
|
||||
Returns:
|
||||
Any: The result of the tool's execution, typically a string or
|
||||
Any: The result of the tool's execution, typically a string or
|
||||
structured JSON.
|
||||
"""
|
||||
if tool_name not in self._tool_registry:
|
||||
return f"Error: Tool '{tool_name}' is not registered."
|
||||
|
||||
|
||||
tool_func = self._tool_registry[tool_name]
|
||||
|
||||
|
||||
try:
|
||||
result = tool_func(**tool_args)
|
||||
return result
|
||||
|
@ -667,14 +692,19 @@ class LocalToolingEnv(ToolingEnv):
|
|||
return f"Error executing '{tool_name}': Invalid arguments. Details: {e}"
|
||||
except Exception as e:
|
||||
return f"Error executing '{tool_name}': An unexpected error occurred. Details: {e}"
|
||||
|
||||
|
||||
```
|
||||
|
||||
### Reward
|
||||
|
||||
1. Workflows for computing rewards using models and computing rewards based on rules should be separated
|
||||
2. Rule-based reward computation is defined as a function with a predefined signature, which can be local or remote, and is generally wrapped in the rollout workflow; user-written reward functions can also not use this signature as long as they provide a workflow
|
||||
3. Computing rewards using models requires users to initialize a controller/engine themselves and explicitly call it in the algorithm workflow
|
||||
1. Workflows for computing rewards using models and computing rewards based on rules
|
||||
should be separated
|
||||
1. Rule-based reward computation is defined as a function with a predefined signature,
|
||||
which can be local or remote, and is generally wrapped in the rollout workflow;
|
||||
user-written reward functions can also not use this signature as long as they provide
|
||||
a workflow
|
||||
1. Computing rewards using models requires users to initialize a controller/engine
|
||||
themselves and explicitly call it in the algorithm workflow
|
||||
|
||||
```python
|
||||
|
||||
|
@ -704,7 +734,8 @@ for _ in range(epochs):
|
|||
|
||||
### Dataset
|
||||
|
||||
Use huggingface datasets and pytorch torchdata. In Single-Controller mode, only one process per experiment is responsible for data loading.
|
||||
Use huggingface datasets and pytorch torchdata. In Single-Controller mode, only one
|
||||
process per experiment is responsible for data loading.
|
||||
|
||||
```python
|
||||
from datasets import Dataset, load_dataset
|
||||
|
|
|
@ -1,8 +1,35 @@
|
|||
# Quickstart
|
||||
|
||||
This guide walks you through a simple example of training an LLM to solve math problems. Please ensure you have properly [installed dependencies and set up the runtime environment](installation.md) before proceeding.
|
||||
This guide walks you through a simple example of training an LLM to solve math problems.
|
||||
Please ensure you have properly
|
||||
[installed dependencies and set up the runtime environment](installation.md) before
|
||||
proceeding.
|
||||
|
||||
## Dataset
|
||||
## Option 1: Using *AReaLite*
|
||||
|
||||
AReaLite is an RL training framework that provides the same functionality as AReaL, but
|
||||
is much easier to use, customize, and understand. It does not depend on AReaL except for
|
||||
some common core utilities such as logging.
|
||||
|
||||
We provide usage examples in the `examples/arealite` folder. To launch an experiment
|
||||
that trains your LLM to solve GSM8k math problems, run the following command:
|
||||
|
||||
```bash
|
||||
python3 -m arealite.launcher.local examples/arealite/gsm8k_grpo.py --config examples/arealite/configs/gsm8k_grpo.yaml
|
||||
```
|
||||
|
||||
You can modify any options in `examples/arealite/configs/gsm8k_grpo.yaml`, such as the
|
||||
base model to use and hyperparameters. Note that this example does not support changing
|
||||
the dataset through configuration modifications. Users can modify the dataset processing
|
||||
logic using the HuggingFace `datasets` package in the training script
|
||||
`examples/arealite/gsm8k_grpo.py` to use other datasets.
|
||||
|
||||
> **Note**: This command assumes you can connect to the HuggingFace Hub to download
|
||||
> models and datasets. Use [hf-mirror](https://hf-mirror.com/) if necessary.
|
||||
|
||||
## Option 2: Using the old version of AReaL
|
||||
|
||||
### Dataset
|
||||
|
||||
Use `huggingface-cli` to download our open-source dataset:
|
||||
|
||||
|
@ -10,20 +37,24 @@ Use `huggingface-cli` to download our open-source dataset:
|
|||
huggingface-cli download --repo-type=dataset inclusionAI/AReaL-RL-Data
|
||||
```
|
||||
|
||||
> **Note**: The command above will display the path of the downloaded dataset. You'll need to pass this path to the training command.
|
||||
> **Note**: The command above will display the path of the downloaded dataset. You'll
|
||||
> need to pass this path to the training command.
|
||||
|
||||
## Model
|
||||
### Model
|
||||
|
||||
We train using open-source models available on Hugging Face Hub. You can either download the model in advance or use the model identifier when running the experiment.
|
||||
We train using open-source models available on Hugging Face Hub. You can either download
|
||||
the model in advance or use the model identifier when running the experiment.
|
||||
|
||||
```bash
|
||||
# If you want to download it in advance
|
||||
huggingface-cli download Qwen/Qwen3-1.7B
|
||||
```
|
||||
|
||||
Refer to the [official documentation](https://huggingface.co/docs/huggingface_hub/guides/cli) for more information on using `huggingface-cli`.
|
||||
Refer to the
|
||||
[official documentation](https://huggingface.co/docs/huggingface_hub/guides/cli) for
|
||||
more information on using `huggingface-cli`.
|
||||
|
||||
## Training
|
||||
### Training
|
||||
|
||||
From the repository directory, run:
|
||||
|
||||
|
@ -48,11 +79,13 @@ python3 training/main_async_ppo.py \
|
|||
max_head_offpolicyness=4
|
||||
```
|
||||
|
||||
::::{important}
|
||||
Running `main_async_ppo.py` with `ppo.recompute_logprob=False`, `ppo.use_decoupled_loss=False`, and `max_head_offpolicyness=0` will essentially replicate the behavior of synchronous PPO. Therefore, it's usually not recommended to run synchronous PPO directly (i.e., `main_sync_ppo.py`). The workflow of asynchronous RL is more stable and easier to customize.
|
||||
::::
|
||||
::::{important} Running `main_async_ppo.py` with `ppo.recompute_logprob=False`,
|
||||
`ppo.use_decoupled_loss=False`, and `max_head_offpolicyness=0` will essentially
|
||||
replicate the behavior of synchronous PPO. Therefore, it's usually not recommended to
|
||||
run synchronous PPO directly (i.e., `main_sync_ppo.py`). The workflow of asynchronous RL
|
||||
is more stable and easier to customize. ::::
|
||||
|
||||
## Command Line Options
|
||||
### Command Line Options
|
||||
|
||||
To view all available options:
|
||||
|
||||
|
@ -60,62 +93,97 @@ To view all available options:
|
|||
python3 training/main_sync_ppo.py --help
|
||||
```
|
||||
|
||||
### Configuration Parameters
|
||||
#### Configuration Parameters
|
||||
|
||||
- **`experiment_name`**: The name of your project.
|
||||
- **`trial_name`**: The name of this trial in your project.
|
||||
- **`{actor|ref}.path`**: The path to the model files.
|
||||
- **`dataset.path`**: The path to the dataset JSONL file.
|
||||
- **`cluster.fileroot`**: The root path for saving training outputs (logs and checkpoints).
|
||||
- **`cluster.fileroot`**: The root path for saving training outputs (logs and
|
||||
checkpoints).
|
||||
- **`n_nodes`**: The number of nodes in the cluster.
|
||||
- **`n_gpus_per_node`**: The number of GPUs per node.
|
||||
- **`allocation_mode`**: The GPU allocation strategy and 3D parallelism configuration for the experiment. Format:
|
||||
- `sglang.d${DP1}m${TP1}p${PP1}+d${DP2}m${TP2}p${PP2}`: Configures parallel strategies for SGLang generation and training respectively. Generation and training use separate GPU sets, and the total GPU count must equal: DP1×TP1×PP1 + DP2×TP2×PP2 = #GPUs.
|
||||
- **`allocation_mode`**: The GPU allocation strategy and 3D parallelism configuration
|
||||
for the experiment. Format:
|
||||
- `sglang.d${DP1}m${TP1}p${PP1}+d${DP2}m${TP2}p${PP2}`: Configures parallel strategies
|
||||
for SGLang generation and training respectively. Generation and training use
|
||||
separate GPU sets, and the total GPU count must equal: DP1×TP1×PP1 + DP2×TP2×PP2 =
|
||||
#GPUs.
|
||||
|
||||
### Training Control
|
||||
#### Training Control
|
||||
|
||||
- **`exp_ctrl.total_train_epochs`**: Number of training epochs (complete dataset iterations).
|
||||
- **`exp_ctrl.save_freq_{epochs|steps|secs}`**: Frequency for saving model parameters to persistent storage. Set to null to disable saving.
|
||||
- **`exp_ctrl.ckpt_freq_{epochs|steps|secs}`**: Frequency for saving temporary parameters for restart capability.
|
||||
- **`dataset.train_bs_n_seqs`**: Training batch size (number of prompts sampled per training iteration).
|
||||
- **`exp_ctrl.total_train_epochs`**: Number of training epochs (complete dataset
|
||||
iterations).
|
||||
- **`exp_ctrl.save_freq_{epochs|steps|secs}`**: Frequency for saving model parameters to
|
||||
persistent storage. Set to null to disable saving.
|
||||
- **`exp_ctrl.ckpt_freq_{epochs|steps|secs}`**: Frequency for saving temporary
|
||||
parameters for restart capability.
|
||||
- **`dataset.train_bs_n_seqs`**: Training batch size (number of prompts sampled per
|
||||
training iteration).
|
||||
- **`group_size`**: Number of responses sampled per prompt.
|
||||
|
||||
### Memory and Performance
|
||||
#### Memory and Performance
|
||||
|
||||
- **`{actor_train|ref_inf|actor_inf}.mb_spec.max_tokens_per_mb`**: Maximum tokens per mini-batch for forward/backward passes during reference model inference and actor model training. Reduce this value to avoid OOM errors.
|
||||
- **`max_concurrent_rollouts`**: The maximum number of concurrent rollouts. SGLang will run out of memory if this value is too large. Defaults to `dataset.train_bs_n_seqs`.
|
||||
- **`{actor_train|ref_inf|actor_inf}.mb_spec.max_tokens_per_mb`**: Maximum tokens per
|
||||
mini-batch for forward/backward passes during reference model inference and actor
|
||||
model training. Reduce this value to avoid OOM errors.
|
||||
- **`max_concurrent_rollouts`**: The maximum number of concurrent rollouts. SGLang will
|
||||
run out of memory if this value is too large. Defaults to `dataset.train_bs_n_seqs`.
|
||||
|
||||
### Algorithm Configuration
|
||||
#### Algorithm Configuration
|
||||
|
||||
- **`max_head_offpolicyness`**: The allowed maximum data staleness. 0 recovers synchronous training. A large value will increase generation throughput but degrade final performance. We recommend keeping this value at 8 or below.
|
||||
- **`ppo.recompute_logprob`**: Whether to compute proximal log probabilities for training. Defaults to True for asynchronous experiments and False for synchronous baselines.
|
||||
- **`ppo.use_decoupled_loss`**: Use decoupled loss to stabilize asynchronous training. Defaults to True.
|
||||
- **`max_head_offpolicyness`**: The allowed maximum data staleness. 0 recovers
|
||||
synchronous training. A large value will increase generation throughput but degrade
|
||||
final performance. We recommend keeping this value at 8 or below.
|
||||
- **`ppo.recompute_logprob`**: Whether to compute proximal log probabilities for
|
||||
training. Defaults to True for asynchronous experiments and False for synchronous
|
||||
baselines.
|
||||
- **`ppo.use_decoupled_loss`**: Use decoupled loss to stabilize asynchronous training.
|
||||
Defaults to True.
|
||||
- **`ppo.gen.max_new_tokens`**: Maximum tokens to generate per prompt.
|
||||
- **`ppo.ppo_n_minibatches`**: Number of mini-batches for dividing data during each PPO update.
|
||||
- **`success_rate_ub`**: Upper bound of success rate. Prompts with a higher success rate will be filtered out.
|
||||
- **`success_rate_lb`**: Lower bound of success rate. Prompts with a lower success rate will be filtered out.
|
||||
- **`ppo.ppo_n_minibatches`**: Number of mini-batches for dividing data during each PPO
|
||||
update.
|
||||
- **`success_rate_ub`**: Upper bound of success rate. Prompts with a higher success rate
|
||||
will be filtered out.
|
||||
- **`success_rate_lb`**: Lower bound of success rate. Prompts with a lower success rate
|
||||
will be filtered out.
|
||||
|
||||
## Monitoring the Training Process
|
||||
|
||||
+ We recommend using [Weights & Biases (wandb)](https://github.com/wandb/wandb) or [SwanLab](https://github.com/SwanHubX/SwanLab) for monitoring—run `wandb login` or `swanlab login`, or set the corresponding environment variable API key (`WANDB_API_KEY` or `SWANLAB_API_KEY`). Set `wandb.mode="online"` or `swanlab.mode="cloud"` in your configuration to upload training statistics. If you cannot connect to the server, you can also use `wandb.mode="offline"` or `swanlab.mode="local"` to save data locally without uploading.
|
||||
### Monitoring the Training Process
|
||||
|
||||
- We recommend using [Weights & Biases (wandb)](https://github.com/wandb/wandb) or
|
||||
[SwanLab](https://github.com/SwanHubX/SwanLab) for monitoring—run `wandb login` or
|
||||
`swanlab login`, or set the corresponding environment variable API key
|
||||
(`WANDB_API_KEY` or `SWANLAB_API_KEY`). Set `wandb.mode="online"` or
|
||||
`swanlab.mode="cloud"` in your configuration to upload training statistics. If you
|
||||
cannot connect to the server, you can also use `wandb.mode="offline"` or
|
||||
`swanlab.mode="local"` to save data locally without uploading.
|
||||
|
||||
You can also use TensorBoard by setting the `tensorboard.path` parameter.
|
||||
|
||||
The main log will be saved to `${fileroot}/logs/${USER}/${experiment_name}/${trial_name}/main.log` and contains the statistics uploaded to wandb.
|
||||
The main log will be saved to
|
||||
`${fileroot}/logs/${USER}/${experiment_name}/${trial_name}/main.log` and contains the
|
||||
statistics uploaded to wandb.
|
||||
|
||||
If SwanLab is enabled, logs will be saved to the directory specified by `swanlab.logdir`.
|
||||
If SwanLab is enabled, logs will be saved to the directory specified by
|
||||
`swanlab.logdir`.
|
||||
|
||||
### Key Training Statistics
|
||||
#### Key Training Statistics
|
||||
|
||||
- **`Epoch 1/5`**: Indicates the total epochs required and the current epoch being trained.
|
||||
- **`step 6/19`**: Shows that the current epoch has 19 steps, with the 6th step just completed.
|
||||
- **`Epoch 1/5`**: Indicates the total epochs required and the current epoch being
|
||||
trained.
|
||||
- **`step 6/19`**: Shows that the current epoch has 19 steps, with the 6th step just
|
||||
completed.
|
||||
- **`global step 6`**: Step count across all epochs.
|
||||
- **`ppo_actor/task_reward/avg`**: Average reward value of all sampled responses in this step. This should steadily increase during training and eventually stabilize.
|
||||
- **`ppo_actor/importance_weight/avg`**: Average importance sampling ratio across all tokens in the PPO loss. This is typically close to 1.0.
|
||||
- **`ppo_actor/actor_clip_ratio/avg`**: Ratio of clipped tokens in PPO loss to total tokens. This is usually less than 0.1.
|
||||
- **`ppo_actor/actor_loss/avg`**: PPO loss value. **This does not show clear trends during training** and should not be used as a performance indicator.
|
||||
- **`ppo_actor/task_reward/avg`**: Average reward value of all sampled responses in this
|
||||
step. This should steadily increase during training and eventually stabilize.
|
||||
- **`ppo_actor/importance_weight/avg`**: Average importance sampling ratio across all
|
||||
tokens in the PPO loss. This is typically close to 1.0.
|
||||
- **`ppo_actor/actor_clip_ratio/avg`**: Ratio of clipped tokens in PPO loss to total
|
||||
tokens. This is usually less than 0.1.
|
||||
- **`ppo_actor/actor_loss/avg`**: PPO loss value. **This does not show clear trends
|
||||
during training** and should not be used as a performance indicator.
|
||||
|
||||
## Next Steps
|
||||
|
||||
[Evaluate your model](eval.md) or check the [troubleshooting section](troubleshooting.md) if you encounter any issues.
|
||||
[Evaluate your model](eval.md) or check the
|
||||
[troubleshooting section](troubleshooting.md) if you encounter any issues.
|
||||
|
|
Loading…
Reference in New Issue