workflow executor doc

This commit is contained in:
晓雷 2025-07-25 15:47:24 +08:00
parent aa6c28ed24
commit e7c713125d
2 changed files with 99 additions and 15 deletions

View File

@ -109,12 +109,100 @@ details.
## Rollout ## Rollout
The data lifecycle is controlled by an `RLVRWorkflow`, which defines how data progresses ### Inference Engine: `RemoteSGLangEngine`
from prompts to complete rollout data containing all fields required for training. Our
example shows a single-turn RLVR workflow with a math reward function. The core logic of In AReaLite, generation tasks are offloaded to remote inference servers, which operate
the workflow is implemented in an async method `arun_episode`, which takes a prompt, on separate GPUs from those used for training. The `RemoteSGLangEngine` acts as a client
generate answers with `RemoteSGLangEngine`, computes rewards, and populates additional that interacts with the servers. `RemoteSGLangEngine` runs in a SPMD manner on every
fields to produce finalized training data. training process, without occupying any GPUs.
`RemoteSGLangEngine` provides two APIs, `agenerate` and `update_weights`. It is worth
mentioning that, in asynchronous RL experiment in AReaLite, inference-side weight update
could happen **in the middle of** generation of one prompt. With that being said, one
output sequence could be generated by multiple versions of models. Let us glimpse into
code of `agenerate` and `update_weights` for a better understanding.
In `update_weights`, the engine first send `pause_generation` requests to all inference
servers, notifying them a weight update is about to happen. Upon receiveing
`pause_generation`, inference servers will immediately stop generating and respond with
already generated tokens. Then, the engine sends `update_weights_from_distributed` (for
NCCL update) or `update_weights_from_disk` (for disk update). After the update is
finished, the engine sends `continue_generation` to inference server telling them to
start working again.
```python
class RemoteSGLangEngine:
...
def update_weights(self, meta: WeightUpdateMeta):
# `update_weights` is completely async.
# It submits task to a ProcessPoolExecutor and returns a future
for addr in self.addresses:
res = requests.post(f"http://{addr}/pause_generation")
if meta.type == "nccl":
future = self.executor.submit(
# a function that send `update_weights_from_distributed` request
update_weights_from_distributed,
)
elif meta.type == "disk":
...
def callback(future):
for addr in self.addresses
requests.post(f"http://{addr}/continue_generation")
future.add_done_callback(callback)
return future
```
`agenerate` takes an `LLMRequest` with `input_ids` of **a single prompt** and generation
hyperparameters, and returns the final generation result, an `LLMResponse` with
`output_tokens` and other outputs. Since the generation could be interrupted,
`agenerate` iteratively prepares payload, sends requests and receives responses until
the generation finishes.
```python
class RemoteSGLangEngine:
...
async def agenerate(self, req: LLMRequest):
payload = ... # prepare payload for request
# If request is from the same workflow, choose old server
# to allow KVCache reuse. Otherwise choose server in a round
# robin manner.
server_addr = self.choose_server(req)
stop_reason = None
# other outputs are omitted for simplicity
output_tokens = []
while (stop_reason != "stop" and len(output_tokens) < max_new_tokens):
# Request is interrupted, wait to avoid contention
if stop_reason is not None:
await asyncio.sleep(0.5)
# send request to remote sever
result = await arequest_with_retry(
addr=server_addr,
endpoint="/generate",
payload=payload,
method="POST"
)
output_tokens.extend(result["output_ids"])
# prepare payload for the next request
payload["input_ids"] += results["output_ids"]
payload["sample_params"]["max_new_tokens"] -= len(results["output_ids"])
return LLMResponse(
input_tokens=req.input_ids,
output_tokens=output_tokens,
...
)
```
### `RLVRWorkflow` and `WorkflowExecutor`
The rollout data lifecycle is controlled by an `RLVRWorkflow`, which defines how data
progresses from prompts to complete rollout data containing all fields required for
training. Our example shows a single-turn RLVR workflow with a math reward function. The
core logic of the workflow is implemented in an async method `arun_episode`, which takes
a prompt, generate answers with `RemoteSGLangEngine`, computes rewards, and populates
additional fields to produce finalized training data.
```python ```python
class RLVRWorkflow(RolloutWorkflow): class RLVRWorkflow(RolloutWorkflow):
@ -158,12 +246,7 @@ workflow = RLVRWorkflow(
) )
``` ```
In AReaLite, generation tasks are offloaded to remote inference servers, which operate `WorkflowExecutor` is responsible for managing the data streaming through rollout
on separate GPUs from those used for training. The `RemoteSGLangEngine` acts as a client
that interacts with the servers. `RemoteSGLangEngine` runs in a SPMD manner on every
training process, without occupying any GPUs.
`RemoteSGLangEngine` is responsible for managing the data streaming through rollout
workflows, and collates completed rollout data into batched training samples. When workflows, and collates completed rollout data into batched training samples. When
initializing, it launches a rollout thread that runs rollout workflows as `asyncio` initializing, it launches a rollout thread that runs rollout workflows as `asyncio`
tasks. The following code shows the simplified version of rollout thread implementation, tasks. The following code shows the simplified version of rollout thread implementation,
@ -177,7 +260,7 @@ which iteratively:
- Gathers data from finished workflows and puts them into `output_queue` - Gathers data from finished workflows and puts them into `output_queue`
```python ```python
class RemoteSGLangEngine(InferenceEngine): class WorkflowExecutor:
... ...
async def _rollout_thread_async(self): async def _rollout_thread_async(self):
rid = 0 rid = 0
@ -257,10 +340,11 @@ def prepare_batch(
pass pass
``` ```
The usage of `RemoteSGLangEngine` in the training script is simple: The usage of `WorkflowExecutor` in the training script is simple:
```python ```python
rollout = RemoteSGLangEngine(config.rollout) inf_engine = RemoteSGLangEngine(config.inf_engine)
rollout = WorkflowExecutor(config.rollout, inf_engine)
rollout.initialize() rollout.initialize()
eval_rollout = ... eval_rollout = ...

Binary file not shown.

Before

Width:  |  Height:  |  Size: 100 KiB

After

Width:  |  Height:  |  Size: 50 KiB