diff --git a/arealite/api/cli_args.py b/arealite/api/cli_args.py index 743eba5..62296f1 100644 --- a/arealite/api/cli_args.py +++ b/arealite/api/cli_args.py @@ -624,6 +624,23 @@ class DatasetConfig: drop_last: bool = field(default=True) + +@dataclass +class RayLauncherConfig: + """Configuration for launching the SGLang server with Ray.""" + main_func_name: int = field( + default="main", + metadata={"help": "Name of the main function in the entrypoint file to run in Ray workers."}, + ) + +@dataclass +class SlurmLauncherConfig: + """Configuration for launching the SGLang server with Slurm.""" + srun_cmd_template: Optional[str] = field( + default=None, + metadata={"help": "Template for the srun command to launch the SGLang server."}, + ) + @dataclass class LauncherConfig: """Configuration for launching the SGLang server.""" @@ -662,6 +679,10 @@ class LauncherConfig: default=27015, metadata={"help": "Trainer port used for torch.distributed initialization."}, ) + ray: RayLauncherConfig = field( + default_factory=RayLauncherConfig, + metadata={"help": "Ray launcher configuration."}, + ) @dataclass diff --git a/arealite/launcher/ray.py b/arealite/launcher/ray.py index e69de29..ddd3451 100644 --- a/arealite/launcher/ray.py +++ b/arealite/launcher/ray.py @@ -0,0 +1,353 @@ +import os +import sys +import time +import getpass +import pathlib +from typing import List, Optional, Dict +import importlib.util + +import ray +import ray.exceptions +from ray.runtime_env import RuntimeEnv +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from ray.util.placement_group import PlacementGroup +from ray.job_submission import JobStatus as RayJobStatus + +import realhf.base.logging as logging +from arealite.api.cli_args import ( + ClusterSpecConfig, + LauncherConfig, + SGLangConfig, + parse_cli_args, + to_structured_cfg, +) +from arealite.api.io_struct import AllocationMode, AllocationType +from arealite.utils.launcher import ( + get_env_vars, + wait_sglang_server_addrs +) +from realhf.base import logging, name_resolve +from realhf.scheduler.client import JobException + +logger = logging.getLogger("RayLauncher") + +RAY_WAIT_CHECK_TIME_INTERVAL = 5 # seconds + + +def run_func(file_path, function_name, *args, **kwargs): + # Convert the file path to a module name + module_name = file_path.replace('/', '_').replace('.', '_') + + # Load the module from file path + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + # Get the function and execute it + function = getattr(module, function_name) + return function(*args, **kwargs) + + +class RayLauncher: + def __init__(self, experiment_name: str, trial_name: str, fileroot: str): + self.experiment_name = experiment_name + self.trial_name = trial_name + self.fileroot = fileroot + + # job_name to ray future + self.jobs = {} + + @property + def run_name(self): + return f"{self.experiment_name}_{self.trial_name}" + + def log_path_of(self, job_name: str) -> str: + log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}" + os.makedirs(log_path, exist_ok=True) + return os.path.join(log_path, f"{job_name}.log") + + def submit( + self, + job_name: str, + file_path: str, + func_name: str, + gpus: int, + cpus: int, + mem: int, # MB + env_vars: Optional[Dict] = None, + placement_group: Optional[PlacementGroup] = None, + bundle_index: Optional[int] = None, + *args, # arguments to pass to the function + **kwargs, # keyword arguments to pass to the function + ): + runtime_env = RuntimeEnv( + env_vars = env_vars or dict(), + ) + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + bundle_index=bundle_index, + ) if placement_group is not None else "DEFAULT" + future = ray.remote( + num_cpus=gpus, + num_gpus=cpus, + memory=mem*1024*1024, # Convert MB to bytes + runtime_env=runtime_env, + scheduling_strategy=scheduling_strategy + )(run_func).remote( + file_path, func_name, *args, **kwargs + ) + self.jobs[job_name] = future + return future + + def submit_array( + self, + job_name: str, + file_path: str, + func_name: str, + count: int, + nodes: int, + list_args: List[List], + gpus_per_task: int, + cpus_per_task: int, + mem_per_task: int, # MB + list_kwargs: List[Dict] | None = None, + env_vars: Optional[Dict] = None, + amend_torch_dist_env: bool = False, + ): + """ Submit an array of jobs to Ray with ray placement groups. + """ + + if count % nodes != 0: + raise ValueError( + f"Count {count} is not divisible by nodes {nodes}. " + "Please ensure that count is a multiple of nodes." + ) + assert len(list_args) == count, ( + f"Length of list_args {len(list_args)} does not match count {count}." + ) + if list_kwargs is not None: + assert len(list_kwargs) == count, ( + f"Length of list_kwargs {len(list_kwargs)} does not match count {count}." + ) + + tasks_per_node = count // nodes + gpus_per_node = gpus_per_task * tasks_per_node + cpus_per_node = cpus_per_task * tasks_per_node + mem_per_node = mem_per_task * tasks_per_node + + placement_group = ray.util.placement_group( + bundles=[ + { + "CPU": cpus_per_node, + "GPU": gpus_per_node, + "memory": mem_per_node * 1024 * 1024, # Convert MB to bytes + } + ] * nodes, + strategy="STRICT_SPREAD", + ) + ray.get(placement_group.ready()) + + futures = [] + for i in enumerate(list_args): + args = list_args[i] + kwargs = list_kwargs[i] if list_kwargs is not None else {} + + # manage environment variables + env_vars = env_vars or {} + assert ( + "CUDA_VISIBLE_DEVICES" not in env_vars + ), "CUDA_VISIBLE_DEVICES should be automatically resolved by Launcher instead of manually assigned." + + gpu_id_start = (i % tasks_per_node) * gpus_per_task + gpu_id_end = ((i % tasks_per_node) + 1) * gpus_per_task + node_id = i // tasks_per_node + _env_vars = { + **env_vars, + "CUDA_VISIBLE_DEVICES": ",".join( + str(x) for x in range(gpu_id_start, gpu_id_end) + ), + } + + if amend_torch_dist_env: + assert gpus_per_task == 1 + _env_vars.update({ + "RANK": str(i), + "WORLD_SIZE": str(count), + "LOCAL_RANK": str(i % tasks_per_node), + }) + + future = self.submit( + job_name=f"{job_name}:{i}", + file_path=file_path, + func_name=func_name, + gpus=gpus_per_task, + cpus=cpus_per_task, + mem=mem_per_task, + env_vars=_env_vars, + placement_group=placement_group, + bundle_index=node_id, + *args, + **kwargs, + ) + futures.append(future) + + return futures + + def stop(self, job_name: str, force: bool = False): + """ Stop a job by name. """ + if job_name in self.jobs: + future = self.jobs[job_name] + try: + ray.cancel(future, force=force) + except Exception as e: + logger.error(f"Failed to cancel job {job_name}: {e}") + return + self.jobs.pop(job_name, None) + logger.info(f"Job {job_name} stopped.") + else: + logger.warning(f"Job {job_name} not found in running jobs.") + + def stop_all(self, force: bool = False): + """ Stop all jobs. """ + for job_name in list(self.jobs.keys()): + self.stop(job_name, force=force) + logger.info("All jobs stopped.") + self.jobs.clear() + + def wait(self): + """ Check every SCHEDULER_WAIT_CHECK_TIME_INTERVAL seconds for the status of all jobs. + If any jobs failed terminate all jobs, and return. + If any jobs completed, remove them from self.jobs. + If all jobs are completed, return. + """ + while self.jobs: + completed_jobs = [] + for job_name, future in list(self.jobs.items()): + try: + r = ray.get(future, timeout=0.1) + logger.info(f"Job {job_name} completed with result: {r}") + completed_jobs.append(job_name) + except ray.exceptions.GetTimeoutError: + continue + except ray.exceptions.RayTaskError as e: + logger.error(f"Job {job_name} failed with error: {e}, stopping all jobs.") + self.stop_all(force=True) + return + for job_name in completed_jobs: + self.jobs.pop(job_name, None) + logger.info(f"Job {job_name} completed. Removed.") + time.sleep(RAY_WAIT_CHECK_TIME_INTERVAL) + + +if __name__ == "__main__": + # usage: python -m arealite.launcher.ray --config [] launcher.ray.main_func_name= + ray.init() + config, config_file = parse_cli_args(sys.argv[2:]) + + config.launcher = to_structured_cfg(config.launcher, LauncherConfig) + config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig) + name_resolve.reconfigure(config.cluster.name_resolve) + + n_nodes = config.n_nodes + n_gpus_per_node = config.n_gpus_per_node + if n_gpus_per_node < config.cluster.n_gpus_per_node: + raise ValueError( + f"Slurm Launcher requires at least {config.cluster.n_gpus_per_node} (#GPUs per node) GPU. For usecases of less GPUs, use LocalLauncher instead." + ) + elif n_gpus_per_node > config.cluster.n_gpus_per_node: + raise ValueError( + f"#GPU per node required by experiment ({n_gpus_per_node}) is larger than #GPU per node in the cluster ({config.cluster.n_gpus_per_node})." + ) + + launcher = RayLauncher( + experiment_name=config.experiment_name, + trial_name=config.trial_name, + fileroot=config.cluster.fileroot, + ) + allocation_mode = config.allocation_mode + allocation_mode = AllocationMode.from_str(allocation_mode) + sglang_cmds = [] + sglang_addrs = [] + n_sglang_nodes = 0 + if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG: + # Launcher should launch SGLang servers according to allocation mode. + config.sglang = to_structured_cfg(config.sglang, SGLangConfig) + assert ( + allocation_mode.gen_pp_size == 1 + ), "Pipeline generation in SGLang is not supported for now." + assert ( + allocation_mode.gen_tp_size <= config.cluster.n_gpus_per_node + ), "Currently only support SGLang TP size less <= #GPUs per node." + sglang_world_size = allocation_mode.gen_world_size + sglang_tp_size = allocation_mode.gen_tp_size + n_sglang_servers = allocation_mode.gen_dp_size + n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node + n_sglang_servers_per_node = config.cluster.n_gpus_per_node // sglang_tp_size + + base_seed = config.sglang.random_seed + + sglang_args_list = [ + sys.argv[2:] + [ + f"sglang.random_seed={base_seed + i}" + ] for i in range(n_sglang_servers) + ] + sglang_entry_point = pathlib.Path(__file__).resolve().joinpath("sglang_server.py") + sglang_main_func_name = "main_sglang_server" + launcher.submit_array( + job_name="llm_server", + file_path=sglang_entry_point, + func_name=sglang_main_func_name, + count=n_sglang_servers, + nodes=n_sglang_nodes, + list_args=sglang_args_list, + gpus_per_task=sglang_tp_size, + cpus_per_task=config.launcher.inference_server_cpus_per_gpu * sglang_tp_size, + mem_per_task=config.launcher.inference_server_mem_per_gpu * sglang_tp_size, + env_vars=get_env_vars( + config.cluster.cluster_name, + config.launcher.inference_server_env_vars, + ), + ) + + # Get SGLang slurm nodes, find the hosts + sglang_addrs = wait_sglang_server_addrs( + config.experiment_name, + config.trial_name, + n_sglang_servers, + ) + + trainer_n_nodes = n_nodes - n_sglang_nodes + trainer_entry_point = sys.argv[1] + trainer_main_func_name = config.launcher.ray.main_func_name + n_trainer_processes = trainer_n_nodes * config.cluster.n_gpus_per_node + trainer_args_list = [ + sys.argv[2:] for _ in range(n_trainer_processes) + ] + if not config.server_only: + # launch trainers + launcher.submit_array( + job_name="trainer", + file_path=trainer_entry_point, + func_name=trainer_main_func_name, + count=trainer_n_nodes * config.cluster.n_gpus_per_node, + nodes=trainer_n_nodes, + list_args=trainer_args_list, + gpus_per_task=1, + cpus_per_task=config.launcher.trainer_cpus_per_gpu, + mem_per_task=config.launcher.trainer_mem_per_gpu, + env_vars=dict( + **get_env_vars( + config.cluster.cluster_name, + config.launcher.trainer_env_vars, + ), + AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs), + ), + amend_torch_dist_env=True, + ) + + try: + launcher.wait() + except (KeyboardInterrupt, JobException, TimeoutError) as e: + launcher.stop_all(force=True) + raise e \ No newline at end of file diff --git a/arealite/launcher/sglang_server.py b/arealite/launcher/sglang_server.py index f840114..9ef0a1d 100644 --- a/arealite/launcher/sglang_server.py +++ b/arealite/launcher/sglang_server.py @@ -147,9 +147,16 @@ class SGLangServerWrapper: f"SGLang server at http://{host_ip}:{server_port} exits, returncode={return_code}" ) + def __del__(self): + if self.server_process and self.server_process.poll() is None: + logger.info("Terminating SGLang server process...") + self.server_process.terminate() + self.server_process.wait() + logger.info("SGLang server process terminated.") -if __name__ == "__main__": - config, _ = parse_cli_args(sys.argv[2:]) + +def main_sglang_server(argv): + config, _ = parse_cli_args(argv) config.sglang = to_structured_cfg(config.sglang, SGLangConfig) config.cluster.name_resolve = to_structured_cfg( config.cluster.name_resolve, NameResolveConfig @@ -169,3 +176,7 @@ if __name__ == "__main__": n_gpus_per_node=config.n_gpus_per_node, ) sglang_server.run() + + +if __name__ == "__main__": + main_sglang_server(sys.argv[1:]) \ No newline at end of file diff --git a/arealite/launcher/slurm.py b/arealite/launcher/slurm.py index 72b3962..d3efea6 100644 --- a/arealite/launcher/slurm.py +++ b/arealite/launcher/slurm.py @@ -8,8 +8,6 @@ import sys import time from typing import Dict, List, Literal, Optional, Tuple -from omegaconf import OmegaConf - import realhf.base.logging as logging from arealite.api.cli_args import ( BaseExperimentConfig, @@ -20,197 +18,24 @@ from arealite.api.cli_args import ( to_structured_cfg, ) from arealite.api.io_struct import AllocationMode, AllocationType +from arealite.utils.launcher import ( + get_env_vars, + wait_sglang_server_addrs, +) +from arealite.utils.slurm import ( + cancel_jobs, + query_jobs, + SBATCH_SCRIPT_TEMPLATE, + DEFAULT_SRUN_CMD_TEMPLATE +) from realhf.base import logging, name_resolve, names from realhf.scheduler.client import JobException, JobInfo, JobState logger = logging.getLogger("SlurmLauncher") - -SQUEUE_FIELDS = [ - "JobID", - "State", - "SubmitTime", - "StartTime", - "Name", - "NodeList", - "UserName", - "MaxCPUs", - "cpus-per-task", - "NumTasks", - "tres-alloc", -] -STATUS_MAPPING = { - "RUNNING": JobState.RUNNING, - "COMPLETING": JobState.RUNNING, - "PENDING": JobState.PENDING, - "CANCELLED": JobState.CANCELLED, - "FAILED": JobState.FAILED, - "COMPLETED": JobState.COMPLETED, - "OUT_OF_MEMORY": JobState.FAILED, - "DEADLINE": JobState.COMPLETED, - "TIMEOUT": JobState.COMPLETED, -} +SLURM_WAIT_CHECK_TIME_INTERVAL = 5 -def cancel_jobs( - slurm_names: Optional[List[str]] = None, - slurm_ids: Optional[List[int]] = None, - signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL", -): - assert ( - slurm_names is not None or slurm_ids is not None - ), "Must specify slurm_names or slurm_ids." - assert not ( - slurm_names and slurm_ids - ), "Cannot specify both slurm_names and slurm_ids." - cmd = ["scancel", "-s", signal] - if slurm_names is not None: - cmd += ["-n", ",".join(slurm_names)] - elif slurm_ids is not None: - cmd += [",".join(str(s) for s in slurm_ids)] - subprocess.check_call(cmd) - logger.info( - f"Cancelled Slurm job with signal {signal}: " - f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}" - ) - - -def query_jobs( - slurm_names: Optional[List[str]] = None, - slurm_ids: Optional[List[int]] = None, - status: str = "all", - delimiter: str = "__PSI__", -) -> List[JobInfo]: - squeue_format = f":.{delimiter},".join(SQUEUE_FIELDS) - cmd = ["squeue", "-O", squeue_format, f"-t{status}"] - if slurm_names is not None: - cmd += ["-n", ",".join(slurm_names)] - if slurm_ids is not None: - cmd += ["-j", ",".join([str(s) for s in slurm_ids])] - - output = ( - subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip() - ) - rs = [] - for line in output.split("\n")[1:]: - job_id, state, submit_time, start_time, slurm_name, nodelist, *_ = line.split( - delimiter - ) - rs.append( - JobInfo( - name=slurm_name, - state=STATUS_MAPPING[state], - host=nodelist, - submit_time=submit_time, - start_time=start_time, - slurm_id=int(job_id.strip()), - ) - ) - return rs - - -def parse_slurm_nodelist(nodelist: str) -> List[str]: - return ( - subprocess.check_output( - [ - "scontrol", - "show", - "hostnames", - nodelist, - ] - ) - .decode("utf-8") - .strip() - .split("\n") - ) - - -def get_slurm_host_ip(node: str): - try: - cmd = f"srun --overlap --mpi=pmi2 --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address" - return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip() - except subprocess.CalledProcessError: - logger.warning(f"Get slurm host ip for node {node} failed.") - - -SGLANG_SERVER_TIMEOUT_SECONDS = 180 -SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5 - - -SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash -{sbatch_options} - -# Getting the node names -nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") -echo nodes=$nodes - -nodes_array=($nodes) -echo node_array=$nodes_array - -head_node=${{nodes_array[0]}} -echo head_node=$head_node - -# Getting the head node IP address -head_node_ip=$(srun --overlap --mpi=pmi2 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address) -echo head_node_ip=$head_node_ip - -# srun commands -{srun_cmds} - -wait -""" - -SRUN_CMD_TEMPLATE: str = """srun --overlap --mpi=pmi2 -K -l --chdir $PWD --nodelist=${{nodes_array[{node_id}]}} \\ - --nodes={nodes} --ntasks={ntasks} --gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} \\ - --mem-per-cpu={mem_per_cpu}M {apptainer_name} exec {apptainer_options} --bind {container_mounts} \\ - {container_env_strings} \\ - {container_image} \\ - {cmd} & -""" - -LOCAL_CACHE_DIR = "/tmp/arealite" -PYTORCH_KERNEL_CACHE_PATH = ( - f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels" -) -TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton" -os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True) -os.makedirs(TRITON_CACHE_PATH, exist_ok=True) -BASE_ENVIRONS = { - "TOKENIZERS_PARALLELISM": "true", - "PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH, - "TRITON_CACHE_DIR": TRITON_CACHE_PATH, - "CUDA_DEVICE_MAX_CONNECTIONS": "1", - "OMP_NUM_THREADS": str(min(os.cpu_count(), 32)), - "HF_ENDPOINT": "https://hf-mirror.com", # FIXME: move to user option - "PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent, -} -NA132_ENVIRONS = { - "NCCL_SOCKET_IFNAME": "bond0", - "NCCL_NET_PLUGIN": "", - "NCCL_IB_GID_INDEX": "3", - "NCCL_IB_TIMEOUT": "2", - "NCCL_IB_RETRY_CNT": "7", - "NCCL_IB_SL": "5", - "NCCL_IB_TC": "136", - "NCCL_IB_HCA": "mlx5_bond", - "NCCL_IB_QPS_PER_CONNECTION": "8", - "NCCL_SET_THREAD_NAME": "1", - "NCCL_DEBUG": "WARN", - "NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH", -} - - -def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]: - """Returns the environment variables for the cluster.""" - additional_env_vars = {} - if additional_env_vars: - additional_env_vars = dict( - item.split("=") for item in additional_env_vars.split(",") - ) - if cluster_name == "na132": - return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars} - else: - return {**BASE_ENVIRONS, **additional_env_vars} class SlurmLauncher: @@ -261,6 +86,7 @@ class SlurmLauncher: self, job_name: str, cmd: List[str] | str, + srun_cmd_template: str, count: int, nodes: int, n_gpus_per_node: int, @@ -271,13 +97,6 @@ class SlurmLauncher: env_vars: Optional[Dict] = None, nodelist: Optional[str] = None, exclude: Optional[str] = None, - apptainer_name: Optional[str] = "singularity", - apptainer_options: Optional[Tuple[str, ...]] = ( - "--no-home", - "--writable-tmpfs", - "--nv", - "--pid", - ), ): """Submits and launch a job array with SBATCH. Note that a job array has one (unique) slurm name, and one (unique) slurm id. @@ -359,15 +178,13 @@ class SlurmLauncher: # FIXME: only for debugging, remove and replace new image # job_cmd = f'bash -c "pip3 install -r requirements.txt; {job_cmd}"' - srun_cmd = SRUN_CMD_TEMPLATE.format( + srun_cmd = srun_cmd_template.format( nodes=1, ntasks=1, node_id=node_id, n_gpus_per_node=n_gpus_per_node, cpus_per_task=cpus_per_task, mem_per_cpu=mem_per_cpu, - apptainer_name=apptainer_name, - apptainer_options=" ".join(apptainer_options), container_mounts=container_mounts or "", container_env_strings=env_string, container_image=container_image, @@ -534,7 +351,7 @@ class SlurmLauncher: left.remove(slurm_id) if update: self.jobs.pop(slurm_info.slurm_id) - time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL) + time.sleep(SLURM_WAIT_CHECK_TIME_INTERVAL) def _update_all(self): """Updates the status of all jobs.""" @@ -550,7 +367,7 @@ class SlurmLauncher: if __name__ == "__main__": - # usage: python -m arealite.launcher.slurm [] + # usage: python -m arealite.launcher.slurm --config [] config, config_file = parse_cli_args(sys.argv[2:]) config.launcher = to_structured_cfg(config.launcher, LauncherConfig) @@ -601,43 +418,30 @@ if __name__ == "__main__": ) sglang_cmds.append(sglang_cmd) - if sglang_cmds: - launcher.submit_array( - job_name="sglang-server", - cmd=sglang_cmds, - count=n_sglang_servers, - nodes=n_sglang_nodes, - n_gpus_per_node=config.cluster.n_gpus_per_node, - cpus_per_task=config.launcher.inference_server_cpus_per_gpu - * sglang_tp_size, - mem_per_task=config.launcher.inference_server_mem_per_gpu - * sglang_tp_size, - container_image=config.cluster.gpu_infer_image, - container_mounts=config.cluster.mount, - env_vars=get_env_vars( - config.cluster.cluster_name, - config.launcher.inference_server_env_vars, - ), - ) - + launcher.submit_array( + job_name="llm_server", + cmd=sglang_cmds, + srun_cmd_template=config.launcher.srun_cmd_template or DEFAULT_SRUN_CMD_TEMPLATE, + count=n_sglang_servers, + nodes=n_sglang_nodes, + n_gpus_per_node=config.cluster.n_gpus_per_node, + cpus_per_task=config.launcher.inference_server_cpus_per_gpu + * sglang_tp_size, + mem_per_task=config.launcher.inference_server_mem_per_gpu + * sglang_tp_size, + container_image=config.cluster.gpu_infer_image, + container_mounts=config.cluster.mount, + env_vars=get_env_vars( + config.cluster.cluster_name, + config.launcher.inference_server_env_vars, + ), + ) # Get SGLang slurm nodes, find the hosts - name = names.gen_servers(config.experiment_name, config.trial_name) - start = time.perf_counter() - while True: - sglang_addrs = name_resolve.get_subtree(name) - if len(sglang_addrs) >= n_sglang_servers: - logger.info( - f"Found {n_sglang_servers} SGLang servers: {', '.join(sglang_addrs)}" - ) - break - - time.sleep(1) - if time.perf_counter() - start > SGLANG_SERVER_TIMEOUT_SECONDS: - launcher.stop_all() - raise TimeoutError( - f"Timeout waiting for SGLang servers to be ready. " - f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}." - ) + sglang_addrs = wait_sglang_server_addrs( + config.experiment_name, + config.trial_name, + n_sglang_servers, + ) trainer_n_nodes = n_nodes - n_sglang_nodes trainer_cmd_template = ( @@ -662,6 +466,7 @@ if __name__ == "__main__": launcher.submit_array( job_name="trainer", cmd=trainer_cmds, + srun_cmd_template=config.launcher.srun_cmd_template or DEFAULT_SRUN_CMD_TEMPLATE, count=trainer_n_nodes, nodes=trainer_n_nodes, n_gpus_per_node=config.cluster.n_gpus_per_node, diff --git a/arealite/launcher/utils.py b/arealite/launcher/utils.py new file mode 100644 index 0000000..819338d --- /dev/null +++ b/arealite/launcher/utils.py @@ -0,0 +1,77 @@ +import os +import time +import pathlib +import getpass +from typing import Dict, Optional + +from realhf.base import names, name_resolve, logging + +logger = logging.getLogger("Launcher Utils") + +LOCAL_CACHE_DIR = "/tmp/arealite" +PYTORCH_KERNEL_CACHE_PATH = ( + f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels" +) +TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton" +os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True) +os.makedirs(TRITON_CACHE_PATH, exist_ok=True) +BASE_ENVIRONS = { + "TOKENIZERS_PARALLELISM": "true", + "PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH, + "TRITON_CACHE_DIR": TRITON_CACHE_PATH, + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "OMP_NUM_THREADS": str(min(os.cpu_count(), 32)), + "HF_ENDPOINT": "https://hf-mirror.com", # FIXME: move to user option + "PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent, +} +NA132_ENVIRONS = { + "NCCL_SOCKET_IFNAME": "bond0", + "NCCL_NET_PLUGIN": "", + "NCCL_IB_GID_INDEX": "3", + "NCCL_IB_TIMEOUT": "2", + "NCCL_IB_RETRY_CNT": "7", + "NCCL_IB_SL": "5", + "NCCL_IB_TC": "136", + "NCCL_IB_HCA": "mlx5_bond", + "NCCL_IB_QPS_PER_CONNECTION": "8", + "NCCL_SET_THREAD_NAME": "1", + "NCCL_DEBUG": "WARN", + "NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH", +} +SGLANG_SERVER_WAIT_TIMEOUT_SECONDS = 180 + +def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]: + """Returns the environment variables for the cluster.""" + additional_env_vars = {} + if additional_env_vars: + additional_env_vars = dict( + item.split("=") for item in additional_env_vars.split(",") + ) + if cluster_name == "na132": + return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars} + else: + return {**BASE_ENVIRONS, **additional_env_vars} + +def wait_sglang_server_addrs( + experiment_name: str, + trial_name: str, + n_sglang_servers: int, +): + # Get SGLang slurm nodes, find the hosts + name = names.gen_servers(experiment_name, trial_name) + start = time.perf_counter() + while True: + sglang_addrs = name_resolve.get_subtree(name) + if len(sglang_addrs) >= n_sglang_servers: + logger.info( + f"Found {n_sglang_servers} SGLang servers: {', '.join(sglang_addrs)}" + ) + break + + time.sleep(1) + if time.perf_counter() - start > SGLANG_SERVER_WAIT_TIMEOUT_SECONDS: + raise TimeoutError( + f"Timeout waiting for SGLang servers to be ready. " + f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}." + ) + return sglang_addrs \ No newline at end of file diff --git a/arealite/utils/launcher.py b/arealite/utils/launcher.py new file mode 100644 index 0000000..a6fa44b --- /dev/null +++ b/arealite/utils/launcher.py @@ -0,0 +1,77 @@ +import os +import time +import pathlib +import getpass +from typing import Dict + +from realhf.base import names, name_resolve, logging + +logger = logging.getLogger("Launcher Utils") + +LOCAL_CACHE_DIR = "/tmp/arealite" +PYTORCH_KERNEL_CACHE_PATH = ( + f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels" +) +TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton" +os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True) +os.makedirs(TRITON_CACHE_PATH, exist_ok=True) +BASE_ENVIRONS = { + "TOKENIZERS_PARALLELISM": "true", + "PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH, + "TRITON_CACHE_DIR": TRITON_CACHE_PATH, + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "OMP_NUM_THREADS": str(min(os.cpu_count(), 32)), + "HF_ENDPOINT": "https://hf-mirror.com", # FIXME: move to user option + "PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent, +} +NA132_ENVIRONS = { + "NCCL_SOCKET_IFNAME": "bond0", + "NCCL_NET_PLUGIN": "", + "NCCL_IB_GID_INDEX": "3", + "NCCL_IB_TIMEOUT": "2", + "NCCL_IB_RETRY_CNT": "7", + "NCCL_IB_SL": "5", + "NCCL_IB_TC": "136", + "NCCL_IB_HCA": "mlx5_bond", + "NCCL_IB_QPS_PER_CONNECTION": "8", + "NCCL_SET_THREAD_NAME": "1", + "NCCL_DEBUG": "WARN", + "NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH", +} +SGLANG_SERVER_WAIT_TIMEOUT_SECONDS = 180 + +def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]: + """Returns the environment variables for the cluster.""" + additional_env_vars = {} + if additional_env_vars: + additional_env_vars = dict( + item.split("=") for item in additional_env_vars.split(",") + ) + if cluster_name == "na132": + return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars} + else: + return {**BASE_ENVIRONS, **additional_env_vars} + +def wait_sglang_server_addrs( + experiment_name: str, + trial_name: str, + n_sglang_servers: int, +): + # Get SGLang slurm nodes, find the hosts + name = names.gen_servers(experiment_name, trial_name) + start = time.perf_counter() + while True: + sglang_addrs = name_resolve.get_subtree(name) + if len(sglang_addrs) >= n_sglang_servers: + logger.info( + f"Found {n_sglang_servers} SGLang servers: {', '.join(sglang_addrs)}" + ) + break + + time.sleep(1) + if time.perf_counter() - start > SGLANG_SERVER_WAIT_TIMEOUT_SECONDS: + raise TimeoutError( + f"Timeout waiting for SGLang servers to be ready. " + f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}." + ) + return sglang_addrs \ No newline at end of file diff --git a/arealite/utils/slurm.py b/arealite/utils/slurm.py new file mode 100644 index 0000000..f86264a --- /dev/null +++ b/arealite/utils/slurm.py @@ -0,0 +1,147 @@ +import subprocess +from typing import List, Optional, Literal + +from realhf.base import logging, name_resolve, names +from realhf.scheduler.client import JobException, JobInfo, JobState + +logger = logging.getLogger("Slurm Utils") + + +SQUEUE_FIELDS = [ + "JobID", + "State", + "SubmitTime", + "StartTime", + "Name", + "NodeList", + "UserName", + "MaxCPUs", + "cpus-per-task", + "NumTasks", + "tres-alloc", +] +STATUS_MAPPING = { + "RUNNING": JobState.RUNNING, + "COMPLETING": JobState.RUNNING, + "PENDING": JobState.PENDING, + "CANCELLED": JobState.CANCELLED, + "FAILED": JobState.FAILED, + "COMPLETED": JobState.COMPLETED, + "OUT_OF_MEMORY": JobState.FAILED, + "DEADLINE": JobState.COMPLETED, + "TIMEOUT": JobState.COMPLETED, +} + +SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash +{sbatch_options} + +# Getting the node names +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +echo nodes=$nodes + +nodes_array=($nodes) +echo node_array=$nodes_array + +head_node=${{nodes_array[0]}} +echo head_node=$head_node + +# Getting the head node IP address +head_node_ip=$(srun --overlap --mpi=pmi2 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address) +echo head_node_ip=$head_node_ip + +# srun commands +{srun_cmds} + +wait +""" + +# Default srun command template, using singularity as apptainer +DEFAULT_SRUN_CMD_TEMPLATE: str = """srun --overlap --mpi=pmi2 -K -l --chdir $PWD \\ + --nodelist=${{nodes_array[{node_id}]}} --nodes={nodes} --ntasks={ntasks} \\ + --gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} --mem-per-cpu={mem_per_cpu}M \\ + singularity exec --no-home --writable-tmpfs --nv --pid \\ + --bind {container_mounts} \\ + {container_env_strings} \\ + {container_image} \\ + {cmd} & +""" + +def cancel_jobs( + slurm_names: Optional[List[str]] = None, + slurm_ids: Optional[List[int]] = None, + signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL", +): + assert ( + slurm_names is not None or slurm_ids is not None + ), "Must specify slurm_names or slurm_ids." + assert not ( + slurm_names and slurm_ids + ), "Cannot specify both slurm_names and slurm_ids." + cmd = ["scancel", "-s", signal] + if slurm_names is not None: + cmd += ["-n", ",".join(slurm_names)] + elif slurm_ids is not None: + cmd += [",".join(str(s) for s in slurm_ids)] + subprocess.check_call(cmd) + logger.info( + f"Cancelled Slurm job with signal {signal}: " + f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}" + ) + + +def query_jobs( + slurm_names: Optional[List[str]] = None, + slurm_ids: Optional[List[int]] = None, + status: str = "all", + delimiter: str = "__PSI__", +) -> List[JobInfo]: + squeue_format = f":.{delimiter},".join(SQUEUE_FIELDS) + cmd = ["squeue", "-O", squeue_format, f"-t{status}"] + if slurm_names is not None: + cmd += ["-n", ",".join(slurm_names)] + if slurm_ids is not None: + cmd += ["-j", ",".join([str(s) for s in slurm_ids])] + + output = ( + subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip() + ) + rs = [] + for line in output.split("\n")[1:]: + job_id, state, submit_time, start_time, slurm_name, nodelist, *_ = line.split( + delimiter + ) + rs.append( + JobInfo( + name=slurm_name, + state=STATUS_MAPPING[state], + host=nodelist, + submit_time=submit_time, + start_time=start_time, + slurm_id=int(job_id.strip()), + ) + ) + return rs + + +def parse_slurm_nodelist(nodelist: str) -> List[str]: + return ( + subprocess.check_output( + [ + "scontrol", + "show", + "hostnames", + nodelist, + ] + ) + .decode("utf-8") + .strip() + .split("\n") + ) + + +def get_slurm_host_ip(node: str): + try: + cmd = f"srun --overlap --mpi=pmi2 --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address" + return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip() + except subprocess.CalledProcessError: + logger.warning(f"Get slurm host ip for node {node} failed.") diff --git a/examples/arealite/boba.py b/examples/arealite/boba.py index a916772..d4639b2 100644 --- a/examples/arealite/boba.py +++ b/examples/arealite/boba.py @@ -106,8 +106,8 @@ def boba_reward_fn( return label -def main_grpo(): - config, _ = load_expr_config(sys.argv[1:], GRPOConfig) +def main_grpo(argv): + config, _ = load_expr_config(argv, GRPOConfig) config: GRPOConfig rank = int(os.getenv("RANK")) @@ -238,4 +238,4 @@ def main_grpo(): if __name__ == "__main__": - main_grpo() + main_grpo(sys.argv[1:]) diff --git a/examples/arealite/gsm8k_grpo.py b/examples/arealite/gsm8k_grpo.py index 12b9491..37ace39 100644 --- a/examples/arealite/gsm8k_grpo.py +++ b/examples/arealite/gsm8k_grpo.py @@ -75,8 +75,8 @@ def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **k return int(sol.strip() == ans.strip()) -def main_grpo(): - config, _ = load_expr_config(sys.argv[1:], GRPOConfig) +def main_grpo(argv): + config, _ = load_expr_config(argv, GRPOConfig) config: GRPOConfig rank = int(os.getenv("RANK")) @@ -250,4 +250,4 @@ def main_grpo(): if __name__ == "__main__": - main_grpo() + main_grpo(sys.argv[1:]) diff --git a/examples/arealite/gsm8k_sft.py b/examples/arealite/gsm8k_sft.py index b7ca34a..23f40bd 100644 --- a/examples/arealite/gsm8k_sft.py +++ b/examples/arealite/gsm8k_sft.py @@ -38,8 +38,8 @@ def get_gsm8k_dataset(split, tokenizer, rank, world_size): return process_gsm8k_sft_dataset(dataset, tokenizer) -def main_sft(): - config, _ = load_expr_config(sys.argv[1:], SFTConfig) +def main_sft(argv): + config, _ = load_expr_config(argv, SFTConfig) config: SFTConfig rank = int(os.getenv("RANK")) @@ -120,4 +120,4 @@ def main_sft(): if __name__ == "__main__": - main_sft() + main_sft(sys.argv[1:])