mirror of https://github.com/inclusionAI/AReaL
workflow executor doc
This commit is contained in:
parent
aa6c28ed24
commit
e7c713125d
|
@ -109,12 +109,100 @@ details.
|
||||||
|
|
||||||
## Rollout
|
## Rollout
|
||||||
|
|
||||||
The data lifecycle is controlled by an `RLVRWorkflow`, which defines how data progresses
|
### Inference Engine: `RemoteSGLangEngine`
|
||||||
from prompts to complete rollout data containing all fields required for training. Our
|
|
||||||
example shows a single-turn RLVR workflow with a math reward function. The core logic of
|
In AReaLite, generation tasks are offloaded to remote inference servers, which operate
|
||||||
the workflow is implemented in an async method `arun_episode`, which takes a prompt,
|
on separate GPUs from those used for training. The `RemoteSGLangEngine` acts as a client
|
||||||
generate answers with `RemoteSGLangEngine`, computes rewards, and populates additional
|
that interacts with the servers. `RemoteSGLangEngine` runs in a SPMD manner on every
|
||||||
fields to produce finalized training data.
|
training process, without occupying any GPUs.
|
||||||
|
|
||||||
|
`RemoteSGLangEngine` provides two APIs, `agenerate` and `update_weights`. It is worth
|
||||||
|
mentioning that, in asynchronous RL experiment in AReaLite, inference-side weight update
|
||||||
|
could happen **in the middle of** generation of one prompt. With that being said, one
|
||||||
|
output sequence could be generated by multiple versions of models. Let us glimpse into
|
||||||
|
code of `agenerate` and `update_weights` for a better understanding.
|
||||||
|
|
||||||
|
In `update_weights`, the engine first send `pause_generation` requests to all inference
|
||||||
|
servers, notifying them a weight update is about to happen. Upon receiveing
|
||||||
|
`pause_generation`, inference servers will immediately stop generating and respond with
|
||||||
|
already generated tokens. Then, the engine sends `update_weights_from_distributed` (for
|
||||||
|
NCCL update) or `update_weights_from_disk` (for disk update). After the update is
|
||||||
|
finished, the engine sends `continue_generation` to inference server telling them to
|
||||||
|
start working again.
|
||||||
|
|
||||||
|
```python
|
||||||
|
class RemoteSGLangEngine:
|
||||||
|
...
|
||||||
|
def update_weights(self, meta: WeightUpdateMeta):
|
||||||
|
# `update_weights` is completely async.
|
||||||
|
# It submits task to a ProcessPoolExecutor and returns a future
|
||||||
|
for addr in self.addresses:
|
||||||
|
res = requests.post(f"http://{addr}/pause_generation")
|
||||||
|
if meta.type == "nccl":
|
||||||
|
future = self.executor.submit(
|
||||||
|
# a function that send `update_weights_from_distributed` request
|
||||||
|
update_weights_from_distributed,
|
||||||
|
)
|
||||||
|
elif meta.type == "disk":
|
||||||
|
...
|
||||||
|
|
||||||
|
def callback(future):
|
||||||
|
for addr in self.addresses
|
||||||
|
requests.post(f"http://{addr}/continue_generation")
|
||||||
|
|
||||||
|
future.add_done_callback(callback)
|
||||||
|
return future
|
||||||
|
```
|
||||||
|
|
||||||
|
`agenerate` takes an `LLMRequest` with `input_ids` of **a single prompt** and generation
|
||||||
|
hyperparameters, and returns the final generation result, an `LLMResponse` with
|
||||||
|
`output_tokens` and other outputs. Since the generation could be interrupted,
|
||||||
|
`agenerate` iteratively prepares payload, sends requests and receives responses until
|
||||||
|
the generation finishes.
|
||||||
|
|
||||||
|
```python
|
||||||
|
class RemoteSGLangEngine:
|
||||||
|
...
|
||||||
|
async def agenerate(self, req: LLMRequest):
|
||||||
|
payload = ... # prepare payload for request
|
||||||
|
# If request is from the same workflow, choose old server
|
||||||
|
# to allow KVCache reuse. Otherwise choose server in a round
|
||||||
|
# robin manner.
|
||||||
|
server_addr = self.choose_server(req)
|
||||||
|
stop_reason = None
|
||||||
|
# other outputs are omitted for simplicity
|
||||||
|
output_tokens = []
|
||||||
|
while (stop_reason != "stop" and len(output_tokens) < max_new_tokens):
|
||||||
|
# Request is interrupted, wait to avoid contention
|
||||||
|
if stop_reason is not None:
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
# send request to remote sever
|
||||||
|
result = await arequest_with_retry(
|
||||||
|
addr=server_addr,
|
||||||
|
endpoint="/generate",
|
||||||
|
payload=payload,
|
||||||
|
method="POST"
|
||||||
|
)
|
||||||
|
output_tokens.extend(result["output_ids"])
|
||||||
|
# prepare payload for the next request
|
||||||
|
payload["input_ids"] += results["output_ids"]
|
||||||
|
payload["sample_params"]["max_new_tokens"] -= len(results["output_ids"])
|
||||||
|
return LLMResponse(
|
||||||
|
input_tokens=req.input_ids,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
...
|
||||||
|
)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### `RLVRWorkflow` and `WorkflowExecutor`
|
||||||
|
|
||||||
|
The rollout data lifecycle is controlled by an `RLVRWorkflow`, which defines how data
|
||||||
|
progresses from prompts to complete rollout data containing all fields required for
|
||||||
|
training. Our example shows a single-turn RLVR workflow with a math reward function. The
|
||||||
|
core logic of the workflow is implemented in an async method `arun_episode`, which takes
|
||||||
|
a prompt, generate answers with `RemoteSGLangEngine`, computes rewards, and populates
|
||||||
|
additional fields to produce finalized training data.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class RLVRWorkflow(RolloutWorkflow):
|
class RLVRWorkflow(RolloutWorkflow):
|
||||||
|
@ -158,12 +246,7 @@ workflow = RLVRWorkflow(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
In AReaLite, generation tasks are offloaded to remote inference servers, which operate
|
`WorkflowExecutor` is responsible for managing the data streaming through rollout
|
||||||
on separate GPUs from those used for training. The `RemoteSGLangEngine` acts as a client
|
|
||||||
that interacts with the servers. `RemoteSGLangEngine` runs in a SPMD manner on every
|
|
||||||
training process, without occupying any GPUs.
|
|
||||||
|
|
||||||
`RemoteSGLangEngine` is responsible for managing the data streaming through rollout
|
|
||||||
workflows, and collates completed rollout data into batched training samples. When
|
workflows, and collates completed rollout data into batched training samples. When
|
||||||
initializing, it launches a rollout thread that runs rollout workflows as `asyncio`
|
initializing, it launches a rollout thread that runs rollout workflows as `asyncio`
|
||||||
tasks. The following code shows the simplified version of rollout thread implementation,
|
tasks. The following code shows the simplified version of rollout thread implementation,
|
||||||
|
@ -177,7 +260,7 @@ which iteratively:
|
||||||
- Gathers data from finished workflows and puts them into `output_queue`
|
- Gathers data from finished workflows and puts them into `output_queue`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class RemoteSGLangEngine(InferenceEngine):
|
class WorkflowExecutor:
|
||||||
...
|
...
|
||||||
async def _rollout_thread_async(self):
|
async def _rollout_thread_async(self):
|
||||||
rid = 0
|
rid = 0
|
||||||
|
@ -257,10 +340,11 @@ def prepare_batch(
|
||||||
pass
|
pass
|
||||||
```
|
```
|
||||||
|
|
||||||
The usage of `RemoteSGLangEngine` in the training script is simple:
|
The usage of `WorkflowExecutor` in the training script is simple:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
rollout = RemoteSGLangEngine(config.rollout)
|
inf_engine = RemoteSGLangEngine(config.inf_engine)
|
||||||
|
rollout = WorkflowExecutor(config.rollout, inf_engine)
|
||||||
rollout.initialize()
|
rollout.initialize()
|
||||||
eval_rollout = ...
|
eval_rollout = ...
|
||||||
|
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 100 KiB After Width: | Height: | Size: 50 KiB |
Loading…
Reference in New Issue