diff --git a/docs/arealite/gsm8k_grpo.md b/docs/arealite/gsm8k_grpo.md index 26125d3..a30a254 100644 --- a/docs/arealite/gsm8k_grpo.md +++ b/docs/arealite/gsm8k_grpo.md @@ -109,12 +109,100 @@ details. ## Rollout -The data lifecycle is controlled by an `RLVRWorkflow`, which defines how data progresses -from prompts to complete rollout data containing all fields required for training. Our -example shows a single-turn RLVR workflow with a math reward function. The core logic of -the workflow is implemented in an async method `arun_episode`, which takes a prompt, -generate answers with `RemoteSGLangEngine`, computes rewards, and populates additional -fields to produce finalized training data. +### Inference Engine: `RemoteSGLangEngine` + +In AReaLite, generation tasks are offloaded to remote inference servers, which operate +on separate GPUs from those used for training. The `RemoteSGLangEngine` acts as a client +that interacts with the servers. `RemoteSGLangEngine` runs in a SPMD manner on every +training process, without occupying any GPUs. + +`RemoteSGLangEngine` provides two APIs, `agenerate` and `update_weights`. It is worth +mentioning that, in asynchronous RL experiment in AReaLite, inference-side weight update +could happen **in the middle of** generation of one prompt. With that being said, one +output sequence could be generated by multiple versions of models. Let us glimpse into +code of `agenerate` and `update_weights` for a better understanding. + +In `update_weights`, the engine first send `pause_generation` requests to all inference +servers, notifying them a weight update is about to happen. Upon receiveing +`pause_generation`, inference servers will immediately stop generating and respond with +already generated tokens. Then, the engine sends `update_weights_from_distributed` (for +NCCL update) or `update_weights_from_disk` (for disk update). After the update is +finished, the engine sends `continue_generation` to inference server telling them to +start working again. + +```python +class RemoteSGLangEngine: + ... + def update_weights(self, meta: WeightUpdateMeta): + # `update_weights` is completely async. + # It submits task to a ProcessPoolExecutor and returns a future + for addr in self.addresses: + res = requests.post(f"http://{addr}/pause_generation") + if meta.type == "nccl": + future = self.executor.submit( + # a function that send `update_weights_from_distributed` request + update_weights_from_distributed, + ) + elif meta.type == "disk": + ... + + def callback(future): + for addr in self.addresses + requests.post(f"http://{addr}/continue_generation") + + future.add_done_callback(callback) + return future +``` + +`agenerate` takes an `LLMRequest` with `input_ids` of **a single prompt** and generation +hyperparameters, and returns the final generation result, an `LLMResponse` with +`output_tokens` and other outputs. Since the generation could be interrupted, +`agenerate` iteratively prepares payload, sends requests and receives responses until +the generation finishes. + +```python +class RemoteSGLangEngine: + ... + async def agenerate(self, req: LLMRequest): + payload = ... # prepare payload for request + # If request is from the same workflow, choose old server + # to allow KVCache reuse. Otherwise choose server in a round + # robin manner. + server_addr = self.choose_server(req) + stop_reason = None + # other outputs are omitted for simplicity + output_tokens = [] + while (stop_reason != "stop" and len(output_tokens) < max_new_tokens): + # Request is interrupted, wait to avoid contention + if stop_reason is not None: + await asyncio.sleep(0.5) + # send request to remote sever + result = await arequest_with_retry( + addr=server_addr, + endpoint="/generate", + payload=payload, + method="POST" + ) + output_tokens.extend(result["output_ids"]) + # prepare payload for the next request + payload["input_ids"] += results["output_ids"] + payload["sample_params"]["max_new_tokens"] -= len(results["output_ids"]) + return LLMResponse( + input_tokens=req.input_ids, + output_tokens=output_tokens, + ... + ) + +``` + +### `RLVRWorkflow` and `WorkflowExecutor` + +The rollout data lifecycle is controlled by an `RLVRWorkflow`, which defines how data +progresses from prompts to complete rollout data containing all fields required for +training. Our example shows a single-turn RLVR workflow with a math reward function. The +core logic of the workflow is implemented in an async method `arun_episode`, which takes +a prompt, generate answers with `RemoteSGLangEngine`, computes rewards, and populates +additional fields to produce finalized training data. ```python class RLVRWorkflow(RolloutWorkflow): @@ -158,12 +246,7 @@ workflow = RLVRWorkflow( ) ``` -In AReaLite, generation tasks are offloaded to remote inference servers, which operate -on separate GPUs from those used for training. The `RemoteSGLangEngine` acts as a client -that interacts with the servers. `RemoteSGLangEngine` runs in a SPMD manner on every -training process, without occupying any GPUs. - -`RemoteSGLangEngine` is responsible for managing the data streaming through rollout +`WorkflowExecutor` is responsible for managing the data streaming through rollout workflows, and collates completed rollout data into batched training samples. When initializing, it launches a rollout thread that runs rollout workflows as `asyncio` tasks. The following code shows the simplified version of rollout thread implementation, @@ -177,7 +260,7 @@ which iteratively: - Gathers data from finished workflows and puts them into `output_queue` ```python -class RemoteSGLangEngine(InferenceEngine): +class WorkflowExecutor: ... async def _rollout_thread_async(self): rid = 0 @@ -257,10 +340,11 @@ def prepare_batch( pass ``` -The usage of `RemoteSGLangEngine` in the training script is simple: +The usage of `WorkflowExecutor` in the training script is simple: ```python -rollout = RemoteSGLangEngine(config.rollout) +inf_engine = RemoteSGLangEngine(config.inf_engine) +rollout = WorkflowExecutor(config.rollout, inf_engine) rollout.initialize() eval_rollout = ... diff --git a/docs/arealite/gsm8k_grpo.png b/docs/arealite/gsm8k_grpo.png index 2a89b47..51e5bb7 100644 Binary files a/docs/arealite/gsm8k_grpo.png and b/docs/arealite/gsm8k_grpo.png differ