From 311bcd76979c9248880da6aff7c9bbdbcfb26358 Mon Sep 17 00:00:00 2001 From: Wei Fu <36355462+garrett4wade@users.noreply.github.com> Date: Thu, 24 Jul 2025 15:34:52 +0800 Subject: [PATCH] [lite] [feature] Bump to SGLang v0.4.9.post2 and use NCCL to update weights (#196) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 * add gradient checkpointing * PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 * . * . * . * . * PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 * . * . * . * . * . * . * fix * . * PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment Reviewed-by: 晓雷 * . * . * . * . * . * fix destroy process group * PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/358 Reviewed-by: 晓雷 * . * . * . * . * fix loss mask * fix * . * PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/368 Reviewed-by: 晓雷 * . * . * PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/371 Reviewed-by: 晓雷 * . * PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/370 Reviewed-by: 博惟 * . * . * . * fix * PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results. Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/392 Reviewed-by: 晓雷 * support fsdp engine and sglang remote engine * minor fix * . * refactor trainer * add close * rm mb_spec * . * fix * . * qwen2 grpo works * fix * fix * async works * fix * slurm launcher not tested * fix arg parse * . * sglang server wrapper * . * . * slurm run * ready for boba * debug * 32k run * . * . * fix * . * . * . * . * . * fix * . * fix * . * . * . * . * fix * . * . * . * . * . * . * . * refactor train engine * refactor train engine * . * fix update weight error * . * . * match train * format * . * fix * seems to work * . * . * . * . * . * PullRequest: 408 [Feature] Bump SGLang version to v0.4.9.post2 Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/408 Reviewed-by: 晓雷 * . * bump arealite to sglang 0.4.9.post2 * . * PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta` * PullRequest: 422 [lite] Fix tests and scripts after updating sgl to 0.4.9 Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/422 Reviewed-by: 晓雷 * . * bump arealite to sglang 0.4.9.post2 * . * PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta` * . * PullRequest: 423 [lite] Remove the boba example for github release. Merge branch fw/remove-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/423 Reviewed-by: 晓雷 * . * . --------- Co-authored-by: 晓雷 --- arealite/README.md | 1 - arealite/api/cli_args.py | 79 ----- arealite/api/engine_api.py | 31 +- arealite/api/io_struct.py | 68 +++- arealite/engine/base_hf_engine.py | 11 +- arealite/engine/fsdp_engine.py | 43 ++- arealite/engine/sglang_remote.py | 248 +++++++-------- arealite/launcher/local.py | 2 +- arealite/launcher/sglang_server.py | 46 +-- arealite/tests/test_fsdp_engine_nccl.py | 37 +-- arealite/tests/test_sglang_engine.py | 38 +-- examples/arealite/boba.py | 310 ------------------- examples/arealite/configs/boba.yaml | 141 --------- examples/arealite/configs/gsm8k_grpo.yaml | 6 +- examples/arealite/configs/gsm8k_sft.yaml | 2 +- examples/arealite/gsm8k_grpo.py | 39 +-- examples/env/scripts/setup-container-deps.sh | 11 - examples/env/scripts/setup-eval-pip-deps.sh | 6 +- examples/env/scripts/setup-pip-deps.sh | 16 +- examples/env/validate_installation.py | 9 +- patch/sglang/v0.4.6.post2.patch | 144 --------- patch/sglang/v0.4.6.post4.patch | 144 --------- pyproject.toml | 3 +- realhf/system/generation_server.py | 31 -- realhf/system/gserver_manager.py | 49 +-- requirements.txt | 1 - 26 files changed, 313 insertions(+), 1203 deletions(-) delete mode 100644 examples/arealite/boba.py delete mode 100644 examples/arealite/configs/boba.yaml delete mode 100644 examples/env/scripts/setup-container-deps.sh delete mode 100644 patch/sglang/v0.4.6.post2.patch delete mode 100644 patch/sglang/v0.4.6.post4.patch diff --git a/arealite/README.md b/arealite/README.md index 65c49fc..4a23ce6 100644 --- a/arealite/README.md +++ b/arealite/README.md @@ -163,7 +163,6 @@ class WeightUpdateMeta: type: str path: str | None alloc_mode: AllocationMode | None - comm_backend: str | None @dataclass class SaveLoadMeta: diff --git a/arealite/api/cli_args.py b/arealite/api/cli_args.py index 3b2afa5..7dc9be6 100644 --- a/arealite/api/cli_args.py +++ b/arealite/api/cli_args.py @@ -285,85 +285,6 @@ class PPOActorConfig(TrainEngineConfig): ) -@dataclass -class PPOActorConfig(TrainEngineConfig): - # Core PPO/GRPO Parameters - group_size: int = field( - default=1, metadata={"help": "Number of sequences in each group"} - ) - group_adv_norm: bool = field( - default=False, - metadata={ - "help": "Normalize advantages within each prompt group rather than globally" - }, - ) - ppo_n_minibatches: int = field( - default=4, metadata={"help": "Number of minibatches for each PPO update"} - ) - eps_clip: float = field( - default=0.2, metadata={"help": "Clipping factor for policy ratio"} - ) - c_clip: Optional[float] = field( - default=None, - metadata={ - "help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping." - }, - ) - temperature: float = field( - default=1.0, metadata={"help": "Temperature during generation."} - ) - # Reward - group_reward_norm: bool = field( - default=False, - metadata={ - "help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias" - }, - ) - reward_scaling: float = field( - default=1.0, metadata={"help": "Reward scaling factor"} - ) - reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"}) - reward_clip: float = field( - default=20.0, metadata={"help": "Maximum absolute value for reward clipping"} - ) - mask_no_eos_with_zero: bool = field( - default=False, - metadata={ - "help": "Mask truncated generations (no EOS token) and exclude from training" - }, - ) - - # Advantage Estimation - discount: float = field( - default=1.0, metadata={"help": "Discount factor for future rewards"} - ) - gae_lambda: float = field( - default=1.0, metadata={"help": "Lambda parameter for GAE"} - ) - adv_norm: bool = field( - default=True, metadata={"help": "Enable advantage normalization"} - ) - - # KL Control - kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"}) - - # Asynchronous RL - recompute_logprob: bool = field( - default=False, - metadata={"help": "Recompute logp and replace the logp returned by inference."}, - ) - use_decoupled_loss: bool = field( - default=False, - metadata={"help": "Use the decoupled loss. recompute_logprob must be True."}, - ) - behav_imp_weight_cap: Optional[float] = field( - default=None, - metadata={ - "help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true" - }, - ) - - @dataclass class SGLangConfig: """Configuration for SGLang runtime. Refer to: diff --git a/arealite/api/engine_api.py b/arealite/api/engine_api.py index 9fb24d2..bd00eb5 100644 --- a/arealite/api/engine_api.py +++ b/arealite/api/engine_api.py @@ -12,6 +12,7 @@ from arealite.api.io_struct import ( FinetuneSpec, LLMRequest, LLMResponse, + ParamSpec, SaveLoadMeta, WeightUpdateMeta, ) @@ -63,7 +64,19 @@ class TrainEngine(abc.ABC): return self.train(False) def upload_weights(self, meta: WeightUpdateMeta): - """Upload weights to the inference engine.""" + """Upload weights to the inference engine (in a blocking manner).""" + raise NotImplementedError() + + def get_param_specs(self) -> List[ParamSpec]: + """Get the parameter specifications for the model.""" + raise NotImplementedError() + + def set_version(self, version: int): + """Set the current weight version in the train engine.""" + raise NotImplementedError() + + def get_version(self) -> int: + """Get the current weight version in the train engine.""" raise NotImplementedError() def save(self, meta: SaveLoadMeta): @@ -122,14 +135,22 @@ class InferenceEngine(abc.ABC): def destroy(self): """Destroy the engine and release GPU memory.""" - def update_weights(self, meta: WeightUpdateMeta) -> Future: - """Update weights in the inference engine.""" - raise NotImplementedError() - async def agenerate(self, req: LLMRequest) -> LLMResponse: """Asynchronously generate a response for the given request.""" raise NotImplementedError() + def update_weights(self, meta: WeightUpdateMeta) -> Future: + """Update weights in the inference engine in a non-blocking manner.""" + raise NotImplementedError() + + def set_version(self, version: int) -> None: + """Set the current weight version in the inference engine.""" + raise NotImplementedError() + + def get_version(self) -> int: + """Get the current weight version in the inference engine.""" + raise NotImplementedError() + def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None: """Asynchronously submit a request to the inference engine. Exits immediately.""" raise NotImplementedError() diff --git a/arealite/api/io_struct.py b/arealite/api/io_struct.py index bd11d02..136c15f 100644 --- a/arealite/api/io_struct.py +++ b/arealite/api/io_struct.py @@ -2,14 +2,19 @@ # Licensed under the Apache License, Version 2.0 import enum import itertools +import os import re import uuid from dataclasses import dataclass, field -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple from transformers import PreTrainedTokenizerFast -from arealite.api.cli_args import GenerationHyperparameters +from arealite.api.cli_args import GenerationHyperparameters, SaverConfig +from arealite.utils.network import find_free_ports, gethostip + +if TYPE_CHECKING: + from arealite.api.engine_api import TrainEngine @dataclass @@ -155,20 +160,55 @@ class AllocationMode: return other_alloc +@dataclass +class ParamSpec: + name: str + shape: Tuple + dtype: str + + @dataclass class WeightUpdateMeta: - type: str - path: str | None - alloc_mode: AllocationMode | None - comm_backend: str | None - model_version: int = 0 - tp_size: int = 1 - master_address: str = "127.0.0.1" - master_port: int = 29500 - world_size: int = 1 - group_name: str = "aupdate_weights_from_distributed" - parameter_names: List[str] = field(default_factory=list) - state_dict_key_to_shape: Dict[str, Tuple[int]] = field(default_factory=dict) + type: Literal["disk", "nccl"] + path: str | None = None + alloc_mode: AllocationMode | None = None + + nccl_master_address: str = "127.0.0.1" + nccl_master_port: int = 29500 + nccl_param_specs: List[ParamSpec] = field(default_factory=list) + nccl_group_name: str = "update_weight_group" + + @classmethod + def from_disk( + cls, + saver_config: SaverConfig, + ) -> "WeightUpdateMeta": + from arealite.utils.saver import Saver + + path = os.path.join( + Saver.get_save_checkpoint_root(saver_config), + "weight_update", + ) + return cls( + type="disk", + path=path, + ) + + @classmethod + def from_fsdp_nccl( + cls, + allocation_mode: AllocationMode, + fsdp_engine: "TrainEngine", + nccl_group_name: str = "update_weight_group", + ): + return cls( + type="nccl", + alloc_mode=allocation_mode, + nccl_master_address=gethostip(), + nccl_master_port=find_free_ports(1)[0], + nccl_param_specs=fsdp_engine.get_param_specs(), + nccl_group_name=nccl_group_name, + ) @dataclass diff --git a/arealite/engine/base_hf_engine.py b/arealite/engine/base_hf_engine.py index 8d0b1a8..f6b9e2a 100644 --- a/arealite/engine/base_hf_engine.py +++ b/arealite/engine/base_hf_engine.py @@ -46,6 +46,7 @@ class BaseHFEngine(TrainEngine): self.tokenizer: PreTrainedTokenizerFast # huggingface model config self.model_config: PretrainedConfig + self._version: int = 0 # initialization self.initialized = False @@ -55,6 +56,12 @@ class BaseHFEngine(TrainEngine): self.world_size = int(os.environ["WORLD_SIZE"]) + def set_version(self, version: int): + self._version = version + + def get_version(self) -> int: + return self._version + def train(self, mode: bool = True): assert self.model is not None self.model.train(mode=mode) @@ -191,7 +198,7 @@ class BaseHFEngine(TrainEngine): ) state_dict = self.optimizer.state_dict() torch.save(state_dict, shard_path) - dist.barrier() + dist.barrier(device_ids=[self.device.index]) def load_optimizer_state(self, path: str): # Load FSDP sharded state dict @@ -203,7 +210,7 @@ class BaseHFEngine(TrainEngine): ) optimizer_state_dict = torch.load(shard_path, weights_only=False) self.optimizer.load_state_dict(optimizer_state_dict) - dist.barrier() + dist.barrier(device_ids=[self.device.index]) def step_lr_scheduler(self): assert self.lr_scheduler is not None diff --git a/arealite/engine/fsdp_engine.py b/arealite/engine/fsdp_engine.py index a0926c3..4338b6d 100644 --- a/arealite/engine/fsdp_engine.py +++ b/arealite/engine/fsdp_engine.py @@ -1,7 +1,7 @@ import os import time from datetime import datetime -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, List, Optional import torch import torch.distributed as dist @@ -14,7 +14,8 @@ from torch.distributed.checkpoint.state_dict import ( from transformers import PreTrainedTokenizerFast from arealite.api.cli_args import TrainEngineConfig -from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta +from arealite.api.engine_api import FinetuneSpec +from arealite.api.io_struct import ParamSpec, SaveLoadMeta, WeightUpdateMeta from arealite.engine.base_hf_engine import BaseHFEngine from arealite.utils.distributed import init_custom_process_group from arealite.utils.fsdp import ( @@ -119,7 +120,7 @@ class FSDPEngine(BaseHFEngine): if tokenizer is not None: tokenizer.save_pretrained(path) - dist.barrier() + dist.barrier(device_ids=[self.device.index]) def _load_model_from_hf(self, path: str): """Load model from HuggingFace format.""" @@ -140,7 +141,7 @@ class FSDPEngine(BaseHFEngine): if not self.weight_update_group_initialized: self._init_distributed_weight_update(meta) self._update_weights_from_distributed() - dist.barrier() + dist.barrier(device_ids=[self.device.index]) torch.cuda.synchronize() elif meta.type == "disk": self._save_model_to_hf(meta.path, self.tokenizer) @@ -149,7 +150,7 @@ class FSDPEngine(BaseHFEngine): update_name = names.update_weights_from_disk( self.config.experiment_name, self.config.trial_name, - meta.model_version, + self.model_version, ) name_resolve.add( update_name, str(datetime.now().timestamp()), keepalive_ttl=120 @@ -158,16 +159,18 @@ class FSDPEngine(BaseHFEngine): raise ValueError(f"Unknown weight update type {meta.type}") def _init_distributed_weight_update(self, meta: WeightUpdateMeta): + # NOTE: Processes launched with torchrun will set the following env var to True, + # which blocks creating another TCP store for weight update. + os.environ["TORCHELASTIC_USE_AGENT_STORE"] = str(False) if dist.get_rank() == 0: self.weight_update_group = init_custom_process_group( backend="nccl", - world_size=meta.world_size, - init_method=f"tcp://{meta.master_address}:{meta.master_port}", + world_size=meta.alloc_mode.gen_world_size + 1, + init_method=f"tcp://{meta.nccl_master_address}:{meta.nccl_master_port}", rank=0, - group_name=meta.group_name, + group_name=meta.nccl_group_name, ) - # NOTE: synchronizing with sglang's barrier - dist.barrier(group=self.weight_update_group, device_ids=[self.device.index]) + # NOTE: sglang v0.4.9.post2 or later does not have the barrier call self.weight_update_group_initialized = True def _update_weights_from_distributed(self): @@ -179,23 +182,29 @@ class FSDPEngine(BaseHFEngine): else: tensor = param.data if dist.get_rank() == 0: - print(f"Broadcasting {name} with shape {tensor.shape}", flush=True) + logger.debug( + f"Broadcasting {name} with shape {tensor.shape}", flush=True + ) dist.broadcast(tensor, src=0, group=self.weight_update_group) - dist.barrier() del tensor # optional, for memory hygiene torch.cuda.empty_cache() - def get_param_meta_for_distributed_update(self) -> Dict[str, Tuple[int]]: - """Return a dict mapping param name to its shape (expanded if DTensor).""" - param_shapes = {} + def get_param_specs(self) -> List[ParamSpec]: + param_specs = [] for name, param in self.model.named_parameters(): if isinstance(param.data, DTensor): tensor = param.data.full_tensor() else: tensor = param.data - param_shapes[name] = tuple(tensor.shape) + param_specs.append( + ParamSpec( + name=name, + shape=tuple(tensor.shape), + dtype=str(tensor.dtype).split("torch.")[1], + ) + ) del tensor # free memory if full_tensor was created - return param_shapes + return param_specs def train_batch( self, diff --git a/arealite/engine/sglang_remote.py b/arealite/engine/sglang_remote.py index 4c838f4..bac9de3 100644 --- a/arealite/engine/sglang_remote.py +++ b/arealite/engine/sglang_remote.py @@ -1,10 +1,11 @@ import asyncio import os import random +import shutil import threading import time import traceback -from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from concurrent.futures import Future, ProcessPoolExecutor from datetime import datetime from queue import Empty, Full, Queue from typing import TYPE_CHECKING, Any, Callable, Dict, List @@ -27,17 +28,12 @@ from arealite.api.io_struct import ( ) from arealite.utils.data import concat_padded_tensors from arealite.utils.http import arequest_with_retry, get_default_connector -from realhf.base import logging, name_resolve, names, pkg_version +from realhf.base import logging, name_resolve, names if TYPE_CHECKING: from arealite.api.workflow_api import RolloutWorkflow logger = logging.getLogger(__name__) -if pkg_version.is_available("sglang"): - if pkg_version.is_version_greater_or_equal("sglang", "0.4.4"): - SGLANG_TOKEN_OUTPUT_IDENTIFIER = "output_ids" - else: - SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids" ROLLOUT_POLL_WAIT_TIME = 0.05 RID_CACHE_SIZE = 128 @@ -91,10 +87,7 @@ class RemoteSGLangEngine(InferenceEngine): def check_health(self, base_url): # Check server endpoint try: - response = requests.get( - f"{base_url}/metrics", - timeout=30, - ) + response = requests.get(f"{base_url}/health", timeout=30) return response.status_code == 200 except requests.exceptions.RequestException as e: return False @@ -123,7 +116,7 @@ class RemoteSGLangEngine(InferenceEngine): """Thread that runs the rollout loop.""" try: uvloop.run(self._rollout_thread_async()) - except Exception as e: + except Exception: traceback.print_exc() async def _rollout_thread_async(self): @@ -259,9 +252,7 @@ class RemoteSGLangEngine(InferenceEngine): accumulated_output_logprobs = [] accumulated_versions = [] - # Deal with rollout interruption - stop_reason = "length" - + # A single "rid" shares the same sever to allow KV cache reuse if req.rid in self.rid_to_address: server_addr = self.rid_to_address[req.rid] else: @@ -273,10 +264,19 @@ class RemoteSGLangEngine(InferenceEngine): self.rid_to_address[req.rid] = server_addr self.rid_queue.append(req.rid) + # Deal with rollout interruption + # "abort" is the stop reason for later v0.4.9.post2 after + # we call the pause_generation endpoint + stop_reason = None while ( stop_reason != "stop" and len(accumulated_output_tokens) < gconfig.max_new_tokens ): + # Request is interrupted, wait for some time to avoid interfering + # with update weights requests + if stop_reason is not None: + await asyncio.sleep(0.5) + # loop until the generation is complete result = await arequest_with_retry( session=self.session, @@ -288,8 +288,17 @@ class RemoteSGLangEngine(InferenceEngine): timeout=self.config.request_timeout, ) - # Parse response meta_info = result["meta_info"] + # Check if generation is complete + finish_reason = meta_info["finish_reason"] + stop_reason = finish_reason["type"] + if ( + stop_reason == "abort" + and finish_reason.get("message") == "Abort before prefill" + ): + continue + + # Parse response output_tokens = [x[1] for x in meta_info["output_token_logprobs"]] output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]] @@ -299,11 +308,7 @@ class RemoteSGLangEngine(InferenceEngine): # FIXME: Update with actual server versions accumulated_versions.extend([-1] * len(output_tokens)) - # Check if generation is complete - finish_reason = meta_info["finish_reason"] - stop_reason = finish_reason["type"] - - payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER] + payload["input_ids"] += result["output_ids"] sample_params["max_new_tokens"] -= len(output_tokens) latency = time.perf_counter() - start_time @@ -318,30 +323,24 @@ class RemoteSGLangEngine(InferenceEngine): ttft=latency, # Simplified for non-streaming ) - def update_weights(self, meta): - executor = ThreadPoolExecutor(max_workers=1) - return executor.submit(self._update_weights, meta) - - def _update_weights(self, meta: WeightUpdateMeta): + def update_weights(self, meta: WeightUpdateMeta): + for addr in self.addresses: + res = requests.post(f"http://{addr}/pause_generation") + res.raise_for_status() + fut = Future() if meta.type == "nccl": - if not self.distributed_weight_update_initialized: - self._init_distributed_weight_update(meta) - tik = time.perf_counter() - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - for param in meta.parameter_names: - jobs = [ - self.aupdate_weights_from_distributed(addr, meta, param) - for addr in self.addresses - ] - loop.run_until_complete(asyncio.gather(*jobs)) - finally: - loop.close() - logger.info( - f"Distributed update weights done in {time.perf_counter() - tik}s" + fut = self.executor.submit( + update_weights_from_distributed, + meta, + self.addresses, + self.config.request_timeout, + not self.distributed_weight_update_initialized, ) - self.set_version(meta.model_version) + + def callback(fut): + self.distributed_weight_update_initialized = True + + fut.add_done_callback(callback) elif meta.type == "disk": # Update weights from disk # Use ProcessPool to bypass python GIL for running async coroutines @@ -349,7 +348,7 @@ class RemoteSGLangEngine(InferenceEngine): update_weights_from_disk, self.config.experiment_name, self.config.trial_name, - meta.model_version, + self.get_version(), self.addresses, meta.path, self.config.request_retries, @@ -357,64 +356,19 @@ class RemoteSGLangEngine(InferenceEngine): ) def callback(fut): - self.set_version(meta.model_version) + shutil.rmtree(meta.path, ignore_errors=True) fut.add_done_callback(callback) - return fut else: raise NotImplementedError(f"Unsupported weight update type: {meta.type}") - def _init_distributed_weight_update(self, meta: WeightUpdateMeta): - try: - # Initialize weights update group - jobs = [ - self.ainit_weights_update_group(addr, meta) for addr in self.addresses - ] - loop = asyncio.new_event_loop() - # asyncio event loop should be manually set when running asyncio stuff in another thread - asyncio.set_event_loop(loop) - loop.run_until_complete(asyncio.gather(*jobs)) - self.distributed_weight_update_initialized = True - logger.info(f"Distributed update weights initialized") - finally: - loop.close() + def callback(fut): + for addr in self.addresses: + res = requests.post(f"http://{addr}/continue_generation") + res.raise_for_status() - async def ainit_weights_update_group(self, addr: str, meta: WeightUpdateMeta): - rank_offset = 1 + self.addresses.index(addr) * meta.tp_size - payload = { - "master_address": meta.master_address, - "master_port": str(meta.master_port), - "rank_offset": rank_offset, - "world_size": meta.world_size, - "group_name": meta.group_name, - "backend": "nccl", - } - res = await arequest_with_retry( - addr=addr, - endpoint="/init_weights_update_group", - payload=payload, - method="POST", - max_retries=1, - timeout=self.config.request_timeout, - ) - assert res["success"] - - async def aupdate_weights_from_distributed( - self, addr: str, meta: WeightUpdateMeta, parameter_name: str - ): - res = await arequest_with_retry( - addr=addr, - endpoint="/update_weights_from_distributed", - payload={ - "name": parameter_name, - "dtype": "bfloat16", - "shape": meta.state_dict_key_to_shape[parameter_name], - }, - method="POST", - max_retries=1, - timeout=self.config.request_timeout, - ) - assert res["success"] + fut.add_done_callback(callback) + return fut def get_capacity(self): if dist.is_initialized(): @@ -519,27 +473,6 @@ class RemoteSGLangEngine(InferenceEngine): self.paused.clear() -async def aupdate_weights_from_disk( - session, addr, path: str, request_retries: int, request_timeout: float -): - tik = time.time() - res = await arequest_with_retry( - addr=addr, - session=session, - endpoint="/update_weights_from_disk", - payload=dict(model_path=str(path), allow_interrupt=True), - method="POST", - max_retries=request_retries, - timeout=request_timeout, - ) - assert res["success"] - if "num_paused_requests" in res: - logger.info( - f"{res['num_paused_requests']} requests are interrupted " - f"during updating weights for server {addr}" - ) - - def update_weights_from_disk( experiment_name, trial_name, @@ -569,12 +502,14 @@ def update_weights_from_disk( connector=get_default_connector(), ) jobs = [ - aupdate_weights_from_disk( - session=session, + arequest_with_retry( addr=addr, - path=path, - request_retries=request_retries, - request_timeout=request_timeout, + session=session, + endpoint="/update_weights_from_disk", + payload=dict(model_path=str(path)), + method="POST", + max_retries=request_retries, + timeout=request_timeout, ) for addr in addresses ] @@ -585,3 +520,72 @@ def update_weights_from_disk( ) return uvloop.run(_fn()) + + +def update_weights_from_distributed( + meta: WeightUpdateMeta, + addresses: List[str], + request_timeout, + init_group: bool, +): + async def _fn(): + tik = time.perf_counter() + if init_group: + await asyncio.gather( + *[ + ainit_weights_update_group(addr, i, meta, request_timeout) + for i, addr in enumerate(addresses) + ] + ) + await asyncio.gather( + *[ + arequest_with_retry( + addr=addr, + endpoint="/update_weights_from_distributed", + payload={ + "names": [pspec.name for pspec in meta.nccl_param_specs], + "dtypes": [pspec.dtype for pspec in meta.nccl_param_specs], + "shapes": [pspec.shape for pspec in meta.nccl_param_specs], + "group_name": meta.nccl_group_name, + }, + method="POST", + max_retries=1, + timeout=request_timeout, + ) + for addr in addresses + ] + ) + logger.info(f"Distributed update weights done in {time.perf_counter() - tik}s") + + return uvloop.run(_fn()) + + +async def ainit_weights_update_group( + addr: str, + server_idx: int, + meta: WeightUpdateMeta, + request_timeout: float, +): + assert meta.alloc_mode is not None + if meta.alloc_mode.gen_pp_size != 1: + raise NotImplementedError( + "NCCL weight update with PP size > 1 is not implemented yet." + ) + rank_offset = 1 + server_idx * meta.alloc_mode.gen_tp_size + payload = { + "master_address": meta.nccl_master_address, + "master_port": str(meta.nccl_master_port), + "rank_offset": rank_offset, + "world_size": meta.alloc_mode.gen_world_size + 1, + "backend": "nccl", + "group_name": meta.nccl_group_name, + } + res = await arequest_with_retry( + addr=addr, + endpoint="/init_weights_update_group", + payload=payload, + method="POST", + max_retries=1, + timeout=request_timeout, + ) + assert res["success"] diff --git a/arealite/launcher/local.py b/arealite/launcher/local.py index 447e7db..27cc0ba 100644 --- a/arealite/launcher/local.py +++ b/arealite/launcher/local.py @@ -283,7 +283,7 @@ def main_local(): if not cfg.server_only: launcher.submit( job_name="trainer", - cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} --standalone {' '.join(sys.argv[1:])}", + cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} --master-addr localhost --master-port {find_free_ports(1, (10000, 50000))[0]} {' '.join(sys.argv[1:])}", gpu=alloc_mode.train_world_size, env_vars=dict(AREAL_LLM_SERVER_ADDRS=",".join(server_addrs)), ) diff --git a/arealite/launcher/sglang_server.py b/arealite/launcher/sglang_server.py index 2e88c04..938ce5c 100644 --- a/arealite/launcher/sglang_server.py +++ b/arealite/launcher/sglang_server.py @@ -3,10 +3,8 @@ import subprocess import sys import time import uuid -from pathlib import Path from typing import Optional -import ray import requests from arealite.api.cli_args import ( @@ -19,12 +17,12 @@ from arealite.api.cli_args import ( from arealite.api.io_struct import AllocationMode, AllocationType from arealite.utils.launcher import TRITON_CACHE_PATH from arealite.utils.network import find_free_ports, gethostip -from realhf.base import logging, name_resolve, names, pkg_version +from realhf.base import logging, name_resolve, names logger = logging.getLogger("SGLangServer Wrapper") -def execute_shell_command(command: str) -> subprocess.Popen: +def launch_server_cmd(command: str) -> subprocess.Popen: """ Execute a shell command and return its process handle. """ @@ -45,46 +43,6 @@ def execute_shell_command(command: str) -> subprocess.Popen: ) -def apply_sglang_patch(): - p = Path(os.path.dirname(__file__)) - patch_path = str( - p.parent.parent - / "patch" - / "sglang" - / f"v{pkg_version.get_version('sglang')}.patch" - ) - - target_path = "" - sglang_meta = subprocess.check_output( - "python3 -m pip show sglang", shell=True - ).decode("ascii") - for line in sglang_meta.split("\n"): - line = line.strip() - if line.startswith("Editable project location: "): - target_path = str(Path(line.split(": ")[1]).parent) - - if target_path: - proc = subprocess.Popen( - ["git", "apply", patch_path], - cwd=target_path, - stderr=sys.stdout, - stdout=sys.stdout, - ) - proc.wait() - logger.info(f"Applied SGLang patch at {target_path}") - - -def launch_server_cmd(command: str): - """ - Launch the server using the given command. - If no port is specified, a free port is reserved. - """ - if not ray.is_initialized(): - apply_sglang_patch() - process = execute_shell_command(command) - return process - - def wait_for_server(base_url: str, timeout: Optional[int] = None) -> None: """Wait for the server to be ready by polling the /v1/models endpoint. diff --git a/arealite/tests/test_fsdp_engine_nccl.py b/arealite/tests/test_fsdp_engine_nccl.py index 7d71fd6..e06b94b 100644 --- a/arealite/tests/test_fsdp_engine_nccl.py +++ b/arealite/tests/test_fsdp_engine_nccl.py @@ -12,10 +12,9 @@ from arealite.api.cli_args import ( SGLangConfig, TrainEngineConfig, ) -from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta +from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta from arealite.engine.fsdp_engine import FSDPEngine from arealite.engine.sglang_remote import RemoteSGLangEngine -from arealite.utils.network import find_free_ports from realhf.base import network EXPR_NAME = "test_fsdp_engine_nccl" @@ -33,7 +32,7 @@ RUN_SERVER_TIMEOUT = 180 def check_server_health(base_url): try: - response = requests.get(f"{base_url}/metrics", timeout=30) + response = requests.get(f"{base_url}/health", timeout=30) return response.status_code == 200 except requests.exceptions.RequestException: return False @@ -82,14 +81,14 @@ def sglang_server_nccl(): def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server_nccl): - # 设置分布式环境变量 + # Set environment variables for torch distributed os.environ["WORLD_SIZE"] = "1" os.environ["RANK"] = "0" os.environ["LOCAL_RANK"] = "0" os.environ["MASTER_ADDR"] = HOST os.environ["MASTER_PORT"] = str(MASTER_PORT) - # 启动本地FSDPEngine + # Initialize FSDPEngine engine_config = TrainEngineConfig( experiment_name=EXPR_NAME, trial_name=TRIAL_NAME, @@ -100,38 +99,26 @@ def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) engine.initialize(None, ft_spec) - # 启动远端RemoteSGLangEngine + # Initialize RemoteSGLangEngine config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) config.server_addrs = [f"{HOST}:{PORT}"] remote_engine = RemoteSGLangEngine(config) remote_engine.initialize(None, None) - # 构造WeightUpdateMeta(type=nccl) - param_meta = engine.get_param_meta_for_distributed_update() - meta = WeightUpdateMeta( - type="nccl", - path=None, - alloc_mode=None, - comm_backend="nccl", - model_version=123, - tp_size=1, - master_address="localhost", - master_port=find_free_ports(1)[0], - world_size=2, - group_name=GROUP_NAME, - parameter_names=list(param_meta.keys()), - state_dict_key_to_shape=param_meta, + # Get WeightUpdateMeta + meta = WeightUpdateMeta.from_fsdp_nccl( + AllocationMode.from_str("sglang.d1p1t1+d1p1t1"), + engine, + nccl_group_name=GROUP_NAME, ) - # 本地engine广播参数 + # Broadcast weights future = remote_engine.update_weights(meta) print("got future", flush=True) engine.upload_weights(meta) print("uploaded wexights to remote engine", flush=True) - # 远端engine拉取参数 + # Wait for remote engine to finish future.result(timeout=120) print("got result", flush=True) - # 检查远端参数版本 - assert remote_engine.get_version() == 123 remote_engine.destroy() engine.destroy() diff --git a/arealite/tests/test_sglang_engine.py b/arealite/tests/test_sglang_engine.py index 71a6f2f..d660315 100644 --- a/arealite/tests/test_sglang_engine.py +++ b/arealite/tests/test_sglang_engine.py @@ -2,7 +2,6 @@ import os import subprocess import sys import time -import uuid import pytest import requests @@ -14,7 +13,7 @@ from arealite.api.cli_args import ( InferenceEngineConfig, SGLangConfig, ) -from arealite.api.io_struct import LLMRequest, LLMResponse, WeightUpdateMeta +from arealite.api.io_struct import WeightUpdateMeta from arealite.utils import network from realhf.api.core.data_api import load_hf_tokenizer @@ -31,10 +30,7 @@ RUN_SERVER_TIMEOUT = 180 def check_server_health(base_url): try: - response = requests.get( - f"{base_url}/metrics", - timeout=30, - ) + response = requests.get(f"{base_url}/health", timeout=30) return response.status_code == 200 except requests.exceptions.RequestException as e: return False @@ -77,29 +73,6 @@ def sglang_server(): process.terminate() -@pytest.mark.asyncio -async def test_remote_sglang_generate(sglang_server): - from arealite.engine.sglang_remote import RemoteSGLangEngine - - config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) - tokenizer = load_hf_tokenizer(MODEL_PATH) - os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" - engine = RemoteSGLangEngine(config) - req = LLMRequest( - rid=str(uuid.uuid4()), - input_ids=tokenizer.encode("hello! how are you today"), - gconfig=GenerationHyperparameters(max_new_tokens=16), - ) - resp = await engine.agenerate(req) - assert isinstance(resp, LLMResponse) - assert resp.input_tokens == req.input_ids - assert ( - len(resp.output_logprobs) - == len(resp.output_tokens) - == len(resp.output_versions) - ) - - @pytest.mark.parametrize("n_samples", [1, 2, 4]) def test_remote_sglang_rollout(sglang_server, n_samples): from arealite.engine.sglang_remote import RemoteSGLangEngine @@ -211,6 +184,7 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server): engine = FSDPEngine(engine_config) ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) engine.initialize(None, ft_spec) + engine.model_version = 100 # setup name resolve import realhf.base.name_resolve as name_resolve @@ -227,13 +201,11 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server): os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" inf_engine = RemoteSGLangEngine(config) inf_engine.initialize(None, None) + inf_engine.set_version(100) # test update weights path = tmp_path_factory.mktemp("upload_weights_from_disk") - update_weight_meta = WeightUpdateMeta( - type="disk", path=path, alloc_mode=None, comm_backend=None, model_version=100 - ) + update_weight_meta = WeightUpdateMeta(type="disk", path=str(path)) future = inf_engine.update_weights(update_weight_meta) engine.upload_weights(update_weight_meta) future.result() - assert inf_engine.get_version() == 100 inf_engine.destroy() diff --git a/examples/arealite/boba.py b/examples/arealite/boba.py deleted file mode 100644 index 4211d11..0000000 --- a/examples/arealite/boba.py +++ /dev/null @@ -1,310 +0,0 @@ -import asyncio -import os -import shutil -import sys -import uuid - -import colorama -import torch -import torch.distributed as dist -from datasets import load_dataset -from datasets.distributed import split_dataset_by_node -from tensordict import TensorDict -from torchdata.stateful_dataloader import StatefulDataLoader -from transformers import PreTrainedTokenizerFast - -from arealite.api.cli_args import ( - GenerationHyperparameters, - GRPOConfig, - load_expr_config, -) -from arealite.api.io_struct import FinetuneSpec, LLMRequest, WeightUpdateMeta -from arealite.api.workflow_api import RolloutWorkflow -from arealite.engine.ppo.actor import FSDPPPOActor -from arealite.engine.sglang_remote import RemoteSGLangEngine -from arealite.utils.data import concat_padded_tensors -from arealite.utils.device import log_gpu_stats -from arealite.utils.saver import Saver -from arealite.utils.stats_logger import StatsLogger -from realhf.api.core.data_api import load_hf_tokenizer -from realhf.base import logging, seeding, stats_tracker - -logger = logging.getLogger("boba math") - - -class RLVRWorkflow(RolloutWorkflow): - def __init__( - self, - reward_fn, - gconfig: GenerationHyperparameters, - tokenizer: PreTrainedTokenizerFast, - dump_dir: str | None = None, - ): - self.reward_fn = reward_fn - self.gconfig = gconfig - self.tokenizer = tokenizer - self.dump_dir = dump_dir - if self.dump_dir is not None and not os.path.exists(self.dump_dir): - os.makedirs(self.dump_dir, exist_ok=True) - - async def arun_episode(self, engine, data): - input_ids = self.tokenizer.encode(data["prompt"]) - n_samples = self.gconfig.n_samples - req = LLMRequest( - rid=uuid.uuid4().hex, - input_ids=input_ids, - gconfig=self.gconfig.new(n_samples=1), - ) - resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)]) - - version = engine.get_version() - prompt_strs = [] - completions_strs = [] - rewards = [] - seqlens = [] - - results = [] - for resp in resps: - seq = resp.input_tokens + resp.output_tokens - logprobs = [0.0] * resp.input_len + resp.output_logprobs - loss_mask = [0] * resp.input_len + [1] * resp.output_len - versions = [-1] * resp.input_len + resp.output_versions - - prompt_str = data["prompt"] - completions_str = self.tokenizer.decode(resp.output_tokens) - prompt_strs.append(prompt_str) - completions_strs.append(completions_str) - seqlens.append(len(seq)) - reward = self.reward_fn( - completions=completions_str, - prompt_ids=resp.input_tokens, - completion_ids=resp.output_tokens, - **data, - ) - rewards.append(reward) - res = dict( - # unsqueeze to add an additional batch dimension - input_ids=torch.tensor(seq).unsqueeze(0), - loss_mask=torch.tensor(loss_mask).unsqueeze(0), - logprobs=torch.tensor(logprobs).unsqueeze(0), - versions=torch.tensor(versions).unsqueeze(0), - attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0), - # reward - rewards=torch.tensor([float(reward)]), - ) - results.append(TensorDict(res, batch_size=[1])) - - if self.dump_dir is not None: - os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True) - # Get the unique identifier for this prompt - qid = None - for key in ["query_id", "id", "qid"]: - qid = data.get(key, None) - if qid is not None: - break - qid = qid or uuid.uuid4().hex - - # Dump rollout to file - with open( - os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a" - ) as f: - n_samples = self.gconfig.n_samples - for i, (p, c, r, sl) in enumerate( - zip(prompt_strs, completions_strs, rewards, seqlens) - ): - info = "\n".join( - [ - f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.", - f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}", - f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}", - ] - ) - f.write(info + "\n") - - return concat_padded_tensors(results) - - -def get_boba_math_dataset(tokenizer, rank, world_size): - dataset = load_dataset( - path="json", - split="train", - data_files="/storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl", - ) - dataset = dataset.filter(lambda x: len(tokenizer.encode(x["prompt"])) <= 1024) - return split_dataset_by_node(dataset, rank=rank, world_size=world_size) - - -def boba_reward_fn( - prompt, completions, prompt_ids, completion_ids, query_id, solutions, **kwargs -): - from pebble import ProcessExpired, ProcessPool - - from realhf.impl.dataset.math_parser import process_results - - jobs = [] - with ProcessPool(max_workers=1) as executor: - for sol in solutions: - job = executor.schedule( - process_results, args=[completions, sol], timeout=15 - ) - jobs.append(job) - - label = 0 - for job in jobs: - try: - x = job.result() - except TimeoutError: - # print("[debug: timeout]") - logger.warning(f"Timeout occurred while justifying the math answer.") - x = (0, "timeout", "timeout") - except ProcessExpired as e: - logger.warning(f"Process terminated abnormally: {e}") - x = (0, "error", "error") - except Exception as e: - logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}") - x = (0, "error", "error") - label = label or x[0] - return label - - -def main(args): - config, _ = load_expr_config(args, GRPOConfig) - config: GRPOConfig - - rank = int(os.getenv("RANK")) - world_size = int(os.getenv("WORLD_SIZE")) - tokenizer = load_hf_tokenizer(config.tokenizer_path) - - seeding.set_random_seed(config.seed, key=f"trainer{rank}") - - # Create dataset and dataloaders - train_dataloader = StatefulDataLoader( - get_boba_math_dataset(tokenizer, rank, world_size), - batch_size=config.train_dataset.batch_size // world_size, - shuffle=config.train_dataset.shuffle, - num_workers=config.train_dataset.num_workers, - collate_fn=lambda x: x, - drop_last=config.train_dataset.drop_last, - ) - ft_spec = FinetuneSpec( - total_train_epochs=config.total_train_epochs, - dataset_size=len(train_dataloader) * config.train_dataset.batch_size, - train_batch_size=config.train_dataset.batch_size, - ) - - # Initialize inference engine - rollout = RemoteSGLangEngine(config.rollout) - rollout.initialize(None, ft_spec) - - # Initialize train engine - actor = FSDPPPOActor(config=config.actor) - actor.initialize(None, ft_spec) - ref = None - if config.actor.kl_ctl > 0 and config.ref is not None: - ref = FSDPPPOActor(config=config.ref) - ref.initialize(None, ft_spec) - - # Create rollout workflow - if tokenizer.pad_token_id not in config.gconfig.stop_token_ids: - config.gconfig.stop_token_ids.append(tokenizer.pad_token_id) - if tokenizer.eos_token_id not in config.gconfig.stop_token_ids: - config.gconfig.stop_token_ids.append(tokenizer.eos_token_id) - workflow = RLVRWorkflow( - reward_fn=boba_reward_fn, - gconfig=config.gconfig, - tokenizer=tokenizer, - dump_dir=os.path.join( - StatsLogger.get_log_path(config.stats_logger), "generated" - ), - ) - - # Run training. - saver = Saver(config.saver, ft_spec, for_recover=False) - logger = StatsLogger(config.stats_logger, ft_spec) - - total_epochs = config.total_train_epochs - steps_per_epoch = len(train_dataloader) - max_steps = total_epochs * steps_per_epoch - - logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}") - data_generator = iter(train_dataloader) - for global_step in range(max_steps): - epoch = global_step // steps_per_epoch - step = global_step % steps_per_epoch - - with stats_tracker.record_timing("rollout"): - if config.async_training: - batch = rollout.prepare_batch(train_dataloader, workflow=workflow) - else: - try: - data = next(data_generator) - except StopIteration: - data_generator = iter(train_dataloader) - data = next(data_generator) - batch = rollout.rollout_batch(data, workflow=workflow) - - batch = batch.to(actor.device) - # Create barrier to synchronize all rollout processes. - dist.barrier() - torch.cuda.synchronize() - - if config.actor.recompute_logprob or config.actor.use_decoupled_loss: - with stats_tracker.record_timing("recompute_logp"): - logp = actor.compute_logp(batch) - batch["prox_logp"] = logp - log_gpu_stats("recompute logp") - - if ref is not None: - with stats_tracker.record_timing("ref_logp"): - batch["ref_logp"] = ref.compute_logp(batch) - log_gpu_stats("ref logp") - - with stats_tracker.record_timing("compute_advantage"): - actor.compute_advantages(batch) - log_gpu_stats("compute advantages") - - with ( - stats_tracker.record_timing("train_step"), - stats_tracker.scope("grpo_actor"), - ): - stats = actor.ppo_update(batch) - actor.step_lr_scheduler() - log_gpu_stats("ppo update") - - with stats_tracker.record_timing("update_weights"): - path = os.path.join( - Saver.get_save_checkpoint_root(config.saver), - "update_weights", - str(global_step + 1), - ) - meta = WeightUpdateMeta( - type="disk", - path=path, - alloc_mode=None, - comm_backend=None, - model_version=global_step + 1, - ) - if dist.get_rank() == 0: - future = rollout.update_weights(meta) - actor.upload_weights(meta) - if dist.get_rank() == 0: - future.result() - shutil.rmtree(path, ignore_errors=True) - dist.barrier() - torch.cuda.synchronize() - rollout.set_version(global_step + 1) - - with stats_tracker.record_timing("save"): - saver.save(actor, epoch, step, global_step) - - logger.commit(epoch, step, global_step, stats) - - logger.close() - rollout.destroy() - if ref is not None: - ref.destroy() - actor.destroy() - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/examples/arealite/configs/boba.yaml b/examples/arealite/configs/boba.yaml deleted file mode 100644 index 3924a14..0000000 --- a/examples/arealite/configs/boba.yaml +++ /dev/null @@ -1,141 +0,0 @@ -experiment_name: lite-boba-math -trial_name: run1 - -cluster: - n_nodes: 16 - n_gpus_per_node: 8 - cluster_name: na132 - fileroot: /storage/openpsi/experiments - name_resolve: - type: nfs - nfs_record_root: /storage/openpsi/experiments/name_resolve/lite-boba-math - etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379 - -seed: 1 -total_train_epochs: 10 -total_train_steps: null -tokenizer_path: ${actor.path} -allocation_mode: sglang.d96p1t1+d32p1t1 -async_training: true - -rollout: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - max_concurrent_rollouts: 400 - queue_size: null - consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 4 - enable_rollout_tracing: true - -gconfig: - n_samples: 16 - min_new_tokens: 0 - max_new_tokens: 30720 - greedy: false - temperature: 1.0 - -actor: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/ - init_from_scratch: false - disable_dropout: true - gradient_checkpointing: true - dtype: bfloat16 - mb_spec: - max_tokens_per_mb: 32768 - optimizer: - type: adam - lr: 1e-5 - weight_decay: 0.01 - beta1: 0.9 - beta2: 0.999 - eps: 1e-8 - lr_scheduler_type: constant - gradient_clipping: 1.0 - warmup_steps_proportion: 0.001 - backend: fsdp - - group_size: ${gconfig.n_samples} - group_adv_norm: false - eps_clip: 0.4 - temperature: ${gconfig.temperature} - reward_scaling: 10.0 - reward_bias: -0.5 - kl_ctl: 0.0 - ppo_n_minibatches: 4 - recompute_logprob: true - use_decoupled_loss: true - behav_imp_weight_cap: 5.0 - -ref: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: ${actor.path} - init_from_scratch: false - disable_dropout: true - dtype: ${actor.dtype} - mb_spec: - max_tokens_per_mb: 32768 - optimizer: null - backend: fsdp - -# SGLang -server_only: false -sglang: - model_path: ${actor.path} - random_seed: ${seed} - skip_tokenizer_init: true - dtype: ${actor.dtype} - max_running_requests: null - context_length: 32768 - mem_fraction_static: 0.9 - -# datasets -train_dataset: - batch_size: 512 - shuffle: true - pin_memory: true - -# Utilities -saver: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: null - -checkpointer: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: 3600 - -evaluator: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: null - freq_steps: null - freq_secs: null - -stats_logger: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - wandb: - mode: online - -# Launcher -launcher: - inference_server_cpus_per_gpu: 15 - inference_server_mem_per_gpu: 153600 - trainer_cpus_per_gpu: 15 - trainer_mem_per_gpu: 153600 - slurm: - mount: /storage:/storage - trainer_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif - inference_server_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif \ No newline at end of file diff --git a/examples/arealite/configs/gsm8k_grpo.yaml b/examples/arealite/configs/gsm8k_grpo.yaml index 9f0107e..9b0362d 100644 --- a/examples/arealite/configs/gsm8k_grpo.yaml +++ b/examples/arealite/configs/gsm8k_grpo.yaml @@ -1,9 +1,9 @@ experiment_name: gsm8k-grpo trial_name: trial0 allocation_mode: sglang.d4p1t1+d4p1t1 -n_nodes: 1 -n_gpus_per_node: 8 cluster: + n_nodes: 1 + n_gpus_per_node: 8 fileroot: /tmp/arealite/experiments name_resolve: type: nfs @@ -32,7 +32,7 @@ gconfig: actor: experiment_name: ${experiment_name} trial_name: ${trial_name} - path: /storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/ + path: Qwen/Qwen2-1.5B-Instruct init_from_scratch: false disable_dropout: true gradient_checkpointing: false diff --git a/examples/arealite/configs/gsm8k_sft.yaml b/examples/arealite/configs/gsm8k_sft.yaml index c8aaa4b..8d05abe 100644 --- a/examples/arealite/configs/gsm8k_sft.yaml +++ b/examples/arealite/configs/gsm8k_sft.yaml @@ -13,7 +13,7 @@ tokenizer_path: ${model.path} model: experiment_name: ${experiment_name} trial_name: ${trial_name} - path: /storage/openpsi/models/Qwen__Qwen3-1.7B/ + path: Qwen/Qwen3-1.7B init_from_scratch: false gradient_checkpointing: false dtype: bfloat16 diff --git a/examples/arealite/gsm8k_grpo.py b/examples/arealite/gsm8k_grpo.py index 6fbccb7..f44c0d3 100644 --- a/examples/arealite/gsm8k_grpo.py +++ b/examples/arealite/gsm8k_grpo.py @@ -1,5 +1,4 @@ import os -import shutil import sys import torch @@ -9,7 +8,7 @@ from datasets.distributed import split_dataset_by_node from torchdata.stateful_dataloader import StatefulDataLoader from arealite.api.cli_args import GRPOConfig, load_expr_config -from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta +from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta from arealite.engine.ppo.actor import FSDPPPOActor from arealite.engine.sglang_remote import RemoteSGLangEngine from arealite.utils.device import log_gpu_stats @@ -93,6 +92,18 @@ def main(args): ref = FSDPPPOActor(config=config.ref) ref.initialize(None, ft_spec) + # NOTE: Weight update meta only requires address and free port of rank 0, + # but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks + # due to `engine.get_param_specs()`. + # Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0. + weight_update_meta = [ + WeightUpdateMeta.from_fsdp_nccl( + AllocationMode.from_str(config.allocation_mode), actor + ) + ] + dist.broadcast_object_list(weight_update_meta, src=0) + weight_update_meta = weight_update_meta[0] + # Create rollout workflow if tokenizer.pad_token_id not in config.gconfig.stop_token_ids: config.gconfig.stop_token_ids.append(tokenizer.pad_token_id) @@ -136,7 +147,7 @@ def main(args): batch = batch.to(actor.device) # Create barrier to synchronize all rollout processes. - dist.barrier() + dist.barrier(device_ids=[actor.device.index]) torch.cuda.synchronize() if config.actor.recompute_logprob or config.actor.use_decoupled_loss: @@ -163,26 +174,16 @@ def main(args): log_gpu_stats("ppo update") with stats_tracker.record_timing("update_weights"): - path = os.path.join( - Saver.get_save_checkpoint_root(config.saver), - "update_weights", - str(global_step + 1), - ) - meta = WeightUpdateMeta( - type="disk", - path=path, - alloc_mode=None, - comm_backend=None, - model_version=global_step + 1, - ) + rollout.pause() if dist.get_rank() == 0: - future = rollout.update_weights(meta) - actor.upload_weights(meta) + future = rollout.update_weights(weight_update_meta) + actor.upload_weights(weight_update_meta) if dist.get_rank() == 0: future.result() - shutil.rmtree(path, ignore_errors=True) - dist.barrier() + dist.barrier(device_ids=[actor.device.index]) torch.cuda.synchronize() + rollout.resume() + actor.set_version(global_step + 1) rollout.set_version(global_step + 1) with stats_tracker.record_timing("save"): diff --git a/examples/env/scripts/setup-container-deps.sh b/examples/env/scripts/setup-container-deps.sh deleted file mode 100644 index 2b07e31..0000000 --- a/examples/env/scripts/setup-container-deps.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/sh -AREAL_PATH=$PWD -cd /sglang -git apply $AREAL_PATH/patch/sglang/v0.4.6.post4.patch -cd $AREAL_PATH - -# Package used for calculating math reward -pip install -e evaluation/latex2sympy - -# Install AReaL -pip install -e . \ No newline at end of file diff --git a/examples/env/scripts/setup-eval-pip-deps.sh b/examples/env/scripts/setup-eval-pip-deps.sh index fbdf4ee..659ed73 100644 --- a/examples/env/scripts/setup-eval-pip-deps.sh +++ b/examples/env/scripts/setup-eval-pip-deps.sh @@ -1,9 +1,9 @@ #/bin/bash # basic dependencies pip install -U pip -pip uninstall deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y -pip install nvidia-ml-py +pip uninstall pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y +pip install pynvml nvidia-ml-py pip install -e evaluation/latex2sympy pip install vllm==0.8.5 --no-build-isolation -pip install flash_attn --no-build-isolation +pip install "flash-attn<=2.7.3" --no-build-isolation pip install -r evaluation/requirements.txt \ No newline at end of file diff --git a/examples/env/scripts/setup-pip-deps.sh b/examples/env/scripts/setup-pip-deps.sh index 8fa203b..1599170 100644 --- a/examples/env/scripts/setup-pip-deps.sh +++ b/examples/env/scripts/setup-pip-deps.sh @@ -1,24 +1,14 @@ #!/bin/bash # basic dependencies pip install -U pip -pip uninstall torch deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y -pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 -pip install "sglang[all]==0.4.6.post4" +pip uninstall pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y +pip install torch==2.7.1 torchaudio==2.7.1 torchvision==0.22.1 "deepspeed>=0.17.2" pynvml +pip install "sglang[all]==0.4.9.post2" pip install megatron-core==0.11.0 nvidia-ml-py pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose pip install "flash-attn<=2.7.3" --no-build-isolation # Package used for calculating math reward pip install -e evaluation/latex2sympy - -# Install an editable sglang -rm -rf ./sglang -git clone -b v0.4.6.post4 https://github.com/sgl-project/sglang -AREAL_PATH=$PWD -cd sglang -git apply ../patch/sglang/v0.4.6.post4.patch -pip install -e "python[all]" --no-deps -cd $AREAL_PATH - # Install AReaL pip install -e . diff --git a/examples/env/validate_installation.py b/examples/env/validate_installation.py index 61ef6f7..67ae740 100644 --- a/examples/env/validate_installation.py +++ b/examples/env/validate_installation.py @@ -67,6 +67,7 @@ class InstallationValidator: def test_flash_attn_functionality(self, flash_attn_module): """Test flash attention functionality.""" # Try to import key functions + import flash_attn_2_cuda from flash_attn import flash_attn_func, flash_attn_varlen_func print(" - Flash attention functions imported successfully") @@ -79,12 +80,12 @@ class InstallationValidator: """Test SGLang basic functionality.""" # Basic import test is sufficient for CI import sgl_kernel - from sglang import launch_server - assert Version(get_version("sglang")) == Version("0.4.6.post4") + from sglang import Engine, launch_server + assert Version(get_version("sglang")) == Version("0.4.9.post2"), "SGLang version should be v0.4.9.post2" print(" - SGLang imported successfully") def test_transformers(self, transformers_module): - assert Version(get_version("transformers")) == Version("4.51.1") + assert Version(get_version("transformers")) == Version("4.53.1"), "transformers version should be 4.53.1" print(" - transformers imported successfully") def validate_critical_dependencies(self): @@ -140,7 +141,7 @@ class InstallationValidator: self.test_import("flashattn_hopper", required=False) # Optional utilities - self.test_import("tensorboardx", required=False) + self.test_import("tensorboardX", required=False) self.test_import("swanlab", required=False) self.test_import("matplotlib", required=False) self.test_import("seaborn", required=False) diff --git a/patch/sglang/v0.4.6.post2.patch b/patch/sglang/v0.4.6.post2.patch deleted file mode 100644 index 6bf47bf..0000000 --- a/patch/sglang/v0.4.6.post2.patch +++ /dev/null @@ -1,144 +0,0 @@ -diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py -index 174656b2..33fe0a5f 100644 ---- a/python/sglang/srt/managers/io_struct.py -+++ b/python/sglang/srt/managers/io_struct.py -@@ -687,10 +687,21 @@ class FlushCacheReqOutput: - success: bool - - -+@dataclass -+class InterruptAllReqInput: -+ pass -+ -+ -+@dataclass -+class InterruptAllReqOutput: -+ num_interrupted_requests: int -+ -+ - @dataclass - class UpdateWeightFromDiskReqInput: - # The model path with the new weights - model_path: str -+ allow_interrupt: bool = False - # The format to load the weights - load_format: Optional[str] = None - -diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py -index 8891115c..843a8a82 100644 ---- a/python/sglang/srt/managers/scheduler.py -+++ b/python/sglang/srt/managers/scheduler.py -@@ -70,6 +70,8 @@ from sglang.srt.managers.io_struct import ( - HealthCheckOutput, - InitWeightsUpdateGroupReqInput, - InitWeightsUpdateGroupReqOutput, -+ InterruptAllReqInput, -+ InterruptAllReqOutput, - OpenSessionReqInput, - OpenSessionReqOutput, - ProfileReq, -@@ -419,6 +421,7 @@ class Scheduler( - # Init request dispatcher - self._request_dispatcher = TypeBasedDispatcher( - [ -+ (InterruptAllReqInput, self.interrupt_all_requests), - (TokenizedGenerateReqInput, self.handle_generate_request), - (TokenizedEmbeddingReqInput, self.handle_embedding_request), - (FlushCacheReqInput, self.flush_cache_wrapped), -@@ -1938,6 +1941,15 @@ class Scheduler( - def _pause_engine(self) -> Tuple[List[Req], int]: - raise NotImplementedError() - -+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput): -+ num = len(self.waiting_queue) + len(self.running_batch.reqs) -+ for req in self.waiting_queue: -+ req.sampling_params.max_new_tokens = 0 -+ for req in self.running_batch.reqs: -+ req.sampling_params.max_new_tokens = len(req.output_ids) -+ logger.info(f"Interrupt {num} requests.") -+ return InterruptAllReqOutput(num) -+ - def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): - """In-place update of the weights from disk.""" - success, message = self.tp_worker.update_weights_from_disk(recv_req) -diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index 82709b09..bfab3ce7 100644 ---- a/python/sglang/srt/managers/tokenizer_manager.py -+++ b/python/sglang/srt/managers/tokenizer_manager.py -@@ -76,6 +76,8 @@ from sglang.srt.managers.io_struct import ( - HealthCheckOutput, - InitWeightsUpdateGroupReqInput, - InitWeightsUpdateGroupReqOutput, -+ InterruptAllReqInput, -+ InterruptAllReqOutput, - OpenSessionReqInput, - OpenSessionReqOutput, - ProfileReq, -@@ -265,6 +267,9 @@ class TokenizerManager: - self.resume_memory_occupation_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) -+ self.interrupt_requests_communicator = _Communicator( -+ self.send_to_scheduler, server_args.dp_size -+ ) - self.flush_cache_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) -@@ -294,6 +299,10 @@ class TokenizerManager: - UpdateWeightFromDiskReqOutput, - self._handle_update_weights_from_disk_req_output, - ), -+ ( -+ InterruptAllReqOutput, -+ self.interrupt_requests_communicator.handle_recv, -+ ), - ( - InitWeightsUpdateGroupReqOutput, - self.init_weights_update_group_communicator.handle_recv, -@@ -767,6 +776,13 @@ class TokenizerManager: - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() - -+ if obj.allow_interrupt: -+ num_interrupted_requests = await self.interrupt_all_requests( -+ InterruptAllReqInput() -+ ) -+ # Set a break point to wait for the interrupt to finish -+ await asyncio.sleep(0.1) -+ - # default the load format to the server_args - if obj.load_format is None: - obj.load_format = self.server_args.load_format -@@ -776,7 +792,12 @@ class TokenizerManager: - # Hold the lock if it is not async. This means that weight sync - # cannot run while requests are in progress. - async with self.model_update_lock.writer_lock: -- return await self._wait_for_model_update_from_disk(obj) -+ success, message, n_paused = ( -+ await self._wait_for_model_update_from_disk(obj) -+ ) -+ if obj.allow_interrupt: -+ return success, message, num_interrupted_requests -+ return success, message, n_paused - - async def _wait_for_model_update_from_disk( - self, obj: UpdateWeightFromDiskReqInput -@@ -849,6 +870,18 @@ class TokenizerManager: - result = (await self.update_weights_from_tensor_communicator(obj))[0] - return result.success, result.message - -+ async def interrupt_all_requests( -+ self, -+ obj: InterruptAllReqInput, -+ request: Optional[fastapi.Request] = None, -+ ) -> Tuple[bool, str]: -+ self.auto_create_handle_loop() -+ result = await self.interrupt_requests_communicator(obj) -+ if self.server_args.dp_size == 1: -+ return result[0].num_interrupted_requests -+ else: -+ return [r.num_interrupted_requests for r in result] -+ - async def get_weights_by_name( - self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None - ): diff --git a/patch/sglang/v0.4.6.post4.patch b/patch/sglang/v0.4.6.post4.patch deleted file mode 100644 index b7dbd09..0000000 --- a/patch/sglang/v0.4.6.post4.patch +++ /dev/null @@ -1,144 +0,0 @@ -diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py -index 5390668c..db370d19 100644 ---- a/python/sglang/srt/managers/io_struct.py -+++ b/python/sglang/srt/managers/io_struct.py -@@ -687,10 +687,21 @@ class FlushCacheReqOutput: - success: bool - - -+@dataclass -+class InterruptAllReqInput: -+ pass -+ -+ -+@dataclass -+class InterruptAllReqOutput: -+ num_interrupted_requests: int -+ -+ - @dataclass - class UpdateWeightFromDiskReqInput: - # The model path with the new weights - model_path: str -+ allow_interrupt: bool = False - # The format to load the weights - load_format: Optional[str] = None - -diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py -index 1178eec5..318dee33 100644 ---- a/python/sglang/srt/managers/scheduler.py -+++ b/python/sglang/srt/managers/scheduler.py -@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import ( - HealthCheckOutput, - InitWeightsUpdateGroupReqInput, - InitWeightsUpdateGroupReqOutput, -+ InterruptAllReqInput, -+ InterruptAllReqOutput, - OpenSessionReqInput, - OpenSessionReqOutput, - ProfileReq, -@@ -427,6 +429,7 @@ class Scheduler( - # Init request dispatcher - self._request_dispatcher = TypeBasedDispatcher( - [ -+ (InterruptAllReqInput, self.interrupt_all_requests), - (TokenizedGenerateReqInput, self.handle_generate_request), - (TokenizedEmbeddingReqInput, self.handle_embedding_request), - (FlushCacheReqInput, self.flush_cache_wrapped), -@@ -1971,6 +1974,15 @@ class Scheduler( - def _pause_engine(self) -> Tuple[List[Req], int]: - raise NotImplementedError() - -+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput): -+ num = len(self.waiting_queue) + len(self.running_batch.reqs) -+ for req in self.waiting_queue: -+ req.sampling_params.max_new_tokens = 0 -+ for req in self.running_batch.reqs: -+ req.sampling_params.max_new_tokens = len(req.output_ids) -+ logger.info(f"Interrupt {num} requests.") -+ return InterruptAllReqOutput(num) -+ - def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): - """In-place update of the weights from disk.""" - success, message = self.tp_worker.update_weights_from_disk(recv_req) -diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index b646fae1..c668728b 100644 ---- a/python/sglang/srt/managers/tokenizer_manager.py -+++ b/python/sglang/srt/managers/tokenizer_manager.py -@@ -80,6 +80,8 @@ from sglang.srt.managers.io_struct import ( - HealthCheckOutput, - InitWeightsUpdateGroupReqInput, - InitWeightsUpdateGroupReqOutput, -+ InterruptAllReqInput, -+ InterruptAllReqOutput, - OpenSessionReqInput, - OpenSessionReqOutput, - ProfileReq, -@@ -279,6 +281,9 @@ class TokenizerManager: - self.slow_down_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) -+ self.interrupt_requests_communicator = _Communicator( -+ self.send_to_scheduler, server_args.dp_size -+ ) - self.flush_cache_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) -@@ -309,6 +314,10 @@ class TokenizerManager: - UpdateWeightFromDiskReqOutput, - self._handle_update_weights_from_disk_req_output, - ), -+ ( -+ InterruptAllReqOutput, -+ self.interrupt_requests_communicator.handle_recv, -+ ), - ( - InitWeightsUpdateGroupReqOutput, - self.init_weights_update_group_communicator.handle_recv, -@@ -799,6 +808,13 @@ class TokenizerManager: - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() - -+ if obj.allow_interrupt: -+ num_interrupted_requests = await self.interrupt_all_requests( -+ InterruptAllReqInput() -+ ) -+ # Set a break point to wait for the interrupt to finish -+ await asyncio.sleep(0.1) -+ - # default the load format to the server_args - if obj.load_format is None: - obj.load_format = self.server_args.load_format -@@ -808,7 +824,12 @@ class TokenizerManager: - # Hold the lock if it is not async. This means that weight sync - # cannot run while requests are in progress. - async with self.model_update_lock.writer_lock: -- return await self._wait_for_model_update_from_disk(obj) -+ success, message, n_paused = ( -+ await self._wait_for_model_update_from_disk(obj) -+ ) -+ if obj.allow_interrupt: -+ return success, message, num_interrupted_requests -+ return success, message, n_paused - - async def _wait_for_model_update_from_disk( - self, obj: UpdateWeightFromDiskReqInput -@@ -881,6 +902,18 @@ class TokenizerManager: - result = (await self.update_weights_from_tensor_communicator(obj))[0] - return result.success, result.message - -+ async def interrupt_all_requests( -+ self, -+ obj: InterruptAllReqInput, -+ request: Optional[fastapi.Request] = None, -+ ) -> Tuple[bool, str]: -+ self.auto_create_handle_loop() -+ result = await self.interrupt_requests_communicator(obj) -+ if self.server_args.dp_size == 1: -+ return result[0].num_interrupted_requests -+ else: -+ return [r.num_interrupted_requests for r in result] -+ - async def get_weights_by_name( - self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None - ): diff --git a/pyproject.toml b/pyproject.toml index cd4e26d..c779fa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,10 +31,9 @@ dependencies = [ "huggingface_hub", "datasets", "accelerate", - "transformers==4.51.1", + "transformers==4.53.0", # Scientific computing - "numpy<2.0.0", "scipy", "pandas", "matplotlib", diff --git a/realhf/system/generation_server.py b/realhf/system/generation_server.py index 2a46141..5d09707 100644 --- a/realhf/system/generation_server.py +++ b/realhf/system/generation_server.py @@ -40,42 +40,11 @@ def execute_shell_command(command: str) -> subprocess.Popen: ) -def apply_sglang_patch(): - p = Path(os.path.dirname(__file__)) - patch_path = str( - p.parent.parent - / "patch" - / "sglang" - / f"v{pkg_version.get_version('sglang')}.patch" - ) - - target_path = "" - sglang_meta = subprocess.check_output( - "python3 -m pip show sglang", shell=True - ).decode("ascii") - for line in sglang_meta.split("\n"): - line = line.strip() - if line.startswith("Editable project location: "): - target_path = str(Path(line.split(": ")[1]).parent) - - if target_path: - proc = subprocess.Popen( - ["git", "apply", patch_path], - cwd=target_path, - stderr=sys.stdout, - stdout=sys.stdout, - ) - proc.wait() - logger.info(f"Applied SGLang patch at {target_path}") - - def launch_server_cmd(command: str, port: int = 30000): """ Launch the server using the given command. If no port is specified, a free port is reserved. """ - if not ray.is_initialized(): - apply_sglang_patch() assert port is not None full_command = f"{command} --port {port}" process = execute_shell_command(full_command) diff --git a/realhf/system/gserver_manager.py b/realhf/system/gserver_manager.py index e746fe9..a789ae7 100644 --- a/realhf/system/gserver_manager.py +++ b/realhf/system/gserver_manager.py @@ -155,39 +155,22 @@ class GserverManager(Worker): return None - async def flush_requests_and_update_weights( - self, server_url, new_param_path, update_weights_retries=5 - ): - server_index = self.server_urls.index(server_url) - success = False - for _ in range(update_weights_retries): - async with aiohttp.ClientSession( - server_url, - timeout=aiohttp.ClientTimeout( - total=self.config.flush_request_timeout, - sock_connect=self.config.flush_request_timeout, - ), - ) as session: - async with session.post( - f"/update_weights_from_disk", - json=dict(model_path=new_param_path, allow_interrupt=True), - ) as resp: - if resp.status == 200: - res = await resp.json() - success = res["success"] - if success: - if "num_paused_requests" in res: - logger.info( - f"{res['num_paused_requests']} requests are interrupted " - f"during updating weights for server {server_index}: {server_url}" - ) - return - logger.warning( - f"Update weights failed: {res['message']}. Retrying." - ) - logger.warning(f"Update weights failed: {resp.reason}. Retrying.") - time.sleep(0.1) - raise RuntimeError("Update weights failed.") + async def flush_requests_and_update_weights(self, server_url, new_param_path): + async with aiohttp.ClientSession( + server_url, + timeout=aiohttp.ClientTimeout( + total=self.config.flush_request_timeout, + sock_connect=self.config.flush_request_timeout, + ), + ) as session: + (await session.post("/pause_generation")).raise_for_status() + async with session.post( + f"/update_weights_from_disk", + json=dict(model_path=new_param_path), + ) as resp: + resp.raise_for_status() + assert (await resp.json())["success"] + (await session.post("/continue_generation")).raise_for_status() def _round_robin_schedule(self, req_meta: GenReqMeta) -> int: if not hasattr(self, "round_robin_idx"): diff --git a/requirements.txt b/requirements.txt index 178e958..0c21f9b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,6 @@ numba packaging pandas pybind11>=2.10.0 -numpy<2.0.0 psutil pynvml pytest