0724_merge7

This commit is contained in:
朱晗 2025-07-24 19:34:22 +08:00
commit f5924b1851
27 changed files with 331 additions and 1216 deletions

View File

@ -163,7 +163,6 @@ class WeightUpdateMeta:
type: str type: str
path: str | None path: str | None
alloc_mode: AllocationMode | None alloc_mode: AllocationMode | None
comm_backend: str | None
@dataclass @dataclass
class SaveLoadMeta: class SaveLoadMeta:

View File

@ -285,85 +285,6 @@ class PPOActorConfig(TrainEngineConfig):
) )
@dataclass
class PPOActorConfig(TrainEngineConfig):
# Core PPO/GRPO Parameters
group_size: int = field(
default=1, metadata={"help": "Number of sequences in each group"}
)
group_adv_norm: bool = field(
default=False,
metadata={
"help": "Normalize advantages within each prompt group rather than globally"
},
)
ppo_n_minibatches: int = field(
default=4, metadata={"help": "Number of minibatches for each PPO update"}
)
eps_clip: float = field(
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
)
c_clip: Optional[float] = field(
default=None,
metadata={
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
},
)
temperature: float = field(
default=1.0, metadata={"help": "Temperature during generation."}
)
# Reward
group_reward_norm: bool = field(
default=False,
metadata={
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
},
)
reward_scaling: float = field(
default=1.0, metadata={"help": "Reward scaling factor"}
)
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
reward_clip: float = field(
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
)
mask_no_eos_with_zero: bool = field(
default=False,
metadata={
"help": "Mask truncated generations (no EOS token) and exclude from training"
},
)
# Advantage Estimation
discount: float = field(
default=1.0, metadata={"help": "Discount factor for future rewards"}
)
gae_lambda: float = field(
default=1.0, metadata={"help": "Lambda parameter for GAE"}
)
adv_norm: bool = field(
default=True, metadata={"help": "Enable advantage normalization"}
)
# KL Control
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
# Asynchronous RL
recompute_logprob: bool = field(
default=False,
metadata={"help": "Recompute logp and replace the logp returned by inference."},
)
use_decoupled_loss: bool = field(
default=False,
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
)
behav_imp_weight_cap: Optional[float] = field(
default=None,
metadata={
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
},
)
@dataclass @dataclass
class SGLangConfig: class SGLangConfig:
"""Configuration for SGLang runtime. Refer to: """Configuration for SGLang runtime. Refer to:

View File

@ -12,6 +12,7 @@ from arealite.api.io_struct import (
FinetuneSpec, FinetuneSpec,
LLMRequest, LLMRequest,
LLMResponse, LLMResponse,
ParamSpec,
SaveLoadMeta, SaveLoadMeta,
WeightUpdateMeta, WeightUpdateMeta,
) )
@ -63,7 +64,19 @@ class TrainEngine(abc.ABC):
return self.train(False) return self.train(False)
def upload_weights(self, meta: WeightUpdateMeta): def upload_weights(self, meta: WeightUpdateMeta):
"""Upload weights to the inference engine.""" """Upload weights to the inference engine (in a blocking manner)."""
raise NotImplementedError()
def get_param_specs(self) -> List[ParamSpec]:
"""Get the parameter specifications for the model."""
raise NotImplementedError()
def set_version(self, version: int):
"""Set the current weight version in the train engine."""
raise NotImplementedError()
def get_version(self) -> int:
"""Get the current weight version in the train engine."""
raise NotImplementedError() raise NotImplementedError()
def save(self, meta: SaveLoadMeta): def save(self, meta: SaveLoadMeta):
@ -122,14 +135,22 @@ class InferenceEngine(abc.ABC):
def destroy(self): def destroy(self):
"""Destroy the engine and release GPU memory.""" """Destroy the engine and release GPU memory."""
def update_weights(self, meta: WeightUpdateMeta) -> Future:
"""Update weights in the inference engine."""
raise NotImplementedError()
async def agenerate(self, req: LLMRequest) -> LLMResponse: async def agenerate(self, req: LLMRequest) -> LLMResponse:
"""Asynchronously generate a response for the given request.""" """Asynchronously generate a response for the given request."""
raise NotImplementedError() raise NotImplementedError()
def update_weights(self, meta: WeightUpdateMeta) -> Future:
"""Update weights in the inference engine in a non-blocking manner."""
raise NotImplementedError()
def set_version(self, version: int) -> None:
"""Set the current weight version in the inference engine."""
raise NotImplementedError()
def get_version(self) -> int:
"""Get the current weight version in the inference engine."""
raise NotImplementedError()
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None: def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
"""Asynchronously submit a request to the inference engine. Exits immediately.""" """Asynchronously submit a request to the inference engine. Exits immediately."""
raise NotImplementedError() raise NotImplementedError()

View File

@ -2,17 +2,22 @@
# Licensed under the Apache License, Version 2.0 # Licensed under the Apache License, Version 2.0
import enum import enum
import itertools import itertools
import os
import re import re
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple
import torch import torch
from gymnasium.core import ActType, ObsType from gymnasium.core import ActType, ObsType
from PIL.Image import Image as ImageObject from PIL.Image import Image as ImageObject
from transformers import AutoProcessor, PreTrainedTokenizerFast from transformers import AutoProcessor, PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters from arealite.api.cli_args import GenerationHyperparameters, SaverConfig
from arealite.utils.network import find_free_ports, gethostip
if TYPE_CHECKING:
from arealite.api.engine_api import TrainEngine
@dataclass @dataclass
@ -168,20 +173,55 @@ class AllocationMode:
return other_alloc return other_alloc
@dataclass
class ParamSpec:
name: str
shape: Tuple
dtype: str
@dataclass @dataclass
class WeightUpdateMeta: class WeightUpdateMeta:
type: str type: Literal["disk", "nccl"]
path: str | None path: str | None = None
alloc_mode: AllocationMode | None alloc_mode: AllocationMode | None = None
comm_backend: str | None
model_version: int = 0 nccl_master_address: str = "127.0.0.1"
tp_size: int = 1 nccl_master_port: int = 29500
master_address: str = "127.0.0.1" nccl_param_specs: List[ParamSpec] = field(default_factory=list)
master_port: int = 29500 nccl_group_name: str = "update_weight_group"
world_size: int = 1
group_name: str = "aupdate_weights_from_distributed" @classmethod
parameter_names: List[str] = field(default_factory=list) def from_disk(
state_dict_key_to_shape: Dict[str, Tuple[int]] = field(default_factory=dict) cls,
saver_config: SaverConfig,
) -> "WeightUpdateMeta":
from arealite.utils.saver import Saver
path = os.path.join(
Saver.get_save_checkpoint_root(saver_config),
"weight_update",
)
return cls(
type="disk",
path=path,
)
@classmethod
def from_fsdp_nccl(
cls,
allocation_mode: AllocationMode,
fsdp_engine: "TrainEngine",
nccl_group_name: str = "update_weight_group",
):
return cls(
type="nccl",
alloc_mode=allocation_mode,
nccl_master_address=gethostip(),
nccl_master_port=find_free_ports(1)[0],
nccl_param_specs=fsdp_engine.get_param_specs(),
nccl_group_name=nccl_group_name,
)
@dataclass @dataclass

View File

@ -49,6 +49,7 @@ class BaseHFEngine(TrainEngine):
self.processor: AutoProcessor | None = None self.processor: AutoProcessor | None = None
# huggingface model config # huggingface model config
self.model_config: PretrainedConfig self.model_config: PretrainedConfig
self._version: int = 0
# initialization # initialization
self.initialized = False self.initialized = False
@ -64,6 +65,12 @@ class BaseHFEngine(TrainEngine):
self.world_size = int(os.environ["WORLD_SIZE"]) self.world_size = int(os.environ["WORLD_SIZE"])
def set_version(self, version: int):
self._version = version
def get_version(self) -> int:
return self._version
def train(self, mode: bool = True): def train(self, mode: bool = True):
assert self.model is not None assert self.model is not None
self.model.train(mode=mode) self.model.train(mode=mode)
@ -222,7 +229,7 @@ class BaseHFEngine(TrainEngine):
) )
state_dict = self.optimizer.state_dict() state_dict = self.optimizer.state_dict()
torch.save(state_dict, shard_path) torch.save(state_dict, shard_path)
dist.barrier() dist.barrier(device_ids=[self.device.index])
def load_optimizer_state(self, path: str): def load_optimizer_state(self, path: str):
# Load FSDP sharded state dict # Load FSDP sharded state dict
@ -234,7 +241,7 @@ class BaseHFEngine(TrainEngine):
) )
optimizer_state_dict = torch.load(shard_path, weights_only=False) optimizer_state_dict = torch.load(shard_path, weights_only=False)
self.optimizer.load_state_dict(optimizer_state_dict) self.optimizer.load_state_dict(optimizer_state_dict)
dist.barrier() dist.barrier(device_ids=[self.device.index])
def step_lr_scheduler(self): def step_lr_scheduler(self):
assert self.lr_scheduler is not None assert self.lr_scheduler is not None

View File

@ -1,7 +1,7 @@
import os import os
import time import time
from datetime import datetime from datetime import datetime
from typing import Callable, Dict, Optional, Tuple from typing import Callable, Dict, List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -14,7 +14,8 @@ from torch.distributed.checkpoint.state_dict import (
from transformers import AutoProcessor, PreTrainedTokenizerFast from transformers import AutoProcessor, PreTrainedTokenizerFast
from arealite.api.cli_args import TrainEngineConfig from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta from arealite.api.engine_api import FinetuneSpec
from arealite.api.io_struct import ParamSpec, SaveLoadMeta, WeightUpdateMeta
from arealite.engine.base_hf_engine import BaseHFEngine from arealite.engine.base_hf_engine import BaseHFEngine
from arealite.utils.distributed import init_custom_process_group from arealite.utils.distributed import init_custom_process_group
from arealite.utils.fsdp import ( from arealite.utils.fsdp import (
@ -125,7 +126,7 @@ class FSDPEngine(BaseHFEngine):
if processor is not None: if processor is not None:
processor.save_pretrained(path) processor.save_pretrained(path)
dist.barrier() dist.barrier(device_ids=[self.device.index])
def _load_model_from_hf(self, path: str): def _load_model_from_hf(self, path: str):
"""Load model from HuggingFace format.""" """Load model from HuggingFace format."""
@ -146,7 +147,7 @@ class FSDPEngine(BaseHFEngine):
if not self.weight_update_group_initialized: if not self.weight_update_group_initialized:
self._init_distributed_weight_update(meta) self._init_distributed_weight_update(meta)
self._update_weights_from_distributed() self._update_weights_from_distributed()
dist.barrier() dist.barrier(device_ids=[self.device.index])
torch.cuda.synchronize() torch.cuda.synchronize()
elif meta.type == "disk": elif meta.type == "disk":
self._save_model_to_hf(meta.path, self.tokenizer, self.processor) self._save_model_to_hf(meta.path, self.tokenizer, self.processor)
@ -155,7 +156,7 @@ class FSDPEngine(BaseHFEngine):
update_name = names.update_weights_from_disk( update_name = names.update_weights_from_disk(
self.config.experiment_name, self.config.experiment_name,
self.config.trial_name, self.config.trial_name,
meta.model_version, self.model_version,
) )
name_resolve.add( name_resolve.add(
update_name, str(datetime.now().timestamp()), keepalive_ttl=120 update_name, str(datetime.now().timestamp()), keepalive_ttl=120
@ -164,16 +165,18 @@ class FSDPEngine(BaseHFEngine):
raise ValueError(f"Unknown weight update type {meta.type}") raise ValueError(f"Unknown weight update type {meta.type}")
def _init_distributed_weight_update(self, meta: WeightUpdateMeta): def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
# NOTE: Processes launched with torchrun will set the following env var to True,
# which blocks creating another TCP store for weight update.
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = str(False)
if dist.get_rank() == 0: if dist.get_rank() == 0:
self.weight_update_group = init_custom_process_group( self.weight_update_group = init_custom_process_group(
backend="nccl", backend="nccl",
world_size=meta.world_size, world_size=meta.alloc_mode.gen_world_size + 1,
init_method=f"tcp://{meta.master_address}:{meta.master_port}", init_method=f"tcp://{meta.nccl_master_address}:{meta.nccl_master_port}",
rank=0, rank=0,
group_name=meta.group_name, group_name=meta.nccl_group_name,
) )
# NOTE: synchronizing with sglang's barrier # NOTE: sglang v0.4.9.post2 or later does not have the barrier call
dist.barrier(group=self.weight_update_group, device_ids=[self.device.index])
self.weight_update_group_initialized = True self.weight_update_group_initialized = True
def _update_weights_from_distributed(self): def _update_weights_from_distributed(self):
@ -185,23 +188,29 @@ class FSDPEngine(BaseHFEngine):
else: else:
tensor = param.data tensor = param.data
if dist.get_rank() == 0: if dist.get_rank() == 0:
print(f"Broadcasting {name} with shape {tensor.shape}", flush=True) logger.debug(
f"Broadcasting {name} with shape {tensor.shape}", flush=True
)
dist.broadcast(tensor, src=0, group=self.weight_update_group) dist.broadcast(tensor, src=0, group=self.weight_update_group)
dist.barrier()
del tensor # optional, for memory hygiene del tensor # optional, for memory hygiene
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_param_meta_for_distributed_update(self) -> Dict[str, Tuple[int]]: def get_param_specs(self) -> List[ParamSpec]:
"""Return a dict mapping param name to its shape (expanded if DTensor).""" param_specs = []
param_shapes = {}
for name, param in self.model.named_parameters(): for name, param in self.model.named_parameters():
if isinstance(param.data, DTensor): if isinstance(param.data, DTensor):
tensor = param.data.full_tensor() tensor = param.data.full_tensor()
else: else:
tensor = param.data tensor = param.data
param_shapes[name] = tuple(tensor.shape) param_specs.append(
ParamSpec(
name=name,
shape=tuple(tensor.shape),
dtype=str(tensor.dtype).split("torch.")[1],
)
)
del tensor # free memory if full_tensor was created del tensor # free memory if full_tensor was created
return param_shapes return param_specs
def train_batch( def train_batch(
self, self,

View File

@ -1,10 +1,11 @@
import asyncio import asyncio
import os import os
import random import random
import shutil
import threading import threading
import time import time
import traceback import traceback
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures import Future, ProcessPoolExecutor
from datetime import datetime from datetime import datetime
from queue import Empty, Full, Queue from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, List from typing import TYPE_CHECKING, Any, Callable, Dict, List
@ -29,17 +30,12 @@ from arealite.api.io_struct import (
) )
from arealite.utils.data import concat_padded_tensors from arealite.utils.data import concat_padded_tensors
from arealite.utils.http import arequest_with_retry, get_default_connector from arealite.utils.http import arequest_with_retry, get_default_connector
from realhf.base import logging, name_resolve, names, pkg_version from realhf.base import logging, name_resolve, names
if TYPE_CHECKING: if TYPE_CHECKING:
from arealite.api.workflow_api import RolloutWorkflow from arealite.api.workflow_api import RolloutWorkflow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if pkg_version.is_available("sglang"):
if pkg_version.is_version_greater_or_equal("sglang", "0.4.4"):
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "output_ids"
else:
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
ROLLOUT_POLL_WAIT_TIME = 0.05 ROLLOUT_POLL_WAIT_TIME = 0.05
RID_CACHE_SIZE = 128 RID_CACHE_SIZE = 128
@ -93,10 +89,7 @@ class RemoteSGLangEngine(InferenceEngine):
def check_health(self, base_url): def check_health(self, base_url):
# Check server endpoint # Check server endpoint
try: try:
response = requests.get( response = requests.get(f"{base_url}/health", timeout=30)
f"{base_url}/metrics",
timeout=30,
)
return response.status_code == 200 return response.status_code == 200
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
return False return False
@ -125,7 +118,7 @@ class RemoteSGLangEngine(InferenceEngine):
"""Thread that runs the rollout loop.""" """Thread that runs the rollout loop."""
try: try:
uvloop.run(self._rollout_thread_async()) uvloop.run(self._rollout_thread_async())
except Exception as e: except Exception:
traceback.print_exc() traceback.print_exc()
async def _rollout_thread_async(self): async def _rollout_thread_async(self):
@ -272,9 +265,7 @@ class RemoteSGLangEngine(InferenceEngine):
accumulated_output_logprobs = [] accumulated_output_logprobs = []
accumulated_versions = [] accumulated_versions = []
# Deal with rollout interruption # A single "rid" shares the same sever to allow KV cache reuse
stop_reason = "length"
if req.rid in self.rid_to_address: if req.rid in self.rid_to_address:
server_addr = self.rid_to_address[req.rid] server_addr = self.rid_to_address[req.rid]
else: else:
@ -286,10 +277,19 @@ class RemoteSGLangEngine(InferenceEngine):
self.rid_to_address[req.rid] = server_addr self.rid_to_address[req.rid] = server_addr
self.rid_queue.append(req.rid) self.rid_queue.append(req.rid)
# Deal with rollout interruption
# "abort" is the stop reason for later v0.4.9.post2 after
# we call the pause_generation endpoint
stop_reason = None
while ( while (
stop_reason != "stop" stop_reason != "stop"
and len(accumulated_output_tokens) < gconfig.max_new_tokens and len(accumulated_output_tokens) < gconfig.max_new_tokens
): ):
# Request is interrupted, wait for some time to avoid interfering
# with update weights requests
if stop_reason is not None:
await asyncio.sleep(0.5)
# loop until the generation is complete # loop until the generation is complete
result = await arequest_with_retry( result = await arequest_with_retry(
session=self.session, session=self.session,
@ -301,8 +301,17 @@ class RemoteSGLangEngine(InferenceEngine):
timeout=self.config.request_timeout, timeout=self.config.request_timeout,
) )
# Parse response
meta_info = result["meta_info"] meta_info = result["meta_info"]
# Check if generation is complete
finish_reason = meta_info["finish_reason"]
stop_reason = finish_reason["type"]
if (
stop_reason == "abort"
and finish_reason.get("message") == "Abort before prefill"
):
continue
# Parse response
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]] output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]] output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
@ -312,11 +321,7 @@ class RemoteSGLangEngine(InferenceEngine):
# FIXME: Update with actual server versions # FIXME: Update with actual server versions
accumulated_versions.extend([-1] * len(output_tokens)) accumulated_versions.extend([-1] * len(output_tokens))
# Check if generation is complete payload["input_ids"] += result["output_ids"]
finish_reason = meta_info["finish_reason"]
stop_reason = finish_reason["type"]
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
sample_params["max_new_tokens"] -= len(output_tokens) sample_params["max_new_tokens"] -= len(output_tokens)
latency = time.perf_counter() - start_time latency = time.perf_counter() - start_time
@ -344,30 +349,24 @@ class RemoteSGLangEngine(InferenceEngine):
) )
return response return response
def update_weights(self, meta): def update_weights(self, meta: WeightUpdateMeta):
executor = ThreadPoolExecutor(max_workers=1) for addr in self.addresses:
return executor.submit(self._update_weights, meta) res = requests.post(f"http://{addr}/pause_generation")
res.raise_for_status()
def _update_weights(self, meta: WeightUpdateMeta): fut = Future()
if meta.type == "nccl": if meta.type == "nccl":
if not self.distributed_weight_update_initialized: fut = self.executor.submit(
self._init_distributed_weight_update(meta) update_weights_from_distributed,
tik = time.perf_counter() meta,
try: self.addresses,
loop = asyncio.new_event_loop() self.config.request_timeout,
asyncio.set_event_loop(loop) not self.distributed_weight_update_initialized,
for param in meta.parameter_names:
jobs = [
self.aupdate_weights_from_distributed(addr, meta, param)
for addr in self.addresses
]
loop.run_until_complete(asyncio.gather(*jobs))
finally:
loop.close()
logger.info(
f"Distributed update weights done in {time.perf_counter() - tik}s"
) )
self.set_version(meta.model_version)
def callback(fut):
self.distributed_weight_update_initialized = True
fut.add_done_callback(callback)
elif meta.type == "disk": elif meta.type == "disk":
# Update weights from disk # Update weights from disk
# Use ProcessPool to bypass python GIL for running async coroutines # Use ProcessPool to bypass python GIL for running async coroutines
@ -375,7 +374,7 @@ class RemoteSGLangEngine(InferenceEngine):
update_weights_from_disk, update_weights_from_disk,
self.config.experiment_name, self.config.experiment_name,
self.config.trial_name, self.config.trial_name,
meta.model_version, self.get_version(),
self.addresses, self.addresses,
meta.path, meta.path,
self.config.request_retries, self.config.request_retries,
@ -383,64 +382,19 @@ class RemoteSGLangEngine(InferenceEngine):
) )
def callback(fut): def callback(fut):
self.set_version(meta.model_version) shutil.rmtree(meta.path, ignore_errors=True)
fut.add_done_callback(callback) fut.add_done_callback(callback)
return fut
else: else:
raise NotImplementedError(f"Unsupported weight update type: {meta.type}") raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
def _init_distributed_weight_update(self, meta: WeightUpdateMeta): def callback(fut):
try: for addr in self.addresses:
# Initialize weights update group res = requests.post(f"http://{addr}/continue_generation")
jobs = [ res.raise_for_status()
self.ainit_weights_update_group(addr, meta) for addr in self.addresses
]
loop = asyncio.new_event_loop()
# asyncio event loop should be manually set when running asyncio stuff in another thread
asyncio.set_event_loop(loop)
loop.run_until_complete(asyncio.gather(*jobs))
self.distributed_weight_update_initialized = True
logger.info(f"Distributed update weights initialized")
finally:
loop.close()
async def ainit_weights_update_group(self, addr: str, meta: WeightUpdateMeta): fut.add_done_callback(callback)
rank_offset = 1 + self.addresses.index(addr) * meta.tp_size return fut
payload = {
"master_address": meta.master_address,
"master_port": str(meta.master_port),
"rank_offset": rank_offset,
"world_size": meta.world_size,
"group_name": meta.group_name,
"backend": "nccl",
}
res = await arequest_with_retry(
addr=addr,
endpoint="/init_weights_update_group",
payload=payload,
method="POST",
max_retries=1,
timeout=self.config.request_timeout,
)
assert res["success"]
async def aupdate_weights_from_distributed(
self, addr: str, meta: WeightUpdateMeta, parameter_name: str
):
res = await arequest_with_retry(
addr=addr,
endpoint="/update_weights_from_distributed",
payload={
"name": parameter_name,
"dtype": "bfloat16",
"shape": meta.state_dict_key_to_shape[parameter_name],
},
method="POST",
max_retries=1,
timeout=self.config.request_timeout,
)
assert res["success"]
def get_capacity(self): def get_capacity(self):
if dist.is_initialized(): if dist.is_initialized():
@ -546,27 +500,6 @@ class RemoteSGLangEngine(InferenceEngine):
self.paused.clear() self.paused.clear()
async def aupdate_weights_from_disk(
session, addr, path: str, request_retries: int, request_timeout: float
):
tik = time.time()
res = await arequest_with_retry(
addr=addr,
session=session,
endpoint="/update_weights_from_disk",
payload=dict(model_path=str(path), allow_interrupt=True),
method="POST",
max_retries=request_retries,
timeout=request_timeout,
)
assert res["success"]
if "num_paused_requests" in res:
logger.info(
f"{res['num_paused_requests']} requests are interrupted "
f"during updating weights for server {addr}"
)
def update_weights_from_disk( def update_weights_from_disk(
experiment_name, experiment_name,
trial_name, trial_name,
@ -596,12 +529,14 @@ def update_weights_from_disk(
connector=get_default_connector(), connector=get_default_connector(),
) )
jobs = [ jobs = [
aupdate_weights_from_disk( arequest_with_retry(
session=session,
addr=addr, addr=addr,
path=path, session=session,
request_retries=request_retries, endpoint="/update_weights_from_disk",
request_timeout=request_timeout, payload=dict(model_path=str(path)),
method="POST",
max_retries=request_retries,
timeout=request_timeout,
) )
for addr in addresses for addr in addresses
] ]
@ -612,3 +547,72 @@ def update_weights_from_disk(
) )
return uvloop.run(_fn()) return uvloop.run(_fn())
def update_weights_from_distributed(
meta: WeightUpdateMeta,
addresses: List[str],
request_timeout,
init_group: bool,
):
async def _fn():
tik = time.perf_counter()
if init_group:
await asyncio.gather(
*[
ainit_weights_update_group(addr, i, meta, request_timeout)
for i, addr in enumerate(addresses)
]
)
await asyncio.gather(
*[
arequest_with_retry(
addr=addr,
endpoint="/update_weights_from_distributed",
payload={
"names": [pspec.name for pspec in meta.nccl_param_specs],
"dtypes": [pspec.dtype for pspec in meta.nccl_param_specs],
"shapes": [pspec.shape for pspec in meta.nccl_param_specs],
"group_name": meta.nccl_group_name,
},
method="POST",
max_retries=1,
timeout=request_timeout,
)
for addr in addresses
]
)
logger.info(f"Distributed update weights done in {time.perf_counter() - tik}s")
return uvloop.run(_fn())
async def ainit_weights_update_group(
addr: str,
server_idx: int,
meta: WeightUpdateMeta,
request_timeout: float,
):
assert meta.alloc_mode is not None
if meta.alloc_mode.gen_pp_size != 1:
raise NotImplementedError(
"NCCL weight update with PP size > 1 is not implemented yet."
)
rank_offset = 1 + server_idx * meta.alloc_mode.gen_tp_size
payload = {
"master_address": meta.nccl_master_address,
"master_port": str(meta.nccl_master_port),
"rank_offset": rank_offset,
"world_size": meta.alloc_mode.gen_world_size + 1,
"backend": "nccl",
"group_name": meta.nccl_group_name,
}
res = await arequest_with_retry(
addr=addr,
endpoint="/init_weights_update_group",
payload=payload,
method="POST",
max_retries=1,
timeout=request_timeout,
)
assert res["success"]

View File

@ -283,7 +283,7 @@ def main_local():
if not cfg.server_only: if not cfg.server_only:
launcher.submit( launcher.submit(
job_name="trainer", job_name="trainer",
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} --standalone {' '.join(sys.argv[1:])}", cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} --master-addr localhost --master-port {find_free_ports(1, (10000, 50000))[0]} {' '.join(sys.argv[1:])}",
gpu=alloc_mode.train_world_size, gpu=alloc_mode.train_world_size,
env_vars=dict(AREAL_LLM_SERVER_ADDRS=",".join(server_addrs)), env_vars=dict(AREAL_LLM_SERVER_ADDRS=",".join(server_addrs)),
) )

View File

@ -3,10 +3,8 @@ import subprocess
import sys import sys
import time import time
import uuid import uuid
from pathlib import Path
from typing import Optional from typing import Optional
import ray
import requests import requests
from arealite.api.cli_args import ( from arealite.api.cli_args import (
@ -19,12 +17,12 @@ from arealite.api.cli_args import (
from arealite.api.io_struct import AllocationMode, AllocationType from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.utils.launcher import TRITON_CACHE_PATH from arealite.utils.launcher import TRITON_CACHE_PATH
from arealite.utils.network import find_free_ports, gethostip from arealite.utils.network import find_free_ports, gethostip
from realhf.base import logging, name_resolve, names, pkg_version from realhf.base import logging, name_resolve, names
logger = logging.getLogger("SGLangServer Wrapper") logger = logging.getLogger("SGLangServer Wrapper")
def execute_shell_command(command: str) -> subprocess.Popen: def launch_server_cmd(command: str) -> subprocess.Popen:
""" """
Execute a shell command and return its process handle. Execute a shell command and return its process handle.
""" """
@ -45,46 +43,6 @@ def execute_shell_command(command: str) -> subprocess.Popen:
) )
def apply_sglang_patch():
p = Path(os.path.dirname(__file__))
patch_path = str(
p.parent.parent
/ "patch"
/ "sglang"
/ f"v{pkg_version.get_version('sglang')}.patch"
)
target_path = ""
sglang_meta = subprocess.check_output(
"python3 -m pip show sglang", shell=True
).decode("ascii")
for line in sglang_meta.split("\n"):
line = line.strip()
if line.startswith("Editable project location: "):
target_path = str(Path(line.split(": ")[1]).parent)
if target_path:
proc = subprocess.Popen(
["git", "apply", patch_path],
cwd=target_path,
stderr=sys.stdout,
stdout=sys.stdout,
)
proc.wait()
logger.info(f"Applied SGLang patch at {target_path}")
def launch_server_cmd(command: str):
"""
Launch the server using the given command.
If no port is specified, a free port is reserved.
"""
if not ray.is_initialized():
apply_sglang_patch()
process = execute_shell_command(command)
return process
def wait_for_server(base_url: str, timeout: Optional[int] = None) -> None: def wait_for_server(base_url: str, timeout: Optional[int] = None) -> None:
"""Wait for the server to be ready by polling the /v1/models endpoint. """Wait for the server to be ready by polling the /v1/models endpoint.

View File

@ -12,10 +12,9 @@ from arealite.api.cli_args import (
SGLangConfig, SGLangConfig,
TrainEngineConfig, TrainEngineConfig,
) )
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
from arealite.engine.fsdp_engine import FSDPEngine from arealite.engine.fsdp_engine import FSDPEngine
from arealite.engine.sglang_remote import RemoteSGLangEngine from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.network import find_free_ports
from realhf.base import network from realhf.base import network
EXPR_NAME = "test_fsdp_engine_nccl" EXPR_NAME = "test_fsdp_engine_nccl"
@ -33,7 +32,7 @@ RUN_SERVER_TIMEOUT = 180
def check_server_health(base_url): def check_server_health(base_url):
try: try:
response = requests.get(f"{base_url}/metrics", timeout=30) response = requests.get(f"{base_url}/health", timeout=30)
return response.status_code == 200 return response.status_code == 200
except requests.exceptions.RequestException: except requests.exceptions.RequestException:
return False return False
@ -82,14 +81,14 @@ def sglang_server_nccl():
def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server_nccl): def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server_nccl):
# 设置分布式环境变量 # Set environment variables for torch distributed
os.environ["WORLD_SIZE"] = "1" os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0" os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0" os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = HOST os.environ["MASTER_ADDR"] = HOST
os.environ["MASTER_PORT"] = str(MASTER_PORT) os.environ["MASTER_PORT"] = str(MASTER_PORT)
# 启动本地FSDPEngine # Initialize FSDPEngine
engine_config = TrainEngineConfig( engine_config = TrainEngineConfig(
experiment_name=EXPR_NAME, experiment_name=EXPR_NAME,
trial_name=TRIAL_NAME, trial_name=TRIAL_NAME,
@ -100,38 +99,26 @@ def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
engine.initialize(None, ft_spec) engine.initialize(None, ft_spec)
# 启动远端RemoteSGLangEngine # Initialize RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
config.server_addrs = [f"{HOST}:{PORT}"] config.server_addrs = [f"{HOST}:{PORT}"]
remote_engine = RemoteSGLangEngine(config) remote_engine = RemoteSGLangEngine(config)
remote_engine.initialize(None, None) remote_engine.initialize(None, None)
# 构造WeightUpdateMetatype=nccl # Get WeightUpdateMeta
param_meta = engine.get_param_meta_for_distributed_update() meta = WeightUpdateMeta.from_fsdp_nccl(
meta = WeightUpdateMeta( AllocationMode.from_str("sglang.d1p1t1+d1p1t1"),
type="nccl", engine,
path=None, nccl_group_name=GROUP_NAME,
alloc_mode=None,
comm_backend="nccl",
model_version=123,
tp_size=1,
master_address="localhost",
master_port=find_free_ports(1)[0],
world_size=2,
group_name=GROUP_NAME,
parameter_names=list(param_meta.keys()),
state_dict_key_to_shape=param_meta,
) )
# 本地engine广播参数 # Broadcast weights
future = remote_engine.update_weights(meta) future = remote_engine.update_weights(meta)
print("got future", flush=True) print("got future", flush=True)
engine.upload_weights(meta) engine.upload_weights(meta)
print("uploaded wexights to remote engine", flush=True) print("uploaded wexights to remote engine", flush=True)
# 远端engine拉取参数 # Wait for remote engine to finish
future.result(timeout=120) future.result(timeout=120)
print("got result", flush=True) print("got result", flush=True)
# 检查远端参数版本
assert remote_engine.get_version() == 123
remote_engine.destroy() remote_engine.destroy()
engine.destroy() engine.destroy()

View File

@ -2,7 +2,6 @@ import os
import subprocess import subprocess
import sys import sys
import time import time
import uuid
import pytest import pytest
import requests import requests
@ -14,7 +13,7 @@ from arealite.api.cli_args import (
InferenceEngineConfig, InferenceEngineConfig,
SGLangConfig, SGLangConfig,
) )
from arealite.api.io_struct import LLMRequest, LLMResponse, WeightUpdateMeta from arealite.api.io_struct import WeightUpdateMeta
from arealite.utils import network from arealite.utils import network
from realhf.api.core.data_api import load_hf_tokenizer from realhf.api.core.data_api import load_hf_tokenizer
@ -31,10 +30,7 @@ RUN_SERVER_TIMEOUT = 180
def check_server_health(base_url): def check_server_health(base_url):
try: try:
response = requests.get( response = requests.get(f"{base_url}/health", timeout=30)
f"{base_url}/metrics",
timeout=30,
)
return response.status_code == 200 return response.status_code == 200
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
return False return False
@ -77,29 +73,6 @@ def sglang_server():
process.terminate() process.terminate()
@pytest.mark.asyncio
async def test_remote_sglang_generate(sglang_server):
from arealite.engine.sglang_remote import RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
tokenizer = load_hf_tokenizer(MODEL_PATH)
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
engine = RemoteSGLangEngine(config)
req = LLMRequest(
rid=str(uuid.uuid4()),
input_ids=tokenizer.encode("hello! how are you today"),
gconfig=GenerationHyperparameters(max_new_tokens=16),
)
resp = await engine.agenerate(req)
assert isinstance(resp, LLMResponse)
assert resp.input_tokens == req.input_ids
assert (
len(resp.output_logprobs)
== len(resp.output_tokens)
== len(resp.output_versions)
)
@pytest.mark.parametrize("n_samples", [1, 2, 4]) @pytest.mark.parametrize("n_samples", [1, 2, 4])
def test_remote_sglang_rollout(sglang_server, n_samples): def test_remote_sglang_rollout(sglang_server, n_samples):
from arealite.engine.sglang_remote import RemoteSGLangEngine from arealite.engine.sglang_remote import RemoteSGLangEngine
@ -211,6 +184,7 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
engine = FSDPEngine(engine_config) engine = FSDPEngine(engine_config)
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
engine.initialize(None, ft_spec) engine.initialize(None, ft_spec)
engine.model_version = 100
# setup name resolve # setup name resolve
import realhf.base.name_resolve as name_resolve import realhf.base.name_resolve as name_resolve
@ -227,13 +201,11 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}" os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
inf_engine = RemoteSGLangEngine(config) inf_engine = RemoteSGLangEngine(config)
inf_engine.initialize(None, None) inf_engine.initialize(None, None)
inf_engine.set_version(100)
# test update weights # test update weights
path = tmp_path_factory.mktemp("upload_weights_from_disk") path = tmp_path_factory.mktemp("upload_weights_from_disk")
update_weight_meta = WeightUpdateMeta( update_weight_meta = WeightUpdateMeta(type="disk", path=str(path))
type="disk", path=path, alloc_mode=None, comm_backend=None, model_version=100
)
future = inf_engine.update_weights(update_weight_meta) future = inf_engine.update_weights(update_weight_meta)
engine.upload_weights(update_weight_meta) engine.upload_weights(update_weight_meta)
future.result() future.result()
assert inf_engine.get_version() == 100
inf_engine.destroy() inf_engine.destroy()

View File

@ -1,310 +0,0 @@
import asyncio
import os
import shutil
import sys
import uuid
import colorama
import torch
import torch.distributed as dist
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import (
GenerationHyperparameters,
GRPOConfig,
load_expr_config,
)
from arealite.api.io_struct import FinetuneSpec, LLMRequest, WeightUpdateMeta
from arealite.api.workflow_api import RolloutWorkflow
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.data import concat_padded_tensors
from arealite.utils.device import log_gpu_stats
from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, seeding, stats_tracker
logger = logging.getLogger("boba math")
class RLVRWorkflow(RolloutWorkflow):
def __init__(
self,
reward_fn,
gconfig: GenerationHyperparameters,
tokenizer: PreTrainedTokenizerFast,
dump_dir: str | None = None,
):
self.reward_fn = reward_fn
self.gconfig = gconfig
self.tokenizer = tokenizer
self.dump_dir = dump_dir
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
os.makedirs(self.dump_dir, exist_ok=True)
async def arun_episode(self, engine, data):
input_ids = self.tokenizer.encode(data["prompt"])
n_samples = self.gconfig.n_samples
req = LLMRequest(
rid=uuid.uuid4().hex,
input_ids=input_ids,
gconfig=self.gconfig.new(n_samples=1),
)
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
version = engine.get_version()
prompt_strs = []
completions_strs = []
rewards = []
seqlens = []
results = []
for resp in resps:
seq = resp.input_tokens + resp.output_tokens
logprobs = [0.0] * resp.input_len + resp.output_logprobs
loss_mask = [0] * resp.input_len + [1] * resp.output_len
versions = [-1] * resp.input_len + resp.output_versions
prompt_str = data["prompt"]
completions_str = self.tokenizer.decode(resp.output_tokens)
prompt_strs.append(prompt_str)
completions_strs.append(completions_str)
seqlens.append(len(seq))
reward = self.reward_fn(
completions=completions_str,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
rewards.append(reward)
res = dict(
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0),
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0),
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
# reward
rewards=torch.tensor([float(reward)]),
)
results.append(TensorDict(res, batch_size=[1]))
if self.dump_dir is not None:
os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True)
# Get the unique identifier for this prompt
qid = None
for key in ["query_id", "id", "qid"]:
qid = data.get(key, None)
if qid is not None:
break
qid = qid or uuid.uuid4().hex
# Dump rollout to file
with open(
os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a"
) as f:
n_samples = self.gconfig.n_samples
for i, (p, c, r, sl) in enumerate(
zip(prompt_strs, completions_strs, rewards, seqlens)
):
info = "\n".join(
[
f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.",
f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}",
f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}",
]
)
f.write(info + "\n")
return concat_padded_tensors(results)
def get_boba_math_dataset(tokenizer, rank, world_size):
dataset = load_dataset(
path="json",
split="train",
data_files="/storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl",
)
dataset = dataset.filter(lambda x: len(tokenizer.encode(x["prompt"])) <= 1024)
return split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def boba_reward_fn(
prompt, completions, prompt_ids, completion_ids, query_id, solutions, **kwargs
):
from pebble import ProcessExpired, ProcessPool
from realhf.impl.dataset.math_parser import process_results
jobs = []
with ProcessPool(max_workers=1) as executor:
for sol in solutions:
job = executor.schedule(
process_results, args=[completions, sol], timeout=15
)
jobs.append(job)
label = 0
for job in jobs:
try:
x = job.result()
except TimeoutError:
# print("[debug: timeout]")
logger.warning(f"Timeout occurred while justifying the math answer.")
x = (0, "timeout", "timeout")
except ProcessExpired as e:
logger.warning(f"Process terminated abnormally: {e}")
x = (0, "error", "error")
except Exception as e:
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
x = (0, "error", "error")
label = label or x[0]
return label
def main(args):
config, _ = load_expr_config(args, GRPOConfig)
config: GRPOConfig
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
tokenizer = load_hf_tokenizer(config.tokenizer_path)
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
get_boba_math_dataset(tokenizer, rank, world_size),
batch_size=config.train_dataset.batch_size // world_size,
shuffle=config.train_dataset.shuffle,
num_workers=config.train_dataset.num_workers,
collate_fn=lambda x: x,
drop_last=config.train_dataset.drop_last,
)
ft_spec = FinetuneSpec(
total_train_epochs=config.total_train_epochs,
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
train_batch_size=config.train_dataset.batch_size,
)
# Initialize inference engine
rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, ft_spec)
# Initialize train engine
actor = FSDPPPOActor(config=config.actor)
actor.initialize(None, ft_spec)
ref = None
if config.actor.kl_ctl > 0 and config.ref is not None:
ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec)
# Create rollout workflow
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
workflow = RLVRWorkflow(
reward_fn=boba_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
dump_dir=os.path.join(
StatsLogger.get_log_path(config.stats_logger), "generated"
),
)
# Run training.
saver = Saver(config.saver, ft_spec, for_recover=False)
logger = StatsLogger(config.stats_logger, ft_spec)
total_epochs = config.total_train_epochs
steps_per_epoch = len(train_dataloader)
max_steps = total_epochs * steps_per_epoch
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
data_generator = iter(train_dataloader)
for global_step in range(max_steps):
epoch = global_step // steps_per_epoch
step = global_step % steps_per_epoch
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = batch.to(actor.device)
# Create barrier to synchronize all rollout processes.
dist.barrier()
torch.cuda.synchronize()
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
with stats_tracker.record_timing("recompute_logp"):
logp = actor.compute_logp(batch)
batch["prox_logp"] = logp
log_gpu_stats("recompute logp")
if ref is not None:
with stats_tracker.record_timing("ref_logp"):
batch["ref_logp"] = ref.compute_logp(batch)
log_gpu_stats("ref logp")
with stats_tracker.record_timing("compute_advantage"):
actor.compute_advantages(batch)
log_gpu_stats("compute advantages")
with (
stats_tracker.record_timing("train_step"),
stats_tracker.scope("grpo_actor"),
):
stats = actor.ppo_update(batch)
actor.step_lr_scheduler()
log_gpu_stats("ppo update")
with stats_tracker.record_timing("update_weights"):
path = os.path.join(
Saver.get_save_checkpoint_root(config.saver),
"update_weights",
str(global_step + 1),
)
meta = WeightUpdateMeta(
type="disk",
path=path,
alloc_mode=None,
comm_backend=None,
model_version=global_step + 1,
)
if dist.get_rank() == 0:
future = rollout.update_weights(meta)
actor.upload_weights(meta)
if dist.get_rank() == 0:
future.result()
shutil.rmtree(path, ignore_errors=True)
dist.barrier()
torch.cuda.synchronize()
rollout.set_version(global_step + 1)
with stats_tracker.record_timing("save"):
saver.save(actor, epoch, step, global_step)
logger.commit(epoch, step, global_step, stats)
logger.close()
rollout.destroy()
if ref is not None:
ref.destroy()
actor.destroy()
if __name__ == "__main__":
main(sys.argv[1:])

View File

@ -9,7 +9,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.workflow.vision_rlvr import VisionRLVRWorkflow from arealite.workflow.vision_rlvr import VisionRLVRWorkflow
from arealite.api.cli_args import GRPOConfig, load_expr_config from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
from arealite.dataset.__init__ import get_custom_dataset from arealite.dataset.__init__ import get_custom_dataset
from arealite.engine.ppo.actor import FSDPPPOActor from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine from arealite.engine.sglang_remote import RemoteSGLangEngine
@ -117,6 +117,18 @@ def main(args):
ref = FSDPPPOActor(config=config.ref) ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec) ref.initialize(None, ft_spec)
# NOTE: Weight update meta only requires address and free port of rank 0,
# but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks
# due to `engine.get_param_specs()`.
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
weight_update_meta = [
WeightUpdateMeta.from_fsdp_nccl(
AllocationMode.from_str(config.allocation_mode), actor
)
]
dist.broadcast_object_list(weight_update_meta, src=0)
weight_update_meta = weight_update_meta[0]
# Create rollout workflow # Create rollout workflow
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids: if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id) config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
@ -189,24 +201,16 @@ def main(args):
log_gpu_stats("ppo update") log_gpu_stats("ppo update")
with stats_tracker.record_timing("update_weights"): with stats_tracker.record_timing("update_weights"):
meta = WeightUpdateMeta( rollout.pause()
type="disk",
path=os.path.join(
Saver.get_save_checkpoint_root(config.saver),
"update_weights",
str(global_step),
),
alloc_mode=None,
comm_backend=None,
model_version=global_step + 1,
)
if dist.get_rank() == 0: if dist.get_rank() == 0:
future = rollout.update_weights(meta) future = rollout.update_weights(weight_update_meta)
actor.upload_weights(meta) actor.upload_weights(weight_update_meta)
if dist.get_rank() == 0: if dist.get_rank() == 0:
future.result() future.result()
dist.barrier() dist.barrier(device_ids=[actor.device.index])
torch.cuda.synchronize() torch.cuda.synchronize()
rollout.resume()
actor.set_version(global_step + 1)
rollout.set_version(global_step + 1) rollout.set_version(global_step + 1)
with stats_tracker.record_timing("save"): with stats_tracker.record_timing("save"):

View File

@ -1,141 +0,0 @@
experiment_name: lite-boba-math
trial_name: run1
cluster:
n_nodes: 16
n_gpus_per_node: 8
cluster_name: na132
fileroot: /storage/openpsi/experiments
name_resolve:
type: nfs
nfs_record_root: /storage/openpsi/experiments/name_resolve/lite-boba-math
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
seed: 1
total_train_epochs: 10
total_train_steps: null
tokenizer_path: ${actor.path}
allocation_mode: sglang.d96p1t1+d32p1t1
async_training: true
rollout:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
max_concurrent_rollouts: 400
queue_size: null
consumer_batch_size: ${train_dataset.batch_size}
max_head_offpolicyness: 4
enable_rollout_tracing: true
gconfig:
n_samples: 16
min_new_tokens: 0
max_new_tokens: 30720
greedy: false
temperature: 1.0
actor:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/
init_from_scratch: false
disable_dropout: true
gradient_checkpointing: true
dtype: bfloat16
mb_spec:
max_tokens_per_mb: 32768
optimizer:
type: adam
lr: 1e-5
weight_decay: 0.01
beta1: 0.9
beta2: 0.999
eps: 1e-8
lr_scheduler_type: constant
gradient_clipping: 1.0
warmup_steps_proportion: 0.001
backend: fsdp
group_size: ${gconfig.n_samples}
group_adv_norm: false
eps_clip: 0.4
temperature: ${gconfig.temperature}
reward_scaling: 10.0
reward_bias: -0.5
kl_ctl: 0.0
ppo_n_minibatches: 4
recompute_logprob: true
use_decoupled_loss: true
behav_imp_weight_cap: 5.0
ref:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: ${actor.path}
init_from_scratch: false
disable_dropout: true
dtype: ${actor.dtype}
mb_spec:
max_tokens_per_mb: 32768
optimizer: null
backend: fsdp
# SGLang
server_only: false
sglang:
model_path: ${actor.path}
random_seed: ${seed}
skip_tokenizer_init: true
dtype: ${actor.dtype}
max_running_requests: null
context_length: 32768
mem_fraction_static: 0.9
# datasets
train_dataset:
batch_size: 512
shuffle: true
pin_memory: true
# Utilities
saver:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null
checkpointer:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: 3600
evaluator:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: null
freq_steps: null
freq_secs: null
stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: online
# Launcher
launcher:
inference_server_cpus_per_gpu: 15
inference_server_mem_per_gpu: 153600
trainer_cpus_per_gpu: 15
trainer_mem_per_gpu: 153600
slurm:
mount: /storage:/storage
trainer_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif
inference_server_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif

View File

@ -1,8 +1,9 @@
experiment_name: gsm8k-grpo experiment_name: gsm8k-grpo
trial_name: trial0 trial_name: trial0
allocation_mode: sglang.d4p1t1+d4p1t1 allocation_mode: sglang.d4p1t1+d4p1t1
cluster: cluster:
n_nodes: 1
n_gpus_per_node: 8
fileroot: /tmp/arealite/experiments fileroot: /tmp/arealite/experiments
n_nodes: 1 n_nodes: 1
n_gpus_per_node: 8 n_gpus_per_node: 8
@ -33,7 +34,7 @@ gconfig:
actor: actor:
experiment_name: ${experiment_name} experiment_name: ${experiment_name}
trial_name: ${trial_name} trial_name: ${trial_name}
path: /storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/ path: Qwen/Qwen2-1.5B-Instruct
init_from_scratch: false init_from_scratch: false
disable_dropout: true disable_dropout: true
gradient_checkpointing: false gradient_checkpointing: false

View File

@ -14,7 +14,7 @@ tokenizer_path: ${model.path}
model: model:
experiment_name: ${experiment_name} experiment_name: ${experiment_name}
trial_name: ${trial_name} trial_name: ${trial_name}
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/ path: Qwen/Qwen3-1.7B
init_from_scratch: false init_from_scratch: false
gradient_checkpointing: false gradient_checkpointing: false
dtype: bfloat16 dtype: bfloat16

View File

@ -1,5 +1,4 @@
import os import os
import shutil
import sys import sys
import torch import torch
@ -9,7 +8,7 @@ from datasets.distributed import split_dataset_by_node
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import GRPOConfig, load_expr_config from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
from arealite.dataset.__init__ import get_custom_dataset from arealite.dataset.__init__ import get_custom_dataset
from arealite.engine.ppo.actor import FSDPPPOActor from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine from arealite.engine.sglang_remote import RemoteSGLangEngine
@ -95,6 +94,18 @@ def main(args):
ref = FSDPPPOActor(config=config.ref) ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec) ref.initialize(None, ft_spec)
# NOTE: Weight update meta only requires address and free port of rank 0,
# but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks
# due to `engine.get_param_specs()`.
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
weight_update_meta = [
WeightUpdateMeta.from_fsdp_nccl(
AllocationMode.from_str(config.allocation_mode), actor
)
]
dist.broadcast_object_list(weight_update_meta, src=0)
weight_update_meta = weight_update_meta[0]
# Create rollout workflow # Create rollout workflow
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids: if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id) config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
@ -138,7 +149,7 @@ def main(args):
batch = batch.to(actor.device) batch = batch.to(actor.device)
# Create barrier to synchronize all rollout processes. # Create barrier to synchronize all rollout processes.
dist.barrier() dist.barrier(device_ids=[actor.device.index])
torch.cuda.synchronize() torch.cuda.synchronize()
if config.actor.recompute_logprob or config.actor.use_decoupled_loss: if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
@ -165,26 +176,16 @@ def main(args):
log_gpu_stats("ppo update") log_gpu_stats("ppo update")
with stats_tracker.record_timing("update_weights"): with stats_tracker.record_timing("update_weights"):
path = os.path.join( rollout.pause()
Saver.get_save_checkpoint_root(config.saver),
"update_weights",
str(global_step + 1),
)
meta = WeightUpdateMeta(
type="disk",
path=path,
alloc_mode=None,
comm_backend=None,
model_version=global_step + 1,
)
if dist.get_rank() == 0: if dist.get_rank() == 0:
future = rollout.update_weights(meta) future = rollout.update_weights(weight_update_meta)
actor.upload_weights(meta) actor.upload_weights(weight_update_meta)
if dist.get_rank() == 0: if dist.get_rank() == 0:
future.result() future.result()
shutil.rmtree(path, ignore_errors=True) dist.barrier(device_ids=[actor.device.index])
dist.barrier()
torch.cuda.synchronize() torch.cuda.synchronize()
rollout.resume()
actor.set_version(global_step + 1)
rollout.set_version(global_step + 1) rollout.set_version(global_step + 1)
with stats_tracker.record_timing("save"): with stats_tracker.record_timing("save"):

View File

@ -1,11 +0,0 @@
#!/bin/sh
AREAL_PATH=$PWD
cd /sglang
git apply $AREAL_PATH/patch/sglang/v0.4.6.post4.patch
cd $AREAL_PATH
# Package used for calculating math reward
pip install -e evaluation/latex2sympy
# Install AReaL
pip install -e .

View File

@ -1,9 +1,9 @@
#/bin/bash #/bin/bash
# basic dependencies # basic dependencies
pip install -U pip pip install -U pip
pip uninstall deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y pip uninstall pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
pip install nvidia-ml-py pip install pynvml nvidia-ml-py
pip install -e evaluation/latex2sympy pip install -e evaluation/latex2sympy
pip install vllm==0.8.5 --no-build-isolation pip install vllm==0.8.5 --no-build-isolation
pip install flash_attn --no-build-isolation pip install "flash-attn<=2.7.3" --no-build-isolation
pip install -r evaluation/requirements.txt pip install -r evaluation/requirements.txt

View File

@ -1,24 +1,14 @@
#!/bin/bash #!/bin/bash
# basic dependencies # basic dependencies
pip install -U pip pip install -U pip
pip uninstall torch deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y pip uninstall pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 pip install torch==2.7.1 torchaudio==2.7.1 torchvision==0.22.1 "deepspeed>=0.17.2" pynvml
pip install "sglang[all]==0.4.6.post4" pip install "sglang[all]==0.4.9.post2"
pip install megatron-core==0.11.0 nvidia-ml-py pip install megatron-core==0.11.0 nvidia-ml-py
pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose
pip install "flash-attn<=2.7.3" --no-build-isolation pip install "flash-attn<=2.7.3" --no-build-isolation
# Package used for calculating math reward # Package used for calculating math reward
pip install -e evaluation/latex2sympy pip install -e evaluation/latex2sympy
# Install an editable sglang
rm -rf ./sglang
git clone -b v0.4.6.post4 https://github.com/sgl-project/sglang
AREAL_PATH=$PWD
cd sglang
git apply ../patch/sglang/v0.4.6.post4.patch
pip install -e "python[all]" --no-deps
cd $AREAL_PATH
# Install AReaL # Install AReaL
pip install -e . pip install -e .

View File

@ -67,6 +67,7 @@ class InstallationValidator:
def test_flash_attn_functionality(self, flash_attn_module): def test_flash_attn_functionality(self, flash_attn_module):
"""Test flash attention functionality.""" """Test flash attention functionality."""
# Try to import key functions # Try to import key functions
import flash_attn_2_cuda
from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn import flash_attn_func, flash_attn_varlen_func
print(" - Flash attention functions imported successfully") print(" - Flash attention functions imported successfully")
@ -79,12 +80,12 @@ class InstallationValidator:
"""Test SGLang basic functionality.""" """Test SGLang basic functionality."""
# Basic import test is sufficient for CI # Basic import test is sufficient for CI
import sgl_kernel import sgl_kernel
from sglang import launch_server from sglang import Engine, launch_server
assert Version(get_version("sglang")) == Version("0.4.6.post4") assert Version(get_version("sglang")) == Version("0.4.9.post2"), "SGLang version should be v0.4.9.post2"
print(" - SGLang imported successfully") print(" - SGLang imported successfully")
def test_transformers(self, transformers_module): def test_transformers(self, transformers_module):
assert Version(get_version("transformers")) == Version("4.51.1") assert Version(get_version("transformers")) == Version("4.53.1"), "transformers version should be 4.53.1"
print(" - transformers imported successfully") print(" - transformers imported successfully")
def validate_critical_dependencies(self): def validate_critical_dependencies(self):
@ -140,7 +141,7 @@ class InstallationValidator:
self.test_import("flashattn_hopper", required=False) self.test_import("flashattn_hopper", required=False)
# Optional utilities # Optional utilities
self.test_import("tensorboardx", required=False) self.test_import("tensorboardX", required=False)
self.test_import("swanlab", required=False) self.test_import("swanlab", required=False)
self.test_import("matplotlib", required=False) self.test_import("matplotlib", required=False)
self.test_import("seaborn", required=False) self.test_import("seaborn", required=False)

View File

@ -1,144 +0,0 @@
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 174656b2..33fe0a5f 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
success: bool
+@dataclass
+class InterruptAllReqInput:
+ pass
+
+
+@dataclass
+class InterruptAllReqOutput:
+ num_interrupted_requests: int
+
+
@dataclass
class UpdateWeightFromDiskReqInput:
# The model path with the new weights
model_path: str
+ allow_interrupt: bool = False
# The format to load the weights
load_format: Optional[str] = None
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 8891115c..843a8a82 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -70,6 +70,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -419,6 +421,7 @@ class Scheduler(
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
+ (InterruptAllReqInput, self.interrupt_all_requests),
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReqInput, self.flush_cache_wrapped),
@@ -1938,6 +1941,15 @@ class Scheduler(
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
+ for req in self.waiting_queue:
+ req.sampling_params.max_new_tokens = 0
+ for req in self.running_batch.reqs:
+ req.sampling_params.max_new_tokens = len(req.output_ids)
+ logger.info(f"Interrupt {num} requests.")
+ return InterruptAllReqOutput(num)
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 82709b09..bfab3ce7 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -76,6 +76,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -265,6 +267,9 @@ class TokenizerManager:
self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
+ self.interrupt_requests_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -294,6 +299,10 @@ class TokenizerManager:
UpdateWeightFromDiskReqOutput,
self._handle_update_weights_from_disk_req_output,
),
+ (
+ InterruptAllReqOutput,
+ self.interrupt_requests_communicator.handle_recv,
+ ),
(
InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv,
@@ -767,6 +776,13 @@ class TokenizerManager:
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
+ if obj.allow_interrupt:
+ num_interrupted_requests = await self.interrupt_all_requests(
+ InterruptAllReqInput()
+ )
+ # Set a break point to wait for the interrupt to finish
+ await asyncio.sleep(0.1)
+
# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
@@ -776,7 +792,12 @@ class TokenizerManager:
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
- return await self._wait_for_model_update_from_disk(obj)
+ success, message, n_paused = (
+ await self._wait_for_model_update_from_disk(obj)
+ )
+ if obj.allow_interrupt:
+ return success, message, num_interrupted_requests
+ return success, message, n_paused
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
@@ -849,6 +870,18 @@ class TokenizerManager:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
+ async def interrupt_all_requests(
+ self,
+ obj: InterruptAllReqInput,
+ request: Optional[fastapi.Request] = None,
+ ) -> Tuple[bool, str]:
+ self.auto_create_handle_loop()
+ result = await self.interrupt_requests_communicator(obj)
+ if self.server_args.dp_size == 1:
+ return result[0].num_interrupted_requests
+ else:
+ return [r.num_interrupted_requests for r in result]
+
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):

View File

@ -1,144 +0,0 @@
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 5390668c..db370d19 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
success: bool
+@dataclass
+class InterruptAllReqInput:
+ pass
+
+
+@dataclass
+class InterruptAllReqOutput:
+ num_interrupted_requests: int
+
+
@dataclass
class UpdateWeightFromDiskReqInput:
# The model path with the new weights
model_path: str
+ allow_interrupt: bool = False
# The format to load the weights
load_format: Optional[str] = None
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 1178eec5..318dee33 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -427,6 +429,7 @@ class Scheduler(
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
+ (InterruptAllReqInput, self.interrupt_all_requests),
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReqInput, self.flush_cache_wrapped),
@@ -1971,6 +1974,15 @@ class Scheduler(
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
+ for req in self.waiting_queue:
+ req.sampling_params.max_new_tokens = 0
+ for req in self.running_batch.reqs:
+ req.sampling_params.max_new_tokens = len(req.output_ids)
+ logger.info(f"Interrupt {num} requests.")
+ return InterruptAllReqOutput(num)
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index b646fae1..c668728b 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -80,6 +80,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -279,6 +281,9 @@ class TokenizerManager:
self.slow_down_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
+ self.interrupt_requests_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -309,6 +314,10 @@ class TokenizerManager:
UpdateWeightFromDiskReqOutput,
self._handle_update_weights_from_disk_req_output,
),
+ (
+ InterruptAllReqOutput,
+ self.interrupt_requests_communicator.handle_recv,
+ ),
(
InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv,
@@ -799,6 +808,13 @@ class TokenizerManager:
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
+ if obj.allow_interrupt:
+ num_interrupted_requests = await self.interrupt_all_requests(
+ InterruptAllReqInput()
+ )
+ # Set a break point to wait for the interrupt to finish
+ await asyncio.sleep(0.1)
+
# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
@@ -808,7 +824,12 @@ class TokenizerManager:
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
- return await self._wait_for_model_update_from_disk(obj)
+ success, message, n_paused = (
+ await self._wait_for_model_update_from_disk(obj)
+ )
+ if obj.allow_interrupt:
+ return success, message, num_interrupted_requests
+ return success, message, n_paused
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
@@ -881,6 +902,18 @@ class TokenizerManager:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
+ async def interrupt_all_requests(
+ self,
+ obj: InterruptAllReqInput,
+ request: Optional[fastapi.Request] = None,
+ ) -> Tuple[bool, str]:
+ self.auto_create_handle_loop()
+ result = await self.interrupt_requests_communicator(obj)
+ if self.server_args.dp_size == 1:
+ return result[0].num_interrupted_requests
+ else:
+ return [r.num_interrupted_requests for r in result]
+
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):

View File

@ -34,7 +34,6 @@ dependencies = [
"transformers==4.53.1", "transformers==4.53.1",
# Scientific computing # Scientific computing
"numpy<2.0.0",
"scipy", "scipy",
"pandas", "pandas",
"matplotlib", "matplotlib",

View File

@ -40,42 +40,11 @@ def execute_shell_command(command: str) -> subprocess.Popen:
) )
def apply_sglang_patch():
p = Path(os.path.dirname(__file__))
patch_path = str(
p.parent.parent
/ "patch"
/ "sglang"
/ f"v{pkg_version.get_version('sglang')}.patch"
)
target_path = ""
sglang_meta = subprocess.check_output(
"python3 -m pip show sglang", shell=True
).decode("ascii")
for line in sglang_meta.split("\n"):
line = line.strip()
if line.startswith("Editable project location: "):
target_path = str(Path(line.split(": ")[1]).parent)
if target_path:
proc = subprocess.Popen(
["git", "apply", patch_path],
cwd=target_path,
stderr=sys.stdout,
stdout=sys.stdout,
)
proc.wait()
logger.info(f"Applied SGLang patch at {target_path}")
def launch_server_cmd(command: str, port: int = 30000): def launch_server_cmd(command: str, port: int = 30000):
""" """
Launch the server using the given command. Launch the server using the given command.
If no port is specified, a free port is reserved. If no port is specified, a free port is reserved.
""" """
if not ray.is_initialized():
apply_sglang_patch()
assert port is not None assert port is not None
full_command = f"{command} --port {port}" full_command = f"{command} --port {port}"
process = execute_shell_command(full_command) process = execute_shell_command(full_command)

View File

@ -155,39 +155,22 @@ class GserverManager(Worker):
return None return None
async def flush_requests_and_update_weights( async def flush_requests_and_update_weights(self, server_url, new_param_path):
self, server_url, new_param_path, update_weights_retries=5 async with aiohttp.ClientSession(
): server_url,
server_index = self.server_urls.index(server_url) timeout=aiohttp.ClientTimeout(
success = False total=self.config.flush_request_timeout,
for _ in range(update_weights_retries): sock_connect=self.config.flush_request_timeout,
async with aiohttp.ClientSession( ),
server_url, ) as session:
timeout=aiohttp.ClientTimeout( (await session.post("/pause_generation")).raise_for_status()
total=self.config.flush_request_timeout, async with session.post(
sock_connect=self.config.flush_request_timeout, f"/update_weights_from_disk",
), json=dict(model_path=new_param_path),
) as session: ) as resp:
async with session.post( resp.raise_for_status()
f"/update_weights_from_disk", assert (await resp.json())["success"]
json=dict(model_path=new_param_path, allow_interrupt=True), (await session.post("/continue_generation")).raise_for_status()
) as resp:
if resp.status == 200:
res = await resp.json()
success = res["success"]
if success:
if "num_paused_requests" in res:
logger.info(
f"{res['num_paused_requests']} requests are interrupted "
f"during updating weights for server {server_index}: {server_url}"
)
return
logger.warning(
f"Update weights failed: {res['message']}. Retrying."
)
logger.warning(f"Update weights failed: {resp.reason}. Retrying.")
time.sleep(0.1)
raise RuntimeError("Update weights failed.")
def _round_robin_schedule(self, req_meta: GenReqMeta) -> int: def _round_robin_schedule(self, req_meta: GenReqMeta) -> int:
if not hasattr(self, "round_robin_idx"): if not hasattr(self, "round_robin_idx"):

View File

@ -25,7 +25,6 @@ numba
packaging packaging
pandas pandas
pybind11>=2.10.0 pybind11>=2.10.0
numpy<2.0.0
psutil psutil
pynvml pynvml
pytest pytest