mirror of https://github.com/inclusionAI/AReaL
[lite] [feature] Bump to SGLang v0.4.9.post2 and use NCCL to update weights (#196)
* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * add gradient checkpointing * PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * fix * . * PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * fix destroy process group * PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/358 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * fix loss mask * fix * . * PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/368 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/371 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/370 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * . * . * . * fix * PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results. Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/392 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * support fsdp engine and sglang remote engine * minor fix * . * refactor trainer * add close * rm mb_spec * . * fix * . * qwen2 grpo works * fix * fix * async works * fix * slurm launcher not tested * fix arg parse * . * sglang server wrapper * . * . * slurm run * ready for boba * debug * 32k run * . * . * fix * . * . * . * . * . * fix * . * fix * . * . * . * . * fix * . * . * . * . * . * . * . * refactor train engine * refactor train engine * . * fix update weight error * . * . * match train * format * . * fix * seems to work * . * . * . * . * . * PullRequest: 408 [Feature] Bump SGLang version to v0.4.9.post2 Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/408 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * bump arealite to sglang 0.4.9.post2 * . * PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta` * PullRequest: 422 [lite] Fix tests and scripts after updating sgl to 0.4.9 Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/422 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * bump arealite to sglang 0.4.9.post2 * . * PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta` * . * PullRequest: 423 [lite] Remove the boba example for github release. Merge branch fw/remove-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/423 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . --------- Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
parent
6239633213
commit
311bcd7697
|
@ -163,7 +163,6 @@ class WeightUpdateMeta:
|
|||
type: str
|
||||
path: str | None
|
||||
alloc_mode: AllocationMode | None
|
||||
comm_backend: str | None
|
||||
|
||||
@dataclass
|
||||
class SaveLoadMeta:
|
||||
|
|
|
@ -285,85 +285,6 @@ class PPOActorConfig(TrainEngineConfig):
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PPOActorConfig(TrainEngineConfig):
|
||||
# Core PPO/GRPO Parameters
|
||||
group_size: int = field(
|
||||
default=1, metadata={"help": "Number of sequences in each group"}
|
||||
)
|
||||
group_adv_norm: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Normalize advantages within each prompt group rather than globally"
|
||||
},
|
||||
)
|
||||
ppo_n_minibatches: int = field(
|
||||
default=4, metadata={"help": "Number of minibatches for each PPO update"}
|
||||
)
|
||||
eps_clip: float = field(
|
||||
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
|
||||
)
|
||||
c_clip: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
|
||||
},
|
||||
)
|
||||
temperature: float = field(
|
||||
default=1.0, metadata={"help": "Temperature during generation."}
|
||||
)
|
||||
# Reward
|
||||
group_reward_norm: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
|
||||
},
|
||||
)
|
||||
reward_scaling: float = field(
|
||||
default=1.0, metadata={"help": "Reward scaling factor"}
|
||||
)
|
||||
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
|
||||
reward_clip: float = field(
|
||||
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
|
||||
)
|
||||
mask_no_eos_with_zero: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Mask truncated generations (no EOS token) and exclude from training"
|
||||
},
|
||||
)
|
||||
|
||||
# Advantage Estimation
|
||||
discount: float = field(
|
||||
default=1.0, metadata={"help": "Discount factor for future rewards"}
|
||||
)
|
||||
gae_lambda: float = field(
|
||||
default=1.0, metadata={"help": "Lambda parameter for GAE"}
|
||||
)
|
||||
adv_norm: bool = field(
|
||||
default=True, metadata={"help": "Enable advantage normalization"}
|
||||
)
|
||||
|
||||
# KL Control
|
||||
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
|
||||
|
||||
# Asynchronous RL
|
||||
recompute_logprob: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Recompute logp and replace the logp returned by inference."},
|
||||
)
|
||||
use_decoupled_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
|
||||
)
|
||||
behav_imp_weight_cap: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SGLangConfig:
|
||||
"""Configuration for SGLang runtime. Refer to:
|
||||
|
|
|
@ -12,6 +12,7 @@ from arealite.api.io_struct import (
|
|||
FinetuneSpec,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
ParamSpec,
|
||||
SaveLoadMeta,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
|
@ -63,7 +64,19 @@ class TrainEngine(abc.ABC):
|
|||
return self.train(False)
|
||||
|
||||
def upload_weights(self, meta: WeightUpdateMeta):
|
||||
"""Upload weights to the inference engine."""
|
||||
"""Upload weights to the inference engine (in a blocking manner)."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_param_specs(self) -> List[ParamSpec]:
|
||||
"""Get the parameter specifications for the model."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def set_version(self, version: int):
|
||||
"""Set the current weight version in the train engine."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_version(self) -> int:
|
||||
"""Get the current weight version in the train engine."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def save(self, meta: SaveLoadMeta):
|
||||
|
@ -122,14 +135,22 @@ class InferenceEngine(abc.ABC):
|
|||
def destroy(self):
|
||||
"""Destroy the engine and release GPU memory."""
|
||||
|
||||
def update_weights(self, meta: WeightUpdateMeta) -> Future:
|
||||
"""Update weights in the inference engine."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def agenerate(self, req: LLMRequest) -> LLMResponse:
|
||||
"""Asynchronously generate a response for the given request."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_weights(self, meta: WeightUpdateMeta) -> Future:
|
||||
"""Update weights in the inference engine in a non-blocking manner."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def set_version(self, version: int) -> None:
|
||||
"""Set the current weight version in the inference engine."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_version(self) -> int:
|
||||
"""Get the current weight version in the inference engine."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
|
||||
"""Asynchronously submit a request to the inference engine. Exits immediately."""
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -2,14 +2,19 @@
|
|||
# Licensed under the Apache License, Version 2.0
|
||||
import enum
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
from arealite.api.cli_args import GenerationHyperparameters, SaverConfig
|
||||
from arealite.utils.network import find_free_ports, gethostip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arealite.api.engine_api import TrainEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -155,20 +160,55 @@ class AllocationMode:
|
|||
return other_alloc
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamSpec:
|
||||
name: str
|
||||
shape: Tuple
|
||||
dtype: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightUpdateMeta:
|
||||
type: str
|
||||
path: str | None
|
||||
alloc_mode: AllocationMode | None
|
||||
comm_backend: str | None
|
||||
model_version: int = 0
|
||||
tp_size: int = 1
|
||||
master_address: str = "127.0.0.1"
|
||||
master_port: int = 29500
|
||||
world_size: int = 1
|
||||
group_name: str = "aupdate_weights_from_distributed"
|
||||
parameter_names: List[str] = field(default_factory=list)
|
||||
state_dict_key_to_shape: Dict[str, Tuple[int]] = field(default_factory=dict)
|
||||
type: Literal["disk", "nccl"]
|
||||
path: str | None = None
|
||||
alloc_mode: AllocationMode | None = None
|
||||
|
||||
nccl_master_address: str = "127.0.0.1"
|
||||
nccl_master_port: int = 29500
|
||||
nccl_param_specs: List[ParamSpec] = field(default_factory=list)
|
||||
nccl_group_name: str = "update_weight_group"
|
||||
|
||||
@classmethod
|
||||
def from_disk(
|
||||
cls,
|
||||
saver_config: SaverConfig,
|
||||
) -> "WeightUpdateMeta":
|
||||
from arealite.utils.saver import Saver
|
||||
|
||||
path = os.path.join(
|
||||
Saver.get_save_checkpoint_root(saver_config),
|
||||
"weight_update",
|
||||
)
|
||||
return cls(
|
||||
type="disk",
|
||||
path=path,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_fsdp_nccl(
|
||||
cls,
|
||||
allocation_mode: AllocationMode,
|
||||
fsdp_engine: "TrainEngine",
|
||||
nccl_group_name: str = "update_weight_group",
|
||||
):
|
||||
return cls(
|
||||
type="nccl",
|
||||
alloc_mode=allocation_mode,
|
||||
nccl_master_address=gethostip(),
|
||||
nccl_master_port=find_free_ports(1)[0],
|
||||
nccl_param_specs=fsdp_engine.get_param_specs(),
|
||||
nccl_group_name=nccl_group_name,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -46,6 +46,7 @@ class BaseHFEngine(TrainEngine):
|
|||
self.tokenizer: PreTrainedTokenizerFast
|
||||
# huggingface model config
|
||||
self.model_config: PretrainedConfig
|
||||
self._version: int = 0
|
||||
|
||||
# initialization
|
||||
self.initialized = False
|
||||
|
@ -55,6 +56,12 @@ class BaseHFEngine(TrainEngine):
|
|||
|
||||
self.world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
def set_version(self, version: int):
|
||||
self._version = version
|
||||
|
||||
def get_version(self) -> int:
|
||||
return self._version
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
assert self.model is not None
|
||||
self.model.train(mode=mode)
|
||||
|
@ -191,7 +198,7 @@ class BaseHFEngine(TrainEngine):
|
|||
)
|
||||
state_dict = self.optimizer.state_dict()
|
||||
torch.save(state_dict, shard_path)
|
||||
dist.barrier()
|
||||
dist.barrier(device_ids=[self.device.index])
|
||||
|
||||
def load_optimizer_state(self, path: str):
|
||||
# Load FSDP sharded state dict
|
||||
|
@ -203,7 +210,7 @@ class BaseHFEngine(TrainEngine):
|
|||
)
|
||||
optimizer_state_dict = torch.load(shard_path, weights_only=False)
|
||||
self.optimizer.load_state_dict(optimizer_state_dict)
|
||||
dist.barrier()
|
||||
dist.barrier(device_ids=[self.device.index])
|
||||
|
||||
def step_lr_scheduler(self):
|
||||
assert self.lr_scheduler is not None
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -14,7 +14,8 @@ from torch.distributed.checkpoint.state_dict import (
|
|||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
||||
from arealite.api.engine_api import FinetuneSpec
|
||||
from arealite.api.io_struct import ParamSpec, SaveLoadMeta, WeightUpdateMeta
|
||||
from arealite.engine.base_hf_engine import BaseHFEngine
|
||||
from arealite.utils.distributed import init_custom_process_group
|
||||
from arealite.utils.fsdp import (
|
||||
|
@ -119,7 +120,7 @@ class FSDPEngine(BaseHFEngine):
|
|||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
dist.barrier()
|
||||
dist.barrier(device_ids=[self.device.index])
|
||||
|
||||
def _load_model_from_hf(self, path: str):
|
||||
"""Load model from HuggingFace format."""
|
||||
|
@ -140,7 +141,7 @@ class FSDPEngine(BaseHFEngine):
|
|||
if not self.weight_update_group_initialized:
|
||||
self._init_distributed_weight_update(meta)
|
||||
self._update_weights_from_distributed()
|
||||
dist.barrier()
|
||||
dist.barrier(device_ids=[self.device.index])
|
||||
torch.cuda.synchronize()
|
||||
elif meta.type == "disk":
|
||||
self._save_model_to_hf(meta.path, self.tokenizer)
|
||||
|
@ -149,7 +150,7 @@ class FSDPEngine(BaseHFEngine):
|
|||
update_name = names.update_weights_from_disk(
|
||||
self.config.experiment_name,
|
||||
self.config.trial_name,
|
||||
meta.model_version,
|
||||
self.model_version,
|
||||
)
|
||||
name_resolve.add(
|
||||
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
|
||||
|
@ -158,16 +159,18 @@ class FSDPEngine(BaseHFEngine):
|
|||
raise ValueError(f"Unknown weight update type {meta.type}")
|
||||
|
||||
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
||||
# NOTE: Processes launched with torchrun will set the following env var to True,
|
||||
# which blocks creating another TCP store for weight update.
|
||||
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = str(False)
|
||||
if dist.get_rank() == 0:
|
||||
self.weight_update_group = init_custom_process_group(
|
||||
backend="nccl",
|
||||
world_size=meta.world_size,
|
||||
init_method=f"tcp://{meta.master_address}:{meta.master_port}",
|
||||
world_size=meta.alloc_mode.gen_world_size + 1,
|
||||
init_method=f"tcp://{meta.nccl_master_address}:{meta.nccl_master_port}",
|
||||
rank=0,
|
||||
group_name=meta.group_name,
|
||||
group_name=meta.nccl_group_name,
|
||||
)
|
||||
# NOTE: synchronizing with sglang's barrier
|
||||
dist.barrier(group=self.weight_update_group, device_ids=[self.device.index])
|
||||
# NOTE: sglang v0.4.9.post2 or later does not have the barrier call
|
||||
self.weight_update_group_initialized = True
|
||||
|
||||
def _update_weights_from_distributed(self):
|
||||
|
@ -179,23 +182,29 @@ class FSDPEngine(BaseHFEngine):
|
|||
else:
|
||||
tensor = param.data
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Broadcasting {name} with shape {tensor.shape}", flush=True)
|
||||
logger.debug(
|
||||
f"Broadcasting {name} with shape {tensor.shape}", flush=True
|
||||
)
|
||||
dist.broadcast(tensor, src=0, group=self.weight_update_group)
|
||||
dist.barrier()
|
||||
del tensor # optional, for memory hygiene
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_param_meta_for_distributed_update(self) -> Dict[str, Tuple[int]]:
|
||||
"""Return a dict mapping param name to its shape (expanded if DTensor)."""
|
||||
param_shapes = {}
|
||||
def get_param_specs(self) -> List[ParamSpec]:
|
||||
param_specs = []
|
||||
for name, param in self.model.named_parameters():
|
||||
if isinstance(param.data, DTensor):
|
||||
tensor = param.data.full_tensor()
|
||||
else:
|
||||
tensor = param.data
|
||||
param_shapes[name] = tuple(tensor.shape)
|
||||
param_specs.append(
|
||||
ParamSpec(
|
||||
name=name,
|
||||
shape=tuple(tensor.shape),
|
||||
dtype=str(tensor.dtype).split("torch.")[1],
|
||||
)
|
||||
)
|
||||
del tensor # free memory if full_tensor was created
|
||||
return param_shapes
|
||||
return param_specs
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from datetime import datetime
|
||||
from queue import Empty, Full, Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
||||
|
@ -27,17 +28,12 @@ from arealite.api.io_struct import (
|
|||
)
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
from arealite.utils.http import arequest_with_retry, get_default_connector
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
from realhf.base import logging, name_resolve, names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if pkg_version.is_available("sglang"):
|
||||
if pkg_version.is_version_greater_or_equal("sglang", "0.4.4"):
|
||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "output_ids"
|
||||
else:
|
||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
|
||||
|
||||
ROLLOUT_POLL_WAIT_TIME = 0.05
|
||||
RID_CACHE_SIZE = 128
|
||||
|
@ -91,10 +87,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
def check_health(self, base_url):
|
||||
# Check server endpoint
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{base_url}/metrics",
|
||||
timeout=30,
|
||||
)
|
||||
response = requests.get(f"{base_url}/health", timeout=30)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException as e:
|
||||
return False
|
||||
|
@ -123,7 +116,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
"""Thread that runs the rollout loop."""
|
||||
try:
|
||||
uvloop.run(self._rollout_thread_async())
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
async def _rollout_thread_async(self):
|
||||
|
@ -259,9 +252,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
accumulated_output_logprobs = []
|
||||
accumulated_versions = []
|
||||
|
||||
# Deal with rollout interruption
|
||||
stop_reason = "length"
|
||||
|
||||
# A single "rid" shares the same sever to allow KV cache reuse
|
||||
if req.rid in self.rid_to_address:
|
||||
server_addr = self.rid_to_address[req.rid]
|
||||
else:
|
||||
|
@ -273,10 +264,19 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
self.rid_to_address[req.rid] = server_addr
|
||||
self.rid_queue.append(req.rid)
|
||||
|
||||
# Deal with rollout interruption
|
||||
# "abort" is the stop reason for later v0.4.9.post2 after
|
||||
# we call the pause_generation endpoint
|
||||
stop_reason = None
|
||||
while (
|
||||
stop_reason != "stop"
|
||||
and len(accumulated_output_tokens) < gconfig.max_new_tokens
|
||||
):
|
||||
# Request is interrupted, wait for some time to avoid interfering
|
||||
# with update weights requests
|
||||
if stop_reason is not None:
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# loop until the generation is complete
|
||||
result = await arequest_with_retry(
|
||||
session=self.session,
|
||||
|
@ -288,8 +288,17 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
timeout=self.config.request_timeout,
|
||||
)
|
||||
|
||||
# Parse response
|
||||
meta_info = result["meta_info"]
|
||||
# Check if generation is complete
|
||||
finish_reason = meta_info["finish_reason"]
|
||||
stop_reason = finish_reason["type"]
|
||||
if (
|
||||
stop_reason == "abort"
|
||||
and finish_reason.get("message") == "Abort before prefill"
|
||||
):
|
||||
continue
|
||||
|
||||
# Parse response
|
||||
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
|
||||
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
|
||||
|
||||
|
@ -299,11 +308,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
# FIXME: Update with actual server versions
|
||||
accumulated_versions.extend([-1] * len(output_tokens))
|
||||
|
||||
# Check if generation is complete
|
||||
finish_reason = meta_info["finish_reason"]
|
||||
stop_reason = finish_reason["type"]
|
||||
|
||||
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
|
||||
payload["input_ids"] += result["output_ids"]
|
||||
sample_params["max_new_tokens"] -= len(output_tokens)
|
||||
|
||||
latency = time.perf_counter() - start_time
|
||||
|
@ -318,30 +323,24 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
|
||||
def update_weights(self, meta):
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
return executor.submit(self._update_weights, meta)
|
||||
|
||||
def _update_weights(self, meta: WeightUpdateMeta):
|
||||
def update_weights(self, meta: WeightUpdateMeta):
|
||||
for addr in self.addresses:
|
||||
res = requests.post(f"http://{addr}/pause_generation")
|
||||
res.raise_for_status()
|
||||
fut = Future()
|
||||
if meta.type == "nccl":
|
||||
if not self.distributed_weight_update_initialized:
|
||||
self._init_distributed_weight_update(meta)
|
||||
tik = time.perf_counter()
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
for param in meta.parameter_names:
|
||||
jobs = [
|
||||
self.aupdate_weights_from_distributed(addr, meta, param)
|
||||
for addr in self.addresses
|
||||
]
|
||||
loop.run_until_complete(asyncio.gather(*jobs))
|
||||
finally:
|
||||
loop.close()
|
||||
logger.info(
|
||||
f"Distributed update weights done in {time.perf_counter() - tik}s"
|
||||
fut = self.executor.submit(
|
||||
update_weights_from_distributed,
|
||||
meta,
|
||||
self.addresses,
|
||||
self.config.request_timeout,
|
||||
not self.distributed_weight_update_initialized,
|
||||
)
|
||||
self.set_version(meta.model_version)
|
||||
|
||||
def callback(fut):
|
||||
self.distributed_weight_update_initialized = True
|
||||
|
||||
fut.add_done_callback(callback)
|
||||
elif meta.type == "disk":
|
||||
# Update weights from disk
|
||||
# Use ProcessPool to bypass python GIL for running async coroutines
|
||||
|
@ -349,7 +348,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
update_weights_from_disk,
|
||||
self.config.experiment_name,
|
||||
self.config.trial_name,
|
||||
meta.model_version,
|
||||
self.get_version(),
|
||||
self.addresses,
|
||||
meta.path,
|
||||
self.config.request_retries,
|
||||
|
@ -357,64 +356,19 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
)
|
||||
|
||||
def callback(fut):
|
||||
self.set_version(meta.model_version)
|
||||
shutil.rmtree(meta.path, ignore_errors=True)
|
||||
|
||||
fut.add_done_callback(callback)
|
||||
return fut
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
|
||||
|
||||
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
||||
try:
|
||||
# Initialize weights update group
|
||||
jobs = [
|
||||
self.ainit_weights_update_group(addr, meta) for addr in self.addresses
|
||||
]
|
||||
loop = asyncio.new_event_loop()
|
||||
# asyncio event loop should be manually set when running asyncio stuff in another thread
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(asyncio.gather(*jobs))
|
||||
self.distributed_weight_update_initialized = True
|
||||
logger.info(f"Distributed update weights initialized")
|
||||
finally:
|
||||
loop.close()
|
||||
def callback(fut):
|
||||
for addr in self.addresses:
|
||||
res = requests.post(f"http://{addr}/continue_generation")
|
||||
res.raise_for_status()
|
||||
|
||||
async def ainit_weights_update_group(self, addr: str, meta: WeightUpdateMeta):
|
||||
rank_offset = 1 + self.addresses.index(addr) * meta.tp_size
|
||||
payload = {
|
||||
"master_address": meta.master_address,
|
||||
"master_port": str(meta.master_port),
|
||||
"rank_offset": rank_offset,
|
||||
"world_size": meta.world_size,
|
||||
"group_name": meta.group_name,
|
||||
"backend": "nccl",
|
||||
}
|
||||
res = await arequest_with_retry(
|
||||
addr=addr,
|
||||
endpoint="/init_weights_update_group",
|
||||
payload=payload,
|
||||
method="POST",
|
||||
max_retries=1,
|
||||
timeout=self.config.request_timeout,
|
||||
)
|
||||
assert res["success"]
|
||||
|
||||
async def aupdate_weights_from_distributed(
|
||||
self, addr: str, meta: WeightUpdateMeta, parameter_name: str
|
||||
):
|
||||
res = await arequest_with_retry(
|
||||
addr=addr,
|
||||
endpoint="/update_weights_from_distributed",
|
||||
payload={
|
||||
"name": parameter_name,
|
||||
"dtype": "bfloat16",
|
||||
"shape": meta.state_dict_key_to_shape[parameter_name],
|
||||
},
|
||||
method="POST",
|
||||
max_retries=1,
|
||||
timeout=self.config.request_timeout,
|
||||
)
|
||||
assert res["success"]
|
||||
fut.add_done_callback(callback)
|
||||
return fut
|
||||
|
||||
def get_capacity(self):
|
||||
if dist.is_initialized():
|
||||
|
@ -519,27 +473,6 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
self.paused.clear()
|
||||
|
||||
|
||||
async def aupdate_weights_from_disk(
|
||||
session, addr, path: str, request_retries: int, request_timeout: float
|
||||
):
|
||||
tik = time.time()
|
||||
res = await arequest_with_retry(
|
||||
addr=addr,
|
||||
session=session,
|
||||
endpoint="/update_weights_from_disk",
|
||||
payload=dict(model_path=str(path), allow_interrupt=True),
|
||||
method="POST",
|
||||
max_retries=request_retries,
|
||||
timeout=request_timeout,
|
||||
)
|
||||
assert res["success"]
|
||||
if "num_paused_requests" in res:
|
||||
logger.info(
|
||||
f"{res['num_paused_requests']} requests are interrupted "
|
||||
f"during updating weights for server {addr}"
|
||||
)
|
||||
|
||||
|
||||
def update_weights_from_disk(
|
||||
experiment_name,
|
||||
trial_name,
|
||||
|
@ -569,12 +502,14 @@ def update_weights_from_disk(
|
|||
connector=get_default_connector(),
|
||||
)
|
||||
jobs = [
|
||||
aupdate_weights_from_disk(
|
||||
session=session,
|
||||
arequest_with_retry(
|
||||
addr=addr,
|
||||
path=path,
|
||||
request_retries=request_retries,
|
||||
request_timeout=request_timeout,
|
||||
session=session,
|
||||
endpoint="/update_weights_from_disk",
|
||||
payload=dict(model_path=str(path)),
|
||||
method="POST",
|
||||
max_retries=request_retries,
|
||||
timeout=request_timeout,
|
||||
)
|
||||
for addr in addresses
|
||||
]
|
||||
|
@ -585,3 +520,72 @@ def update_weights_from_disk(
|
|||
)
|
||||
|
||||
return uvloop.run(_fn())
|
||||
|
||||
|
||||
def update_weights_from_distributed(
|
||||
meta: WeightUpdateMeta,
|
||||
addresses: List[str],
|
||||
request_timeout,
|
||||
init_group: bool,
|
||||
):
|
||||
async def _fn():
|
||||
tik = time.perf_counter()
|
||||
if init_group:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
ainit_weights_update_group(addr, i, meta, request_timeout)
|
||||
for i, addr in enumerate(addresses)
|
||||
]
|
||||
)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
arequest_with_retry(
|
||||
addr=addr,
|
||||
endpoint="/update_weights_from_distributed",
|
||||
payload={
|
||||
"names": [pspec.name for pspec in meta.nccl_param_specs],
|
||||
"dtypes": [pspec.dtype for pspec in meta.nccl_param_specs],
|
||||
"shapes": [pspec.shape for pspec in meta.nccl_param_specs],
|
||||
"group_name": meta.nccl_group_name,
|
||||
},
|
||||
method="POST",
|
||||
max_retries=1,
|
||||
timeout=request_timeout,
|
||||
)
|
||||
for addr in addresses
|
||||
]
|
||||
)
|
||||
logger.info(f"Distributed update weights done in {time.perf_counter() - tik}s")
|
||||
|
||||
return uvloop.run(_fn())
|
||||
|
||||
|
||||
async def ainit_weights_update_group(
|
||||
addr: str,
|
||||
server_idx: int,
|
||||
meta: WeightUpdateMeta,
|
||||
request_timeout: float,
|
||||
):
|
||||
assert meta.alloc_mode is not None
|
||||
if meta.alloc_mode.gen_pp_size != 1:
|
||||
raise NotImplementedError(
|
||||
"NCCL weight update with PP size > 1 is not implemented yet."
|
||||
)
|
||||
rank_offset = 1 + server_idx * meta.alloc_mode.gen_tp_size
|
||||
payload = {
|
||||
"master_address": meta.nccl_master_address,
|
||||
"master_port": str(meta.nccl_master_port),
|
||||
"rank_offset": rank_offset,
|
||||
"world_size": meta.alloc_mode.gen_world_size + 1,
|
||||
"backend": "nccl",
|
||||
"group_name": meta.nccl_group_name,
|
||||
}
|
||||
res = await arequest_with_retry(
|
||||
addr=addr,
|
||||
endpoint="/init_weights_update_group",
|
||||
payload=payload,
|
||||
method="POST",
|
||||
max_retries=1,
|
||||
timeout=request_timeout,
|
||||
)
|
||||
assert res["success"]
|
||||
|
|
|
@ -283,7 +283,7 @@ def main_local():
|
|||
if not cfg.server_only:
|
||||
launcher.submit(
|
||||
job_name="trainer",
|
||||
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} --standalone {' '.join(sys.argv[1:])}",
|
||||
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} --master-addr localhost --master-port {find_free_ports(1, (10000, 50000))[0]} {' '.join(sys.argv[1:])}",
|
||||
gpu=alloc_mode.train_world_size,
|
||||
env_vars=dict(AREAL_LLM_SERVER_ADDRS=",".join(server_addrs)),
|
||||
)
|
||||
|
|
|
@ -3,10 +3,8 @@ import subprocess
|
|||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import ray
|
||||
import requests
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
|
@ -19,12 +17,12 @@ from arealite.api.cli_args import (
|
|||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.utils.launcher import TRITON_CACHE_PATH
|
||||
from arealite.utils.network import find_free_ports, gethostip
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
from realhf.base import logging, name_resolve, names
|
||||
|
||||
logger = logging.getLogger("SGLangServer Wrapper")
|
||||
|
||||
|
||||
def execute_shell_command(command: str) -> subprocess.Popen:
|
||||
def launch_server_cmd(command: str) -> subprocess.Popen:
|
||||
"""
|
||||
Execute a shell command and return its process handle.
|
||||
"""
|
||||
|
@ -45,46 +43,6 @@ def execute_shell_command(command: str) -> subprocess.Popen:
|
|||
)
|
||||
|
||||
|
||||
def apply_sglang_patch():
|
||||
p = Path(os.path.dirname(__file__))
|
||||
patch_path = str(
|
||||
p.parent.parent
|
||||
/ "patch"
|
||||
/ "sglang"
|
||||
/ f"v{pkg_version.get_version('sglang')}.patch"
|
||||
)
|
||||
|
||||
target_path = ""
|
||||
sglang_meta = subprocess.check_output(
|
||||
"python3 -m pip show sglang", shell=True
|
||||
).decode("ascii")
|
||||
for line in sglang_meta.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("Editable project location: "):
|
||||
target_path = str(Path(line.split(": ")[1]).parent)
|
||||
|
||||
if target_path:
|
||||
proc = subprocess.Popen(
|
||||
["git", "apply", patch_path],
|
||||
cwd=target_path,
|
||||
stderr=sys.stdout,
|
||||
stdout=sys.stdout,
|
||||
)
|
||||
proc.wait()
|
||||
logger.info(f"Applied SGLang patch at {target_path}")
|
||||
|
||||
|
||||
def launch_server_cmd(command: str):
|
||||
"""
|
||||
Launch the server using the given command.
|
||||
If no port is specified, a free port is reserved.
|
||||
"""
|
||||
if not ray.is_initialized():
|
||||
apply_sglang_patch()
|
||||
process = execute_shell_command(command)
|
||||
return process
|
||||
|
||||
|
||||
def wait_for_server(base_url: str, timeout: Optional[int] = None) -> None:
|
||||
"""Wait for the server to be ready by polling the /v1/models endpoint.
|
||||
|
||||
|
|
|
@ -12,10 +12,9 @@ from arealite.api.cli_args import (
|
|||
SGLangConfig,
|
||||
TrainEngineConfig,
|
||||
)
|
||||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.network import find_free_ports
|
||||
from realhf.base import network
|
||||
|
||||
EXPR_NAME = "test_fsdp_engine_nccl"
|
||||
|
@ -33,7 +32,7 @@ RUN_SERVER_TIMEOUT = 180
|
|||
|
||||
def check_server_health(base_url):
|
||||
try:
|
||||
response = requests.get(f"{base_url}/metrics", timeout=30)
|
||||
response = requests.get(f"{base_url}/health", timeout=30)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException:
|
||||
return False
|
||||
|
@ -82,14 +81,14 @@ def sglang_server_nccl():
|
|||
|
||||
|
||||
def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server_nccl):
|
||||
# 设置分布式环境变量
|
||||
# Set environment variables for torch distributed
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["MASTER_ADDR"] = HOST
|
||||
os.environ["MASTER_PORT"] = str(MASTER_PORT)
|
||||
|
||||
# 启动本地FSDPEngine
|
||||
# Initialize FSDPEngine
|
||||
engine_config = TrainEngineConfig(
|
||||
experiment_name=EXPR_NAME,
|
||||
trial_name=TRIAL_NAME,
|
||||
|
@ -100,38 +99,26 @@ def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server
|
|||
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
|
||||
engine.initialize(None, ft_spec)
|
||||
|
||||
# 启动远端RemoteSGLangEngine
|
||||
# Initialize RemoteSGLangEngine
|
||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
remote_engine = RemoteSGLangEngine(config)
|
||||
remote_engine.initialize(None, None)
|
||||
|
||||
# 构造WeightUpdateMeta(type=nccl)
|
||||
param_meta = engine.get_param_meta_for_distributed_update()
|
||||
meta = WeightUpdateMeta(
|
||||
type="nccl",
|
||||
path=None,
|
||||
alloc_mode=None,
|
||||
comm_backend="nccl",
|
||||
model_version=123,
|
||||
tp_size=1,
|
||||
master_address="localhost",
|
||||
master_port=find_free_ports(1)[0],
|
||||
world_size=2,
|
||||
group_name=GROUP_NAME,
|
||||
parameter_names=list(param_meta.keys()),
|
||||
state_dict_key_to_shape=param_meta,
|
||||
# Get WeightUpdateMeta
|
||||
meta = WeightUpdateMeta.from_fsdp_nccl(
|
||||
AllocationMode.from_str("sglang.d1p1t1+d1p1t1"),
|
||||
engine,
|
||||
nccl_group_name=GROUP_NAME,
|
||||
)
|
||||
|
||||
# 本地engine广播参数
|
||||
# Broadcast weights
|
||||
future = remote_engine.update_weights(meta)
|
||||
print("got future", flush=True)
|
||||
engine.upload_weights(meta)
|
||||
print("uploaded wexights to remote engine", flush=True)
|
||||
# 远端engine拉取参数
|
||||
# Wait for remote engine to finish
|
||||
future.result(timeout=120)
|
||||
print("got result", flush=True)
|
||||
# 检查远端参数版本
|
||||
assert remote_engine.get_version() == 123
|
||||
remote_engine.destroy()
|
||||
engine.destroy()
|
||||
|
|
|
@ -2,7 +2,6 @@ import os
|
|||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
@ -14,7 +13,7 @@ from arealite.api.cli_args import (
|
|||
InferenceEngineConfig,
|
||||
SGLangConfig,
|
||||
)
|
||||
from arealite.api.io_struct import LLMRequest, LLMResponse, WeightUpdateMeta
|
||||
from arealite.api.io_struct import WeightUpdateMeta
|
||||
from arealite.utils import network
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
|
||||
|
@ -31,10 +30,7 @@ RUN_SERVER_TIMEOUT = 180
|
|||
|
||||
def check_server_health(base_url):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{base_url}/metrics",
|
||||
timeout=30,
|
||||
)
|
||||
response = requests.get(f"{base_url}/health", timeout=30)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException as e:
|
||||
return False
|
||||
|
@ -77,29 +73,6 @@ def sglang_server():
|
|||
process.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remote_sglang_generate(sglang_server):
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
tokenizer = load_hf_tokenizer(MODEL_PATH)
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
engine = RemoteSGLangEngine(config)
|
||||
req = LLMRequest(
|
||||
rid=str(uuid.uuid4()),
|
||||
input_ids=tokenizer.encode("hello! how are you today"),
|
||||
gconfig=GenerationHyperparameters(max_new_tokens=16),
|
||||
)
|
||||
resp = await engine.agenerate(req)
|
||||
assert isinstance(resp, LLMResponse)
|
||||
assert resp.input_tokens == req.input_ids
|
||||
assert (
|
||||
len(resp.output_logprobs)
|
||||
== len(resp.output_tokens)
|
||||
== len(resp.output_versions)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_samples", [1, 2, 4])
|
||||
def test_remote_sglang_rollout(sglang_server, n_samples):
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
@ -211,6 +184,7 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
|
|||
engine = FSDPEngine(engine_config)
|
||||
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
|
||||
engine.initialize(None, ft_spec)
|
||||
engine.model_version = 100
|
||||
|
||||
# setup name resolve
|
||||
import realhf.base.name_resolve as name_resolve
|
||||
|
@ -227,13 +201,11 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
|
|||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
inf_engine = RemoteSGLangEngine(config)
|
||||
inf_engine.initialize(None, None)
|
||||
inf_engine.set_version(100)
|
||||
# test update weights
|
||||
path = tmp_path_factory.mktemp("upload_weights_from_disk")
|
||||
update_weight_meta = WeightUpdateMeta(
|
||||
type="disk", path=path, alloc_mode=None, comm_backend=None, model_version=100
|
||||
)
|
||||
update_weight_meta = WeightUpdateMeta(type="disk", path=str(path))
|
||||
future = inf_engine.update_weights(update_weight_meta)
|
||||
engine.upload_weights(update_weight_meta)
|
||||
future.result()
|
||||
assert inf_engine.get_version() == 100
|
||||
inf_engine.destroy()
|
||||
|
|
|
@ -1,310 +0,0 @@
|
|||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import colorama
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from datasets import load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
from tensordict import TensorDict
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
GenerationHyperparameters,
|
||||
GRPOConfig,
|
||||
load_expr_config,
|
||||
)
|
||||
from arealite.api.io_struct import FinetuneSpec, LLMRequest, WeightUpdateMeta
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
from arealite.utils.device import log_gpu_stats
|
||||
from arealite.utils.saver import Saver
|
||||
from arealite.utils.stats_logger import StatsLogger
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import logging, seeding, stats_tracker
|
||||
|
||||
logger = logging.getLogger("boba math")
|
||||
|
||||
|
||||
class RLVRWorkflow(RolloutWorkflow):
|
||||
def __init__(
|
||||
self,
|
||||
reward_fn,
|
||||
gconfig: GenerationHyperparameters,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
dump_dir: str | None = None,
|
||||
):
|
||||
self.reward_fn = reward_fn
|
||||
self.gconfig = gconfig
|
||||
self.tokenizer = tokenizer
|
||||
self.dump_dir = dump_dir
|
||||
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
|
||||
os.makedirs(self.dump_dir, exist_ok=True)
|
||||
|
||||
async def arun_episode(self, engine, data):
|
||||
input_ids = self.tokenizer.encode(data["prompt"])
|
||||
n_samples = self.gconfig.n_samples
|
||||
req = LLMRequest(
|
||||
rid=uuid.uuid4().hex,
|
||||
input_ids=input_ids,
|
||||
gconfig=self.gconfig.new(n_samples=1),
|
||||
)
|
||||
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
||||
|
||||
version = engine.get_version()
|
||||
prompt_strs = []
|
||||
completions_strs = []
|
||||
rewards = []
|
||||
seqlens = []
|
||||
|
||||
results = []
|
||||
for resp in resps:
|
||||
seq = resp.input_tokens + resp.output_tokens
|
||||
logprobs = [0.0] * resp.input_len + resp.output_logprobs
|
||||
loss_mask = [0] * resp.input_len + [1] * resp.output_len
|
||||
versions = [-1] * resp.input_len + resp.output_versions
|
||||
|
||||
prompt_str = data["prompt"]
|
||||
completions_str = self.tokenizer.decode(resp.output_tokens)
|
||||
prompt_strs.append(prompt_str)
|
||||
completions_strs.append(completions_str)
|
||||
seqlens.append(len(seq))
|
||||
reward = self.reward_fn(
|
||||
completions=completions_str,
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
rewards.append(reward)
|
||||
res = dict(
|
||||
# unsqueeze to add an additional batch dimension
|
||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
|
||||
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
||||
versions=torch.tensor(versions).unsqueeze(0),
|
||||
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
||||
# reward
|
||||
rewards=torch.tensor([float(reward)]),
|
||||
)
|
||||
results.append(TensorDict(res, batch_size=[1]))
|
||||
|
||||
if self.dump_dir is not None:
|
||||
os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True)
|
||||
# Get the unique identifier for this prompt
|
||||
qid = None
|
||||
for key in ["query_id", "id", "qid"]:
|
||||
qid = data.get(key, None)
|
||||
if qid is not None:
|
||||
break
|
||||
qid = qid or uuid.uuid4().hex
|
||||
|
||||
# Dump rollout to file
|
||||
with open(
|
||||
os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a"
|
||||
) as f:
|
||||
n_samples = self.gconfig.n_samples
|
||||
for i, (p, c, r, sl) in enumerate(
|
||||
zip(prompt_strs, completions_strs, rewards, seqlens)
|
||||
):
|
||||
info = "\n".join(
|
||||
[
|
||||
f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.",
|
||||
f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}",
|
||||
f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}",
|
||||
]
|
||||
)
|
||||
f.write(info + "\n")
|
||||
|
||||
return concat_padded_tensors(results)
|
||||
|
||||
|
||||
def get_boba_math_dataset(tokenizer, rank, world_size):
|
||||
dataset = load_dataset(
|
||||
path="json",
|
||||
split="train",
|
||||
data_files="/storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl",
|
||||
)
|
||||
dataset = dataset.filter(lambda x: len(tokenizer.encode(x["prompt"])) <= 1024)
|
||||
return split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def boba_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, query_id, solutions, **kwargs
|
||||
):
|
||||
from pebble import ProcessExpired, ProcessPool
|
||||
|
||||
from realhf.impl.dataset.math_parser import process_results
|
||||
|
||||
jobs = []
|
||||
with ProcessPool(max_workers=1) as executor:
|
||||
for sol in solutions:
|
||||
job = executor.schedule(
|
||||
process_results, args=[completions, sol], timeout=15
|
||||
)
|
||||
jobs.append(job)
|
||||
|
||||
label = 0
|
||||
for job in jobs:
|
||||
try:
|
||||
x = job.result()
|
||||
except TimeoutError:
|
||||
# print("[debug: timeout]")
|
||||
logger.warning(f"Timeout occurred while justifying the math answer.")
|
||||
x = (0, "timeout", "timeout")
|
||||
except ProcessExpired as e:
|
||||
logger.warning(f"Process terminated abnormally: {e}")
|
||||
x = (0, "error", "error")
|
||||
except Exception as e:
|
||||
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
|
||||
x = (0, "error", "error")
|
||||
label = label or x[0]
|
||||
return label
|
||||
|
||||
|
||||
def main(args):
|
||||
config, _ = load_expr_config(args, GRPOConfig)
|
||||
config: GRPOConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||
|
||||
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
|
||||
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
get_boba_math_dataset(tokenizer, rank, world_size),
|
||||
batch_size=config.train_dataset.batch_size // world_size,
|
||||
shuffle=config.train_dataset.shuffle,
|
||||
num_workers=config.train_dataset.num_workers,
|
||||
collate_fn=lambda x: x,
|
||||
drop_last=config.train_dataset.drop_last,
|
||||
)
|
||||
ft_spec = FinetuneSpec(
|
||||
total_train_epochs=config.total_train_epochs,
|
||||
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
|
||||
train_batch_size=config.train_dataset.batch_size,
|
||||
)
|
||||
|
||||
# Initialize inference engine
|
||||
rollout = RemoteSGLangEngine(config.rollout)
|
||||
rollout.initialize(None, ft_spec)
|
||||
|
||||
# Initialize train engine
|
||||
actor = FSDPPPOActor(config=config.actor)
|
||||
actor.initialize(None, ft_spec)
|
||||
ref = None
|
||||
if config.actor.kl_ctl > 0 and config.ref is not None:
|
||||
ref = FSDPPPOActor(config=config.ref)
|
||||
ref.initialize(None, ft_spec)
|
||||
|
||||
# Create rollout workflow
|
||||
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
|
||||
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
|
||||
workflow = RLVRWorkflow(
|
||||
reward_fn=boba_reward_fn,
|
||||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
dump_dir=os.path.join(
|
||||
StatsLogger.get_log_path(config.stats_logger), "generated"
|
||||
),
|
||||
)
|
||||
|
||||
# Run training.
|
||||
saver = Saver(config.saver, ft_spec, for_recover=False)
|
||||
logger = StatsLogger(config.stats_logger, ft_spec)
|
||||
|
||||
total_epochs = config.total_train_epochs
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
max_steps = total_epochs * steps_per_epoch
|
||||
|
||||
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
||||
data_generator = iter(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
epoch = global_step // steps_per_epoch
|
||||
step = global_step % steps_per_epoch
|
||||
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
|
||||
batch = batch.to(actor.device)
|
||||
# Create barrier to synchronize all rollout processes.
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
|
||||
with stats_tracker.record_timing("recompute_logp"):
|
||||
logp = actor.compute_logp(batch)
|
||||
batch["prox_logp"] = logp
|
||||
log_gpu_stats("recompute logp")
|
||||
|
||||
if ref is not None:
|
||||
with stats_tracker.record_timing("ref_logp"):
|
||||
batch["ref_logp"] = ref.compute_logp(batch)
|
||||
log_gpu_stats("ref logp")
|
||||
|
||||
with stats_tracker.record_timing("compute_advantage"):
|
||||
actor.compute_advantages(batch)
|
||||
log_gpu_stats("compute advantages")
|
||||
|
||||
with (
|
||||
stats_tracker.record_timing("train_step"),
|
||||
stats_tracker.scope("grpo_actor"),
|
||||
):
|
||||
stats = actor.ppo_update(batch)
|
||||
actor.step_lr_scheduler()
|
||||
log_gpu_stats("ppo update")
|
||||
|
||||
with stats_tracker.record_timing("update_weights"):
|
||||
path = os.path.join(
|
||||
Saver.get_save_checkpoint_root(config.saver),
|
||||
"update_weights",
|
||||
str(global_step + 1),
|
||||
)
|
||||
meta = WeightUpdateMeta(
|
||||
type="disk",
|
||||
path=path,
|
||||
alloc_mode=None,
|
||||
comm_backend=None,
|
||||
model_version=global_step + 1,
|
||||
)
|
||||
if dist.get_rank() == 0:
|
||||
future = rollout.update_weights(meta)
|
||||
actor.upload_weights(meta)
|
||||
if dist.get_rank() == 0:
|
||||
future.result()
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
rollout.set_version(global_step + 1)
|
||||
|
||||
with stats_tracker.record_timing("save"):
|
||||
saver.save(actor, epoch, step, global_step)
|
||||
|
||||
logger.commit(epoch, step, global_step, stats)
|
||||
|
||||
logger.close()
|
||||
rollout.destroy()
|
||||
if ref is not None:
|
||||
ref.destroy()
|
||||
actor.destroy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:])
|
|
@ -1,141 +0,0 @@
|
|||
experiment_name: lite-boba-math
|
||||
trial_name: run1
|
||||
|
||||
cluster:
|
||||
n_nodes: 16
|
||||
n_gpus_per_node: 8
|
||||
cluster_name: na132
|
||||
fileroot: /storage/openpsi/experiments
|
||||
name_resolve:
|
||||
type: nfs
|
||||
nfs_record_root: /storage/openpsi/experiments/name_resolve/lite-boba-math
|
||||
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
|
||||
|
||||
seed: 1
|
||||
total_train_epochs: 10
|
||||
total_train_steps: null
|
||||
tokenizer_path: ${actor.path}
|
||||
allocation_mode: sglang.d96p1t1+d32p1t1
|
||||
async_training: true
|
||||
|
||||
rollout:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
max_concurrent_rollouts: 400
|
||||
queue_size: null
|
||||
consumer_batch_size: ${train_dataset.batch_size}
|
||||
max_head_offpolicyness: 4
|
||||
enable_rollout_tracing: true
|
||||
|
||||
gconfig:
|
||||
n_samples: 16
|
||||
min_new_tokens: 0
|
||||
max_new_tokens: 30720
|
||||
greedy: false
|
||||
temperature: 1.0
|
||||
|
||||
actor:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/
|
||||
init_from_scratch: false
|
||||
disable_dropout: true
|
||||
gradient_checkpointing: true
|
||||
dtype: bfloat16
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 32768
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 1e-5
|
||||
weight_decay: 0.01
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
eps: 1e-8
|
||||
lr_scheduler_type: constant
|
||||
gradient_clipping: 1.0
|
||||
warmup_steps_proportion: 0.001
|
||||
backend: fsdp
|
||||
|
||||
group_size: ${gconfig.n_samples}
|
||||
group_adv_norm: false
|
||||
eps_clip: 0.4
|
||||
temperature: ${gconfig.temperature}
|
||||
reward_scaling: 10.0
|
||||
reward_bias: -0.5
|
||||
kl_ctl: 0.0
|
||||
ppo_n_minibatches: 4
|
||||
recompute_logprob: true
|
||||
use_decoupled_loss: true
|
||||
behav_imp_weight_cap: 5.0
|
||||
|
||||
ref:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: ${actor.path}
|
||||
init_from_scratch: false
|
||||
disable_dropout: true
|
||||
dtype: ${actor.dtype}
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 32768
|
||||
optimizer: null
|
||||
backend: fsdp
|
||||
|
||||
# SGLang
|
||||
server_only: false
|
||||
sglang:
|
||||
model_path: ${actor.path}
|
||||
random_seed: ${seed}
|
||||
skip_tokenizer_init: true
|
||||
dtype: ${actor.dtype}
|
||||
max_running_requests: null
|
||||
context_length: 32768
|
||||
mem_fraction_static: 0.9
|
||||
|
||||
# datasets
|
||||
train_dataset:
|
||||
batch_size: 512
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
checkpointer:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: 3600
|
||||
|
||||
evaluator:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: null
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
stats_logger:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
wandb:
|
||||
mode: online
|
||||
|
||||
# Launcher
|
||||
launcher:
|
||||
inference_server_cpus_per_gpu: 15
|
||||
inference_server_mem_per_gpu: 153600
|
||||
trainer_cpus_per_gpu: 15
|
||||
trainer_mem_per_gpu: 153600
|
||||
slurm:
|
||||
mount: /storage:/storage
|
||||
trainer_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif
|
||||
inference_server_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif
|
|
@ -1,9 +1,9 @@
|
|||
experiment_name: gsm8k-grpo
|
||||
trial_name: trial0
|
||||
allocation_mode: sglang.d4p1t1+d4p1t1
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
cluster:
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
fileroot: /tmp/arealite/experiments
|
||||
name_resolve:
|
||||
type: nfs
|
||||
|
@ -32,7 +32,7 @@ gconfig:
|
|||
actor:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: /storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/
|
||||
path: Qwen/Qwen2-1.5B-Instruct
|
||||
init_from_scratch: false
|
||||
disable_dropout: true
|
||||
gradient_checkpointing: false
|
||||
|
|
|
@ -13,7 +13,7 @@ tokenizer_path: ${model.path}
|
|||
model:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
|
||||
path: Qwen/Qwen3-1.7B
|
||||
init_from_scratch: false
|
||||
gradient_checkpointing: false
|
||||
dtype: bfloat16
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
@ -9,7 +8,7 @@ from datasets.distributed import split_dataset_by_node
|
|||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.device import log_gpu_stats
|
||||
|
@ -93,6 +92,18 @@ def main(args):
|
|||
ref = FSDPPPOActor(config=config.ref)
|
||||
ref.initialize(None, ft_spec)
|
||||
|
||||
# NOTE: Weight update meta only requires address and free port of rank 0,
|
||||
# but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks
|
||||
# due to `engine.get_param_specs()`.
|
||||
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
|
||||
weight_update_meta = [
|
||||
WeightUpdateMeta.from_fsdp_nccl(
|
||||
AllocationMode.from_str(config.allocation_mode), actor
|
||||
)
|
||||
]
|
||||
dist.broadcast_object_list(weight_update_meta, src=0)
|
||||
weight_update_meta = weight_update_meta[0]
|
||||
|
||||
# Create rollout workflow
|
||||
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
|
||||
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
|
||||
|
@ -136,7 +147,7 @@ def main(args):
|
|||
|
||||
batch = batch.to(actor.device)
|
||||
# Create barrier to synchronize all rollout processes.
|
||||
dist.barrier()
|
||||
dist.barrier(device_ids=[actor.device.index])
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
|
||||
|
@ -163,26 +174,16 @@ def main(args):
|
|||
log_gpu_stats("ppo update")
|
||||
|
||||
with stats_tracker.record_timing("update_weights"):
|
||||
path = os.path.join(
|
||||
Saver.get_save_checkpoint_root(config.saver),
|
||||
"update_weights",
|
||||
str(global_step + 1),
|
||||
)
|
||||
meta = WeightUpdateMeta(
|
||||
type="disk",
|
||||
path=path,
|
||||
alloc_mode=None,
|
||||
comm_backend=None,
|
||||
model_version=global_step + 1,
|
||||
)
|
||||
rollout.pause()
|
||||
if dist.get_rank() == 0:
|
||||
future = rollout.update_weights(meta)
|
||||
actor.upload_weights(meta)
|
||||
future = rollout.update_weights(weight_update_meta)
|
||||
actor.upload_weights(weight_update_meta)
|
||||
if dist.get_rank() == 0:
|
||||
future.result()
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
dist.barrier()
|
||||
dist.barrier(device_ids=[actor.device.index])
|
||||
torch.cuda.synchronize()
|
||||
rollout.resume()
|
||||
actor.set_version(global_step + 1)
|
||||
rollout.set_version(global_step + 1)
|
||||
|
||||
with stats_tracker.record_timing("save"):
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
#!/bin/sh
|
||||
AREAL_PATH=$PWD
|
||||
cd /sglang
|
||||
git apply $AREAL_PATH/patch/sglang/v0.4.6.post4.patch
|
||||
cd $AREAL_PATH
|
||||
|
||||
# Package used for calculating math reward
|
||||
pip install -e evaluation/latex2sympy
|
||||
|
||||
# Install AReaL
|
||||
pip install -e .
|
|
@ -1,9 +1,9 @@
|
|||
#/bin/bash
|
||||
# basic dependencies
|
||||
pip install -U pip
|
||||
pip uninstall deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
|
||||
pip install nvidia-ml-py
|
||||
pip uninstall pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
|
||||
pip install pynvml nvidia-ml-py
|
||||
pip install -e evaluation/latex2sympy
|
||||
pip install vllm==0.8.5 --no-build-isolation
|
||||
pip install flash_attn --no-build-isolation
|
||||
pip install "flash-attn<=2.7.3" --no-build-isolation
|
||||
pip install -r evaluation/requirements.txt
|
|
@ -1,24 +1,14 @@
|
|||
#!/bin/bash
|
||||
# basic dependencies
|
||||
pip install -U pip
|
||||
pip uninstall torch deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
|
||||
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0
|
||||
pip install "sglang[all]==0.4.6.post4"
|
||||
pip uninstall pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
|
||||
pip install torch==2.7.1 torchaudio==2.7.1 torchvision==0.22.1 "deepspeed>=0.17.2" pynvml
|
||||
pip install "sglang[all]==0.4.9.post2"
|
||||
pip install megatron-core==0.11.0 nvidia-ml-py
|
||||
pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose
|
||||
pip install "flash-attn<=2.7.3" --no-build-isolation
|
||||
|
||||
# Package used for calculating math reward
|
||||
pip install -e evaluation/latex2sympy
|
||||
|
||||
# Install an editable sglang
|
||||
rm -rf ./sglang
|
||||
git clone -b v0.4.6.post4 https://github.com/sgl-project/sglang
|
||||
AREAL_PATH=$PWD
|
||||
cd sglang
|
||||
git apply ../patch/sglang/v0.4.6.post4.patch
|
||||
pip install -e "python[all]" --no-deps
|
||||
cd $AREAL_PATH
|
||||
|
||||
# Install AReaL
|
||||
pip install -e .
|
||||
|
|
|
@ -67,6 +67,7 @@ class InstallationValidator:
|
|||
def test_flash_attn_functionality(self, flash_attn_module):
|
||||
"""Test flash attention functionality."""
|
||||
# Try to import key functions
|
||||
import flash_attn_2_cuda
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
print(" - Flash attention functions imported successfully")
|
||||
|
||||
|
@ -79,12 +80,12 @@ class InstallationValidator:
|
|||
"""Test SGLang basic functionality."""
|
||||
# Basic import test is sufficient for CI
|
||||
import sgl_kernel
|
||||
from sglang import launch_server
|
||||
assert Version(get_version("sglang")) == Version("0.4.6.post4")
|
||||
from sglang import Engine, launch_server
|
||||
assert Version(get_version("sglang")) == Version("0.4.9.post2"), "SGLang version should be v0.4.9.post2"
|
||||
print(" - SGLang imported successfully")
|
||||
|
||||
def test_transformers(self, transformers_module):
|
||||
assert Version(get_version("transformers")) == Version("4.51.1")
|
||||
assert Version(get_version("transformers")) == Version("4.53.1"), "transformers version should be 4.53.1"
|
||||
print(" - transformers imported successfully")
|
||||
|
||||
def validate_critical_dependencies(self):
|
||||
|
@ -140,7 +141,7 @@ class InstallationValidator:
|
|||
self.test_import("flashattn_hopper", required=False)
|
||||
|
||||
# Optional utilities
|
||||
self.test_import("tensorboardx", required=False)
|
||||
self.test_import("tensorboardX", required=False)
|
||||
self.test_import("swanlab", required=False)
|
||||
self.test_import("matplotlib", required=False)
|
||||
self.test_import("seaborn", required=False)
|
||||
|
|
|
@ -1,144 +0,0 @@
|
|||
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
|
||||
index 174656b2..33fe0a5f 100644
|
||||
--- a/python/sglang/srt/managers/io_struct.py
|
||||
+++ b/python/sglang/srt/managers/io_struct.py
|
||||
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
|
||||
success: bool
|
||||
|
||||
|
||||
+@dataclass
|
||||
+class InterruptAllReqInput:
|
||||
+ pass
|
||||
+
|
||||
+
|
||||
+@dataclass
|
||||
+class InterruptAllReqOutput:
|
||||
+ num_interrupted_requests: int
|
||||
+
|
||||
+
|
||||
@dataclass
|
||||
class UpdateWeightFromDiskReqInput:
|
||||
# The model path with the new weights
|
||||
model_path: str
|
||||
+ allow_interrupt: bool = False
|
||||
# The format to load the weights
|
||||
load_format: Optional[str] = None
|
||||
|
||||
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||
index 8891115c..843a8a82 100644
|
||||
--- a/python/sglang/srt/managers/scheduler.py
|
||||
+++ b/python/sglang/srt/managers/scheduler.py
|
||||
@@ -70,6 +70,8 @@ from sglang.srt.managers.io_struct import (
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
+ InterruptAllReqInput,
|
||||
+ InterruptAllReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -419,6 +421,7 @@ class Scheduler(
|
||||
# Init request dispatcher
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
+ (InterruptAllReqInput, self.interrupt_all_requests),
|
||||
(TokenizedGenerateReqInput, self.handle_generate_request),
|
||||
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
||||
(FlushCacheReqInput, self.flush_cache_wrapped),
|
||||
@@ -1938,6 +1941,15 @@ class Scheduler(
|
||||
def _pause_engine(self) -> Tuple[List[Req], int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
|
||||
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
|
||||
+ for req in self.waiting_queue:
|
||||
+ req.sampling_params.max_new_tokens = 0
|
||||
+ for req in self.running_batch.reqs:
|
||||
+ req.sampling_params.max_new_tokens = len(req.output_ids)
|
||||
+ logger.info(f"Interrupt {num} requests.")
|
||||
+ return InterruptAllReqOutput(num)
|
||||
+
|
||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||
"""In-place update of the weights from disk."""
|
||||
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
||||
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
||||
index 82709b09..bfab3ce7 100644
|
||||
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
||||
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
||||
@@ -76,6 +76,8 @@ from sglang.srt.managers.io_struct import (
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
+ InterruptAllReqInput,
|
||||
+ InterruptAllReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -265,6 +267,9 @@ class TokenizerManager:
|
||||
self.resume_memory_occupation_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
+ self.interrupt_requests_communicator = _Communicator(
|
||||
+ self.send_to_scheduler, server_args.dp_size
|
||||
+ )
|
||||
self.flush_cache_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -294,6 +299,10 @@ class TokenizerManager:
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
self._handle_update_weights_from_disk_req_output,
|
||||
),
|
||||
+ (
|
||||
+ InterruptAllReqOutput,
|
||||
+ self.interrupt_requests_communicator.handle_recv,
|
||||
+ ),
|
||||
(
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
self.init_weights_update_group_communicator.handle_recv,
|
||||
@@ -767,6 +776,13 @@ class TokenizerManager:
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
+ if obj.allow_interrupt:
|
||||
+ num_interrupted_requests = await self.interrupt_all_requests(
|
||||
+ InterruptAllReqInput()
|
||||
+ )
|
||||
+ # Set a break point to wait for the interrupt to finish
|
||||
+ await asyncio.sleep(0.1)
|
||||
+
|
||||
# default the load format to the server_args
|
||||
if obj.load_format is None:
|
||||
obj.load_format = self.server_args.load_format
|
||||
@@ -776,7 +792,12 @@ class TokenizerManager:
|
||||
# Hold the lock if it is not async. This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
- return await self._wait_for_model_update_from_disk(obj)
|
||||
+ success, message, n_paused = (
|
||||
+ await self._wait_for_model_update_from_disk(obj)
|
||||
+ )
|
||||
+ if obj.allow_interrupt:
|
||||
+ return success, message, num_interrupted_requests
|
||||
+ return success, message, n_paused
|
||||
|
||||
async def _wait_for_model_update_from_disk(
|
||||
self, obj: UpdateWeightFromDiskReqInput
|
||||
@@ -849,6 +870,18 @@ class TokenizerManager:
|
||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
+ async def interrupt_all_requests(
|
||||
+ self,
|
||||
+ obj: InterruptAllReqInput,
|
||||
+ request: Optional[fastapi.Request] = None,
|
||||
+ ) -> Tuple[bool, str]:
|
||||
+ self.auto_create_handle_loop()
|
||||
+ result = await self.interrupt_requests_communicator(obj)
|
||||
+ if self.server_args.dp_size == 1:
|
||||
+ return result[0].num_interrupted_requests
|
||||
+ else:
|
||||
+ return [r.num_interrupted_requests for r in result]
|
||||
+
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
|
@ -1,144 +0,0 @@
|
|||
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
|
||||
index 5390668c..db370d19 100644
|
||||
--- a/python/sglang/srt/managers/io_struct.py
|
||||
+++ b/python/sglang/srt/managers/io_struct.py
|
||||
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
|
||||
success: bool
|
||||
|
||||
|
||||
+@dataclass
|
||||
+class InterruptAllReqInput:
|
||||
+ pass
|
||||
+
|
||||
+
|
||||
+@dataclass
|
||||
+class InterruptAllReqOutput:
|
||||
+ num_interrupted_requests: int
|
||||
+
|
||||
+
|
||||
@dataclass
|
||||
class UpdateWeightFromDiskReqInput:
|
||||
# The model path with the new weights
|
||||
model_path: str
|
||||
+ allow_interrupt: bool = False
|
||||
# The format to load the weights
|
||||
load_format: Optional[str] = None
|
||||
|
||||
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||
index 1178eec5..318dee33 100644
|
||||
--- a/python/sglang/srt/managers/scheduler.py
|
||||
+++ b/python/sglang/srt/managers/scheduler.py
|
||||
@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
+ InterruptAllReqInput,
|
||||
+ InterruptAllReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -427,6 +429,7 @@ class Scheduler(
|
||||
# Init request dispatcher
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
+ (InterruptAllReqInput, self.interrupt_all_requests),
|
||||
(TokenizedGenerateReqInput, self.handle_generate_request),
|
||||
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
||||
(FlushCacheReqInput, self.flush_cache_wrapped),
|
||||
@@ -1971,6 +1974,15 @@ class Scheduler(
|
||||
def _pause_engine(self) -> Tuple[List[Req], int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
|
||||
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
|
||||
+ for req in self.waiting_queue:
|
||||
+ req.sampling_params.max_new_tokens = 0
|
||||
+ for req in self.running_batch.reqs:
|
||||
+ req.sampling_params.max_new_tokens = len(req.output_ids)
|
||||
+ logger.info(f"Interrupt {num} requests.")
|
||||
+ return InterruptAllReqOutput(num)
|
||||
+
|
||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||
"""In-place update of the weights from disk."""
|
||||
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
||||
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
||||
index b646fae1..c668728b 100644
|
||||
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
||||
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
||||
@@ -80,6 +80,8 @@ from sglang.srt.managers.io_struct import (
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
+ InterruptAllReqInput,
|
||||
+ InterruptAllReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -279,6 +281,9 @@ class TokenizerManager:
|
||||
self.slow_down_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
+ self.interrupt_requests_communicator = _Communicator(
|
||||
+ self.send_to_scheduler, server_args.dp_size
|
||||
+ )
|
||||
self.flush_cache_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -309,6 +314,10 @@ class TokenizerManager:
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
self._handle_update_weights_from_disk_req_output,
|
||||
),
|
||||
+ (
|
||||
+ InterruptAllReqOutput,
|
||||
+ self.interrupt_requests_communicator.handle_recv,
|
||||
+ ),
|
||||
(
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
self.init_weights_update_group_communicator.handle_recv,
|
||||
@@ -799,6 +808,13 @@ class TokenizerManager:
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
+ if obj.allow_interrupt:
|
||||
+ num_interrupted_requests = await self.interrupt_all_requests(
|
||||
+ InterruptAllReqInput()
|
||||
+ )
|
||||
+ # Set a break point to wait for the interrupt to finish
|
||||
+ await asyncio.sleep(0.1)
|
||||
+
|
||||
# default the load format to the server_args
|
||||
if obj.load_format is None:
|
||||
obj.load_format = self.server_args.load_format
|
||||
@@ -808,7 +824,12 @@ class TokenizerManager:
|
||||
# Hold the lock if it is not async. This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
- return await self._wait_for_model_update_from_disk(obj)
|
||||
+ success, message, n_paused = (
|
||||
+ await self._wait_for_model_update_from_disk(obj)
|
||||
+ )
|
||||
+ if obj.allow_interrupt:
|
||||
+ return success, message, num_interrupted_requests
|
||||
+ return success, message, n_paused
|
||||
|
||||
async def _wait_for_model_update_from_disk(
|
||||
self, obj: UpdateWeightFromDiskReqInput
|
||||
@@ -881,6 +902,18 @@ class TokenizerManager:
|
||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
+ async def interrupt_all_requests(
|
||||
+ self,
|
||||
+ obj: InterruptAllReqInput,
|
||||
+ request: Optional[fastapi.Request] = None,
|
||||
+ ) -> Tuple[bool, str]:
|
||||
+ self.auto_create_handle_loop()
|
||||
+ result = await self.interrupt_requests_communicator(obj)
|
||||
+ if self.server_args.dp_size == 1:
|
||||
+ return result[0].num_interrupted_requests
|
||||
+ else:
|
||||
+ return [r.num_interrupted_requests for r in result]
|
||||
+
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
|
@ -31,10 +31,9 @@ dependencies = [
|
|||
"huggingface_hub",
|
||||
"datasets",
|
||||
"accelerate",
|
||||
"transformers==4.51.1",
|
||||
"transformers==4.53.0",
|
||||
|
||||
# Scientific computing
|
||||
"numpy<2.0.0",
|
||||
"scipy",
|
||||
"pandas",
|
||||
"matplotlib",
|
||||
|
|
|
@ -40,42 +40,11 @@ def execute_shell_command(command: str) -> subprocess.Popen:
|
|||
)
|
||||
|
||||
|
||||
def apply_sglang_patch():
|
||||
p = Path(os.path.dirname(__file__))
|
||||
patch_path = str(
|
||||
p.parent.parent
|
||||
/ "patch"
|
||||
/ "sglang"
|
||||
/ f"v{pkg_version.get_version('sglang')}.patch"
|
||||
)
|
||||
|
||||
target_path = ""
|
||||
sglang_meta = subprocess.check_output(
|
||||
"python3 -m pip show sglang", shell=True
|
||||
).decode("ascii")
|
||||
for line in sglang_meta.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("Editable project location: "):
|
||||
target_path = str(Path(line.split(": ")[1]).parent)
|
||||
|
||||
if target_path:
|
||||
proc = subprocess.Popen(
|
||||
["git", "apply", patch_path],
|
||||
cwd=target_path,
|
||||
stderr=sys.stdout,
|
||||
stdout=sys.stdout,
|
||||
)
|
||||
proc.wait()
|
||||
logger.info(f"Applied SGLang patch at {target_path}")
|
||||
|
||||
|
||||
def launch_server_cmd(command: str, port: int = 30000):
|
||||
"""
|
||||
Launch the server using the given command.
|
||||
If no port is specified, a free port is reserved.
|
||||
"""
|
||||
if not ray.is_initialized():
|
||||
apply_sglang_patch()
|
||||
assert port is not None
|
||||
full_command = f"{command} --port {port}"
|
||||
process = execute_shell_command(full_command)
|
||||
|
|
|
@ -155,39 +155,22 @@ class GserverManager(Worker):
|
|||
|
||||
return None
|
||||
|
||||
async def flush_requests_and_update_weights(
|
||||
self, server_url, new_param_path, update_weights_retries=5
|
||||
):
|
||||
server_index = self.server_urls.index(server_url)
|
||||
success = False
|
||||
for _ in range(update_weights_retries):
|
||||
async with aiohttp.ClientSession(
|
||||
server_url,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.config.flush_request_timeout,
|
||||
sock_connect=self.config.flush_request_timeout,
|
||||
),
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"/update_weights_from_disk",
|
||||
json=dict(model_path=new_param_path, allow_interrupt=True),
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
res = await resp.json()
|
||||
success = res["success"]
|
||||
if success:
|
||||
if "num_paused_requests" in res:
|
||||
logger.info(
|
||||
f"{res['num_paused_requests']} requests are interrupted "
|
||||
f"during updating weights for server {server_index}: {server_url}"
|
||||
)
|
||||
return
|
||||
logger.warning(
|
||||
f"Update weights failed: {res['message']}. Retrying."
|
||||
)
|
||||
logger.warning(f"Update weights failed: {resp.reason}. Retrying.")
|
||||
time.sleep(0.1)
|
||||
raise RuntimeError("Update weights failed.")
|
||||
async def flush_requests_and_update_weights(self, server_url, new_param_path):
|
||||
async with aiohttp.ClientSession(
|
||||
server_url,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.config.flush_request_timeout,
|
||||
sock_connect=self.config.flush_request_timeout,
|
||||
),
|
||||
) as session:
|
||||
(await session.post("/pause_generation")).raise_for_status()
|
||||
async with session.post(
|
||||
f"/update_weights_from_disk",
|
||||
json=dict(model_path=new_param_path),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
assert (await resp.json())["success"]
|
||||
(await session.post("/continue_generation")).raise_for_status()
|
||||
|
||||
def _round_robin_schedule(self, req_meta: GenReqMeta) -> int:
|
||||
if not hasattr(self, "round_robin_idx"):
|
||||
|
|
|
@ -25,7 +25,6 @@ numba
|
|||
packaging
|
||||
pandas
|
||||
pybind11>=2.10.0
|
||||
numpy<2.0.0
|
||||
psutil
|
||||
pynvml
|
||||
pytest
|
||||
|
|
Loading…
Reference in New Issue