mirror of https://github.com/inclusionAI/AReaL
slurm launcher not tested
This commit is contained in:
parent
c38cffc023
commit
6acd1696f6
|
@ -195,26 +195,29 @@ class SGLangConfig:
|
|||
@staticmethod
|
||||
def build_cmd(
|
||||
sglang_config: "SGLangConfig",
|
||||
model_path,
|
||||
tp_size,
|
||||
base_gpu_id,
|
||||
model_path: str,
|
||||
tp_size: int,
|
||||
base_gpu_id: int,
|
||||
server_port: Optional[int] = None,
|
||||
dist_init_addr: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
served_model_name: Optional[str] = None,
|
||||
skip_tokenizer_init: bool = True,
|
||||
):
|
||||
) -> str:
|
||||
from realhf.base import network, pkg_version, seeding
|
||||
from realhf.experiments.common.utils import asdict as conf_as_dict
|
||||
|
||||
args: Dict = conf_as_dict(sglang_config)
|
||||
args["random_seed"] = seeding.get_seed()
|
||||
|
||||
if server_port is not None:
|
||||
args["port"] = server_port
|
||||
if served_model_name is None:
|
||||
served_model_name = model_path
|
||||
host_ip = network.gethostip()
|
||||
host = "localhost" if not sglang_config.enable_metrics else host_ip
|
||||
host = "localhost"
|
||||
args = dict(
|
||||
host=host,
|
||||
model_path=model_path,
|
||||
# seed
|
||||
seed=seed if seed is not None else seeding.get_seed(),
|
||||
# Model and tokenizer
|
||||
tokenizer_path=model_path,
|
||||
tokenizer_mode="auto",
|
||||
|
@ -230,6 +233,7 @@ class SGLangConfig:
|
|||
base_gpu_id=base_gpu_id,
|
||||
nnodes=1,
|
||||
node_rank=0,
|
||||
# initialization addresses and ports
|
||||
dist_init_addr=dist_init_addr,
|
||||
**args,
|
||||
)
|
||||
|
@ -543,6 +547,16 @@ class SFTConfig(BaseExperimentConfig):
|
|||
model: TrainEngineConfig = field(default_factory=TrainEngineConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRPOConfig(BaseExperimentConfig):
|
||||
actor: TrainEngineConfig = field(default_factory=TrainEngineConfig)
|
||||
ref: TrainEngineConfig = field(default_factory=TrainEngineConfig)
|
||||
rollout: InferenceEngineConfig = field(
|
||||
default_factory=InferenceEngineConfig
|
||||
)
|
||||
sglang: SGLangConfig = field(default_factory=SGLangConfig)
|
||||
|
||||
|
||||
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
|
|
@ -0,0 +1,614 @@
|
|||
from typing import Tuple, List, Optional, Literal, Dict
|
||||
import subprocess
|
||||
import time
|
||||
import re
|
||||
import os
|
||||
import getpass
|
||||
import argparse
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from arealite.api.cli_args import SGLangConfig, ClusterSpecConfig
|
||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.launcher.utils import find_config, find_and_amend_config
|
||||
import realhf.base.logging as logging
|
||||
from realhf.scheduler.client import (
|
||||
JobException,
|
||||
JobInfo,
|
||||
JobState
|
||||
)
|
||||
|
||||
logger = logging.getLogger("SlurmLauncher")
|
||||
|
||||
|
||||
SQUEUE_FIELDS = [
|
||||
"JobID",
|
||||
"State",
|
||||
"SubmitTime",
|
||||
"StartTime",
|
||||
"Name",
|
||||
"NodeList",
|
||||
"UserName",
|
||||
"MaxCPUs",
|
||||
"cpus-per-task",
|
||||
"NumTasks",
|
||||
"tres-alloc",
|
||||
]
|
||||
STATUS_MAPPING = {
|
||||
"RUNNING": JobState.RUNNING,
|
||||
"COMPLETING": JobState.RUNNING,
|
||||
"PENDING": JobState.PENDING,
|
||||
"CANCELLED": JobState.CANCELLED,
|
||||
"FAILED": JobState.FAILED,
|
||||
"COMPLETED": JobState.COMPLETED,
|
||||
"OUT_OF_MEMORY": JobState.FAILED,
|
||||
"DEADLINE": JobState.COMPLETED,
|
||||
"TIMEOUT": JobState.COMPLETED,
|
||||
}
|
||||
|
||||
|
||||
def cancel_jobs(
|
||||
slurm_names: Optional[List[str]] = None,
|
||||
slurm_ids: Optional[List[str]] = None,
|
||||
signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL",
|
||||
):
|
||||
assert (
|
||||
slurm_names is not None or slurm_ids is not None
|
||||
), "Must specify slurm_names or slurm_ids."
|
||||
assert not (
|
||||
slurm_names and slurm_ids
|
||||
), "Cannot specify both slurm_names and slurm_ids."
|
||||
cmd = ["scancel", "-s", signal]
|
||||
if slurm_names is not None:
|
||||
cmd += ["-n", ",".join(slurm_names)]
|
||||
elif slurm_ids is not None:
|
||||
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
|
||||
subprocess.check_call(cmd)
|
||||
logger.info(
|
||||
f"Cancelled Slurm job with signal {signal}: "
|
||||
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}"
|
||||
)
|
||||
|
||||
def query_jobs(
|
||||
slurm_names: Optional[List[str]] = None,
|
||||
slurm_ids: Optional[List[str]] = None,
|
||||
status: str = "all",
|
||||
delimiter: str = "__PSI__",
|
||||
) -> List[JobInfo]:
|
||||
squeue_format = f":.{delimiter},".join(SQUEUE_FIELDS)
|
||||
cmd = ["squeue", "-O", squeue_format, f"-t{status}"]
|
||||
if slurm_names is not None:
|
||||
cmd += ["-n", ",".join(slurm_names)]
|
||||
if slurm_ids is not None:
|
||||
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
|
||||
output = (
|
||||
subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip()
|
||||
)
|
||||
rs = []
|
||||
for line in output.split("\n")[1:]:
|
||||
job_id, state, submit_time, start_time, slurm_name, nodelist, *_ = line.split(
|
||||
delimiter
|
||||
)
|
||||
rs.append(
|
||||
JobInfo(
|
||||
name=slurm_name,
|
||||
state=STATUS_MAPPING[state],
|
||||
host=nodelist,
|
||||
submit_time=submit_time,
|
||||
start_time=start_time,
|
||||
slurm_id=job_id.strip(),
|
||||
)
|
||||
)
|
||||
return rs
|
||||
|
||||
|
||||
SCHEDULING_RETRY_INTERVAL_SECONDS = 30
|
||||
SCHEDULING_TIMEOUT_MAX_SECONDS = 3600 * 24
|
||||
SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5
|
||||
|
||||
|
||||
SBATCH_SCRIPT_TEMPLATE = """
|
||||
#!/bin/bash
|
||||
{sbatch_options}
|
||||
|
||||
# Getting the node names
|
||||
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
|
||||
echo nodes=$nodes
|
||||
|
||||
nodes_array=($nodes)
|
||||
echo node_array=$nodes_array
|
||||
|
||||
head_node=$\{nodes_array[0]\}
|
||||
echo head_node=$head_node
|
||||
|
||||
# Getting the head node IP address
|
||||
head_node_ip=$(srun --mpi=pmi2 --nodes=1 --ntasks=1 --nodelist="$head_node" hostname --ip-address)
|
||||
echo head_node_ip=$head_node_ip
|
||||
|
||||
# srun commands
|
||||
{srun_cmds}
|
||||
|
||||
wait
|
||||
"""
|
||||
|
||||
SRUN_CMD_TEMPLATE = """
|
||||
srun --mpi=pmi2 -K -l --chdir $PWD --nodes={nodes} --ntasks={ntasks} --gres:gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} \\
|
||||
--mem-per-cpu={mem_per_cpu}M {apptainer_name} exec {apptainer_options} --bind {container_mounts} \\
|
||||
{container_env_strings} {container_image} {cmd} &
|
||||
|
||||
"""
|
||||
|
||||
LOCAL_CACHE_DIR = "/tmp/arealite"
|
||||
PYTORCH_KERNEL_CACHE_PATH = (
|
||||
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
|
||||
)
|
||||
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
|
||||
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
|
||||
BASE_ENVIRONS = {
|
||||
"TOKENIZERS_PARALLELISM": "true",
|
||||
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
|
||||
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
||||
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
|
||||
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
|
||||
}
|
||||
NA132_ENVIRONS = {
|
||||
"NCCL_SOCKET_IFNAME": "bond0",
|
||||
"NCCL_NET_PLUGIN": "",
|
||||
"NCCL_IB_GID_INDEX": "3",
|
||||
"NCCL_IB_TIMEOUT": "2",
|
||||
"NCCL_IB_RETRY_CNT": "7",
|
||||
"NCCL_IB_SL": "5",
|
||||
"NCCL_IB_TC": "136",
|
||||
"NCCL_IB_HCA": "mlx5_bond",
|
||||
"NCCL_IB_QPS_PER_CONNECTION": "8",
|
||||
"NCCL_SET_THREAD_NAME": "1",
|
||||
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
||||
}
|
||||
|
||||
def get_env_vars(cluster_name: str):
|
||||
"""Returns the environment variables for the cluster."""
|
||||
if cluster_name == "na132":
|
||||
return {**BASE_ENVIRONS, **NA132_ENVIRONS}
|
||||
else:
|
||||
return BASE_ENVIRONS
|
||||
|
||||
|
||||
class SlurmLauncher:
|
||||
def __init__(self, experiment_name: str, trial_name: str, fileroot: str):
|
||||
self.experiment_name = experiment_name
|
||||
self.trial_name = trial_name
|
||||
self.fileroot = fileroot
|
||||
|
||||
# actual slurm job name -> JobInfo
|
||||
self.jobs: Dict[str, JobInfo] = {}
|
||||
|
||||
@property
|
||||
def run_name(self) -> str:
|
||||
"""Returns the run name of this launcher."""
|
||||
return f"{self.experiment_name}_{self.trial_name}"
|
||||
|
||||
def slurm_name(self, job_name: str) -> str:
|
||||
"""Returns the slurm name of a job."""
|
||||
return f"{self.experiment_name}_{self.trial_name}:{job_name}"
|
||||
|
||||
|
||||
def log_path_of(self, job_name: str) -> str:
|
||||
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
return os.path.join(log_path, f"{job_name}.log")
|
||||
|
||||
def sbatch_path_of(self, job_name: str) -> str:
|
||||
sbatch_path = f"{self.fileroot}/sbatch/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||
os.makedirs(sbatch_path, exist_ok=True)
|
||||
return os.path.join(sbatch_path, f"{job_name}.sh")
|
||||
|
||||
|
||||
def submit(self, job_name, cmd, **kwargs):
|
||||
"""Submits and launch a job with SBATCH.
|
||||
|
||||
Args:
|
||||
cmd (str or List[str]): The core command to be executed.
|
||||
"""
|
||||
return self.submit_array(job_name, cmd, count=1, **kwargs)
|
||||
|
||||
def submit_array(
|
||||
self,
|
||||
job_name: str,
|
||||
cmd: List[str] | str,
|
||||
count: int,
|
||||
nodes: int,
|
||||
n_gpus_per_node: int,
|
||||
cpus_per_task: int,
|
||||
mem_per_task: int, # MB
|
||||
container_image: str,
|
||||
container_mounts: Optional[str] = None,
|
||||
env_vars: Optional[Dict] = None,
|
||||
nodelist: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
apptainer_name: Optional[str] = "singularity",
|
||||
apptainer_options: Optional[Tuple[str]] = (
|
||||
"--no-home", "--writable-tmpfs", "--nv", "--pid"
|
||||
),
|
||||
):
|
||||
"""Submits and launch a job array with SBATCH.
|
||||
Note that a job array has one (unique) slurm name, and one (unique) slurm id.
|
||||
|
||||
Args:
|
||||
job_name (str): The job name of the job array. The actual slurm name will be
|
||||
`<experiment_name>_<trial_name>:<job_name>`.
|
||||
cmd (str or List[str]): The core command to be executed.
|
||||
count (int): The number of jobs in the array.
|
||||
"""
|
||||
assert self.slurm_name(job_name) not in self.jobs, (
|
||||
f"Job {self.slurm_name(job_name)} is already submitted. "
|
||||
"Please use a different job name or stop the existing job."
|
||||
)
|
||||
if isinstance(cmd, str):
|
||||
cmd = [cmd]
|
||||
assert len(cmd) == count, (
|
||||
f"Command length {len(cmd)} does not match the job count {count}. "
|
||||
"Please provide a command for each job in the array."
|
||||
)
|
||||
assert count % nodes == 0, (
|
||||
f"Job count {count} must be divisible by the number of nodes {nodes}. "
|
||||
"Please adjust the job count or the number of nodes."
|
||||
)
|
||||
ntasks_per_node = count // nodes
|
||||
assert n_gpus_per_node % ntasks_per_node == 0, (
|
||||
"GPUs must be evenly distributed across tasks. "
|
||||
f"Current #GPUs per node {n_gpus_per_node}, #tasks per node {ntasks_per_node}."
|
||||
)
|
||||
|
||||
mem_per_cpu = mem_per_task // cpus_per_task # MB per CPU
|
||||
|
||||
sbatch_options = [
|
||||
f"--job-name={self.slurm_name(job_name)}",
|
||||
f"--output={self.fileroot}/{self.run_name}/{job_name}.out",
|
||||
"--open-mode=append",
|
||||
"--no-requeue",
|
||||
f"--nodes={nodes}-{nodes}",
|
||||
f"--ntasks-per-node={ntasks_per_node}",
|
||||
f"--gres=gpu:{n_gpus_per_node}",
|
||||
f"--cpus-per-task={cpus_per_task}",
|
||||
f"--mem-per-cpu={mem_per_cpu}M",
|
||||
]
|
||||
|
||||
if nodelist:
|
||||
sbatch_options.append(f"--nodelist={nodelist}")
|
||||
if exclude:
|
||||
sbatch_options.append(f"--exclude={exclude}")
|
||||
|
||||
sbatch_options_str = "\n".join([f"#SBATCH {opt}" for opt in sbatch_options])
|
||||
|
||||
|
||||
if env_vars is None:
|
||||
env_vars = dict()
|
||||
n_gpus_per_task = n_gpus_per_node // ntasks_per_node
|
||||
assert "CUDA_VISIBLE_DEVICES" not in env_vars, (
|
||||
"CUDA_VISIBLE_DEVICES should be automatically resolved by Launcher instead of manually assigned."
|
||||
)
|
||||
|
||||
srun_cmds = []
|
||||
for i in range(count):
|
||||
# resolve CUDA_VISIBLE_DEVICES for each task
|
||||
gpu_id_start = (i % ntasks_per_node) * n_gpus_per_task
|
||||
gpu_id_end = ((i % ntasks_per_node) + 1) * n_gpus_per_task
|
||||
_env_vars = {
|
||||
**env_vars,
|
||||
"CUDA_VISIBLE_DEVICES": ",".join(
|
||||
str(x) for x in range(gpu_id_start, gpu_id_end)
|
||||
),
|
||||
}
|
||||
env_string = " ".join("--env {}={}".format(k, v) for k, v in (_env_vars or {}).items())
|
||||
# Prepare the command for each job in the array
|
||||
job_cmd = cmd[i]
|
||||
srun_cmd = SRUN_CMD_TEMPLATE.format(
|
||||
nodes=nodes,
|
||||
ntasks=ntasks_per_node,
|
||||
n_gpus_per_node=n_gpus_per_node,
|
||||
cpus_per_task=cpus_per_task,
|
||||
mem_per_cpu=mem_per_cpu,
|
||||
apptainer_name=apptainer_name,
|
||||
apptainer_options=" ".join(apptainer_options),
|
||||
container_mounts=container_mounts or "",
|
||||
container_env_strings=env_string,
|
||||
container_image=container_image,
|
||||
cmd=job_cmd,
|
||||
)
|
||||
srun_cmds.append(srun_cmd)
|
||||
|
||||
srun_cmd = "\n".join(srun_cmds)
|
||||
sbatch_script = SBATCH_SCRIPT_TEMPLATE.format(
|
||||
sbatch_options=sbatch_options_str, srun_cmds=srun_cmd
|
||||
)
|
||||
sbatch_file_path = self.sbatch_path_of(f"{job_name}_{i}")
|
||||
with open(sbatch_file_path, "w") as f:
|
||||
f.write(sbatch_script)
|
||||
|
||||
# Submit the job
|
||||
return_code = subprocess.check_call(["sbatch", sbatch_file_path])
|
||||
|
||||
if return_code != 0:
|
||||
logger.info(
|
||||
f"Failed to submit job {self.slurm_name(job_name)}. "
|
||||
f"For debugging, please make sure your sbatch command works "
|
||||
f"and check generated sbatch file on {sbatch_file_path}."
|
||||
)
|
||||
|
||||
self.jobs[self.slurm_name(job_name)] = JobInfo(
|
||||
name=self.slurm_name(job_name),
|
||||
state=JobState.PENDING,
|
||||
)
|
||||
self._update_all()
|
||||
|
||||
def stop(self, job_name, signal=None):
|
||||
"""Stops a running job.
|
||||
|
||||
Raises exception if there is no such job, but passes if the job
|
||||
has stopped either successfully or not.
|
||||
|
||||
Args:
|
||||
job_name: The job name of the job array to stop.
|
||||
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def stop_all(self, signal=None):
|
||||
"""Stops all running jobs."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def find(self, job_name) -> JobInfo:
|
||||
"""Gets the status of a job of this job.
|
||||
|
||||
Args:
|
||||
job_name: The job name of the job array to find.
|
||||
The actual slurm job name will be `<experiment_name>_<trial_name>:<job_name>`.
|
||||
|
||||
Returns:
|
||||
A JobInfo if the job is found, or None otherwise.
|
||||
"""
|
||||
self._update_all()
|
||||
return self.jobs.get(self.slurm_name(job_name), None)
|
||||
|
||||
|
||||
def find_all(self, job_name_regex=".*") -> List[JobInfo]:
|
||||
"""Finds jobs.
|
||||
|
||||
Args:
|
||||
job_name_regex: job name regex.
|
||||
|
||||
Returns:
|
||||
A list of found JobInfo.
|
||||
"""
|
||||
self._update_all()
|
||||
infos = []
|
||||
for r in self.jobs.values():
|
||||
job_name = r.name.split(":")[-1] # Extract the job name from slurm name
|
||||
if re.fullmatch(job_name_regex, job_name):
|
||||
infos.append(r)
|
||||
return infos
|
||||
|
||||
def _find_job_with_status(
|
||||
self,
|
||||
status: List[JobState],
|
||||
) -> List[JobInfo]:
|
||||
"""Finds jobs with the given status.
|
||||
|
||||
Args:
|
||||
status: A list of JobState to filter jobs.
|
||||
|
||||
Returns:
|
||||
A list of JobInfo with the given status.
|
||||
"""
|
||||
self._update_all()
|
||||
return [r for r in self.jobs.values() if r.state in status]
|
||||
|
||||
|
||||
def wait(
|
||||
self,
|
||||
timeout=None,
|
||||
check_status: Tuple[JobState, ...] = (
|
||||
JobState.CANCELLED,
|
||||
JobState.FAILED,
|
||||
JobState.NOT_FOUND,
|
||||
),
|
||||
remove_status: Tuple[JobState, ...] = (JobState.COMPLETED,),
|
||||
update=False,
|
||||
):
|
||||
"""Waits until all jobs submitted via this client instance finish."""
|
||||
# begin wait
|
||||
deadline = None if timeout is None else time.time() + timeout
|
||||
|
||||
num_jobs_left = len(self.jobs)
|
||||
left = set(self.jobs.values())
|
||||
logger.info(
|
||||
f"Waiting for {num_jobs_left} jobs. Jobs IDs: "
|
||||
f"{','.join(sorted([x.slurm_id for x in self.jobs.values()]))}."
|
||||
)
|
||||
while len(left) > 0:
|
||||
if len(left) < num_jobs_left:
|
||||
num_jobs_left = len(left)
|
||||
logger.info(f"Waiting for {num_jobs_left} jobs.")
|
||||
if deadline is not None and time.time() > deadline:
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for {num_jobs_left} jobs. Job ID: "
|
||||
f"{','.join(sorted([x.slurm_id for x in self.jobs.values()]))}."
|
||||
)
|
||||
self._update_all()
|
||||
left = list(self.jobs.values())
|
||||
for slurm_info in list(left):
|
||||
if slurm_info.slurm_id is None:
|
||||
continue
|
||||
if slurm_info.state in check_status:
|
||||
raise JobException(
|
||||
run_name=self.run_name,
|
||||
worker_type=slurm_info.name,
|
||||
host=slurm_info.host,
|
||||
reason=slurm_info.state,
|
||||
)
|
||||
if slurm_info.state in remove_status:
|
||||
logger.info(
|
||||
f"Job {slurm_info.name} is {slurm_info.state}. (Removed)"
|
||||
)
|
||||
left.remove(slurm_info)
|
||||
if update:
|
||||
self.jobs.pop(slurm_info.name)
|
||||
time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL)
|
||||
|
||||
def _update_all(self):
|
||||
"""Updates the status of all jobs. """
|
||||
try:
|
||||
slurm_infos = query_jobs(list(self.jobs.keys()))
|
||||
for slurm_info in slurm_infos:
|
||||
if slurm_info.name in self.jobs:
|
||||
self.jobs[slurm_info.name] = slurm_info
|
||||
except subprocess.CalledProcessError:
|
||||
logger.warning(
|
||||
"Calling squeue failed. Check slurm manually if you continue to see this warning."
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Slurm Launcher for AReaL")
|
||||
parser.add_argument("entry_point", type=str, help="The entry point script to run.")
|
||||
parser.add_argument("config_path", type=str, help="Path to the configuration file.")
|
||||
parser.add_argument(
|
||||
"--sglang-server-base-port", type=int, required=False, default=27010,
|
||||
help="Base port for SGLang servers. SGLang servers on the same node will ."
|
||||
)
|
||||
parser.add_argument("--trainer-port", type=int, required=False, default=27009, help="Pytorch distributed initialization port for trainer.")
|
||||
parser.add_argument("remaining_args", nargs='*', help="Additional arguments to pass to the entry point script.")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# usage: python -m arealite.launcher.slurm <entry_point> --allocation_mode <allocation_mode> <config_path> [<args>]
|
||||
args = parse_args()
|
||||
config = OmegaConf.load(args.config_path)
|
||||
# Fix config with remaining args
|
||||
config = OmegaConf.merge(config, OmegaConf.from_dotlist(args.remaining_args))
|
||||
|
||||
|
||||
cluster_config: ClusterSpecConfig = find_and_amend_config(config, "cluster", ClusterSpecConfig)
|
||||
assert cluster_config is not None, "Cluster configuration is required for slurm launcher."
|
||||
|
||||
n_nodes = find_config(config, "n_nodes")
|
||||
n_gpus_per_node = find_config(config, "n_gpus_per_node")
|
||||
assert n_gpus_per_node is not None and isinstance(n_gpus_per_node, int)
|
||||
assert n_nodes is not None and isinstance(n_nodes, int)
|
||||
if n_gpus_per_node < cluster_config.n_gpus_per_node:
|
||||
raise ValueError(
|
||||
f"Slurm Launcher requires at least {cluster_config.n_gpus_per_node} (#GPUs per node) GPU. For usecases of less GPUs, use LocalLauncher instead."
|
||||
)
|
||||
elif n_gpus_per_node > cluster_config.n_gpus_per_node:
|
||||
raise ValueError(
|
||||
f"#GPU per node required by experiment ({n_gpus_per_node}) is larger than #GPU per node in the cluster ({cluster_config.n_gpus_per_node})."
|
||||
)
|
||||
|
||||
launcher = SlurmLauncher(
|
||||
experiment_name=find_config(config, "experiment_name"),
|
||||
trial_name=find_config(config, "trial_name"),
|
||||
fileroot=cluster_config.fileroot
|
||||
)
|
||||
|
||||
allocation_mode = find_config(config, "allocation_mode")
|
||||
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||
sglang_cmds = []
|
||||
n_sglang_nodes = 0
|
||||
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
||||
# Launcher should launch SGLang servers according to allocation mode.
|
||||
assert isinstance(allocation_mode, str), "Allocation mode should be a string."
|
||||
sglang_config = find_and_amend_config(config, "sglang", SGLangConfig)
|
||||
assert "gen" in allocation_mode
|
||||
assert allocation_mode.gen_pp_size == 1, "Pipeline generation in SGLang is not supported for now."
|
||||
assert allocation_mode.gen_tp_size <= cluster_config.n_gpus_per_node, "Currently only support SGLang TP size less <= #GPUs per node."
|
||||
sglang_world_size = allocation_mode.gen_world_size
|
||||
sglang_tp_size = allocation_mode.gen_tp_size
|
||||
n_sglang_server_instances = allocation_mode.gen_dp_size
|
||||
n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node
|
||||
|
||||
model_path = find_config(config, "path")
|
||||
assert model_path is not None and isinstance(model_path, str)
|
||||
|
||||
seed = find_config(config, "seed")
|
||||
base_port = args.sglang_server_base_port
|
||||
for i in range(n_sglang_server_instances):
|
||||
base_gpu_id = i * sglang_tp_size % n_gpus_per_node
|
||||
n_server_per_node = cluster_config.n_gpus_per_node // sglang_tp_size
|
||||
# Since we cannot get port information from slurm, we only ensure ports on the same .
|
||||
server_port = base_port + i % n_server_per_node * 2
|
||||
dist_port = base_port + i % n_server_per_node * 2 + 1
|
||||
dist_init_addr = f"tcp://localhost:{server_port}"
|
||||
sglang_cmds.append(
|
||||
SGLangConfig.build_cmd(
|
||||
sglang_config,
|
||||
model_path,
|
||||
sglang_tp_size,
|
||||
base_gpu_id=base_gpu_id,
|
||||
seed=seed + i
|
||||
)
|
||||
)
|
||||
|
||||
# launch sglang servers
|
||||
if sglang_cmds:
|
||||
launcher.submit_array(
|
||||
job_name="sglang-server",
|
||||
cmd=sglang_cmds,
|
||||
count=n_sglang_nodes,
|
||||
nodes=n_sglang_server_instances,
|
||||
n_gpus_per_node=cluster_config.n_gpus_per_node,
|
||||
cpus_per_task=15 * sglang_tp_size, # one sglang task occupy the entire node
|
||||
mem_per_task=150 * 1024 * sglang_tp_size, # 150GB per task
|
||||
container_image=cluster_config.gpu_infer_image,
|
||||
container_mounts=cluster_config.mount,
|
||||
env_vars=get_env_vars(cluster_config.cluster_name),
|
||||
)
|
||||
|
||||
trainer_n_nodes = n_nodes - n_sglang_nodes
|
||||
entry_point_cmd = f"{args.entry_point} --config {args.config_path} {' '.join(args.remaining_args)}"
|
||||
trainer_cmd_template = (
|
||||
f"torchrun --nnodes={{nnodes}} --nproc-per-node={{nproc_per_node}} --node-rank {{node_rank}} "
|
||||
f"--master-addr $head_node_ip --master-port {args.trainer_port} {entry_point_cmd}"
|
||||
)
|
||||
|
||||
trainer_cmds = []
|
||||
for i in range(trainer_n_nodes):
|
||||
# For each trainer node, we launch a trainer with the same command.
|
||||
# The node rank is the index of the node in the cluster.
|
||||
trainer_cmds.append(
|
||||
trainer_cmd_template.format(
|
||||
nnodes=trainer_n_nodes,
|
||||
nproc_per_node=cluster_config.n_gpus_per_node,
|
||||
node_rank=i
|
||||
)
|
||||
)
|
||||
|
||||
# launch trainers
|
||||
launcher.submit_array(
|
||||
job_name="trainer",
|
||||
cmd=trainer_cmds,
|
||||
count=trainer_n_nodes,
|
||||
nodes=trainer_n_nodes,
|
||||
n_gpus_per_node=cluster_config.n_gpus_per_node,
|
||||
cpus_per_task=120, # one trainer task occupy the entire node
|
||||
mem_per_task=1200 * 1024, # 1.2T per task
|
||||
container_image=cluster_config.gpu_image,
|
||||
container_mounts=cluster_config.mount,
|
||||
env_vars=get_env_vars(cluster_config.cluster_name),
|
||||
)
|
||||
|
||||
try:
|
||||
launcher.wait(
|
||||
check_status=(
|
||||
JobState.CANCELLED,
|
||||
JobState.FAILED,
|
||||
JobState.NOT_FOUND,
|
||||
JobState.COMPLETED,
|
||||
),
|
||||
remove_status=(),
|
||||
)
|
||||
except (KeyboardInterrupt, JobException, TimeoutError) as e:
|
||||
launcher.stop_all("SIGTERM")
|
||||
raise e
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
def find_config(config: DictConfig, name: str) -> DictConfig | None:
|
||||
# iterate through the nested DictConfig and find the first matching config with name
|
||||
for key, value in config.items():
|
||||
if key == name:
|
||||
return value
|
||||
if isinstance(value, DictConfig):
|
||||
found = find_config(value, name)
|
||||
if found:
|
||||
return found
|
||||
return None
|
||||
|
||||
def amend_config(config: DictConfig, config_cls):
|
||||
default_config = OmegaConf.structured(config_cls)
|
||||
config = OmegaConf.merge(default_config, config)
|
||||
config = OmegaConf.to_object(config)
|
||||
assert isinstance(config, config_cls)
|
||||
return config
|
||||
|
||||
def find_and_amend_config(config: DictConfig, name: str, config_cls):
|
||||
# Find the config with the given name and amend it with the given config_cls
|
||||
found = find_config(config, name)
|
||||
if found is not None:
|
||||
return amend_config(found, config_cls)
|
||||
return None
|
Loading…
Reference in New Issue