ray launcher before test

This commit is contained in:
晓雷 2025-07-15 14:22:45 +08:00
parent 29172e0e10
commit 44947fe7fa
10 changed files with 736 additions and 245 deletions

View File

@ -624,6 +624,23 @@ class DatasetConfig:
drop_last: bool = field(default=True)
@dataclass
class RayLauncherConfig:
"""Configuration for launching the SGLang server with Ray."""
main_func_name: int = field(
default="main",
metadata={"help": "Name of the main function in the entrypoint file to run in Ray workers."},
)
@dataclass
class SlurmLauncherConfig:
"""Configuration for launching the SGLang server with Slurm."""
srun_cmd_template: Optional[str] = field(
default=None,
metadata={"help": "Template for the srun command to launch the SGLang server."},
)
@dataclass
class LauncherConfig:
"""Configuration for launching the SGLang server."""
@ -662,6 +679,10 @@ class LauncherConfig:
default=27015,
metadata={"help": "Trainer port used for torch.distributed initialization."},
)
ray: RayLauncherConfig = field(
default_factory=RayLauncherConfig,
metadata={"help": "Ray launcher configuration."},
)
@dataclass

View File

@ -0,0 +1,353 @@
import os
import sys
import time
import getpass
import pathlib
from typing import List, Optional, Dict
import importlib.util
import ray
import ray.exceptions
from ray.runtime_env import RuntimeEnv
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.placement_group import PlacementGroup
from ray.job_submission import JobStatus as RayJobStatus
import realhf.base.logging as logging
from arealite.api.cli_args import (
ClusterSpecConfig,
LauncherConfig,
SGLangConfig,
parse_cli_args,
to_structured_cfg,
)
from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.utils.launcher import (
get_env_vars,
wait_sglang_server_addrs
)
from realhf.base import logging, name_resolve
from realhf.scheduler.client import JobException
logger = logging.getLogger("RayLauncher")
RAY_WAIT_CHECK_TIME_INTERVAL = 5 # seconds
def run_func(file_path, function_name, *args, **kwargs):
# Convert the file path to a module name
module_name = file_path.replace('/', '_').replace('.', '_')
# Load the module from file path
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
# Get the function and execute it
function = getattr(module, function_name)
return function(*args, **kwargs)
class RayLauncher:
def __init__(self, experiment_name: str, trial_name: str, fileroot: str):
self.experiment_name = experiment_name
self.trial_name = trial_name
self.fileroot = fileroot
# job_name to ray future
self.jobs = {}
@property
def run_name(self):
return f"{self.experiment_name}_{self.trial_name}"
def log_path_of(self, job_name: str) -> str:
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
os.makedirs(log_path, exist_ok=True)
return os.path.join(log_path, f"{job_name}.log")
def submit(
self,
job_name: str,
file_path: str,
func_name: str,
gpus: int,
cpus: int,
mem: int, # MB
env_vars: Optional[Dict] = None,
placement_group: Optional[PlacementGroup] = None,
bundle_index: Optional[int] = None,
*args, # arguments to pass to the function
**kwargs, # keyword arguments to pass to the function
):
runtime_env = RuntimeEnv(
env_vars = env_vars or dict(),
)
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
bundle_index=bundle_index,
) if placement_group is not None else "DEFAULT"
future = ray.remote(
num_cpus=gpus,
num_gpus=cpus,
memory=mem*1024*1024, # Convert MB to bytes
runtime_env=runtime_env,
scheduling_strategy=scheduling_strategy
)(run_func).remote(
file_path, func_name, *args, **kwargs
)
self.jobs[job_name] = future
return future
def submit_array(
self,
job_name: str,
file_path: str,
func_name: str,
count: int,
nodes: int,
list_args: List[List],
gpus_per_task: int,
cpus_per_task: int,
mem_per_task: int, # MB
list_kwargs: List[Dict] | None = None,
env_vars: Optional[Dict] = None,
amend_torch_dist_env: bool = False,
):
""" Submit an array of jobs to Ray with ray placement groups.
"""
if count % nodes != 0:
raise ValueError(
f"Count {count} is not divisible by nodes {nodes}. "
"Please ensure that count is a multiple of nodes."
)
assert len(list_args) == count, (
f"Length of list_args {len(list_args)} does not match count {count}."
)
if list_kwargs is not None:
assert len(list_kwargs) == count, (
f"Length of list_kwargs {len(list_kwargs)} does not match count {count}."
)
tasks_per_node = count // nodes
gpus_per_node = gpus_per_task * tasks_per_node
cpus_per_node = cpus_per_task * tasks_per_node
mem_per_node = mem_per_task * tasks_per_node
placement_group = ray.util.placement_group(
bundles=[
{
"CPU": cpus_per_node,
"GPU": gpus_per_node,
"memory": mem_per_node * 1024 * 1024, # Convert MB to bytes
}
] * nodes,
strategy="STRICT_SPREAD",
)
ray.get(placement_group.ready())
futures = []
for i in enumerate(list_args):
args = list_args[i]
kwargs = list_kwargs[i] if list_kwargs is not None else {}
# manage environment variables
env_vars = env_vars or {}
assert (
"CUDA_VISIBLE_DEVICES" not in env_vars
), "CUDA_VISIBLE_DEVICES should be automatically resolved by Launcher instead of manually assigned."
gpu_id_start = (i % tasks_per_node) * gpus_per_task
gpu_id_end = ((i % tasks_per_node) + 1) * gpus_per_task
node_id = i // tasks_per_node
_env_vars = {
**env_vars,
"CUDA_VISIBLE_DEVICES": ",".join(
str(x) for x in range(gpu_id_start, gpu_id_end)
),
}
if amend_torch_dist_env:
assert gpus_per_task == 1
_env_vars.update({
"RANK": str(i),
"WORLD_SIZE": str(count),
"LOCAL_RANK": str(i % tasks_per_node),
})
future = self.submit(
job_name=f"{job_name}:{i}",
file_path=file_path,
func_name=func_name,
gpus=gpus_per_task,
cpus=cpus_per_task,
mem=mem_per_task,
env_vars=_env_vars,
placement_group=placement_group,
bundle_index=node_id,
*args,
**kwargs,
)
futures.append(future)
return futures
def stop(self, job_name: str, force: bool = False):
""" Stop a job by name. """
if job_name in self.jobs:
future = self.jobs[job_name]
try:
ray.cancel(future, force=force)
except Exception as e:
logger.error(f"Failed to cancel job {job_name}: {e}")
return
self.jobs.pop(job_name, None)
logger.info(f"Job {job_name} stopped.")
else:
logger.warning(f"Job {job_name} not found in running jobs.")
def stop_all(self, force: bool = False):
""" Stop all jobs. """
for job_name in list(self.jobs.keys()):
self.stop(job_name, force=force)
logger.info("All jobs stopped.")
self.jobs.clear()
def wait(self):
""" Check every SCHEDULER_WAIT_CHECK_TIME_INTERVAL seconds for the status of all jobs.
If any jobs failed terminate all jobs, and return.
If any jobs completed, remove them from self.jobs.
If all jobs are completed, return.
"""
while self.jobs:
completed_jobs = []
for job_name, future in list(self.jobs.items()):
try:
r = ray.get(future, timeout=0.1)
logger.info(f"Job {job_name} completed with result: {r}")
completed_jobs.append(job_name)
except ray.exceptions.GetTimeoutError:
continue
except ray.exceptions.RayTaskError as e:
logger.error(f"Job {job_name} failed with error: {e}, stopping all jobs.")
self.stop_all(force=True)
return
for job_name in completed_jobs:
self.jobs.pop(job_name, None)
logger.info(f"Job {job_name} completed. Removed.")
time.sleep(RAY_WAIT_CHECK_TIME_INTERVAL)
if __name__ == "__main__":
# usage: python -m arealite.launcher.ray <entry_point> --config <config_path> [<additional_args>] launcher.ray.main_func_name=<main_func_name_in_entry_point>
ray.init()
config, config_file = parse_cli_args(sys.argv[2:])
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
name_resolve.reconfigure(config.cluster.name_resolve)
n_nodes = config.n_nodes
n_gpus_per_node = config.n_gpus_per_node
if n_gpus_per_node < config.cluster.n_gpus_per_node:
raise ValueError(
f"Slurm Launcher requires at least {config.cluster.n_gpus_per_node} (#GPUs per node) GPU. For usecases of less GPUs, use LocalLauncher instead."
)
elif n_gpus_per_node > config.cluster.n_gpus_per_node:
raise ValueError(
f"#GPU per node required by experiment ({n_gpus_per_node}) is larger than #GPU per node in the cluster ({config.cluster.n_gpus_per_node})."
)
launcher = RayLauncher(
experiment_name=config.experiment_name,
trial_name=config.trial_name,
fileroot=config.cluster.fileroot,
)
allocation_mode = config.allocation_mode
allocation_mode = AllocationMode.from_str(allocation_mode)
sglang_cmds = []
sglang_addrs = []
n_sglang_nodes = 0
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
# Launcher should launch SGLang servers according to allocation mode.
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
assert (
allocation_mode.gen_pp_size == 1
), "Pipeline generation in SGLang is not supported for now."
assert (
allocation_mode.gen_tp_size <= config.cluster.n_gpus_per_node
), "Currently only support SGLang TP size less <= #GPUs per node."
sglang_world_size = allocation_mode.gen_world_size
sglang_tp_size = allocation_mode.gen_tp_size
n_sglang_servers = allocation_mode.gen_dp_size
n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node
n_sglang_servers_per_node = config.cluster.n_gpus_per_node // sglang_tp_size
base_seed = config.sglang.random_seed
sglang_args_list = [
sys.argv[2:] + [
f"sglang.random_seed={base_seed + i}"
] for i in range(n_sglang_servers)
]
sglang_entry_point = pathlib.Path(__file__).resolve().joinpath("sglang_server.py")
sglang_main_func_name = "main_sglang_server"
launcher.submit_array(
job_name="llm_server",
file_path=sglang_entry_point,
func_name=sglang_main_func_name,
count=n_sglang_servers,
nodes=n_sglang_nodes,
list_args=sglang_args_list,
gpus_per_task=sglang_tp_size,
cpus_per_task=config.launcher.inference_server_cpus_per_gpu * sglang_tp_size,
mem_per_task=config.launcher.inference_server_mem_per_gpu * sglang_tp_size,
env_vars=get_env_vars(
config.cluster.cluster_name,
config.launcher.inference_server_env_vars,
),
)
# Get SGLang slurm nodes, find the hosts
sglang_addrs = wait_sglang_server_addrs(
config.experiment_name,
config.trial_name,
n_sglang_servers,
)
trainer_n_nodes = n_nodes - n_sglang_nodes
trainer_entry_point = sys.argv[1]
trainer_main_func_name = config.launcher.ray.main_func_name
n_trainer_processes = trainer_n_nodes * config.cluster.n_gpus_per_node
trainer_args_list = [
sys.argv[2:] for _ in range(n_trainer_processes)
]
if not config.server_only:
# launch trainers
launcher.submit_array(
job_name="trainer",
file_path=trainer_entry_point,
func_name=trainer_main_func_name,
count=trainer_n_nodes * config.cluster.n_gpus_per_node,
nodes=trainer_n_nodes,
list_args=trainer_args_list,
gpus_per_task=1,
cpus_per_task=config.launcher.trainer_cpus_per_gpu,
mem_per_task=config.launcher.trainer_mem_per_gpu,
env_vars=dict(
**get_env_vars(
config.cluster.cluster_name,
config.launcher.trainer_env_vars,
),
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
),
amend_torch_dist_env=True,
)
try:
launcher.wait()
except (KeyboardInterrupt, JobException, TimeoutError) as e:
launcher.stop_all(force=True)
raise e

View File

@ -147,9 +147,16 @@ class SGLangServerWrapper:
f"SGLang server at http://{host_ip}:{server_port} exits, returncode={return_code}"
)
def __del__(self):
if self.server_process and self.server_process.poll() is None:
logger.info("Terminating SGLang server process...")
self.server_process.terminate()
self.server_process.wait()
logger.info("SGLang server process terminated.")
if __name__ == "__main__":
config, _ = parse_cli_args(sys.argv[2:])
def main_sglang_server(argv):
config, _ = parse_cli_args(argv)
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
config.cluster.name_resolve = to_structured_cfg(
config.cluster.name_resolve, NameResolveConfig
@ -169,3 +176,7 @@ if __name__ == "__main__":
n_gpus_per_node=config.n_gpus_per_node,
)
sglang_server.run()
if __name__ == "__main__":
main_sglang_server(sys.argv[1:])

View File

@ -8,8 +8,6 @@ import sys
import time
from typing import Dict, List, Literal, Optional, Tuple
from omegaconf import OmegaConf
import realhf.base.logging as logging
from arealite.api.cli_args import (
BaseExperimentConfig,
@ -20,197 +18,24 @@ from arealite.api.cli_args import (
to_structured_cfg,
)
from arealite.api.io_struct import AllocationMode, AllocationType
from arealite.utils.launcher import (
get_env_vars,
wait_sglang_server_addrs,
)
from arealite.utils.slurm import (
cancel_jobs,
query_jobs,
SBATCH_SCRIPT_TEMPLATE,
DEFAULT_SRUN_CMD_TEMPLATE
)
from realhf.base import logging, name_resolve, names
from realhf.scheduler.client import JobException, JobInfo, JobState
logger = logging.getLogger("SlurmLauncher")
SQUEUE_FIELDS = [
"JobID",
"State",
"SubmitTime",
"StartTime",
"Name",
"NodeList",
"UserName",
"MaxCPUs",
"cpus-per-task",
"NumTasks",
"tres-alloc",
]
STATUS_MAPPING = {
"RUNNING": JobState.RUNNING,
"COMPLETING": JobState.RUNNING,
"PENDING": JobState.PENDING,
"CANCELLED": JobState.CANCELLED,
"FAILED": JobState.FAILED,
"COMPLETED": JobState.COMPLETED,
"OUT_OF_MEMORY": JobState.FAILED,
"DEADLINE": JobState.COMPLETED,
"TIMEOUT": JobState.COMPLETED,
}
SLURM_WAIT_CHECK_TIME_INTERVAL = 5
def cancel_jobs(
slurm_names: Optional[List[str]] = None,
slurm_ids: Optional[List[int]] = None,
signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL",
):
assert (
slurm_names is not None or slurm_ids is not None
), "Must specify slurm_names or slurm_ids."
assert not (
slurm_names and slurm_ids
), "Cannot specify both slurm_names and slurm_ids."
cmd = ["scancel", "-s", signal]
if slurm_names is not None:
cmd += ["-n", ",".join(slurm_names)]
elif slurm_ids is not None:
cmd += [",".join(str(s) for s in slurm_ids)]
subprocess.check_call(cmd)
logger.info(
f"Cancelled Slurm job with signal {signal}: "
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}"
)
def query_jobs(
slurm_names: Optional[List[str]] = None,
slurm_ids: Optional[List[int]] = None,
status: str = "all",
delimiter: str = "__PSI__",
) -> List[JobInfo]:
squeue_format = f":.{delimiter},".join(SQUEUE_FIELDS)
cmd = ["squeue", "-O", squeue_format, f"-t{status}"]
if slurm_names is not None:
cmd += ["-n", ",".join(slurm_names)]
if slurm_ids is not None:
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
output = (
subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip()
)
rs = []
for line in output.split("\n")[1:]:
job_id, state, submit_time, start_time, slurm_name, nodelist, *_ = line.split(
delimiter
)
rs.append(
JobInfo(
name=slurm_name,
state=STATUS_MAPPING[state],
host=nodelist,
submit_time=submit_time,
start_time=start_time,
slurm_id=int(job_id.strip()),
)
)
return rs
def parse_slurm_nodelist(nodelist: str) -> List[str]:
return (
subprocess.check_output(
[
"scontrol",
"show",
"hostnames",
nodelist,
]
)
.decode("utf-8")
.strip()
.split("\n")
)
def get_slurm_host_ip(node: str):
try:
cmd = f"srun --overlap --mpi=pmi2 --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address"
return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip()
except subprocess.CalledProcessError:
logger.warning(f"Get slurm host ip for node {node} failed.")
SGLANG_SERVER_TIMEOUT_SECONDS = 180
SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5
SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash
{sbatch_options}
# Getting the node names
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
echo nodes=$nodes
nodes_array=($nodes)
echo node_array=$nodes_array
head_node=${{nodes_array[0]}}
echo head_node=$head_node
# Getting the head node IP address
head_node_ip=$(srun --overlap --mpi=pmi2 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address)
echo head_node_ip=$head_node_ip
# srun commands
{srun_cmds}
wait
"""
SRUN_CMD_TEMPLATE: str = """srun --overlap --mpi=pmi2 -K -l --chdir $PWD --nodelist=${{nodes_array[{node_id}]}} \\
--nodes={nodes} --ntasks={ntasks} --gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} \\
--mem-per-cpu={mem_per_cpu}M {apptainer_name} exec {apptainer_options} --bind {container_mounts} \\
{container_env_strings} \\
{container_image} \\
{cmd} &
"""
LOCAL_CACHE_DIR = "/tmp/arealite"
PYTORCH_KERNEL_CACHE_PATH = (
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
)
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
BASE_ENVIRONS = {
"TOKENIZERS_PARALLELISM": "true",
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
"HF_ENDPOINT": "https://hf-mirror.com", # FIXME: move to user option
"PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent,
}
NA132_ENVIRONS = {
"NCCL_SOCKET_IFNAME": "bond0",
"NCCL_NET_PLUGIN": "",
"NCCL_IB_GID_INDEX": "3",
"NCCL_IB_TIMEOUT": "2",
"NCCL_IB_RETRY_CNT": "7",
"NCCL_IB_SL": "5",
"NCCL_IB_TC": "136",
"NCCL_IB_HCA": "mlx5_bond",
"NCCL_IB_QPS_PER_CONNECTION": "8",
"NCCL_SET_THREAD_NAME": "1",
"NCCL_DEBUG": "WARN",
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
}
def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]:
"""Returns the environment variables for the cluster."""
additional_env_vars = {}
if additional_env_vars:
additional_env_vars = dict(
item.split("=") for item in additional_env_vars.split(",")
)
if cluster_name == "na132":
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars}
else:
return {**BASE_ENVIRONS, **additional_env_vars}
class SlurmLauncher:
@ -261,6 +86,7 @@ class SlurmLauncher:
self,
job_name: str,
cmd: List[str] | str,
srun_cmd_template: str,
count: int,
nodes: int,
n_gpus_per_node: int,
@ -271,13 +97,6 @@ class SlurmLauncher:
env_vars: Optional[Dict] = None,
nodelist: Optional[str] = None,
exclude: Optional[str] = None,
apptainer_name: Optional[str] = "singularity",
apptainer_options: Optional[Tuple[str, ...]] = (
"--no-home",
"--writable-tmpfs",
"--nv",
"--pid",
),
):
"""Submits and launch a job array with SBATCH.
Note that a job array has one (unique) slurm name, and one (unique) slurm id.
@ -359,15 +178,13 @@ class SlurmLauncher:
# FIXME: only for debugging, remove and replace new image
# job_cmd = f'bash -c "pip3 install -r requirements.txt; {job_cmd}"'
srun_cmd = SRUN_CMD_TEMPLATE.format(
srun_cmd = srun_cmd_template.format(
nodes=1,
ntasks=1,
node_id=node_id,
n_gpus_per_node=n_gpus_per_node,
cpus_per_task=cpus_per_task,
mem_per_cpu=mem_per_cpu,
apptainer_name=apptainer_name,
apptainer_options=" ".join(apptainer_options),
container_mounts=container_mounts or "",
container_env_strings=env_string,
container_image=container_image,
@ -534,7 +351,7 @@ class SlurmLauncher:
left.remove(slurm_id)
if update:
self.jobs.pop(slurm_info.slurm_id)
time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL)
time.sleep(SLURM_WAIT_CHECK_TIME_INTERVAL)
def _update_all(self):
"""Updates the status of all jobs."""
@ -550,7 +367,7 @@ class SlurmLauncher:
if __name__ == "__main__":
# usage: python -m arealite.launcher.slurm <entry_point> <config_path> [<args>]
# usage: python -m arealite.launcher.slurm <entry_point> --config <config_path> [<additional_args>]
config, config_file = parse_cli_args(sys.argv[2:])
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
@ -601,43 +418,30 @@ if __name__ == "__main__":
)
sglang_cmds.append(sglang_cmd)
if sglang_cmds:
launcher.submit_array(
job_name="sglang-server",
cmd=sglang_cmds,
count=n_sglang_servers,
nodes=n_sglang_nodes,
n_gpus_per_node=config.cluster.n_gpus_per_node,
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
* sglang_tp_size,
mem_per_task=config.launcher.inference_server_mem_per_gpu
* sglang_tp_size,
container_image=config.cluster.gpu_infer_image,
container_mounts=config.cluster.mount,
env_vars=get_env_vars(
config.cluster.cluster_name,
config.launcher.inference_server_env_vars,
),
)
launcher.submit_array(
job_name="llm_server",
cmd=sglang_cmds,
srun_cmd_template=config.launcher.srun_cmd_template or DEFAULT_SRUN_CMD_TEMPLATE,
count=n_sglang_servers,
nodes=n_sglang_nodes,
n_gpus_per_node=config.cluster.n_gpus_per_node,
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
* sglang_tp_size,
mem_per_task=config.launcher.inference_server_mem_per_gpu
* sglang_tp_size,
container_image=config.cluster.gpu_infer_image,
container_mounts=config.cluster.mount,
env_vars=get_env_vars(
config.cluster.cluster_name,
config.launcher.inference_server_env_vars,
),
)
# Get SGLang slurm nodes, find the hosts
name = names.gen_servers(config.experiment_name, config.trial_name)
start = time.perf_counter()
while True:
sglang_addrs = name_resolve.get_subtree(name)
if len(sglang_addrs) >= n_sglang_servers:
logger.info(
f"Found {n_sglang_servers} SGLang servers: {', '.join(sglang_addrs)}"
)
break
time.sleep(1)
if time.perf_counter() - start > SGLANG_SERVER_TIMEOUT_SECONDS:
launcher.stop_all()
raise TimeoutError(
f"Timeout waiting for SGLang servers to be ready. "
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
)
sglang_addrs = wait_sglang_server_addrs(
config.experiment_name,
config.trial_name,
n_sglang_servers,
)
trainer_n_nodes = n_nodes - n_sglang_nodes
trainer_cmd_template = (
@ -662,6 +466,7 @@ if __name__ == "__main__":
launcher.submit_array(
job_name="trainer",
cmd=trainer_cmds,
srun_cmd_template=config.launcher.srun_cmd_template or DEFAULT_SRUN_CMD_TEMPLATE,
count=trainer_n_nodes,
nodes=trainer_n_nodes,
n_gpus_per_node=config.cluster.n_gpus_per_node,

View File

@ -0,0 +1,77 @@
import os
import time
import pathlib
import getpass
from typing import Dict, Optional
from realhf.base import names, name_resolve, logging
logger = logging.getLogger("Launcher Utils")
LOCAL_CACHE_DIR = "/tmp/arealite"
PYTORCH_KERNEL_CACHE_PATH = (
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
)
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
BASE_ENVIRONS = {
"TOKENIZERS_PARALLELISM": "true",
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
"HF_ENDPOINT": "https://hf-mirror.com", # FIXME: move to user option
"PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent,
}
NA132_ENVIRONS = {
"NCCL_SOCKET_IFNAME": "bond0",
"NCCL_NET_PLUGIN": "",
"NCCL_IB_GID_INDEX": "3",
"NCCL_IB_TIMEOUT": "2",
"NCCL_IB_RETRY_CNT": "7",
"NCCL_IB_SL": "5",
"NCCL_IB_TC": "136",
"NCCL_IB_HCA": "mlx5_bond",
"NCCL_IB_QPS_PER_CONNECTION": "8",
"NCCL_SET_THREAD_NAME": "1",
"NCCL_DEBUG": "WARN",
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
}
SGLANG_SERVER_WAIT_TIMEOUT_SECONDS = 180
def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]:
"""Returns the environment variables for the cluster."""
additional_env_vars = {}
if additional_env_vars:
additional_env_vars = dict(
item.split("=") for item in additional_env_vars.split(",")
)
if cluster_name == "na132":
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars}
else:
return {**BASE_ENVIRONS, **additional_env_vars}
def wait_sglang_server_addrs(
experiment_name: str,
trial_name: str,
n_sglang_servers: int,
):
# Get SGLang slurm nodes, find the hosts
name = names.gen_servers(experiment_name, trial_name)
start = time.perf_counter()
while True:
sglang_addrs = name_resolve.get_subtree(name)
if len(sglang_addrs) >= n_sglang_servers:
logger.info(
f"Found {n_sglang_servers} SGLang servers: {', '.join(sglang_addrs)}"
)
break
time.sleep(1)
if time.perf_counter() - start > SGLANG_SERVER_WAIT_TIMEOUT_SECONDS:
raise TimeoutError(
f"Timeout waiting for SGLang servers to be ready. "
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
)
return sglang_addrs

View File

@ -0,0 +1,77 @@
import os
import time
import pathlib
import getpass
from typing import Dict
from realhf.base import names, name_resolve, logging
logger = logging.getLogger("Launcher Utils")
LOCAL_CACHE_DIR = "/tmp/arealite"
PYTORCH_KERNEL_CACHE_PATH = (
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
)
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
BASE_ENVIRONS = {
"TOKENIZERS_PARALLELISM": "true",
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
"HF_ENDPOINT": "https://hf-mirror.com", # FIXME: move to user option
"PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent,
}
NA132_ENVIRONS = {
"NCCL_SOCKET_IFNAME": "bond0",
"NCCL_NET_PLUGIN": "",
"NCCL_IB_GID_INDEX": "3",
"NCCL_IB_TIMEOUT": "2",
"NCCL_IB_RETRY_CNT": "7",
"NCCL_IB_SL": "5",
"NCCL_IB_TC": "136",
"NCCL_IB_HCA": "mlx5_bond",
"NCCL_IB_QPS_PER_CONNECTION": "8",
"NCCL_SET_THREAD_NAME": "1",
"NCCL_DEBUG": "WARN",
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
}
SGLANG_SERVER_WAIT_TIMEOUT_SECONDS = 180
def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]:
"""Returns the environment variables for the cluster."""
additional_env_vars = {}
if additional_env_vars:
additional_env_vars = dict(
item.split("=") for item in additional_env_vars.split(",")
)
if cluster_name == "na132":
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars}
else:
return {**BASE_ENVIRONS, **additional_env_vars}
def wait_sglang_server_addrs(
experiment_name: str,
trial_name: str,
n_sglang_servers: int,
):
# Get SGLang slurm nodes, find the hosts
name = names.gen_servers(experiment_name, trial_name)
start = time.perf_counter()
while True:
sglang_addrs = name_resolve.get_subtree(name)
if len(sglang_addrs) >= n_sglang_servers:
logger.info(
f"Found {n_sglang_servers} SGLang servers: {', '.join(sglang_addrs)}"
)
break
time.sleep(1)
if time.perf_counter() - start > SGLANG_SERVER_WAIT_TIMEOUT_SECONDS:
raise TimeoutError(
f"Timeout waiting for SGLang servers to be ready. "
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
)
return sglang_addrs

147
arealite/utils/slurm.py Normal file
View File

@ -0,0 +1,147 @@
import subprocess
from typing import List, Optional, Literal
from realhf.base import logging, name_resolve, names
from realhf.scheduler.client import JobException, JobInfo, JobState
logger = logging.getLogger("Slurm Utils")
SQUEUE_FIELDS = [
"JobID",
"State",
"SubmitTime",
"StartTime",
"Name",
"NodeList",
"UserName",
"MaxCPUs",
"cpus-per-task",
"NumTasks",
"tres-alloc",
]
STATUS_MAPPING = {
"RUNNING": JobState.RUNNING,
"COMPLETING": JobState.RUNNING,
"PENDING": JobState.PENDING,
"CANCELLED": JobState.CANCELLED,
"FAILED": JobState.FAILED,
"COMPLETED": JobState.COMPLETED,
"OUT_OF_MEMORY": JobState.FAILED,
"DEADLINE": JobState.COMPLETED,
"TIMEOUT": JobState.COMPLETED,
}
SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash
{sbatch_options}
# Getting the node names
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
echo nodes=$nodes
nodes_array=($nodes)
echo node_array=$nodes_array
head_node=${{nodes_array[0]}}
echo head_node=$head_node
# Getting the head node IP address
head_node_ip=$(srun --overlap --mpi=pmi2 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address)
echo head_node_ip=$head_node_ip
# srun commands
{srun_cmds}
wait
"""
# Default srun command template, using singularity as apptainer
DEFAULT_SRUN_CMD_TEMPLATE: str = """srun --overlap --mpi=pmi2 -K -l --chdir $PWD \\
--nodelist=${{nodes_array[{node_id}]}} --nodes={nodes} --ntasks={ntasks} \\
--gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} --mem-per-cpu={mem_per_cpu}M \\
singularity exec --no-home --writable-tmpfs --nv --pid \\
--bind {container_mounts} \\
{container_env_strings} \\
{container_image} \\
{cmd} &
"""
def cancel_jobs(
slurm_names: Optional[List[str]] = None,
slurm_ids: Optional[List[int]] = None,
signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL",
):
assert (
slurm_names is not None or slurm_ids is not None
), "Must specify slurm_names or slurm_ids."
assert not (
slurm_names and slurm_ids
), "Cannot specify both slurm_names and slurm_ids."
cmd = ["scancel", "-s", signal]
if slurm_names is not None:
cmd += ["-n", ",".join(slurm_names)]
elif slurm_ids is not None:
cmd += [",".join(str(s) for s in slurm_ids)]
subprocess.check_call(cmd)
logger.info(
f"Cancelled Slurm job with signal {signal}: "
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}"
)
def query_jobs(
slurm_names: Optional[List[str]] = None,
slurm_ids: Optional[List[int]] = None,
status: str = "all",
delimiter: str = "__PSI__",
) -> List[JobInfo]:
squeue_format = f":.{delimiter},".join(SQUEUE_FIELDS)
cmd = ["squeue", "-O", squeue_format, f"-t{status}"]
if slurm_names is not None:
cmd += ["-n", ",".join(slurm_names)]
if slurm_ids is not None:
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
output = (
subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip()
)
rs = []
for line in output.split("\n")[1:]:
job_id, state, submit_time, start_time, slurm_name, nodelist, *_ = line.split(
delimiter
)
rs.append(
JobInfo(
name=slurm_name,
state=STATUS_MAPPING[state],
host=nodelist,
submit_time=submit_time,
start_time=start_time,
slurm_id=int(job_id.strip()),
)
)
return rs
def parse_slurm_nodelist(nodelist: str) -> List[str]:
return (
subprocess.check_output(
[
"scontrol",
"show",
"hostnames",
nodelist,
]
)
.decode("utf-8")
.strip()
.split("\n")
)
def get_slurm_host_ip(node: str):
try:
cmd = f"srun --overlap --mpi=pmi2 --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address"
return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip()
except subprocess.CalledProcessError:
logger.warning(f"Get slurm host ip for node {node} failed.")

View File

@ -106,8 +106,8 @@ def boba_reward_fn(
return label
def main_grpo():
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
def main_grpo(argv):
config, _ = load_expr_config(argv, GRPOConfig)
config: GRPOConfig
rank = int(os.getenv("RANK"))
@ -238,4 +238,4 @@ def main_grpo():
if __name__ == "__main__":
main_grpo()
main_grpo(sys.argv[1:])

View File

@ -75,8 +75,8 @@ def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **k
return int(sol.strip() == ans.strip())
def main_grpo():
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
def main_grpo(argv):
config, _ = load_expr_config(argv, GRPOConfig)
config: GRPOConfig
rank = int(os.getenv("RANK"))
@ -250,4 +250,4 @@ def main_grpo():
if __name__ == "__main__":
main_grpo()
main_grpo(sys.argv[1:])

View File

@ -38,8 +38,8 @@ def get_gsm8k_dataset(split, tokenizer, rank, world_size):
return process_gsm8k_sft_dataset(dataset, tokenizer)
def main_sft():
config, _ = load_expr_config(sys.argv[1:], SFTConfig)
def main_sft(argv):
config, _ = load_expr_config(argv, SFTConfig)
config: SFTConfig
rank = int(os.getenv("RANK"))
@ -120,4 +120,4 @@ def main_sft():
if __name__ == "__main__":
main_sft()
main_sft(sys.argv[1:])