From e507ce281c10f4968237dd798b3404dfc750f4db Mon Sep 17 00:00:00 2001 From: Wei Fu <36355462+garrett4wade@users.noreply.github.com> Date: Thu, 31 Jul 2025 19:29:55 +0800 Subject: [PATCH] [lite] [fix] Fix a performance issue and several minor issues before release (#203) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 * add gradient checkpointing * PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 * . * . * . * . * PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 * . * . * . * . * . * . * fix * . * PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment Reviewed-by: 晓雷 * . * . * . * . * . * fix destroy process group * PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/358 Reviewed-by: 晓雷 * . * . * . * . * fix loss mask * fix * . * PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/368 Reviewed-by: 晓雷 * . * . * PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/371 Reviewed-by: 晓雷 * . * PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/370 Reviewed-by: 博惟 * . * . * . * fix * PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results. Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/392 Reviewed-by: 晓雷 * support fsdp engine and sglang remote engine * minor fix * . * refactor trainer * add close * rm mb_spec * . * fix * . * qwen2 grpo works * fix * fix * async works * fix * slurm launcher not tested * fix arg parse * . * sglang server wrapper * . * . * slurm run * ready for boba * debug * 32k run * . * . * fix * . * . * . * . * . * fix * . * fix * . * . * . * . * fix * . * . * . * . * . * . * . * refactor train engine * refactor train engine * . * fix update weight error * . * . * match train * format * . * fix * seems to work * . * . * . * . * . * PullRequest: 408 [Feature] Bump SGLang version to v0.4.9.post2 Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/408 Reviewed-by: 晓雷 * . * bump arealite to sglang 0.4.9.post2 * . * PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta` * PullRequest: 422 [lite] Fix tests and scripts after updating sgl to 0.4.9 Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/422 Reviewed-by: 晓雷 * . * bump arealite to sglang 0.4.9.post2 * . * PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta` * . * PullRequest: 423 [lite] Remove the boba example for github release. Merge branch fw/remove-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/423 Reviewed-by: 晓雷 * . * . * update readme * PullRequest: 431 [Fix] Fix environment of lite Merge branch fw/lite-dev of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/431 Reviewed-by: 晓雷 * change requirements * . * . * . * PullRequest: 440 [FIX] fix update weight from disk Merge branch sxj/lite-fix-disk-update of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/440 Reviewed-by: 博惟 * [FIX] fix update weight from disk * PullRequest: 442 [lite] Refactor `RemoteSGLangEngine` into two parts: `RemoteSGLangEngine` and `WorkflowExecutor`. Merge branch mzy/workflow-executor of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/442 Reviewed-by: 博惟 * refactor workflow executor * . * fix tests and eval * . * . * revert workflow executor into remote sglang engine * . * PullRequest: 456 [lite] [Bug] Use `ProcessPoolExecutor` to calculate reward to avoid rollout slow down Merge branch mzy/lite/fix-reward of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/456?tab=comment Reviewed-by: 博惟 * fix reward * . * . * . * PullRequest: 460 [lite][fix] add a warning when reward computation timeout Merge branch fw/lite-fix of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/460 Reviewed-by: 晓雷 * add a warning when reward computation timeout * PullRequest: 465 [lite][fix] Fix issues raised by tsao Merge branch fw/lite-fix of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/465 Reviewed-by: 晓雷 * fix --------- Co-authored-by: 晓雷 Co-authored-by: 冰临 --- arealite/api/cli_args.py | 42 +--- arealite/api/engine_api.py | 1 + arealite/api/workflow_api.py | 245 +++++++++++++++++++- arealite/engine/fsdp_engine.py | 2 +- arealite/engine/ppo/actor.py | 2 +- arealite/engine/sglang_remote.py | 270 ++++------------------ arealite/experimental/sglang_engine.py | 1 + arealite/launcher/ray.py | 1 - arealite/utils/http.py | 1 + arealite/workflow/multi_turn.py | 2 +- arealite/workflow/rlvr.py | 36 ++- docs/arealite/gsm8k_grpo.md | 26 +-- docs/customization/agent.md | 29 ++- docs/customization/algorithm.md | 9 +- examples/arealite/configs/gsm8k_grpo.yaml | 10 +- examples/arealite/gsm8k_grpo.py | 10 +- examples/env/scripts/setup-pip-deps.sh | 2 +- pyproject.toml | 2 + requirements.txt | 3 + 19 files changed, 384 insertions(+), 310 deletions(-) diff --git a/arealite/api/cli_args.py b/arealite/api/cli_args.py index a25c841..38a33b3 100644 --- a/arealite/api/cli_args.py +++ b/arealite/api/cli_args.py @@ -195,7 +195,10 @@ class TrainEngineConfig: gradient_checkpointing: bool = field( default=True, metadata={"help": "Enable gradient checkpointing"} ) - dtype: str = field(default="float16", metadata={"help": "Parameter dtype."}) + dtype: str = field(default="bfloat16", metadata={"help": "Parameter dtype."}) + grad_reduce_dtype: str = field( + default="float32", metadata={"help": "Gradient reduce dtype."} + ) optimizer: Optional[OptimizerConfig] = field( default=None, metadata={"help": "Optimizer configuration"} ) @@ -262,7 +265,7 @@ class PPOActorConfig(TrainEngineConfig): default=1.0, metadata={"help": "Lambda parameter for GAE"} ) adv_norm: bool = field( - default=True, metadata={"help": "Enable advantage normalization"} + default=True, metadata={"help": "Enable advantage normalization globally"} ) # KL Control @@ -275,7 +278,9 @@ class PPOActorConfig(TrainEngineConfig): ) use_decoupled_loss: bool = field( default=False, - metadata={"help": "Use the decoupled loss. recompute_logprob must be True."}, + metadata={ + "help": "Use the decoupled loss. Implicitly enable recompute_logprob." + }, ) behav_imp_weight_cap: Optional[float] = field( default=None, @@ -329,7 +334,7 @@ class SGLangConfig: schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 cpu_offload_gb: int = 0 - dtype: str = "float16" + dtype: str = "bfloat16" kv_cache_dtype: str = "auto" # logging log_level: str = "warning" @@ -407,31 +412,8 @@ class SGLangConfig: dist_init_addr=dist_init_addr, **args, ) - sglang_version = pkg_version.get_version("sglang") - if sglang_version: - version_less_than_0_4_4 = ( - pkg_version.compare_versions(sglang_version, "0.4.4") < 0 - ) - version_less_than_0_4_3 = ( - pkg_version.compare_versions(sglang_version, "0.4.3") < 0 - ) - elif pkg_version.is_available("sglang"): - version_less_than_0_4_4 = pkg_version.is_version_less("sglang", "0.4.4") - version_less_than_0_4_3 = pkg_version.is_version_less("sglang", "0.4.3") - else: - raise ValueError( - "A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd." - ) - if version_less_than_0_4_4: - args.pop("log_requests_level") - if version_less_than_0_4_3: - args.pop("enable_nccl_nvls") - args.pop("triton_attention_num_kv_splits") - args.pop("cuda_graph_bs") - args.pop("enable_memory_saver") - args.pop("allow_auto_truncate") - args.pop("file_storage_path") - + if not pkg_version.is_version_greater_or_equal("sglang", "0.4.9.post2"): + raise RuntimeError("Needs sglang>=0.4.9.post2 to run the code.") return args @@ -466,7 +448,7 @@ class InferenceEngineConfig: default="round_robin", metadata={"help": "Request scheduling policy", "choices": ["round_robin"]}, ) - setup_timeout: float = field(default=90.0) + setup_timeout: float = field(default=120.0) request_timeout: float = field( default=3600, metadata={"help": "Timeout for HTTP requests."} ) diff --git a/arealite/api/engine_api.py b/arealite/api/engine_api.py index bd00eb5..db0f4f5 100644 --- a/arealite/api/engine_api.py +++ b/arealite/api/engine_api.py @@ -174,6 +174,7 @@ class InferenceEngine(abc.ABC): self, dataloader: StatefulDataLoader, workflow: "RolloutWorkflow", + should_accept: Callable | None = None, ): """Asynchronously submit and wait until a full batch is ready.""" raise NotImplementedError() diff --git a/arealite/api/workflow_api.py b/arealite/api/workflow_api.py index 9141399..fbb24dd 100644 --- a/arealite/api/workflow_api.py +++ b/arealite/api/workflow_api.py @@ -1,10 +1,30 @@ -from typing import TYPE_CHECKING, Any, Dict +import asyncio +import itertools +import queue +import threading +import time +import traceback +from typing import TYPE_CHECKING, Any, Callable, Dict, List +import torch.distributed as dist +import uvloop from tensordict import TensorDict +from torchdata.stateful_dataloader import StatefulDataLoader + +from arealite.api.cli_args import InferenceEngineConfig +from arealite.api.engine_api import InferenceEngine +from arealite.api.io_struct import RolloutStat +from arealite.utils.data import concat_padded_tensors +from realhf.base import logging if TYPE_CHECKING: from arealite.api.engine_api import InferenceEngine +logger = logging.getLogger("arealite.workflow_api") + + +ROLLOUT_POLL_WAIT_TIME = 0.05 + class RolloutWorkflow: @@ -16,3 +36,226 @@ class RolloutWorkflow: See concrete example implementations under the `arealite/workflow` directory. """ raise NotImplementedError() + + +class WorkflowExecutor: + + def __init__( + self, + config: InferenceEngineConfig, + inference_engine: "InferenceEngine", + ): + config.max_concurrent_rollouts = ( + config.max_concurrent_rollouts or config.consumer_batch_size + ) + self.config = config + self.exiting = threading.Event() + self.paused = threading.Event() + self.lock = threading.Lock() + + self.inference_engine = inference_engine + + qsize = config.queue_size or config.max_concurrent_rollouts * 16 + self.input_queue = queue.Queue(maxsize=qsize) + self.output_queue = queue.Queue(maxsize=qsize) + self.result_cache: List[TensorDict] = [] + + self.rollout_stat = RolloutStat() + + def initialize(self): + self.rollout_tasks: Dict[str, asyncio.Task] = {} + self.rollout_thread = threading.Thread(target=self._rollout_thread) + self.rollout_thread.start() + + def destroy(self): + self.exiting.set() + self.rollout_thread.join() + + def get_capacity(self): + if dist.is_initialized(): + world_size = dist.get_world_size() + else: + world_size = 1 + + with self.lock: + max_concurrent_rollouts = max( + 1, self.config.max_concurrent_rollouts // world_size + ) + capacity = max_concurrent_rollouts - len(self.rollout_tasks) + # Staleness control + version = self.inference_engine.get_version() + ofp = self.config.max_head_offpolicyness + sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running + consumer_bs = max(1, self.config.consumer_batch_size // world_size) + capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt) + return capacity + + def _rollout_thread(self): + """Thread that runs the rollout loop.""" + try: + uvloop.run(self._rollout_thread_async()) + except Exception: + traceback.print_exc() + + async def _rollout_thread_async(self): + rollout_tasks = self.rollout_tasks + rid = 0 + try: + while not self.exiting.is_set(): + # Check capacity + capacity = self.get_capacity() + # Create new rollout task + self.lock.acquire() + while ( + capacity > 0 + and not self.paused.is_set() + and self.input_queue.qsize() > 0 + ): + data, workflow = self.input_queue.get_nowait() + logger.debug(f"Get data from puller: {data}") + task = asyncio.create_task( + workflow.arun_episode(self.inference_engine, data), + name=str(rid), + ) + rollout_tasks[str(rid)] = task + self.rollout_stat.submitted += 1 + self.rollout_stat.running += 1 + if self.config.enable_rollout_tracing: + logger.info( + f"Submit rollout rid {rid}. " + f"Submit: {self.rollout_stat.submitted}, " + f"running: {self.rollout_stat.running}, " + f"accepted: {self.rollout_stat.accepted}." + ) + capacity -= 1 + rid += 1 + tasks = list(rollout_tasks.values()) + self.lock.release() + + # Wait for rollout completion + done = [] + if tasks: + done, _ = await asyncio.wait( + tasks, + timeout=ROLLOUT_POLL_WAIT_TIME, + return_when=asyncio.FIRST_COMPLETED, + ) + # Collect done results + for task in done: + traj = await task + traj: TensorDict + task_rid = task.get_name() + with self.lock: + rollout_tasks.pop(task_rid) + self.rollout_stat.accepted += 1 + + self.rollout_stat.running -= 1 + if self.config.enable_rollout_tracing: + logger.info( + f"Finish rollout {task_rid}. " + f"Submit: {self.rollout_stat.submitted}, " + f"running: {self.rollout_stat.running}, " + f"accepted: {self.rollout_stat.accepted}." + ) + try: + self.output_queue.put_nowait(traj) + except queue.Full: + raise RuntimeError( + "Output queue full. Please increase queue_size." + ) + + await asyncio.sleep(1) + except Exception: + traceback.print_exc() + finally: + # Cancel remaining tasks + with self.lock: + for task in rollout_tasks.values(): + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None: + try: + self.input_queue.put_nowait((data, workflow)) + except queue.Full: + raise RuntimeError("Input queue full. Please increase queue_size.") + + def wait( + self, + count: int, + timeout: float | None = None, + should_accept: Callable | None = None, + ) -> TensorDict: + tik = time.perf_counter() + accepted = len(self.result_cache) + timeout = timeout or float(7 * 24 * 3600) + while ( + accepted < count + and not self.exiting.is_set() + and time.perf_counter() - tik < timeout + ): + try: + result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME) + if should_accept is None or should_accept(result): + self.result_cache.append(result) + accepted += 1 + else: + with self.lock: + self.rollout_stat.accepted -= 1 + except queue.Empty: + pass + if self.exiting.is_set(): + raise RuntimeError("Rollout engine is exiting, cannot wait for results.") + if accepted < count: + raise TimeoutError( + f"Timed out waiting for {count} rollouts, " f"only received {accepted}." + ) + results, self.result_cache = ( + self.result_cache[:count], + self.result_cache[count:], + ) + return concat_padded_tensors(results) + + def rollout_batch( + self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow" + ) -> TensorDict: + """Submit a batch of requests to the inference engine and wait for the results.""" + for item in data: + self.submit(item, workflow) + return self.wait(count=len(data)) + + def prepare_batch( + self, + dataloader: StatefulDataLoader, + workflow: "RolloutWorkflow", + should_accept: Callable | None = None, + ): + if not hasattr(self, "data_generator"): + self.data_generator = itertools.cycle(dataloader) + assert dataloader.batch_size is not None + while True: + # Submit at least two batches to allow maximum overlap + if ( + self.get_capacity() + dataloader.batch_size > 0 + and self.input_queue.qsize() + dataloader.batch_size + < self.input_queue.maxsize + ): + data = next(self.data_generator) + for item in data: + self.submit(item, workflow=workflow) + try: + return self.wait( + dataloader.batch_size, timeout=1, should_accept=should_accept + ) + except TimeoutError: + pass + + def pause(self): + self.paused.set() + + def resume(self): + self.paused.clear() diff --git a/arealite/engine/fsdp_engine.py b/arealite/engine/fsdp_engine.py index 76fe449..30b5205 100644 --- a/arealite/engine/fsdp_engine.py +++ b/arealite/engine/fsdp_engine.py @@ -55,7 +55,7 @@ class FSDPEngine(BaseHFEngine): # Simple auto wrap policy self.mixed_precision_policy = MixedPrecisionPolicy( param_dtype=getattr(torch, self.config.dtype), - reduce_dtype=torch.float32, + reduce_dtype=getattr(torch, self.config.grad_reduce_dtype), cast_forward_inputs=True, ) self.device_mesh = create_fsdp_device_mesh(self.world_size, self.world_size) diff --git a/arealite/engine/ppo/actor.py b/arealite/engine/ppo/actor.py index 802c03e..ba40e0c 100644 --- a/arealite/engine/ppo/actor.py +++ b/arealite/engine/ppo/actor.py @@ -128,7 +128,7 @@ class PPOActor: advantages = torch.stack(advantages_reversed[::-1], dim=1) # Optionally perform advantage normalization. - if self.adv_norm: + if self.adv_norm or self.group_adv_norm: if self.group_adv_norm: adv_list = [] for i in range(0, bs, self.group_size): diff --git a/arealite/engine/sglang_remote.py b/arealite/engine/sglang_remote.py index c904d25..66f3d1e 100644 --- a/arealite/engine/sglang_remote.py +++ b/arealite/engine/sglang_remote.py @@ -2,17 +2,13 @@ import asyncio import os import random import shutil -import threading import time -import traceback from concurrent.futures import Future, ProcessPoolExecutor from datetime import datetime -from queue import Empty, Full, Queue -from typing import TYPE_CHECKING, Any, Callable, Dict, List +from typing import Any, Callable, Dict, List import aiohttp import requests -import torch.distributed as dist import uvloop from tensordict import TensorDict from torchdata.stateful_dataloader import StatefulDataLoader @@ -28,25 +24,19 @@ from arealite.api.io_struct import ( VLMResponse, WeightUpdateMeta, ) -from arealite.utils.data import concat_padded_tensors +from arealite.api.workflow_api import RolloutWorkflow, WorkflowExecutor from arealite.utils.http import arequest_with_retry, get_default_connector from realhf.base import logging, name_resolve, names -if TYPE_CHECKING: - from arealite.api.workflow_api import RolloutWorkflow logger = logging.getLogger(__name__) -ROLLOUT_POLL_WAIT_TIME = 0.05 RID_CACHE_SIZE = 128 class RemoteSGLangEngine(InferenceEngine): def __init__(self, config: InferenceEngineConfig): - config.max_concurrent_rollouts = ( - config.max_concurrent_rollouts or config.consumer_batch_size - ) self.config = config self.rid_to_address = {} @@ -54,29 +44,20 @@ class RemoteSGLangEngine(InferenceEngine): self.rid_queue = [] self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",") + self.session = None + if not self.addresses: raise RuntimeError("No configured SGLang servers.") - logger.info("Waiting for server ready...") - for addr in self.addresses: - self._wait_for_server(addr) - logger.info("Servers are all ready!") self.server_idx = random.randint(0, len(self.addresses) - 1) - - qsize = config.queue_size or config.max_concurrent_rollouts * 16 - self.input_queue = Queue(maxsize=qsize) - self.output_queue = Queue(maxsize=qsize) - self.result_cache = [] - - self.exiting = threading.Event() - self.paused = threading.Event() - self.lock = threading.Lock() - - self.rollout_stat = RolloutStat() self.distributed_weight_update_initialized = False - self._version = 0 + self.workflow_executor = WorkflowExecutor( + config=config, + inference_engine=self, + ) + def _wait_for_server(self, address): base_url = f"http://{address}" tik = time.time() @@ -94,137 +75,46 @@ class RemoteSGLangEngine(InferenceEngine): except requests.exceptions.RequestException as e: return False - def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None): - self.rollout_tasks: Dict[str, asyncio.Task] = {} - + def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None = None): + logger.info("Waiting for server ready...") + for addr_ in self.addresses: + self._wait_for_server(addr_) + logger.info("Servers are all ready!") self.executor = ProcessPoolExecutor(max_workers=1) - self.rollout_thread = threading.Thread(target=self._rollout_thread) - self.rollout_thread.start() + self.workflow_executor.initialize() def destroy(self): self.executor.shutdown() - self.exiting.set() - self.rollout_thread.join() def set_version(self, version): - with self.lock: - self._version = version + self._version = version def get_version(self): - with self.lock: - return self._version - - def _rollout_thread(self): - """Thread that runs the rollout loop.""" - try: - uvloop.run(self._rollout_thread_async()) - except Exception: - traceback.print_exc() - - async def _rollout_thread_async(self): - rollout_tasks = self.rollout_tasks - rid = 0 - - # NOTE: session is not thread-safe, but we only submit requests in the sub-thread. - self.session = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout( - total=self.config.request_timeout, - sock_connect=self.config.request_timeout, - connect=self.config.request_timeout, - ), - read_bufsize=1024 * 1024 * 10, - connector=get_default_connector(), - ) - - try: - while not self.exiting.is_set(): - # Check capacity - capacity = self.get_capacity() - # Create new rollout task - while ( - capacity > 0 - and not self.paused.is_set() - and self.input_queue.qsize() > 0 - ): - data, workflow = self.input_queue.get_nowait() - logger.debug(f"Get data from puller: {data}") - task = asyncio.create_task( - workflow.arun_episode(self, data), name=str(rid) - ) - with self.lock: - rollout_tasks[str(rid)] = task - self.rollout_stat.submitted += 1 - self.rollout_stat.running += 1 - if self.config.enable_rollout_tracing: - logger.info( - f"Submit rollout rid {rid}. " - f"Submit: {self.rollout_stat.submitted}, " - f"running: {self.rollout_stat.running}, " - f"accepted: {self.rollout_stat.accepted}." - ) - capacity -= 1 - rid += 1 - # Wait for rollout completion - with self.lock: - tasks = list(rollout_tasks.values()) - done = [] - if tasks: - done, _ = await asyncio.wait( - tasks, - timeout=ROLLOUT_POLL_WAIT_TIME, - return_when=asyncio.FIRST_COMPLETED, - ) - # Collect done results - for task in done: - traj = await task - traj: TensorDict - task_rid = task.get_name() - with self.lock: - rollout_tasks.pop(task_rid) - self.rollout_stat.accepted += 1 - - try: - self.output_queue.put_nowait(traj) - except Full: - raise RuntimeError( - "Output queue full. Please increase queue_size." - ) - - with self.lock: - self.rollout_stat.running -= 1 - if self.config.enable_rollout_tracing: - logger.info( - f"Finish rollout {task_rid}. " - f"Submit: {self.rollout_stat.submitted}, " - f"running: {self.rollout_stat.running}, " - f"accepted: {self.rollout_stat.accepted}." - ) - await asyncio.sleep(1) - except Exception: - traceback.print_exc() - finally: - # Cancel remaining tasks - with self.lock: - for task in rollout_tasks.values(): - if not task.done(): - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + return self._version def choose_server(self) -> str: - with self.lock: - if self.config.schedule_policy == "round_robin": - server = self.addresses[self.server_idx] - self.server_idx = (self.server_idx + 1) % len(self.addresses) - return server + if self.config.schedule_policy == "round_robin": + server = self.addresses[self.server_idx] + self.server_idx = (self.server_idx + 1) % len(self.addresses) + return server raise NotImplementedError("Only round-robin scheduling is implemented.") async def agenerate( self, req: LLMRequest | VLMRequest ) -> LLMResponse | VLMResponse: """Async version of generate using aiohttp.""" + if self.session is None: + # NOTE: Lazily initialize aiohttp.ClientSession since it needs to be initialized + # inside asyncio loop in WorkflowExecutor + self.session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=self.config.request_timeout, + sock_connect=self.config.request_timeout, + connect=self.config.request_timeout, + ), + read_bufsize=1024 * 1024 * 10, + connector=get_default_connector(), + ) # Prepare request payload gconfig = req.gconfig stop_token_ids = gconfig.stop_token_ids @@ -396,30 +286,8 @@ class RemoteSGLangEngine(InferenceEngine): fut.add_done_callback(callback) return fut - def get_capacity(self): - if dist.is_initialized(): - world_size = dist.get_world_size() - else: - world_size = 1 - - max_concurrent_rollouts = max( - 1, self.config.max_concurrent_rollouts // world_size - ) - capacity = max_concurrent_rollouts - len(self.rollout_tasks) - # Staleness control - version = self.get_version() - ofp = self.config.max_head_offpolicyness - with self.lock: - sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running - consumer_bs = max(1, self.config.consumer_batch_size // world_size) - capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt) - return capacity - - def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None: - try: - self.input_queue.put_nowait((data, workflow)) - except Full: - raise RuntimeError("Input queue full. Please increase queue_size.") + def submit(self, data: Dict[str, Any], workflow: RolloutWorkflow) -> None: + return self.workflow_executor.submit(data, workflow) def wait( self, @@ -427,77 +295,31 @@ class RemoteSGLangEngine(InferenceEngine): timeout: float | None = None, should_accept: Callable | None = None, ) -> TensorDict: - tik = time.perf_counter() - accepted = len(self.result_cache) - timeout = timeout or float(7 * 24 * 3600) - while ( - accepted < count - and not self.exiting.is_set() - and time.perf_counter() - tik < timeout - ): - try: - result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME) - if should_accept is None or should_accept(result): - self.result_cache.append(result) - accepted += 1 - else: - with self.lock: - self.rollout_stat.accepted -= 1 - except Empty: - pass - if self.exiting.is_set(): - raise RuntimeError("Rollout engine is exiting, cannot wait for results.") - if accepted < count: - raise TimeoutError( - f"Timed out waiting for {count} rollouts, " f"only received {accepted}." - ) - results, self.result_cache = ( - self.result_cache[:count], - self.result_cache[count:], + return self.workflow_executor.wait( + count, + timeout=timeout, + should_accept=should_accept, ) - return concat_padded_tensors(results) def rollout_batch( self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow" ) -> TensorDict: - """Submit a batch of requests to the inference engine and wait for the results.""" - for item in data: - self.submit(item, workflow) - return self.wait(count=len(data)) + return self.workflow_executor.rollout_batch(data, workflow) def prepare_batch( self, dataloader: StatefulDataLoader, - workflow: "RolloutWorkflow", + workflow: RolloutWorkflow, ): - if not hasattr(self, "data_generator"): - self.data_generator = iter(dataloader) - assert dataloader.batch_size is not None - while True: - # Submit at least two batches to allow maximum overlap - if ( - self.get_capacity() + dataloader.batch_size > 0 - and self.input_queue.qsize() + dataloader.batch_size - < self.input_queue.maxsize - ): - try: - data = next(self.data_generator) - - except StopIteration: - self.data_generator = iter(dataloader) - data = next(self.data_generator) - for item in data: - self.submit(item, workflow=workflow) - try: - return self.wait(dataloader.batch_size, timeout=1) - except TimeoutError: - pass + return self.workflow_executor.prepare_batch(dataloader, workflow) def pause(self): - self.paused.set() + """Pause request submission for async rollout. Used during evaluation to prevent data over generation.""" + return self.workflow_executor.pause() def resume(self): - self.paused.clear() + """Resume request submission for async rollout.""" + return self.workflow_executor.resume() def update_weights_from_disk( diff --git a/arealite/experimental/sglang_engine.py b/arealite/experimental/sglang_engine.py index 777091e..9ee7bdc 100644 --- a/arealite/experimental/sglang_engine.py +++ b/arealite/experimental/sglang_engine.py @@ -94,6 +94,7 @@ class SGLangEngine(InferenceEngine): asyncio.run(self._rollout_thread_async()) except Exception as e: traceback.print_exc() + raise e async def _rollout_thread_async(self): data = None diff --git a/arealite/launcher/ray.py b/arealite/launcher/ray.py index 0344e98..0a0fdf5 100644 --- a/arealite/launcher/ray.py +++ b/arealite/launcher/ray.py @@ -324,7 +324,6 @@ def ray_main(): ) allocation_mode = config.allocation_mode allocation_mode = AllocationMode.from_str(allocation_mode) - sglang_cmds = [] sglang_addrs = [] n_sglang_nodes = 0 if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG: diff --git a/arealite/utils/http.py b/arealite/utils/http.py index a39a361..0ddca14 100644 --- a/arealite/utils/http.py +++ b/arealite/utils/http.py @@ -61,6 +61,7 @@ async def arequest_with_retry( ctx = _session.delete(url, timeout=timeo) else: raise ValueError(f"Unsupported HTTP method: {method}") + async with ctx as response: if verbose: logger.info("http requests return") diff --git a/arealite/workflow/multi_turn.py b/arealite/workflow/multi_turn.py index eec8aa1..14bd263 100644 --- a/arealite/workflow/multi_turn.py +++ b/arealite/workflow/multi_turn.py @@ -31,7 +31,7 @@ class MultiTurnWorkflow(RolloutWorkflow): messages = data["messages"] # Run multi-turn rollout until correct t = reward = 0 - discount = 0 + discount = 1 rid = uuid.uuid4().hex while reward == 0 and t < self.max_turns: # Amend a prompt if the previous answer is incorrect diff --git a/arealite/workflow/rlvr.py b/arealite/workflow/rlvr.py index cf7d21b..7a2170b 100644 --- a/arealite/workflow/rlvr.py +++ b/arealite/workflow/rlvr.py @@ -1,6 +1,8 @@ import asyncio +import functools import os import uuid +from concurrent.futures import ProcessPoolExecutor import colorama import torch @@ -12,6 +14,11 @@ from arealite.api.engine_api import InferenceEngine from arealite.api.io_struct import LLMRequest from arealite.api.workflow_api import RolloutWorkflow from arealite.utils.data import concat_padded_tensors +from realhf.base import logging + +logger = logging.getLogger("RLVR workflow") + +REWARD_TIMEOUT_SECONDS = 15 class RLVRWorkflow(RolloutWorkflow): @@ -28,6 +35,7 @@ class RLVRWorkflow(RolloutWorkflow): self.tokenizer = tokenizer self.enable_thinking = enable_thinking self.dump_dir = dump_dir + self.rw_executor = ProcessPoolExecutor(max_workers=4) if self.dump_dir is not None and not os.path.exists(self.dump_dir): os.makedirs(self.dump_dir, exist_ok=True) @@ -54,6 +62,7 @@ class RLVRWorkflow(RolloutWorkflow): seqlens = [] results = [] + loop = asyncio.get_event_loop() for resp in resps: seq = resp.input_tokens + resp.output_tokens logprobs = [0.0] * resp.input_len + resp.output_logprobs @@ -65,13 +74,26 @@ class RLVRWorkflow(RolloutWorkflow): prompt_strs.append(prompt_str) completions_strs.append(completions_str) seqlens.append(len(seq)) - reward = self.reward_fn( - prompt=prompt_str, - completions=completions_str, - prompt_ids=resp.input_tokens, - completion_ids=resp.output_tokens, - **data, - ) + try: + reward = await asyncio.wait_for( + loop.run_in_executor( + self.rw_executor, + functools.partial( + self.reward_fn, + prompt_str, + completions_str, + resp.input_tokens, + resp.output_tokens, + **data, + ), + ), + timeout=REWARD_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + logger.warning( + f"Computing reward timeout after {REWARD_TIMEOUT_SECONDS}s. Set reward to 0." + ) + reward = 0 rewards.append(reward) res = dict( # unsqueeze to add an additional batch dimension diff --git a/docs/arealite/gsm8k_grpo.md b/docs/arealite/gsm8k_grpo.md index 26125d3..be3b9ce 100644 --- a/docs/arealite/gsm8k_grpo.md +++ b/docs/arealite/gsm8k_grpo.md @@ -233,7 +233,7 @@ def prepare_batch( workflow: "RolloutWorkflow", ): if not hasattr(self, "data_generator"): - self.data_generator = iter(dataloader) + self.data_generator = itertools.cycle(dataloader) assert dataloader.batch_size is not None while True: # Submit at least two batches to allow maximum overlap @@ -242,11 +242,7 @@ def prepare_batch( and self.input_queue.qsize() + dataloader.batch_size < self.input_queue.maxsize ): - try: - data = next(self.data_generator) - except StopIteration: - self.data_generator = iter(dataloader) - data = next(self.data_generator) + data = next(self.data_generator) for item in data: # submit data into input_queue self.submit(item, workflow=workflow) @@ -264,18 +260,13 @@ rollout = RemoteSGLangEngine(config.rollout) rollout.initialize() eval_rollout = ... -data_generator = iter(train_dataloader) +data_generator = iterools.cycle(train_dataloader) for global_step in range(max_steps): # rollout batched training data for current step if config.async_training: batch = rollout.prepare_batch(train_dataloader, workflow=workflow) else: - try: - data = next(data_generator) - except StopIteration: - data_generator = iter(train_dataloader) - data = next(data_generator) - batch = rollout.rollout_batch(data, workflow=workflow) + batch = rollout.rollout_batch(next(data_generator), workflow=workflow) ``` If you want to use rollout workflows with custom reward functions or agentic tool @@ -375,17 +366,12 @@ Now a complete GRPO training step in AReaLite is done! The core logic of our exa training script can be summarized as: ```python -data_generator = iter(train_dataloader) +data_generator = itertools.cycle(train_dataloader) for global_step in range(max_steps): if config.async_training: batch = rollout.prepare_batch(train_dataloader, workflow=workflow) else: - try: - data = next(data_generator) - except StopIteration: - data_generator = iter(train_dataloader) - data = next(data_generator) - batch = rollout.rollout_batch(data, workflow=workflow) + batch = rollout.rollout_batch(next(data_generator), workflow=workflow) logp = actor.compute_logp(batch) batch["prox_logp"] = logp diff --git a/docs/customization/agent.md b/docs/customization/agent.md index 8d5c6ff..1ebd5b3 100644 --- a/docs/customization/agent.md +++ b/docs/customization/agent.md @@ -75,7 +75,7 @@ and converting it into an `LLMRequest` object for the inference engine: class MultiTurnWorkflow(RolloutWorkflow): # ... __init__ method above ... - async def arun_episode(self, engine: InferenceEngine, data): + async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict: # Initialize result containers seq, logprobs, loss_mask, versions = [], [], [], [] messages = data["messages"] @@ -119,7 +119,7 @@ we'll apply a discount, add feedback to the conversation, and let the model try class MultiTurnWorkflow(RolloutWorkflow): # ... previous methods ... - async def arun_episode(self, engine: InferenceEngine, data): + async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict: # ... initialization code ... while reward == 0 and t < self.max_turns: # Add feedback if the previous answer was incorrect @@ -190,7 +190,7 @@ Finally, let's complete the implementation by collecting trajectories in the class MultiTurnWorkflow(RolloutWorkflow): # ... previous methods ... - async def arun_episode(self, engine: InferenceEngine, data): + async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict: # ... episode logic above ... while reward == 0 and t < self.max_turns: @@ -417,8 +417,27 @@ Using your custom workflow is straightforward—just create it in your training pass it to the `rollout_batch` or `prepare_batch` method: ```python -# in realhf/impl/agent/__init__.py -import realhf.impl.agent.math_multi_turn_agent +def main(args): + # ... setup code ... + + # Create your custom workflow + workflow = MultiTurnWorkflow( + reward_fn=gsm8k_reward_fn, + gconfig=config.gconfig, + tokenizer=tokenizer, + turn_discount=0.9, + max_turns=5, + ) + + # Run training—no other changes needed! + data_generator = itertools.cycle(train_dataloader) + for global_step in range(max_steps): + with stats_tracker.record_timing("rollout"): + if config.async_training: + batch = rollout.prepare_batch(train_dataloader, workflow=workflow) + else: + batch = rollout.rollout_batch(next(data_generator), workflow=workflow) + # ... continue with training loop ... ``` Then update your experiment configuration in diff --git a/docs/customization/algorithm.md b/docs/customization/algorithm.md index f087f83..04a9100 100644 --- a/docs/customization/algorithm.md +++ b/docs/customization/algorithm.md @@ -171,16 +171,11 @@ def main(args): ) # Main training loop + data_generator = itertools.cycle(dataloader) for global_step in range(max_steps): # Generate training data with stats_tracker.record_timing("rollout"): - try: - data = next(data_generator) - except StopIteration: - data_generator = iter(train_dataloader) - data = next(data_generator) - - batch = rollout.rollout_batch(data, workflow=workflow) + batch = rollout.rollout_batch(next(data_generator), workflow=workflow) batch = batch.to(actor.device) diff --git a/examples/arealite/configs/gsm8k_grpo.yaml b/examples/arealite/configs/gsm8k_grpo.yaml index db16d36..6a2cfa7 100644 --- a/examples/arealite/configs/gsm8k_grpo.yaml +++ b/examples/arealite/configs/gsm8k_grpo.yaml @@ -68,6 +68,7 @@ ref: trial_name: ${trial_name} path: ${actor.path} init_from_scratch: false + disable_dropout: true dtype: ${actor.dtype} mb_spec: max_tokens_per_mb: 10240 @@ -81,9 +82,9 @@ sglang: random_seed: ${seed} skip_tokenizer_init: true dtype: ${actor.dtype} - max_running_requests: null + max_running_requests: null context_length: 32768 - mem_fraction_static: 0.9 + mem_fraction_static: 0.8 # datasets train_dataset: @@ -94,7 +95,7 @@ train_dataset: path: openai/gsm8k type: rl -valid_dataset: +valid_dataset: batch_size: 256 shuffle: true pin_memory: true @@ -131,4 +132,5 @@ stats_logger: experiment_name: ${experiment_name} trial_name: ${trial_name} fileroot: ${cluster.fileroot} - + wandb: + mode: disabled diff --git a/examples/arealite/gsm8k_grpo.py b/examples/arealite/gsm8k_grpo.py index f0cfb9e..8014122 100644 --- a/examples/arealite/gsm8k_grpo.py +++ b/examples/arealite/gsm8k_grpo.py @@ -1,3 +1,4 @@ +import itertools import os import sys @@ -129,7 +130,7 @@ def main(args): max_steps = total_epochs * steps_per_epoch logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}") - data_generator = iter(train_dataloader) + data_generator = itertools.cycle(train_dataloader) for global_step in range(max_steps): epoch = global_step // steps_per_epoch step = global_step % steps_per_epoch @@ -138,12 +139,7 @@ def main(args): if config.async_training: batch = rollout.prepare_batch(train_dataloader, workflow=workflow) else: - try: - data = next(data_generator) - except StopIteration: - data_generator = iter(train_dataloader) - data = next(data_generator) - batch = rollout.rollout_batch(data, workflow=workflow) + batch = rollout.rollout_batch(next(data_generator), workflow=workflow) batch = batch.to(actor.device) # Create barrier to synchronize all rollout processes. diff --git a/examples/env/scripts/setup-pip-deps.sh b/examples/env/scripts/setup-pip-deps.sh index 1599170..e907109 100644 --- a/examples/env/scripts/setup-pip-deps.sh +++ b/examples/env/scripts/setup-pip-deps.sh @@ -6,7 +6,7 @@ pip install torch==2.7.1 torchaudio==2.7.1 torchvision==0.22.1 "deepspeed>=0.17. pip install "sglang[all]==0.4.9.post2" pip install megatron-core==0.11.0 nvidia-ml-py pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose -pip install "flash-attn<=2.7.3" --no-build-isolation +pip install "flash-attn<=2.8.2" --no-build-isolation # Package used for calculating math reward pip install -e evaluation/latex2sympy diff --git a/pyproject.toml b/pyproject.toml index 2100587..4e9e558 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,8 @@ dependencies = [ "torchdata", "autoflake", "tensordict", + "pybase64", + "msgspec", # Monitoring and logging "wandb", diff --git a/requirements.txt b/requirements.txt index 0c21f9b..7cde02b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -74,3 +74,6 @@ torchdata autoflake tensordict deepspeed>=0.17.2 +pybase64 +msgspec +transformers==4.53.1 \ No newline at end of file