minor revise code walkthrough to be consistent with current impl

This commit is contained in:
晓雷 2025-07-31 11:23:24 +08:00
parent 5ad65846c7
commit 2bfba4804e
2 changed files with 34 additions and 19 deletions

View File

@ -37,12 +37,12 @@ like how you enjoy real-world milk tea (cheers).
## News
**\[2025/07/31\] (v0.4, AReaLite)** We introduce **AReaLite**, a **light-weight**
version of AReaL designed specifically for AI researchers and rapid prototyping.
AReaLite features an **AI-centric** API design that prioritizes ease of use and
algorithm development, while inherently supporting fully asynchronous **agentic RL**.
With 80% fewer lines of code, AReaLite maintains 90% of AReaL's core functionality.
Check out [our AReaLite design doc](/arealite/README.md) and
**\[2025/07/31\] (AReaLite)** We introduce **AReaLite**, a **light-weight** version of
AReaL designed specifically for AI researchers and rapid prototyping. AReaLite features
an **AI-centric** API design that prioritizes ease of use and algorithm development,
while inherently supporting fully asynchronous **agentic RL**. With 80% fewer lines of
code, AReaLite maintains 90% of AReaL's core functionality. Check out
[our AReaLite design doc](/arealite/README.md) and
[the quickstart guide](/docs/tutorial/quickstart.md) to begin your journey with
**AReaLite**!

View File

@ -116,13 +116,14 @@ on separate GPUs from those used for training. The `RemoteSGLangEngine` acts as
that interacts with the servers. `RemoteSGLangEngine` runs in a SPMD manner on every
training process, without occupying any GPUs.
`RemoteSGLangEngine` provides two APIs, `agenerate` and `async_update_weights`. It is
worth mentioning that, in asynchronous RL experiment in AReaLite, inference-side weight
update could happen **in the middle of** generation of one prompt. With that being said,
one output sequence could be generated by multiple versions of models. Let us glimpse
into code of `agenerate` and `async_update_weights` for a better understanding.
`RemoteSGLangEngine` provides two core APIs that access the remote servers, `agenerate`
and `update_weights_async`. It is worth mentioning that, in asynchronous RL experiment
in AReaLite, inference-side weight update could happen **in the middle of** generation
of one prompt. With that being said, one output sequence could be generated by multiple
versions of models. Let us glimpse into code of `agenerate` and `update_weights_async`
for a better understanding.
In `async_update_weights`, the engine first send `pause_generation` requests to all
In `update_weights_async`, the engine first send `pause_generation` requests to all
inference servers, notifying them a weight update is about to happen. Upon receiveing
`pause_generation`, inference servers will immediately stop generating and respond with
already generated tokens. Then, the engine sends `update_weights_from_distributed` (for
@ -133,8 +134,8 @@ start working again.
```python
class RemoteSGLangEngine:
...
def async_update_weights(self, meta: WeightUpdateMeta):
# `async_update_weights` is completely async.
def update_weights_async(self, meta: WeightUpdateMeta):
# `update_weights_async` is completely async.
# It submits task to a ProcessPoolExecutor and returns a future
for addr in self.addresses:
res = requests.post(f"http://{addr}/pause_generation")
@ -195,6 +196,11 @@ class RemoteSGLangEngine:
```
The `InferenceEngine` class is designed to be extensible, supporting not just SGLang but
also other backends like vLLM. While different inference engines may be used, the
rollout management logic remains consistent. This common functionality is abstracted
into the `WorkflowExecutor`, which will be introduced in the following section.
### `RLVRWorkflow` and `WorkflowExecutor`
The rollout data lifecycle is controlled by an `RLVRWorkflow`, which defines how data
@ -341,11 +347,20 @@ def prepare_batch(
pass
```
The usage of `WorkflowExecutor` in the training script is simple:
The `RemoteSGLangEngine` exposes `rollout_batch` and `prepare_batch` by calling them in
the workflow executor:
```python
inf_engine = RemoteSGLangEngine(config.inf_engine)
rollout = WorkflowExecutor(config.rollout, inf_engine)
class RemoteSGLangEngine(InferenceEngine):
...
def prepare_batch(self, *args, **kwargs):
return self.workflow_executor.prepare_batch(*args, **kwargs)
```
The usage of `RemoteSGLangEngine` in the training script is simple:
```python
rollout = RemoteSGLangEngine(config.inf_engine)
rollout.initialize()
eval_rollout = ...
@ -444,7 +459,7 @@ inference servers:
```python
rollout.pause()
if dist.get_rank() == 0:
future = rollout.async_update_weights(weight_update_meta)
future = rollout.update_weights_async(weight_update_meta)
actor.upload_weights(weight_update_meta)
if dist.get_rank() == 0:
future.result()
@ -482,7 +497,7 @@ for global_step in range(max_steps):
rollout.pause()
if dist.get_rank() == 0:
future = rollout.async_update_weights(weight_update_meta)
future = rollout.update_weights_async(weight_update_meta)
actor.upload_weights(weight_update_meta)
if dist.get_rank() == 0:
future.result()