mirror of https://github.com/inclusionAI/AReaL
ray launcher before test
This commit is contained in:
parent
29172e0e10
commit
44947fe7fa
|
@ -624,6 +624,23 @@ class DatasetConfig:
|
||||||
drop_last: bool = field(default=True)
|
drop_last: bool = field(default=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RayLauncherConfig:
|
||||||
|
"""Configuration for launching the SGLang server with Ray."""
|
||||||
|
main_func_name: int = field(
|
||||||
|
default="main",
|
||||||
|
metadata={"help": "Name of the main function in the entrypoint file to run in Ray workers."},
|
||||||
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SlurmLauncherConfig:
|
||||||
|
"""Configuration for launching the SGLang server with Slurm."""
|
||||||
|
srun_cmd_template: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Template for the srun command to launch the SGLang server."},
|
||||||
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LauncherConfig:
|
class LauncherConfig:
|
||||||
"""Configuration for launching the SGLang server."""
|
"""Configuration for launching the SGLang server."""
|
||||||
|
@ -662,6 +679,10 @@ class LauncherConfig:
|
||||||
default=27015,
|
default=27015,
|
||||||
metadata={"help": "Trainer port used for torch.distributed initialization."},
|
metadata={"help": "Trainer port used for torch.distributed initialization."},
|
||||||
)
|
)
|
||||||
|
ray: RayLauncherConfig = field(
|
||||||
|
default_factory=RayLauncherConfig,
|
||||||
|
metadata={"help": "Ray launcher configuration."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -0,0 +1,353 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import getpass
|
||||||
|
import pathlib
|
||||||
|
from typing import List, Optional, Dict
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import ray.exceptions
|
||||||
|
from ray.runtime_env import RuntimeEnv
|
||||||
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
from ray.job_submission import JobStatus as RayJobStatus
|
||||||
|
|
||||||
|
import realhf.base.logging as logging
|
||||||
|
from arealite.api.cli_args import (
|
||||||
|
ClusterSpecConfig,
|
||||||
|
LauncherConfig,
|
||||||
|
SGLangConfig,
|
||||||
|
parse_cli_args,
|
||||||
|
to_structured_cfg,
|
||||||
|
)
|
||||||
|
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||||
|
from arealite.utils.launcher import (
|
||||||
|
get_env_vars,
|
||||||
|
wait_sglang_server_addrs
|
||||||
|
)
|
||||||
|
from realhf.base import logging, name_resolve
|
||||||
|
from realhf.scheduler.client import JobException
|
||||||
|
|
||||||
|
logger = logging.getLogger("RayLauncher")
|
||||||
|
|
||||||
|
RAY_WAIT_CHECK_TIME_INTERVAL = 5 # seconds
|
||||||
|
|
||||||
|
|
||||||
|
def run_func(file_path, function_name, *args, **kwargs):
|
||||||
|
# Convert the file path to a module name
|
||||||
|
module_name = file_path.replace('/', '_').replace('.', '_')
|
||||||
|
|
||||||
|
# Load the module from file path
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules[module_name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
# Get the function and execute it
|
||||||
|
function = getattr(module, function_name)
|
||||||
|
return function(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class RayLauncher:
|
||||||
|
def __init__(self, experiment_name: str, trial_name: str, fileroot: str):
|
||||||
|
self.experiment_name = experiment_name
|
||||||
|
self.trial_name = trial_name
|
||||||
|
self.fileroot = fileroot
|
||||||
|
|
||||||
|
# job_name to ray future
|
||||||
|
self.jobs = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def run_name(self):
|
||||||
|
return f"{self.experiment_name}_{self.trial_name}"
|
||||||
|
|
||||||
|
def log_path_of(self, job_name: str) -> str:
|
||||||
|
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||||
|
os.makedirs(log_path, exist_ok=True)
|
||||||
|
return os.path.join(log_path, f"{job_name}.log")
|
||||||
|
|
||||||
|
def submit(
|
||||||
|
self,
|
||||||
|
job_name: str,
|
||||||
|
file_path: str,
|
||||||
|
func_name: str,
|
||||||
|
gpus: int,
|
||||||
|
cpus: int,
|
||||||
|
mem: int, # MB
|
||||||
|
env_vars: Optional[Dict] = None,
|
||||||
|
placement_group: Optional[PlacementGroup] = None,
|
||||||
|
bundle_index: Optional[int] = None,
|
||||||
|
*args, # arguments to pass to the function
|
||||||
|
**kwargs, # keyword arguments to pass to the function
|
||||||
|
):
|
||||||
|
runtime_env = RuntimeEnv(
|
||||||
|
env_vars = env_vars or dict(),
|
||||||
|
)
|
||||||
|
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=placement_group,
|
||||||
|
bundle_index=bundle_index,
|
||||||
|
) if placement_group is not None else "DEFAULT"
|
||||||
|
future = ray.remote(
|
||||||
|
num_cpus=gpus,
|
||||||
|
num_gpus=cpus,
|
||||||
|
memory=mem*1024*1024, # Convert MB to bytes
|
||||||
|
runtime_env=runtime_env,
|
||||||
|
scheduling_strategy=scheduling_strategy
|
||||||
|
)(run_func).remote(
|
||||||
|
file_path, func_name, *args, **kwargs
|
||||||
|
)
|
||||||
|
self.jobs[job_name] = future
|
||||||
|
return future
|
||||||
|
|
||||||
|
def submit_array(
|
||||||
|
self,
|
||||||
|
job_name: str,
|
||||||
|
file_path: str,
|
||||||
|
func_name: str,
|
||||||
|
count: int,
|
||||||
|
nodes: int,
|
||||||
|
list_args: List[List],
|
||||||
|
gpus_per_task: int,
|
||||||
|
cpus_per_task: int,
|
||||||
|
mem_per_task: int, # MB
|
||||||
|
list_kwargs: List[Dict] | None = None,
|
||||||
|
env_vars: Optional[Dict] = None,
|
||||||
|
amend_torch_dist_env: bool = False,
|
||||||
|
):
|
||||||
|
""" Submit an array of jobs to Ray with ray placement groups.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if count % nodes != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Count {count} is not divisible by nodes {nodes}. "
|
||||||
|
"Please ensure that count is a multiple of nodes."
|
||||||
|
)
|
||||||
|
assert len(list_args) == count, (
|
||||||
|
f"Length of list_args {len(list_args)} does not match count {count}."
|
||||||
|
)
|
||||||
|
if list_kwargs is not None:
|
||||||
|
assert len(list_kwargs) == count, (
|
||||||
|
f"Length of list_kwargs {len(list_kwargs)} does not match count {count}."
|
||||||
|
)
|
||||||
|
|
||||||
|
tasks_per_node = count // nodes
|
||||||
|
gpus_per_node = gpus_per_task * tasks_per_node
|
||||||
|
cpus_per_node = cpus_per_task * tasks_per_node
|
||||||
|
mem_per_node = mem_per_task * tasks_per_node
|
||||||
|
|
||||||
|
placement_group = ray.util.placement_group(
|
||||||
|
bundles=[
|
||||||
|
{
|
||||||
|
"CPU": cpus_per_node,
|
||||||
|
"GPU": gpus_per_node,
|
||||||
|
"memory": mem_per_node * 1024 * 1024, # Convert MB to bytes
|
||||||
|
}
|
||||||
|
] * nodes,
|
||||||
|
strategy="STRICT_SPREAD",
|
||||||
|
)
|
||||||
|
ray.get(placement_group.ready())
|
||||||
|
|
||||||
|
futures = []
|
||||||
|
for i in enumerate(list_args):
|
||||||
|
args = list_args[i]
|
||||||
|
kwargs = list_kwargs[i] if list_kwargs is not None else {}
|
||||||
|
|
||||||
|
# manage environment variables
|
||||||
|
env_vars = env_vars or {}
|
||||||
|
assert (
|
||||||
|
"CUDA_VISIBLE_DEVICES" not in env_vars
|
||||||
|
), "CUDA_VISIBLE_DEVICES should be automatically resolved by Launcher instead of manually assigned."
|
||||||
|
|
||||||
|
gpu_id_start = (i % tasks_per_node) * gpus_per_task
|
||||||
|
gpu_id_end = ((i % tasks_per_node) + 1) * gpus_per_task
|
||||||
|
node_id = i // tasks_per_node
|
||||||
|
_env_vars = {
|
||||||
|
**env_vars,
|
||||||
|
"CUDA_VISIBLE_DEVICES": ",".join(
|
||||||
|
str(x) for x in range(gpu_id_start, gpu_id_end)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
if amend_torch_dist_env:
|
||||||
|
assert gpus_per_task == 1
|
||||||
|
_env_vars.update({
|
||||||
|
"RANK": str(i),
|
||||||
|
"WORLD_SIZE": str(count),
|
||||||
|
"LOCAL_RANK": str(i % tasks_per_node),
|
||||||
|
})
|
||||||
|
|
||||||
|
future = self.submit(
|
||||||
|
job_name=f"{job_name}:{i}",
|
||||||
|
file_path=file_path,
|
||||||
|
func_name=func_name,
|
||||||
|
gpus=gpus_per_task,
|
||||||
|
cpus=cpus_per_task,
|
||||||
|
mem=mem_per_task,
|
||||||
|
env_vars=_env_vars,
|
||||||
|
placement_group=placement_group,
|
||||||
|
bundle_index=node_id,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
return futures
|
||||||
|
|
||||||
|
def stop(self, job_name: str, force: bool = False):
|
||||||
|
""" Stop a job by name. """
|
||||||
|
if job_name in self.jobs:
|
||||||
|
future = self.jobs[job_name]
|
||||||
|
try:
|
||||||
|
ray.cancel(future, force=force)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to cancel job {job_name}: {e}")
|
||||||
|
return
|
||||||
|
self.jobs.pop(job_name, None)
|
||||||
|
logger.info(f"Job {job_name} stopped.")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Job {job_name} not found in running jobs.")
|
||||||
|
|
||||||
|
def stop_all(self, force: bool = False):
|
||||||
|
""" Stop all jobs. """
|
||||||
|
for job_name in list(self.jobs.keys()):
|
||||||
|
self.stop(job_name, force=force)
|
||||||
|
logger.info("All jobs stopped.")
|
||||||
|
self.jobs.clear()
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
""" Check every SCHEDULER_WAIT_CHECK_TIME_INTERVAL seconds for the status of all jobs.
|
||||||
|
If any jobs failed terminate all jobs, and return.
|
||||||
|
If any jobs completed, remove them from self.jobs.
|
||||||
|
If all jobs are completed, return.
|
||||||
|
"""
|
||||||
|
while self.jobs:
|
||||||
|
completed_jobs = []
|
||||||
|
for job_name, future in list(self.jobs.items()):
|
||||||
|
try:
|
||||||
|
r = ray.get(future, timeout=0.1)
|
||||||
|
logger.info(f"Job {job_name} completed with result: {r}")
|
||||||
|
completed_jobs.append(job_name)
|
||||||
|
except ray.exceptions.GetTimeoutError:
|
||||||
|
continue
|
||||||
|
except ray.exceptions.RayTaskError as e:
|
||||||
|
logger.error(f"Job {job_name} failed with error: {e}, stopping all jobs.")
|
||||||
|
self.stop_all(force=True)
|
||||||
|
return
|
||||||
|
for job_name in completed_jobs:
|
||||||
|
self.jobs.pop(job_name, None)
|
||||||
|
logger.info(f"Job {job_name} completed. Removed.")
|
||||||
|
time.sleep(RAY_WAIT_CHECK_TIME_INTERVAL)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# usage: python -m arealite.launcher.ray <entry_point> --config <config_path> [<additional_args>] launcher.ray.main_func_name=<main_func_name_in_entry_point>
|
||||||
|
ray.init()
|
||||||
|
config, config_file = parse_cli_args(sys.argv[2:])
|
||||||
|
|
||||||
|
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
|
||||||
|
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
|
||||||
|
name_resolve.reconfigure(config.cluster.name_resolve)
|
||||||
|
|
||||||
|
n_nodes = config.n_nodes
|
||||||
|
n_gpus_per_node = config.n_gpus_per_node
|
||||||
|
if n_gpus_per_node < config.cluster.n_gpus_per_node:
|
||||||
|
raise ValueError(
|
||||||
|
f"Slurm Launcher requires at least {config.cluster.n_gpus_per_node} (#GPUs per node) GPU. For usecases of less GPUs, use LocalLauncher instead."
|
||||||
|
)
|
||||||
|
elif n_gpus_per_node > config.cluster.n_gpus_per_node:
|
||||||
|
raise ValueError(
|
||||||
|
f"#GPU per node required by experiment ({n_gpus_per_node}) is larger than #GPU per node in the cluster ({config.cluster.n_gpus_per_node})."
|
||||||
|
)
|
||||||
|
|
||||||
|
launcher = RayLauncher(
|
||||||
|
experiment_name=config.experiment_name,
|
||||||
|
trial_name=config.trial_name,
|
||||||
|
fileroot=config.cluster.fileroot,
|
||||||
|
)
|
||||||
|
allocation_mode = config.allocation_mode
|
||||||
|
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||||
|
sglang_cmds = []
|
||||||
|
sglang_addrs = []
|
||||||
|
n_sglang_nodes = 0
|
||||||
|
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
||||||
|
# Launcher should launch SGLang servers according to allocation mode.
|
||||||
|
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
|
||||||
|
assert (
|
||||||
|
allocation_mode.gen_pp_size == 1
|
||||||
|
), "Pipeline generation in SGLang is not supported for now."
|
||||||
|
assert (
|
||||||
|
allocation_mode.gen_tp_size <= config.cluster.n_gpus_per_node
|
||||||
|
), "Currently only support SGLang TP size less <= #GPUs per node."
|
||||||
|
sglang_world_size = allocation_mode.gen_world_size
|
||||||
|
sglang_tp_size = allocation_mode.gen_tp_size
|
||||||
|
n_sglang_servers = allocation_mode.gen_dp_size
|
||||||
|
n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node
|
||||||
|
n_sglang_servers_per_node = config.cluster.n_gpus_per_node // sglang_tp_size
|
||||||
|
|
||||||
|
base_seed = config.sglang.random_seed
|
||||||
|
|
||||||
|
sglang_args_list = [
|
||||||
|
sys.argv[2:] + [
|
||||||
|
f"sglang.random_seed={base_seed + i}"
|
||||||
|
] for i in range(n_sglang_servers)
|
||||||
|
]
|
||||||
|
sglang_entry_point = pathlib.Path(__file__).resolve().joinpath("sglang_server.py")
|
||||||
|
sglang_main_func_name = "main_sglang_server"
|
||||||
|
launcher.submit_array(
|
||||||
|
job_name="llm_server",
|
||||||
|
file_path=sglang_entry_point,
|
||||||
|
func_name=sglang_main_func_name,
|
||||||
|
count=n_sglang_servers,
|
||||||
|
nodes=n_sglang_nodes,
|
||||||
|
list_args=sglang_args_list,
|
||||||
|
gpus_per_task=sglang_tp_size,
|
||||||
|
cpus_per_task=config.launcher.inference_server_cpus_per_gpu * sglang_tp_size,
|
||||||
|
mem_per_task=config.launcher.inference_server_mem_per_gpu * sglang_tp_size,
|
||||||
|
env_vars=get_env_vars(
|
||||||
|
config.cluster.cluster_name,
|
||||||
|
config.launcher.inference_server_env_vars,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get SGLang slurm nodes, find the hosts
|
||||||
|
sglang_addrs = wait_sglang_server_addrs(
|
||||||
|
config.experiment_name,
|
||||||
|
config.trial_name,
|
||||||
|
n_sglang_servers,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer_n_nodes = n_nodes - n_sglang_nodes
|
||||||
|
trainer_entry_point = sys.argv[1]
|
||||||
|
trainer_main_func_name = config.launcher.ray.main_func_name
|
||||||
|
n_trainer_processes = trainer_n_nodes * config.cluster.n_gpus_per_node
|
||||||
|
trainer_args_list = [
|
||||||
|
sys.argv[2:] for _ in range(n_trainer_processes)
|
||||||
|
]
|
||||||
|
if not config.server_only:
|
||||||
|
# launch trainers
|
||||||
|
launcher.submit_array(
|
||||||
|
job_name="trainer",
|
||||||
|
file_path=trainer_entry_point,
|
||||||
|
func_name=trainer_main_func_name,
|
||||||
|
count=trainer_n_nodes * config.cluster.n_gpus_per_node,
|
||||||
|
nodes=trainer_n_nodes,
|
||||||
|
list_args=trainer_args_list,
|
||||||
|
gpus_per_task=1,
|
||||||
|
cpus_per_task=config.launcher.trainer_cpus_per_gpu,
|
||||||
|
mem_per_task=config.launcher.trainer_mem_per_gpu,
|
||||||
|
env_vars=dict(
|
||||||
|
**get_env_vars(
|
||||||
|
config.cluster.cluster_name,
|
||||||
|
config.launcher.trainer_env_vars,
|
||||||
|
),
|
||||||
|
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
|
||||||
|
),
|
||||||
|
amend_torch_dist_env=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
launcher.wait()
|
||||||
|
except (KeyboardInterrupt, JobException, TimeoutError) as e:
|
||||||
|
launcher.stop_all(force=True)
|
||||||
|
raise e
|
|
@ -147,9 +147,16 @@ class SGLangServerWrapper:
|
||||||
f"SGLang server at http://{host_ip}:{server_port} exits, returncode={return_code}"
|
f"SGLang server at http://{host_ip}:{server_port} exits, returncode={return_code}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.server_process and self.server_process.poll() is None:
|
||||||
|
logger.info("Terminating SGLang server process...")
|
||||||
|
self.server_process.terminate()
|
||||||
|
self.server_process.wait()
|
||||||
|
logger.info("SGLang server process terminated.")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
config, _ = parse_cli_args(sys.argv[2:])
|
def main_sglang_server(argv):
|
||||||
|
config, _ = parse_cli_args(argv)
|
||||||
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
|
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
|
||||||
config.cluster.name_resolve = to_structured_cfg(
|
config.cluster.name_resolve = to_structured_cfg(
|
||||||
config.cluster.name_resolve, NameResolveConfig
|
config.cluster.name_resolve, NameResolveConfig
|
||||||
|
@ -169,3 +176,7 @@ if __name__ == "__main__":
|
||||||
n_gpus_per_node=config.n_gpus_per_node,
|
n_gpus_per_node=config.n_gpus_per_node,
|
||||||
)
|
)
|
||||||
sglang_server.run()
|
sglang_server.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main_sglang_server(sys.argv[1:])
|
|
@ -8,8 +8,6 @@ import sys
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Literal, Optional, Tuple
|
from typing import Dict, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
import realhf.base.logging as logging
|
import realhf.base.logging as logging
|
||||||
from arealite.api.cli_args import (
|
from arealite.api.cli_args import (
|
||||||
BaseExperimentConfig,
|
BaseExperimentConfig,
|
||||||
|
@ -20,197 +18,24 @@ from arealite.api.cli_args import (
|
||||||
to_structured_cfg,
|
to_structured_cfg,
|
||||||
)
|
)
|
||||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||||
|
from arealite.utils.launcher import (
|
||||||
|
get_env_vars,
|
||||||
|
wait_sglang_server_addrs,
|
||||||
|
)
|
||||||
|
from arealite.utils.slurm import (
|
||||||
|
cancel_jobs,
|
||||||
|
query_jobs,
|
||||||
|
SBATCH_SCRIPT_TEMPLATE,
|
||||||
|
DEFAULT_SRUN_CMD_TEMPLATE
|
||||||
|
)
|
||||||
from realhf.base import logging, name_resolve, names
|
from realhf.base import logging, name_resolve, names
|
||||||
from realhf.scheduler.client import JobException, JobInfo, JobState
|
from realhf.scheduler.client import JobException, JobInfo, JobState
|
||||||
|
|
||||||
logger = logging.getLogger("SlurmLauncher")
|
logger = logging.getLogger("SlurmLauncher")
|
||||||
|
|
||||||
|
SLURM_WAIT_CHECK_TIME_INTERVAL = 5
|
||||||
SQUEUE_FIELDS = [
|
|
||||||
"JobID",
|
|
||||||
"State",
|
|
||||||
"SubmitTime",
|
|
||||||
"StartTime",
|
|
||||||
"Name",
|
|
||||||
"NodeList",
|
|
||||||
"UserName",
|
|
||||||
"MaxCPUs",
|
|
||||||
"cpus-per-task",
|
|
||||||
"NumTasks",
|
|
||||||
"tres-alloc",
|
|
||||||
]
|
|
||||||
STATUS_MAPPING = {
|
|
||||||
"RUNNING": JobState.RUNNING,
|
|
||||||
"COMPLETING": JobState.RUNNING,
|
|
||||||
"PENDING": JobState.PENDING,
|
|
||||||
"CANCELLED": JobState.CANCELLED,
|
|
||||||
"FAILED": JobState.FAILED,
|
|
||||||
"COMPLETED": JobState.COMPLETED,
|
|
||||||
"OUT_OF_MEMORY": JobState.FAILED,
|
|
||||||
"DEADLINE": JobState.COMPLETED,
|
|
||||||
"TIMEOUT": JobState.COMPLETED,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def cancel_jobs(
|
|
||||||
slurm_names: Optional[List[str]] = None,
|
|
||||||
slurm_ids: Optional[List[int]] = None,
|
|
||||||
signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL",
|
|
||||||
):
|
|
||||||
assert (
|
|
||||||
slurm_names is not None or slurm_ids is not None
|
|
||||||
), "Must specify slurm_names or slurm_ids."
|
|
||||||
assert not (
|
|
||||||
slurm_names and slurm_ids
|
|
||||||
), "Cannot specify both slurm_names and slurm_ids."
|
|
||||||
cmd = ["scancel", "-s", signal]
|
|
||||||
if slurm_names is not None:
|
|
||||||
cmd += ["-n", ",".join(slurm_names)]
|
|
||||||
elif slurm_ids is not None:
|
|
||||||
cmd += [",".join(str(s) for s in slurm_ids)]
|
|
||||||
subprocess.check_call(cmd)
|
|
||||||
logger.info(
|
|
||||||
f"Cancelled Slurm job with signal {signal}: "
|
|
||||||
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def query_jobs(
|
|
||||||
slurm_names: Optional[List[str]] = None,
|
|
||||||
slurm_ids: Optional[List[int]] = None,
|
|
||||||
status: str = "all",
|
|
||||||
delimiter: str = "__PSI__",
|
|
||||||
) -> List[JobInfo]:
|
|
||||||
squeue_format = f":.{delimiter},".join(SQUEUE_FIELDS)
|
|
||||||
cmd = ["squeue", "-O", squeue_format, f"-t{status}"]
|
|
||||||
if slurm_names is not None:
|
|
||||||
cmd += ["-n", ",".join(slurm_names)]
|
|
||||||
if slurm_ids is not None:
|
|
||||||
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
|
|
||||||
|
|
||||||
output = (
|
|
||||||
subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip()
|
|
||||||
)
|
|
||||||
rs = []
|
|
||||||
for line in output.split("\n")[1:]:
|
|
||||||
job_id, state, submit_time, start_time, slurm_name, nodelist, *_ = line.split(
|
|
||||||
delimiter
|
|
||||||
)
|
|
||||||
rs.append(
|
|
||||||
JobInfo(
|
|
||||||
name=slurm_name,
|
|
||||||
state=STATUS_MAPPING[state],
|
|
||||||
host=nodelist,
|
|
||||||
submit_time=submit_time,
|
|
||||||
start_time=start_time,
|
|
||||||
slurm_id=int(job_id.strip()),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return rs
|
|
||||||
|
|
||||||
|
|
||||||
def parse_slurm_nodelist(nodelist: str) -> List[str]:
|
|
||||||
return (
|
|
||||||
subprocess.check_output(
|
|
||||||
[
|
|
||||||
"scontrol",
|
|
||||||
"show",
|
|
||||||
"hostnames",
|
|
||||||
nodelist,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
.decode("utf-8")
|
|
||||||
.strip()
|
|
||||||
.split("\n")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_slurm_host_ip(node: str):
|
|
||||||
try:
|
|
||||||
cmd = f"srun --overlap --mpi=pmi2 --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address"
|
|
||||||
return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip()
|
|
||||||
except subprocess.CalledProcessError:
|
|
||||||
logger.warning(f"Get slurm host ip for node {node} failed.")
|
|
||||||
|
|
||||||
|
|
||||||
SGLANG_SERVER_TIMEOUT_SECONDS = 180
|
|
||||||
SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5
|
|
||||||
|
|
||||||
|
|
||||||
SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash
|
|
||||||
{sbatch_options}
|
|
||||||
|
|
||||||
# Getting the node names
|
|
||||||
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
|
|
||||||
echo nodes=$nodes
|
|
||||||
|
|
||||||
nodes_array=($nodes)
|
|
||||||
echo node_array=$nodes_array
|
|
||||||
|
|
||||||
head_node=${{nodes_array[0]}}
|
|
||||||
echo head_node=$head_node
|
|
||||||
|
|
||||||
# Getting the head node IP address
|
|
||||||
head_node_ip=$(srun --overlap --mpi=pmi2 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address)
|
|
||||||
echo head_node_ip=$head_node_ip
|
|
||||||
|
|
||||||
# srun commands
|
|
||||||
{srun_cmds}
|
|
||||||
|
|
||||||
wait
|
|
||||||
"""
|
|
||||||
|
|
||||||
SRUN_CMD_TEMPLATE: str = """srun --overlap --mpi=pmi2 -K -l --chdir $PWD --nodelist=${{nodes_array[{node_id}]}} \\
|
|
||||||
--nodes={nodes} --ntasks={ntasks} --gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} \\
|
|
||||||
--mem-per-cpu={mem_per_cpu}M {apptainer_name} exec {apptainer_options} --bind {container_mounts} \\
|
|
||||||
{container_env_strings} \\
|
|
||||||
{container_image} \\
|
|
||||||
{cmd} &
|
|
||||||
"""
|
|
||||||
|
|
||||||
LOCAL_CACHE_DIR = "/tmp/arealite"
|
|
||||||
PYTORCH_KERNEL_CACHE_PATH = (
|
|
||||||
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
|
|
||||||
)
|
|
||||||
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
|
|
||||||
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
|
|
||||||
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
|
|
||||||
BASE_ENVIRONS = {
|
|
||||||
"TOKENIZERS_PARALLELISM": "true",
|
|
||||||
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
|
|
||||||
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
|
||||||
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
|
|
||||||
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
|
|
||||||
"HF_ENDPOINT": "https://hf-mirror.com", # FIXME: move to user option
|
|
||||||
"PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent,
|
|
||||||
}
|
|
||||||
NA132_ENVIRONS = {
|
|
||||||
"NCCL_SOCKET_IFNAME": "bond0",
|
|
||||||
"NCCL_NET_PLUGIN": "",
|
|
||||||
"NCCL_IB_GID_INDEX": "3",
|
|
||||||
"NCCL_IB_TIMEOUT": "2",
|
|
||||||
"NCCL_IB_RETRY_CNT": "7",
|
|
||||||
"NCCL_IB_SL": "5",
|
|
||||||
"NCCL_IB_TC": "136",
|
|
||||||
"NCCL_IB_HCA": "mlx5_bond",
|
|
||||||
"NCCL_IB_QPS_PER_CONNECTION": "8",
|
|
||||||
"NCCL_SET_THREAD_NAME": "1",
|
|
||||||
"NCCL_DEBUG": "WARN",
|
|
||||||
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]:
|
|
||||||
"""Returns the environment variables for the cluster."""
|
|
||||||
additional_env_vars = {}
|
|
||||||
if additional_env_vars:
|
|
||||||
additional_env_vars = dict(
|
|
||||||
item.split("=") for item in additional_env_vars.split(",")
|
|
||||||
)
|
|
||||||
if cluster_name == "na132":
|
|
||||||
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars}
|
|
||||||
else:
|
|
||||||
return {**BASE_ENVIRONS, **additional_env_vars}
|
|
||||||
|
|
||||||
|
|
||||||
class SlurmLauncher:
|
class SlurmLauncher:
|
||||||
|
@ -261,6 +86,7 @@ class SlurmLauncher:
|
||||||
self,
|
self,
|
||||||
job_name: str,
|
job_name: str,
|
||||||
cmd: List[str] | str,
|
cmd: List[str] | str,
|
||||||
|
srun_cmd_template: str,
|
||||||
count: int,
|
count: int,
|
||||||
nodes: int,
|
nodes: int,
|
||||||
n_gpus_per_node: int,
|
n_gpus_per_node: int,
|
||||||
|
@ -271,13 +97,6 @@ class SlurmLauncher:
|
||||||
env_vars: Optional[Dict] = None,
|
env_vars: Optional[Dict] = None,
|
||||||
nodelist: Optional[str] = None,
|
nodelist: Optional[str] = None,
|
||||||
exclude: Optional[str] = None,
|
exclude: Optional[str] = None,
|
||||||
apptainer_name: Optional[str] = "singularity",
|
|
||||||
apptainer_options: Optional[Tuple[str, ...]] = (
|
|
||||||
"--no-home",
|
|
||||||
"--writable-tmpfs",
|
|
||||||
"--nv",
|
|
||||||
"--pid",
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
"""Submits and launch a job array with SBATCH.
|
"""Submits and launch a job array with SBATCH.
|
||||||
Note that a job array has one (unique) slurm name, and one (unique) slurm id.
|
Note that a job array has one (unique) slurm name, and one (unique) slurm id.
|
||||||
|
@ -359,15 +178,13 @@ class SlurmLauncher:
|
||||||
# FIXME: only for debugging, remove and replace new image
|
# FIXME: only for debugging, remove and replace new image
|
||||||
# job_cmd = f'bash -c "pip3 install -r requirements.txt; {job_cmd}"'
|
# job_cmd = f'bash -c "pip3 install -r requirements.txt; {job_cmd}"'
|
||||||
|
|
||||||
srun_cmd = SRUN_CMD_TEMPLATE.format(
|
srun_cmd = srun_cmd_template.format(
|
||||||
nodes=1,
|
nodes=1,
|
||||||
ntasks=1,
|
ntasks=1,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
n_gpus_per_node=n_gpus_per_node,
|
n_gpus_per_node=n_gpus_per_node,
|
||||||
cpus_per_task=cpus_per_task,
|
cpus_per_task=cpus_per_task,
|
||||||
mem_per_cpu=mem_per_cpu,
|
mem_per_cpu=mem_per_cpu,
|
||||||
apptainer_name=apptainer_name,
|
|
||||||
apptainer_options=" ".join(apptainer_options),
|
|
||||||
container_mounts=container_mounts or "",
|
container_mounts=container_mounts or "",
|
||||||
container_env_strings=env_string,
|
container_env_strings=env_string,
|
||||||
container_image=container_image,
|
container_image=container_image,
|
||||||
|
@ -534,7 +351,7 @@ class SlurmLauncher:
|
||||||
left.remove(slurm_id)
|
left.remove(slurm_id)
|
||||||
if update:
|
if update:
|
||||||
self.jobs.pop(slurm_info.slurm_id)
|
self.jobs.pop(slurm_info.slurm_id)
|
||||||
time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL)
|
time.sleep(SLURM_WAIT_CHECK_TIME_INTERVAL)
|
||||||
|
|
||||||
def _update_all(self):
|
def _update_all(self):
|
||||||
"""Updates the status of all jobs."""
|
"""Updates the status of all jobs."""
|
||||||
|
@ -550,7 +367,7 @@ class SlurmLauncher:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# usage: python -m arealite.launcher.slurm <entry_point> <config_path> [<args>]
|
# usage: python -m arealite.launcher.slurm <entry_point> --config <config_path> [<additional_args>]
|
||||||
config, config_file = parse_cli_args(sys.argv[2:])
|
config, config_file = parse_cli_args(sys.argv[2:])
|
||||||
|
|
||||||
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
|
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
|
||||||
|
@ -601,43 +418,30 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
sglang_cmds.append(sglang_cmd)
|
sglang_cmds.append(sglang_cmd)
|
||||||
|
|
||||||
if sglang_cmds:
|
launcher.submit_array(
|
||||||
launcher.submit_array(
|
job_name="llm_server",
|
||||||
job_name="sglang-server",
|
cmd=sglang_cmds,
|
||||||
cmd=sglang_cmds,
|
srun_cmd_template=config.launcher.srun_cmd_template or DEFAULT_SRUN_CMD_TEMPLATE,
|
||||||
count=n_sglang_servers,
|
count=n_sglang_servers,
|
||||||
nodes=n_sglang_nodes,
|
nodes=n_sglang_nodes,
|
||||||
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||||
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
|
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
|
||||||
* sglang_tp_size,
|
* sglang_tp_size,
|
||||||
mem_per_task=config.launcher.inference_server_mem_per_gpu
|
mem_per_task=config.launcher.inference_server_mem_per_gpu
|
||||||
* sglang_tp_size,
|
* sglang_tp_size,
|
||||||
container_image=config.cluster.gpu_infer_image,
|
container_image=config.cluster.gpu_infer_image,
|
||||||
container_mounts=config.cluster.mount,
|
container_mounts=config.cluster.mount,
|
||||||
env_vars=get_env_vars(
|
env_vars=get_env_vars(
|
||||||
config.cluster.cluster_name,
|
config.cluster.cluster_name,
|
||||||
config.launcher.inference_server_env_vars,
|
config.launcher.inference_server_env_vars,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get SGLang slurm nodes, find the hosts
|
# Get SGLang slurm nodes, find the hosts
|
||||||
name = names.gen_servers(config.experiment_name, config.trial_name)
|
sglang_addrs = wait_sglang_server_addrs(
|
||||||
start = time.perf_counter()
|
config.experiment_name,
|
||||||
while True:
|
config.trial_name,
|
||||||
sglang_addrs = name_resolve.get_subtree(name)
|
n_sglang_servers,
|
||||||
if len(sglang_addrs) >= n_sglang_servers:
|
)
|
||||||
logger.info(
|
|
||||||
f"Found {n_sglang_servers} SGLang servers: {', '.join(sglang_addrs)}"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
if time.perf_counter() - start > SGLANG_SERVER_TIMEOUT_SECONDS:
|
|
||||||
launcher.stop_all()
|
|
||||||
raise TimeoutError(
|
|
||||||
f"Timeout waiting for SGLang servers to be ready. "
|
|
||||||
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer_n_nodes = n_nodes - n_sglang_nodes
|
trainer_n_nodes = n_nodes - n_sglang_nodes
|
||||||
trainer_cmd_template = (
|
trainer_cmd_template = (
|
||||||
|
@ -662,6 +466,7 @@ if __name__ == "__main__":
|
||||||
launcher.submit_array(
|
launcher.submit_array(
|
||||||
job_name="trainer",
|
job_name="trainer",
|
||||||
cmd=trainer_cmds,
|
cmd=trainer_cmds,
|
||||||
|
srun_cmd_template=config.launcher.srun_cmd_template or DEFAULT_SRUN_CMD_TEMPLATE,
|
||||||
count=trainer_n_nodes,
|
count=trainer_n_nodes,
|
||||||
nodes=trainer_n_nodes,
|
nodes=trainer_n_nodes,
|
||||||
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import pathlib
|
||||||
|
import getpass
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from realhf.base import names, name_resolve, logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("Launcher Utils")
|
||||||
|
|
||||||
|
LOCAL_CACHE_DIR = "/tmp/arealite"
|
||||||
|
PYTORCH_KERNEL_CACHE_PATH = (
|
||||||
|
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
|
||||||
|
)
|
||||||
|
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
|
||||||
|
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
|
||||||
|
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
|
||||||
|
BASE_ENVIRONS = {
|
||||||
|
"TOKENIZERS_PARALLELISM": "true",
|
||||||
|
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
|
||||||
|
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
||||||
|
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
|
||||||
|
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
|
||||||
|
"HF_ENDPOINT": "https://hf-mirror.com", # FIXME: move to user option
|
||||||
|
"PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent,
|
||||||
|
}
|
||||||
|
NA132_ENVIRONS = {
|
||||||
|
"NCCL_SOCKET_IFNAME": "bond0",
|
||||||
|
"NCCL_NET_PLUGIN": "",
|
||||||
|
"NCCL_IB_GID_INDEX": "3",
|
||||||
|
"NCCL_IB_TIMEOUT": "2",
|
||||||
|
"NCCL_IB_RETRY_CNT": "7",
|
||||||
|
"NCCL_IB_SL": "5",
|
||||||
|
"NCCL_IB_TC": "136",
|
||||||
|
"NCCL_IB_HCA": "mlx5_bond",
|
||||||
|
"NCCL_IB_QPS_PER_CONNECTION": "8",
|
||||||
|
"NCCL_SET_THREAD_NAME": "1",
|
||||||
|
"NCCL_DEBUG": "WARN",
|
||||||
|
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
||||||
|
}
|
||||||
|
SGLANG_SERVER_WAIT_TIMEOUT_SECONDS = 180
|
||||||
|
|
||||||
|
def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]:
|
||||||
|
"""Returns the environment variables for the cluster."""
|
||||||
|
additional_env_vars = {}
|
||||||
|
if additional_env_vars:
|
||||||
|
additional_env_vars = dict(
|
||||||
|
item.split("=") for item in additional_env_vars.split(",")
|
||||||
|
)
|
||||||
|
if cluster_name == "na132":
|
||||||
|
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars}
|
||||||
|
else:
|
||||||
|
return {**BASE_ENVIRONS, **additional_env_vars}
|
||||||
|
|
||||||
|
def wait_sglang_server_addrs(
|
||||||
|
experiment_name: str,
|
||||||
|
trial_name: str,
|
||||||
|
n_sglang_servers: int,
|
||||||
|
):
|
||||||
|
# Get SGLang slurm nodes, find the hosts
|
||||||
|
name = names.gen_servers(experiment_name, trial_name)
|
||||||
|
start = time.perf_counter()
|
||||||
|
while True:
|
||||||
|
sglang_addrs = name_resolve.get_subtree(name)
|
||||||
|
if len(sglang_addrs) >= n_sglang_servers:
|
||||||
|
logger.info(
|
||||||
|
f"Found {n_sglang_servers} SGLang servers: {', '.join(sglang_addrs)}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
if time.perf_counter() - start > SGLANG_SERVER_WAIT_TIMEOUT_SECONDS:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Timeout waiting for SGLang servers to be ready. "
|
||||||
|
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
|
||||||
|
)
|
||||||
|
return sglang_addrs
|
|
@ -0,0 +1,77 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import pathlib
|
||||||
|
import getpass
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from realhf.base import names, name_resolve, logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("Launcher Utils")
|
||||||
|
|
||||||
|
LOCAL_CACHE_DIR = "/tmp/arealite"
|
||||||
|
PYTORCH_KERNEL_CACHE_PATH = (
|
||||||
|
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
|
||||||
|
)
|
||||||
|
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
|
||||||
|
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
|
||||||
|
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
|
||||||
|
BASE_ENVIRONS = {
|
||||||
|
"TOKENIZERS_PARALLELISM": "true",
|
||||||
|
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
|
||||||
|
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
||||||
|
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
|
||||||
|
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
|
||||||
|
"HF_ENDPOINT": "https://hf-mirror.com", # FIXME: move to user option
|
||||||
|
"PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent,
|
||||||
|
}
|
||||||
|
NA132_ENVIRONS = {
|
||||||
|
"NCCL_SOCKET_IFNAME": "bond0",
|
||||||
|
"NCCL_NET_PLUGIN": "",
|
||||||
|
"NCCL_IB_GID_INDEX": "3",
|
||||||
|
"NCCL_IB_TIMEOUT": "2",
|
||||||
|
"NCCL_IB_RETRY_CNT": "7",
|
||||||
|
"NCCL_IB_SL": "5",
|
||||||
|
"NCCL_IB_TC": "136",
|
||||||
|
"NCCL_IB_HCA": "mlx5_bond",
|
||||||
|
"NCCL_IB_QPS_PER_CONNECTION": "8",
|
||||||
|
"NCCL_SET_THREAD_NAME": "1",
|
||||||
|
"NCCL_DEBUG": "WARN",
|
||||||
|
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
||||||
|
}
|
||||||
|
SGLANG_SERVER_WAIT_TIMEOUT_SECONDS = 180
|
||||||
|
|
||||||
|
def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]:
|
||||||
|
"""Returns the environment variables for the cluster."""
|
||||||
|
additional_env_vars = {}
|
||||||
|
if additional_env_vars:
|
||||||
|
additional_env_vars = dict(
|
||||||
|
item.split("=") for item in additional_env_vars.split(",")
|
||||||
|
)
|
||||||
|
if cluster_name == "na132":
|
||||||
|
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars}
|
||||||
|
else:
|
||||||
|
return {**BASE_ENVIRONS, **additional_env_vars}
|
||||||
|
|
||||||
|
def wait_sglang_server_addrs(
|
||||||
|
experiment_name: str,
|
||||||
|
trial_name: str,
|
||||||
|
n_sglang_servers: int,
|
||||||
|
):
|
||||||
|
# Get SGLang slurm nodes, find the hosts
|
||||||
|
name = names.gen_servers(experiment_name, trial_name)
|
||||||
|
start = time.perf_counter()
|
||||||
|
while True:
|
||||||
|
sglang_addrs = name_resolve.get_subtree(name)
|
||||||
|
if len(sglang_addrs) >= n_sglang_servers:
|
||||||
|
logger.info(
|
||||||
|
f"Found {n_sglang_servers} SGLang servers: {', '.join(sglang_addrs)}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
if time.perf_counter() - start > SGLANG_SERVER_WAIT_TIMEOUT_SECONDS:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Timeout waiting for SGLang servers to be ready. "
|
||||||
|
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
|
||||||
|
)
|
||||||
|
return sglang_addrs
|
|
@ -0,0 +1,147 @@
|
||||||
|
import subprocess
|
||||||
|
from typing import List, Optional, Literal
|
||||||
|
|
||||||
|
from realhf.base import logging, name_resolve, names
|
||||||
|
from realhf.scheduler.client import JobException, JobInfo, JobState
|
||||||
|
|
||||||
|
logger = logging.getLogger("Slurm Utils")
|
||||||
|
|
||||||
|
|
||||||
|
SQUEUE_FIELDS = [
|
||||||
|
"JobID",
|
||||||
|
"State",
|
||||||
|
"SubmitTime",
|
||||||
|
"StartTime",
|
||||||
|
"Name",
|
||||||
|
"NodeList",
|
||||||
|
"UserName",
|
||||||
|
"MaxCPUs",
|
||||||
|
"cpus-per-task",
|
||||||
|
"NumTasks",
|
||||||
|
"tres-alloc",
|
||||||
|
]
|
||||||
|
STATUS_MAPPING = {
|
||||||
|
"RUNNING": JobState.RUNNING,
|
||||||
|
"COMPLETING": JobState.RUNNING,
|
||||||
|
"PENDING": JobState.PENDING,
|
||||||
|
"CANCELLED": JobState.CANCELLED,
|
||||||
|
"FAILED": JobState.FAILED,
|
||||||
|
"COMPLETED": JobState.COMPLETED,
|
||||||
|
"OUT_OF_MEMORY": JobState.FAILED,
|
||||||
|
"DEADLINE": JobState.COMPLETED,
|
||||||
|
"TIMEOUT": JobState.COMPLETED,
|
||||||
|
}
|
||||||
|
|
||||||
|
SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash
|
||||||
|
{sbatch_options}
|
||||||
|
|
||||||
|
# Getting the node names
|
||||||
|
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
|
||||||
|
echo nodes=$nodes
|
||||||
|
|
||||||
|
nodes_array=($nodes)
|
||||||
|
echo node_array=$nodes_array
|
||||||
|
|
||||||
|
head_node=${{nodes_array[0]}}
|
||||||
|
echo head_node=$head_node
|
||||||
|
|
||||||
|
# Getting the head node IP address
|
||||||
|
head_node_ip=$(srun --overlap --mpi=pmi2 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address)
|
||||||
|
echo head_node_ip=$head_node_ip
|
||||||
|
|
||||||
|
# srun commands
|
||||||
|
{srun_cmds}
|
||||||
|
|
||||||
|
wait
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default srun command template, using singularity as apptainer
|
||||||
|
DEFAULT_SRUN_CMD_TEMPLATE: str = """srun --overlap --mpi=pmi2 -K -l --chdir $PWD \\
|
||||||
|
--nodelist=${{nodes_array[{node_id}]}} --nodes={nodes} --ntasks={ntasks} \\
|
||||||
|
--gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} --mem-per-cpu={mem_per_cpu}M \\
|
||||||
|
singularity exec --no-home --writable-tmpfs --nv --pid \\
|
||||||
|
--bind {container_mounts} \\
|
||||||
|
{container_env_strings} \\
|
||||||
|
{container_image} \\
|
||||||
|
{cmd} &
|
||||||
|
"""
|
||||||
|
|
||||||
|
def cancel_jobs(
|
||||||
|
slurm_names: Optional[List[str]] = None,
|
||||||
|
slurm_ids: Optional[List[int]] = None,
|
||||||
|
signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL",
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
slurm_names is not None or slurm_ids is not None
|
||||||
|
), "Must specify slurm_names or slurm_ids."
|
||||||
|
assert not (
|
||||||
|
slurm_names and slurm_ids
|
||||||
|
), "Cannot specify both slurm_names and slurm_ids."
|
||||||
|
cmd = ["scancel", "-s", signal]
|
||||||
|
if slurm_names is not None:
|
||||||
|
cmd += ["-n", ",".join(slurm_names)]
|
||||||
|
elif slurm_ids is not None:
|
||||||
|
cmd += [",".join(str(s) for s in slurm_ids)]
|
||||||
|
subprocess.check_call(cmd)
|
||||||
|
logger.info(
|
||||||
|
f"Cancelled Slurm job with signal {signal}: "
|
||||||
|
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def query_jobs(
|
||||||
|
slurm_names: Optional[List[str]] = None,
|
||||||
|
slurm_ids: Optional[List[int]] = None,
|
||||||
|
status: str = "all",
|
||||||
|
delimiter: str = "__PSI__",
|
||||||
|
) -> List[JobInfo]:
|
||||||
|
squeue_format = f":.{delimiter},".join(SQUEUE_FIELDS)
|
||||||
|
cmd = ["squeue", "-O", squeue_format, f"-t{status}"]
|
||||||
|
if slurm_names is not None:
|
||||||
|
cmd += ["-n", ",".join(slurm_names)]
|
||||||
|
if slurm_ids is not None:
|
||||||
|
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
|
||||||
|
|
||||||
|
output = (
|
||||||
|
subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip()
|
||||||
|
)
|
||||||
|
rs = []
|
||||||
|
for line in output.split("\n")[1:]:
|
||||||
|
job_id, state, submit_time, start_time, slurm_name, nodelist, *_ = line.split(
|
||||||
|
delimiter
|
||||||
|
)
|
||||||
|
rs.append(
|
||||||
|
JobInfo(
|
||||||
|
name=slurm_name,
|
||||||
|
state=STATUS_MAPPING[state],
|
||||||
|
host=nodelist,
|
||||||
|
submit_time=submit_time,
|
||||||
|
start_time=start_time,
|
||||||
|
slurm_id=int(job_id.strip()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return rs
|
||||||
|
|
||||||
|
|
||||||
|
def parse_slurm_nodelist(nodelist: str) -> List[str]:
|
||||||
|
return (
|
||||||
|
subprocess.check_output(
|
||||||
|
[
|
||||||
|
"scontrol",
|
||||||
|
"show",
|
||||||
|
"hostnames",
|
||||||
|
nodelist,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
.decode("utf-8")
|
||||||
|
.strip()
|
||||||
|
.split("\n")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_slurm_host_ip(node: str):
|
||||||
|
try:
|
||||||
|
cmd = f"srun --overlap --mpi=pmi2 --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address"
|
||||||
|
return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip()
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
logger.warning(f"Get slurm host ip for node {node} failed.")
|
|
@ -106,8 +106,8 @@ def boba_reward_fn(
|
||||||
return label
|
return label
|
||||||
|
|
||||||
|
|
||||||
def main_grpo():
|
def main_grpo(argv):
|
||||||
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
|
config, _ = load_expr_config(argv, GRPOConfig)
|
||||||
config: GRPOConfig
|
config: GRPOConfig
|
||||||
|
|
||||||
rank = int(os.getenv("RANK"))
|
rank = int(os.getenv("RANK"))
|
||||||
|
@ -238,4 +238,4 @@ def main_grpo():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main_grpo()
|
main_grpo(sys.argv[1:])
|
||||||
|
|
|
@ -75,8 +75,8 @@ def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **k
|
||||||
return int(sol.strip() == ans.strip())
|
return int(sol.strip() == ans.strip())
|
||||||
|
|
||||||
|
|
||||||
def main_grpo():
|
def main_grpo(argv):
|
||||||
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
|
config, _ = load_expr_config(argv, GRPOConfig)
|
||||||
config: GRPOConfig
|
config: GRPOConfig
|
||||||
|
|
||||||
rank = int(os.getenv("RANK"))
|
rank = int(os.getenv("RANK"))
|
||||||
|
@ -250,4 +250,4 @@ def main_grpo():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main_grpo()
|
main_grpo(sys.argv[1:])
|
||||||
|
|
|
@ -38,8 +38,8 @@ def get_gsm8k_dataset(split, tokenizer, rank, world_size):
|
||||||
return process_gsm8k_sft_dataset(dataset, tokenizer)
|
return process_gsm8k_sft_dataset(dataset, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
def main_sft():
|
def main_sft(argv):
|
||||||
config, _ = load_expr_config(sys.argv[1:], SFTConfig)
|
config, _ = load_expr_config(argv, SFTConfig)
|
||||||
config: SFTConfig
|
config: SFTConfig
|
||||||
|
|
||||||
rank = int(os.getenv("RANK"))
|
rank = int(os.getenv("RANK"))
|
||||||
|
@ -120,4 +120,4 @@ def main_sft():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main_sft()
|
main_sft(sys.argv[1:])
|
||||||
|
|
Loading…
Reference in New Issue