This commit is contained in:
meizhiyu.mzy 2025-07-11 22:56:59 +08:00
parent 97511e43ff
commit 4a26f28adf
4 changed files with 307 additions and 171 deletions

View File

@ -283,12 +283,12 @@ class SGLangConfig:
host,
port,
dist_init_addr: Optional[str] = None,
sglang_version: Optional[str] = None,
):
from realhf.base import network, pkg_version, seeding
from realhf.base import pkg_version
from realhf.experiments.common.utils import asdict as conf_as_dict
args: Dict = conf_as_dict(sglang_config)
args = dict(
host=host,
port=port,
@ -309,10 +309,24 @@ class SGLangConfig:
dist_init_addr=dist_init_addr,
**args,
)
if sglang_version:
version_less_than_0_4_4 = (
pkg_version.compare_versions(sglang_version, "0.4.4") < 0
)
version_less_than_0_4_3 = (
pkg_version.compare_versions(sglang_version, "0.4.3") < 0
)
elif pkg_version.is_available("sglang"):
version_less_than_0_4_4 = pkg_version.is_version_less("sglang", "0.4.4")
version_less_than_0_4_3 = pkg_version.is_version_less("sglang", "0.4.3")
else:
raise ValueError(
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
)
if pkg_version.is_version_less("sglang", "0.4.4"):
if version_less_than_0_4_4:
args.pop("log_requests_level")
if pkg_version.is_version_less("sglang", "0.4.3"):
if version_less_than_0_4_3:
args.pop("enable_nccl_nvls")
args.pop("triton_attention_num_kv_splits")
args.pop("cuda_graph_bs")
@ -634,8 +648,12 @@ class ArgParseResult:
config: BaseExperimentConfig
config_file: Path
additional_args: Optional[argparse.Namespace] = None
overrides: Optional[List[str]] = None
def parse_cli_args(argv: List[str], parser: Optional[argparse.ArgumentParser] = None) -> ArgParseResult:
def parse_cli_args(
argv: List[str], parser: Optional[argparse.ArgumentParser] = None
) -> ArgParseResult:
if parser is None:
parser = argparse.ArgumentParser()
parser.add_argument(
@ -654,11 +672,10 @@ def parse_cli_args(argv: List[str], parser: Optional[argparse.ArgumentParser] =
overrides=overrides,
)
return ArgParseResult(
config=cfg,
config_file=config_file,
additional_args=args
config=cfg, config_file=config_file, additional_args=args, overrides=overrides
)
def to_structured_cfg(cfg, config_cls):
# Merge with the default configuration.
# The yaml and commandline can omit some default values defined in python dataclasses.
@ -668,9 +685,7 @@ def to_structured_cfg(cfg, config_cls):
def load_expr_config(
argv: List[str],
config_cls,
parser: Optional[argparse.ArgumentParser] = None
argv: List[str], config_cls, parser: Optional[argparse.ArgumentParser] = None
) -> ArgParseResult:
r = parse_cli_args(argv, parser=parser)
cfg = r.config
@ -684,7 +699,5 @@ def load_expr_config(
constants.set_experiment_trial_names(cfg.experiment_name, cfg.trial_name)
name_resolve.reconfigure(cfg.cluster.name_resolve)
return ArgParseResult(
config=cfg,
config_file=r.config_file,
additional_args=r.additional_args
config=cfg, config_file=r.config_file, additional_args=r.additional_args
)

View File

@ -1,23 +1,25 @@
from typing import Tuple, List, Optional, Literal, Dict
import subprocess
import time
import re
import os
import getpass
import argparse
import getpass
import os
import pathlib
import re
import subprocess
import sys
import time
from typing import Dict, List, Literal, Optional, Tuple
from omegaconf import OmegaConf
from arealite.api.cli_args import SGLangConfig, ClusterSpecConfig, parse_cli_args, to_structured_cfg
from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.launcher.utils import find_config, find_and_amend_config
import realhf.base.logging as logging
from realhf.scheduler.client import (
JobException,
JobInfo,
JobState
from arealite.api.cli_args import (
ClusterSpecConfig,
SGLangConfig,
parse_cli_args,
to_structured_cfg,
)
from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.launcher.utils import find_and_amend_config, find_config
from realhf.scheduler.client import JobException, JobInfo, JobState
logger = logging.getLogger("SlurmLauncher")
@ -50,7 +52,7 @@ STATUS_MAPPING = {
def cancel_jobs(
slurm_names: Optional[List[str]] = None,
slurm_ids: Optional[List[str]] = None,
slurm_ids: Optional[List[int]] = None,
signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL",
):
assert (
@ -63,16 +65,17 @@ def cancel_jobs(
if slurm_names is not None:
cmd += ["-n", ",".join(slurm_names)]
elif slurm_ids is not None:
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
cmd += [",".join(str(s) for s in slurm_ids)]
subprocess.check_call(cmd)
logger.info(
f"Cancelled Slurm job with signal {signal}: "
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}"
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}"
)
def query_jobs(
slurm_names: Optional[List[str]] = None,
slurm_ids: Optional[List[str]] = None,
slurm_ids: Optional[List[int]] = None,
status: str = "all",
delimiter: str = "__PSI__",
) -> List[JobInfo]:
@ -82,6 +85,7 @@ def query_jobs(
cmd += ["-n", ",".join(slurm_names)]
if slurm_ids is not None:
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
output = (
subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip()
)
@ -97,19 +101,43 @@ def query_jobs(
host=nodelist,
submit_time=submit_time,
start_time=start_time,
slurm_id=job_id.strip(),
slurm_id=int(job_id.strip()),
)
)
print(rs)
return rs
def parse_slurm_nodelist(nodelist: str) -> List[str]:
return (
subprocess.check_output(
[
"scontrol",
"show",
"hostnames",
nodelist,
]
)
.decode("utf-8")
.strip()
.split("\n")
)
def get_slurm_host_ip(node: str):
try:
cmd = f"srun --overlap --mpi=pmi2 --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address"
return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip()
except subprocess.CalledProcessError:
logger.warning(f"Get slurm host ip for node {node} failed.")
SCHEDULING_RETRY_INTERVAL_SECONDS = 30
SCHEDULING_TIMEOUT_MAX_SECONDS = 3600 * 24
SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5
SBATCH_SCRIPT_TEMPLATE = """
#!/bin/bash
SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash
{sbatch_options}
# Getting the node names
@ -119,11 +147,11 @@ echo nodes=$nodes
nodes_array=($nodes)
echo node_array=$nodes_array
head_node=$\{nodes_array[0]\}
head_node=${{nodes_array[0]}}
echo head_node=$head_node
# Getting the head node IP address
head_node_ip=$(srun --mpi=pmi2 --nodes=1 --ntasks=1 --nodelist="$head_node" hostname --ip-address)
head_node_ip=$(srun --overlap --mpi=pmi2 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address)
echo head_node_ip=$head_node_ip
# srun commands
@ -132,11 +160,11 @@ echo head_node_ip=$head_node_ip
wait
"""
SRUN_CMD_TEMPLATE = """
srun --mpi=pmi2 -K -l --chdir $PWD --nodes={nodes} --ntasks={ntasks} --gres:gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} \\
SRUN_CMD_TEMPLATE = """srun --overlap --mpi=pmi2 -K -l --chdir $PWD --nodes={nodes} --ntasks={ntasks} --gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} \\
--mem-per-cpu={mem_per_cpu}M {apptainer_name} exec {apptainer_options} --bind {container_mounts} \\
{container_env_strings} {container_image} {cmd} &
{container_env_strings} \\
{container_image} \\
{cmd} &
"""
LOCAL_CACHE_DIR = "/tmp/arealite"
@ -152,6 +180,8 @@ BASE_ENVIRONS = {
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
"HF_ENDPOINT": "https://hf-mirror.com",
"PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent,
}
NA132_ENVIRONS = {
"NCCL_SOCKET_IFNAME": "bond0",
@ -164,9 +194,11 @@ NA132_ENVIRONS = {
"NCCL_IB_HCA": "mlx5_bond",
"NCCL_IB_QPS_PER_CONNECTION": "8",
"NCCL_SET_THREAD_NAME": "1",
"NCCL_DEBUG": "WARN",
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
}
def get_env_vars(cluster_name: str):
"""Returns the environment variables for the cluster."""
if cluster_name == "na132":
@ -181,30 +213,29 @@ class SlurmLauncher:
self.trial_name = trial_name
self.fileroot = fileroot
# actual slurm job name -> JobInfo
self.jobs: Dict[str, JobInfo] = {}
# slurm_job_id -> JobInfo
self.jobs: Dict[int, JobInfo] = {}
self.job_names = []
@property
def run_name(self) -> str:
"""Returns the run name of this launcher."""
return f"{self.experiment_name}_{self.trial_name}"
def slurm_name(self, job_name: str) -> str:
"""Returns the slurm name of a job."""
return f"{self.experiment_name}_{self.trial_name}:{job_name}"
def log_path_of(self, job_name: str) -> str:
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
os.makedirs(log_path, exist_ok=True)
return os.path.join(log_path, f"{job_name}.log")
def sbatch_path_of(self, job_name: str) -> str:
sbatch_path = f"{self.fileroot}/sbatch/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
sbatch_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
os.makedirs(sbatch_path, exist_ok=True)
return os.path.join(sbatch_path, f"{job_name}.sh")
def submit(self, job_name, cmd, **kwargs):
"""Submits and launch a job with SBATCH.
@ -213,36 +244,46 @@ class SlurmLauncher:
"""
return self.submit_array(job_name, cmd, count=1, **kwargs)
def find_job_id(self, job_name: str):
job_name = self.slurm_name(job_name)
for job_id, job_info in self.jobs.items():
if job_info.name == job_name:
return job_id
return None
def submit_array(
self,
job_name: str,
self,
job_name: str,
cmd: List[str] | str,
count: int,
nodes: int,
n_gpus_per_node: int,
cpus_per_task: int,
mem_per_task: int, # MB
mem_per_task: int, # MB
container_image: str,
container_mounts: Optional[str] = None,
env_vars: Optional[Dict] = None,
nodelist: Optional[str] = None,
exclude: Optional[str] = None,
apptainer_name: Optional[str] = "singularity",
apptainer_options: Optional[Tuple[str]] = (
"--no-home", "--writable-tmpfs", "--nv", "--pid"
apptainer_options: Optional[Tuple[str, ...]] = (
"--no-home",
"--writable-tmpfs",
"--nv",
"--pid",
),
):
"""Submits and launch a job array with SBATCH.
"""Submits and launch a job array with SBATCH.
Note that a job array has one (unique) slurm name, and one (unique) slurm id.
Args:
job_name (str): The job name of the job array. The actual slurm name will be
`<experiment_name>_<trial_name>:<job_name>`.
`<experiment_name>_<trial_name>:<job_name>`.
cmd (str or List[str]): The core command to be executed.
count (int): The number of jobs in the array.
"""
assert self.slurm_name(job_name) not in self.jobs, (
f"Job {self.slurm_name(job_name)} is already submitted. "
assert job_name not in self.job_names, (
f"Job {job_name} is already submitted. "
"Please use a different job name or stop the existing job."
)
if isinstance(cmd, str):
@ -262,17 +303,18 @@ class SlurmLauncher:
)
mem_per_cpu = mem_per_task // cpus_per_task # MB per CPU
mem_per_node = mem_per_task * count // nodes + 1024 * 100 # FIXME
sbatch_options = [
f"--job-name={self.slurm_name(job_name)}",
f"--output={self.fileroot}/{self.run_name}/{job_name}.out",
"--open-mode=append",
f"--output={self.log_path_of(job_name)}",
# "--open-mode=append", # FIXME
"--no-requeue",
f"--nodes={nodes}-{nodes}",
f"--ntasks-per-node={ntasks_per_node}",
f"--gres=gpu:{n_gpus_per_node}",
f"--cpus-per-task={cpus_per_task}",
f"--mem-per-cpu={mem_per_cpu}M",
f"--mem={mem_per_node}M",
]
if nodelist:
@ -282,13 +324,12 @@ class SlurmLauncher:
sbatch_options_str = "\n".join([f"#SBATCH {opt}" for opt in sbatch_options])
if env_vars is None:
env_vars = dict()
n_gpus_per_task = n_gpus_per_node // ntasks_per_node
assert "CUDA_VISIBLE_DEVICES" not in env_vars, (
"CUDA_VISIBLE_DEVICES should be automatically resolved by Launcher instead of manually assigned."
)
assert (
"CUDA_VISIBLE_DEVICES" not in env_vars
), "CUDA_VISIBLE_DEVICES should be automatically resolved by Launcher instead of manually assigned."
srun_cmds = []
for i in range(count):
@ -301,12 +342,17 @@ class SlurmLauncher:
str(x) for x in range(gpu_id_start, gpu_id_end)
),
}
env_string = " ".join("--env {}={}".format(k, v) for k, v in (_env_vars or {}).items())
env_string = " ".join(
"--env {}={}".format(k, v) for k, v in (_env_vars or {}).items()
)
# Prepare the command for each job in the array
job_cmd = cmd[i]
# FIXME: only for debugging, remove and replace new image
job_cmd = f'bash -c "pip3 install -U gymnasium torchdata tensordict hf-xet; {job_cmd}"'
srun_cmd = SRUN_CMD_TEMPLATE.format(
nodes=nodes,
ntasks=ntasks_per_node,
nodes=1,
ntasks=1,
n_gpus_per_node=n_gpus_per_node,
cpus_per_task=cpus_per_task,
mem_per_cpu=mem_per_cpu,
@ -319,59 +365,83 @@ class SlurmLauncher:
)
srun_cmds.append(srun_cmd)
srun_cmd = "\n".join(srun_cmds)
srun_cmds = "\n".join(srun_cmds)
sbatch_script = SBATCH_SCRIPT_TEMPLATE.format(
sbatch_options=sbatch_options_str, srun_cmds=srun_cmd
sbatch_options=sbatch_options_str, srun_cmds=srun_cmds
)
sbatch_file_path = self.sbatch_path_of(f"{job_name}_{i}")
sbatch_file_path = self.sbatch_path_of(f"{job_name}")
with open(sbatch_file_path, "w") as f:
f.write(sbatch_script)
# Submit the job
return_code = subprocess.check_call(["sbatch", sbatch_file_path])
if return_code != 0:
# FIXME: debug only
try:
output = (
subprocess.check_output(["sbatch", sbatch_file_path])
.decode("utf-8")
.strip()
)
logger.info(
f"Submitted Slurm job {self.slurm_name(job_name)} to scheduler. To check the output, run \n\t`tail -f {self.log_path_of(job_name)}`."
)
except subprocess.CalledProcessError as e:
logger.warning(
f"Failed to submit job {self.slurm_name(job_name)}. "
f"For debugging, please make sure your sbatch command works "
f"and check generated sbatch file on {sbatch_file_path}."
)
self.jobs[self.slurm_name(job_name)] = JobInfo(
logger.error(f"Error message: {e}")
return
match = re.search(r"Submitted batch job (\d+)", output)
slurm_job_id = int(match.group(1)) if match else None
if slurm_job_id is None:
logger.warning(
f"Failed to obtain job id for job {self.slurm_name(job_name)}. "
f"sbatch output: {output}"
)
return
assert isinstance(slurm_job_id, int)
self.jobs[slurm_job_id] = JobInfo(
name=self.slurm_name(job_name),
state=JobState.PENDING,
slurm_id=slurm_job_id,
)
self._update_all()
def stop(self, job_name, signal=None):
def stop(self, job_name, signal="SIGKILL"):
"""Stops a running job.
Raises exception if there is no such job, but passes if the job
has stopped either successfully or not.
Args:
job_name: The job name of the job array to stop.
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
job_name: The job name of the job array to stop.
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
"""
raise NotImplementedError()
job_id = self.find_job_id(job_name)
if not job_id:
return
return cancel_jobs(slurm_ids=[job_id], signal=signal)
def stop_all(self, signal=None):
def stop_all(self, signal="SIGKILL"):
"""Stops all running jobs."""
raise NotImplementedError()
return cancel_jobs(slurm_ids=list(self.jobs.keys()), signal=signal)
def find(self, job_name) -> JobInfo:
def find(self, job_name) -> JobInfo | None:
"""Gets the status of a job of this job.
Args:
job_name: The job name of the job array to find.
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
job_name: The job name of the job array to find.
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
Returns:
A JobInfo if the job is found, or None otherwise.
"""
self._update_all()
return self.jobs.get(self.slurm_name(job_name), None)
job_id = self.find_job_id(job_name)
return self.jobs[job_id] if job_id else None
def find_all(self, job_name_regex=".*") -> List[JobInfo]:
"""Finds jobs.
@ -385,11 +455,11 @@ class SlurmLauncher:
self._update_all()
infos = []
for r in self.jobs.values():
job_name = r.name.split(":")[-1] # Extract the job name from slurm name
job_name = r.name.split(":")[-1] # Extract the job name from slurm name
if re.fullmatch(job_name_regex, job_name):
infos.append(r)
return infos
def _find_job_with_status(
self,
status: List[JobState],
@ -405,7 +475,6 @@ class SlurmLauncher:
self._update_all()
return [r for r in self.jobs.values() if r.state in status]
def wait(
self,
timeout=None,
@ -422,10 +491,10 @@ class SlurmLauncher:
deadline = None if timeout is None else time.time() + timeout
num_jobs_left = len(self.jobs)
left = set(self.jobs.values())
left = list(self.jobs.keys())
logger.info(
f"Waiting for {num_jobs_left} jobs. Jobs IDs: "
f"{','.join(sorted([x.slurm_id for x in self.jobs.values()]))}."
f"{','.join(sorted([str(x.slurm_id) for x in self.jobs.values()]))}."
)
while len(left) > 0:
if len(left) < num_jobs_left:
@ -434,11 +503,12 @@ class SlurmLauncher:
if deadline is not None and time.time() > deadline:
raise TimeoutError(
f"Timeout waiting for {num_jobs_left} jobs. Job ID: "
f"{','.join(sorted([x.slurm_id for x in self.jobs.values()]))}."
f"{','.join(sorted([str(x.slurm_id) for x in self.jobs.values()]))}."
)
self._update_all()
left = list(self.jobs.values())
for slurm_info in list(left):
left = list(self.jobs.keys())
for slurm_id in list(left):
slurm_info = self.jobs[slurm_id]
if slurm_info.slurm_id is None:
continue
if slurm_info.state in check_status:
@ -452,18 +522,18 @@ class SlurmLauncher:
logger.info(
f"Job {slurm_info.name} is {slurm_info.state}. (Removed)"
)
left.remove(slurm_info)
left.remove(slurm_id)
if update:
self.jobs.pop(slurm_info.name)
self.jobs.pop(slurm_info.slurm_id)
time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL)
def _update_all(self):
"""Updates the status of all jobs. """
"""Updates the status of all jobs."""
try:
slurm_infos = query_jobs(list(self.jobs.keys()))
slurm_infos = query_jobs(slurm_ids=list(self.jobs.keys()))
for slurm_info in slurm_infos:
if slurm_info.name in self.jobs:
self.jobs[slurm_info.name] = slurm_info
assert slurm_info.slurm_id is not None
self.jobs[slurm_info.slurm_id] = slurm_info
except subprocess.CalledProcessError:
logger.warning(
"Calling squeue failed. Check slurm manually if you continue to see this warning."
@ -472,99 +542,146 @@ class SlurmLauncher:
def slurm_args_parser():
parser = argparse.ArgumentParser(description="Slurm Launcher for AReaL")
parser.add_argument("entry_point", type=str, help="The entry point script to run.")
parser.add_argument("config_path", type=str, help="Path to the configuration file.")
parser.add_argument(
"--sglang-server-base-port", type=int, required=False, default=27010,
help="Base port for SGLang servers. SGLang servers on the same node will ."
"--sglang-server-base-port",
type=int,
required=False,
default=27010,
help="Base port for SGLang servers. SGLang servers on the same node will .",
)
parser.add_argument(
"--trainer-port",
type=int,
required=False,
default=27009,
help="Pytorch distributed initialization port for trainer.",
)
parser.add_argument(
"--sglang-version",
type=str,
required=False,
default="0.4.6.post4",
help="SGLang version in your GPU inference image.",
)
parser.add_argument("--trainer-port", type=int, required=False, default=27009, help="Pytorch distributed initialization port for trainer.")
# parser.add_argument("remaining_args", nargs='*', help="Additional arguments to pass to the entry point script.")
return parser
if __name__ == "__main__":
# usage: python -m arealite.launcher.slurm <entry_point> --allocation_mode <allocation_mode> <config_path> [<args>]
# usage: python -m arealite.launcher.slurm <entry_point> <config_path> [<args>]
r = parse_cli_args(sys.argv[2:], parser=slurm_args_parser())
entry_point = sys.argv[1]
config = r.config
config_file = r.config_file
args = r.additional_args
cluster_config: ClusterSpecConfig = config.cluster
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
n_nodes = config.n_nodes
n_gpus_per_node = config.n_gpus_per_node
if n_gpus_per_node < cluster_config.n_gpus_per_node:
if n_gpus_per_node < config.cluster.n_gpus_per_node:
raise ValueError(
f"Slurm Launcher requires at least {cluster_config.n_gpus_per_node} (#GPUs per node) GPU. For usecases of less GPUs, use LocalLauncher instead."
f"Slurm Launcher requires at least {config.cluster.n_gpus_per_node} (#GPUs per node) GPU. For usecases of less GPUs, use LocalLauncher instead."
)
elif n_gpus_per_node > cluster_config.n_gpus_per_node:
elif n_gpus_per_node > config.cluster.n_gpus_per_node:
raise ValueError(
f"#GPU per node required by experiment ({n_gpus_per_node}) is larger than #GPU per node in the cluster ({cluster_config.n_gpus_per_node})."
f"#GPU per node required by experiment ({n_gpus_per_node}) is larger than #GPU per node in the cluster ({config.cluster.n_gpus_per_node})."
)
launcher = SlurmLauncher(
experiment_name=config.experiment_name,
trial_name=config.trial_name,
fileroot=cluster_config.fileroot
fileroot=config.cluster.fileroot,
)
allocation_mode = config.allocation_mode
allocation_mode = AllocationMode.from_str(allocation_mode)
sglang_cmds = []
sglang_addrs = []
n_sglang_nodes = 0
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
# Launcher should launch SGLang servers according to allocation mode.
assert isinstance(allocation_mode, str), "Allocation mode should be a string."
sglang_config = find_and_amend_config(config, "sglang", SGLangConfig)
assert "gen" in allocation_mode
assert allocation_mode.gen_pp_size == 1, "Pipeline generation in SGLang is not supported for now."
assert allocation_mode.gen_tp_size <= cluster_config.n_gpus_per_node, "Currently only support SGLang TP size less <= #GPUs per node."
sglang_world_size = allocation_mode.gen_world_size
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
assert (
allocation_mode.gen_pp_size == 1
), "Pipeline generation in SGLang is not supported for now."
assert (
allocation_mode.gen_tp_size <= config.cluster.n_gpus_per_node
), "Currently only support SGLang TP size less <= #GPUs per node."
sglang_world_size = allocation_mode.gen_world_size
sglang_tp_size = allocation_mode.gen_tp_size
n_sglang_server_instances = allocation_mode.gen_dp_size
n_sglang_servers = allocation_mode.gen_dp_size
n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node
n_sglang_servers_per_node = config.cluster.n_gpus_per_node // sglang_tp_size
model_path = find_config(config, "path")
assert model_path is not None and isinstance(model_path, str)
seed = find_config(config, "seed")
base_seed = config.sglang.random_seed
base_port = args.sglang_server_base_port
for i in range(n_sglang_server_instances):
sglang_ports_on_node = set()
for i in range(n_sglang_servers):
base_gpu_id = i * sglang_tp_size % n_gpus_per_node
n_server_per_node = cluster_config.n_gpus_per_node // sglang_tp_size
# Since we cannot get port information from slurm, we only ensure ports on the same .
server_port = base_port + i % n_server_per_node * 2
dist_port = base_port + i % n_server_per_node * 2 + 1
dist_init_addr = f"tcp://localhost:{server_port}"
server_port = base_port + i % n_sglang_servers_per_node * 2
# print(server_port, dist_init_addr)
config.sglang.random_seed = base_seed + i
sglang_cmds.append(
SGLangConfig.build_cmd(
sglang_config,
model_path,
config.sglang,
sglang_tp_size,
base_gpu_id=base_gpu_id,
seed=seed + i
base_gpu_id=0,
host="localhost",
port=server_port,
dist_init_addr=None,
sglang_version=args.sglang_version,
)
)
sglang_ports_on_node.add(server_port)
assert len(sglang_ports_on_node) == n_sglang_servers_per_node
# launch sglang servers
# launch SGLang servers, note that we need to leave some resources on each node
# to schedule jobs that retrieve node IP. (1 CPU core & 10 MB memory)
if sglang_cmds:
launcher.submit_array(
job_name="sglang-server",
cmd=sglang_cmds,
count=n_sglang_nodes,
nodes=n_sglang_server_instances,
n_gpus_per_node=cluster_config.n_gpus_per_node,
cpus_per_task=15 * sglang_tp_size, # one sglang task occupy the entire node
mem_per_task=150 * 1024 * sglang_tp_size, # 150GB per task
container_image=cluster_config.gpu_infer_image,
container_mounts=cluster_config.mount,
env_vars=get_env_vars(cluster_config.cluster_name),
count=n_sglang_servers,
nodes=n_sglang_nodes,
n_gpus_per_node=config.cluster.n_gpus_per_node,
cpus_per_task=15 * sglang_tp_size,
mem_per_task=150 * 1024 * sglang_tp_size, # 20GB per task
container_image=config.cluster.gpu_infer_image,
container_mounts=config.cluster.mount,
env_vars=get_env_vars(config.cluster.cluster_name),
)
# Get SGLang slurm nodes, find the hosts
start_time = time.perf_counter()
while True:
job_info = launcher.find("sglang-server")
assert job_info is not None
print(job_info)
sglang_hosts = job_info.host
logger.info(
f"Waiting for SGLang servers to be scheduled by slurm, time since started = {time.perf_counter() - start_time:.2f}"
)
if sglang_hosts:
print(sglang_hosts)
sglang_hosts = [
get_slurm_host_ip(node)
for node in parse_slurm_nodelist(sglang_hosts)
]
print(sglang_hosts)
for host in sglang_hosts:
sglang_addrs.extend(
[f"{host}:{port}" for port in sglang_ports_on_node]
)
assert len(sglang_addrs) == n_sglang_servers
break
time.sleep(10)
trainer_n_nodes = n_nodes - n_sglang_nodes
assert r.overrides is not None
trainer_cmd_template = (
f"torchrun --nnodes={{nnodes}} --nproc-per-node={{nproc_per_node}} --node-rank {{node_rank}} "
f"--master-addr $head_node_ip --master-port {args.trainer_port} {' '.join(sys.argv[1:])}"
f"--master-addr $head_node_ip --master-port {args.trainer_port} {entry_point} --config {config_file} {' '.join(r.overrides)}"
)
trainer_cmds = []
for i in range(trainer_n_nodes):
# For each trainer node, we launch a trainer with the same command.
@ -572,24 +689,28 @@ if __name__ == "__main__":
trainer_cmds.append(
trainer_cmd_template.format(
nnodes=trainer_n_nodes,
nproc_per_node=cluster_config.n_gpus_per_node,
node_rank=i
nproc_per_node=config.cluster.n_gpus_per_node,
node_rank=i,
)
)
# launch trainers
launcher.submit_array(
job_name="trainer",
cmd=trainer_cmds,
count=trainer_n_nodes,
nodes=trainer_n_nodes,
n_gpus_per_node=cluster_config.n_gpus_per_node,
cpus_per_task=120, # one trainer task occupy the entire node
mem_per_task=1200 * 1024, # 1.2T per task
container_image=cluster_config.gpu_image,
container_mounts=cluster_config.mount,
env_vars=get_env_vars(cluster_config.cluster_name),
)
if not config.server_only:
# launch trainers
launcher.submit_array(
job_name="trainer",
cmd=trainer_cmds,
count=trainer_n_nodes,
nodes=trainer_n_nodes,
n_gpus_per_node=config.cluster.n_gpus_per_node,
cpus_per_task=120, # one trainer task occupy the entire node
mem_per_task=1024 * 1024, # 1024GB per task
container_image=config.cluster.gpu_image,
container_mounts=config.cluster.mount,
env_vars=dict(
**get_env_vars(config.cluster.cluster_name),
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
),
)
try:
launcher.wait(
@ -602,6 +723,5 @@ if __name__ == "__main__":
remove_status=(),
)
except (KeyboardInterrupt, JobException, TimeoutError) as e:
launcher.stop_all("SIGTERM")
launcher.stop_all("SIGKILL")
raise e

View File

@ -1,5 +1,6 @@
from omegaconf import DictConfig, OmegaConf
def find_config(config: DictConfig, name: str) -> DictConfig | None:
# iterate through the nested DictConfig and find the first matching config with name
for key, value in config.items():
@ -11,6 +12,7 @@ def find_config(config: DictConfig, name: str) -> DictConfig | None:
return found
return None
def amend_config(config: DictConfig, config_cls):
default_config = OmegaConf.structured(config_cls)
config = OmegaConf.merge(default_config, config)
@ -18,9 +20,10 @@ def amend_config(config: DictConfig, config_cls):
assert isinstance(config, config_cls)
return config
def find_and_amend_config(config: DictConfig, name: str, config_cls):
# Find the config with the given name and amend it with the given config_cls
found = find_config(config, name)
if found is not None:
return amend_config(found, config_cls)
return None
return None

View File

@ -41,12 +41,12 @@ class JobException(Exception):
class JobInfo:
name: str
state: JobState
host: str = (
host: Optional[str] = (
None # The host on which the job is/was running. None if the job had not run.
)
submit_time: str = None
start_time: str = None
slurm_id: str = None # Slurm only. The Slurm id of the job.
submit_time: Optional[str] = None
start_time: Optional[str] = None
slurm_id: Optional[int] = None # Slurm only. The Slurm id of the job.
class SchedulerClient: