[lite] [feature] Bump to SGLang v0.4.9.post2 and use NCCL to update weights (#196)

* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine

Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/353

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* add gradient checkpointing

* PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP  engine

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/354

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .

* PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* fix
* .

* PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities

Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* fix destroy process group

* PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset

Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/358

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* fix loss mask
* fix
* .

* PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub

Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/368

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .

* PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation

Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/371

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher

Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/370

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* .
* .
* .
* fix

* PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results.

Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/392

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* support fsdp engine and sglang remote engine
* minor fix
* .
* refactor trainer
* add close
* rm mb_spec
* .
* fix
* .
* qwen2 grpo works
* fix
* fix
* async works
* fix
* slurm launcher not tested
* fix arg parse
* .
* sglang server wrapper
* .
* .
* slurm run
* ready for boba
* debug
* 32k run
* .
* .
* fix
* .
* .
* .
* .
* .
* fix
* .
* fix
* .
* .
* .
* .
* fix
* .
* .
* .
* .
* .
* .
* .
* refactor train engine
* refactor train engine
* .
* fix update weight error
* .
* .
* match train
* format
* .
* fix
* seems to work
* .
* .
* .
* .

* .

* PullRequest: 408 [Feature] Bump SGLang version to v0.4.9.post2

Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/408

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* bump arealite to sglang 0.4.9.post2
* .
* PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta`

* PullRequest: 422 [lite] Fix tests and scripts after updating sgl to 0.4.9

Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/422

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* bump arealite to sglang 0.4.9.post2
* .
* PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta`
* .

* PullRequest: 423 [lite] Remove the boba example for github release.

Merge branch fw/remove-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/423

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .

---------

Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
Wei Fu 2025-07-24 15:34:52 +08:00 committed by GitHub
parent 6239633213
commit 311bcd7697
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 313 additions and 1203 deletions

View File

@ -163,7 +163,6 @@ class WeightUpdateMeta:
type: str
path: str | None
alloc_mode: AllocationMode | None
comm_backend: str | None
@dataclass
class SaveLoadMeta:

View File

@ -285,85 +285,6 @@ class PPOActorConfig(TrainEngineConfig):
)
@dataclass
class PPOActorConfig(TrainEngineConfig):
# Core PPO/GRPO Parameters
group_size: int = field(
default=1, metadata={"help": "Number of sequences in each group"}
)
group_adv_norm: bool = field(
default=False,
metadata={
"help": "Normalize advantages within each prompt group rather than globally"
},
)
ppo_n_minibatches: int = field(
default=4, metadata={"help": "Number of minibatches for each PPO update"}
)
eps_clip: float = field(
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
)
c_clip: Optional[float] = field(
default=None,
metadata={
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
},
)
temperature: float = field(
default=1.0, metadata={"help": "Temperature during generation."}
)
# Reward
group_reward_norm: bool = field(
default=False,
metadata={
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
},
)
reward_scaling: float = field(
default=1.0, metadata={"help": "Reward scaling factor"}
)
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
reward_clip: float = field(
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
)
mask_no_eos_with_zero: bool = field(
default=False,
metadata={
"help": "Mask truncated generations (no EOS token) and exclude from training"
},
)
# Advantage Estimation
discount: float = field(
default=1.0, metadata={"help": "Discount factor for future rewards"}
)
gae_lambda: float = field(
default=1.0, metadata={"help": "Lambda parameter for GAE"}
)
adv_norm: bool = field(
default=True, metadata={"help": "Enable advantage normalization"}
)
# KL Control
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
# Asynchronous RL
recompute_logprob: bool = field(
default=False,
metadata={"help": "Recompute logp and replace the logp returned by inference."},
)
use_decoupled_loss: bool = field(
default=False,
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
)
behav_imp_weight_cap: Optional[float] = field(
default=None,
metadata={
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
},
)
@dataclass
class SGLangConfig:
"""Configuration for SGLang runtime. Refer to:

View File

@ -12,6 +12,7 @@ from arealite.api.io_struct import (
FinetuneSpec,
LLMRequest,
LLMResponse,
ParamSpec,
SaveLoadMeta,
WeightUpdateMeta,
)
@ -63,7 +64,19 @@ class TrainEngine(abc.ABC):
return self.train(False)
def upload_weights(self, meta: WeightUpdateMeta):
"""Upload weights to the inference engine."""
"""Upload weights to the inference engine (in a blocking manner)."""
raise NotImplementedError()
def get_param_specs(self) -> List[ParamSpec]:
"""Get the parameter specifications for the model."""
raise NotImplementedError()
def set_version(self, version: int):
"""Set the current weight version in the train engine."""
raise NotImplementedError()
def get_version(self) -> int:
"""Get the current weight version in the train engine."""
raise NotImplementedError()
def save(self, meta: SaveLoadMeta):
@ -122,14 +135,22 @@ class InferenceEngine(abc.ABC):
def destroy(self):
"""Destroy the engine and release GPU memory."""
def update_weights(self, meta: WeightUpdateMeta) -> Future:
"""Update weights in the inference engine."""
raise NotImplementedError()
async def agenerate(self, req: LLMRequest) -> LLMResponse:
"""Asynchronously generate a response for the given request."""
raise NotImplementedError()
def update_weights(self, meta: WeightUpdateMeta) -> Future:
"""Update weights in the inference engine in a non-blocking manner."""
raise NotImplementedError()
def set_version(self, version: int) -> None:
"""Set the current weight version in the inference engine."""
raise NotImplementedError()
def get_version(self) -> int:
"""Get the current weight version in the inference engine."""
raise NotImplementedError()
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
"""Asynchronously submit a request to the inference engine. Exits immediately."""
raise NotImplementedError()

View File

@ -2,14 +2,19 @@
# Licensed under the Apache License, Version 2.0
import enum
import itertools
import os
import re
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple
from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters
from arealite.api.cli_args import GenerationHyperparameters, SaverConfig
from arealite.utils.network import find_free_ports, gethostip
if TYPE_CHECKING:
from arealite.api.engine_api import TrainEngine
@dataclass
@ -155,20 +160,55 @@ class AllocationMode:
return other_alloc
@dataclass
class ParamSpec:
name: str
shape: Tuple
dtype: str
@dataclass
class WeightUpdateMeta:
type: str
path: str | None
alloc_mode: AllocationMode | None
comm_backend: str | None
model_version: int = 0
tp_size: int = 1
master_address: str = "127.0.0.1"
master_port: int = 29500
world_size: int = 1
group_name: str = "aupdate_weights_from_distributed"
parameter_names: List[str] = field(default_factory=list)
state_dict_key_to_shape: Dict[str, Tuple[int]] = field(default_factory=dict)
type: Literal["disk", "nccl"]
path: str | None = None
alloc_mode: AllocationMode | None = None
nccl_master_address: str = "127.0.0.1"
nccl_master_port: int = 29500
nccl_param_specs: List[ParamSpec] = field(default_factory=list)
nccl_group_name: str = "update_weight_group"
@classmethod
def from_disk(
cls,
saver_config: SaverConfig,
) -> "WeightUpdateMeta":
from arealite.utils.saver import Saver
path = os.path.join(
Saver.get_save_checkpoint_root(saver_config),
"weight_update",
)
return cls(
type="disk",
path=path,
)
@classmethod
def from_fsdp_nccl(
cls,
allocation_mode: AllocationMode,
fsdp_engine: "TrainEngine",
nccl_group_name: str = "update_weight_group",
):
return cls(
type="nccl",
alloc_mode=allocation_mode,
nccl_master_address=gethostip(),
nccl_master_port=find_free_ports(1)[0],
nccl_param_specs=fsdp_engine.get_param_specs(),
nccl_group_name=nccl_group_name,
)
@dataclass

View File

@ -46,6 +46,7 @@ class BaseHFEngine(TrainEngine):
self.tokenizer: PreTrainedTokenizerFast
# huggingface model config
self.model_config: PretrainedConfig
self._version: int = 0
# initialization
self.initialized = False
@ -55,6 +56,12 @@ class BaseHFEngine(TrainEngine):
self.world_size = int(os.environ["WORLD_SIZE"])
def set_version(self, version: int):
self._version = version
def get_version(self) -> int:
return self._version
def train(self, mode: bool = True):
assert self.model is not None
self.model.train(mode=mode)
@ -191,7 +198,7 @@ class BaseHFEngine(TrainEngine):
)
state_dict = self.optimizer.state_dict()
torch.save(state_dict, shard_path)
dist.barrier()
dist.barrier(device_ids=[self.device.index])
def load_optimizer_state(self, path: str):
# Load FSDP sharded state dict
@ -203,7 +210,7 @@ class BaseHFEngine(TrainEngine):
)
optimizer_state_dict = torch.load(shard_path, weights_only=False)
self.optimizer.load_state_dict(optimizer_state_dict)
dist.barrier()
dist.barrier(device_ids=[self.device.index])
def step_lr_scheduler(self):
assert self.lr_scheduler is not None

View File

@ -1,7 +1,7 @@
import os
import time
from datetime import datetime
from typing import Callable, Dict, Optional, Tuple
from typing import Callable, Dict, List, Optional
import torch
import torch.distributed as dist
@ -14,7 +14,8 @@ from torch.distributed.checkpoint.state_dict import (
from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
from arealite.api.engine_api import FinetuneSpec
from arealite.api.io_struct import ParamSpec, SaveLoadMeta, WeightUpdateMeta
from arealite.engine.base_hf_engine import BaseHFEngine
from arealite.utils.distributed import init_custom_process_group
from arealite.utils.fsdp import (
@ -119,7 +120,7 @@ class FSDPEngine(BaseHFEngine):
if tokenizer is not None:
tokenizer.save_pretrained(path)
dist.barrier()
dist.barrier(device_ids=[self.device.index])
def _load_model_from_hf(self, path: str):
"""Load model from HuggingFace format."""
@ -140,7 +141,7 @@ class FSDPEngine(BaseHFEngine):
if not self.weight_update_group_initialized:
self._init_distributed_weight_update(meta)
self._update_weights_from_distributed()
dist.barrier()
dist.barrier(device_ids=[self.device.index])
torch.cuda.synchronize()
elif meta.type == "disk":
self._save_model_to_hf(meta.path, self.tokenizer)
@ -149,7 +150,7 @@ class FSDPEngine(BaseHFEngine):
update_name = names.update_weights_from_disk(
self.config.experiment_name,
self.config.trial_name,
meta.model_version,
self.model_version,
)
name_resolve.add(
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
@ -158,16 +159,18 @@ class FSDPEngine(BaseHFEngine):
raise ValueError(f"Unknown weight update type {meta.type}")
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
# NOTE: Processes launched with torchrun will set the following env var to True,
# which blocks creating another TCP store for weight update.
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = str(False)
if dist.get_rank() == 0:
self.weight_update_group = init_custom_process_group(
backend="nccl",
world_size=meta.world_size,
init_method=f"tcp://{meta.master_address}:{meta.master_port}",
world_size=meta.alloc_mode.gen_world_size + 1,
init_method=f"tcp://{meta.nccl_master_address}:{meta.nccl_master_port}",
rank=0,
group_name=meta.group_name,
group_name=meta.nccl_group_name,
)
# NOTE: synchronizing with sglang's barrier
dist.barrier(group=self.weight_update_group, device_ids=[self.device.index])
# NOTE: sglang v0.4.9.post2 or later does not have the barrier call
self.weight_update_group_initialized = True
def _update_weights_from_distributed(self):
@ -179,23 +182,29 @@ class FSDPEngine(BaseHFEngine):
else:
tensor = param.data
if dist.get_rank() == 0:
print(f"Broadcasting {name} with shape {tensor.shape}", flush=True)
logger.debug(
f"Broadcasting {name} with shape {tensor.shape}", flush=True
)
dist.broadcast(tensor, src=0, group=self.weight_update_group)
dist.barrier()
del tensor # optional, for memory hygiene
torch.cuda.empty_cache()
def get_param_meta_for_distributed_update(self) -> Dict[str, Tuple[int]]:
"""Return a dict mapping param name to its shape (expanded if DTensor)."""
param_shapes = {}
def get_param_specs(self) -> List[ParamSpec]:
param_specs = []
for name, param in self.model.named_parameters():
if isinstance(param.data, DTensor):
tensor = param.data.full_tensor()
else:
tensor = param.data
param_shapes[name] = tuple(tensor.shape)
param_specs.append(
ParamSpec(
name=name,
shape=tuple(tensor.shape),
dtype=str(tensor.dtype).split("torch.")[1],
)
)
del tensor # free memory if full_tensor was created
return param_shapes
return param_specs
def train_batch(
self,

View File

@ -1,10 +1,11 @@
import asyncio
import os
import random
import shutil
import threading
import time
import traceback
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from concurrent.futures import Future, ProcessPoolExecutor
from datetime import datetime
from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, List
@ -27,17 +28,12 @@ from arealite.api.io_struct import (
)
from arealite.utils.data import concat_padded_tensors
from arealite.utils.http import arequest_with_retry, get_default_connector
from realhf.base import logging, name_resolve, names, pkg_version
from realhf.base import logging, name_resolve, names
if TYPE_CHECKING:
from arealite.api.workflow_api import RolloutWorkflow
logger = logging.getLogger(__name__)
if pkg_version.is_available("sglang"):
if pkg_version.is_version_greater_or_equal("sglang", "0.4.4"):
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "output_ids"
else:
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
ROLLOUT_POLL_WAIT_TIME = 0.05
RID_CACHE_SIZE = 128
@ -91,10 +87,7 @@ class RemoteSGLangEngine(InferenceEngine):
def check_health(self, base_url):
# Check server endpoint
try:
response = requests.get(
f"{base_url}/metrics",
timeout=30,
)
response = requests.get(f"{base_url}/health", timeout=30)
return response.status_code == 200
except requests.exceptions.RequestException as e:
return False
@ -123,7 +116,7 @@ class RemoteSGLangEngine(InferenceEngine):
"""Thread that runs the rollout loop."""
try:
uvloop.run(self._rollout_thread_async())
except Exception as e:
except Exception:
traceback.print_exc()
async def _rollout_thread_async(self):
@ -259,9 +252,7 @@ class RemoteSGLangEngine(InferenceEngine):
accumulated_output_logprobs = []
accumulated_versions = []
# Deal with rollout interruption
stop_reason = "length"
# A single "rid" shares the same sever to allow KV cache reuse
if req.rid in self.rid_to_address:
server_addr = self.rid_to_address[req.rid]
else:
@ -273,10 +264,19 @@ class RemoteSGLangEngine(InferenceEngine):
self.rid_to_address[req.rid] = server_addr
self.rid_queue.append(req.rid)
# Deal with rollout interruption
# "abort" is the stop reason for later v0.4.9.post2 after
# we call the pause_generation endpoint
stop_reason = None
while (
stop_reason != "stop"
and len(accumulated_output_tokens) < gconfig.max_new_tokens
):
# Request is interrupted, wait for some time to avoid interfering
# with update weights requests
if stop_reason is not None:
await asyncio.sleep(0.5)
# loop until the generation is complete
result = await arequest_with_retry(
session=self.session,
@ -288,8 +288,17 @@ class RemoteSGLangEngine(InferenceEngine):
timeout=self.config.request_timeout,
)
# Parse response
meta_info = result["meta_info"]
# Check if generation is complete
finish_reason = meta_info["finish_reason"]
stop_reason = finish_reason["type"]
if (
stop_reason == "abort"
and finish_reason.get("message") == "Abort before prefill"
):
continue
# Parse response
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
@ -299,11 +308,7 @@ class RemoteSGLangEngine(InferenceEngine):
# FIXME: Update with actual server versions
accumulated_versions.extend([-1] * len(output_tokens))
# Check if generation is complete
finish_reason = meta_info["finish_reason"]
stop_reason = finish_reason["type"]
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
payload["input_ids"] += result["output_ids"]
sample_params["max_new_tokens"] -= len(output_tokens)
latency = time.perf_counter() - start_time
@ -318,30 +323,24 @@ class RemoteSGLangEngine(InferenceEngine):
ttft=latency, # Simplified for non-streaming
)
def update_weights(self, meta):
executor = ThreadPoolExecutor(max_workers=1)
return executor.submit(self._update_weights, meta)
def _update_weights(self, meta: WeightUpdateMeta):
def update_weights(self, meta: WeightUpdateMeta):
for addr in self.addresses:
res = requests.post(f"http://{addr}/pause_generation")
res.raise_for_status()
fut = Future()
if meta.type == "nccl":
if not self.distributed_weight_update_initialized:
self._init_distributed_weight_update(meta)
tik = time.perf_counter()
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
for param in meta.parameter_names:
jobs = [
self.aupdate_weights_from_distributed(addr, meta, param)
for addr in self.addresses
]
loop.run_until_complete(asyncio.gather(*jobs))
finally:
loop.close()
logger.info(
f"Distributed update weights done in {time.perf_counter() - tik}s"
fut = self.executor.submit(
update_weights_from_distributed,
meta,
self.addresses,
self.config.request_timeout,
not self.distributed_weight_update_initialized,
)
self.set_version(meta.model_version)
def callback(fut):
self.distributed_weight_update_initialized = True
fut.add_done_callback(callback)
elif meta.type == "disk":
# Update weights from disk
# Use ProcessPool to bypass python GIL for running async coroutines
@ -349,7 +348,7 @@ class RemoteSGLangEngine(InferenceEngine):
update_weights_from_disk,
self.config.experiment_name,
self.config.trial_name,
meta.model_version,
self.get_version(),
self.addresses,
meta.path,
self.config.request_retries,
@ -357,64 +356,19 @@ class RemoteSGLangEngine(InferenceEngine):
)
def callback(fut):
self.set_version(meta.model_version)
shutil.rmtree(meta.path, ignore_errors=True)
fut.add_done_callback(callback)
return fut
else:
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
try:
# Initialize weights update group
jobs = [
self.ainit_weights_update_group(addr, meta) for addr in self.addresses
]
loop = asyncio.new_event_loop()
# asyncio event loop should be manually set when running asyncio stuff in another thread
asyncio.set_event_loop(loop)
loop.run_until_complete(asyncio.gather(*jobs))
self.distributed_weight_update_initialized = True
logger.info(f"Distributed update weights initialized")
finally:
loop.close()
def callback(fut):
for addr in self.addresses:
res = requests.post(f"http://{addr}/continue_generation")
res.raise_for_status()
async def ainit_weights_update_group(self, addr: str, meta: WeightUpdateMeta):
rank_offset = 1 + self.addresses.index(addr) * meta.tp_size
payload = {
"master_address": meta.master_address,
"master_port": str(meta.master_port),
"rank_offset": rank_offset,
"world_size": meta.world_size,
"group_name": meta.group_name,
"backend": "nccl",
}
res = await arequest_with_retry(
addr=addr,
endpoint="/init_weights_update_group",
payload=payload,
method="POST",
max_retries=1,
timeout=self.config.request_timeout,
)
assert res["success"]
async def aupdate_weights_from_distributed(
self, addr: str, meta: WeightUpdateMeta, parameter_name: str
):
res = await arequest_with_retry(
addr=addr,
endpoint="/update_weights_from_distributed",
payload={
"name": parameter_name,
"dtype": "bfloat16",
"shape": meta.state_dict_key_to_shape[parameter_name],
},
method="POST",
max_retries=1,
timeout=self.config.request_timeout,
)
assert res["success"]
fut.add_done_callback(callback)
return fut
def get_capacity(self):
if dist.is_initialized():
@ -519,27 +473,6 @@ class RemoteSGLangEngine(InferenceEngine):
self.paused.clear()
async def aupdate_weights_from_disk(
session, addr, path: str, request_retries: int, request_timeout: float
):
tik = time.time()
res = await arequest_with_retry(
addr=addr,
session=session,
endpoint="/update_weights_from_disk",
payload=dict(model_path=str(path), allow_interrupt=True),
method="POST",
max_retries=request_retries,
timeout=request_timeout,
)
assert res["success"]
if "num_paused_requests" in res:
logger.info(
f"{res['num_paused_requests']} requests are interrupted "
f"during updating weights for server {addr}"
)
def update_weights_from_disk(
experiment_name,
trial_name,
@ -569,12 +502,14 @@ def update_weights_from_disk(
connector=get_default_connector(),
)
jobs = [
aupdate_weights_from_disk(
session=session,
arequest_with_retry(
addr=addr,
path=path,
request_retries=request_retries,
request_timeout=request_timeout,
session=session,
endpoint="/update_weights_from_disk",
payload=dict(model_path=str(path)),
method="POST",
max_retries=request_retries,
timeout=request_timeout,
)
for addr in addresses
]
@ -585,3 +520,72 @@ def update_weights_from_disk(
)
return uvloop.run(_fn())
def update_weights_from_distributed(
meta: WeightUpdateMeta,
addresses: List[str],
request_timeout,
init_group: bool,
):
async def _fn():
tik = time.perf_counter()
if init_group:
await asyncio.gather(
*[
ainit_weights_update_group(addr, i, meta, request_timeout)
for i, addr in enumerate(addresses)
]
)
await asyncio.gather(
*[
arequest_with_retry(
addr=addr,
endpoint="/update_weights_from_distributed",
payload={
"names": [pspec.name for pspec in meta.nccl_param_specs],
"dtypes": [pspec.dtype for pspec in meta.nccl_param_specs],
"shapes": [pspec.shape for pspec in meta.nccl_param_specs],
"group_name": meta.nccl_group_name,
},
method="POST",
max_retries=1,
timeout=request_timeout,
)
for addr in addresses
]
)
logger.info(f"Distributed update weights done in {time.perf_counter() - tik}s")
return uvloop.run(_fn())
async def ainit_weights_update_group(
addr: str,
server_idx: int,
meta: WeightUpdateMeta,
request_timeout: float,
):
assert meta.alloc_mode is not None
if meta.alloc_mode.gen_pp_size != 1:
raise NotImplementedError(
"NCCL weight update with PP size > 1 is not implemented yet."
)
rank_offset = 1 + server_idx * meta.alloc_mode.gen_tp_size
payload = {
"master_address": meta.nccl_master_address,
"master_port": str(meta.nccl_master_port),
"rank_offset": rank_offset,
"world_size": meta.alloc_mode.gen_world_size + 1,
"backend": "nccl",
"group_name": meta.nccl_group_name,
}
res = await arequest_with_retry(
addr=addr,
endpoint="/init_weights_update_group",
payload=payload,
method="POST",
max_retries=1,
timeout=request_timeout,
)
assert res["success"]

View File

@ -283,7 +283,7 @@ def main_local():
if not cfg.server_only:
launcher.submit(
job_name="trainer",
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} --standalone {' '.join(sys.argv[1:])}",
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} --master-addr localhost --master-port {find_free_ports(1, (10000, 50000))[0]} {' '.join(sys.argv[1:])}",
gpu=alloc_mode.train_world_size,
env_vars=dict(AREAL_LLM_SERVER_ADDRS=",".join(server_addrs)),
)

View File

@ -3,10 +3,8 @@ import subprocess
import sys
import time
import uuid
from pathlib import Path
from typing import Optional
import ray
import requests
from arealite.api.cli_args import (
@ -19,12 +17,12 @@ from arealite.api.cli_args import (
from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.utils.launcher import TRITON_CACHE_PATH
from arealite.utils.network import find_free_ports, gethostip
from realhf.base import logging, name_resolve, names, pkg_version
from realhf.base import logging, name_resolve, names
logger = logging.getLogger("SGLangServer Wrapper")
def execute_shell_command(command: str) -> subprocess.Popen:
def launch_server_cmd(command: str) -> subprocess.Popen:
"""
Execute a shell command and return its process handle.
"""
@ -45,46 +43,6 @@ def execute_shell_command(command: str) -> subprocess.Popen:
)
def apply_sglang_patch():
p = Path(os.path.dirname(__file__))
patch_path = str(
p.parent.parent
/ "patch"
/ "sglang"
/ f"v{pkg_version.get_version('sglang')}.patch"
)
target_path = ""
sglang_meta = subprocess.check_output(
"python3 -m pip show sglang", shell=True
).decode("ascii")
for line in sglang_meta.split("\n"):
line = line.strip()
if line.startswith("Editable project location: "):
target_path = str(Path(line.split(": ")[1]).parent)
if target_path:
proc = subprocess.Popen(
["git", "apply", patch_path],
cwd=target_path,
stderr=sys.stdout,
stdout=sys.stdout,
)
proc.wait()
logger.info(f"Applied SGLang patch at {target_path}")
def launch_server_cmd(command: str):
"""
Launch the server using the given command.
If no port is specified, a free port is reserved.
"""
if not ray.is_initialized():
apply_sglang_patch()
process = execute_shell_command(command)
return process
def wait_for_server(base_url: str, timeout: Optional[int] = None) -> None:
"""Wait for the server to be ready by polling the /v1/models endpoint.

View File

@ -12,10 +12,9 @@ from arealite.api.cli_args import (
SGLangConfig,
TrainEngineConfig,
)
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
from arealite.engine.fsdp_engine import FSDPEngine
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.network import find_free_ports
from realhf.base import network
EXPR_NAME = "test_fsdp_engine_nccl"
@ -33,7 +32,7 @@ RUN_SERVER_TIMEOUT = 180
def check_server_health(base_url):
try:
response = requests.get(f"{base_url}/metrics", timeout=30)
response = requests.get(f"{base_url}/health", timeout=30)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
@ -82,14 +81,14 @@ def sglang_server_nccl():
def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server_nccl):
# 设置分布式环境变量
# Set environment variables for torch distributed
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = HOST
os.environ["MASTER_PORT"] = str(MASTER_PORT)
# 启动本地FSDPEngine
# Initialize FSDPEngine
engine_config = TrainEngineConfig(
experiment_name=EXPR_NAME,
trial_name=TRIAL_NAME,
@ -100,38 +99,26 @@ def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
engine.initialize(None, ft_spec)
# 启动远端RemoteSGLangEngine
# Initialize RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
config.server_addrs = [f"{HOST}:{PORT}"]
remote_engine = RemoteSGLangEngine(config)
remote_engine.initialize(None, None)
# 构造WeightUpdateMetatype=nccl
param_meta = engine.get_param_meta_for_distributed_update()
meta = WeightUpdateMeta(
type="nccl",
path=None,
alloc_mode=None,
comm_backend="nccl",
model_version=123,
tp_size=1,
master_address="localhost",
master_port=find_free_ports(1)[0],
world_size=2,
group_name=GROUP_NAME,
parameter_names=list(param_meta.keys()),
state_dict_key_to_shape=param_meta,
# Get WeightUpdateMeta
meta = WeightUpdateMeta.from_fsdp_nccl(
AllocationMode.from_str("sglang.d1p1t1+d1p1t1"),
engine,
nccl_group_name=GROUP_NAME,
)
# 本地engine广播参数
# Broadcast weights
future = remote_engine.update_weights(meta)
print("got future", flush=True)
engine.upload_weights(meta)
print("uploaded wexights to remote engine", flush=True)
# 远端engine拉取参数
# Wait for remote engine to finish
future.result(timeout=120)
print("got result", flush=True)
# 检查远端参数版本
assert remote_engine.get_version() == 123
remote_engine.destroy()
engine.destroy()

View File

@ -2,7 +2,6 @@ import os
import subprocess
import sys
import time
import uuid
import pytest
import requests
@ -14,7 +13,7 @@ from arealite.api.cli_args import (
InferenceEngineConfig,
SGLangConfig,
)
from arealite.api.io_struct import LLMRequest, LLMResponse, WeightUpdateMeta
from arealite.api.io_struct import WeightUpdateMeta
from arealite.utils import network
from realhf.api.core.data_api import load_hf_tokenizer
@ -31,10 +30,7 @@ RUN_SERVER_TIMEOUT = 180
def check_server_health(base_url):
try:
response = requests.get(
f"{base_url}/metrics",
timeout=30,
)
response = requests.get(f"{base_url}/health", timeout=30)
return response.status_code == 200
except requests.exceptions.RequestException as e:
return False
@ -77,29 +73,6 @@ def sglang_server():
process.terminate()
@pytest.mark.asyncio
async def test_remote_sglang_generate(sglang_server):
from arealite.engine.sglang_remote import RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
tokenizer = load_hf_tokenizer(MODEL_PATH)
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
engine = RemoteSGLangEngine(config)
req = LLMRequest(
rid=str(uuid.uuid4()),
input_ids=tokenizer.encode("hello! how are you today"),
gconfig=GenerationHyperparameters(max_new_tokens=16),
)
resp = await engine.agenerate(req)
assert isinstance(resp, LLMResponse)
assert resp.input_tokens == req.input_ids
assert (
len(resp.output_logprobs)
== len(resp.output_tokens)
== len(resp.output_versions)
)
@pytest.mark.parametrize("n_samples", [1, 2, 4])
def test_remote_sglang_rollout(sglang_server, n_samples):
from arealite.engine.sglang_remote import RemoteSGLangEngine
@ -211,6 +184,7 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
engine = FSDPEngine(engine_config)
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
engine.initialize(None, ft_spec)
engine.model_version = 100
# setup name resolve
import realhf.base.name_resolve as name_resolve
@ -227,13 +201,11 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
inf_engine = RemoteSGLangEngine(config)
inf_engine.initialize(None, None)
inf_engine.set_version(100)
# test update weights
path = tmp_path_factory.mktemp("upload_weights_from_disk")
update_weight_meta = WeightUpdateMeta(
type="disk", path=path, alloc_mode=None, comm_backend=None, model_version=100
)
update_weight_meta = WeightUpdateMeta(type="disk", path=str(path))
future = inf_engine.update_weights(update_weight_meta)
engine.upload_weights(update_weight_meta)
future.result()
assert inf_engine.get_version() == 100
inf_engine.destroy()

View File

@ -1,310 +0,0 @@
import asyncio
import os
import shutil
import sys
import uuid
import colorama
import torch
import torch.distributed as dist
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import (
GenerationHyperparameters,
GRPOConfig,
load_expr_config,
)
from arealite.api.io_struct import FinetuneSpec, LLMRequest, WeightUpdateMeta
from arealite.api.workflow_api import RolloutWorkflow
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.data import concat_padded_tensors
from arealite.utils.device import log_gpu_stats
from arealite.utils.saver import Saver
from arealite.utils.stats_logger import StatsLogger
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, seeding, stats_tracker
logger = logging.getLogger("boba math")
class RLVRWorkflow(RolloutWorkflow):
def __init__(
self,
reward_fn,
gconfig: GenerationHyperparameters,
tokenizer: PreTrainedTokenizerFast,
dump_dir: str | None = None,
):
self.reward_fn = reward_fn
self.gconfig = gconfig
self.tokenizer = tokenizer
self.dump_dir = dump_dir
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
os.makedirs(self.dump_dir, exist_ok=True)
async def arun_episode(self, engine, data):
input_ids = self.tokenizer.encode(data["prompt"])
n_samples = self.gconfig.n_samples
req = LLMRequest(
rid=uuid.uuid4().hex,
input_ids=input_ids,
gconfig=self.gconfig.new(n_samples=1),
)
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
version = engine.get_version()
prompt_strs = []
completions_strs = []
rewards = []
seqlens = []
results = []
for resp in resps:
seq = resp.input_tokens + resp.output_tokens
logprobs = [0.0] * resp.input_len + resp.output_logprobs
loss_mask = [0] * resp.input_len + [1] * resp.output_len
versions = [-1] * resp.input_len + resp.output_versions
prompt_str = data["prompt"]
completions_str = self.tokenizer.decode(resp.output_tokens)
prompt_strs.append(prompt_str)
completions_strs.append(completions_str)
seqlens.append(len(seq))
reward = self.reward_fn(
completions=completions_str,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
rewards.append(reward)
res = dict(
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(seq).unsqueeze(0),
loss_mask=torch.tensor(loss_mask).unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0),
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
# reward
rewards=torch.tensor([float(reward)]),
)
results.append(TensorDict(res, batch_size=[1]))
if self.dump_dir is not None:
os.makedirs(os.path.join(self.dump_dir, str(version)), exist_ok=True)
# Get the unique identifier for this prompt
qid = None
for key in ["query_id", "id", "qid"]:
qid = data.get(key, None)
if qid is not None:
break
qid = qid or uuid.uuid4().hex
# Dump rollout to file
with open(
os.path.join(self.dump_dir, str(version), f"{qid}.txt"), "a"
) as f:
n_samples = self.gconfig.n_samples
for i, (p, c, r, sl) in enumerate(
zip(prompt_strs, completions_strs, rewards, seqlens)
):
info = "\n".join(
[
f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.",
f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}",
f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}",
]
)
f.write(info + "\n")
return concat_padded_tensors(results)
def get_boba_math_dataset(tokenizer, rank, world_size):
dataset = load_dataset(
path="json",
split="train",
data_files="/storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl",
)
dataset = dataset.filter(lambda x: len(tokenizer.encode(x["prompt"])) <= 1024)
return split_dataset_by_node(dataset, rank=rank, world_size=world_size)
def boba_reward_fn(
prompt, completions, prompt_ids, completion_ids, query_id, solutions, **kwargs
):
from pebble import ProcessExpired, ProcessPool
from realhf.impl.dataset.math_parser import process_results
jobs = []
with ProcessPool(max_workers=1) as executor:
for sol in solutions:
job = executor.schedule(
process_results, args=[completions, sol], timeout=15
)
jobs.append(job)
label = 0
for job in jobs:
try:
x = job.result()
except TimeoutError:
# print("[debug: timeout]")
logger.warning(f"Timeout occurred while justifying the math answer.")
x = (0, "timeout", "timeout")
except ProcessExpired as e:
logger.warning(f"Process terminated abnormally: {e}")
x = (0, "error", "error")
except Exception as e:
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
x = (0, "error", "error")
label = label or x[0]
return label
def main(args):
config, _ = load_expr_config(args, GRPOConfig)
config: GRPOConfig
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
tokenizer = load_hf_tokenizer(config.tokenizer_path)
seeding.set_random_seed(config.seed, key=f"trainer{rank}")
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
get_boba_math_dataset(tokenizer, rank, world_size),
batch_size=config.train_dataset.batch_size // world_size,
shuffle=config.train_dataset.shuffle,
num_workers=config.train_dataset.num_workers,
collate_fn=lambda x: x,
drop_last=config.train_dataset.drop_last,
)
ft_spec = FinetuneSpec(
total_train_epochs=config.total_train_epochs,
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
train_batch_size=config.train_dataset.batch_size,
)
# Initialize inference engine
rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, ft_spec)
# Initialize train engine
actor = FSDPPPOActor(config=config.actor)
actor.initialize(None, ft_spec)
ref = None
if config.actor.kl_ctl > 0 and config.ref is not None:
ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec)
# Create rollout workflow
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
workflow = RLVRWorkflow(
reward_fn=boba_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
dump_dir=os.path.join(
StatsLogger.get_log_path(config.stats_logger), "generated"
),
)
# Run training.
saver = Saver(config.saver, ft_spec, for_recover=False)
logger = StatsLogger(config.stats_logger, ft_spec)
total_epochs = config.total_train_epochs
steps_per_epoch = len(train_dataloader)
max_steps = total_epochs * steps_per_epoch
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
data_generator = iter(train_dataloader)
for global_step in range(max_steps):
epoch = global_step // steps_per_epoch
step = global_step % steps_per_epoch
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = batch.to(actor.device)
# Create barrier to synchronize all rollout processes.
dist.barrier()
torch.cuda.synchronize()
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
with stats_tracker.record_timing("recompute_logp"):
logp = actor.compute_logp(batch)
batch["prox_logp"] = logp
log_gpu_stats("recompute logp")
if ref is not None:
with stats_tracker.record_timing("ref_logp"):
batch["ref_logp"] = ref.compute_logp(batch)
log_gpu_stats("ref logp")
with stats_tracker.record_timing("compute_advantage"):
actor.compute_advantages(batch)
log_gpu_stats("compute advantages")
with (
stats_tracker.record_timing("train_step"),
stats_tracker.scope("grpo_actor"),
):
stats = actor.ppo_update(batch)
actor.step_lr_scheduler()
log_gpu_stats("ppo update")
with stats_tracker.record_timing("update_weights"):
path = os.path.join(
Saver.get_save_checkpoint_root(config.saver),
"update_weights",
str(global_step + 1),
)
meta = WeightUpdateMeta(
type="disk",
path=path,
alloc_mode=None,
comm_backend=None,
model_version=global_step + 1,
)
if dist.get_rank() == 0:
future = rollout.update_weights(meta)
actor.upload_weights(meta)
if dist.get_rank() == 0:
future.result()
shutil.rmtree(path, ignore_errors=True)
dist.barrier()
torch.cuda.synchronize()
rollout.set_version(global_step + 1)
with stats_tracker.record_timing("save"):
saver.save(actor, epoch, step, global_step)
logger.commit(epoch, step, global_step, stats)
logger.close()
rollout.destroy()
if ref is not None:
ref.destroy()
actor.destroy()
if __name__ == "__main__":
main(sys.argv[1:])

View File

@ -1,141 +0,0 @@
experiment_name: lite-boba-math
trial_name: run1
cluster:
n_nodes: 16
n_gpus_per_node: 8
cluster_name: na132
fileroot: /storage/openpsi/experiments
name_resolve:
type: nfs
nfs_record_root: /storage/openpsi/experiments/name_resolve/lite-boba-math
etcd3_addr: etcd-client.openpsi-etcd.svc.sigma-na130-lingbo.na130.wl-robby.local:2379
seed: 1
total_train_epochs: 10
total_train_steps: null
tokenizer_path: ${actor.path}
allocation_mode: sglang.d96p1t1+d32p1t1
async_training: true
rollout:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
max_concurrent_rollouts: 400
queue_size: null
consumer_batch_size: ${train_dataset.batch_size}
max_head_offpolicyness: 4
enable_rollout_tracing: true
gconfig:
n_samples: 16
min_new_tokens: 0
max_new_tokens: 30720
greedy: false
temperature: 1.0
actor:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: /storage/openpsi/models/deepseek-ai__DeepSeek-R1-Distill-Qwen-1.5B/
init_from_scratch: false
disable_dropout: true
gradient_checkpointing: true
dtype: bfloat16
mb_spec:
max_tokens_per_mb: 32768
optimizer:
type: adam
lr: 1e-5
weight_decay: 0.01
beta1: 0.9
beta2: 0.999
eps: 1e-8
lr_scheduler_type: constant
gradient_clipping: 1.0
warmup_steps_proportion: 0.001
backend: fsdp
group_size: ${gconfig.n_samples}
group_adv_norm: false
eps_clip: 0.4
temperature: ${gconfig.temperature}
reward_scaling: 10.0
reward_bias: -0.5
kl_ctl: 0.0
ppo_n_minibatches: 4
recompute_logprob: true
use_decoupled_loss: true
behav_imp_weight_cap: 5.0
ref:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: ${actor.path}
init_from_scratch: false
disable_dropout: true
dtype: ${actor.dtype}
mb_spec:
max_tokens_per_mb: 32768
optimizer: null
backend: fsdp
# SGLang
server_only: false
sglang:
model_path: ${actor.path}
random_seed: ${seed}
skip_tokenizer_init: true
dtype: ${actor.dtype}
max_running_requests: null
context_length: 32768
mem_fraction_static: 0.9
# datasets
train_dataset:
batch_size: 512
shuffle: true
pin_memory: true
# Utilities
saver:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null
checkpointer:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: 3600
evaluator:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: null
freq_steps: null
freq_secs: null
stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: online
# Launcher
launcher:
inference_server_cpus_per_gpu: 15
inference_server_mem_per_gpu: 153600
trainer_cpus_per_gpu: 15
trainer_mem_per_gpu: 153600
slurm:
mount: /storage:/storage
trainer_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif
inference_server_image: /storage/openpsi/images/arealite-20250712-update-hf-xet.sif

View File

@ -1,9 +1,9 @@
experiment_name: gsm8k-grpo
trial_name: trial0
allocation_mode: sglang.d4p1t1+d4p1t1
n_nodes: 1
n_gpus_per_node: 8
cluster:
n_nodes: 1
n_gpus_per_node: 8
fileroot: /tmp/arealite/experiments
name_resolve:
type: nfs
@ -32,7 +32,7 @@ gconfig:
actor:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: /storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/
path: Qwen/Qwen2-1.5B-Instruct
init_from_scratch: false
disable_dropout: true
gradient_checkpointing: false

View File

@ -13,7 +13,7 @@ tokenizer_path: ${model.path}
model:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
path: Qwen/Qwen3-1.7B
init_from_scratch: false
gradient_checkpointing: false
dtype: bfloat16

View File

@ -1,5 +1,4 @@
import os
import shutil
import sys
import torch
@ -9,7 +8,7 @@ from datasets.distributed import split_dataset_by_node
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import GRPOConfig, load_expr_config
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
from arealite.api.io_struct import AllocationMode, FinetuneSpec, WeightUpdateMeta
from arealite.engine.ppo.actor import FSDPPPOActor
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.device import log_gpu_stats
@ -93,6 +92,18 @@ def main(args):
ref = FSDPPPOActor(config=config.ref)
ref.initialize(None, ft_spec)
# NOTE: Weight update meta only requires address and free port of rank 0,
# but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks
# due to `engine.get_param_specs()`.
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
weight_update_meta = [
WeightUpdateMeta.from_fsdp_nccl(
AllocationMode.from_str(config.allocation_mode), actor
)
]
dist.broadcast_object_list(weight_update_meta, src=0)
weight_update_meta = weight_update_meta[0]
# Create rollout workflow
if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
@ -136,7 +147,7 @@ def main(args):
batch = batch.to(actor.device)
# Create barrier to synchronize all rollout processes.
dist.barrier()
dist.barrier(device_ids=[actor.device.index])
torch.cuda.synchronize()
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
@ -163,26 +174,16 @@ def main(args):
log_gpu_stats("ppo update")
with stats_tracker.record_timing("update_weights"):
path = os.path.join(
Saver.get_save_checkpoint_root(config.saver),
"update_weights",
str(global_step + 1),
)
meta = WeightUpdateMeta(
type="disk",
path=path,
alloc_mode=None,
comm_backend=None,
model_version=global_step + 1,
)
rollout.pause()
if dist.get_rank() == 0:
future = rollout.update_weights(meta)
actor.upload_weights(meta)
future = rollout.update_weights(weight_update_meta)
actor.upload_weights(weight_update_meta)
if dist.get_rank() == 0:
future.result()
shutil.rmtree(path, ignore_errors=True)
dist.barrier()
dist.barrier(device_ids=[actor.device.index])
torch.cuda.synchronize()
rollout.resume()
actor.set_version(global_step + 1)
rollout.set_version(global_step + 1)
with stats_tracker.record_timing("save"):

View File

@ -1,11 +0,0 @@
#!/bin/sh
AREAL_PATH=$PWD
cd /sglang
git apply $AREAL_PATH/patch/sglang/v0.4.6.post4.patch
cd $AREAL_PATH
# Package used for calculating math reward
pip install -e evaluation/latex2sympy
# Install AReaL
pip install -e .

View File

@ -1,9 +1,9 @@
#/bin/bash
# basic dependencies
pip install -U pip
pip uninstall deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
pip install nvidia-ml-py
pip uninstall pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
pip install pynvml nvidia-ml-py
pip install -e evaluation/latex2sympy
pip install vllm==0.8.5 --no-build-isolation
pip install flash_attn --no-build-isolation
pip install "flash-attn<=2.7.3" --no-build-isolation
pip install -r evaluation/requirements.txt

View File

@ -1,24 +1,14 @@
#!/bin/bash
# basic dependencies
pip install -U pip
pip uninstall torch deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0
pip install "sglang[all]==0.4.6.post4"
pip uninstall pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
pip install torch==2.7.1 torchaudio==2.7.1 torchvision==0.22.1 "deepspeed>=0.17.2" pynvml
pip install "sglang[all]==0.4.9.post2"
pip install megatron-core==0.11.0 nvidia-ml-py
pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose
pip install "flash-attn<=2.7.3" --no-build-isolation
# Package used for calculating math reward
pip install -e evaluation/latex2sympy
# Install an editable sglang
rm -rf ./sglang
git clone -b v0.4.6.post4 https://github.com/sgl-project/sglang
AREAL_PATH=$PWD
cd sglang
git apply ../patch/sglang/v0.4.6.post4.patch
pip install -e "python[all]" --no-deps
cd $AREAL_PATH
# Install AReaL
pip install -e .

View File

@ -67,6 +67,7 @@ class InstallationValidator:
def test_flash_attn_functionality(self, flash_attn_module):
"""Test flash attention functionality."""
# Try to import key functions
import flash_attn_2_cuda
from flash_attn import flash_attn_func, flash_attn_varlen_func
print(" - Flash attention functions imported successfully")
@ -79,12 +80,12 @@ class InstallationValidator:
"""Test SGLang basic functionality."""
# Basic import test is sufficient for CI
import sgl_kernel
from sglang import launch_server
assert Version(get_version("sglang")) == Version("0.4.6.post4")
from sglang import Engine, launch_server
assert Version(get_version("sglang")) == Version("0.4.9.post2"), "SGLang version should be v0.4.9.post2"
print(" - SGLang imported successfully")
def test_transformers(self, transformers_module):
assert Version(get_version("transformers")) == Version("4.51.1")
assert Version(get_version("transformers")) == Version("4.53.1"), "transformers version should be 4.53.1"
print(" - transformers imported successfully")
def validate_critical_dependencies(self):
@ -140,7 +141,7 @@ class InstallationValidator:
self.test_import("flashattn_hopper", required=False)
# Optional utilities
self.test_import("tensorboardx", required=False)
self.test_import("tensorboardX", required=False)
self.test_import("swanlab", required=False)
self.test_import("matplotlib", required=False)
self.test_import("seaborn", required=False)

View File

@ -1,144 +0,0 @@
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 174656b2..33fe0a5f 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
success: bool
+@dataclass
+class InterruptAllReqInput:
+ pass
+
+
+@dataclass
+class InterruptAllReqOutput:
+ num_interrupted_requests: int
+
+
@dataclass
class UpdateWeightFromDiskReqInput:
# The model path with the new weights
model_path: str
+ allow_interrupt: bool = False
# The format to load the weights
load_format: Optional[str] = None
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 8891115c..843a8a82 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -70,6 +70,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -419,6 +421,7 @@ class Scheduler(
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
+ (InterruptAllReqInput, self.interrupt_all_requests),
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReqInput, self.flush_cache_wrapped),
@@ -1938,6 +1941,15 @@ class Scheduler(
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
+ for req in self.waiting_queue:
+ req.sampling_params.max_new_tokens = 0
+ for req in self.running_batch.reqs:
+ req.sampling_params.max_new_tokens = len(req.output_ids)
+ logger.info(f"Interrupt {num} requests.")
+ return InterruptAllReqOutput(num)
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 82709b09..bfab3ce7 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -76,6 +76,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -265,6 +267,9 @@ class TokenizerManager:
self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
+ self.interrupt_requests_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -294,6 +299,10 @@ class TokenizerManager:
UpdateWeightFromDiskReqOutput,
self._handle_update_weights_from_disk_req_output,
),
+ (
+ InterruptAllReqOutput,
+ self.interrupt_requests_communicator.handle_recv,
+ ),
(
InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv,
@@ -767,6 +776,13 @@ class TokenizerManager:
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
+ if obj.allow_interrupt:
+ num_interrupted_requests = await self.interrupt_all_requests(
+ InterruptAllReqInput()
+ )
+ # Set a break point to wait for the interrupt to finish
+ await asyncio.sleep(0.1)
+
# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
@@ -776,7 +792,12 @@ class TokenizerManager:
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
- return await self._wait_for_model_update_from_disk(obj)
+ success, message, n_paused = (
+ await self._wait_for_model_update_from_disk(obj)
+ )
+ if obj.allow_interrupt:
+ return success, message, num_interrupted_requests
+ return success, message, n_paused
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
@@ -849,6 +870,18 @@ class TokenizerManager:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
+ async def interrupt_all_requests(
+ self,
+ obj: InterruptAllReqInput,
+ request: Optional[fastapi.Request] = None,
+ ) -> Tuple[bool, str]:
+ self.auto_create_handle_loop()
+ result = await self.interrupt_requests_communicator(obj)
+ if self.server_args.dp_size == 1:
+ return result[0].num_interrupted_requests
+ else:
+ return [r.num_interrupted_requests for r in result]
+
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):

View File

@ -1,144 +0,0 @@
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 5390668c..db370d19 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
success: bool
+@dataclass
+class InterruptAllReqInput:
+ pass
+
+
+@dataclass
+class InterruptAllReqOutput:
+ num_interrupted_requests: int
+
+
@dataclass
class UpdateWeightFromDiskReqInput:
# The model path with the new weights
model_path: str
+ allow_interrupt: bool = False
# The format to load the weights
load_format: Optional[str] = None
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 1178eec5..318dee33 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -427,6 +429,7 @@ class Scheduler(
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
+ (InterruptAllReqInput, self.interrupt_all_requests),
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReqInput, self.flush_cache_wrapped),
@@ -1971,6 +1974,15 @@ class Scheduler(
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
+ for req in self.waiting_queue:
+ req.sampling_params.max_new_tokens = 0
+ for req in self.running_batch.reqs:
+ req.sampling_params.max_new_tokens = len(req.output_ids)
+ logger.info(f"Interrupt {num} requests.")
+ return InterruptAllReqOutput(num)
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index b646fae1..c668728b 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -80,6 +80,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -279,6 +281,9 @@ class TokenizerManager:
self.slow_down_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
+ self.interrupt_requests_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -309,6 +314,10 @@ class TokenizerManager:
UpdateWeightFromDiskReqOutput,
self._handle_update_weights_from_disk_req_output,
),
+ (
+ InterruptAllReqOutput,
+ self.interrupt_requests_communicator.handle_recv,
+ ),
(
InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv,
@@ -799,6 +808,13 @@ class TokenizerManager:
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
+ if obj.allow_interrupt:
+ num_interrupted_requests = await self.interrupt_all_requests(
+ InterruptAllReqInput()
+ )
+ # Set a break point to wait for the interrupt to finish
+ await asyncio.sleep(0.1)
+
# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
@@ -808,7 +824,12 @@ class TokenizerManager:
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
- return await self._wait_for_model_update_from_disk(obj)
+ success, message, n_paused = (
+ await self._wait_for_model_update_from_disk(obj)
+ )
+ if obj.allow_interrupt:
+ return success, message, num_interrupted_requests
+ return success, message, n_paused
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
@@ -881,6 +902,18 @@ class TokenizerManager:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
+ async def interrupt_all_requests(
+ self,
+ obj: InterruptAllReqInput,
+ request: Optional[fastapi.Request] = None,
+ ) -> Tuple[bool, str]:
+ self.auto_create_handle_loop()
+ result = await self.interrupt_requests_communicator(obj)
+ if self.server_args.dp_size == 1:
+ return result[0].num_interrupted_requests
+ else:
+ return [r.num_interrupted_requests for r in result]
+
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):

View File

@ -31,10 +31,9 @@ dependencies = [
"huggingface_hub",
"datasets",
"accelerate",
"transformers==4.51.1",
"transformers==4.53.0",
# Scientific computing
"numpy<2.0.0",
"scipy",
"pandas",
"matplotlib",

View File

@ -40,42 +40,11 @@ def execute_shell_command(command: str) -> subprocess.Popen:
)
def apply_sglang_patch():
p = Path(os.path.dirname(__file__))
patch_path = str(
p.parent.parent
/ "patch"
/ "sglang"
/ f"v{pkg_version.get_version('sglang')}.patch"
)
target_path = ""
sglang_meta = subprocess.check_output(
"python3 -m pip show sglang", shell=True
).decode("ascii")
for line in sglang_meta.split("\n"):
line = line.strip()
if line.startswith("Editable project location: "):
target_path = str(Path(line.split(": ")[1]).parent)
if target_path:
proc = subprocess.Popen(
["git", "apply", patch_path],
cwd=target_path,
stderr=sys.stdout,
stdout=sys.stdout,
)
proc.wait()
logger.info(f"Applied SGLang patch at {target_path}")
def launch_server_cmd(command: str, port: int = 30000):
"""
Launch the server using the given command.
If no port is specified, a free port is reserved.
"""
if not ray.is_initialized():
apply_sglang_patch()
assert port is not None
full_command = f"{command} --port {port}"
process = execute_shell_command(full_command)

View File

@ -155,39 +155,22 @@ class GserverManager(Worker):
return None
async def flush_requests_and_update_weights(
self, server_url, new_param_path, update_weights_retries=5
):
server_index = self.server_urls.index(server_url)
success = False
for _ in range(update_weights_retries):
async with aiohttp.ClientSession(
server_url,
timeout=aiohttp.ClientTimeout(
total=self.config.flush_request_timeout,
sock_connect=self.config.flush_request_timeout,
),
) as session:
async with session.post(
f"/update_weights_from_disk",
json=dict(model_path=new_param_path, allow_interrupt=True),
) as resp:
if resp.status == 200:
res = await resp.json()
success = res["success"]
if success:
if "num_paused_requests" in res:
logger.info(
f"{res['num_paused_requests']} requests are interrupted "
f"during updating weights for server {server_index}: {server_url}"
)
return
logger.warning(
f"Update weights failed: {res['message']}. Retrying."
)
logger.warning(f"Update weights failed: {resp.reason}. Retrying.")
time.sleep(0.1)
raise RuntimeError("Update weights failed.")
async def flush_requests_and_update_weights(self, server_url, new_param_path):
async with aiohttp.ClientSession(
server_url,
timeout=aiohttp.ClientTimeout(
total=self.config.flush_request_timeout,
sock_connect=self.config.flush_request_timeout,
),
) as session:
(await session.post("/pause_generation")).raise_for_status()
async with session.post(
f"/update_weights_from_disk",
json=dict(model_path=new_param_path),
) as resp:
resp.raise_for_status()
assert (await resp.json())["success"]
(await session.post("/continue_generation")).raise_for_status()
def _round_robin_schedule(self, req_meta: GenReqMeta) -> int:
if not hasattr(self, "round_robin_idx"):

View File

@ -25,7 +25,6 @@ numba
packaging
pandas
pybind11>=2.10.0
numpy<2.0.0
psutil
pynvml
pytest