mirror of https://github.com/inclusionAI/AReaL
added remote nccl weight update
feat: implement update_weights_from_distributed in fsdp_engine.py unfinishd test, raise PR first coroutine for each server chore: change uploads weights behavior, change test order fix small bug fixed test
This commit is contained in:
parent
29e164a69d
commit
543a169116
|
@ -1,12 +1,11 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import enum
|
||||
import itertools
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
|
@ -163,6 +162,13 @@ class WeightUpdateMeta:
|
|||
alloc_mode: AllocationMode | None
|
||||
comm_backend: str | None
|
||||
model_version: int = 0
|
||||
tp_size: int = 1
|
||||
master_address: str = "127.0.0.1"
|
||||
master_port: int = 29500
|
||||
world_size: int = 1
|
||||
group_name: str = "aupdate_weights_from_distributed"
|
||||
parameter_names: List[str] = field(default_factory=list)
|
||||
state_dict_key_to_shape: Dict[str, Tuple[int]] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -1,16 +1,28 @@
|
|||
import dis
|
||||
import gc
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from sglang.srt.utils import init_custom_process_group
|
||||
from tensordict import TensorDict
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
StateDictOptions,
|
||||
get_model_state_dict,
|
||||
)
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
PreTrainedTokenizerFast,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
||||
|
@ -137,6 +149,10 @@ class FSDPEngine(BaseHFEngine):
|
|||
if meta.type == "nccl":
|
||||
if not self.weight_update_group_initialized:
|
||||
self._init_distributed_weight_update(meta)
|
||||
print(
|
||||
"Initialized distributed weight update group in training engine",
|
||||
flush=True,
|
||||
)
|
||||
self._update_weights_from_distributed()
|
||||
elif meta.type == "disk":
|
||||
self._save_model_to_hf(meta.path, self.tokenizer)
|
||||
|
@ -154,14 +170,53 @@ class FSDPEngine(BaseHFEngine):
|
|||
raise ValueError(f"Unknown weight update type {meta.type}")
|
||||
|
||||
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
||||
raise NotImplementedError(
|
||||
"Distributed weight update is not implemented for FSDPEngine yet. "
|
||||
)
|
||||
if dist.get_rank() == 0:
|
||||
print(
|
||||
f"[FSDP Engine]World size: {meta.world_size}, Master address: {meta.master_address}, Master port: {meta.master_port}",
|
||||
flush=True,
|
||||
)
|
||||
self.weight_update_group = init_custom_process_group(
|
||||
backend="nccl",
|
||||
world_size=meta.world_size,
|
||||
init_method=f"tcp://{meta.master_address}:{meta.master_port}",
|
||||
rank=0,
|
||||
group_name=meta.group_name,
|
||||
)
|
||||
self.weight_update_group_initialized = True
|
||||
dist.barrier(group=self.weight_update_group, device_ids=[0])
|
||||
|
||||
def _update_weights_from_distributed(self):
|
||||
raise NotImplementedError(
|
||||
"Distributed weight update is not implemented for FSDPEngine yet. "
|
||||
)
|
||||
"""Broadcast parameters from rank 0 (FSDP2 compatible)."""
|
||||
if dist.get_rank() == 0:
|
||||
for name, param in self.model.named_parameters():
|
||||
if isinstance(param.data, DTensor):
|
||||
tensor = param.data.full_tensor()
|
||||
else:
|
||||
tensor = param.data
|
||||
print(f"Broadcasting {name} with shape {tensor.shape}", flush=True)
|
||||
dist.broadcast(tensor, src=0, group=self.weight_update_group)
|
||||
del tensor # optional, for memory hygiene
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# dist.barrier(group=self.weight_update_group)
|
||||
|
||||
def get_param_meta_for_distributed_update(self) -> Dict[str, Tuple[int]]:
|
||||
"""Return a dict mapping param name to its shape (expanded if DTensor)."""
|
||||
param_shapes = {}
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
try:
|
||||
if isinstance(param.data, DTensor):
|
||||
tensor = param.data.full_tensor()
|
||||
else:
|
||||
tensor = param.data
|
||||
|
||||
param_shapes[name] = tuple(tensor.shape)
|
||||
del tensor # free memory if full_tensor was created
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get shape for param {name}: {e}")
|
||||
|
||||
return param_shapes
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
|
|
|
@ -6,10 +6,10 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
from queue import Empty, Full, Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
import sglang as sgl
|
||||
import torch.distributed as dist
|
||||
from tensordict import TensorDict
|
||||
|
||||
import sglang as sgl
|
||||
from arealite.api.cli_args import InferenceEngineConfig
|
||||
from arealite.api.engine_api import InferenceEngine
|
||||
from arealite.api.io_struct import (
|
||||
|
|
|
@ -75,6 +75,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
self.lock = threading.Lock()
|
||||
|
||||
self.rollout_stat = RolloutStat()
|
||||
self.distributed_weight_update_initialized = False
|
||||
|
||||
self._version = 0
|
||||
|
||||
|
@ -317,8 +318,32 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
|
||||
def update_weights(self, meta: WeightUpdateMeta):
|
||||
if meta.type == "disk":
|
||||
def update_weights(self, meta):
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
return executor.submit(self._update_weights, meta)
|
||||
|
||||
async def update_all_params_for_addr(self, addr, meta):
|
||||
for param in meta.parameter_names:
|
||||
await self.aupdate_weights_from_distributed(addr, meta, param)
|
||||
|
||||
def _update_weights(self, meta: WeightUpdateMeta):
|
||||
if meta.type == "nccl":
|
||||
if not self.distributed_weight_update_initialized:
|
||||
self._init_distributed_weight_update(meta)
|
||||
try:
|
||||
jobs = [
|
||||
self.update_all_params_for_addr(addr, meta)
|
||||
for addr in self.addresses
|
||||
]
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(asyncio.gather(*jobs))
|
||||
finally:
|
||||
loop.close()
|
||||
logger.info(f"Distributed update weights done")
|
||||
self.set_version(meta.model_version)
|
||||
|
||||
elif meta.type == "disk":
|
||||
# Update weights from disk
|
||||
# Use ProcessPool to bypass python GIL for running async coroutines
|
||||
fut = self.executor.submit(
|
||||
|
@ -340,6 +365,68 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
else:
|
||||
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
|
||||
|
||||
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
||||
try:
|
||||
# Initialize weights update group
|
||||
jobs = [
|
||||
self.ainit_weights_update_group(addr, meta) for addr in self.addresses
|
||||
]
|
||||
loop = asyncio.new_event_loop()
|
||||
# asyncio event loop should be manually set when running asyncio stuff in another thread
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(asyncio.gather(*jobs))
|
||||
self.distributed_weight_update_initialized = True
|
||||
logger.info(f"Distributed update weights initialized")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
async def ainit_weights_update_group(self, addr: str, meta: WeightUpdateMeta):
|
||||
rank_offset = 1 + self.addresses.index(addr) * meta.tp_size
|
||||
payload = {
|
||||
"master_address": meta.master_address,
|
||||
"master_port": meta.master_port,
|
||||
"rank_offset": rank_offset,
|
||||
"world_size": meta.world_size,
|
||||
"group_name": meta.group_name,
|
||||
"backend": "nccl",
|
||||
}
|
||||
response = await self.arequest_with_retry(
|
||||
endpoint="/init_weights_update_group",
|
||||
payload=payload,
|
||||
method="POST",
|
||||
max_retries=3,
|
||||
timeout=self.config.request_timeout,
|
||||
)
|
||||
res = await response.json()
|
||||
assert res["success"]
|
||||
if "num_paused_requests" in res:
|
||||
logger.info(
|
||||
f"{res['num_paused_requests']} requests are interrupted "
|
||||
f"during updating weights for server {addr}"
|
||||
)
|
||||
|
||||
async def aupdate_weights_from_distributed(
|
||||
self, addr: str, meta: WeightUpdateMeta, parameter_name: str
|
||||
):
|
||||
response = await self.arequest_with_retry(
|
||||
endpoint="/update_weights_from_distributed",
|
||||
payload={
|
||||
"name": parameter_name,
|
||||
"dtype": "bfloat16",
|
||||
"shape": meta.state_dict_key_to_shape[parameter_name],
|
||||
},
|
||||
method="POST",
|
||||
max_retries=3,
|
||||
timeout=self.config.request_timeout,
|
||||
)
|
||||
res = await response.json()
|
||||
assert res["success"]
|
||||
if "num_paused_requests" in res:
|
||||
logger.info(
|
||||
f"{res['num_paused_requests']} requests are interrupted "
|
||||
f"during updating weights for server {addr}"
|
||||
)
|
||||
|
||||
def get_capacity(self):
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
InferenceEngineConfig,
|
||||
OptimizerConfig,
|
||||
SGLangConfig,
|
||||
TrainEngineConfig,
|
||||
)
|
||||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import network
|
||||
|
||||
EXPR_NAME = "test_fsdp_engine_nccl"
|
||||
TRIAL_NAME = "trial_nccl"
|
||||
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
MODEL_PATH = "Qwen/Qwen2-0.5B"
|
||||
PORT = 13998
|
||||
DIST_PORT = 15998
|
||||
GROUP_NAME = "test_nccl_group21"
|
||||
MASTER_PORT = DIST_PORT + 1
|
||||
HOST = network.gethostip()
|
||||
RUN_SERVER_TIMEOUT = 180
|
||||
|
||||
|
||||
def check_server_health(base_url):
|
||||
try:
|
||||
response = requests.get(f"{base_url}/metrics", timeout=30)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sglang_server_nccl():
|
||||
from realhf.base import seeding
|
||||
|
||||
seeding.set_random_seed(1, EXPR_NAME)
|
||||
cmd = SGLangConfig.build_cmd(
|
||||
sglang_config=SGLangConfig(mem_fraction_static=0.2),
|
||||
model_path=MODEL_PATH,
|
||||
tp_size=1,
|
||||
base_gpu_id=1,
|
||||
dist_init_addr=f"{HOST}:{DIST_PORT}",
|
||||
served_model_name=MODEL_PATH,
|
||||
skip_tokenizer_init=False,
|
||||
)
|
||||
full_command = f"{cmd} --port {PORT}"
|
||||
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
|
||||
|
||||
print(f"full_command to start sglang server: {full_command}", flush=True)
|
||||
process = subprocess.Popen(
|
||||
full_command.split(),
|
||||
text=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stdout,
|
||||
)
|
||||
base_url = f"http://{HOST}:{PORT}"
|
||||
tik = time.time()
|
||||
while time.time() - tik < RUN_SERVER_TIMEOUT:
|
||||
if check_server_health(base_url):
|
||||
break
|
||||
time.sleep(1)
|
||||
if time.time() - tik > RUN_SERVER_TIMEOUT:
|
||||
process.terminate()
|
||||
raise RuntimeError("server launch failed")
|
||||
yield
|
||||
process.terminate()
|
||||
|
||||
|
||||
def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server_nccl):
|
||||
# 设置分布式环境变量
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["MASTER_ADDR"] = HOST
|
||||
os.environ["MASTER_PORT"] = str(MASTER_PORT)
|
||||
os.environ["NCCL_P2P_DISABLE"] = "1"
|
||||
os.environ["NCCL_IB_DISABLE"] = "1"
|
||||
|
||||
# 启动本地FSDPEngine
|
||||
engine_config = TrainEngineConfig(
|
||||
experiment_name=EXPR_NAME,
|
||||
trial_name=TRIAL_NAME,
|
||||
path=MODEL_PATH,
|
||||
optimizer=OptimizerConfig(),
|
||||
)
|
||||
engine = FSDPEngine(engine_config)
|
||||
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
|
||||
engine.initialize(None, ft_spec)
|
||||
|
||||
# 启动远端RemoteSGLangEngine
|
||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
remote_engine = RemoteSGLangEngine(config)
|
||||
|
||||
# 构造WeightUpdateMeta(type=nccl)
|
||||
param_meta = engine.get_param_meta_for_distributed_update()
|
||||
meta = WeightUpdateMeta(
|
||||
type="nccl",
|
||||
path=None,
|
||||
alloc_mode=None,
|
||||
comm_backend="nccl",
|
||||
model_version=123,
|
||||
tp_size=1,
|
||||
master_address=HOST,
|
||||
master_port=MASTER_PORT,
|
||||
world_size=2,
|
||||
group_name="test_nccl_group12",
|
||||
parameter_names=list(param_meta.keys()),
|
||||
state_dict_key_to_shape=param_meta,
|
||||
)
|
||||
# 本地engine广播参数
|
||||
remote_engine.initialize(None, None)
|
||||
|
||||
future = remote_engine.update_weights(meta)
|
||||
print("got future", flush=True)
|
||||
engine.upload_weights(meta)
|
||||
print("uploaded wexights to remote engine", flush=True)
|
||||
# 远端engine拉取参数
|
||||
future.result(timeout=120)
|
||||
print("got result", flush=True)
|
||||
# 检查远端参数版本
|
||||
assert remote_engine.get_version() == 123
|
||||
remote_engine.destroy()
|
|
@ -79,6 +79,7 @@ class InstallationValidator:
|
|||
"""Test SGLang basic functionality."""
|
||||
# Basic import test is sufficient for CI
|
||||
import sgl_kernel
|
||||
|
||||
from sglang import launch_server
|
||||
assert Version(get_version("sglang")) == Version("0.4.6.post4")
|
||||
print(" - SGLang imported successfully")
|
||||
|
|
Loading…
Reference in New Issue