Implement fsdp distributed update (#183)

* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine

Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/353

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* add gradient checkpointing

* PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP  engine

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/354

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .

* PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* fix
* .

* PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities

Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* fix destroy process group

* PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset

Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/358

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* fix loss mask
* fix
* .

* PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub

Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/368

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .

* PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation

Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/371

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* added remote nccl weight update

feat: implement update_weights_from_distributed in fsdp_engine.py

unfinishd test, raise PR first

coroutine for each server

chore: change uploads weights behavior, change test order

fix small bug

fixed test

* fix rebase

* add test.sh

* updated, test stil fails

* .

* .

* .

* fix: full_tensor() should happen in all rank (#187)

Co-authored-by: ChangyiYang <changyiyang2023@gmail.com>

---------

Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
Co-authored-by: 博惟 <bowei.fw@antgroup.com>
Co-authored-by: ChangyiYang <changyiyang2023@gmail.com>
Co-authored-by: ChangyiYang <112288487+ChangyiYang@users.noreply.github.com>
This commit is contained in:
Night 2025-07-20 21:42:00 -07:00 committed by GitHub
parent 29e164a69d
commit f68a4f677d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 357 additions and 14 deletions

View File

@ -1,12 +1,11 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import enum
import itertools
import re
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional, Tuple
from transformers import PreTrainedTokenizerFast
@ -163,6 +162,13 @@ class WeightUpdateMeta:
alloc_mode: AllocationMode | None
comm_backend: str | None
model_version: int = 0
tp_size: int = 1
master_address: str = "127.0.0.1"
master_port: int = 29500
world_size: int = 1
group_name: str = "aupdate_weights_from_distributed"
parameter_names: List[str] = field(default_factory=list)
state_dict_key_to_shape: Dict[str, Tuple[int]] = field(default_factory=dict)
@dataclass

View File

@ -66,12 +66,16 @@ class BaseHFEngine(TrainEngine):
return self._parallelism_group
def create_process_group(self):
# Required by NCCL weight update group for SGLang
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
if not dist.is_initialized():
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
# NOTE: device_id **SHOULD NOT** be passed into init_process_group,
# otherwise initializing the NCCL weight update group will be wrong!
dist.init_process_group(
backend="nccl",
timeout=constants.NCCL_DEFAULT_TIMEOUT,
device_id=torch.device(int(os.environ["LOCAL_RANK"])),
)
self.own_global_group = True
self._parallelism_group = dist.new_group()

View File

@ -1,20 +1,32 @@
import dis
import gc
import os
import threading
import time
from datetime import datetime
from typing import Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
from tensordict import TensorDict
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
)
from transformers import PreTrainedTokenizerFast
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PreTrainedTokenizerFast,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
from arealite.engine.base_hf_engine import BaseHFEngine
from arealite.utils.distributed import init_custom_process_group
from arealite.utils.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
@ -138,6 +150,8 @@ class FSDPEngine(BaseHFEngine):
if not self.weight_update_group_initialized:
self._init_distributed_weight_update(meta)
self._update_weights_from_distributed()
dist.barrier()
torch.cuda.synchronize()
elif meta.type == "disk":
self._save_model_to_hf(meta.path, self.tokenizer)
# dist.barrier() are called when _save_model_to_hf finished
@ -154,14 +168,44 @@ class FSDPEngine(BaseHFEngine):
raise ValueError(f"Unknown weight update type {meta.type}")
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
raise NotImplementedError(
"Distributed weight update is not implemented for FSDPEngine yet. "
)
if dist.get_rank() == 0:
self.weight_update_group = init_custom_process_group(
backend="nccl",
world_size=meta.world_size,
init_method=f"tcp://{meta.master_address}:{meta.master_port}",
rank=0,
group_name=meta.group_name,
)
# NOTE: synchronizing with sglang's barrier
dist.barrier(group=self.weight_update_group, device_ids=[self.device.index])
self.weight_update_group_initialized = True
def _update_weights_from_distributed(self):
raise NotImplementedError(
"Distributed weight update is not implemented for FSDPEngine yet. "
)
"""Broadcast parameters from rank 0 (FSDP2 compatible)."""
for name, param in self.model.named_parameters():
if isinstance(param.data, DTensor):
tensor = param.data.full_tensor()
else:
tensor = param.data
if dist.get_rank() == 0:
print(f"Broadcasting {name} with shape {tensor.shape}", flush=True)
dist.broadcast(tensor, src=0, group=self.weight_update_group)
dist.barrier()
del tensor # optional, for memory hygiene
torch.cuda.empty_cache()
def get_param_meta_for_distributed_update(self) -> Dict[str, Tuple[int]]:
"""Return a dict mapping param name to its shape (expanded if DTensor)."""
param_shapes = {}
for name, param in self.model.named_parameters():
if isinstance(param.data, DTensor):
tensor = param.data.full_tensor()
else:
tensor = param.data
param_shapes[name] = tuple(tensor.shape)
del tensor # free memory if full_tensor was created
return param_shapes
def train_batch(
self,

View File

@ -4,7 +4,7 @@ import random
import threading
import time
import traceback
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from datetime import datetime
from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, List
@ -75,6 +75,7 @@ class RemoteSGLangEngine(InferenceEngine):
self.lock = threading.Lock()
self.rollout_stat = RolloutStat()
self.distributed_weight_update_initialized = False
self._version = 0
@ -317,8 +318,31 @@ class RemoteSGLangEngine(InferenceEngine):
ttft=latency, # Simplified for non-streaming
)
def update_weights(self, meta: WeightUpdateMeta):
if meta.type == "disk":
def update_weights(self, meta):
executor = ThreadPoolExecutor(max_workers=1)
return executor.submit(self._update_weights, meta)
def _update_weights(self, meta: WeightUpdateMeta):
if meta.type == "nccl":
if not self.distributed_weight_update_initialized:
self._init_distributed_weight_update(meta)
tik = time.perf_counter()
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
for param in meta.parameter_names:
jobs = [
self.aupdate_weights_from_distributed(addr, meta, param)
for addr in self.addresses
]
loop.run_until_complete(asyncio.gather(*jobs))
finally:
loop.close()
logger.info(
f"Distributed update weights done in {time.perf_counter() - tik}s"
)
self.set_version(meta.model_version)
elif meta.type == "disk":
# Update weights from disk
# Use ProcessPool to bypass python GIL for running async coroutines
fut = self.executor.submit(
@ -340,6 +364,58 @@ class RemoteSGLangEngine(InferenceEngine):
else:
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
try:
# Initialize weights update group
jobs = [
self.ainit_weights_update_group(addr, meta) for addr in self.addresses
]
loop = asyncio.new_event_loop()
# asyncio event loop should be manually set when running asyncio stuff in another thread
asyncio.set_event_loop(loop)
loop.run_until_complete(asyncio.gather(*jobs))
self.distributed_weight_update_initialized = True
logger.info(f"Distributed update weights initialized")
finally:
loop.close()
async def ainit_weights_update_group(self, addr: str, meta: WeightUpdateMeta):
rank_offset = 1 + self.addresses.index(addr) * meta.tp_size
payload = {
"master_address": meta.master_address,
"master_port": str(meta.master_port),
"rank_offset": rank_offset,
"world_size": meta.world_size,
"group_name": meta.group_name,
"backend": "nccl",
}
res = await arequest_with_retry(
addr=addr,
endpoint="/init_weights_update_group",
payload=payload,
method="POST",
max_retries=1,
timeout=self.config.request_timeout,
)
assert res["success"]
async def aupdate_weights_from_distributed(
self, addr: str, meta: WeightUpdateMeta, parameter_name: str
):
res = await arequest_with_retry(
addr=addr,
endpoint="/update_weights_from_distributed",
payload={
"name": parameter_name,
"dtype": "bfloat16",
"shape": meta.state_dict_key_to_shape[parameter_name],
},
method="POST",
max_retries=1,
timeout=self.config.request_timeout,
)
assert res["success"]
def get_capacity(self):
if dist.is_initialized():
world_size = dist.get_world_size()

View File

@ -0,0 +1,140 @@
import os
import subprocess
import sys
import time
import uuid
import pytest
import requests
import torch
from arealite.api.cli_args import (
InferenceEngineConfig,
OptimizerConfig,
SGLangConfig,
TrainEngineConfig,
)
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
from arealite.engine.fsdp_engine import FSDPEngine
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.network import find_free_ports
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import network
EXPR_NAME = "test_fsdp_engine_nccl"
TRIAL_NAME = "trial_nccl"
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
if not os.path.exists(MODEL_PATH):
MODEL_PATH = "Qwen/Qwen2-0.5B"
PORT = 13998
DIST_PORT = 15998
GROUP_NAME = "test_nccl_group"
MASTER_PORT = DIST_PORT + 1
HOST = network.gethostip()
RUN_SERVER_TIMEOUT = 180
def check_server_health(base_url):
try:
response = requests.get(f"{base_url}/metrics", timeout=30)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
@pytest.fixture(scope="module")
def sglang_server_nccl():
from realhf.base import seeding
seeding.set_random_seed(1, EXPR_NAME)
cmd = SGLangConfig.build_cmd(
sglang_config=SGLangConfig(
mem_fraction_static=0.2,
model_path=MODEL_PATH,
skip_tokenizer_init=False,
log_level="info",
),
tp_size=1,
base_gpu_id=1,
host=HOST,
port=PORT,
dist_init_addr=f"{HOST}:{DIST_PORT}",
)
full_command = f"{cmd} --port {PORT}"
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
print(f"full_command to start sglang server: {full_command}", flush=True)
process = subprocess.Popen(
full_command.split(),
text=True,
stdout=sys.stdout,
stderr=sys.stdout,
)
base_url = f"http://{HOST}:{PORT}"
tik = time.time()
while time.time() - tik < RUN_SERVER_TIMEOUT:
if check_server_health(base_url):
break
time.sleep(1)
if time.time() - tik > RUN_SERVER_TIMEOUT:
process.terminate()
raise RuntimeError("server launch failed")
yield
process.terminate()
def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server_nccl):
# 设置分布式环境变量
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = HOST
os.environ["MASTER_PORT"] = str(MASTER_PORT)
# 启动本地FSDPEngine
engine_config = TrainEngineConfig(
experiment_name=EXPR_NAME,
trial_name=TRIAL_NAME,
path=MODEL_PATH,
optimizer=OptimizerConfig(),
)
engine = FSDPEngine(engine_config)
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
engine.initialize(None, ft_spec)
# 启动远端RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
config.server_addrs = [f"{HOST}:{PORT}"]
remote_engine = RemoteSGLangEngine(config)
remote_engine.initialize(None, None)
# 构造WeightUpdateMetatype=nccl
param_meta = engine.get_param_meta_for_distributed_update()
meta = WeightUpdateMeta(
type="nccl",
path=None,
alloc_mode=None,
comm_backend="nccl",
model_version=123,
tp_size=1,
master_address="localhost",
master_port=find_free_ports(1)[0],
world_size=2,
group_name=GROUP_NAME,
parameter_names=list(param_meta.keys()),
state_dict_key_to_shape=param_meta,
)
# 本地engine广播参数
future = remote_engine.update_weights(meta)
print("got future", flush=True)
engine.upload_weights(meta)
print("uploaded wexights to remote engine", flush=True)
# 远端engine拉取参数
future.result(timeout=120)
print("got result", flush=True)
# 检查远端参数版本
assert remote_engine.get_version() == 123
remote_engine.destroy()
engine.destroy()

View File

@ -0,0 +1,73 @@
import torch
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
def init_custom_process_group(
backend=None,
init_method=None,
timeout=None,
world_size=-1,
rank=-1,
store=None,
group_name=None,
pg_options=None,
):
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
_new_process_group_helper,
_world,
default_pg_timeout,
rendezvous,
)
assert (store is None) or (
init_method is None
), "Cannot specify both init_method and store."
if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
if backend:
backend = Backend(backend)
else:
backend = Backend("undefined")
if timeout is None:
timeout = default_pg_timeout
# backward compatible API
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
store = PrefixStore(group_name, store)
# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
# We need to determine the appropriate parameter name based on PyTorch version
pg_options_param_name = (
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
)
pg, _ = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
group_name=group_name,
**{pg_options_param_name: pg_options},
timeout=timeout,
)
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
return pg