mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
parent
805437463f
commit
e222cea659
|
@ -565,6 +565,58 @@ class DatasetConfig:
|
|||
drop_last: bool = field(default=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LauncherConfig:
|
||||
"""Configuration for launching the SGLang server."""
|
||||
|
||||
inference_server_cpus_per_gpu: int = field(
|
||||
default=15,
|
||||
metadata={
|
||||
"help": "Number of CPUs allocated per GPU for inference server. "
|
||||
},
|
||||
)
|
||||
inference_server_mem_per_gpu: int = field(
|
||||
default=150*1024,
|
||||
metadata={
|
||||
"help": "Memory allocated per GPU for inference server in MB. "
|
||||
},
|
||||
)
|
||||
trainer_cpus_per_gpu: int = field(
|
||||
default=15,
|
||||
metadata={
|
||||
"help": "Number of CPUs allocated per GPU for training. "
|
||||
},
|
||||
)
|
||||
trainer_mem_per_gpu: int = field(
|
||||
default=150*1024,
|
||||
metadata={
|
||||
"help": "Memory allocated per GPU for training in MB. "
|
||||
},
|
||||
)
|
||||
inference_server_env_vars: str = field(
|
||||
defualt="",
|
||||
metadata={
|
||||
"help": "Environment variables for inference server, seperated by commas. "
|
||||
"Example: 'ENV1=val1,ENV2=val2'. "
|
||||
},
|
||||
)
|
||||
trainer_env_vars: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "Environment variables for training, seperated by commas. "
|
||||
"Example: 'ENV1=val1,ENV2=val2'. "
|
||||
},
|
||||
)
|
||||
trainer_port: int = field(
|
||||
default=27015,
|
||||
metadata={
|
||||
"help": "Trainer port used for torch.distributed initialization."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseExperimentConfig:
|
||||
# NOTE: we need this unified config class because different experiments
|
||||
|
@ -625,6 +677,7 @@ class BaseExperimentConfig:
|
|||
|
||||
server_only: bool = False
|
||||
sglang: SGLangConfig = field(default_factory=SGLangConfig)
|
||||
launcher: LauncherConfig = field(default_factory=LauncherConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
|
|
@ -12,14 +12,17 @@ from omegaconf import OmegaConf
|
|||
|
||||
import realhf.base.logging as logging
|
||||
from arealite.api.cli_args import (
|
||||
BaseExperimentConfig,
|
||||
ClusterSpecConfig,
|
||||
SGLangConfig,
|
||||
LauncherConfig,
|
||||
parse_cli_args,
|
||||
to_structured_cfg,
|
||||
)
|
||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.launcher.utils import find_and_amend_config, find_config
|
||||
from realhf.scheduler.client import JobException, JobInfo, JobState
|
||||
from realhf.base import logging, name_resolve, names
|
||||
|
||||
|
||||
logger = logging.getLogger("SlurmLauncher")
|
||||
|
||||
|
@ -132,8 +135,7 @@ def get_slurm_host_ip(node: str):
|
|||
logger.warning(f"Get slurm host ip for node {node} failed.")
|
||||
|
||||
|
||||
SCHEDULING_RETRY_INTERVAL_SECONDS = 30
|
||||
SCHEDULING_TIMEOUT_MAX_SECONDS = 3600 * 24
|
||||
SGLANG_SERVER_TIMEOUT_SECONDS = 120
|
||||
SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5
|
||||
|
||||
|
||||
|
@ -200,12 +202,17 @@ NA132_ENVIRONS = {
|
|||
}
|
||||
|
||||
|
||||
def get_env_vars(cluster_name: str):
|
||||
def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]:
|
||||
"""Returns the environment variables for the cluster."""
|
||||
additional_env_vars = {}
|
||||
if additional_env_vars:
|
||||
additional_env_vars = dict(
|
||||
item.split("=") for item in additional_env_vars.split(",")
|
||||
)
|
||||
if cluster_name == "na132":
|
||||
return {**BASE_ENVIRONS, **NA132_ENVIRONS}
|
||||
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars}
|
||||
else:
|
||||
return BASE_ENVIRONS
|
||||
return {**BASE_ENVIRONS, **additional_env_vars}
|
||||
|
||||
|
||||
class SlurmLauncher:
|
||||
|
@ -545,9 +552,11 @@ class SlurmLauncher:
|
|||
if __name__ == "__main__":
|
||||
# usage: python -m arealite.launcher.slurm <entry_point> <config_path> [<args>]
|
||||
config, config_file = parse_cli_args(sys.argv[2:])
|
||||
entry_point = sys.argv[1]
|
||||
|
||||
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
|
||||
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
|
||||
name_resolve.reconfigure(config.cluster.name_resolve)
|
||||
|
||||
n_nodes = config.n_nodes
|
||||
n_gpus_per_node = config.n_gpus_per_node
|
||||
if n_gpus_per_node < config.cluster.n_gpus_per_node:
|
||||
|
@ -585,30 +594,15 @@ if __name__ == "__main__":
|
|||
n_sglang_servers_per_node = config.cluster.n_gpus_per_node // sglang_tp_size
|
||||
|
||||
base_seed = config.sglang.random_seed
|
||||
base_port = args.sglang_server_base_port
|
||||
sglang_ports_on_node = set()
|
||||
sglang_server_cmd_template = (
|
||||
f"python3 -m arealite.launcher.sglang_server {sys.argv[1:]} sglang.random_seed={{seed}}"
|
||||
)
|
||||
for i in range(n_sglang_servers):
|
||||
base_gpu_id = i * sglang_tp_size % n_gpus_per_node
|
||||
# Since we cannot get port information from slurm, we only ensure ports on the same .
|
||||
server_port = base_port + i % n_sglang_servers_per_node * 2
|
||||
# print(server_port, dist_init_addr)
|
||||
config.sglang.random_seed = base_seed + i
|
||||
sglang_cmds.append(
|
||||
SGLangConfig.build_cmd(
|
||||
config.sglang,
|
||||
sglang_tp_size,
|
||||
base_gpu_id=0,
|
||||
host="localhost",
|
||||
port=server_port,
|
||||
dist_init_addr=None,
|
||||
sglang_version=args.sglang_version,
|
||||
)
|
||||
sglang_cmd = sglang_server_cmd_template.format(
|
||||
seed=base_seed + i,
|
||||
)
|
||||
sglang_ports_on_node.add(server_port)
|
||||
assert len(sglang_ports_on_node) == n_sglang_servers_per_node
|
||||
sglang_cmds.append(sglang_cmd)
|
||||
|
||||
# launch SGLang servers, note that we need to leave some resources on each node
|
||||
# to schedule jobs that retrieve node IP. (1 CPU core & 10 MB memory)
|
||||
if sglang_cmds:
|
||||
launcher.submit_array(
|
||||
job_name="sglang-server",
|
||||
|
@ -616,44 +610,38 @@ if __name__ == "__main__":
|
|||
count=n_sglang_servers,
|
||||
nodes=n_sglang_nodes,
|
||||
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||
cpus_per_task=15 * sglang_tp_size,
|
||||
mem_per_task=150 * 1024 * sglang_tp_size, # 20GB per task
|
||||
cpus_per_task=launcher.inference_server_cpus_per_gpu * sglang_tp_size,
|
||||
mem_per_task=launcher.inference_server_mem_per_gpu * sglang_tp_size,
|
||||
container_image=config.cluster.gpu_infer_image,
|
||||
container_mounts=config.cluster.mount,
|
||||
env_vars=get_env_vars(config.cluster.cluster_name),
|
||||
env_vars=get_env_vars(
|
||||
config.cluster.cluster_name,
|
||||
config.laucnher.inference_server_env_vars,
|
||||
),
|
||||
)
|
||||
|
||||
# Get SGLang slurm nodes, find the hosts
|
||||
start_time = time.perf_counter()
|
||||
name = names.gen_servers(
|
||||
config.experiment_name, config.trial_name
|
||||
)
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
job_info = launcher.find("sglang-server")
|
||||
assert job_info is not None
|
||||
print(job_info)
|
||||
sglang_hosts = job_info.host
|
||||
logger.info(
|
||||
f"Waiting for SGLang servers to be scheduled by slurm, time since started = {time.perf_counter() - start_time:.2f}"
|
||||
)
|
||||
if sglang_hosts:
|
||||
print(sglang_hosts)
|
||||
sglang_hosts = [
|
||||
get_slurm_host_ip(node)
|
||||
for node in parse_slurm_nodelist(sglang_hosts)
|
||||
]
|
||||
print(sglang_hosts)
|
||||
for host in sglang_hosts:
|
||||
sglang_addrs.extend(
|
||||
[f"{host}:{port}" for port in sglang_ports_on_node]
|
||||
)
|
||||
logger.info(f"Get SGLang addresses: {' '.join(sglang_addrs)}")
|
||||
assert len(sglang_addrs) == n_sglang_servers
|
||||
sglang_addrs = name_resolve.get_subtree(name)
|
||||
if len(sglang_addrs) >= n_sglang_servers:
|
||||
logger.info(f"Found {n_sglang_servers} SGLang servers: {", ".join(sglang_addrs)}")
|
||||
break
|
||||
time.sleep(10)
|
||||
|
||||
time.sleep(1)
|
||||
if time.perf_counter() - start > SGLANG_SERVER_TIMEOUT_SECONDS:
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for SGLang servers to be ready. "
|
||||
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
|
||||
)
|
||||
|
||||
trainer_n_nodes = n_nodes - n_sglang_nodes
|
||||
assert r.overrides is not None
|
||||
trainer_cmd_template = (
|
||||
f"torchrun --nnodes={{nnodes}} --nproc-per-node={{nproc_per_node}} --node-rank {{node_rank}} "
|
||||
f"--master-addr $head_node_ip --master-port {args.trainer_port} {entry_point} --config {config_file} {' '.join(r.overrides)}"
|
||||
f"--master-addr $head_node_ip --master-port {config.launcher.trainer_port} {sys.argv[1:]}"
|
||||
)
|
||||
|
||||
trainer_cmds = []
|
||||
|
@ -676,12 +664,15 @@ if __name__ == "__main__":
|
|||
count=trainer_n_nodes,
|
||||
nodes=trainer_n_nodes,
|
||||
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||
cpus_per_task=120, # one trainer task occupy the entire node
|
||||
mem_per_task=1024 * 1024, # 1024GB per task
|
||||
cpus_per_task=launcher.trainer_cpus_per_gpu * config.cluster.n_gpus_per_node,
|
||||
mem_per_task=launcher.trainer_mem_per_gpu * config.cluster.n_gpus_per_node,
|
||||
container_image=config.cluster.gpu_image,
|
||||
container_mounts=config.cluster.mount,
|
||||
env_vars=dict(
|
||||
**get_env_vars(config.cluster.cluster_name),
|
||||
**get_env_vars(
|
||||
config.cluster.cluster_name,
|
||||
config.launcher.trainer_env_vars,
|
||||
),
|
||||
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
|
||||
),
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue