mirror of https://github.com/inclusionAI/AReaL
minor revise code walkthrough to be consistent with current impl
This commit is contained in:
parent
5ad65846c7
commit
2bfba4804e
12
README.md
12
README.md
|
@ -37,12 +37,12 @@ like how you enjoy real-world milk tea (cheers).
|
|||
|
||||
## News
|
||||
|
||||
**\[2025/07/31\] (v0.4, AReaLite)** We introduce **AReaLite**, a **light-weight**
|
||||
version of AReaL designed specifically for AI researchers and rapid prototyping.
|
||||
AReaLite features an **AI-centric** API design that prioritizes ease of use and
|
||||
algorithm development, while inherently supporting fully asynchronous **agentic RL**.
|
||||
With 80% fewer lines of code, AReaLite maintains 90% of AReaL's core functionality.
|
||||
Check out [our AReaLite design doc](/arealite/README.md) and
|
||||
**\[2025/07/31\] (AReaLite)** We introduce **AReaLite**, a **light-weight** version of
|
||||
AReaL designed specifically for AI researchers and rapid prototyping. AReaLite features
|
||||
an **AI-centric** API design that prioritizes ease of use and algorithm development,
|
||||
while inherently supporting fully asynchronous **agentic RL**. With 80% fewer lines of
|
||||
code, AReaLite maintains 90% of AReaL's core functionality. Check out
|
||||
[our AReaLite design doc](/arealite/README.md) and
|
||||
[the quickstart guide](/docs/tutorial/quickstart.md) to begin your journey with
|
||||
**AReaLite**!
|
||||
|
||||
|
|
|
@ -116,13 +116,14 @@ on separate GPUs from those used for training. The `RemoteSGLangEngine` acts as
|
|||
that interacts with the servers. `RemoteSGLangEngine` runs in a SPMD manner on every
|
||||
training process, without occupying any GPUs.
|
||||
|
||||
`RemoteSGLangEngine` provides two APIs, `agenerate` and `async_update_weights`. It is
|
||||
worth mentioning that, in asynchronous RL experiment in AReaLite, inference-side weight
|
||||
update could happen **in the middle of** generation of one prompt. With that being said,
|
||||
one output sequence could be generated by multiple versions of models. Let us glimpse
|
||||
into code of `agenerate` and `async_update_weights` for a better understanding.
|
||||
`RemoteSGLangEngine` provides two core APIs that access the remote servers, `agenerate`
|
||||
and `update_weights_async`. It is worth mentioning that, in asynchronous RL experiment
|
||||
in AReaLite, inference-side weight update could happen **in the middle of** generation
|
||||
of one prompt. With that being said, one output sequence could be generated by multiple
|
||||
versions of models. Let us glimpse into code of `agenerate` and `update_weights_async`
|
||||
for a better understanding.
|
||||
|
||||
In `async_update_weights`, the engine first send `pause_generation` requests to all
|
||||
In `update_weights_async`, the engine first send `pause_generation` requests to all
|
||||
inference servers, notifying them a weight update is about to happen. Upon receiveing
|
||||
`pause_generation`, inference servers will immediately stop generating and respond with
|
||||
already generated tokens. Then, the engine sends `update_weights_from_distributed` (for
|
||||
|
@ -133,8 +134,8 @@ start working again.
|
|||
```python
|
||||
class RemoteSGLangEngine:
|
||||
...
|
||||
def async_update_weights(self, meta: WeightUpdateMeta):
|
||||
# `async_update_weights` is completely async.
|
||||
def update_weights_async(self, meta: WeightUpdateMeta):
|
||||
# `update_weights_async` is completely async.
|
||||
# It submits task to a ProcessPoolExecutor and returns a future
|
||||
for addr in self.addresses:
|
||||
res = requests.post(f"http://{addr}/pause_generation")
|
||||
|
@ -195,6 +196,11 @@ class RemoteSGLangEngine:
|
|||
|
||||
```
|
||||
|
||||
The `InferenceEngine` class is designed to be extensible, supporting not just SGLang but
|
||||
also other backends like vLLM. While different inference engines may be used, the
|
||||
rollout management logic remains consistent. This common functionality is abstracted
|
||||
into the `WorkflowExecutor`, which will be introduced in the following section.
|
||||
|
||||
### `RLVRWorkflow` and `WorkflowExecutor`
|
||||
|
||||
The rollout data lifecycle is controlled by an `RLVRWorkflow`, which defines how data
|
||||
|
@ -341,11 +347,20 @@ def prepare_batch(
|
|||
pass
|
||||
```
|
||||
|
||||
The usage of `WorkflowExecutor` in the training script is simple:
|
||||
The `RemoteSGLangEngine` exposes `rollout_batch` and `prepare_batch` by calling them in
|
||||
the workflow executor:
|
||||
|
||||
```python
|
||||
inf_engine = RemoteSGLangEngine(config.inf_engine)
|
||||
rollout = WorkflowExecutor(config.rollout, inf_engine)
|
||||
class RemoteSGLangEngine(InferenceEngine):
|
||||
...
|
||||
def prepare_batch(self, *args, **kwargs):
|
||||
return self.workflow_executor.prepare_batch(*args, **kwargs)
|
||||
```
|
||||
|
||||
The usage of `RemoteSGLangEngine` in the training script is simple:
|
||||
|
||||
```python
|
||||
rollout = RemoteSGLangEngine(config.inf_engine)
|
||||
rollout.initialize()
|
||||
eval_rollout = ...
|
||||
|
||||
|
@ -444,7 +459,7 @@ inference servers:
|
|||
```python
|
||||
rollout.pause()
|
||||
if dist.get_rank() == 0:
|
||||
future = rollout.async_update_weights(weight_update_meta)
|
||||
future = rollout.update_weights_async(weight_update_meta)
|
||||
actor.upload_weights(weight_update_meta)
|
||||
if dist.get_rank() == 0:
|
||||
future.result()
|
||||
|
@ -482,7 +497,7 @@ for global_step in range(max_steps):
|
|||
|
||||
rollout.pause()
|
||||
if dist.get_rank() == 0:
|
||||
future = rollout.async_update_weights(weight_update_meta)
|
||||
future = rollout.update_weights_async(weight_update_meta)
|
||||
actor.upload_weights(weight_update_meta)
|
||||
if dist.get_rank() == 0:
|
||||
future.result()
|
||||
|
|
Loading…
Reference in New Issue