mirror of https://github.com/inclusionAI/AReaL
138 lines
4.0 KiB
Python
138 lines
4.0 KiB
Python
import os
|
||
import subprocess
|
||
import sys
|
||
import time
|
||
|
||
import pytest
|
||
import requests
|
||
|
||
from arealite.api.cli_args import (
|
||
InferenceEngineConfig,
|
||
OptimizerConfig,
|
||
SGLangConfig,
|
||
TrainEngineConfig,
|
||
)
|
||
from arealite.api.io_struct import FinetuneSpec, WeightUpdateMeta
|
||
from arealite.engine.fsdp_engine import FSDPEngine
|
||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||
from arealite.utils.network import find_free_ports
|
||
from realhf.base import network
|
||
|
||
EXPR_NAME = "test_fsdp_engine_nccl"
|
||
TRIAL_NAME = "trial_nccl"
|
||
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
|
||
if not os.path.exists(MODEL_PATH):
|
||
MODEL_PATH = "Qwen/Qwen2-0.5B"
|
||
PORT = 13998
|
||
DIST_PORT = 15998
|
||
GROUP_NAME = "test_nccl_group"
|
||
MASTER_PORT = DIST_PORT + 1
|
||
HOST = network.gethostip()
|
||
RUN_SERVER_TIMEOUT = 180
|
||
|
||
|
||
def check_server_health(base_url):
|
||
try:
|
||
response = requests.get(f"{base_url}/metrics", timeout=30)
|
||
return response.status_code == 200
|
||
except requests.exceptions.RequestException:
|
||
return False
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
def sglang_server_nccl():
|
||
from realhf.base import seeding
|
||
|
||
seeding.set_random_seed(1, EXPR_NAME)
|
||
cmd = SGLangConfig.build_cmd(
|
||
sglang_config=SGLangConfig(
|
||
mem_fraction_static=0.2,
|
||
model_path=MODEL_PATH,
|
||
skip_tokenizer_init=False,
|
||
log_level="info",
|
||
),
|
||
tp_size=1,
|
||
base_gpu_id=1,
|
||
host=HOST,
|
||
port=PORT,
|
||
dist_init_addr=f"{HOST}:{DIST_PORT}",
|
||
)
|
||
full_command = f"{cmd} --port {PORT}"
|
||
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
|
||
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
|
||
|
||
print(f"full_command to start sglang server: {full_command}", flush=True)
|
||
process = subprocess.Popen(
|
||
full_command.split(),
|
||
text=True,
|
||
stdout=sys.stdout,
|
||
stderr=sys.stdout,
|
||
)
|
||
base_url = f"http://{HOST}:{PORT}"
|
||
tik = time.time()
|
||
while time.time() - tik < RUN_SERVER_TIMEOUT:
|
||
if check_server_health(base_url):
|
||
break
|
||
time.sleep(1)
|
||
if time.time() - tik > RUN_SERVER_TIMEOUT:
|
||
process.terminate()
|
||
raise RuntimeError("server launch failed")
|
||
yield
|
||
process.terminate()
|
||
|
||
|
||
def test_fsdpengine_nccl_weight_update_to_remote(tmp_path_factory, sglang_server_nccl):
|
||
# 设置分布式环境变量
|
||
os.environ["WORLD_SIZE"] = "1"
|
||
os.environ["RANK"] = "0"
|
||
os.environ["LOCAL_RANK"] = "0"
|
||
os.environ["MASTER_ADDR"] = HOST
|
||
os.environ["MASTER_PORT"] = str(MASTER_PORT)
|
||
|
||
# 启动本地FSDPEngine
|
||
engine_config = TrainEngineConfig(
|
||
experiment_name=EXPR_NAME,
|
||
trial_name=TRIAL_NAME,
|
||
path=MODEL_PATH,
|
||
optimizer=OptimizerConfig(),
|
||
)
|
||
engine = FSDPEngine(engine_config)
|
||
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
|
||
engine.initialize(None, ft_spec)
|
||
|
||
# 启动远端RemoteSGLangEngine
|
||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||
remote_engine = RemoteSGLangEngine(config)
|
||
remote_engine.initialize(None, None)
|
||
|
||
# 构造WeightUpdateMeta(type=nccl)
|
||
param_meta = engine.get_param_meta_for_distributed_update()
|
||
meta = WeightUpdateMeta(
|
||
type="nccl",
|
||
path=None,
|
||
alloc_mode=None,
|
||
comm_backend="nccl",
|
||
model_version=123,
|
||
tp_size=1,
|
||
master_address="localhost",
|
||
master_port=find_free_ports(1)[0],
|
||
world_size=2,
|
||
group_name=GROUP_NAME,
|
||
parameter_names=list(param_meta.keys()),
|
||
state_dict_key_to_shape=param_meta,
|
||
)
|
||
|
||
# 本地engine广播参数
|
||
future = remote_engine.update_weights(meta)
|
||
print("got future", flush=True)
|
||
engine.upload_weights(meta)
|
||
print("uploaded wexights to remote engine", flush=True)
|
||
# 远端engine拉取参数
|
||
future.result(timeout=120)
|
||
print("got result", flush=True)
|
||
# 检查远端参数版本
|
||
assert remote_engine.get_version() == 123
|
||
remote_engine.destroy()
|
||
engine.destroy()
|