workflow executor doc

This commit is contained in:
晓雷 2025-07-25 15:47:24 +08:00
parent aa6c28ed24
commit e7c713125d
2 changed files with 99 additions and 15 deletions

View File

@ -109,12 +109,100 @@ details.
## Rollout
The data lifecycle is controlled by an `RLVRWorkflow`, which defines how data progresses
from prompts to complete rollout data containing all fields required for training. Our
example shows a single-turn RLVR workflow with a math reward function. The core logic of
the workflow is implemented in an async method `arun_episode`, which takes a prompt,
generate answers with `RemoteSGLangEngine`, computes rewards, and populates additional
fields to produce finalized training data.
### Inference Engine: `RemoteSGLangEngine`
In AReaLite, generation tasks are offloaded to remote inference servers, which operate
on separate GPUs from those used for training. The `RemoteSGLangEngine` acts as a client
that interacts with the servers. `RemoteSGLangEngine` runs in a SPMD manner on every
training process, without occupying any GPUs.
`RemoteSGLangEngine` provides two APIs, `agenerate` and `update_weights`. It is worth
mentioning that, in asynchronous RL experiment in AReaLite, inference-side weight update
could happen **in the middle of** generation of one prompt. With that being said, one
output sequence could be generated by multiple versions of models. Let us glimpse into
code of `agenerate` and `update_weights` for a better understanding.
In `update_weights`, the engine first send `pause_generation` requests to all inference
servers, notifying them a weight update is about to happen. Upon receiveing
`pause_generation`, inference servers will immediately stop generating and respond with
already generated tokens. Then, the engine sends `update_weights_from_distributed` (for
NCCL update) or `update_weights_from_disk` (for disk update). After the update is
finished, the engine sends `continue_generation` to inference server telling them to
start working again.
```python
class RemoteSGLangEngine:
...
def update_weights(self, meta: WeightUpdateMeta):
# `update_weights` is completely async.
# It submits task to a ProcessPoolExecutor and returns a future
for addr in self.addresses:
res = requests.post(f"http://{addr}/pause_generation")
if meta.type == "nccl":
future = self.executor.submit(
# a function that send `update_weights_from_distributed` request
update_weights_from_distributed,
)
elif meta.type == "disk":
...
def callback(future):
for addr in self.addresses
requests.post(f"http://{addr}/continue_generation")
future.add_done_callback(callback)
return future
```
`agenerate` takes an `LLMRequest` with `input_ids` of **a single prompt** and generation
hyperparameters, and returns the final generation result, an `LLMResponse` with
`output_tokens` and other outputs. Since the generation could be interrupted,
`agenerate` iteratively prepares payload, sends requests and receives responses until
the generation finishes.
```python
class RemoteSGLangEngine:
...
async def agenerate(self, req: LLMRequest):
payload = ... # prepare payload for request
# If request is from the same workflow, choose old server
# to allow KVCache reuse. Otherwise choose server in a round
# robin manner.
server_addr = self.choose_server(req)
stop_reason = None
# other outputs are omitted for simplicity
output_tokens = []
while (stop_reason != "stop" and len(output_tokens) < max_new_tokens):
# Request is interrupted, wait to avoid contention
if stop_reason is not None:
await asyncio.sleep(0.5)
# send request to remote sever
result = await arequest_with_retry(
addr=server_addr,
endpoint="/generate",
payload=payload,
method="POST"
)
output_tokens.extend(result["output_ids"])
# prepare payload for the next request
payload["input_ids"] += results["output_ids"]
payload["sample_params"]["max_new_tokens"] -= len(results["output_ids"])
return LLMResponse(
input_tokens=req.input_ids,
output_tokens=output_tokens,
...
)
```
### `RLVRWorkflow` and `WorkflowExecutor`
The rollout data lifecycle is controlled by an `RLVRWorkflow`, which defines how data
progresses from prompts to complete rollout data containing all fields required for
training. Our example shows a single-turn RLVR workflow with a math reward function. The
core logic of the workflow is implemented in an async method `arun_episode`, which takes
a prompt, generate answers with `RemoteSGLangEngine`, computes rewards, and populates
additional fields to produce finalized training data.
```python
class RLVRWorkflow(RolloutWorkflow):
@ -158,12 +246,7 @@ workflow = RLVRWorkflow(
)
```
In AReaLite, generation tasks are offloaded to remote inference servers, which operate
on separate GPUs from those used for training. The `RemoteSGLangEngine` acts as a client
that interacts with the servers. `RemoteSGLangEngine` runs in a SPMD manner on every
training process, without occupying any GPUs.
`RemoteSGLangEngine` is responsible for managing the data streaming through rollout
`WorkflowExecutor` is responsible for managing the data streaming through rollout
workflows, and collates completed rollout data into batched training samples. When
initializing, it launches a rollout thread that runs rollout workflows as `asyncio`
tasks. The following code shows the simplified version of rollout thread implementation,
@ -177,7 +260,7 @@ which iteratively:
- Gathers data from finished workflows and puts them into `output_queue`
```python
class RemoteSGLangEngine(InferenceEngine):
class WorkflowExecutor:
...
async def _rollout_thread_async(self):
rid = 0
@ -257,10 +340,11 @@ def prepare_batch(
pass
```
The usage of `RemoteSGLangEngine` in the training script is simple:
The usage of `WorkflowExecutor` in the training script is simple:
```python
rollout = RemoteSGLangEngine(config.rollout)
inf_engine = RemoteSGLangEngine(config.inf_engine)
rollout = WorkflowExecutor(config.rollout, inf_engine)
rollout.initialize()
eval_rollout = ...

Binary file not shown.

Before

Width:  |  Height:  |  Size: 100 KiB

After

Width:  |  Height:  |  Size: 50 KiB