merge slurm launcher

This commit is contained in:
bowei.fw 2025-07-12 14:54:49 +08:00
commit 0d03141cbc
7 changed files with 921 additions and 11 deletions

View File

@ -340,12 +340,12 @@ class SGLangConfig:
host,
port,
dist_init_addr: Optional[str] = None,
sglang_version: Optional[str] = None,
):
from realhf.base import pkg_version
from realhf.experiments.common.utils import asdict as conf_as_dict
args: Dict = conf_as_dict(sglang_config)
args = dict(
host=host,
port=port,
@ -362,13 +362,28 @@ class SGLangConfig:
base_gpu_id=base_gpu_id,
nnodes=1,
node_rank=0,
# initialization addresses and ports
dist_init_addr=dist_init_addr,
**args,
)
if sglang_version:
version_less_than_0_4_4 = (
pkg_version.compare_versions(sglang_version, "0.4.4") < 0
)
version_less_than_0_4_3 = (
pkg_version.compare_versions(sglang_version, "0.4.3") < 0
)
elif pkg_version.is_available("sglang"):
version_less_than_0_4_4 = pkg_version.is_version_less("sglang", "0.4.4")
version_less_than_0_4_3 = pkg_version.is_version_less("sglang", "0.4.3")
else:
raise ValueError(
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
)
if pkg_version.is_version_less("sglang", "0.4.4"):
if version_less_than_0_4_4:
args.pop("log_requests_level")
if pkg_version.is_version_less("sglang", "0.4.3"):
if version_less_than_0_4_3:
args.pop("enable_nccl_nvls")
args.pop("triton_attention_num_kv_splits")
args.pop("cuda_graph_bs")
@ -608,6 +623,46 @@ class DatasetConfig:
drop_last: bool = field(default=True)
@dataclass
class LauncherConfig:
"""Configuration for launching the SGLang server."""
inference_server_cpus_per_gpu: int = field(
default=4,
metadata={"help": "Number of CPUs allocated per GPU for inference server. "},
)
inference_server_mem_per_gpu: int = field(
default=32 * 1024,
metadata={"help": "Memory allocated per GPU for inference server in MB. "},
)
trainer_cpus_per_gpu: int = field(
default=4,
metadata={"help": "Number of CPUs allocated per GPU for training. "},
)
trainer_mem_per_gpu: int = field(
default=32 * 1024,
metadata={"help": "Memory allocated per GPU for training in MB. "},
)
inference_server_env_vars: str = field(
default="",
metadata={
"help": "Environment variables for inference server, seperated by commas. "
"Example: 'ENV1=val1,ENV2=val2'. "
},
)
trainer_env_vars: str = field(
default="",
metadata={
"help": "Environment variables for training, seperated by commas. "
"Example: 'ENV1=val1,ENV2=val2'. "
},
)
trainer_port: int = field(
default=27015,
metadata={"help": "Trainer port used for torch.distributed initialization."},
)
@dataclass
class BaseExperimentConfig:
# NOTE: we need this unified config class because different experiments
@ -668,6 +723,7 @@ class BaseExperimentConfig:
server_only: bool = False
sglang: SGLangConfig = field(default_factory=SGLangConfig)
launcher: LauncherConfig = field(default_factory=LauncherConfig)
@dataclass
@ -714,7 +770,7 @@ def to_structured_cfg(cfg, config_cls):
return cfg
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
def load_expr_config(argv: List[str], config_cls):
cfg, config_file = parse_cli_args(argv)
cfg = to_structured_cfg(cfg, config_cls=config_cls)
cfg = OmegaConf.to_object(cfg)

View File

@ -57,6 +57,8 @@ class RemoteSGLangEngine(InferenceEngine):
raise RuntimeError("No configured SGLang servers.")
logger.info("Waiting for server ready...")
for addr in self.addresses:
# FIXME
print(f"waiting for server address {addr}")
self._wait_for_server(addr)
logger.info("Servers are all ready!")
@ -92,7 +94,8 @@ class RemoteSGLangEngine(InferenceEngine):
timeout=30,
)
return response.status_code == 200
except requests.exceptions.RequestException:
except requests.exceptions.RequestException as e:
print(f"Check {base_url}/metrics failed, reason: {e}")
return False
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):

0
arealite/launcher/ray.py Normal file
View File

View File

@ -0,0 +1,157 @@
import os
import subprocess
import sys
import time
from pathlib import Path
from typing import Optional
import ray
import requests
from arealite.api.cli_args import (
NameResolveConfig,
SGLangConfig,
parse_cli_args,
to_structured_cfg,
)
from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.utils.network import find_free_ports, gethostip
from realhf.base import logging, name_resolve, names, pkg_version
logger = logging.getLogger("SGLangServer Wrapper")
def execute_shell_command(command: str) -> subprocess.Popen:
"""
Execute a shell command and return its process handle.
"""
# Replace newline continuations and split the command string.
command = command.replace("\\\n", " ").replace("\\", " ")
parts = command.split()
return subprocess.Popen(
parts,
text=True,
stdout=sys.stdout,
stderr=subprocess.STDOUT,
)
def apply_sglang_patch():
p = Path(os.path.dirname(__file__))
patch_path = str(
p.parent.parent
/ "patch"
/ "sglang"
/ f"v{pkg_version.get_version('sglang')}.patch"
)
target_path = ""
sglang_meta = subprocess.check_output(
"python3 -m pip show sglang", shell=True
).decode("ascii")
for line in sglang_meta.split("\n"):
line = line.strip()
if line.startswith("Editable project location: "):
target_path = str(Path(line.split(": ")[1]).parent)
if target_path:
proc = subprocess.Popen(
["git", "apply", patch_path],
cwd=target_path,
stderr=sys.stdout,
stdout=sys.stdout,
)
proc.wait()
logger.info(f"Applied SGLang patch at {target_path}")
def launch_server_cmd(command: str):
"""
Launch the server using the given command.
If no port is specified, a free port is reserved.
"""
if not ray.is_initialized():
apply_sglang_patch()
process = execute_shell_command(command)
return process
def wait_for_server(base_url: str, timeout: Optional[int] = None) -> None:
"""Wait for the server to be ready by polling the /v1/models endpoint.
Args:
base_url: The base URL of the server
timeout: Maximum time to wait in seconds. None means wait forever.
"""
start_time = time.time()
while True:
try:
response = requests.get(
f"{base_url}/v1/models",
headers={"Authorization": "Bearer None"},
)
if response.status_code == 200:
time.sleep(5)
break
if timeout and time.time() - start_time > timeout:
raise TimeoutError("Server did not become ready within timeout period")
except requests.exceptions.RequestException:
time.sleep(1)
class SGLangServerWrapper:
def __init__(
self,
experiment_name: str,
trial_name: str,
sglang_config: SGLangConfig,
tp_size: int,
):
self.experiment_name = experiment_name
self.trial_name = trial_name
self.config = sglang_config
self.tp_size = tp_size
self.server_process = None
def run(self):
server_port, dist_init_port = find_free_ports(2, (10000, 50000))
dist_init_addr = f"localhost:{dist_init_port}"
host_ip = gethostip()
cmd = SGLangConfig.build_cmd(
self.config, tp_size, 0, host_ip, server_port, dist_init_addr=dist_init_addr
)
self.server_process = launch_server_cmd(cmd)
wait_for_server(f"http://{host_ip}:{server_port}")
name = names.gen_servers(self.experiment_name, self.trial_name)
name_resolve.add_subentry(name, f"{host_ip}:{server_port}")
logger.info(f"SGLang server launched at: http://{host_ip}:{server_port}")
return_code = self.server_process.wait()
logger.info(
f"SGLang server at http://{host_ip}:{server_port} exits, returncode={return_code}"
)
if __name__ == "__main__":
config, _ = parse_cli_args(sys.argv[2:])
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
config.cluster.name_resolve = to_structured_cfg(
config.cluster.name_resolve, NameResolveConfig
)
name_resolve.reconfigure(config.cluster.name_resolve)
allocation_mode = config.allocation_mode
allocation_mode = AllocationMode.from_str(allocation_mode)
assert allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG
tp_size = allocation_mode.gen_tp_size
sglang_server = SGLangServerWrapper(
config.experiment_name,
config.trial_name,
config.sglang,
tp_size,
)
sglang_server.run()

694
arealite/launcher/slurm.py Normal file
View File

@ -0,0 +1,694 @@
import argparse
import getpass
import os
import pathlib
import re
import subprocess
import sys
import time
from typing import Dict, List, Literal, Optional, Tuple
from omegaconf import OmegaConf
import realhf.base.logging as logging
from arealite.api.cli_args import (
BaseExperimentConfig,
ClusterSpecConfig,
LauncherConfig,
SGLangConfig,
parse_cli_args,
to_structured_cfg,
)
from arealite.api.io_struct import AllocationMode, AllocationType
from realhf.base import logging, name_resolve, names
from realhf.scheduler.client import JobException, JobInfo, JobState
logger = logging.getLogger("SlurmLauncher")
SQUEUE_FIELDS = [
"JobID",
"State",
"SubmitTime",
"StartTime",
"Name",
"NodeList",
"UserName",
"MaxCPUs",
"cpus-per-task",
"NumTasks",
"tres-alloc",
]
STATUS_MAPPING = {
"RUNNING": JobState.RUNNING,
"COMPLETING": JobState.RUNNING,
"PENDING": JobState.PENDING,
"CANCELLED": JobState.CANCELLED,
"FAILED": JobState.FAILED,
"COMPLETED": JobState.COMPLETED,
"OUT_OF_MEMORY": JobState.FAILED,
"DEADLINE": JobState.COMPLETED,
"TIMEOUT": JobState.COMPLETED,
}
def cancel_jobs(
slurm_names: Optional[List[str]] = None,
slurm_ids: Optional[List[int]] = None,
signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL",
):
assert (
slurm_names is not None or slurm_ids is not None
), "Must specify slurm_names or slurm_ids."
assert not (
slurm_names and slurm_ids
), "Cannot specify both slurm_names and slurm_ids."
cmd = ["scancel", "-s", signal]
if slurm_names is not None:
cmd += ["-n", ",".join(slurm_names)]
elif slurm_ids is not None:
cmd += [",".join(str(s) for s in slurm_ids)]
subprocess.check_call(cmd)
logger.info(
f"Cancelled Slurm job with signal {signal}: "
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}"
)
def query_jobs(
slurm_names: Optional[List[str]] = None,
slurm_ids: Optional[List[int]] = None,
status: str = "all",
delimiter: str = "__PSI__",
) -> List[JobInfo]:
squeue_format = f":.{delimiter},".join(SQUEUE_FIELDS)
cmd = ["squeue", "-O", squeue_format, f"-t{status}"]
if slurm_names is not None:
cmd += ["-n", ",".join(slurm_names)]
if slurm_ids is not None:
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
output = (
subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip()
)
rs = []
for line in output.split("\n")[1:]:
job_id, state, submit_time, start_time, slurm_name, nodelist, *_ = line.split(
delimiter
)
rs.append(
JobInfo(
name=slurm_name,
state=STATUS_MAPPING[state],
host=nodelist,
submit_time=submit_time,
start_time=start_time,
slurm_id=int(job_id.strip()),
)
)
return rs
def parse_slurm_nodelist(nodelist: str) -> List[str]:
return (
subprocess.check_output(
[
"scontrol",
"show",
"hostnames",
nodelist,
]
)
.decode("utf-8")
.strip()
.split("\n")
)
def get_slurm_host_ip(node: str):
try:
cmd = f"srun --overlap --mpi=pmi2 --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address"
return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip()
except subprocess.CalledProcessError:
logger.warning(f"Get slurm host ip for node {node} failed.")
SGLANG_SERVER_TIMEOUT_SECONDS = 120
SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5
SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash
{sbatch_options}
# Getting the node names
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
echo nodes=$nodes
nodes_array=($nodes)
echo node_array=$nodes_array
head_node=${{nodes_array[0]}}
echo head_node=$head_node
# Getting the head node IP address
head_node_ip=$(srun --overlap --mpi=pmi2 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address)
echo head_node_ip=$head_node_ip
# srun commands
{srun_cmds}
wait
"""
SRUN_CMD_TEMPLATE = """srun --overlap --mpi=pmi2 -K -l --chdir $PWD --nodelist=${{nodes_array[{node_id}]}} \\
--nodes={nodes} --ntasks={ntasks} --gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} \\
--mem-per-cpu={mem_per_cpu}M {apptainer_name} exec {apptainer_options} --bind {container_mounts} \\
{container_env_strings} \\
{container_image} \\
{cmd} &
"""
LOCAL_CACHE_DIR = "/tmp/arealite"
PYTORCH_KERNEL_CACHE_PATH = (
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
)
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
BASE_ENVIRONS = {
"TOKENIZERS_PARALLELISM": "true",
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
"HF_ENDPOINT": "https://hf-mirror.com", # FIXME: move to user option
"PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent,
}
NA132_ENVIRONS = {
"NCCL_SOCKET_IFNAME": "bond0",
"NCCL_NET_PLUGIN": "",
"NCCL_IB_GID_INDEX": "3",
"NCCL_IB_TIMEOUT": "2",
"NCCL_IB_RETRY_CNT": "7",
"NCCL_IB_SL": "5",
"NCCL_IB_TC": "136",
"NCCL_IB_HCA": "mlx5_bond",
"NCCL_IB_QPS_PER_CONNECTION": "8",
"NCCL_SET_THREAD_NAME": "1",
"NCCL_DEBUG": "WARN",
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
}
def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]:
"""Returns the environment variables for the cluster."""
additional_env_vars = {}
if additional_env_vars:
additional_env_vars = dict(
item.split("=") for item in additional_env_vars.split(",")
)
if cluster_name == "na132":
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars}
else:
return {**BASE_ENVIRONS, **additional_env_vars}
class SlurmLauncher:
def __init__(self, experiment_name: str, trial_name: str, fileroot: str):
self.experiment_name = experiment_name
self.trial_name = trial_name
self.fileroot = fileroot
# slurm_job_id -> JobInfo
self.jobs: Dict[int, JobInfo] = {}
self.job_names = []
@property
def run_name(self) -> str:
"""Returns the run name of this launcher."""
return f"{self.experiment_name}_{self.trial_name}"
def slurm_name(self, job_name: str) -> str:
"""Returns the slurm name of a job."""
return f"{self.experiment_name}_{self.trial_name}:{job_name}"
def log_path_of(self, job_name: str) -> str:
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
os.makedirs(log_path, exist_ok=True)
return os.path.join(log_path, f"{job_name}.log")
def sbatch_path_of(self, job_name: str) -> str:
sbatch_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
os.makedirs(sbatch_path, exist_ok=True)
return os.path.join(sbatch_path, f"{job_name}.sh")
def submit(self, job_name, cmd, **kwargs):
"""Submits and launch a job with SBATCH.
Args:
cmd (str or List[str]): The core command to be executed.
"""
return self.submit_array(job_name, cmd, count=1, **kwargs)
def find_job_id(self, job_name: str):
job_name = self.slurm_name(job_name)
for job_id, job_info in self.jobs.items():
if job_info.name == job_name:
return job_id
return None
def submit_array(
self,
job_name: str,
cmd: List[str] | str,
count: int,
nodes: int,
n_gpus_per_node: int,
cpus_per_task: int,
mem_per_task: int, # MB
container_image: str,
container_mounts: Optional[str] = None,
env_vars: Optional[Dict] = None,
nodelist: Optional[str] = None,
exclude: Optional[str] = None,
apptainer_name: Optional[str] = "singularity",
apptainer_options: Optional[Tuple[str, ...]] = (
"--no-home",
"--writable-tmpfs",
"--nv",
"--pid",
),
):
"""Submits and launch a job array with SBATCH.
Note that a job array has one (unique) slurm name, and one (unique) slurm id.
Args:
job_name (str): The job name of the job array. The actual slurm name will be
`<experiment_name>_<trial_name>:<job_name>`.
cmd (str or List[str]): The core command to be executed.
count (int): The number of jobs in the array.
"""
assert job_name not in self.job_names, (
f"Job {job_name} is already submitted. "
"Please use a different job name or stop the existing job."
)
if isinstance(cmd, str):
cmd = [cmd]
assert len(cmd) == count, (
f"Command length {len(cmd)} does not match the job count {count}. "
"Please provide a command for each job in the array."
)
assert count % nodes == 0, (
f"Job count {count} must be divisible by the number of nodes {nodes}. "
"Please adjust the job count or the number of nodes."
)
ntasks_per_node = count // nodes
assert n_gpus_per_node % ntasks_per_node == 0, (
"GPUs must be evenly distributed across tasks. "
f"Current #GPUs per node {n_gpus_per_node}, #tasks per node {ntasks_per_node}."
)
mem_per_cpu = mem_per_task // cpus_per_task # MB per CPU
mem_per_node = (
mem_per_task * count // nodes + 1024 * 10
) # make sure slurm does not run out of resources
sbatch_options = [
f"--job-name={self.slurm_name(job_name)}",
f"--output={self.log_path_of(job_name)}",
"--open-mode=append",
"--no-requeue",
f"--nodes={nodes}-{nodes}",
f"--ntasks-per-node={ntasks_per_node}",
f"--gres=gpu:{n_gpus_per_node}",
f"--cpus-per-task={cpus_per_task}",
f"--mem={mem_per_node}M",
]
if nodelist:
sbatch_options.append(f"--nodelist={nodelist}")
if exclude:
sbatch_options.append(f"--exclude={exclude}")
sbatch_options_str = "\n".join([f"#SBATCH {opt}" for opt in sbatch_options])
if env_vars is None:
env_vars = dict()
n_gpus_per_task = n_gpus_per_node // ntasks_per_node
assert (
"CUDA_VISIBLE_DEVICES" not in env_vars
), "CUDA_VISIBLE_DEVICES should be automatically resolved by Launcher instead of manually assigned."
srun_cmds = []
for i in range(count):
# resolve CUDA_VISIBLE_DEVICES for each task
gpu_id_start = (i % ntasks_per_node) * n_gpus_per_task
gpu_id_end = ((i % ntasks_per_node) + 1) * n_gpus_per_task
node_id = i // ntasks_per_node
_env_vars = {
**env_vars,
"CUDA_VISIBLE_DEVICES": ",".join(
str(x) for x in range(gpu_id_start, gpu_id_end)
),
}
env_string = " ".join(
"--env {}={}".format(k, v) for k, v in (_env_vars or {}).items()
)
# Prepare the command for each job in the array
job_cmd = cmd[i]
# FIXME: only for debugging, remove and replace new image
job_cmd = f'bash -c "pip3 install -U gymnasium torchdata tensordict hf-xet; {job_cmd}"'
srun_cmd = SRUN_CMD_TEMPLATE.format(
nodes=1,
ntasks=1,
node_id=node_id,
n_gpus_per_node=n_gpus_per_node,
cpus_per_task=cpus_per_task,
mem_per_cpu=mem_per_cpu,
apptainer_name=apptainer_name,
apptainer_options=" ".join(apptainer_options),
container_mounts=container_mounts or "",
container_env_strings=env_string,
container_image=container_image,
cmd=job_cmd,
)
srun_cmds.append(srun_cmd)
srun_cmds = "\n".join(srun_cmds)
sbatch_script = SBATCH_SCRIPT_TEMPLATE.format(
sbatch_options=sbatch_options_str, srun_cmds=srun_cmds
)
sbatch_file_path = self.sbatch_path_of(f"{job_name}")
with open(sbatch_file_path, "w") as f:
f.write(sbatch_script)
# Submit the job
try:
output = (
subprocess.check_output(["sbatch", sbatch_file_path])
.decode("utf-8")
.strip()
)
logger.info(
f"Submitted Slurm job {self.slurm_name(job_name)} to scheduler. To check the output, run \n\t`tail -f {self.log_path_of(job_name)}`."
)
except subprocess.CalledProcessError as e:
logger.warning(
f"Failed to submit job {self.slurm_name(job_name)}. "
f"For debugging, please make sure your sbatch command works "
f"and check generated sbatch file on {sbatch_file_path}."
)
logger.error(f"Error message: {e}")
return
match = re.search(r"Submitted batch job (\d+)", output)
slurm_job_id = int(match.group(1)) if match else None
if slurm_job_id is None:
logger.warning(
f"Failed to obtain job id for job {self.slurm_name(job_name)}. "
f"sbatch output: {output}"
)
return
assert isinstance(slurm_job_id, int)
self.jobs[slurm_job_id] = JobInfo(
name=self.slurm_name(job_name),
state=JobState.PENDING,
slurm_id=slurm_job_id,
)
self._update_all()
def stop(self, job_name, signal="SIGKILL"):
"""Stops a running job.
Raises exception if there is no such job, but passes if the job
has stopped either successfully or not.
Args:
job_name: The job name of the job array to stop.
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
"""
job_id = self.find_job_id(job_name)
if not job_id:
return
return cancel_jobs(slurm_ids=[job_id], signal=signal)
def stop_all(self, signal="SIGKILL"):
"""Stops all running jobs."""
return cancel_jobs(slurm_ids=list(self.jobs.keys()), signal=signal)
def find(self, job_name) -> JobInfo | None:
"""Gets the status of a job of this job.
Args:
job_name: The job name of the job array to find.
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
Returns:
A JobInfo if the job is found, or None otherwise.
"""
self._update_all()
job_id = self.find_job_id(job_name)
return self.jobs[job_id] if job_id else None
def find_all(self, job_name_regex=".*") -> List[JobInfo]:
"""Finds jobs.
Args:
job_name_regex: job name regex.
Returns:
A list of found JobInfo.
"""
self._update_all()
infos = []
for r in self.jobs.values():
job_name = r.name.split(":")[-1] # Extract the job name from slurm name
if re.fullmatch(job_name_regex, job_name):
infos.append(r)
return infos
def _find_job_with_status(
self,
status: List[JobState],
) -> List[JobInfo]:
"""Finds jobs with the given status.
Args:
status: A list of JobState to filter jobs.
Returns:
A list of JobInfo with the given status.
"""
self._update_all()
return [r for r in self.jobs.values() if r.state in status]
def wait(
self,
timeout=None,
check_status: Tuple[JobState, ...] = (
JobState.CANCELLED,
JobState.FAILED,
JobState.NOT_FOUND,
),
remove_status: Tuple[JobState, ...] = (JobState.COMPLETED,),
update=False,
):
"""Waits until all jobs submitted via this client instance finish."""
# begin wait
deadline = None if timeout is None else time.time() + timeout
num_jobs_left = len(self.jobs)
left = list(self.jobs.keys())
logger.info(
f"Waiting for {num_jobs_left} jobs. Jobs IDs: "
f"{','.join(sorted([str(x.slurm_id) for x in self.jobs.values()]))}."
)
while len(left) > 0:
if len(left) < num_jobs_left:
num_jobs_left = len(left)
logger.info(f"Waiting for {num_jobs_left} jobs.")
if deadline is not None and time.time() > deadline:
raise TimeoutError(
f"Timeout waiting for {num_jobs_left} jobs. Job ID: "
f"{','.join(sorted([str(x.slurm_id) for x in self.jobs.values()]))}."
)
self._update_all()
left = list(self.jobs.keys())
for slurm_id in list(left):
slurm_info = self.jobs[slurm_id]
if slurm_info.slurm_id is None:
continue
if slurm_info.state in check_status:
raise JobException(
run_name=self.run_name,
worker_type=slurm_info.name,
host=slurm_info.host,
reason=slurm_info.state,
)
if slurm_info.state in remove_status:
logger.info(
f"Job {slurm_info.name} is {slurm_info.state}. (Removed)"
)
left.remove(slurm_id)
if update:
self.jobs.pop(slurm_info.slurm_id)
time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL)
def _update_all(self):
"""Updates the status of all jobs."""
try:
slurm_infos = query_jobs(slurm_ids=list(self.jobs.keys()))
for slurm_info in slurm_infos:
assert slurm_info.slurm_id is not None
self.jobs[slurm_info.slurm_id] = slurm_info
except subprocess.CalledProcessError:
logger.warning(
"Calling squeue failed. Check slurm manually if you continue to see this warning."
)
if __name__ == "__main__":
# usage: python -m arealite.launcher.slurm <entry_point> <config_path> [<args>]
config, config_file = parse_cli_args(sys.argv[2:])
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
name_resolve.reconfigure(config.cluster.name_resolve)
n_nodes = config.n_nodes
n_gpus_per_node = config.n_gpus_per_node
if n_gpus_per_node < config.cluster.n_gpus_per_node:
raise ValueError(
f"Slurm Launcher requires at least {config.cluster.n_gpus_per_node} (#GPUs per node) GPU. For usecases of less GPUs, use LocalLauncher instead."
)
elif n_gpus_per_node > config.cluster.n_gpus_per_node:
raise ValueError(
f"#GPU per node required by experiment ({n_gpus_per_node}) is larger than #GPU per node in the cluster ({config.cluster.n_gpus_per_node})."
)
launcher = SlurmLauncher(
experiment_name=config.experiment_name,
trial_name=config.trial_name,
fileroot=config.cluster.fileroot,
)
allocation_mode = config.allocation_mode
allocation_mode = AllocationMode.from_str(allocation_mode)
sglang_cmds = []
sglang_addrs = []
n_sglang_nodes = 0
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
# Launcher should launch SGLang servers according to allocation mode.
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
assert (
allocation_mode.gen_pp_size == 1
), "Pipeline generation in SGLang is not supported for now."
assert (
allocation_mode.gen_tp_size <= config.cluster.n_gpus_per_node
), "Currently only support SGLang TP size less <= #GPUs per node."
sglang_world_size = allocation_mode.gen_world_size
sglang_tp_size = allocation_mode.gen_tp_size
n_sglang_servers = allocation_mode.gen_dp_size
n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node
n_sglang_servers_per_node = config.cluster.n_gpus_per_node // sglang_tp_size
base_seed = config.sglang.random_seed
sglang_server_cmd_template = f"python3 -m arealite.launcher.sglang_server {' '.join(sys.argv[1:])} sglang.random_seed={{seed}}"
for i in range(n_sglang_servers):
sglang_cmd = sglang_server_cmd_template.format(
seed=base_seed + i,
)
sglang_cmds.append(sglang_cmd)
if sglang_cmds:
launcher.submit_array(
job_name="sglang-server",
cmd=sglang_cmds,
count=n_sglang_servers,
nodes=n_sglang_nodes,
n_gpus_per_node=config.cluster.n_gpus_per_node,
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
* sglang_tp_size,
mem_per_task=config.launcher.inference_server_mem_per_gpu
* sglang_tp_size,
container_image=config.cluster.gpu_infer_image,
container_mounts=config.cluster.mount,
env_vars=get_env_vars(
config.cluster.cluster_name,
config.launcher.inference_server_env_vars,
),
)
# Get SGLang slurm nodes, find the hosts
name = names.gen_servers(config.experiment_name, config.trial_name)
start = time.perf_counter()
while True:
sglang_addrs = name_resolve.get_subtree(name)
if len(sglang_addrs) >= n_sglang_servers:
logger.info(
f"Found {n_sglang_servers} SGLang servers: {', '.join(sglang_addrs)}"
)
break
time.sleep(1)
if time.perf_counter() - start > SGLANG_SERVER_TIMEOUT_SECONDS:
raise TimeoutError(
f"Timeout waiting for SGLang servers to be ready. "
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
)
trainer_n_nodes = n_nodes - n_sglang_nodes
trainer_cmd_template = (
f"torchrun --nnodes={{nnodes}} --nproc-per-node={{nproc_per_node}} --node-rank {{node_rank}} "
f"--master-addr $head_node_ip --master-port {config.launcher.trainer_port} {' '.join(sys.argv[1:])}"
)
trainer_cmds = []
for i in range(trainer_n_nodes):
# For each trainer node, we launch a trainer with the same command.
# The node rank is the index of the node in the cluster.
trainer_cmds.append(
trainer_cmd_template.format(
nnodes=trainer_n_nodes,
nproc_per_node=config.cluster.n_gpus_per_node,
node_rank=i,
)
)
if not config.server_only:
# launch trainers
launcher.submit_array(
job_name="trainer",
cmd=trainer_cmds,
count=trainer_n_nodes,
nodes=trainer_n_nodes,
n_gpus_per_node=config.cluster.n_gpus_per_node,
cpus_per_task=config.launcher.trainer_cpus_per_gpu
* config.cluster.n_gpus_per_node,
mem_per_task=config.launcher.trainer_mem_per_gpu
* config.cluster.n_gpus_per_node,
container_image=config.cluster.gpu_image,
container_mounts=config.cluster.mount,
env_vars=dict(
**get_env_vars(
config.cluster.cluster_name,
config.launcher.trainer_env_vars,
),
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
),
)
try:
launcher.wait(
check_status=(
JobState.CANCELLED,
JobState.FAILED,
JobState.NOT_FOUND,
JobState.COMPLETED,
),
remove_status=(),
)
except (KeyboardInterrupt, JobException, TimeoutError) as e:
launcher.stop_all("SIGKILL")
raise e

View File

@ -41,12 +41,12 @@ class JobException(Exception):
class JobInfo:
name: str
state: JobState
host: str = (
host: Optional[str] = (
None # The host on which the job is/was running. None if the job had not run.
)
submit_time: str = None
start_time: str = None
slurm_id: str = None # Slurm only. The Slurm id of the job.
submit_time: Optional[str] = None
start_time: Optional[str] = None
slurm_id: Optional[int] = None # Slurm only. The Slurm id of the job.
class SchedulerClient:

View File

@ -40,7 +40,7 @@ def execute_shell_command(command: str) -> subprocess.Popen:
)
def apply_sglang_path():
def apply_sglang_patch():
p = Path(os.path.dirname(__file__))
patch_path = str(
p.parent.parent
@ -75,7 +75,7 @@ def launch_server_cmd(command: str, port: int = 30000):
If no port is specified, a free port is reserved.
"""
if not ray.is_initialized():
apply_sglang_path()
apply_sglang_patch()
assert port is not None
full_command = f"{command} --port {port}"
process = execute_shell_command(full_command)