This commit is contained in:
晓雷 2025-07-12 11:43:29 +08:00
parent 805437463f
commit e222cea659
3 changed files with 102 additions and 59 deletions

View File

@ -565,6 +565,58 @@ class DatasetConfig:
drop_last: bool = field(default=True)
@dataclass
class LauncherConfig:
"""Configuration for launching the SGLang server."""
inference_server_cpus_per_gpu: int = field(
default=15,
metadata={
"help": "Number of CPUs allocated per GPU for inference server. "
},
)
inference_server_mem_per_gpu: int = field(
default=150*1024,
metadata={
"help": "Memory allocated per GPU for inference server in MB. "
},
)
trainer_cpus_per_gpu: int = field(
default=15,
metadata={
"help": "Number of CPUs allocated per GPU for training. "
},
)
trainer_mem_per_gpu: int = field(
default=150*1024,
metadata={
"help": "Memory allocated per GPU for training in MB. "
},
)
inference_server_env_vars: str = field(
defualt="",
metadata={
"help": "Environment variables for inference server, seperated by commas. "
"Example: 'ENV1=val1,ENV2=val2'. "
},
)
trainer_env_vars: str = field(
default="",
metadata={
"help": "Environment variables for training, seperated by commas. "
"Example: 'ENV1=val1,ENV2=val2'. "
},
)
trainer_port: int = field(
default=27015,
metadata={
"help": "Trainer port used for torch.distributed initialization."
},
)
@dataclass
class BaseExperimentConfig:
# NOTE: we need this unified config class because different experiments
@ -625,6 +677,7 @@ class BaseExperimentConfig:
server_only: bool = False
sglang: SGLangConfig = field(default_factory=SGLangConfig)
launcher: LauncherConfig = field(default_factory=LauncherConfig)
@dataclass

View File

@ -1,5 +1,4 @@
import os
import socket
import subprocess
import sys
import time

View File

@ -12,14 +12,17 @@ from omegaconf import OmegaConf
import realhf.base.logging as logging
from arealite.api.cli_args import (
BaseExperimentConfig,
ClusterSpecConfig,
SGLangConfig,
LauncherConfig,
parse_cli_args,
to_structured_cfg,
)
from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.launcher.utils import find_and_amend_config, find_config
from realhf.scheduler.client import JobException, JobInfo, JobState
from realhf.base import logging, name_resolve, names
logger = logging.getLogger("SlurmLauncher")
@ -132,8 +135,7 @@ def get_slurm_host_ip(node: str):
logger.warning(f"Get slurm host ip for node {node} failed.")
SCHEDULING_RETRY_INTERVAL_SECONDS = 30
SCHEDULING_TIMEOUT_MAX_SECONDS = 3600 * 24
SGLANG_SERVER_TIMEOUT_SECONDS = 120
SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5
@ -200,12 +202,17 @@ NA132_ENVIRONS = {
}
def get_env_vars(cluster_name: str):
def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]:
"""Returns the environment variables for the cluster."""
additional_env_vars = {}
if additional_env_vars:
additional_env_vars = dict(
item.split("=") for item in additional_env_vars.split(",")
)
if cluster_name == "na132":
return {**BASE_ENVIRONS, **NA132_ENVIRONS}
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars}
else:
return BASE_ENVIRONS
return {**BASE_ENVIRONS, **additional_env_vars}
class SlurmLauncher:
@ -545,9 +552,11 @@ class SlurmLauncher:
if __name__ == "__main__":
# usage: python -m arealite.launcher.slurm <entry_point> <config_path> [<args>]
config, config_file = parse_cli_args(sys.argv[2:])
entry_point = sys.argv[1]
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
name_resolve.reconfigure(config.cluster.name_resolve)
n_nodes = config.n_nodes
n_gpus_per_node = config.n_gpus_per_node
if n_gpus_per_node < config.cluster.n_gpus_per_node:
@ -585,30 +594,15 @@ if __name__ == "__main__":
n_sglang_servers_per_node = config.cluster.n_gpus_per_node // sglang_tp_size
base_seed = config.sglang.random_seed
base_port = args.sglang_server_base_port
sglang_ports_on_node = set()
sglang_server_cmd_template = (
f"python3 -m arealite.launcher.sglang_server {sys.argv[1:]} sglang.random_seed={{seed}}"
)
for i in range(n_sglang_servers):
base_gpu_id = i * sglang_tp_size % n_gpus_per_node
# Since we cannot get port information from slurm, we only ensure ports on the same .
server_port = base_port + i % n_sglang_servers_per_node * 2
# print(server_port, dist_init_addr)
config.sglang.random_seed = base_seed + i
sglang_cmds.append(
SGLangConfig.build_cmd(
config.sglang,
sglang_tp_size,
base_gpu_id=0,
host="localhost",
port=server_port,
dist_init_addr=None,
sglang_version=args.sglang_version,
)
sglang_cmd = sglang_server_cmd_template.format(
seed=base_seed + i,
)
sglang_ports_on_node.add(server_port)
assert len(sglang_ports_on_node) == n_sglang_servers_per_node
sglang_cmds.append(sglang_cmd)
# launch SGLang servers, note that we need to leave some resources on each node
# to schedule jobs that retrieve node IP. (1 CPU core & 10 MB memory)
if sglang_cmds:
launcher.submit_array(
job_name="sglang-server",
@ -616,44 +610,38 @@ if __name__ == "__main__":
count=n_sglang_servers,
nodes=n_sglang_nodes,
n_gpus_per_node=config.cluster.n_gpus_per_node,
cpus_per_task=15 * sglang_tp_size,
mem_per_task=150 * 1024 * sglang_tp_size, # 20GB per task
cpus_per_task=launcher.inference_server_cpus_per_gpu * sglang_tp_size,
mem_per_task=launcher.inference_server_mem_per_gpu * sglang_tp_size,
container_image=config.cluster.gpu_infer_image,
container_mounts=config.cluster.mount,
env_vars=get_env_vars(config.cluster.cluster_name),
env_vars=get_env_vars(
config.cluster.cluster_name,
config.laucnher.inference_server_env_vars,
),
)
# Get SGLang slurm nodes, find the hosts
start_time = time.perf_counter()
name = names.gen_servers(
config.experiment_name, config.trial_name
)
start = time.perf_counter()
while True:
job_info = launcher.find("sglang-server")
assert job_info is not None
print(job_info)
sglang_hosts = job_info.host
logger.info(
f"Waiting for SGLang servers to be scheduled by slurm, time since started = {time.perf_counter() - start_time:.2f}"
)
if sglang_hosts:
print(sglang_hosts)
sglang_hosts = [
get_slurm_host_ip(node)
for node in parse_slurm_nodelist(sglang_hosts)
]
print(sglang_hosts)
for host in sglang_hosts:
sglang_addrs.extend(
[f"{host}:{port}" for port in sglang_ports_on_node]
)
logger.info(f"Get SGLang addresses: {' '.join(sglang_addrs)}")
assert len(sglang_addrs) == n_sglang_servers
sglang_addrs = name_resolve.get_subtree(name)
if len(sglang_addrs) >= n_sglang_servers:
logger.info(f"Found {n_sglang_servers} SGLang servers: {", ".join(sglang_addrs)}")
break
time.sleep(10)
time.sleep(1)
if time.perf_counter() - start > SGLANG_SERVER_TIMEOUT_SECONDS:
raise TimeoutError(
f"Timeout waiting for SGLang servers to be ready. "
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
)
trainer_n_nodes = n_nodes - n_sglang_nodes
assert r.overrides is not None
trainer_cmd_template = (
f"torchrun --nnodes={{nnodes}} --nproc-per-node={{nproc_per_node}} --node-rank {{node_rank}} "
f"--master-addr $head_node_ip --master-port {args.trainer_port} {entry_point} --config {config_file} {' '.join(r.overrides)}"
f"--master-addr $head_node_ip --master-port {config.launcher.trainer_port} {sys.argv[1:]}"
)
trainer_cmds = []
@ -676,12 +664,15 @@ if __name__ == "__main__":
count=trainer_n_nodes,
nodes=trainer_n_nodes,
n_gpus_per_node=config.cluster.n_gpus_per_node,
cpus_per_task=120, # one trainer task occupy the entire node
mem_per_task=1024 * 1024, # 1024GB per task
cpus_per_task=launcher.trainer_cpus_per_gpu * config.cluster.n_gpus_per_node,
mem_per_task=launcher.trainer_mem_per_gpu * config.cluster.n_gpus_per_node,
container_image=config.cluster.gpu_image,
container_mounts=config.cluster.mount,
env_vars=dict(
**get_env_vars(config.cluster.cluster_name),
**get_env_vars(
config.cluster.cluster_name,
config.launcher.trainer_env_vars,
),
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
),
)