AReaL/arealite/tests/test_fsdp_engine_nccl.py

138 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import subprocess
import sys
import time
import pytest
import requests
from arealite.api.cli_args import (
InferenceEngineConfig,
OptimizerConfig,
SGLangConfig,
TrainEngineConfig,
)
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
from arealite.engine.fsdp_engine import FSDPEngine
from arealite.engine.sglang_remote import RemoteSGLangEngine
from arealite.utils.network import find_free_ports
from realhf.base import network
EXPR_NAME = "test_fsdp_engine_nccl"
TRIAL_NAME = "trial_nccl"
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
if not os.path.exists(MODEL_PATH):
MODEL_PATH = "Qwen/Qwen2-0.5B"
PORT = 13998
DIST_PORT = 15998
GROUP_NAME = "test_nccl_group"
MASTER_PORT = DIST_PORT + 1
HOST = network.gethostip()
RUN_SERVER_TIMEOUT = 180
def check_server_health(base_url):
try:
response = requests.get(f"{base_url}/metrics", timeout=30)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
@pytest.fixture(scope="module")
def sglang_server_nccl():
from realhf.base import seeding
seeding.set_random_seed(1, EXPR_NAME)
cmd = SGLangConfig.build_cmd(
sglang_config=SGLangConfig(
mem_fraction_static=0.2,
model_path=MODEL_PATH,
skip_tokenizer_init=False,
log_level="info",
),
tp_size=1,
base_gpu_id=1,
host=HOST,
port=PORT,
dist_init_addr=f"{HOST}:{DIST_PORT}",
)
full_command = f"{cmd} --port {PORT}"
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
print(f"full_command to start sglang server: {full_command}", flush=True)
process = subprocess.Popen(
full_command.split(),
text=True,
stdout=sys.stdout,
stderr=sys.stdout,
)
base_url = f"http://{HOST}:{PORT}"
tik = time.time()
while time.time() - tik < RUN_SERVER_TIMEOUT:
if check_server_health(base_url):
break
time.sleep(1)
if time.time() - tik > RUN_SERVER_TIMEOUT:
process.terminate()
raise RuntimeError("server launch failed")
yield
process.terminate()
def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server_nccl):
# 设置分布式环境变量
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = HOST
os.environ["MASTER_PORT"] = str(MASTER_PORT)
# 启动本地FSDPEngine
engine_config = TrainEngineConfig(
experiment_name=EXPR_NAME,
trial_name=TRIAL_NAME,
path=MODEL_PATH,
optimizer=OptimizerConfig(),
)
engine = FSDPEngine(engine_config)
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
engine.initialize(None, ft_spec)
# 启动远端RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
config.server_addrs = [f"{HOST}:{PORT}"]
remote_engine = RemoteSGLangEngine(config)
remote_engine.initialize(None, None)
# 构造WeightUpdateMetatype=nccl
param_meta = engine.get_param_meta_for_distributed_update()
meta = WeightUpdateMeta(
type="nccl",
path=None,
alloc_mode=None,
comm_backend="nccl",
model_version=123,
tp_size=1,
master_address="localhost",
master_port=find_free_ports(1)[0],
world_size=2,
group_name=GROUP_NAME,
parameter_names=list(param_meta.keys()),
state_dict_key_to_shape=param_meta,
)
# 本地engine广播参数
future = remote_engine.update_weights(meta)
print("got future", flush=True)
engine.upload_weights(meta)
print("uploaded wexights to remote engine", flush=True)
# 远端engine拉取参数
future.result(timeout=120)
print("got result", flush=True)
# 检查远端参数版本
assert remote_engine.get_version() == 123
remote_engine.destroy()
engine.destroy()