mirror of https://github.com/inclusionAI/AReaL
Implement fsdp distributed update (#183)
* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * add gradient checkpointing * PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * fix * . * PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * fix destroy process group * PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/358 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * fix loss mask * fix * . * PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/368 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/371 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * added remote nccl weight update feat: implement update_weights_from_distributed in fsdp_engine.py unfinishd test, raise PR first coroutine for each server chore: change uploads weights behavior, change test order fix small bug fixed test * fix rebase * add test.sh * updated, test stil fails * . * . * . * fix: full_tensor() should happen in all rank (#187) Co-authored-by: ChangyiYang <changyiyang2023@gmail.com> --------- Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com> Co-authored-by: 博惟 <bowei.fw@antgroup.com> Co-authored-by: ChangyiYang <changyiyang2023@gmail.com> Co-authored-by: ChangyiYang <112288487+ChangyiYang@users.noreply.github.com>
This commit is contained in:
parent
29e164a69d
commit
f68a4f677d
|
@ -1,12 +1,11 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import enum
|
||||
import itertools
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
|
@ -163,6 +162,13 @@ class WeightUpdateMeta:
|
|||
alloc_mode: AllocationMode | None
|
||||
comm_backend: str | None
|
||||
model_version: int = 0
|
||||
tp_size: int = 1
|
||||
master_address: str = "127.0.0.1"
|
||||
master_port: int = 29500
|
||||
world_size: int = 1
|
||||
group_name: str = "aupdate_weights_from_distributed"
|
||||
parameter_names: List[str] = field(default_factory=list)
|
||||
state_dict_key_to_shape: Dict[str, Tuple[int]] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -66,12 +66,16 @@ class BaseHFEngine(TrainEngine):
|
|||
return self._parallelism_group
|
||||
|
||||
def create_process_group(self):
|
||||
# Required by NCCL weight update group for SGLang
|
||||
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||
if not dist.is_initialized():
|
||||
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
|
||||
# NOTE: device_id **SHOULD NOT** be passed into init_process_group,
|
||||
# otherwise initializing the NCCL weight update group will be wrong!
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
timeout=constants.NCCL_DEFAULT_TIMEOUT,
|
||||
device_id=torch.device(int(os.environ["LOCAL_RANK"])),
|
||||
)
|
||||
self.own_global_group = True
|
||||
self._parallelism_group = dist.new_group()
|
||||
|
|
|
@ -1,20 +1,32 @@
|
|||
import dis
|
||||
import gc
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tensordict import TensorDict
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
StateDictOptions,
|
||||
get_model_state_dict,
|
||||
)
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
PreTrainedTokenizerFast,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
||||
from arealite.engine.base_hf_engine import BaseHFEngine
|
||||
from arealite.utils.distributed import init_custom_process_group
|
||||
from arealite.utils.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
MixedPrecisionPolicy,
|
||||
|
@ -138,6 +150,8 @@ class FSDPEngine(BaseHFEngine):
|
|||
if not self.weight_update_group_initialized:
|
||||
self._init_distributed_weight_update(meta)
|
||||
self._update_weights_from_distributed()
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
elif meta.type == "disk":
|
||||
self._save_model_to_hf(meta.path, self.tokenizer)
|
||||
# dist.barrier() are called when _save_model_to_hf finished
|
||||
|
@ -154,14 +168,44 @@ class FSDPEngine(BaseHFEngine):
|
|||
raise ValueError(f"Unknown weight update type {meta.type}")
|
||||
|
||||
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
||||
raise NotImplementedError(
|
||||
"Distributed weight update is not implemented for FSDPEngine yet. "
|
||||
)
|
||||
if dist.get_rank() == 0:
|
||||
self.weight_update_group = init_custom_process_group(
|
||||
backend="nccl",
|
||||
world_size=meta.world_size,
|
||||
init_method=f"tcp://{meta.master_address}:{meta.master_port}",
|
||||
rank=0,
|
||||
group_name=meta.group_name,
|
||||
)
|
||||
# NOTE: synchronizing with sglang's barrier
|
||||
dist.barrier(group=self.weight_update_group, device_ids=[self.device.index])
|
||||
self.weight_update_group_initialized = True
|
||||
|
||||
def _update_weights_from_distributed(self):
|
||||
raise NotImplementedError(
|
||||
"Distributed weight update is not implemented for FSDPEngine yet. "
|
||||
)
|
||||
"""Broadcast parameters from rank 0 (FSDP2 compatible)."""
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
if isinstance(param.data, DTensor):
|
||||
tensor = param.data.full_tensor()
|
||||
else:
|
||||
tensor = param.data
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Broadcasting {name} with shape {tensor.shape}", flush=True)
|
||||
dist.broadcast(tensor, src=0, group=self.weight_update_group)
|
||||
dist.barrier()
|
||||
del tensor # optional, for memory hygiene
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_param_meta_for_distributed_update(self) -> Dict[str, Tuple[int]]:
|
||||
"""Return a dict mapping param name to its shape (expanded if DTensor)."""
|
||||
param_shapes = {}
|
||||
for name, param in self.model.named_parameters():
|
||||
if isinstance(param.data, DTensor):
|
||||
tensor = param.data.full_tensor()
|
||||
else:
|
||||
tensor = param.data
|
||||
param_shapes[name] = tuple(tensor.shape)
|
||||
del tensor # free memory if full_tensor was created
|
||||
return param_shapes
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
|
|
|
@ -4,7 +4,7 @@ import random
|
|||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from queue import Empty, Full, Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
||||
|
@ -75,6 +75,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
self.lock = threading.Lock()
|
||||
|
||||
self.rollout_stat = RolloutStat()
|
||||
self.distributed_weight_update_initialized = False
|
||||
|
||||
self._version = 0
|
||||
|
||||
|
@ -317,8 +318,31 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
|
||||
def update_weights(self, meta: WeightUpdateMeta):
|
||||
if meta.type == "disk":
|
||||
def update_weights(self, meta):
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
return executor.submit(self._update_weights, meta)
|
||||
|
||||
def _update_weights(self, meta: WeightUpdateMeta):
|
||||
if meta.type == "nccl":
|
||||
if not self.distributed_weight_update_initialized:
|
||||
self._init_distributed_weight_update(meta)
|
||||
tik = time.perf_counter()
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
for param in meta.parameter_names:
|
||||
jobs = [
|
||||
self.aupdate_weights_from_distributed(addr, meta, param)
|
||||
for addr in self.addresses
|
||||
]
|
||||
loop.run_until_complete(asyncio.gather(*jobs))
|
||||
finally:
|
||||
loop.close()
|
||||
logger.info(
|
||||
f"Distributed update weights done in {time.perf_counter() - tik}s"
|
||||
)
|
||||
self.set_version(meta.model_version)
|
||||
elif meta.type == "disk":
|
||||
# Update weights from disk
|
||||
# Use ProcessPool to bypass python GIL for running async coroutines
|
||||
fut = self.executor.submit(
|
||||
|
@ -340,6 +364,58 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
else:
|
||||
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
|
||||
|
||||
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
||||
try:
|
||||
# Initialize weights update group
|
||||
jobs = [
|
||||
self.ainit_weights_update_group(addr, meta) for addr in self.addresses
|
||||
]
|
||||
loop = asyncio.new_event_loop()
|
||||
# asyncio event loop should be manually set when running asyncio stuff in another thread
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(asyncio.gather(*jobs))
|
||||
self.distributed_weight_update_initialized = True
|
||||
logger.info(f"Distributed update weights initialized")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
async def ainit_weights_update_group(self, addr: str, meta: WeightUpdateMeta):
|
||||
rank_offset = 1 + self.addresses.index(addr) * meta.tp_size
|
||||
payload = {
|
||||
"master_address": meta.master_address,
|
||||
"master_port": str(meta.master_port),
|
||||
"rank_offset": rank_offset,
|
||||
"world_size": meta.world_size,
|
||||
"group_name": meta.group_name,
|
||||
"backend": "nccl",
|
||||
}
|
||||
res = await arequest_with_retry(
|
||||
addr=addr,
|
||||
endpoint="/init_weights_update_group",
|
||||
payload=payload,
|
||||
method="POST",
|
||||
max_retries=1,
|
||||
timeout=self.config.request_timeout,
|
||||
)
|
||||
assert res["success"]
|
||||
|
||||
async def aupdate_weights_from_distributed(
|
||||
self, addr: str, meta: WeightUpdateMeta, parameter_name: str
|
||||
):
|
||||
res = await arequest_with_retry(
|
||||
addr=addr,
|
||||
endpoint="/update_weights_from_distributed",
|
||||
payload={
|
||||
"name": parameter_name,
|
||||
"dtype": "bfloat16",
|
||||
"shape": meta.state_dict_key_to_shape[parameter_name],
|
||||
},
|
||||
method="POST",
|
||||
max_retries=1,
|
||||
timeout=self.config.request_timeout,
|
||||
)
|
||||
assert res["success"]
|
||||
|
||||
def get_capacity(self):
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
InferenceEngineConfig,
|
||||
OptimizerConfig,
|
||||
SGLangConfig,
|
||||
TrainEngineConfig,
|
||||
)
|
||||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.utils.network import find_free_ports
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import network
|
||||
|
||||
EXPR_NAME = "test_fsdp_engine_nccl"
|
||||
TRIAL_NAME = "trial_nccl"
|
||||
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
MODEL_PATH = "Qwen/Qwen2-0.5B"
|
||||
PORT = 13998
|
||||
DIST_PORT = 15998
|
||||
GROUP_NAME = "test_nccl_group"
|
||||
MASTER_PORT = DIST_PORT + 1
|
||||
HOST = network.gethostip()
|
||||
RUN_SERVER_TIMEOUT = 180
|
||||
|
||||
|
||||
def check_server_health(base_url):
|
||||
try:
|
||||
response = requests.get(f"{base_url}/metrics", timeout=30)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sglang_server_nccl():
|
||||
from realhf.base import seeding
|
||||
|
||||
seeding.set_random_seed(1, EXPR_NAME)
|
||||
cmd = SGLangConfig.build_cmd(
|
||||
sglang_config=SGLangConfig(
|
||||
mem_fraction_static=0.2,
|
||||
model_path=MODEL_PATH,
|
||||
skip_tokenizer_init=False,
|
||||
log_level="info",
|
||||
),
|
||||
tp_size=1,
|
||||
base_gpu_id=1,
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
dist_init_addr=f"{HOST}:{DIST_PORT}",
|
||||
)
|
||||
full_command = f"{cmd} --port {PORT}"
|
||||
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
|
||||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||||
|
||||
print(f"full_command to start sglang server: {full_command}", flush=True)
|
||||
process = subprocess.Popen(
|
||||
full_command.split(),
|
||||
text=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stdout,
|
||||
)
|
||||
base_url = f"http://{HOST}:{PORT}"
|
||||
tik = time.time()
|
||||
while time.time() - tik < RUN_SERVER_TIMEOUT:
|
||||
if check_server_health(base_url):
|
||||
break
|
||||
time.sleep(1)
|
||||
if time.time() - tik > RUN_SERVER_TIMEOUT:
|
||||
process.terminate()
|
||||
raise RuntimeError("server launch failed")
|
||||
yield
|
||||
process.terminate()
|
||||
|
||||
|
||||
def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server_nccl):
|
||||
# 设置分布式环境变量
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["MASTER_ADDR"] = HOST
|
||||
os.environ["MASTER_PORT"] = str(MASTER_PORT)
|
||||
|
||||
# 启动本地FSDPEngine
|
||||
engine_config = TrainEngineConfig(
|
||||
experiment_name=EXPR_NAME,
|
||||
trial_name=TRIAL_NAME,
|
||||
path=MODEL_PATH,
|
||||
optimizer=OptimizerConfig(),
|
||||
)
|
||||
engine = FSDPEngine(engine_config)
|
||||
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
|
||||
engine.initialize(None, ft_spec)
|
||||
|
||||
# 启动远端RemoteSGLangEngine
|
||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
remote_engine = RemoteSGLangEngine(config)
|
||||
remote_engine.initialize(None, None)
|
||||
|
||||
# 构造WeightUpdateMeta(type=nccl)
|
||||
param_meta = engine.get_param_meta_for_distributed_update()
|
||||
meta = WeightUpdateMeta(
|
||||
type="nccl",
|
||||
path=None,
|
||||
alloc_mode=None,
|
||||
comm_backend="nccl",
|
||||
model_version=123,
|
||||
tp_size=1,
|
||||
master_address="localhost",
|
||||
master_port=find_free_ports(1)[0],
|
||||
world_size=2,
|
||||
group_name=GROUP_NAME,
|
||||
parameter_names=list(param_meta.keys()),
|
||||
state_dict_key_to_shape=param_meta,
|
||||
)
|
||||
|
||||
# 本地engine广播参数
|
||||
future = remote_engine.update_weights(meta)
|
||||
print("got future", flush=True)
|
||||
engine.upload_weights(meta)
|
||||
print("uploaded wexights to remote engine", flush=True)
|
||||
# 远端engine拉取参数
|
||||
future.result(timeout=120)
|
||||
print("got result", flush=True)
|
||||
# 检查远端参数版本
|
||||
assert remote_engine.get_version() == 123
|
||||
remote_engine.destroy()
|
||||
engine.destroy()
|
|
@ -0,0 +1,73 @@
|
|||
import torch
|
||||
|
||||
|
||||
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
|
||||
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
|
||||
def init_custom_process_group(
|
||||
backend=None,
|
||||
init_method=None,
|
||||
timeout=None,
|
||||
world_size=-1,
|
||||
rank=-1,
|
||||
store=None,
|
||||
group_name=None,
|
||||
pg_options=None,
|
||||
):
|
||||
from torch.distributed.distributed_c10d import (
|
||||
Backend,
|
||||
PrefixStore,
|
||||
_new_process_group_helper,
|
||||
_world,
|
||||
default_pg_timeout,
|
||||
rendezvous,
|
||||
)
|
||||
|
||||
assert (store is None) or (
|
||||
init_method is None
|
||||
), "Cannot specify both init_method and store."
|
||||
|
||||
if store is not None:
|
||||
assert world_size > 0, "world_size must be positive if using store"
|
||||
assert rank >= 0, "rank must be non-negative if using store"
|
||||
elif init_method is None:
|
||||
init_method = "env://"
|
||||
|
||||
if backend:
|
||||
backend = Backend(backend)
|
||||
else:
|
||||
backend = Backend("undefined")
|
||||
|
||||
if timeout is None:
|
||||
timeout = default_pg_timeout
|
||||
|
||||
# backward compatible API
|
||||
if store is None:
|
||||
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
|
||||
store, rank, world_size = next(rendezvous_iterator)
|
||||
store.set_timeout(timeout)
|
||||
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
store = PrefixStore(group_name, store)
|
||||
|
||||
# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
|
||||
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
|
||||
# We need to determine the appropriate parameter name based on PyTorch version
|
||||
pg_options_param_name = (
|
||||
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
|
||||
)
|
||||
pg, _ = _new_process_group_helper(
|
||||
world_size,
|
||||
rank,
|
||||
[],
|
||||
backend,
|
||||
store,
|
||||
group_name=group_name,
|
||||
**{pg_options_param_name: pg_options},
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
||||
|
||||
return pg
|
Loading…
Reference in New Issue