mirror of https://github.com/inclusionAI/AReaL
0724_merge7
This commit is contained in:
commit
f5924b1851
|
@ -163,7 +163,6 @@ class WeightUpdateMeta:
|
||||||
type: str
|
type: str
|
||||||
path: str | None
|
path: str | None
|
||||||
alloc_mode: AllocationMode | None
|
alloc_mode: AllocationMode | None
|
||||||
comm_backend: str | None
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SaveLoadMeta:
|
class SaveLoadMeta:
|
||||||
|
|
|
@ -285,85 +285,6 @@ class PPOActorConfig(TrainEngineConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PPOActorConfig(TrainEngineConfig):
|
|
||||||
# Core PPO/GRPO Parameters
|
|
||||||
group_size: int = field(
|
|
||||||
default=1, metadata={"help": "Number of sequences in each group"}
|
|
||||||
)
|
|
||||||
group_adv_norm: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Normalize advantages within each prompt group rather than globally"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ppo_n_minibatches: int = field(
|
|
||||||
default=4, metadata={"help": "Number of minibatches for each PPO update"}
|
|
||||||
)
|
|
||||||
eps_clip: float = field(
|
|
||||||
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
|
|
||||||
)
|
|
||||||
c_clip: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
temperature: float = field(
|
|
||||||
default=1.0, metadata={"help": "Temperature during generation."}
|
|
||||||
)
|
|
||||||
# Reward
|
|
||||||
group_reward_norm: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
reward_scaling: float = field(
|
|
||||||
default=1.0, metadata={"help": "Reward scaling factor"}
|
|
||||||
)
|
|
||||||
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
|
|
||||||
reward_clip: float = field(
|
|
||||||
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
|
|
||||||
)
|
|
||||||
mask_no_eos_with_zero: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Mask truncated generations (no EOS token) and exclude from training"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Advantage Estimation
|
|
||||||
discount: float = field(
|
|
||||||
default=1.0, metadata={"help": "Discount factor for future rewards"}
|
|
||||||
)
|
|
||||||
gae_lambda: float = field(
|
|
||||||
default=1.0, metadata={"help": "Lambda parameter for GAE"}
|
|
||||||
)
|
|
||||||
adv_norm: bool = field(
|
|
||||||
default=True, metadata={"help": "Enable advantage normalization"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# KL Control
|
|
||||||
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
|
|
||||||
|
|
||||||
# Asynchronous RL
|
|
||||||
recompute_logprob: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Recompute logp and replace the logp returned by inference."},
|
|
||||||
)
|
|
||||||
use_decoupled_loss: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
|
|
||||||
)
|
|
||||||
behav_imp_weight_cap: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SGLangConfig:
|
class SGLangConfig:
|
||||||
"""Configuration for SGLang runtime. Refer to:
|
"""Configuration for SGLang runtime. Refer to:
|
||||||
|
|
|
@ -12,6 +12,7 @@ from arealite.api.io_struct import (
|
||||||
FinetuneSpec,
|
FinetuneSpec,
|
||||||
LLMRequest,
|
LLMRequest,
|
||||||
LLMResponse,
|
LLMResponse,
|
||||||
|
ParamSpec,
|
||||||
SaveLoadMeta,
|
SaveLoadMeta,
|
||||||
WeightUpdateMeta,
|
WeightUpdateMeta,
|
||||||
)
|
)
|
||||||
|
@ -63,7 +64,19 @@ class TrainEngine(abc.ABC):
|
||||||
return self.train(False)
|
return self.train(False)
|
||||||
|
|
||||||
def upload_weights(self, meta: WeightUpdateMeta):
|
def upload_weights(self, meta: WeightUpdateMeta):
|
||||||
"""Upload weights to the inference engine."""
|
"""Upload weights to the inference engine (in a blocking manner)."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_param_specs(self) -> List[ParamSpec]:
|
||||||
|
"""Get the parameter specifications for the model."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def set_version(self, version: int):
|
||||||
|
"""Set the current weight version in the train engine."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_version(self) -> int:
|
||||||
|
"""Get the current weight version in the train engine."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def save(self, meta: SaveLoadMeta):
|
def save(self, meta: SaveLoadMeta):
|
||||||
|
@ -122,14 +135,22 @@ class InferenceEngine(abc.ABC):
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
"""Destroy the engine and release GPU memory."""
|
"""Destroy the engine and release GPU memory."""
|
||||||
|
|
||||||
def update_weights(self, meta: WeightUpdateMeta) -> Future:
|
|
||||||
"""Update weights in the inference engine."""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def agenerate(self, req: LLMRequest) -> LLMResponse:
|
async def agenerate(self, req: LLMRequest) -> LLMResponse:
|
||||||
"""Asynchronously generate a response for the given request."""
|
"""Asynchronously generate a response for the given request."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def update_weights(self, meta: WeightUpdateMeta) -> Future:
|
||||||
|
"""Update weights in the inference engine in a non-blocking manner."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def set_version(self, version: int) -> None:
|
||||||
|
"""Set the current weight version in the inference engine."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_version(self) -> int:
|
||||||
|
"""Get the current weight version in the inference engine."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
|
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
|
||||||
"""Asynchronously submit a request to the inference engine. Exits immediately."""
|
"""Asynchronously submit a request to the inference engine. Exits immediately."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -2,17 +2,22 @@
|
||||||
# Licensed under the Apache License, Version 2.0
|
# Licensed under the Apache License, Version 2.0
|
||||||
import enum
|
import enum
|
||||||
import itertools
|
import itertools
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from gymnasium.core import ActType, ObsType
|
from gymnasium.core import ActType, ObsType
|
||||||
from PIL.Image import Image as ImageObject
|
from PIL.Image import Image as ImageObject
|
||||||
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||||
|
|
||||||
from arealite.api.cli_args import GenerationHyperparameters
|
from arealite.api.cli_args import GenerationHyperparameters, SaverConfig
|
||||||
|
from arealite.utils.network import find_free_ports, gethostip
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from arealite.api.engine_api import TrainEngine
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -168,20 +173,55 @@ class AllocationMode:
|
||||||
return other_alloc
|
return other_alloc
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParamSpec:
|
||||||
|
name: str
|
||||||
|
shape: Tuple
|
||||||
|
dtype: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WeightUpdateMeta:
|
class WeightUpdateMeta:
|
||||||
type: str
|
type: Literal["disk", "nccl"]
|
||||||
path: str | None
|
path: str | None = None
|
||||||
alloc_mode: AllocationMode | None
|
alloc_mode: AllocationMode | None = None
|
||||||
comm_backend: str | None
|
|
||||||
model_version: int = 0
|
nccl_master_address: str = "127.0.0.1"
|
||||||
tp_size: int = 1
|
nccl_master_port: int = 29500
|
||||||
master_address: str = "127.0.0.1"
|
nccl_param_specs: List[ParamSpec] = field(default_factory=list)
|
||||||
master_port: int = 29500
|
nccl_group_name: str = "update_weight_group"
|
||||||
world_size: int = 1
|
|
||||||
group_name: str = "aupdate_weights_from_distributed"
|
@classmethod
|
||||||
parameter_names: List[str] = field(default_factory=list)
|
def from_disk(
|
||||||
state_dict_key_to_shape: Dict[str, Tuple[int]] = field(default_factory=dict)
|
cls,
|
||||||
|
saver_config: SaverConfig,
|
||||||
|
) -> "WeightUpdateMeta":
|
||||||
|
from arealite.utils.saver import Saver
|
||||||
|
|
||||||
|
path = os.path.join(
|
||||||
|
Saver.get_save_checkpoint_root(saver_config),
|
||||||
|
"weight_update",
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
type="disk",
|
||||||
|
path=path,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_fsdp_nccl(
|
||||||
|
cls,
|
||||||
|
allocation_mode: AllocationMode,
|
||||||
|
fsdp_engine: "TrainEngine",
|
||||||
|
nccl_group_name: str = "update_weight_group",
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
type="nccl",
|
||||||
|
alloc_mode=allocation_mode,
|
||||||
|
nccl_master_address=gethostip(),
|
||||||
|
nccl_master_port=find_free_ports(1)[0],
|
||||||
|
nccl_param_specs=fsdp_engine.get_param_specs(),
|
||||||
|
nccl_group_name=nccl_group_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -49,6 +49,7 @@ class BaseHFEngine(TrainEngine):
|
||||||
self.processor: AutoProcessor | None = None
|
self.processor: AutoProcessor | None = None
|
||||||
# huggingface model config
|
# huggingface model config
|
||||||
self.model_config: PretrainedConfig
|
self.model_config: PretrainedConfig
|
||||||
|
self._version: int = 0
|
||||||
|
|
||||||
# initialization
|
# initialization
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
|
@ -64,6 +65,12 @@ class BaseHFEngine(TrainEngine):
|
||||||
|
|
||||||
self.world_size = int(os.environ["WORLD_SIZE"])
|
self.world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
|
||||||
|
def set_version(self, version: int):
|
||||||
|
self._version = version
|
||||||
|
|
||||||
|
def get_version(self) -> int:
|
||||||
|
return self._version
|
||||||
|
|
||||||
def train(self, mode: bool = True):
|
def train(self, mode: bool = True):
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
self.model.train(mode=mode)
|
self.model.train(mode=mode)
|
||||||
|
@ -222,7 +229,7 @@ class BaseHFEngine(TrainEngine):
|
||||||
)
|
)
|
||||||
state_dict = self.optimizer.state_dict()
|
state_dict = self.optimizer.state_dict()
|
||||||
torch.save(state_dict, shard_path)
|
torch.save(state_dict, shard_path)
|
||||||
dist.barrier()
|
dist.barrier(device_ids=[self.device.index])
|
||||||
|
|
||||||
def load_optimizer_state(self, path: str):
|
def load_optimizer_state(self, path: str):
|
||||||
# Load FSDP sharded state dict
|
# Load FSDP sharded state dict
|
||||||
|
@ -234,7 +241,7 @@ class BaseHFEngine(TrainEngine):
|
||||||
)
|
)
|
||||||
optimizer_state_dict = torch.load(shard_path, weights_only=False)
|
optimizer_state_dict = torch.load(shard_path, weights_only=False)
|
||||||
self.optimizer.load_state_dict(optimizer_state_dict)
|
self.optimizer.load_state_dict(optimizer_state_dict)
|
||||||
dist.barrier()
|
dist.barrier(device_ids=[self.device.index])
|
||||||
|
|
||||||
def step_lr_scheduler(self):
|
def step_lr_scheduler(self):
|
||||||
assert self.lr_scheduler is not None
|
assert self.lr_scheduler is not None
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Callable, Dict, Optional, Tuple
|
from typing import Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -14,7 +14,8 @@ from torch.distributed.checkpoint.state_dict import (
|
||||||
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
from transformers import AutoProcessor, PreTrainedTokenizerFast
|
||||||
|
|
||||||
from arealite.api.cli_args import TrainEngineConfig
|
from arealite.api.cli_args import TrainEngineConfig
|
||||||
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
from arealite.api.engine_api import FinetuneSpec
|
||||||
|
from arealite.api.io_struct import ParamSpec, SaveLoadMeta, WeightUpdateMeta
|
||||||
from arealite.engine.base_hf_engine import BaseHFEngine
|
from arealite.engine.base_hf_engine import BaseHFEngine
|
||||||
from arealite.utils.distributed import init_custom_process_group
|
from arealite.utils.distributed import init_custom_process_group
|
||||||
from arealite.utils.fsdp import (
|
from arealite.utils.fsdp import (
|
||||||
|
@ -125,7 +126,7 @@ class FSDPEngine(BaseHFEngine):
|
||||||
if processor is not None:
|
if processor is not None:
|
||||||
processor.save_pretrained(path)
|
processor.save_pretrained(path)
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier(device_ids=[self.device.index])
|
||||||
|
|
||||||
def _load_model_from_hf(self, path: str):
|
def _load_model_from_hf(self, path: str):
|
||||||
"""Load model from HuggingFace format."""
|
"""Load model from HuggingFace format."""
|
||||||
|
@ -146,7 +147,7 @@ class FSDPEngine(BaseHFEngine):
|
||||||
if not self.weight_update_group_initialized:
|
if not self.weight_update_group_initialized:
|
||||||
self._init_distributed_weight_update(meta)
|
self._init_distributed_weight_update(meta)
|
||||||
self._update_weights_from_distributed()
|
self._update_weights_from_distributed()
|
||||||
dist.barrier()
|
dist.barrier(device_ids=[self.device.index])
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
elif meta.type == "disk":
|
elif meta.type == "disk":
|
||||||
self._save_model_to_hf(meta.path, self.tokenizer, self.processor)
|
self._save_model_to_hf(meta.path, self.tokenizer, self.processor)
|
||||||
|
@ -155,7 +156,7 @@ class FSDPEngine(BaseHFEngine):
|
||||||
update_name = names.update_weights_from_disk(
|
update_name = names.update_weights_from_disk(
|
||||||
self.config.experiment_name,
|
self.config.experiment_name,
|
||||||
self.config.trial_name,
|
self.config.trial_name,
|
||||||
meta.model_version,
|
self.model_version,
|
||||||
)
|
)
|
||||||
name_resolve.add(
|
name_resolve.add(
|
||||||
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
|
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
|
||||||
|
@ -164,16 +165,18 @@ class FSDPEngine(BaseHFEngine):
|
||||||
raise ValueError(f"Unknown weight update type {meta.type}")
|
raise ValueError(f"Unknown weight update type {meta.type}")
|
||||||
|
|
||||||
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
||||||
|
# NOTE: Processes launched with torchrun will set the following env var to True,
|
||||||
|
# which blocks creating another TCP store for weight update.
|
||||||
|
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = str(False)
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
self.weight_update_group = init_custom_process_group(
|
self.weight_update_group = init_custom_process_group(
|
||||||
backend="nccl",
|
backend="nccl",
|
||||||
world_size=meta.world_size,
|
world_size=meta.alloc_mode.gen_world_size + 1,
|
||||||
init_method=f"tcp://{meta.master_address}:{meta.master_port}",
|
init_method=f"tcp://{meta.nccl_master_address}:{meta.nccl_master_port}",
|
||||||
rank=0,
|
rank=0,
|
||||||
group_name=meta.group_name,
|
group_name=meta.nccl_group_name,
|
||||||
)
|
)
|
||||||
# NOTE: synchronizing with sglang's barrier
|
# NOTE: sglang v0.4.9.post2 or later does not have the barrier call
|
||||||
dist.barrier(group=self.weight_update_group, device_ids=[self.device.index])
|
|
||||||
self.weight_update_group_initialized = True
|
self.weight_update_group_initialized = True
|
||||||
|
|
||||||
def _update_weights_from_distributed(self):
|
def _update_weights_from_distributed(self):
|
||||||
|
@ -185,23 +188,29 @@ class FSDPEngine(BaseHFEngine):
|
||||||
else:
|
else:
|
||||||
tensor = param.data
|
tensor = param.data
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
print(f"Broadcasting {name} with shape {tensor.shape}", flush=True)
|
logger.debug(
|
||||||
|
f"Broadcasting {name} with shape {tensor.shape}", flush=True
|
||||||
|
)
|
||||||
dist.broadcast(tensor, src=0, group=self.weight_update_group)
|
dist.broadcast(tensor, src=0, group=self.weight_update_group)
|
||||||
dist.barrier()
|
|
||||||
del tensor # optional, for memory hygiene
|
del tensor # optional, for memory hygiene
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_param_meta_for_distributed_update(self) -> Dict[str, Tuple[int]]:
|
def get_param_specs(self) -> List[ParamSpec]:
|
||||||
"""Return a dict mapping param name to its shape (expanded if DTensor)."""
|
param_specs = []
|
||||||
param_shapes = {}
|
|
||||||
for name, param in self.model.named_parameters():
|
for name, param in self.model.named_parameters():
|
||||||
if isinstance(param.data, DTensor):
|
if isinstance(param.data, DTensor):
|
||||||
tensor = param.data.full_tensor()
|
tensor = param.data.full_tensor()
|
||||||
else:
|
else:
|
||||||
tensor = param.data
|
tensor = param.data
|
||||||
param_shapes[name] = tuple(tensor.shape)
|
param_specs.append(
|
||||||
|
ParamSpec(
|
||||||
|
name=name,
|
||||||
|
shape=tuple(tensor.shape),
|
||||||
|
dtype=str(tensor.dtype).split("torch.")[1],
|
||||||
|
)
|
||||||
|
)
|
||||||
del tensor # free memory if full_tensor was created
|
del tensor # free memory if full_tensor was created
|
||||||
return param_shapes
|
return param_specs
|
||||||
|
|
||||||
def train_batch(
|
def train_batch(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import shutil
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
from concurrent.futures import Future, ProcessPoolExecutor
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from queue import Empty, Full, Queue
|
from queue import Empty, Full, Queue
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
||||||
|
@ -29,17 +30,12 @@ from arealite.api.io_struct import (
|
||||||
)
|
)
|
||||||
from arealite.utils.data import concat_padded_tensors
|
from arealite.utils.data import concat_padded_tensors
|
||||||
from arealite.utils.http import arequest_with_retry, get_default_connector
|
from arealite.utils.http import arequest_with_retry, get_default_connector
|
||||||
from realhf.base import logging, name_resolve, names, pkg_version
|
from realhf.base import logging, name_resolve, names
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from arealite.api.workflow_api import RolloutWorkflow
|
from arealite.api.workflow_api import RolloutWorkflow
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if pkg_version.is_available("sglang"):
|
|
||||||
if pkg_version.is_version_greater_or_equal("sglang", "0.4.4"):
|
|
||||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "output_ids"
|
|
||||||
else:
|
|
||||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
|
|
||||||
|
|
||||||
ROLLOUT_POLL_WAIT_TIME = 0.05
|
ROLLOUT_POLL_WAIT_TIME = 0.05
|
||||||
RID_CACHE_SIZE = 128
|
RID_CACHE_SIZE = 128
|
||||||
|
@ -93,10 +89,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
def check_health(self, base_url):
|
def check_health(self, base_url):
|
||||||
# Check server endpoint
|
# Check server endpoint
|
||||||
try:
|
try:
|
||||||
response = requests.get(
|
response = requests.get(f"{base_url}/health", timeout=30)
|
||||||
f"{base_url}/metrics",
|
|
||||||
timeout=30,
|
|
||||||
)
|
|
||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
return False
|
return False
|
||||||
|
@ -125,7 +118,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
"""Thread that runs the rollout loop."""
|
"""Thread that runs the rollout loop."""
|
||||||
try:
|
try:
|
||||||
uvloop.run(self._rollout_thread_async())
|
uvloop.run(self._rollout_thread_async())
|
||||||
except Exception as e:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
async def _rollout_thread_async(self):
|
async def _rollout_thread_async(self):
|
||||||
|
@ -272,9 +265,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
accumulated_output_logprobs = []
|
accumulated_output_logprobs = []
|
||||||
accumulated_versions = []
|
accumulated_versions = []
|
||||||
|
|
||||||
# Deal with rollout interruption
|
# A single "rid" shares the same sever to allow KV cache reuse
|
||||||
stop_reason = "length"
|
|
||||||
|
|
||||||
if req.rid in self.rid_to_address:
|
if req.rid in self.rid_to_address:
|
||||||
server_addr = self.rid_to_address[req.rid]
|
server_addr = self.rid_to_address[req.rid]
|
||||||
else:
|
else:
|
||||||
|
@ -286,10 +277,19 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
self.rid_to_address[req.rid] = server_addr
|
self.rid_to_address[req.rid] = server_addr
|
||||||
self.rid_queue.append(req.rid)
|
self.rid_queue.append(req.rid)
|
||||||
|
|
||||||
|
# Deal with rollout interruption
|
||||||
|
# "abort" is the stop reason for later v0.4.9.post2 after
|
||||||
|
# we call the pause_generation endpoint
|
||||||
|
stop_reason = None
|
||||||
while (
|
while (
|
||||||
stop_reason != "stop"
|
stop_reason != "stop"
|
||||||
and len(accumulated_output_tokens) < gconfig.max_new_tokens
|
and len(accumulated_output_tokens) < gconfig.max_new_tokens
|
||||||
):
|
):
|
||||||
|
# Request is interrupted, wait for some time to avoid interfering
|
||||||
|
# with update weights requests
|
||||||
|
if stop_reason is not None:
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
# loop until the generation is complete
|
# loop until the generation is complete
|
||||||
result = await arequest_with_retry(
|
result = await arequest_with_retry(
|
||||||
session=self.session,
|
session=self.session,
|
||||||
|
@ -301,8 +301,17 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
timeout=self.config.request_timeout,
|
timeout=self.config.request_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse response
|
|
||||||
meta_info = result["meta_info"]
|
meta_info = result["meta_info"]
|
||||||
|
# Check if generation is complete
|
||||||
|
finish_reason = meta_info["finish_reason"]
|
||||||
|
stop_reason = finish_reason["type"]
|
||||||
|
if (
|
||||||
|
stop_reason == "abort"
|
||||||
|
and finish_reason.get("message") == "Abort before prefill"
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse response
|
||||||
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
|
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
|
||||||
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
|
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
|
||||||
|
|
||||||
|
@ -312,11 +321,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
# FIXME: Update with actual server versions
|
# FIXME: Update with actual server versions
|
||||||
accumulated_versions.extend([-1] * len(output_tokens))
|
accumulated_versions.extend([-1] * len(output_tokens))
|
||||||
|
|
||||||
# Check if generation is complete
|
payload["input_ids"] += result["output_ids"]
|
||||||
finish_reason = meta_info["finish_reason"]
|
|
||||||
stop_reason = finish_reason["type"]
|
|
||||||
|
|
||||||
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
|
|
||||||
sample_params["max_new_tokens"] -= len(output_tokens)
|
sample_params["max_new_tokens"] -= len(output_tokens)
|
||||||
|
|
||||||
latency = time.perf_counter() - start_time
|
latency = time.perf_counter() - start_time
|
||||||
|
@ -344,30 +349,24 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def update_weights(self, meta):
|
def update_weights(self, meta: WeightUpdateMeta):
|
||||||
executor = ThreadPoolExecutor(max_workers=1)
|
for addr in self.addresses:
|
||||||
return executor.submit(self._update_weights, meta)
|
res = requests.post(f"http://{addr}/pause_generation")
|
||||||
|
res.raise_for_status()
|
||||||
def _update_weights(self, meta: WeightUpdateMeta):
|
fut = Future()
|
||||||
if meta.type == "nccl":
|
if meta.type == "nccl":
|
||||||
if not self.distributed_weight_update_initialized:
|
fut = self.executor.submit(
|
||||||
self._init_distributed_weight_update(meta)
|
update_weights_from_distributed,
|
||||||
tik = time.perf_counter()
|
meta,
|
||||||
try:
|
self.addresses,
|
||||||
loop = asyncio.new_event_loop()
|
self.config.request_timeout,
|
||||||
asyncio.set_event_loop(loop)
|
not self.distributed_weight_update_initialized,
|
||||||
for param in meta.parameter_names:
|
|
||||||
jobs = [
|
|
||||||
self.aupdate_weights_from_distributed(addr, meta, param)
|
|
||||||
for addr in self.addresses
|
|
||||||
]
|
|
||||||
loop.run_until_complete(asyncio.gather(*jobs))
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
logger.info(
|
|
||||||
f"Distributed update weights done in {time.perf_counter() - tik}s"
|
|
||||||
)
|
)
|
||||||
self.set_version(meta.model_version)
|
|
||||||
|
def callback(fut):
|
||||||
|
self.distributed_weight_update_initialized = True
|
||||||
|
|
||||||
|
fut.add_done_callback(callback)
|
||||||
elif meta.type == "disk":
|
elif meta.type == "disk":
|
||||||
# Update weights from disk
|
# Update weights from disk
|
||||||
# Use ProcessPool to bypass python GIL for running async coroutines
|
# Use ProcessPool to bypass python GIL for running async coroutines
|
||||||
|
@ -375,7 +374,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
update_weights_from_disk,
|
update_weights_from_disk,
|
||||||
self.config.experiment_name,
|
self.config.experiment_name,
|
||||||
self.config.trial_name,
|
self.config.trial_name,
|
||||||
meta.model_version,
|
self.get_version(),
|
||||||
self.addresses,
|
self.addresses,
|
||||||
meta.path,
|
meta.path,
|
||||||
self.config.request_retries,
|
self.config.request_retries,
|
||||||
|
@ -383,64 +382,19 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
)
|
)
|
||||||
|
|
||||||
def callback(fut):
|
def callback(fut):
|
||||||
self.set_version(meta.model_version)
|
shutil.rmtree(meta.path, ignore_errors=True)
|
||||||
|
|
||||||
fut.add_done_callback(callback)
|
fut.add_done_callback(callback)
|
||||||
return fut
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
|
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
|
||||||
|
|
||||||
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
def callback(fut):
|
||||||
try:
|
for addr in self.addresses:
|
||||||
# Initialize weights update group
|
res = requests.post(f"http://{addr}/continue_generation")
|
||||||
jobs = [
|
res.raise_for_status()
|
||||||
self.ainit_weights_update_group(addr, meta) for addr in self.addresses
|
|
||||||
]
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
# asyncio event loop should be manually set when running asyncio stuff in another thread
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
loop.run_until_complete(asyncio.gather(*jobs))
|
|
||||||
self.distributed_weight_update_initialized = True
|
|
||||||
logger.info(f"Distributed update weights initialized")
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
async def ainit_weights_update_group(self, addr: str, meta: WeightUpdateMeta):
|
fut.add_done_callback(callback)
|
||||||
rank_offset = 1 + self.addresses.index(addr) * meta.tp_size
|
return fut
|
||||||
payload = {
|
|
||||||
"master_address": meta.master_address,
|
|
||||||
"master_port": str(meta.master_port),
|
|
||||||
"rank_offset": rank_offset,
|
|
||||||
"world_size": meta.world_size,
|
|
||||||
"group_name": meta.group_name,
|
|
||||||
"backend": "nccl",
|
|
||||||
}
|
|
||||||
res = await arequest_with_retry(
|
|
||||||
addr=addr,
|
|
||||||
endpoint="/init_weights_update_group",
|
|
||||||
payload=payload,
|
|
||||||
method="POST",
|
|
||||||
max_retries=1,
|
|
||||||
timeout=self.config.request_timeout,
|
|
||||||
)
|
|
||||||
assert res["success"]
|
|
||||||
|
|
||||||
async def aupdate_weights_from_distributed(
|
|
||||||
self, addr: str, meta: WeightUpdateMeta, parameter_name: str
|
|
||||||
):
|
|
||||||
res = await arequest_with_retry(
|
|
||||||
addr=addr,
|
|
||||||
endpoint="/update_weights_from_distributed",
|
|
||||||
payload={
|
|
||||||
"name": parameter_name,
|
|
||||||
"dtype": "bfloat16",
|
|
||||||
"shape": meta.state_dict_key_to_shape[parameter_name],
|
|
||||||
},
|
|
||||||
method="POST",
|
|
||||||
max_retries=1,
|
|
||||||
timeout=self.config.request_timeout,
|
|
||||||
)
|
|
||||||
assert res["success"]
|
|
||||||
|
|
||||||
def get_capacity(self):
|
def get_capacity(self):
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
|
@ -546,27 +500,6 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
self.paused.clear()
|
self.paused.clear()
|
||||||
|
|
||||||
|
|
||||||
async def aupdate_weights_from_disk(
|
|
||||||
session, addr, path: str, request_retries: int, request_timeout: float
|
|
||||||
):
|
|
||||||
tik = time.time()
|
|
||||||
res = await arequest_with_retry(
|
|
||||||
addr=addr,
|
|
||||||
session=session,
|
|
||||||
endpoint="/update_weights_from_disk",
|
|
||||||
payload=dict(model_path=str(path), allow_interrupt=True),
|
|
||||||
method="POST",
|
|
||||||
max_retries=request_retries,
|
|
||||||
timeout=request_timeout,
|
|
||||||
)
|
|
||||||
assert res["success"]
|
|
||||||
if "num_paused_requests" in res:
|
|
||||||
logger.info(
|
|
||||||
f"{res['num_paused_requests']} requests are interrupted "
|
|
||||||
f"during updating weights for server {addr}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def update_weights_from_disk(
|
def update_weights_from_disk(
|
||||||
experiment_name,
|
experiment_name,
|
||||||
trial_name,
|
trial_name,
|
||||||
|
@ -596,12 +529,14 @@ def update_weights_from_disk(
|
||||||
connector=get_default_connector(),
|
connector=get_default_connector(),
|
||||||
)
|
)
|
||||||
jobs = [
|
jobs = [
|
||||||
aupdate_weights_from_disk(
|
arequest_with_retry(
|
||||||
session=session,
|
|
||||||
addr=addr,
|
addr=addr,
|
||||||
path=path,
|
session=session,
|
||||||
request_retries=request_retries,
|
endpoint="/update_weights_from_disk",
|
||||||
request_timeout=request_timeout,
|
payload=dict(model_path=str(path)),
|
||||||
|
method="POST",
|
||||||
|
max_retries=request_retries,
|
||||||
|
timeout=request_timeout,
|
||||||
)
|
)
|
||||||
for addr in addresses
|
for addr in addresses
|
||||||
]
|
]
|
||||||
|
@ -612,3 +547,72 @@ def update_weights_from_disk(
|
||||||
)
|
)
|
||||||
|
|
||||||
return uvloop.run(_fn())
|
return uvloop.run(_fn())
|
||||||
|
|
||||||
|
|
||||||
|
def update_weights_from_distributed(
|
||||||
|
meta: WeightUpdateMeta,
|
||||||
|
addresses: List[str],
|
||||||
|
request_timeout,
|
||||||
|
init_group: bool,
|
||||||
|
):
|
||||||
|
async def _fn():
|
||||||
|
tik = time.perf_counter()
|
||||||
|
if init_group:
|
||||||
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
ainit_weights_update_group(addr, i, meta, request_timeout)
|
||||||
|
for i, addr in enumerate(addresses)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
arequest_with_retry(
|
||||||
|
addr=addr,
|
||||||
|
endpoint="/update_weights_from_distributed",
|
||||||
|
payload={
|
||||||
|
"names": [pspec.name for pspec in meta.nccl_param_specs],
|
||||||
|
"dtypes": [pspec.dtype for pspec in meta.nccl_param_specs],
|
||||||
|
"shapes": [pspec.shape for pspec in meta.nccl_param_specs],
|
||||||
|
"group_name": meta.nccl_group_name,
|
||||||
|
},
|
||||||
|
method="POST",
|
||||||
|
max_retries=1,
|
||||||
|
timeout=request_timeout,
|
||||||
|
)
|
||||||
|
for addr in addresses
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger.info(f"Distributed update weights done in {time.perf_counter() - tik}s")
|
||||||
|
|
||||||
|
return uvloop.run(_fn())
|
||||||
|
|
||||||
|
|
||||||
|
async def ainit_weights_update_group(
|
||||||
|
addr: str,
|
||||||
|
server_idx: int,
|
||||||
|
meta: WeightUpdateMeta,
|
||||||
|
request_timeout: float,
|
||||||
|
):
|
||||||
|
assert meta.alloc_mode is not None
|
||||||
|
if meta.alloc_mode.gen_pp_size != 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"NCCL weight update with PP size > 1 is not implemented yet."
|
||||||
|
)
|
||||||
|
rank_offset = 1 + server_idx * meta.alloc_mode.gen_tp_size
|
||||||
|
payload = {
|
||||||
|
"master_address": meta.nccl_master_address,
|
||||||
|
"master_port": str(meta.nccl_master_port),
|
||||||
|
"rank_offset": rank_offset,
|
||||||
|
"world_size": meta.alloc_mode.gen_world_size + 1,
|
||||||
|
"backend": "nccl",
|
||||||
|
"group_name": meta.nccl_group_name,
|
||||||
|
}
|
||||||
|
res = await arequest_with_retry(
|
||||||
|
addr=addr,
|
||||||
|
endpoint="/init_weights_update_group",
|
||||||
|
payload=payload,
|
||||||
|
method="POST",
|
||||||
|
max_retries=1,
|
||||||
|
timeout=request_timeout,
|
||||||
|
)
|
||||||
|
assert res["success"]
|
||||||
|
|
|
@ -283,7 +283,7 @@ def main_local():
|
||||||
if not cfg.server_only:
|
if not cfg.server_only:
|
||||||
launcher.submit(
|
launcher.submit(
|
||||||
job_name="trainer",
|
job_name="trainer",
|
||||||
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} --standalone {' '.join(sys.argv[1:])}",
|
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} --master-addr localhost --master-port {find_free_ports(1, (10000, 50000))[0]} {' '.join(sys.argv[1:])}",
|
||||||
gpu=alloc_mode.train_world_size,
|
gpu=alloc_mode.train_world_size,
|
||||||
env_vars=dict(AREAL_LLM_SERVER_ADDRS=",".join(server_addrs)),
|
env_vars=dict(AREAL_LLM_SERVER_ADDRS=",".join(server_addrs)),
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,10 +3,8 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import ray
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from arealite.api.cli_args import (
|
from arealite.api.cli_args import (
|
||||||
|
@ -19,12 +17,12 @@ from arealite.api.cli_args import (
|
||||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||||
from arealite.utils.launcher import TRITON_CACHE_PATH
|
from arealite.utils.launcher import TRITON_CACHE_PATH
|
||||||
from arealite.utils.network import find_free_ports, gethostip
|
from arealite.utils.network import find_free_ports, gethostip
|
||||||
from realhf.base import logging, name_resolve, names, pkg_version
|
from realhf.base import logging, name_resolve, names
|
||||||
|
|
||||||
logger = logging.getLogger("SGLangServer Wrapper")
|
logger = logging.getLogger("SGLangServer Wrapper")
|
||||||
|
|
||||||
|
|
||||||
def execute_shell_command(command: str) -> subprocess.Popen:
|
def launch_server_cmd(command: str) -> subprocess.Popen:
|
||||||
"""
|
"""
|
||||||
Execute a shell command and return its process handle.
|
Execute a shell command and return its process handle.
|
||||||
"""
|
"""
|
||||||
|
@ -45,46 +43,6 @@ def execute_shell_command(command: str) -> subprocess.Popen:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_sglang_patch():
|
|
||||||
p = Path(os.path.dirname(__file__))
|
|
||||||
patch_path = str(
|
|
||||||
p.parent.parent
|
|
||||||
/ "patch"
|
|
||||||
/ "sglang"
|
|
||||||
/ f"v{pkg_version.get_version('sglang')}.patch"
|
|
||||||
)
|
|
||||||
|
|
||||||
target_path = ""
|
|
||||||
sglang_meta = subprocess.check_output(
|
|
||||||
"python3 -m pip show sglang", shell=True
|
|
||||||
).decode("ascii")
|
|
||||||
for line in sglang_meta.split("\n"):
|
|
||||||
line = line.strip()
|
|
||||||
if line.startswith("Editable project location: "):
|
|
||||||
target_path = str(Path(line.split(": ")[1]).parent)
|
|
||||||
|
|
||||||
if target_path:
|
|
||||||
proc = subprocess.Popen(
|
|
||||||
["git", "apply", patch_path],
|
|
||||||
cwd=target_path,
|
|
||||||
stderr=sys.stdout,
|
|
||||||
stdout=sys.stdout,
|
|
||||||
)
|
|
||||||
proc.wait()
|
|
||||||
logger.info(f"Applied SGLang patch at {target_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def launch_server_cmd(command: str):
|
|
||||||
"""
|
|
||||||
Launch the server using the given command.
|
|
||||||
If no port is specified, a free port is reserved.
|
|
||||||
"""
|
|
||||||
if not ray.is_initialized():
|
|
||||||
apply_sglang_patch()
|
|
||||||
process = execute_shell_command(command)
|
|
||||||
return process
|
|
||||||
|
|
||||||
|
|
||||||
def wait_for_server(base_url: str, timeout: Optional[int] = None) -> None:
|
def wait_for_server(base_url: str, timeout: Optional[int] = None) -> None:
|
||||||
"""Wait for the server to be ready by polling the /v1/models endpoint.
|
"""Wait for the server to be ready by polling the /v1/models endpoint.
|
||||||
|
|
||||||
|
|
|
@ -12,10 +12,9 @@ from arealite.api.cli_args import (
|
||||||
SGLangConfig,
|
SGLangConfig,
|
||||||
TrainEngineConfig,
|
TrainEngineConfig,
|
||||||
)
|
)
|
||||||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
|
||||||
from arealite.engine.fsdp_engine import FSDPEngine
|
from arealite.engine.fsdp_engine import FSDPEngine
|
||||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||||
from arealite.utils.network import find_free_ports
|
|
||||||
from realhf.base import network
|
from realhf.base import network
|
||||||
|
|
||||||
EXPR_NAME = "test_fsdp_engine_nccl"
|
EXPR_NAME = "test_fsdp_engine_nccl"
|
||||||
|
@ -33,7 +32,7 @@ RUN_SERVER_TIMEOUT = 180
|
||||||
|
|
||||||
def check_server_health(base_url):
|
def check_server_health(base_url):
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{base_url}/metrics", timeout=30)
|
response = requests.get(f"{base_url}/health", timeout=30)
|
||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
except requests.exceptions.RequestException:
|
except requests.exceptions.RequestException:
|
||||||
return False
|
return False
|
||||||
|
@ -82,14 +81,14 @@ def sglang_server_nccl():
|
||||||
|
|
||||||
|
|
||||||
def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server_nccl):
|
def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server_nccl):
|
||||||
# 设置分布式环境变量
|
# Set environment variables for torch distributed
|
||||||
os.environ["WORLD_SIZE"] = "1"
|
os.environ["WORLD_SIZE"] = "1"
|
||||||
os.environ["RANK"] = "0"
|
os.environ["RANK"] = "0"
|
||||||
os.environ["LOCAL_RANK"] = "0"
|
os.environ["LOCAL_RANK"] = "0"
|
||||||
os.environ["MASTER_ADDR"] = HOST
|
os.environ["MASTER_ADDR"] = HOST
|
||||||
os.environ["MASTER_PORT"] = str(MASTER_PORT)
|
os.environ["MASTER_PORT"] = str(MASTER_PORT)
|
||||||
|
|
||||||
# 启动本地FSDPEngine
|
# Initialize FSDPEngine
|
||||||
engine_config = TrainEngineConfig(
|
engine_config = TrainEngineConfig(
|
||||||
experiment_name=EXPR_NAME,
|
experiment_name=EXPR_NAME,
|
||||||
trial_name=TRIAL_NAME,
|
trial_name=TRIAL_NAME,
|
||||||
|
@ -100,38 +99,26 @@ def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server
|
||||||
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
|
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
|
||||||
engine.initialize(None, ft_spec)
|
engine.initialize(None, ft_spec)
|
||||||
|
|
||||||
# 启动远端RemoteSGLangEngine
|
# Initialize RemoteSGLangEngine
|
||||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||||
remote_engine = RemoteSGLangEngine(config)
|
remote_engine = RemoteSGLangEngine(config)
|
||||||
remote_engine.initialize(None, None)
|
remote_engine.initialize(None, None)
|
||||||
|
|
||||||
# 构造WeightUpdateMeta(type=nccl)
|
# Get WeightUpdateMeta
|
||||||
param_meta = engine.get_param_meta_for_distributed_update()
|
meta = WeightUpdateMeta.from_fsdp_nccl(
|
||||||
meta = WeightUpdateMeta(
|
AllocationMode.from_str("sglang.d1p1t1+d1p1t1"),
|
||||||
type="nccl",
|
engine,
|
||||||
path=None,
|
nccl_group_name=GROUP_NAME,
|
||||||
alloc_mode=None,
|
|
||||||
comm_backend="nccl",
|
|
||||||
model_version=123,
|
|
||||||
tp_size=1,
|
|
||||||
master_address="localhost",
|
|
||||||
master_port=find_free_ports(1)[0],
|
|
||||||
world_size=2,
|
|
||||||
group_name=GROUP_NAME,
|
|
||||||
parameter_names=list(param_meta.keys()),
|
|
||||||
state_dict_key_to_shape=param_meta,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 本地engine广播参数
|
# Broadcast weights
|
||||||
future = remote_engine.update_weights(meta)
|
future = remote_engine.update_weights(meta)
|
||||||
print("got future", flush=True)
|
print("got future", flush=True)
|
||||||
engine.upload_weights(meta)
|
engine.upload_weights(meta)
|
||||||
print("uploaded wexights to remote engine", flush=True)
|
print("uploaded wexights to remote engine", flush=True)
|
||||||
# 远端engine拉取参数
|
# Wait for remote engine to finish
|
||||||
future.result(timeout=120)
|
future.result(timeout=120)
|
||||||
print("got result", flush=True)
|
print("got result", flush=True)
|
||||||
# 检查远端参数版本
|
|
||||||
assert remote_engine.get_version() == 123
|
|
||||||
remote_engine.destroy()
|
remote_engine.destroy()
|
||||||
engine.destroy()
|
engine.destroy()
|
||||||
|
|
|
@ -2,7 +2,6 @@ import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import uuid
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
@ -14,7 +13,7 @@ from arealite.api.cli_args import (
|
||||||
InferenceEngineConfig,
|
InferenceEngineConfig,
|
||||||
SGLangConfig,
|
SGLangConfig,
|
||||||
)
|
)
|
||||||
from arealite.api.io_struct import LLMRequest, LLMResponse, WeightUpdateMeta
|
from arealite.api.io_struct import WeightUpdateMeta
|
||||||
from arealite.utils import network
|
from arealite.utils import network
|
||||||
from realhf.api.core.data_api import load_hf_tokenizer
|
from realhf.api.core.data_api import load_hf_tokenizer
|
||||||
|
|
||||||
|
@ -31,10 +30,7 @@ RUN_SERVER_TIMEOUT = 180
|
||||||
|
|
||||||
def check_server_health(base_url):
|
def check_server_health(base_url):
|
||||||
try:
|
try:
|
||||||
response = requests.get(
|
response = requests.get(f"{base_url}/health", timeout=30)
|
||||||
f"{base_url}/metrics",
|
|
||||||
timeout=30,
|
|
||||||
)
|
|
||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
return False
|
return False
|
||||||
|
@ -77,29 +73,6 @@ def sglang_server():
|
||||||
process.terminate()
|
process.terminate()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_remote_sglang_generate(sglang_server):
|
|
||||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
|
||||||
|
|
||||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
|
||||||
tokenizer = load_hf_tokenizer(MODEL_PATH)
|
|
||||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
|
||||||
engine = RemoteSGLangEngine(config)
|
|
||||||
req = LLMRequest(
|
|
||||||
rid=str(uuid.uuid4()),
|
|
||||||
input_ids=tokenizer.encode("hello! how are you today"),
|
|
||||||
gconfig=GenerationHyperparameters(max_new_tokens=16),
|
|
||||||
)
|
|
||||||
resp = await engine.agenerate(req)
|
|
||||||
assert isinstance(resp, LLMResponse)
|
|
||||||
assert resp.input_tokens == req.input_ids
|
|
||||||
assert (
|
|
||||||
len(resp.output_logprobs)
|
|
||||||
== len(resp.output_tokens)
|
|
||||||
== len(resp.output_versions)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_samples", [1, 2, 4])
|
@pytest.mark.parametrize("n_samples", [1, 2, 4])
|
||||||
def test_remote_sglang_rollout(sglang_server, n_samples):
|
def test_remote_sglang_rollout(sglang_server, n_samples):
|
||||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||||
|
@ -211,6 +184,7 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
|
||||||
engine = FSDPEngine(engine_config)
|
engine = FSDPEngine(engine_config)
|
||||||
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
|
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
|
||||||
engine.initialize(None, ft_spec)
|
engine.initialize(None, ft_spec)
|
||||||
|
engine.model_version = 100
|
||||||
|
|
||||||
# setup name resolve
|
# setup name resolve
|
||||||
import realhf.base.name_resolve as name_resolve
|
import realhf.base.name_resolve as name_resolve
|
||||||
|
@ -227,13 +201,11 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
|
||||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||||
inf_engine = RemoteSGLangEngine(config)
|
inf_engine = RemoteSGLangEngine(config)
|
||||||
inf_engine.initialize(None, None)
|
inf_engine.initialize(None, None)
|
||||||
|
inf_engine.set_version(100)
|
||||||
# test update weights
|
# test update weights
|
||||||
path = tmp_path_factory.mktemp("upload_weights_from_disk")
|
path = tmp_path_factory.mktemp("upload_weights_from_disk")
|
||||||
update_weight_meta = WeightUpdateMeta(
|
update_weight_meta = WeightUpdateMeta(type="disk", path=str(path))
|
||||||
type="disk", path=path, alloc_mode=None, comm_backend=None, model_version=100
|
|
||||||
)
|
|
||||||
future = inf_engine.update_weights(update_weight_meta)
|
future = inf_engine.update_weights(update_weight_meta)
|
||||||
engine.upload_weights(update_weight_meta)
|
engine.upload_weights(update_weight_meta)
|
||||||
future.result()
|
future.result()
|
||||||
assert inf_engine.get_version() == 100
|
|
||||||
inf_engine.destroy()
|
inf_engine.destroy()
|
||||||
|
|
|
@ -1,310 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
import colorama
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from datasets import load_dataset
|
|
||||||
from datasets.distributed import split_dataset_by_node
|
|
||||||
from tensordict import TensorDict
|
|
||||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
|
||||||
from transformers import PreTrainedTokenizerFast
|
|
||||||
|
|
||||||
from arealite.api.cli_args import (
|
|
||||||
GenerationHyperparameters,
|
|
||||||
GRPOConfig,
|
|
||||||
load_expr_config,
|
|
||||||
)
|
|
||||||
from arealite.api.io_struct import FinetuneSpec, LLMRequest, WeightUpdateMeta
|
|
||||||
from arealite.api.workflow_api import RolloutWorkflow
|
|
||||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
|
||||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
|
||||||
from arealite.utils.data import concat_padded_tensors
|
|
||||||
from arealite.utils.device import log_gpu_stats
|
|
||||||
from arealite.utils.saver import Saver
|
|
||||||
from arealite.utils.stats_logger import StatsLogger
|
|
||||||
from realhf.api.core.data_api import load_hf_tokenizer
|
|
||||||
from realhf.base import logging, seeding, stats_tracker
|
|
||||||
|
|
||||||
logger = logging.getLogger("boba math")
|
|
||||||
|
|
||||||
|
|
||||||
class RLVRWorkflow(RolloutWorkflow):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
reward_fn,
|
|
||||||
gconfig: GenerationHyperparameters,
|
|
||||||
tokenizer: PreTrainedTokenizerFast,
|
|
||||||
dump_dir: str | None = None,
|
|
||||||
):
|
|
||||||
self.reward_fn = reward_fn
|
|
||||||
self.gconfig = gconfig
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.dump_dir = dump_dir
|
|
||||||
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
|
|
||||||
os.makedirs(self.dump_dir, exist_ok=True)
|
|
||||||
|
|
||||||
async def arun_episode(self, engine, data):
|
|
||||||
input_ids = self.tokenizer.encode(data["prompt"])
|
|
||||||
n_samples = self.gconfig.n_samples
|
|
||||||
req = LLMRequest(
|
|
||||||
rid=uuid.uuid4().hex,
|
|
||||||
input_ids=input_ids,
|
|
||||||
gconfig=self.gconfig.new(n_samples=1),
|
|
||||||
)
|
|
||||||
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
|
||||||
|
|
||||||
version = engine.get_version()
|
|
||||||
prompt_strs = []
|
|
||||||
completions_strs = []
|
|
||||||
rewards = []
|
|
||||||
seqlens = []
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for resp in resps:
|
|
||||||
seq = resp.input_tokens + resp.output_tokens
|
|
||||||
logprobs = [0.0] * resp.input_len + resp.output_logprobs
|
|
||||||
loss_mask = [0] * resp.input_len + [1] * resp.output_len
|
|
||||||
versions = [-1] * resp.input_len + resp.output_versions
|
|
||||||
|
|
||||||
prompt_str = data["prompt"]
|
|
||||||
completions_str = self.tokenizer.decode(resp.output_tokens)
|
|
||||||
prompt_strs.append(prompt_str)
|
|
||||||
completions_strs.append(completions_str)
|
|
||||||
seqlens.append(len(seq))
|
|
||||||
reward = self.reward_fn(
|
|
||||||
completions=completions_str,
|
|
||||||
prompt_ids=resp.input_tokens,
|
|
||||||
completion_ids=resp.output_tokens,
|
|
||||||
**data,
|
|
||||||
)
|
|
||||||
rewards.append(reward)
|
|
||||||
res = dict(
|
|
||||||
# unsqueeze to add an additional batch dimension
|
|
||||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
|
||||||
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
|
|
||||||
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
|
||||||
versions=torch.tensor(versions).unsqueeze(0),
|
|
||||||
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
|
|
||||||
# reward
|
|
||||||
rewards=torch.tensor([float(reward)]),
|
|
||||||
)
|
|
||||||
results.append(TensorDict(res, batch_size=[1]))
|
|
||||||
|
|
||||||
if self.dump_dir is not None:
|
|
||||||
os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True)
|
|
||||||
# Get the unique identifier for this prompt
|
|
||||||
qid = None
|
|
||||||
for key in ["query_id", "id", "qid"]:
|
|
||||||
qid = data.get(key, None)
|
|
||||||
if qid is not None:
|
|
||||||
break
|
|
||||||
qid = qid or uuid.uuid4().hex
|
|
||||||
|
|
||||||
# Dump rollout to file
|
|
||||||
with open(
|
|
||||||
os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a"
|
|
||||||
) as f:
|
|
||||||
n_samples = self.gconfig.n_samples
|
|
||||||
for i, (p, c, r, sl) in enumerate(
|
|
||||||
zip(prompt_strs, completions_strs, rewards, seqlens)
|
|
||||||
):
|
|
||||||
info = "\n".join(
|
|
||||||
[
|
|
||||||
f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.",
|
|
||||||
f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}",
|
|
||||||
f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
f.write(info + "\n")
|
|
||||||
|
|
||||||
return concat_padded_tensors(results)
|
|
||||||
|
|
||||||
|
|
||||||
def get_boba_math_dataset(tokenizer, rank, world_size):
|
|
||||||
dataset = load_dataset(
|
|
||||||
path="json",
|
|
||||||
split="train",
|
|
||||||
data_files="/storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl",
|
|
||||||
)
|
|
||||||
dataset = dataset.filter(lambda x: len(tokenizer.encode(x["prompt"])) <= 1024)
|
|
||||||
return split_dataset_by_node(dataset, rank=rank, world_size=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
def boba_reward_fn(
|
|
||||||
prompt, completions, prompt_ids, completion_ids, query_id, solutions, **kwargs
|
|
||||||
):
|
|
||||||
from pebble import ProcessExpired, ProcessPool
|
|
||||||
|
|
||||||
from realhf.impl.dataset.math_parser import process_results
|
|
||||||
|
|
||||||
jobs = []
|
|
||||||
with ProcessPool(max_workers=1) as executor:
|
|
||||||
for sol in solutions:
|
|
||||||
job = executor.schedule(
|
|
||||||
process_results, args=[completions, sol], timeout=15
|
|
||||||
)
|
|
||||||
jobs.append(job)
|
|
||||||
|
|
||||||
label = 0
|
|
||||||
for job in jobs:
|
|
||||||
try:
|
|
||||||
x = job.result()
|
|
||||||
except TimeoutError:
|
|
||||||
# print("[debug: timeout]")
|
|
||||||
logger.warning(f"Timeout occurred while justifying the math answer.")
|
|
||||||
x = (0, "timeout", "timeout")
|
|
||||||
except ProcessExpired as e:
|
|
||||||
logger.warning(f"Process terminated abnormally: {e}")
|
|
||||||
x = (0, "error", "error")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
|
|
||||||
x = (0, "error", "error")
|
|
||||||
label = label or x[0]
|
|
||||||
return label
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
config, _ = load_expr_config(args, GRPOConfig)
|
|
||||||
config: GRPOConfig
|
|
||||||
|
|
||||||
rank = int(os.getenv("RANK"))
|
|
||||||
world_size = int(os.getenv("WORLD_SIZE"))
|
|
||||||
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
|
||||||
|
|
||||||
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
|
|
||||||
|
|
||||||
# Create dataset and dataloaders
|
|
||||||
train_dataloader = StatefulDataLoader(
|
|
||||||
get_boba_math_dataset(tokenizer, rank, world_size),
|
|
||||||
batch_size=config.train_dataset.batch_size // world_size,
|
|
||||||
shuffle=config.train_dataset.shuffle,
|
|
||||||
num_workers=config.train_dataset.num_workers,
|
|
||||||
collate_fn=lambda x: x,
|
|
||||||
drop_last=config.train_dataset.drop_last,
|
|
||||||
)
|
|
||||||
ft_spec = FinetuneSpec(
|
|
||||||
total_train_epochs=config.total_train_epochs,
|
|
||||||
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
|
|
||||||
train_batch_size=config.train_dataset.batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize inference engine
|
|
||||||
rollout = RemoteSGLangEngine(config.rollout)
|
|
||||||
rollout.initialize(None, ft_spec)
|
|
||||||
|
|
||||||
# Initialize train engine
|
|
||||||
actor = FSDPPPOActor(config=config.actor)
|
|
||||||
actor.initialize(None, ft_spec)
|
|
||||||
ref = None
|
|
||||||
if config.actor.kl_ctl > 0 and config.ref is not None:
|
|
||||||
ref = FSDPPPOActor(config=config.ref)
|
|
||||||
ref.initialize(None, ft_spec)
|
|
||||||
|
|
||||||
# Create rollout workflow
|
|
||||||
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
|
|
||||||
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
|
|
||||||
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
|
|
||||||
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
|
|
||||||
workflow = RLVRWorkflow(
|
|
||||||
reward_fn=boba_reward_fn,
|
|
||||||
gconfig=config.gconfig,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
dump_dir=os.path.join(
|
|
||||||
StatsLogger.get_log_path(config.stats_logger), "generated"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run training.
|
|
||||||
saver = Saver(config.saver, ft_spec, for_recover=False)
|
|
||||||
logger = StatsLogger(config.stats_logger, ft_spec)
|
|
||||||
|
|
||||||
total_epochs = config.total_train_epochs
|
|
||||||
steps_per_epoch = len(train_dataloader)
|
|
||||||
max_steps = total_epochs * steps_per_epoch
|
|
||||||
|
|
||||||
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
|
||||||
data_generator = iter(train_dataloader)
|
|
||||||
for global_step in range(max_steps):
|
|
||||||
epoch = global_step // steps_per_epoch
|
|
||||||
step = global_step % steps_per_epoch
|
|
||||||
|
|
||||||
with stats_tracker.record_timing("rollout"):
|
|
||||||
if config.async_training:
|
|
||||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
data = next(data_generator)
|
|
||||||
except StopIteration:
|
|
||||||
data_generator = iter(train_dataloader)
|
|
||||||
data = next(data_generator)
|
|
||||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
|
||||||
|
|
||||||
batch = batch.to(actor.device)
|
|
||||||
# Create barrier to synchronize all rollout processes.
|
|
||||||
dist.barrier()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
|
|
||||||
with stats_tracker.record_timing("recompute_logp"):
|
|
||||||
logp = actor.compute_logp(batch)
|
|
||||||
batch["prox_logp"] = logp
|
|
||||||
log_gpu_stats("recompute logp")
|
|
||||||
|
|
||||||
if ref is not None:
|
|
||||||
with stats_tracker.record_timing("ref_logp"):
|
|
||||||
batch["ref_logp"] = ref.compute_logp(batch)
|
|
||||||
log_gpu_stats("ref logp")
|
|
||||||
|
|
||||||
with stats_tracker.record_timing("compute_advantage"):
|
|
||||||
actor.compute_advantages(batch)
|
|
||||||
log_gpu_stats("compute advantages")
|
|
||||||
|
|
||||||
with (
|
|
||||||
stats_tracker.record_timing("train_step"),
|
|
||||||
stats_tracker.scope("grpo_actor"),
|
|
||||||
):
|
|
||||||
stats = actor.ppo_update(batch)
|
|
||||||
actor.step_lr_scheduler()
|
|
||||||
log_gpu_stats("ppo update")
|
|
||||||
|
|
||||||
with stats_tracker.record_timing("update_weights"):
|
|
||||||
path = os.path.join(
|
|
||||||
Saver.get_save_checkpoint_root(config.saver),
|
|
||||||
"update_weights",
|
|
||||||
str(global_step + 1),
|
|
||||||
)
|
|
||||||
meta = WeightUpdateMeta(
|
|
||||||
type="disk",
|
|
||||||
path=path,
|
|
||||||
alloc_mode=None,
|
|
||||||
comm_backend=None,
|
|
||||||
model_version=global_step + 1,
|
|
||||||
)
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
future = rollout.update_weights(meta)
|
|
||||||
actor.upload_weights(meta)
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
future.result()
|
|
||||||
shutil.rmtree(path, ignore_errors=True)
|
|
||||||
dist.barrier()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
rollout.set_version(global_step + 1)
|
|
||||||
|
|
||||||
with stats_tracker.record_timing("save"):
|
|
||||||
saver.save(actor, epoch, step, global_step)
|
|
||||||
|
|
||||||
logger.commit(epoch, step, global_step, stats)
|
|
||||||
|
|
||||||
logger.close()
|
|
||||||
rollout.destroy()
|
|
||||||
if ref is not None:
|
|
||||||
ref.destroy()
|
|
||||||
actor.destroy()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main(sys.argv[1:])
|
|
|
@ -9,7 +9,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
|
|
||||||
from arealite.workflow.vision_rlvr import VisionRLVRWorkflow
|
from arealite.workflow.vision_rlvr import VisionRLVRWorkflow
|
||||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
|
||||||
from arealite.dataset.__init__ import get_custom_dataset
|
from arealite.dataset.__init__ import get_custom_dataset
|
||||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||||
|
@ -117,6 +117,18 @@ def main(args):
|
||||||
ref = FSDPPPOActor(config=config.ref)
|
ref = FSDPPPOActor(config=config.ref)
|
||||||
ref.initialize(None, ft_spec)
|
ref.initialize(None, ft_spec)
|
||||||
|
|
||||||
|
# NOTE: Weight update meta only requires address and free port of rank 0,
|
||||||
|
# but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks
|
||||||
|
# due to `engine.get_param_specs()`.
|
||||||
|
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
|
||||||
|
weight_update_meta = [
|
||||||
|
WeightUpdateMeta.from_fsdp_nccl(
|
||||||
|
AllocationMode.from_str(config.allocation_mode), actor
|
||||||
|
)
|
||||||
|
]
|
||||||
|
dist.broadcast_object_list(weight_update_meta, src=0)
|
||||||
|
weight_update_meta = weight_update_meta[0]
|
||||||
|
|
||||||
# Create rollout workflow
|
# Create rollout workflow
|
||||||
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
|
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
|
||||||
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
|
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
|
||||||
|
@ -189,24 +201,16 @@ def main(args):
|
||||||
log_gpu_stats("ppo update")
|
log_gpu_stats("ppo update")
|
||||||
|
|
||||||
with stats_tracker.record_timing("update_weights"):
|
with stats_tracker.record_timing("update_weights"):
|
||||||
meta = WeightUpdateMeta(
|
rollout.pause()
|
||||||
type="disk",
|
|
||||||
path=os.path.join(
|
|
||||||
Saver.get_save_checkpoint_root(config.saver),
|
|
||||||
"update_weights",
|
|
||||||
str(global_step),
|
|
||||||
),
|
|
||||||
alloc_mode=None,
|
|
||||||
comm_backend=None,
|
|
||||||
model_version=global_step + 1,
|
|
||||||
)
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
future = rollout.update_weights(meta)
|
future = rollout.update_weights(weight_update_meta)
|
||||||
actor.upload_weights(meta)
|
actor.upload_weights(weight_update_meta)
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
future.result()
|
future.result()
|
||||||
dist.barrier()
|
dist.barrier(device_ids=[actor.device.index])
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
rollout.resume()
|
||||||
|
actor.set_version(global_step + 1)
|
||||||
rollout.set_version(global_step + 1)
|
rollout.set_version(global_step + 1)
|
||||||
|
|
||||||
with stats_tracker.record_timing("save"):
|
with stats_tracker.record_timing("save"):
|
||||||
|
|
|
@ -1,141 +0,0 @@
|
||||||
experiment_name: lite-boba-math
|
|
||||||
trial_name: run1
|
|
||||||
|
|
||||||
cluster:
|
|
||||||
n_nodes: 16
|
|
||||||
n_gpus_per_node: 8
|
|
||||||
cluster_name: na132
|
|
||||||
fileroot: /storage/openpsi/experiments
|
|
||||||
name_resolve:
|
|
||||||
type: nfs
|
|
||||||
nfs_record_root: /storage/openpsi/experiments/name_resolve/lite-boba-math
|
|
||||||
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
|
|
||||||
|
|
||||||
seed: 1
|
|
||||||
total_train_epochs: 10
|
|
||||||
total_train_steps: null
|
|
||||||
tokenizer_path: ${actor.path}
|
|
||||||
allocation_mode: sglang.d96p1t1+d32p1t1
|
|
||||||
async_training: true
|
|
||||||
|
|
||||||
rollout:
|
|
||||||
experiment_name: ${experiment_name}
|
|
||||||
trial_name: ${trial_name}
|
|
||||||
max_concurrent_rollouts: 400
|
|
||||||
queue_size: null
|
|
||||||
consumer_batch_size: ${train_dataset.batch_size}
|
|
||||||
max_head_offpolicyness: 4
|
|
||||||
enable_rollout_tracing: true
|
|
||||||
|
|
||||||
gconfig:
|
|
||||||
n_samples: 16
|
|
||||||
min_new_tokens: 0
|
|
||||||
max_new_tokens: 30720
|
|
||||||
greedy: false
|
|
||||||
temperature: 1.0
|
|
||||||
|
|
||||||
actor:
|
|
||||||
experiment_name: ${experiment_name}
|
|
||||||
trial_name: ${trial_name}
|
|
||||||
path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/
|
|
||||||
init_from_scratch: false
|
|
||||||
disable_dropout: true
|
|
||||||
gradient_checkpointing: true
|
|
||||||
dtype: bfloat16
|
|
||||||
mb_spec:
|
|
||||||
max_tokens_per_mb: 32768
|
|
||||||
optimizer:
|
|
||||||
type: adam
|
|
||||||
lr: 1e-5
|
|
||||||
weight_decay: 0.01
|
|
||||||
beta1: 0.9
|
|
||||||
beta2: 0.999
|
|
||||||
eps: 1e-8
|
|
||||||
lr_scheduler_type: constant
|
|
||||||
gradient_clipping: 1.0
|
|
||||||
warmup_steps_proportion: 0.001
|
|
||||||
backend: fsdp
|
|
||||||
|
|
||||||
group_size: ${gconfig.n_samples}
|
|
||||||
group_adv_norm: false
|
|
||||||
eps_clip: 0.4
|
|
||||||
temperature: ${gconfig.temperature}
|
|
||||||
reward_scaling: 10.0
|
|
||||||
reward_bias: -0.5
|
|
||||||
kl_ctl: 0.0
|
|
||||||
ppo_n_minibatches: 4
|
|
||||||
recompute_logprob: true
|
|
||||||
use_decoupled_loss: true
|
|
||||||
behav_imp_weight_cap: 5.0
|
|
||||||
|
|
||||||
ref:
|
|
||||||
experiment_name: ${experiment_name}
|
|
||||||
trial_name: ${trial_name}
|
|
||||||
path: ${actor.path}
|
|
||||||
init_from_scratch: false
|
|
||||||
disable_dropout: true
|
|
||||||
dtype: ${actor.dtype}
|
|
||||||
mb_spec:
|
|
||||||
max_tokens_per_mb: 32768
|
|
||||||
optimizer: null
|
|
||||||
backend: fsdp
|
|
||||||
|
|
||||||
# SGLang
|
|
||||||
server_only: false
|
|
||||||
sglang:
|
|
||||||
model_path: ${actor.path}
|
|
||||||
random_seed: ${seed}
|
|
||||||
skip_tokenizer_init: true
|
|
||||||
dtype: ${actor.dtype}
|
|
||||||
max_running_requests: null
|
|
||||||
context_length: 32768
|
|
||||||
mem_fraction_static: 0.9
|
|
||||||
|
|
||||||
# datasets
|
|
||||||
train_dataset:
|
|
||||||
batch_size: 512
|
|
||||||
shuffle: true
|
|
||||||
pin_memory: true
|
|
||||||
|
|
||||||
# Utilities
|
|
||||||
saver:
|
|
||||||
experiment_name: ${experiment_name}
|
|
||||||
trial_name: ${trial_name}
|
|
||||||
fileroot: ${cluster.fileroot}
|
|
||||||
freq_epochs: 1
|
|
||||||
freq_steps: null
|
|
||||||
freq_secs: null
|
|
||||||
|
|
||||||
checkpointer:
|
|
||||||
experiment_name: ${experiment_name}
|
|
||||||
trial_name: ${trial_name}
|
|
||||||
fileroot: ${cluster.fileroot}
|
|
||||||
freq_epochs: 1
|
|
||||||
freq_steps: null
|
|
||||||
freq_secs: 3600
|
|
||||||
|
|
||||||
evaluator:
|
|
||||||
experiment_name: ${experiment_name}
|
|
||||||
trial_name: ${trial_name}
|
|
||||||
fileroot: ${cluster.fileroot}
|
|
||||||
freq_epochs: null
|
|
||||||
freq_steps: null
|
|
||||||
freq_secs: null
|
|
||||||
|
|
||||||
stats_logger:
|
|
||||||
experiment_name: ${experiment_name}
|
|
||||||
trial_name: ${trial_name}
|
|
||||||
fileroot: ${cluster.fileroot}
|
|
||||||
wandb:
|
|
||||||
mode: online
|
|
||||||
|
|
||||||
# Launcher
|
|
||||||
launcher:
|
|
||||||
inference_server_cpus_per_gpu: 15
|
|
||||||
inference_server_mem_per_gpu: 153600
|
|
||||||
trainer_cpus_per_gpu: 15
|
|
||||||
trainer_mem_per_gpu: 153600
|
|
||||||
slurm:
|
|
||||||
mount: /storage:/storage
|
|
||||||
trainer_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif
|
|
||||||
inference_server_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif
|
|
|
@ -1,8 +1,9 @@
|
||||||
experiment_name: gsm8k-grpo
|
experiment_name: gsm8k-grpo
|
||||||
trial_name: trial0
|
trial_name: trial0
|
||||||
allocation_mode: sglang.d4p1t1+d4p1t1
|
allocation_mode: sglang.d4p1t1+d4p1t1
|
||||||
|
|
||||||
cluster:
|
cluster:
|
||||||
|
n_nodes: 1
|
||||||
|
n_gpus_per_node: 8
|
||||||
fileroot: /tmp/arealite/experiments
|
fileroot: /tmp/arealite/experiments
|
||||||
n_nodes: 1
|
n_nodes: 1
|
||||||
n_gpus_per_node: 8
|
n_gpus_per_node: 8
|
||||||
|
@ -33,7 +34,7 @@ gconfig:
|
||||||
actor:
|
actor:
|
||||||
experiment_name: ${experiment_name}
|
experiment_name: ${experiment_name}
|
||||||
trial_name: ${trial_name}
|
trial_name: ${trial_name}
|
||||||
path: /storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/
|
path: Qwen/Qwen2-1.5B-Instruct
|
||||||
init_from_scratch: false
|
init_from_scratch: false
|
||||||
disable_dropout: true
|
disable_dropout: true
|
||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
|
|
|
@ -14,7 +14,7 @@ tokenizer_path: ${model.path}
|
||||||
model:
|
model:
|
||||||
experiment_name: ${experiment_name}
|
experiment_name: ${experiment_name}
|
||||||
trial_name: ${trial_name}
|
trial_name: ${trial_name}
|
||||||
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
|
path: Qwen/Qwen3-1.7B
|
||||||
init_from_scratch: false
|
init_from_scratch: false
|
||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
dtype: bfloat16
|
dtype: bfloat16
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -9,7 +8,7 @@ from datasets.distributed import split_dataset_by_node
|
||||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
|
|
||||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
|
||||||
from arealite.dataset.__init__ import get_custom_dataset
|
from arealite.dataset.__init__ import get_custom_dataset
|
||||||
from arealite.engine.ppo.actor import FSDPPPOActor
|
from arealite.engine.ppo.actor import FSDPPPOActor
|
||||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||||
|
@ -95,6 +94,18 @@ def main(args):
|
||||||
ref = FSDPPPOActor(config=config.ref)
|
ref = FSDPPPOActor(config=config.ref)
|
||||||
ref.initialize(None, ft_spec)
|
ref.initialize(None, ft_spec)
|
||||||
|
|
||||||
|
# NOTE: Weight update meta only requires address and free port of rank 0,
|
||||||
|
# but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks
|
||||||
|
# due to `engine.get_param_specs()`.
|
||||||
|
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
|
||||||
|
weight_update_meta = [
|
||||||
|
WeightUpdateMeta.from_fsdp_nccl(
|
||||||
|
AllocationMode.from_str(config.allocation_mode), actor
|
||||||
|
)
|
||||||
|
]
|
||||||
|
dist.broadcast_object_list(weight_update_meta, src=0)
|
||||||
|
weight_update_meta = weight_update_meta[0]
|
||||||
|
|
||||||
# Create rollout workflow
|
# Create rollout workflow
|
||||||
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
|
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
|
||||||
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
|
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
|
||||||
|
@ -138,7 +149,7 @@ def main(args):
|
||||||
|
|
||||||
batch = batch.to(actor.device)
|
batch = batch.to(actor.device)
|
||||||
# Create barrier to synchronize all rollout processes.
|
# Create barrier to synchronize all rollout processes.
|
||||||
dist.barrier()
|
dist.barrier(device_ids=[actor.device.index])
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
|
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
|
||||||
|
@ -165,26 +176,16 @@ def main(args):
|
||||||
log_gpu_stats("ppo update")
|
log_gpu_stats("ppo update")
|
||||||
|
|
||||||
with stats_tracker.record_timing("update_weights"):
|
with stats_tracker.record_timing("update_weights"):
|
||||||
path = os.path.join(
|
rollout.pause()
|
||||||
Saver.get_save_checkpoint_root(config.saver),
|
|
||||||
"update_weights",
|
|
||||||
str(global_step + 1),
|
|
||||||
)
|
|
||||||
meta = WeightUpdateMeta(
|
|
||||||
type="disk",
|
|
||||||
path=path,
|
|
||||||
alloc_mode=None,
|
|
||||||
comm_backend=None,
|
|
||||||
model_version=global_step + 1,
|
|
||||||
)
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
future = rollout.update_weights(meta)
|
future = rollout.update_weights(weight_update_meta)
|
||||||
actor.upload_weights(meta)
|
actor.upload_weights(weight_update_meta)
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
future.result()
|
future.result()
|
||||||
shutil.rmtree(path, ignore_errors=True)
|
dist.barrier(device_ids=[actor.device.index])
|
||||||
dist.barrier()
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
rollout.resume()
|
||||||
|
actor.set_version(global_step + 1)
|
||||||
rollout.set_version(global_step + 1)
|
rollout.set_version(global_step + 1)
|
||||||
|
|
||||||
with stats_tracker.record_timing("save"):
|
with stats_tracker.record_timing("save"):
|
||||||
|
|
|
@ -1,11 +0,0 @@
|
||||||
#!/bin/sh
|
|
||||||
AREAL_PATH=$PWD
|
|
||||||
cd /sglang
|
|
||||||
git apply $AREAL_PATH/patch/sglang/v0.4.6.post4.patch
|
|
||||||
cd $AREAL_PATH
|
|
||||||
|
|
||||||
# Package used for calculating math reward
|
|
||||||
pip install -e evaluation/latex2sympy
|
|
||||||
|
|
||||||
# Install AReaL
|
|
||||||
pip install -e .
|
|
|
@ -1,9 +1,9 @@
|
||||||
#/bin/bash
|
#/bin/bash
|
||||||
# basic dependencies
|
# basic dependencies
|
||||||
pip install -U pip
|
pip install -U pip
|
||||||
pip uninstall deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
|
pip uninstall pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
|
||||||
pip install nvidia-ml-py
|
pip install pynvml nvidia-ml-py
|
||||||
pip install -e evaluation/latex2sympy
|
pip install -e evaluation/latex2sympy
|
||||||
pip install vllm==0.8.5 --no-build-isolation
|
pip install vllm==0.8.5 --no-build-isolation
|
||||||
pip install flash_attn --no-build-isolation
|
pip install "flash-attn<=2.7.3" --no-build-isolation
|
||||||
pip install -r evaluation/requirements.txt
|
pip install -r evaluation/requirements.txt
|
|
@ -1,24 +1,14 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# basic dependencies
|
# basic dependencies
|
||||||
pip install -U pip
|
pip install -U pip
|
||||||
pip uninstall torch deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
|
pip uninstall pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
|
||||||
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0
|
pip install torch==2.7.1 torchaudio==2.7.1 torchvision==0.22.1 "deepspeed>=0.17.2" pynvml
|
||||||
pip install "sglang[all]==0.4.6.post4"
|
pip install "sglang[all]==0.4.9.post2"
|
||||||
pip install megatron-core==0.11.0 nvidia-ml-py
|
pip install megatron-core==0.11.0 nvidia-ml-py
|
||||||
pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose
|
pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose
|
||||||
pip install "flash-attn<=2.7.3" --no-build-isolation
|
pip install "flash-attn<=2.7.3" --no-build-isolation
|
||||||
|
|
||||||
# Package used for calculating math reward
|
# Package used for calculating math reward
|
||||||
pip install -e evaluation/latex2sympy
|
pip install -e evaluation/latex2sympy
|
||||||
|
|
||||||
# Install an editable sglang
|
|
||||||
rm -rf ./sglang
|
|
||||||
git clone -b v0.4.6.post4 https://github.com/sgl-project/sglang
|
|
||||||
AREAL_PATH=$PWD
|
|
||||||
cd sglang
|
|
||||||
git apply ../patch/sglang/v0.4.6.post4.patch
|
|
||||||
pip install -e "python[all]" --no-deps
|
|
||||||
cd $AREAL_PATH
|
|
||||||
|
|
||||||
# Install AReaL
|
# Install AReaL
|
||||||
pip install -e .
|
pip install -e .
|
||||||
|
|
|
@ -67,6 +67,7 @@ class InstallationValidator:
|
||||||
def test_flash_attn_functionality(self, flash_attn_module):
|
def test_flash_attn_functionality(self, flash_attn_module):
|
||||||
"""Test flash attention functionality."""
|
"""Test flash attention functionality."""
|
||||||
# Try to import key functions
|
# Try to import key functions
|
||||||
|
import flash_attn_2_cuda
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
print(" - Flash attention functions imported successfully")
|
print(" - Flash attention functions imported successfully")
|
||||||
|
|
||||||
|
@ -79,12 +80,12 @@ class InstallationValidator:
|
||||||
"""Test SGLang basic functionality."""
|
"""Test SGLang basic functionality."""
|
||||||
# Basic import test is sufficient for CI
|
# Basic import test is sufficient for CI
|
||||||
import sgl_kernel
|
import sgl_kernel
|
||||||
from sglang import launch_server
|
from sglang import Engine, launch_server
|
||||||
assert Version(get_version("sglang")) == Version("0.4.6.post4")
|
assert Version(get_version("sglang")) == Version("0.4.9.post2"), "SGLang version should be v0.4.9.post2"
|
||||||
print(" - SGLang imported successfully")
|
print(" - SGLang imported successfully")
|
||||||
|
|
||||||
def test_transformers(self, transformers_module):
|
def test_transformers(self, transformers_module):
|
||||||
assert Version(get_version("transformers")) == Version("4.51.1")
|
assert Version(get_version("transformers")) == Version("4.53.1"), "transformers version should be 4.53.1"
|
||||||
print(" - transformers imported successfully")
|
print(" - transformers imported successfully")
|
||||||
|
|
||||||
def validate_critical_dependencies(self):
|
def validate_critical_dependencies(self):
|
||||||
|
@ -140,7 +141,7 @@ class InstallationValidator:
|
||||||
self.test_import("flashattn_hopper", required=False)
|
self.test_import("flashattn_hopper", required=False)
|
||||||
|
|
||||||
# Optional utilities
|
# Optional utilities
|
||||||
self.test_import("tensorboardx", required=False)
|
self.test_import("tensorboardX", required=False)
|
||||||
self.test_import("swanlab", required=False)
|
self.test_import("swanlab", required=False)
|
||||||
self.test_import("matplotlib", required=False)
|
self.test_import("matplotlib", required=False)
|
||||||
self.test_import("seaborn", required=False)
|
self.test_import("seaborn", required=False)
|
||||||
|
|
|
@ -1,144 +0,0 @@
|
||||||
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
|
|
||||||
index 174656b2..33fe0a5f 100644
|
|
||||||
--- a/python/sglang/srt/managers/io_struct.py
|
|
||||||
+++ b/python/sglang/srt/managers/io_struct.py
|
|
||||||
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
|
|
||||||
success: bool
|
|
||||||
|
|
||||||
|
|
||||||
+@dataclass
|
|
||||||
+class InterruptAllReqInput:
|
|
||||||
+ pass
|
|
||||||
+
|
|
||||||
+
|
|
||||||
+@dataclass
|
|
||||||
+class InterruptAllReqOutput:
|
|
||||||
+ num_interrupted_requests: int
|
|
||||||
+
|
|
||||||
+
|
|
||||||
@dataclass
|
|
||||||
class UpdateWeightFromDiskReqInput:
|
|
||||||
# The model path with the new weights
|
|
||||||
model_path: str
|
|
||||||
+ allow_interrupt: bool = False
|
|
||||||
# The format to load the weights
|
|
||||||
load_format: Optional[str] = None
|
|
||||||
|
|
||||||
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
|
||||||
index 8891115c..843a8a82 100644
|
|
||||||
--- a/python/sglang/srt/managers/scheduler.py
|
|
||||||
+++ b/python/sglang/srt/managers/scheduler.py
|
|
||||||
@@ -70,6 +70,8 @@ from sglang.srt.managers.io_struct import (
|
|
||||||
HealthCheckOutput,
|
|
||||||
InitWeightsUpdateGroupReqInput,
|
|
||||||
InitWeightsUpdateGroupReqOutput,
|
|
||||||
+ InterruptAllReqInput,
|
|
||||||
+ InterruptAllReqOutput,
|
|
||||||
OpenSessionReqInput,
|
|
||||||
OpenSessionReqOutput,
|
|
||||||
ProfileReq,
|
|
||||||
@@ -419,6 +421,7 @@ class Scheduler(
|
|
||||||
# Init request dispatcher
|
|
||||||
self._request_dispatcher = TypeBasedDispatcher(
|
|
||||||
[
|
|
||||||
+ (InterruptAllReqInput, self.interrupt_all_requests),
|
|
||||||
(TokenizedGenerateReqInput, self.handle_generate_request),
|
|
||||||
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
|
||||||
(FlushCacheReqInput, self.flush_cache_wrapped),
|
|
||||||
@@ -1938,6 +1941,15 @@ class Scheduler(
|
|
||||||
def _pause_engine(self) -> Tuple[List[Req], int]:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
|
|
||||||
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
|
|
||||||
+ for req in self.waiting_queue:
|
|
||||||
+ req.sampling_params.max_new_tokens = 0
|
|
||||||
+ for req in self.running_batch.reqs:
|
|
||||||
+ req.sampling_params.max_new_tokens = len(req.output_ids)
|
|
||||||
+ logger.info(f"Interrupt {num} requests.")
|
|
||||||
+ return InterruptAllReqOutput(num)
|
|
||||||
+
|
|
||||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
|
||||||
"""In-place update of the weights from disk."""
|
|
||||||
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
|
||||||
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
|
||||||
index 82709b09..bfab3ce7 100644
|
|
||||||
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
|
||||||
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
|
||||||
@@ -76,6 +76,8 @@ from sglang.srt.managers.io_struct import (
|
|
||||||
HealthCheckOutput,
|
|
||||||
InitWeightsUpdateGroupReqInput,
|
|
||||||
InitWeightsUpdateGroupReqOutput,
|
|
||||||
+ InterruptAllReqInput,
|
|
||||||
+ InterruptAllReqOutput,
|
|
||||||
OpenSessionReqInput,
|
|
||||||
OpenSessionReqOutput,
|
|
||||||
ProfileReq,
|
|
||||||
@@ -265,6 +267,9 @@ class TokenizerManager:
|
|
||||||
self.resume_memory_occupation_communicator = _Communicator(
|
|
||||||
self.send_to_scheduler, server_args.dp_size
|
|
||||||
)
|
|
||||||
+ self.interrupt_requests_communicator = _Communicator(
|
|
||||||
+ self.send_to_scheduler, server_args.dp_size
|
|
||||||
+ )
|
|
||||||
self.flush_cache_communicator = _Communicator(
|
|
||||||
self.send_to_scheduler, server_args.dp_size
|
|
||||||
)
|
|
||||||
@@ -294,6 +299,10 @@ class TokenizerManager:
|
|
||||||
UpdateWeightFromDiskReqOutput,
|
|
||||||
self._handle_update_weights_from_disk_req_output,
|
|
||||||
),
|
|
||||||
+ (
|
|
||||||
+ InterruptAllReqOutput,
|
|
||||||
+ self.interrupt_requests_communicator.handle_recv,
|
|
||||||
+ ),
|
|
||||||
(
|
|
||||||
InitWeightsUpdateGroupReqOutput,
|
|
||||||
self.init_weights_update_group_communicator.handle_recv,
|
|
||||||
@@ -767,6 +776,13 @@ class TokenizerManager:
|
|
||||||
) -> Tuple[bool, str]:
|
|
||||||
self.auto_create_handle_loop()
|
|
||||||
|
|
||||||
+ if obj.allow_interrupt:
|
|
||||||
+ num_interrupted_requests = await self.interrupt_all_requests(
|
|
||||||
+ InterruptAllReqInput()
|
|
||||||
+ )
|
|
||||||
+ # Set a break point to wait for the interrupt to finish
|
|
||||||
+ await asyncio.sleep(0.1)
|
|
||||||
+
|
|
||||||
# default the load format to the server_args
|
|
||||||
if obj.load_format is None:
|
|
||||||
obj.load_format = self.server_args.load_format
|
|
||||||
@@ -776,7 +792,12 @@ class TokenizerManager:
|
|
||||||
# Hold the lock if it is not async. This means that weight sync
|
|
||||||
# cannot run while requests are in progress.
|
|
||||||
async with self.model_update_lock.writer_lock:
|
|
||||||
- return await self._wait_for_model_update_from_disk(obj)
|
|
||||||
+ success, message, n_paused = (
|
|
||||||
+ await self._wait_for_model_update_from_disk(obj)
|
|
||||||
+ )
|
|
||||||
+ if obj.allow_interrupt:
|
|
||||||
+ return success, message, num_interrupted_requests
|
|
||||||
+ return success, message, n_paused
|
|
||||||
|
|
||||||
async def _wait_for_model_update_from_disk(
|
|
||||||
self, obj: UpdateWeightFromDiskReqInput
|
|
||||||
@@ -849,6 +870,18 @@ class TokenizerManager:
|
|
||||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
|
||||||
return result.success, result.message
|
|
||||||
|
|
||||||
+ async def interrupt_all_requests(
|
|
||||||
+ self,
|
|
||||||
+ obj: InterruptAllReqInput,
|
|
||||||
+ request: Optional[fastapi.Request] = None,
|
|
||||||
+ ) -> Tuple[bool, str]:
|
|
||||||
+ self.auto_create_handle_loop()
|
|
||||||
+ result = await self.interrupt_requests_communicator(obj)
|
|
||||||
+ if self.server_args.dp_size == 1:
|
|
||||||
+ return result[0].num_interrupted_requests
|
|
||||||
+ else:
|
|
||||||
+ return [r.num_interrupted_requests for r in result]
|
|
||||||
+
|
|
||||||
async def get_weights_by_name(
|
|
||||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
|
||||||
):
|
|
|
@ -1,144 +0,0 @@
|
||||||
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
|
|
||||||
index 5390668c..db370d19 100644
|
|
||||||
--- a/python/sglang/srt/managers/io_struct.py
|
|
||||||
+++ b/python/sglang/srt/managers/io_struct.py
|
|
||||||
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
|
|
||||||
success: bool
|
|
||||||
|
|
||||||
|
|
||||||
+@dataclass
|
|
||||||
+class InterruptAllReqInput:
|
|
||||||
+ pass
|
|
||||||
+
|
|
||||||
+
|
|
||||||
+@dataclass
|
|
||||||
+class InterruptAllReqOutput:
|
|
||||||
+ num_interrupted_requests: int
|
|
||||||
+
|
|
||||||
+
|
|
||||||
@dataclass
|
|
||||||
class UpdateWeightFromDiskReqInput:
|
|
||||||
# The model path with the new weights
|
|
||||||
model_path: str
|
|
||||||
+ allow_interrupt: bool = False
|
|
||||||
# The format to load the weights
|
|
||||||
load_format: Optional[str] = None
|
|
||||||
|
|
||||||
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
|
||||||
index 1178eec5..318dee33 100644
|
|
||||||
--- a/python/sglang/srt/managers/scheduler.py
|
|
||||||
+++ b/python/sglang/srt/managers/scheduler.py
|
|
||||||
@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
|
|
||||||
HealthCheckOutput,
|
|
||||||
InitWeightsUpdateGroupReqInput,
|
|
||||||
InitWeightsUpdateGroupReqOutput,
|
|
||||||
+ InterruptAllReqInput,
|
|
||||||
+ InterruptAllReqOutput,
|
|
||||||
OpenSessionReqInput,
|
|
||||||
OpenSessionReqOutput,
|
|
||||||
ProfileReq,
|
|
||||||
@@ -427,6 +429,7 @@ class Scheduler(
|
|
||||||
# Init request dispatcher
|
|
||||||
self._request_dispatcher = TypeBasedDispatcher(
|
|
||||||
[
|
|
||||||
+ (InterruptAllReqInput, self.interrupt_all_requests),
|
|
||||||
(TokenizedGenerateReqInput, self.handle_generate_request),
|
|
||||||
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
|
||||||
(FlushCacheReqInput, self.flush_cache_wrapped),
|
|
||||||
@@ -1971,6 +1974,15 @@ class Scheduler(
|
|
||||||
def _pause_engine(self) -> Tuple[List[Req], int]:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
|
|
||||||
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
|
|
||||||
+ for req in self.waiting_queue:
|
|
||||||
+ req.sampling_params.max_new_tokens = 0
|
|
||||||
+ for req in self.running_batch.reqs:
|
|
||||||
+ req.sampling_params.max_new_tokens = len(req.output_ids)
|
|
||||||
+ logger.info(f"Interrupt {num} requests.")
|
|
||||||
+ return InterruptAllReqOutput(num)
|
|
||||||
+
|
|
||||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
|
||||||
"""In-place update of the weights from disk."""
|
|
||||||
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
|
||||||
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
|
||||||
index b646fae1..c668728b 100644
|
|
||||||
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
|
||||||
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
|
||||||
@@ -80,6 +80,8 @@ from sglang.srt.managers.io_struct import (
|
|
||||||
HealthCheckOutput,
|
|
||||||
InitWeightsUpdateGroupReqInput,
|
|
||||||
InitWeightsUpdateGroupReqOutput,
|
|
||||||
+ InterruptAllReqInput,
|
|
||||||
+ InterruptAllReqOutput,
|
|
||||||
OpenSessionReqInput,
|
|
||||||
OpenSessionReqOutput,
|
|
||||||
ProfileReq,
|
|
||||||
@@ -279,6 +281,9 @@ class TokenizerManager:
|
|
||||||
self.slow_down_communicator = _Communicator(
|
|
||||||
self.send_to_scheduler, server_args.dp_size
|
|
||||||
)
|
|
||||||
+ self.interrupt_requests_communicator = _Communicator(
|
|
||||||
+ self.send_to_scheduler, server_args.dp_size
|
|
||||||
+ )
|
|
||||||
self.flush_cache_communicator = _Communicator(
|
|
||||||
self.send_to_scheduler, server_args.dp_size
|
|
||||||
)
|
|
||||||
@@ -309,6 +314,10 @@ class TokenizerManager:
|
|
||||||
UpdateWeightFromDiskReqOutput,
|
|
||||||
self._handle_update_weights_from_disk_req_output,
|
|
||||||
),
|
|
||||||
+ (
|
|
||||||
+ InterruptAllReqOutput,
|
|
||||||
+ self.interrupt_requests_communicator.handle_recv,
|
|
||||||
+ ),
|
|
||||||
(
|
|
||||||
InitWeightsUpdateGroupReqOutput,
|
|
||||||
self.init_weights_update_group_communicator.handle_recv,
|
|
||||||
@@ -799,6 +808,13 @@ class TokenizerManager:
|
|
||||||
) -> Tuple[bool, str]:
|
|
||||||
self.auto_create_handle_loop()
|
|
||||||
|
|
||||||
+ if obj.allow_interrupt:
|
|
||||||
+ num_interrupted_requests = await self.interrupt_all_requests(
|
|
||||||
+ InterruptAllReqInput()
|
|
||||||
+ )
|
|
||||||
+ # Set a break point to wait for the interrupt to finish
|
|
||||||
+ await asyncio.sleep(0.1)
|
|
||||||
+
|
|
||||||
# default the load format to the server_args
|
|
||||||
if obj.load_format is None:
|
|
||||||
obj.load_format = self.server_args.load_format
|
|
||||||
@@ -808,7 +824,12 @@ class TokenizerManager:
|
|
||||||
# Hold the lock if it is not async. This means that weight sync
|
|
||||||
# cannot run while requests are in progress.
|
|
||||||
async with self.model_update_lock.writer_lock:
|
|
||||||
- return await self._wait_for_model_update_from_disk(obj)
|
|
||||||
+ success, message, n_paused = (
|
|
||||||
+ await self._wait_for_model_update_from_disk(obj)
|
|
||||||
+ )
|
|
||||||
+ if obj.allow_interrupt:
|
|
||||||
+ return success, message, num_interrupted_requests
|
|
||||||
+ return success, message, n_paused
|
|
||||||
|
|
||||||
async def _wait_for_model_update_from_disk(
|
|
||||||
self, obj: UpdateWeightFromDiskReqInput
|
|
||||||
@@ -881,6 +902,18 @@ class TokenizerManager:
|
|
||||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
|
||||||
return result.success, result.message
|
|
||||||
|
|
||||||
+ async def interrupt_all_requests(
|
|
||||||
+ self,
|
|
||||||
+ obj: InterruptAllReqInput,
|
|
||||||
+ request: Optional[fastapi.Request] = None,
|
|
||||||
+ ) -> Tuple[bool, str]:
|
|
||||||
+ self.auto_create_handle_loop()
|
|
||||||
+ result = await self.interrupt_requests_communicator(obj)
|
|
||||||
+ if self.server_args.dp_size == 1:
|
|
||||||
+ return result[0].num_interrupted_requests
|
|
||||||
+ else:
|
|
||||||
+ return [r.num_interrupted_requests for r in result]
|
|
||||||
+
|
|
||||||
async def get_weights_by_name(
|
|
||||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
|
||||||
):
|
|
|
@ -34,7 +34,6 @@ dependencies = [
|
||||||
"transformers==4.53.1",
|
"transformers==4.53.1",
|
||||||
|
|
||||||
# Scientific computing
|
# Scientific computing
|
||||||
"numpy<2.0.0",
|
|
||||||
"scipy",
|
"scipy",
|
||||||
"pandas",
|
"pandas",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
|
|
|
@ -40,42 +40,11 @@ def execute_shell_command(command: str) -> subprocess.Popen:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_sglang_patch():
|
|
||||||
p = Path(os.path.dirname(__file__))
|
|
||||||
patch_path = str(
|
|
||||||
p.parent.parent
|
|
||||||
/ "patch"
|
|
||||||
/ "sglang"
|
|
||||||
/ f"v{pkg_version.get_version('sglang')}.patch"
|
|
||||||
)
|
|
||||||
|
|
||||||
target_path = ""
|
|
||||||
sglang_meta = subprocess.check_output(
|
|
||||||
"python3 -m pip show sglang", shell=True
|
|
||||||
).decode("ascii")
|
|
||||||
for line in sglang_meta.split("\n"):
|
|
||||||
line = line.strip()
|
|
||||||
if line.startswith("Editable project location: "):
|
|
||||||
target_path = str(Path(line.split(": ")[1]).parent)
|
|
||||||
|
|
||||||
if target_path:
|
|
||||||
proc = subprocess.Popen(
|
|
||||||
["git", "apply", patch_path],
|
|
||||||
cwd=target_path,
|
|
||||||
stderr=sys.stdout,
|
|
||||||
stdout=sys.stdout,
|
|
||||||
)
|
|
||||||
proc.wait()
|
|
||||||
logger.info(f"Applied SGLang patch at {target_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def launch_server_cmd(command: str, port: int = 30000):
|
def launch_server_cmd(command: str, port: int = 30000):
|
||||||
"""
|
"""
|
||||||
Launch the server using the given command.
|
Launch the server using the given command.
|
||||||
If no port is specified, a free port is reserved.
|
If no port is specified, a free port is reserved.
|
||||||
"""
|
"""
|
||||||
if not ray.is_initialized():
|
|
||||||
apply_sglang_patch()
|
|
||||||
assert port is not None
|
assert port is not None
|
||||||
full_command = f"{command} --port {port}"
|
full_command = f"{command} --port {port}"
|
||||||
process = execute_shell_command(full_command)
|
process = execute_shell_command(full_command)
|
||||||
|
|
|
@ -155,39 +155,22 @@ class GserverManager(Worker):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def flush_requests_and_update_weights(
|
async def flush_requests_and_update_weights(self, server_url, new_param_path):
|
||||||
self, server_url, new_param_path, update_weights_retries=5
|
async with aiohttp.ClientSession(
|
||||||
):
|
server_url,
|
||||||
server_index = self.server_urls.index(server_url)
|
timeout=aiohttp.ClientTimeout(
|
||||||
success = False
|
total=self.config.flush_request_timeout,
|
||||||
for _ in range(update_weights_retries):
|
sock_connect=self.config.flush_request_timeout,
|
||||||
async with aiohttp.ClientSession(
|
),
|
||||||
server_url,
|
) as session:
|
||||||
timeout=aiohttp.ClientTimeout(
|
(await session.post("/pause_generation")).raise_for_status()
|
||||||
total=self.config.flush_request_timeout,
|
async with session.post(
|
||||||
sock_connect=self.config.flush_request_timeout,
|
f"/update_weights_from_disk",
|
||||||
),
|
json=dict(model_path=new_param_path),
|
||||||
) as session:
|
) as resp:
|
||||||
async with session.post(
|
resp.raise_for_status()
|
||||||
f"/update_weights_from_disk",
|
assert (await resp.json())["success"]
|
||||||
json=dict(model_path=new_param_path, allow_interrupt=True),
|
(await session.post("/continue_generation")).raise_for_status()
|
||||||
) as resp:
|
|
||||||
if resp.status == 200:
|
|
||||||
res = await resp.json()
|
|
||||||
success = res["success"]
|
|
||||||
if success:
|
|
||||||
if "num_paused_requests" in res:
|
|
||||||
logger.info(
|
|
||||||
f"{res['num_paused_requests']} requests are interrupted "
|
|
||||||
f"during updating weights for server {server_index}: {server_url}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
logger.warning(
|
|
||||||
f"Update weights failed: {res['message']}. Retrying."
|
|
||||||
)
|
|
||||||
logger.warning(f"Update weights failed: {resp.reason}. Retrying.")
|
|
||||||
time.sleep(0.1)
|
|
||||||
raise RuntimeError("Update weights failed.")
|
|
||||||
|
|
||||||
def _round_robin_schedule(self, req_meta: GenReqMeta) -> int:
|
def _round_robin_schedule(self, req_meta: GenReqMeta) -> int:
|
||||||
if not hasattr(self, "round_robin_idx"):
|
if not hasattr(self, "round_robin_idx"):
|
||||||
|
|
|
@ -25,7 +25,6 @@ numba
|
||||||
packaging
|
packaging
|
||||||
pandas
|
pandas
|
||||||
pybind11>=2.10.0
|
pybind11>=2.10.0
|
||||||
numpy<2.0.0
|
|
||||||
psutil
|
psutil
|
||||||
pynvml
|
pynvml
|
||||||
pytest
|
pytest
|
||||||
|
|
Loading…
Reference in New Issue