mirror of https://github.com/inclusionAI/AReaL
workflow executor doc
This commit is contained in:
parent
aa6c28ed24
commit
e7c713125d
|
@ -109,12 +109,100 @@ details.
|
|||
|
||||
## Rollout
|
||||
|
||||
The data lifecycle is controlled by an `RLVRWorkflow`, which defines how data progresses
|
||||
from prompts to complete rollout data containing all fields required for training. Our
|
||||
example shows a single-turn RLVR workflow with a math reward function. The core logic of
|
||||
the workflow is implemented in an async method `arun_episode`, which takes a prompt,
|
||||
generate answers with `RemoteSGLangEngine`, computes rewards, and populates additional
|
||||
fields to produce finalized training data.
|
||||
### Inference Engine: `RemoteSGLangEngine`
|
||||
|
||||
In AReaLite, generation tasks are offloaded to remote inference servers, which operate
|
||||
on separate GPUs from those used for training. The `RemoteSGLangEngine` acts as a client
|
||||
that interacts with the servers. `RemoteSGLangEngine` runs in a SPMD manner on every
|
||||
training process, without occupying any GPUs.
|
||||
|
||||
`RemoteSGLangEngine` provides two APIs, `agenerate` and `update_weights`. It is worth
|
||||
mentioning that, in asynchronous RL experiment in AReaLite, inference-side weight update
|
||||
could happen **in the middle of** generation of one prompt. With that being said, one
|
||||
output sequence could be generated by multiple versions of models. Let us glimpse into
|
||||
code of `agenerate` and `update_weights` for a better understanding.
|
||||
|
||||
In `update_weights`, the engine first send `pause_generation` requests to all inference
|
||||
servers, notifying them a weight update is about to happen. Upon receiveing
|
||||
`pause_generation`, inference servers will immediately stop generating and respond with
|
||||
already generated tokens. Then, the engine sends `update_weights_from_distributed` (for
|
||||
NCCL update) or `update_weights_from_disk` (for disk update). After the update is
|
||||
finished, the engine sends `continue_generation` to inference server telling them to
|
||||
start working again.
|
||||
|
||||
```python
|
||||
class RemoteSGLangEngine:
|
||||
...
|
||||
def update_weights(self, meta: WeightUpdateMeta):
|
||||
# `update_weights` is completely async.
|
||||
# It submits task to a ProcessPoolExecutor and returns a future
|
||||
for addr in self.addresses:
|
||||
res = requests.post(f"http://{addr}/pause_generation")
|
||||
if meta.type == "nccl":
|
||||
future = self.executor.submit(
|
||||
# a function that send `update_weights_from_distributed` request
|
||||
update_weights_from_distributed,
|
||||
)
|
||||
elif meta.type == "disk":
|
||||
...
|
||||
|
||||
def callback(future):
|
||||
for addr in self.addresses
|
||||
requests.post(f"http://{addr}/continue_generation")
|
||||
|
||||
future.add_done_callback(callback)
|
||||
return future
|
||||
```
|
||||
|
||||
`agenerate` takes an `LLMRequest` with `input_ids` of **a single prompt** and generation
|
||||
hyperparameters, and returns the final generation result, an `LLMResponse` with
|
||||
`output_tokens` and other outputs. Since the generation could be interrupted,
|
||||
`agenerate` iteratively prepares payload, sends requests and receives responses until
|
||||
the generation finishes.
|
||||
|
||||
```python
|
||||
class RemoteSGLangEngine:
|
||||
...
|
||||
async def agenerate(self, req: LLMRequest):
|
||||
payload = ... # prepare payload for request
|
||||
# If request is from the same workflow, choose old server
|
||||
# to allow KVCache reuse. Otherwise choose server in a round
|
||||
# robin manner.
|
||||
server_addr = self.choose_server(req)
|
||||
stop_reason = None
|
||||
# other outputs are omitted for simplicity
|
||||
output_tokens = []
|
||||
while (stop_reason != "stop" and len(output_tokens) < max_new_tokens):
|
||||
# Request is interrupted, wait to avoid contention
|
||||
if stop_reason is not None:
|
||||
await asyncio.sleep(0.5)
|
||||
# send request to remote sever
|
||||
result = await arequest_with_retry(
|
||||
addr=server_addr,
|
||||
endpoint="/generate",
|
||||
payload=payload,
|
||||
method="POST"
|
||||
)
|
||||
output_tokens.extend(result["output_ids"])
|
||||
# prepare payload for the next request
|
||||
payload["input_ids"] += results["output_ids"]
|
||||
payload["sample_params"]["max_new_tokens"] -= len(results["output_ids"])
|
||||
return LLMResponse(
|
||||
input_tokens=req.input_ids,
|
||||
output_tokens=output_tokens,
|
||||
...
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### `RLVRWorkflow` and `WorkflowExecutor`
|
||||
|
||||
The rollout data lifecycle is controlled by an `RLVRWorkflow`, which defines how data
|
||||
progresses from prompts to complete rollout data containing all fields required for
|
||||
training. Our example shows a single-turn RLVR workflow with a math reward function. The
|
||||
core logic of the workflow is implemented in an async method `arun_episode`, which takes
|
||||
a prompt, generate answers with `RemoteSGLangEngine`, computes rewards, and populates
|
||||
additional fields to produce finalized training data.
|
||||
|
||||
```python
|
||||
class RLVRWorkflow(RolloutWorkflow):
|
||||
|
@ -158,12 +246,7 @@ workflow = RLVRWorkflow(
|
|||
)
|
||||
```
|
||||
|
||||
In AReaLite, generation tasks are offloaded to remote inference servers, which operate
|
||||
on separate GPUs from those used for training. The `RemoteSGLangEngine` acts as a client
|
||||
that interacts with the servers. `RemoteSGLangEngine` runs in a SPMD manner on every
|
||||
training process, without occupying any GPUs.
|
||||
|
||||
`RemoteSGLangEngine` is responsible for managing the data streaming through rollout
|
||||
`WorkflowExecutor` is responsible for managing the data streaming through rollout
|
||||
workflows, and collates completed rollout data into batched training samples. When
|
||||
initializing, it launches a rollout thread that runs rollout workflows as `asyncio`
|
||||
tasks. The following code shows the simplified version of rollout thread implementation,
|
||||
|
@ -177,7 +260,7 @@ which iteratively:
|
|||
- Gathers data from finished workflows and puts them into `output_queue`
|
||||
|
||||
```python
|
||||
class RemoteSGLangEngine(InferenceEngine):
|
||||
class WorkflowExecutor:
|
||||
...
|
||||
async def _rollout_thread_async(self):
|
||||
rid = 0
|
||||
|
@ -257,10 +340,11 @@ def prepare_batch(
|
|||
pass
|
||||
```
|
||||
|
||||
The usage of `RemoteSGLangEngine` in the training script is simple:
|
||||
The usage of `WorkflowExecutor` in the training script is simple:
|
||||
|
||||
```python
|
||||
rollout = RemoteSGLangEngine(config.rollout)
|
||||
inf_engine = RemoteSGLangEngine(config.inf_engine)
|
||||
rollout = WorkflowExecutor(config.rollout, inf_engine)
|
||||
rollout.initialize()
|
||||
eval_rollout = ...
|
||||
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 100 KiB After Width: | Height: | Size: 50 KiB |
Loading…
Reference in New Issue