added remote nccl weight update

feat: implement update_weights_from_distributed in fsdp_engine.py

unfinishd test, raise PR first

coroutine for each server

chore: change uploads weights behavior, change test order

fix small bug

fixed test
This commit is contained in:
Zhuoran Y 2025-07-14 21:43:34 -07:00 committed by Zhuoran Y
parent 29e164a69d
commit 543a169116
6 changed files with 297 additions and 13 deletions

View File

@ -1,12 +1,11 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import enum
import itertools
import re
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional, Tuple
from transformers import PreTrainedTokenizerFast
@ -163,6 +162,13 @@ class WeightUpdateMeta:
alloc_mode: AllocationMode | None
comm_backend: str | None
model_version: int = 0
tp_size: int = 1
master_address: str = "127.0.0.1"
master_port: int = 29500
world_size: int = 1
group_name: str = "aupdate_weights_from_distributed"
parameter_names: List[str] = field(default_factory=list)
state_dict_key_to_shape: Dict[str, Tuple[int]] = field(default_factory=dict)
@dataclass

View File

@ -1,16 +1,28 @@
import dis
import gc
import os
import threading
import time
from datetime import datetime
from typing import Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
from sglang.srt.utils import init_custom_process_group
from tensordict import TensorDict
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
)
from transformers import PreTrainedTokenizerFast
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PreTrainedTokenizerFast,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
@ -137,6 +149,10 @@ class FSDPEngine(BaseHFEngine):
if meta.type == "nccl":
if not self.weight_update_group_initialized:
self._init_distributed_weight_update(meta)
print(
"Initialized distributed weight update group in training engine",
flush=True,
)
self._update_weights_from_distributed()
elif meta.type == "disk":
self._save_model_to_hf(meta.path, self.tokenizer)
@ -154,14 +170,53 @@ class FSDPEngine(BaseHFEngine):
raise ValueError(f"Unknown weight update type {meta.type}")
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
raise NotImplementedError(
"Distributed weight update is not implemented for FSDPEngine yet. "
)
if dist.get_rank() == 0:
print(
f"[FSDP Engine]World size: {meta.world_size}, Master address: {meta.master_address}, Master port: {meta.master_port}",
flush=True,
)
self.weight_update_group = init_custom_process_group(
backend="nccl",
world_size=meta.world_size,
init_method=f"tcp://{meta.master_address}:{meta.master_port}",
rank=0,
group_name=meta.group_name,
)
self.weight_update_group_initialized = True
dist.barrier(group=self.weight_update_group, device_ids=[0])
def _update_weights_from_distributed(self):
raise NotImplementedError(
"Distributed weight update is not implemented for FSDPEngine yet. "
)
"""Broadcast parameters from rank 0 (FSDP2 compatible)."""
if dist.get_rank() == 0:
for name, param in self.model.named_parameters():
if isinstance(param.data, DTensor):
tensor = param.data.full_tensor()
else:
tensor = param.data
print(f"Broadcasting {name} with shape {tensor.shape}", flush=True)
dist.broadcast(tensor, src=0, group=self.weight_update_group)
del tensor # optional, for memory hygiene
torch.cuda.empty_cache()
# dist.barrier(group=self.weight_update_group)
def get_param_meta_for_distributed_update(self) -> Dict[str, Tuple[int]]:
"""Return a dict mapping param name to its shape (expanded if DTensor)."""
param_shapes = {}
for name, param in self.model.named_parameters():
try:
if isinstance(param.data, DTensor):
tensor = param.data.full_tensor()
else:
tensor = param.data
param_shapes[name] = tuple(tensor.shape)
del tensor # free memory if full_tensor was created
except Exception as e:
logger.warning(f"Failed to get shape for param {name}: {e}")
return param_shapes
def train_batch(
self,

View File

@ -6,10 +6,10 @@ from concurrent.futures import ThreadPoolExecutor
from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import sglang as sgl
import torch.distributed as dist
from tensordict import TensorDict
import sglang as sgl
from arealite.api.cli_args import InferenceEngineConfig
from arealite.api.engine_api import InferenceEngine
from arealite.api.io_struct import (

View File

@ -75,6 +75,7 @@ class RemoteSGLangEngine(InferenceEngine):
self.lock = threading.Lock()
self.rollout_stat = RolloutStat()
self.distributed_weight_update_initialized = False
self._version = 0
@ -317,8 +318,32 @@ class RemoteSGLangEngine(InferenceEngine):
ttft=latency, # Simplified for non-streaming
)
def update_weights(self, meta: WeightUpdateMeta):
if meta.type == "disk":
def update_weights(self, meta):
executor = ThreadPoolExecutor(max_workers=1)
return executor.submit(self._update_weights, meta)
async def update_all_params_for_addr(self, addr, meta):
for param in meta.parameter_names:
await self.aupdate_weights_from_distributed(addr, meta, param)
def _update_weights(self, meta: WeightUpdateMeta):
if meta.type == "nccl":
if not self.distributed_weight_update_initialized:
self._init_distributed_weight_update(meta)
try:
jobs = [
self.update_all_params_for_addr(addr, meta)
for addr in self.addresses
]
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(asyncio.gather(*jobs))
finally:
loop.close()
logger.info(f"Distributed update weights done")
self.set_version(meta.model_version)
elif meta.type == "disk":
# Update weights from disk
# Use ProcessPool to bypass python GIL for running async coroutines
fut = self.executor.submit(
@ -340,6 +365,68 @@ class RemoteSGLangEngine(InferenceEngine):
else:
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
try:
# Initialize weights update group
jobs = [
self.ainit_weights_update_group(addr, meta) for addr in self.addresses
]
loop = asyncio.new_event_loop()
# asyncio event loop should be manually set when running asyncio stuff in another thread
asyncio.set_event_loop(loop)
loop.run_until_complete(asyncio.gather(*jobs))
self.distributed_weight_update_initialized = True
logger.info(f"Distributed update weights initialized")
finally:
loop.close()
async def ainit_weights_update_group(self, addr: str, meta: WeightUpdateMeta):
rank_offset = 1 + self.addresses.index(addr) * meta.tp_size
payload = {
"master_address": meta.master_address,
"master_port": meta.master_port,
"rank_offset": rank_offset,
"world_size": meta.world_size,
"group_name": meta.group_name,
"backend": "nccl",
}
response = await self.arequest_with_retry(
endpoint="/init_weights_update_group",
payload=payload,
method="POST",
max_retries=3,
timeout=self.config.request_timeout,
)
res = await response.json()
assert res["success"]
if "num_paused_requests" in res:
logger.info(
f"{res['num_paused_requests']} requests are interrupted "
f"during updating weights for server {addr}"
)
async def aupdate_weights_from_distributed(
self, addr: str, meta: WeightUpdateMeta, parameter_name: str
):
response = await self.arequest_with_retry(
endpoint="/update_weights_from_distributed",
payload={
"name": parameter_name,
"dtype": "bfloat16",
"shape": meta.state_dict_key_to_shape[parameter_name],
},
method="POST",
max_retries=3,
timeout=self.config.request_timeout,
)
res = await response.json()
assert res["success"]
if "num_paused_requests" in res:
logger.info(
f"{res['num_paused_requests']} requests are interrupted "
f"during updating weights for server {addr}"
)
def get_capacity(self):
if dist.is_initialized():
world_size = dist.get_world_size()

View File

@ -0,0 +1,135 @@
import os
import subprocess
import sys
import time
import uuid
import pytest
import requests
import torch
from arealite.api.cli_args import (
InferenceEngineConfig,
OptimizerConfig,
SGLangConfig,
TrainEngineConfig,
)
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
from arealite.engine.fsdp_engine import FSDPEngine
from arealite.engine.sglang_remote import RemoteSGLangEngine
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import network
EXPR_NAME = "test_fsdp_engine_nccl"
TRIAL_NAME = "trial_nccl"
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
if not os.path.exists(MODEL_PATH):
MODEL_PATH = "Qwen/Qwen2-0.5B"
PORT = 13998
DIST_PORT = 15998
GROUP_NAME = "test_nccl_group21"
MASTER_PORT = DIST_PORT + 1
HOST = network.gethostip()
RUN_SERVER_TIMEOUT = 180
def check_server_health(base_url):
try:
response = requests.get(f"{base_url}/metrics", timeout=30)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
@pytest.fixture(scope="module")
def sglang_server_nccl():
from realhf.base import seeding
seeding.set_random_seed(1, EXPR_NAME)
cmd = SGLangConfig.build_cmd(
sglang_config=SGLangConfig(mem_fraction_static=0.2),
model_path=MODEL_PATH,
tp_size=1,
base_gpu_id=1,
dist_init_addr=f"{HOST}:{DIST_PORT}",
served_model_name=MODEL_PATH,
skip_tokenizer_init=False,
)
full_command = f"{cmd} --port {PORT}"
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
print(f"full_command to start sglang server: {full_command}", flush=True)
process = subprocess.Popen(
full_command.split(),
text=True,
stdout=sys.stdout,
stderr=sys.stdout,
)
base_url = f"http://{HOST}:{PORT}"
tik = time.time()
while time.time() - tik < RUN_SERVER_TIMEOUT:
if check_server_health(base_url):
break
time.sleep(1)
if time.time() - tik > RUN_SERVER_TIMEOUT:
process.terminate()
raise RuntimeError("server launch failed")
yield
process.terminate()
def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server_nccl):
# 设置分布式环境变量
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = HOST
os.environ["MASTER_PORT"] = str(MASTER_PORT)
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"
# 启动本地FSDPEngine
engine_config = TrainEngineConfig(
experiment_name=EXPR_NAME,
trial_name=TRIAL_NAME,
path=MODEL_PATH,
optimizer=OptimizerConfig(),
)
engine = FSDPEngine(engine_config)
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
engine.initialize(None, ft_spec)
# 启动远端RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
config.server_addrs = [f"{HOST}:{PORT}"]
remote_engine = RemoteSGLangEngine(config)
# 构造WeightUpdateMetatype=nccl
param_meta = engine.get_param_meta_for_distributed_update()
meta = WeightUpdateMeta(
type="nccl",
path=None,
alloc_mode=None,
comm_backend="nccl",
model_version=123,
tp_size=1,
master_address=HOST,
master_port=MASTER_PORT,
world_size=2,
group_name="test_nccl_group12",
parameter_names=list(param_meta.keys()),
state_dict_key_to_shape=param_meta,
)
# 本地engine广播参数
remote_engine.initialize(None, None)
future = remote_engine.update_weights(meta)
print("got future", flush=True)
engine.upload_weights(meta)
print("uploaded wexights to remote engine", flush=True)
# 远端engine拉取参数
future.result(timeout=120)
print("got result", flush=True)
# 检查远端参数版本
assert remote_engine.get_version() == 123
remote_engine.destroy()

View File

@ -79,6 +79,7 @@ class InstallationValidator:
"""Test SGLang basic functionality."""
# Basic import test is sufficient for CI
import sgl_kernel
from sglang import launch_server
assert Version(get_version("sglang")) == Version("0.4.6.post4")
print(" - SGLang imported successfully")