mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
parent
97511e43ff
commit
4a26f28adf
|
@ -283,12 +283,12 @@ class SGLangConfig:
|
|||
host,
|
||||
port,
|
||||
dist_init_addr: Optional[str] = None,
|
||||
sglang_version: Optional[str] = None,
|
||||
):
|
||||
from realhf.base import network, pkg_version, seeding
|
||||
from realhf.base import pkg_version
|
||||
from realhf.experiments.common.utils import asdict as conf_as_dict
|
||||
|
||||
args: Dict = conf_as_dict(sglang_config)
|
||||
|
||||
args = dict(
|
||||
host=host,
|
||||
port=port,
|
||||
|
@ -309,10 +309,24 @@ class SGLangConfig:
|
|||
dist_init_addr=dist_init_addr,
|
||||
**args,
|
||||
)
|
||||
if sglang_version:
|
||||
version_less_than_0_4_4 = (
|
||||
pkg_version.compare_versions(sglang_version, "0.4.4") < 0
|
||||
)
|
||||
version_less_than_0_4_3 = (
|
||||
pkg_version.compare_versions(sglang_version, "0.4.3") < 0
|
||||
)
|
||||
elif pkg_version.is_available("sglang"):
|
||||
version_less_than_0_4_4 = pkg_version.is_version_less("sglang", "0.4.4")
|
||||
version_less_than_0_4_3 = pkg_version.is_version_less("sglang", "0.4.3")
|
||||
else:
|
||||
raise ValueError(
|
||||
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
|
||||
)
|
||||
|
||||
if pkg_version.is_version_less("sglang", "0.4.4"):
|
||||
if version_less_than_0_4_4:
|
||||
args.pop("log_requests_level")
|
||||
if pkg_version.is_version_less("sglang", "0.4.3"):
|
||||
if version_less_than_0_4_3:
|
||||
args.pop("enable_nccl_nvls")
|
||||
args.pop("triton_attention_num_kv_splits")
|
||||
args.pop("cuda_graph_bs")
|
||||
|
@ -634,8 +648,12 @@ class ArgParseResult:
|
|||
config: BaseExperimentConfig
|
||||
config_file: Path
|
||||
additional_args: Optional[argparse.Namespace] = None
|
||||
overrides: Optional[List[str]] = None
|
||||
|
||||
def parse_cli_args(argv: List[str], parser: Optional[argparse.ArgumentParser] = None) -> ArgParseResult:
|
||||
|
||||
def parse_cli_args(
|
||||
argv: List[str], parser: Optional[argparse.ArgumentParser] = None
|
||||
) -> ArgParseResult:
|
||||
if parser is None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
@ -654,11 +672,10 @@ def parse_cli_args(argv: List[str], parser: Optional[argparse.ArgumentParser] =
|
|||
overrides=overrides,
|
||||
)
|
||||
return ArgParseResult(
|
||||
config=cfg,
|
||||
config_file=config_file,
|
||||
additional_args=args
|
||||
config=cfg, config_file=config_file, additional_args=args, overrides=overrides
|
||||
)
|
||||
|
||||
|
||||
def to_structured_cfg(cfg, config_cls):
|
||||
# Merge with the default configuration.
|
||||
# The yaml and commandline can omit some default values defined in python dataclasses.
|
||||
|
@ -668,9 +685,7 @@ def to_structured_cfg(cfg, config_cls):
|
|||
|
||||
|
||||
def load_expr_config(
|
||||
argv: List[str],
|
||||
config_cls,
|
||||
parser: Optional[argparse.ArgumentParser] = None
|
||||
argv: List[str], config_cls, parser: Optional[argparse.ArgumentParser] = None
|
||||
) -> ArgParseResult:
|
||||
r = parse_cli_args(argv, parser=parser)
|
||||
cfg = r.config
|
||||
|
@ -684,7 +699,5 @@ def load_expr_config(
|
|||
constants.set_experiment_trial_names(cfg.experiment_name, cfg.trial_name)
|
||||
name_resolve.reconfigure(cfg.cluster.name_resolve)
|
||||
return ArgParseResult(
|
||||
config=cfg,
|
||||
config_file=r.config_file,
|
||||
additional_args=r.additional_args
|
||||
config=cfg, config_file=r.config_file, additional_args=r.additional_args
|
||||
)
|
||||
|
|
|
@ -1,23 +1,25 @@
|
|||
from typing import Tuple, List, Optional, Literal, Dict
|
||||
import subprocess
|
||||
import time
|
||||
import re
|
||||
import os
|
||||
import getpass
|
||||
import argparse
|
||||
import getpass
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from arealite.api.cli_args import SGLangConfig, ClusterSpecConfig, parse_cli_args, to_structured_cfg
|
||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.launcher.utils import find_config, find_and_amend_config
|
||||
import realhf.base.logging as logging
|
||||
from realhf.scheduler.client import (
|
||||
JobException,
|
||||
JobInfo,
|
||||
JobState
|
||||
from arealite.api.cli_args import (
|
||||
ClusterSpecConfig,
|
||||
SGLangConfig,
|
||||
parse_cli_args,
|
||||
to_structured_cfg,
|
||||
)
|
||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.launcher.utils import find_and_amend_config, find_config
|
||||
from realhf.scheduler.client import JobException, JobInfo, JobState
|
||||
|
||||
logger = logging.getLogger("SlurmLauncher")
|
||||
|
||||
|
@ -50,7 +52,7 @@ STATUS_MAPPING = {
|
|||
|
||||
def cancel_jobs(
|
||||
slurm_names: Optional[List[str]] = None,
|
||||
slurm_ids: Optional[List[str]] = None,
|
||||
slurm_ids: Optional[List[int]] = None,
|
||||
signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL",
|
||||
):
|
||||
assert (
|
||||
|
@ -63,16 +65,17 @@ def cancel_jobs(
|
|||
if slurm_names is not None:
|
||||
cmd += ["-n", ",".join(slurm_names)]
|
||||
elif slurm_ids is not None:
|
||||
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
|
||||
cmd += [",".join(str(s) for s in slurm_ids)]
|
||||
subprocess.check_call(cmd)
|
||||
logger.info(
|
||||
f"Cancelled Slurm job with signal {signal}: "
|
||||
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}"
|
||||
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}"
|
||||
)
|
||||
|
||||
|
||||
def query_jobs(
|
||||
slurm_names: Optional[List[str]] = None,
|
||||
slurm_ids: Optional[List[str]] = None,
|
||||
slurm_ids: Optional[List[int]] = None,
|
||||
status: str = "all",
|
||||
delimiter: str = "__PSI__",
|
||||
) -> List[JobInfo]:
|
||||
|
@ -82,6 +85,7 @@ def query_jobs(
|
|||
cmd += ["-n", ",".join(slurm_names)]
|
||||
if slurm_ids is not None:
|
||||
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
|
||||
|
||||
output = (
|
||||
subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip()
|
||||
)
|
||||
|
@ -97,19 +101,43 @@ def query_jobs(
|
|||
host=nodelist,
|
||||
submit_time=submit_time,
|
||||
start_time=start_time,
|
||||
slurm_id=job_id.strip(),
|
||||
slurm_id=int(job_id.strip()),
|
||||
)
|
||||
)
|
||||
print(rs)
|
||||
return rs
|
||||
|
||||
|
||||
def parse_slurm_nodelist(nodelist: str) -> List[str]:
|
||||
return (
|
||||
subprocess.check_output(
|
||||
[
|
||||
"scontrol",
|
||||
"show",
|
||||
"hostnames",
|
||||
nodelist,
|
||||
]
|
||||
)
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
.split("\n")
|
||||
)
|
||||
|
||||
|
||||
def get_slurm_host_ip(node: str):
|
||||
try:
|
||||
cmd = f"srun --overlap --mpi=pmi2 --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address"
|
||||
return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip()
|
||||
except subprocess.CalledProcessError:
|
||||
logger.warning(f"Get slurm host ip for node {node} failed.")
|
||||
|
||||
|
||||
SCHEDULING_RETRY_INTERVAL_SECONDS = 30
|
||||
SCHEDULING_TIMEOUT_MAX_SECONDS = 3600 * 24
|
||||
SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5
|
||||
|
||||
|
||||
SBATCH_SCRIPT_TEMPLATE = """
|
||||
#!/bin/bash
|
||||
SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash
|
||||
{sbatch_options}
|
||||
|
||||
# Getting the node names
|
||||
|
@ -119,11 +147,11 @@ echo nodes=$nodes
|
|||
nodes_array=($nodes)
|
||||
echo node_array=$nodes_array
|
||||
|
||||
head_node=$\{nodes_array[0]\}
|
||||
head_node=${{nodes_array[0]}}
|
||||
echo head_node=$head_node
|
||||
|
||||
# Getting the head node IP address
|
||||
head_node_ip=$(srun --mpi=pmi2 --nodes=1 --ntasks=1 --nodelist="$head_node" hostname --ip-address)
|
||||
head_node_ip=$(srun --overlap --mpi=pmi2 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address)
|
||||
echo head_node_ip=$head_node_ip
|
||||
|
||||
# srun commands
|
||||
|
@ -132,11 +160,11 @@ echo head_node_ip=$head_node_ip
|
|||
wait
|
||||
"""
|
||||
|
||||
SRUN_CMD_TEMPLATE = """
|
||||
srun --mpi=pmi2 -K -l --chdir $PWD --nodes={nodes} --ntasks={ntasks} --gres:gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} \\
|
||||
SRUN_CMD_TEMPLATE = """srun --overlap --mpi=pmi2 -K -l --chdir $PWD --nodes={nodes} --ntasks={ntasks} --gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} \\
|
||||
--mem-per-cpu={mem_per_cpu}M {apptainer_name} exec {apptainer_options} --bind {container_mounts} \\
|
||||
{container_env_strings} {container_image} {cmd} &
|
||||
|
||||
{container_env_strings} \\
|
||||
{container_image} \\
|
||||
{cmd} &
|
||||
"""
|
||||
|
||||
LOCAL_CACHE_DIR = "/tmp/arealite"
|
||||
|
@ -152,6 +180,8 @@ BASE_ENVIRONS = {
|
|||
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
||||
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
|
||||
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
|
||||
"HF_ENDPOINT": "https://hf-mirror.com",
|
||||
"PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent,
|
||||
}
|
||||
NA132_ENVIRONS = {
|
||||
"NCCL_SOCKET_IFNAME": "bond0",
|
||||
|
@ -164,9 +194,11 @@ NA132_ENVIRONS = {
|
|||
"NCCL_IB_HCA": "mlx5_bond",
|
||||
"NCCL_IB_QPS_PER_CONNECTION": "8",
|
||||
"NCCL_SET_THREAD_NAME": "1",
|
||||
"NCCL_DEBUG": "WARN",
|
||||
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
||||
}
|
||||
|
||||
|
||||
def get_env_vars(cluster_name: str):
|
||||
"""Returns the environment variables for the cluster."""
|
||||
if cluster_name == "na132":
|
||||
|
@ -181,30 +213,29 @@ class SlurmLauncher:
|
|||
self.trial_name = trial_name
|
||||
self.fileroot = fileroot
|
||||
|
||||
# actual slurm job name -> JobInfo
|
||||
self.jobs: Dict[str, JobInfo] = {}
|
||||
# slurm_job_id -> JobInfo
|
||||
self.jobs: Dict[int, JobInfo] = {}
|
||||
self.job_names = []
|
||||
|
||||
@property
|
||||
def run_name(self) -> str:
|
||||
"""Returns the run name of this launcher."""
|
||||
return f"{self.experiment_name}_{self.trial_name}"
|
||||
|
||||
|
||||
def slurm_name(self, job_name: str) -> str:
|
||||
"""Returns the slurm name of a job."""
|
||||
return f"{self.experiment_name}_{self.trial_name}:{job_name}"
|
||||
|
||||
|
||||
|
||||
def log_path_of(self, job_name: str) -> str:
|
||||
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
return os.path.join(log_path, f"{job_name}.log")
|
||||
|
||||
|
||||
def sbatch_path_of(self, job_name: str) -> str:
|
||||
sbatch_path = f"{self.fileroot}/sbatch/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||
sbatch_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||
os.makedirs(sbatch_path, exist_ok=True)
|
||||
return os.path.join(sbatch_path, f"{job_name}.sh")
|
||||
|
||||
|
||||
def submit(self, job_name, cmd, **kwargs):
|
||||
"""Submits and launch a job with SBATCH.
|
||||
|
||||
|
@ -213,36 +244,46 @@ class SlurmLauncher:
|
|||
"""
|
||||
return self.submit_array(job_name, cmd, count=1, **kwargs)
|
||||
|
||||
def find_job_id(self, job_name: str):
|
||||
job_name = self.slurm_name(job_name)
|
||||
for job_id, job_info in self.jobs.items():
|
||||
if job_info.name == job_name:
|
||||
return job_id
|
||||
return None
|
||||
|
||||
def submit_array(
|
||||
self,
|
||||
job_name: str,
|
||||
self,
|
||||
job_name: str,
|
||||
cmd: List[str] | str,
|
||||
count: int,
|
||||
nodes: int,
|
||||
n_gpus_per_node: int,
|
||||
cpus_per_task: int,
|
||||
mem_per_task: int, # MB
|
||||
mem_per_task: int, # MB
|
||||
container_image: str,
|
||||
container_mounts: Optional[str] = None,
|
||||
env_vars: Optional[Dict] = None,
|
||||
nodelist: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
apptainer_name: Optional[str] = "singularity",
|
||||
apptainer_options: Optional[Tuple[str]] = (
|
||||
"--no-home", "--writable-tmpfs", "--nv", "--pid"
|
||||
apptainer_options: Optional[Tuple[str, ...]] = (
|
||||
"--no-home",
|
||||
"--writable-tmpfs",
|
||||
"--nv",
|
||||
"--pid",
|
||||
),
|
||||
):
|
||||
"""Submits and launch a job array with SBATCH.
|
||||
"""Submits and launch a job array with SBATCH.
|
||||
Note that a job array has one (unique) slurm name, and one (unique) slurm id.
|
||||
|
||||
Args:
|
||||
job_name (str): The job name of the job array. The actual slurm name will be
|
||||
`<experiment_name>_<trial_name>:<job_name>`.
|
||||
`<experiment_name>_<trial_name>:<job_name>`.
|
||||
cmd (str or List[str]): The core command to be executed.
|
||||
count (int): The number of jobs in the array.
|
||||
"""
|
||||
assert self.slurm_name(job_name) not in self.jobs, (
|
||||
f"Job {self.slurm_name(job_name)} is already submitted. "
|
||||
assert job_name not in self.job_names, (
|
||||
f"Job {job_name} is already submitted. "
|
||||
"Please use a different job name or stop the existing job."
|
||||
)
|
||||
if isinstance(cmd, str):
|
||||
|
@ -262,17 +303,18 @@ class SlurmLauncher:
|
|||
)
|
||||
|
||||
mem_per_cpu = mem_per_task // cpus_per_task # MB per CPU
|
||||
mem_per_node = mem_per_task * count // nodes + 1024 * 100 # FIXME
|
||||
|
||||
sbatch_options = [
|
||||
f"--job-name={self.slurm_name(job_name)}",
|
||||
f"--output={self.fileroot}/{self.run_name}/{job_name}.out",
|
||||
"--open-mode=append",
|
||||
f"--output={self.log_path_of(job_name)}",
|
||||
# "--open-mode=append", # FIXME
|
||||
"--no-requeue",
|
||||
f"--nodes={nodes}-{nodes}",
|
||||
f"--ntasks-per-node={ntasks_per_node}",
|
||||
f"--gres=gpu:{n_gpus_per_node}",
|
||||
f"--cpus-per-task={cpus_per_task}",
|
||||
f"--mem-per-cpu={mem_per_cpu}M",
|
||||
f"--mem={mem_per_node}M",
|
||||
]
|
||||
|
||||
if nodelist:
|
||||
|
@ -282,13 +324,12 @@ class SlurmLauncher:
|
|||
|
||||
sbatch_options_str = "\n".join([f"#SBATCH {opt}" for opt in sbatch_options])
|
||||
|
||||
|
||||
if env_vars is None:
|
||||
env_vars = dict()
|
||||
n_gpus_per_task = n_gpus_per_node // ntasks_per_node
|
||||
assert "CUDA_VISIBLE_DEVICES" not in env_vars, (
|
||||
"CUDA_VISIBLE_DEVICES should be automatically resolved by Launcher instead of manually assigned."
|
||||
)
|
||||
assert (
|
||||
"CUDA_VISIBLE_DEVICES" not in env_vars
|
||||
), "CUDA_VISIBLE_DEVICES should be automatically resolved by Launcher instead of manually assigned."
|
||||
|
||||
srun_cmds = []
|
||||
for i in range(count):
|
||||
|
@ -301,12 +342,17 @@ class SlurmLauncher:
|
|||
str(x) for x in range(gpu_id_start, gpu_id_end)
|
||||
),
|
||||
}
|
||||
env_string = " ".join("--env {}={}".format(k, v) for k, v in (_env_vars or {}).items())
|
||||
env_string = " ".join(
|
||||
"--env {}={}".format(k, v) for k, v in (_env_vars or {}).items()
|
||||
)
|
||||
# Prepare the command for each job in the array
|
||||
job_cmd = cmd[i]
|
||||
# FIXME: only for debugging, remove and replace new image
|
||||
job_cmd = f'bash -c "pip3 install -U gymnasium torchdata tensordict hf-xet; {job_cmd}"'
|
||||
|
||||
srun_cmd = SRUN_CMD_TEMPLATE.format(
|
||||
nodes=nodes,
|
||||
ntasks=ntasks_per_node,
|
||||
nodes=1,
|
||||
ntasks=1,
|
||||
n_gpus_per_node=n_gpus_per_node,
|
||||
cpus_per_task=cpus_per_task,
|
||||
mem_per_cpu=mem_per_cpu,
|
||||
|
@ -319,59 +365,83 @@ class SlurmLauncher:
|
|||
)
|
||||
srun_cmds.append(srun_cmd)
|
||||
|
||||
srun_cmd = "\n".join(srun_cmds)
|
||||
srun_cmds = "\n".join(srun_cmds)
|
||||
sbatch_script = SBATCH_SCRIPT_TEMPLATE.format(
|
||||
sbatch_options=sbatch_options_str, srun_cmds=srun_cmd
|
||||
sbatch_options=sbatch_options_str, srun_cmds=srun_cmds
|
||||
)
|
||||
sbatch_file_path = self.sbatch_path_of(f"{job_name}_{i}")
|
||||
sbatch_file_path = self.sbatch_path_of(f"{job_name}")
|
||||
with open(sbatch_file_path, "w") as f:
|
||||
f.write(sbatch_script)
|
||||
|
||||
# Submit the job
|
||||
return_code = subprocess.check_call(["sbatch", sbatch_file_path])
|
||||
|
||||
if return_code != 0:
|
||||
# FIXME: debug only
|
||||
try:
|
||||
output = (
|
||||
subprocess.check_output(["sbatch", sbatch_file_path])
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
)
|
||||
logger.info(
|
||||
f"Submitted Slurm job {self.slurm_name(job_name)} to scheduler. To check the output, run \n\t`tail -f {self.log_path_of(job_name)}`."
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(
|
||||
f"Failed to submit job {self.slurm_name(job_name)}. "
|
||||
f"For debugging, please make sure your sbatch command works "
|
||||
f"and check generated sbatch file on {sbatch_file_path}."
|
||||
)
|
||||
|
||||
self.jobs[self.slurm_name(job_name)] = JobInfo(
|
||||
logger.error(f"Error message: {e}")
|
||||
return
|
||||
|
||||
match = re.search(r"Submitted batch job (\d+)", output)
|
||||
slurm_job_id = int(match.group(1)) if match else None
|
||||
if slurm_job_id is None:
|
||||
logger.warning(
|
||||
f"Failed to obtain job id for job {self.slurm_name(job_name)}. "
|
||||
f"sbatch output: {output}"
|
||||
)
|
||||
return
|
||||
|
||||
assert isinstance(slurm_job_id, int)
|
||||
self.jobs[slurm_job_id] = JobInfo(
|
||||
name=self.slurm_name(job_name),
|
||||
state=JobState.PENDING,
|
||||
slurm_id=slurm_job_id,
|
||||
)
|
||||
self._update_all()
|
||||
|
||||
def stop(self, job_name, signal=None):
|
||||
def stop(self, job_name, signal="SIGKILL"):
|
||||
"""Stops a running job.
|
||||
|
||||
Raises exception if there is no such job, but passes if the job
|
||||
has stopped either successfully or not.
|
||||
|
||||
|
||||
Args:
|
||||
job_name: The job name of the job array to stop.
|
||||
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
|
||||
job_name: The job name of the job array to stop.
|
||||
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
job_id = self.find_job_id(job_name)
|
||||
if not job_id:
|
||||
return
|
||||
return cancel_jobs(slurm_ids=[job_id], signal=signal)
|
||||
|
||||
def stop_all(self, signal=None):
|
||||
def stop_all(self, signal="SIGKILL"):
|
||||
"""Stops all running jobs."""
|
||||
raise NotImplementedError()
|
||||
return cancel_jobs(slurm_ids=list(self.jobs.keys()), signal=signal)
|
||||
|
||||
def find(self, job_name) -> JobInfo:
|
||||
def find(self, job_name) -> JobInfo | None:
|
||||
"""Gets the status of a job of this job.
|
||||
|
||||
Args:
|
||||
job_name: The job name of the job array to find.
|
||||
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
|
||||
job_name: The job name of the job array to find.
|
||||
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
|
||||
|
||||
Returns:
|
||||
A JobInfo if the job is found, or None otherwise.
|
||||
"""
|
||||
self._update_all()
|
||||
return self.jobs.get(self.slurm_name(job_name), None)
|
||||
|
||||
job_id = self.find_job_id(job_name)
|
||||
return self.jobs[job_id] if job_id else None
|
||||
|
||||
def find_all(self, job_name_regex=".*") -> List[JobInfo]:
|
||||
"""Finds jobs.
|
||||
|
@ -385,11 +455,11 @@ class SlurmLauncher:
|
|||
self._update_all()
|
||||
infos = []
|
||||
for r in self.jobs.values():
|
||||
job_name = r.name.split(":")[-1] # Extract the job name from slurm name
|
||||
job_name = r.name.split(":")[-1] # Extract the job name from slurm name
|
||||
if re.fullmatch(job_name_regex, job_name):
|
||||
infos.append(r)
|
||||
return infos
|
||||
|
||||
|
||||
def _find_job_with_status(
|
||||
self,
|
||||
status: List[JobState],
|
||||
|
@ -405,7 +475,6 @@ class SlurmLauncher:
|
|||
self._update_all()
|
||||
return [r for r in self.jobs.values() if r.state in status]
|
||||
|
||||
|
||||
def wait(
|
||||
self,
|
||||
timeout=None,
|
||||
|
@ -422,10 +491,10 @@ class SlurmLauncher:
|
|||
deadline = None if timeout is None else time.time() + timeout
|
||||
|
||||
num_jobs_left = len(self.jobs)
|
||||
left = set(self.jobs.values())
|
||||
left = list(self.jobs.keys())
|
||||
logger.info(
|
||||
f"Waiting for {num_jobs_left} jobs. Jobs IDs: "
|
||||
f"{','.join(sorted([x.slurm_id for x in self.jobs.values()]))}."
|
||||
f"{','.join(sorted([str(x.slurm_id) for x in self.jobs.values()]))}."
|
||||
)
|
||||
while len(left) > 0:
|
||||
if len(left) < num_jobs_left:
|
||||
|
@ -434,11 +503,12 @@ class SlurmLauncher:
|
|||
if deadline is not None and time.time() > deadline:
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for {num_jobs_left} jobs. Job ID: "
|
||||
f"{','.join(sorted([x.slurm_id for x in self.jobs.values()]))}."
|
||||
f"{','.join(sorted([str(x.slurm_id) for x in self.jobs.values()]))}."
|
||||
)
|
||||
self._update_all()
|
||||
left = list(self.jobs.values())
|
||||
for slurm_info in list(left):
|
||||
left = list(self.jobs.keys())
|
||||
for slurm_id in list(left):
|
||||
slurm_info = self.jobs[slurm_id]
|
||||
if slurm_info.slurm_id is None:
|
||||
continue
|
||||
if slurm_info.state in check_status:
|
||||
|
@ -452,18 +522,18 @@ class SlurmLauncher:
|
|||
logger.info(
|
||||
f"Job {slurm_info.name} is {slurm_info.state}. (Removed)"
|
||||
)
|
||||
left.remove(slurm_info)
|
||||
left.remove(slurm_id)
|
||||
if update:
|
||||
self.jobs.pop(slurm_info.name)
|
||||
self.jobs.pop(slurm_info.slurm_id)
|
||||
time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL)
|
||||
|
||||
def _update_all(self):
|
||||
"""Updates the status of all jobs. """
|
||||
"""Updates the status of all jobs."""
|
||||
try:
|
||||
slurm_infos = query_jobs(list(self.jobs.keys()))
|
||||
slurm_infos = query_jobs(slurm_ids=list(self.jobs.keys()))
|
||||
for slurm_info in slurm_infos:
|
||||
if slurm_info.name in self.jobs:
|
||||
self.jobs[slurm_info.name] = slurm_info
|
||||
assert slurm_info.slurm_id is not None
|
||||
self.jobs[slurm_info.slurm_id] = slurm_info
|
||||
except subprocess.CalledProcessError:
|
||||
logger.warning(
|
||||
"Calling squeue failed. Check slurm manually if you continue to see this warning."
|
||||
|
@ -472,99 +542,146 @@ class SlurmLauncher:
|
|||
|
||||
def slurm_args_parser():
|
||||
parser = argparse.ArgumentParser(description="Slurm Launcher for AReaL")
|
||||
parser.add_argument("entry_point", type=str, help="The entry point script to run.")
|
||||
parser.add_argument("config_path", type=str, help="Path to the configuration file.")
|
||||
parser.add_argument(
|
||||
"--sglang-server-base-port", type=int, required=False, default=27010,
|
||||
help="Base port for SGLang servers. SGLang servers on the same node will ."
|
||||
"--sglang-server-base-port",
|
||||
type=int,
|
||||
required=False,
|
||||
default=27010,
|
||||
help="Base port for SGLang servers. SGLang servers on the same node will .",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trainer-port",
|
||||
type=int,
|
||||
required=False,
|
||||
default=27009,
|
||||
help="Pytorch distributed initialization port for trainer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sglang-version",
|
||||
type=str,
|
||||
required=False,
|
||||
default="0.4.6.post4",
|
||||
help="SGLang version in your GPU inference image.",
|
||||
)
|
||||
parser.add_argument("--trainer-port", type=int, required=False, default=27009, help="Pytorch distributed initialization port for trainer.")
|
||||
# parser.add_argument("remaining_args", nargs='*', help="Additional arguments to pass to the entry point script.")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# usage: python -m arealite.launcher.slurm <entry_point> --allocation_mode <allocation_mode> <config_path> [<args>]
|
||||
# usage: python -m arealite.launcher.slurm <entry_point> <config_path> [<args>]
|
||||
r = parse_cli_args(sys.argv[2:], parser=slurm_args_parser())
|
||||
entry_point = sys.argv[1]
|
||||
config = r.config
|
||||
config_file = r.config_file
|
||||
args = r.additional_args
|
||||
|
||||
cluster_config: ClusterSpecConfig = config.cluster
|
||||
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
|
||||
n_nodes = config.n_nodes
|
||||
n_gpus_per_node = config.n_gpus_per_node
|
||||
if n_gpus_per_node < cluster_config.n_gpus_per_node:
|
||||
if n_gpus_per_node < config.cluster.n_gpus_per_node:
|
||||
raise ValueError(
|
||||
f"Slurm Launcher requires at least {cluster_config.n_gpus_per_node} (#GPUs per node) GPU. For usecases of less GPUs, use LocalLauncher instead."
|
||||
f"Slurm Launcher requires at least {config.cluster.n_gpus_per_node} (#GPUs per node) GPU. For usecases of less GPUs, use LocalLauncher instead."
|
||||
)
|
||||
elif n_gpus_per_node > cluster_config.n_gpus_per_node:
|
||||
elif n_gpus_per_node > config.cluster.n_gpus_per_node:
|
||||
raise ValueError(
|
||||
f"#GPU per node required by experiment ({n_gpus_per_node}) is larger than #GPU per node in the cluster ({cluster_config.n_gpus_per_node})."
|
||||
f"#GPU per node required by experiment ({n_gpus_per_node}) is larger than #GPU per node in the cluster ({config.cluster.n_gpus_per_node})."
|
||||
)
|
||||
|
||||
|
||||
launcher = SlurmLauncher(
|
||||
experiment_name=config.experiment_name,
|
||||
trial_name=config.trial_name,
|
||||
fileroot=cluster_config.fileroot
|
||||
fileroot=config.cluster.fileroot,
|
||||
)
|
||||
allocation_mode = config.allocation_mode
|
||||
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||
sglang_cmds = []
|
||||
sglang_addrs = []
|
||||
n_sglang_nodes = 0
|
||||
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
||||
# Launcher should launch SGLang servers according to allocation mode.
|
||||
assert isinstance(allocation_mode, str), "Allocation mode should be a string."
|
||||
sglang_config = find_and_amend_config(config, "sglang", SGLangConfig)
|
||||
assert "gen" in allocation_mode
|
||||
assert allocation_mode.gen_pp_size == 1, "Pipeline generation in SGLang is not supported for now."
|
||||
assert allocation_mode.gen_tp_size <= cluster_config.n_gpus_per_node, "Currently only support SGLang TP size less <= #GPUs per node."
|
||||
sglang_world_size = allocation_mode.gen_world_size
|
||||
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
|
||||
assert (
|
||||
allocation_mode.gen_pp_size == 1
|
||||
), "Pipeline generation in SGLang is not supported for now."
|
||||
assert (
|
||||
allocation_mode.gen_tp_size <= config.cluster.n_gpus_per_node
|
||||
), "Currently only support SGLang TP size less <= #GPUs per node."
|
||||
sglang_world_size = allocation_mode.gen_world_size
|
||||
sglang_tp_size = allocation_mode.gen_tp_size
|
||||
n_sglang_server_instances = allocation_mode.gen_dp_size
|
||||
n_sglang_servers = allocation_mode.gen_dp_size
|
||||
n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node
|
||||
n_sglang_servers_per_node = config.cluster.n_gpus_per_node // sglang_tp_size
|
||||
|
||||
model_path = find_config(config, "path")
|
||||
assert model_path is not None and isinstance(model_path, str)
|
||||
|
||||
seed = find_config(config, "seed")
|
||||
base_seed = config.sglang.random_seed
|
||||
base_port = args.sglang_server_base_port
|
||||
for i in range(n_sglang_server_instances):
|
||||
sglang_ports_on_node = set()
|
||||
for i in range(n_sglang_servers):
|
||||
base_gpu_id = i * sglang_tp_size % n_gpus_per_node
|
||||
n_server_per_node = cluster_config.n_gpus_per_node // sglang_tp_size
|
||||
# Since we cannot get port information from slurm, we only ensure ports on the same .
|
||||
server_port = base_port + i % n_server_per_node * 2
|
||||
dist_port = base_port + i % n_server_per_node * 2 + 1
|
||||
dist_init_addr = f"tcp://localhost:{server_port}"
|
||||
server_port = base_port + i % n_sglang_servers_per_node * 2
|
||||
# print(server_port, dist_init_addr)
|
||||
config.sglang.random_seed = base_seed + i
|
||||
sglang_cmds.append(
|
||||
SGLangConfig.build_cmd(
|
||||
sglang_config,
|
||||
model_path,
|
||||
config.sglang,
|
||||
sglang_tp_size,
|
||||
base_gpu_id=base_gpu_id,
|
||||
seed=seed + i
|
||||
base_gpu_id=0,
|
||||
host="localhost",
|
||||
port=server_port,
|
||||
dist_init_addr=None,
|
||||
sglang_version=args.sglang_version,
|
||||
)
|
||||
)
|
||||
sglang_ports_on_node.add(server_port)
|
||||
assert len(sglang_ports_on_node) == n_sglang_servers_per_node
|
||||
|
||||
# launch sglang servers
|
||||
# launch SGLang servers, note that we need to leave some resources on each node
|
||||
# to schedule jobs that retrieve node IP. (1 CPU core & 10 MB memory)
|
||||
if sglang_cmds:
|
||||
launcher.submit_array(
|
||||
job_name="sglang-server",
|
||||
cmd=sglang_cmds,
|
||||
count=n_sglang_nodes,
|
||||
nodes=n_sglang_server_instances,
|
||||
n_gpus_per_node=cluster_config.n_gpus_per_node,
|
||||
cpus_per_task=15 * sglang_tp_size, # one sglang task occupy the entire node
|
||||
mem_per_task=150 * 1024 * sglang_tp_size, # 150GB per task
|
||||
container_image=cluster_config.gpu_infer_image,
|
||||
container_mounts=cluster_config.mount,
|
||||
env_vars=get_env_vars(cluster_config.cluster_name),
|
||||
count=n_sglang_servers,
|
||||
nodes=n_sglang_nodes,
|
||||
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||
cpus_per_task=15 * sglang_tp_size,
|
||||
mem_per_task=150 * 1024 * sglang_tp_size, # 20GB per task
|
||||
container_image=config.cluster.gpu_infer_image,
|
||||
container_mounts=config.cluster.mount,
|
||||
env_vars=get_env_vars(config.cluster.cluster_name),
|
||||
)
|
||||
|
||||
|
||||
# Get SGLang slurm nodes, find the hosts
|
||||
start_time = time.perf_counter()
|
||||
while True:
|
||||
job_info = launcher.find("sglang-server")
|
||||
assert job_info is not None
|
||||
print(job_info)
|
||||
sglang_hosts = job_info.host
|
||||
logger.info(
|
||||
f"Waiting for SGLang servers to be scheduled by slurm, time since started = {time.perf_counter() - start_time:.2f}"
|
||||
)
|
||||
if sglang_hosts:
|
||||
print(sglang_hosts)
|
||||
sglang_hosts = [
|
||||
get_slurm_host_ip(node)
|
||||
for node in parse_slurm_nodelist(sglang_hosts)
|
||||
]
|
||||
print(sglang_hosts)
|
||||
for host in sglang_hosts:
|
||||
sglang_addrs.extend(
|
||||
[f"{host}:{port}" for port in sglang_ports_on_node]
|
||||
)
|
||||
assert len(sglang_addrs) == n_sglang_servers
|
||||
break
|
||||
time.sleep(10)
|
||||
|
||||
trainer_n_nodes = n_nodes - n_sglang_nodes
|
||||
assert r.overrides is not None
|
||||
trainer_cmd_template = (
|
||||
f"torchrun --nnodes={{nnodes}} --nproc-per-node={{nproc_per_node}} --node-rank {{node_rank}} "
|
||||
f"--master-addr $head_node_ip --master-port {args.trainer_port} {' '.join(sys.argv[1:])}"
|
||||
f"--master-addr $head_node_ip --master-port {args.trainer_port} {entry_point} --config {config_file} {' '.join(r.overrides)}"
|
||||
)
|
||||
|
||||
|
||||
trainer_cmds = []
|
||||
for i in range(trainer_n_nodes):
|
||||
# For each trainer node, we launch a trainer with the same command.
|
||||
|
@ -572,24 +689,28 @@ if __name__ == "__main__":
|
|||
trainer_cmds.append(
|
||||
trainer_cmd_template.format(
|
||||
nnodes=trainer_n_nodes,
|
||||
nproc_per_node=cluster_config.n_gpus_per_node,
|
||||
node_rank=i
|
||||
nproc_per_node=config.cluster.n_gpus_per_node,
|
||||
node_rank=i,
|
||||
)
|
||||
)
|
||||
|
||||
# launch trainers
|
||||
launcher.submit_array(
|
||||
job_name="trainer",
|
||||
cmd=trainer_cmds,
|
||||
count=trainer_n_nodes,
|
||||
nodes=trainer_n_nodes,
|
||||
n_gpus_per_node=cluster_config.n_gpus_per_node,
|
||||
cpus_per_task=120, # one trainer task occupy the entire node
|
||||
mem_per_task=1200 * 1024, # 1.2T per task
|
||||
container_image=cluster_config.gpu_image,
|
||||
container_mounts=cluster_config.mount,
|
||||
env_vars=get_env_vars(cluster_config.cluster_name),
|
||||
)
|
||||
|
||||
if not config.server_only:
|
||||
# launch trainers
|
||||
launcher.submit_array(
|
||||
job_name="trainer",
|
||||
cmd=trainer_cmds,
|
||||
count=trainer_n_nodes,
|
||||
nodes=trainer_n_nodes,
|
||||
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||
cpus_per_task=120, # one trainer task occupy the entire node
|
||||
mem_per_task=1024 * 1024, # 1024GB per task
|
||||
container_image=config.cluster.gpu_image,
|
||||
container_mounts=config.cluster.mount,
|
||||
env_vars=dict(
|
||||
**get_env_vars(config.cluster.cluster_name),
|
||||
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
launcher.wait(
|
||||
|
@ -602,6 +723,5 @@ if __name__ == "__main__":
|
|||
remove_status=(),
|
||||
)
|
||||
except (KeyboardInterrupt, JobException, TimeoutError) as e:
|
||||
launcher.stop_all("SIGTERM")
|
||||
launcher.stop_all("SIGKILL")
|
||||
raise e
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
|
||||
def find_config(config: DictConfig, name: str) -> DictConfig | None:
|
||||
# iterate through the nested DictConfig and find the first matching config with name
|
||||
for key, value in config.items():
|
||||
|
@ -11,6 +12,7 @@ def find_config(config: DictConfig, name: str) -> DictConfig | None:
|
|||
return found
|
||||
return None
|
||||
|
||||
|
||||
def amend_config(config: DictConfig, config_cls):
|
||||
default_config = OmegaConf.structured(config_cls)
|
||||
config = OmegaConf.merge(default_config, config)
|
||||
|
@ -18,9 +20,10 @@ def amend_config(config: DictConfig, config_cls):
|
|||
assert isinstance(config, config_cls)
|
||||
return config
|
||||
|
||||
|
||||
def find_and_amend_config(config: DictConfig, name: str, config_cls):
|
||||
# Find the config with the given name and amend it with the given config_cls
|
||||
found = find_config(config, name)
|
||||
if found is not None:
|
||||
return amend_config(found, config_cls)
|
||||
return None
|
||||
return None
|
||||
|
|
|
@ -41,12 +41,12 @@ class JobException(Exception):
|
|||
class JobInfo:
|
||||
name: str
|
||||
state: JobState
|
||||
host: str = (
|
||||
host: Optional[str] = (
|
||||
None # The host on which the job is/was running. None if the job had not run.
|
||||
)
|
||||
submit_time: str = None
|
||||
start_time: str = None
|
||||
slurm_id: str = None # Slurm only. The Slurm id of the job.
|
||||
submit_time: Optional[str] = None
|
||||
start_time: Optional[str] = None
|
||||
slurm_id: Optional[int] = None # Slurm only. The Slurm id of the job.
|
||||
|
||||
|
||||
class SchedulerClient:
|
||||
|
|
Loading…
Reference in New Issue