mirror of https://github.com/inclusionAI/AReaL
ray launcher before test
This commit is contained in:
parent
29172e0e10
commit
44947fe7fa
|
@ -624,6 +624,23 @@ class DatasetConfig:
|
|||
drop_last: bool = field(default=True)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class RayLauncherConfig:
|
||||
"""Configuration for launching the SGLang server with Ray."""
|
||||
main_func_name: int = field(
|
||||
default="main",
|
||||
metadata={"help": "Name of the main function in the entrypoint file to run in Ray workers."},
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class SlurmLauncherConfig:
|
||||
"""Configuration for launching the SGLang server with Slurm."""
|
||||
srun_cmd_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Template for the srun command to launch the SGLang server."},
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class LauncherConfig:
|
||||
"""Configuration for launching the SGLang server."""
|
||||
|
@ -662,6 +679,10 @@ class LauncherConfig:
|
|||
default=27015,
|
||||
metadata={"help": "Trainer port used for torch.distributed initialization."},
|
||||
)
|
||||
ray: RayLauncherConfig = field(
|
||||
default_factory=RayLauncherConfig,
|
||||
metadata={"help": "Ray launcher configuration."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -0,0 +1,353 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import getpass
|
||||
import pathlib
|
||||
from typing import List, Optional, Dict
|
||||
import importlib.util
|
||||
|
||||
import ray
|
||||
import ray.exceptions
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
from ray.job_submission import JobStatus as RayJobStatus
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from arealite.api.cli_args import (
|
||||
ClusterSpecConfig,
|
||||
LauncherConfig,
|
||||
SGLangConfig,
|
||||
parse_cli_args,
|
||||
to_structured_cfg,
|
||||
)
|
||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.utils.launcher import (
|
||||
get_env_vars,
|
||||
wait_sglang_server_addrs
|
||||
)
|
||||
from realhf.base import logging, name_resolve
|
||||
from realhf.scheduler.client import JobException
|
||||
|
||||
logger = logging.getLogger("RayLauncher")
|
||||
|
||||
RAY_WAIT_CHECK_TIME_INTERVAL = 5 # seconds
|
||||
|
||||
|
||||
def run_func(file_path, function_name, *args, **kwargs):
|
||||
# Convert the file path to a module name
|
||||
module_name = file_path.replace('/', '_').replace('.', '_')
|
||||
|
||||
# Load the module from file path
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Get the function and execute it
|
||||
function = getattr(module, function_name)
|
||||
return function(*args, **kwargs)
|
||||
|
||||
|
||||
class RayLauncher:
|
||||
def __init__(self, experiment_name: str, trial_name: str, fileroot: str):
|
||||
self.experiment_name = experiment_name
|
||||
self.trial_name = trial_name
|
||||
self.fileroot = fileroot
|
||||
|
||||
# job_name to ray future
|
||||
self.jobs = {}
|
||||
|
||||
@property
|
||||
def run_name(self):
|
||||
return f"{self.experiment_name}_{self.trial_name}"
|
||||
|
||||
def log_path_of(self, job_name: str) -> str:
|
||||
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
return os.path.join(log_path, f"{job_name}.log")
|
||||
|
||||
def submit(
|
||||
self,
|
||||
job_name: str,
|
||||
file_path: str,
|
||||
func_name: str,
|
||||
gpus: int,
|
||||
cpus: int,
|
||||
mem: int, # MB
|
||||
env_vars: Optional[Dict] = None,
|
||||
placement_group: Optional[PlacementGroup] = None,
|
||||
bundle_index: Optional[int] = None,
|
||||
*args, # arguments to pass to the function
|
||||
**kwargs, # keyword arguments to pass to the function
|
||||
):
|
||||
runtime_env = RuntimeEnv(
|
||||
env_vars = env_vars or dict(),
|
||||
)
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
bundle_index=bundle_index,
|
||||
) if placement_group is not None else "DEFAULT"
|
||||
future = ray.remote(
|
||||
num_cpus=gpus,
|
||||
num_gpus=cpus,
|
||||
memory=mem*1024*1024, # Convert MB to bytes
|
||||
runtime_env=runtime_env,
|
||||
scheduling_strategy=scheduling_strategy
|
||||
)(run_func).remote(
|
||||
file_path, func_name, *args, **kwargs
|
||||
)
|
||||
self.jobs[job_name] = future
|
||||
return future
|
||||
|
||||
def submit_array(
|
||||
self,
|
||||
job_name: str,
|
||||
file_path: str,
|
||||
func_name: str,
|
||||
count: int,
|
||||
nodes: int,
|
||||
list_args: List[List],
|
||||
gpus_per_task: int,
|
||||
cpus_per_task: int,
|
||||
mem_per_task: int, # MB
|
||||
list_kwargs: List[Dict] | None = None,
|
||||
env_vars: Optional[Dict] = None,
|
||||
amend_torch_dist_env: bool = False,
|
||||
):
|
||||
""" Submit an array of jobs to Ray with ray placement groups.
|
||||
"""
|
||||
|
||||
if count % nodes != 0:
|
||||
raise ValueError(
|
||||
f"Count {count} is not divisible by nodes {nodes}. "
|
||||
"Please ensure that count is a multiple of nodes."
|
||||
)
|
||||
assert len(list_args) == count, (
|
||||
f"Length of list_args {len(list_args)} does not match count {count}."
|
||||
)
|
||||
if list_kwargs is not None:
|
||||
assert len(list_kwargs) == count, (
|
||||
f"Length of list_kwargs {len(list_kwargs)} does not match count {count}."
|
||||
)
|
||||
|
||||
tasks_per_node = count // nodes
|
||||
gpus_per_node = gpus_per_task * tasks_per_node
|
||||
cpus_per_node = cpus_per_task * tasks_per_node
|
||||
mem_per_node = mem_per_task * tasks_per_node
|
||||
|
||||
placement_group = ray.util.placement_group(
|
||||
bundles=[
|
||||
{
|
||||
"CPU": cpus_per_node,
|
||||
"GPU": gpus_per_node,
|
||||
"memory": mem_per_node * 1024 * 1024, # Convert MB to bytes
|
||||
}
|
||||
] * nodes,
|
||||
strategy="STRICT_SPREAD",
|
||||
)
|
||||
ray.get(placement_group.ready())
|
||||
|
||||
futures = []
|
||||
for i in enumerate(list_args):
|
||||
args = list_args[i]
|
||||
kwargs = list_kwargs[i] if list_kwargs is not None else {}
|
||||
|
||||
# manage environment variables
|
||||
env_vars = env_vars or {}
|
||||
assert (
|
||||
"CUDA_VISIBLE_DEVICES" not in env_vars
|
||||
), "CUDA_VISIBLE_DEVICES should be automatically resolved by Launcher instead of manually assigned."
|
||||
|
||||
gpu_id_start = (i % tasks_per_node) * gpus_per_task
|
||||
gpu_id_end = ((i % tasks_per_node) + 1) * gpus_per_task
|
||||
node_id = i // tasks_per_node
|
||||
_env_vars = {
|
||||
**env_vars,
|
||||
"CUDA_VISIBLE_DEVICES": ",".join(
|
||||
str(x) for x in range(gpu_id_start, gpu_id_end)
|
||||
),
|
||||
}
|
||||
|
||||
if amend_torch_dist_env:
|
||||
assert gpus_per_task == 1
|
||||
_env_vars.update({
|
||||
"RANK": str(i),
|
||||
"WORLD_SIZE": str(count),
|
||||
"LOCAL_RANK": str(i % tasks_per_node),
|
||||
})
|
||||
|
||||
future = self.submit(
|
||||
job_name=f"{job_name}:{i}",
|
||||
file_path=file_path,
|
||||
func_name=func_name,
|
||||
gpus=gpus_per_task,
|
||||
cpus=cpus_per_task,
|
||||
mem=mem_per_task,
|
||||
env_vars=_env_vars,
|
||||
placement_group=placement_group,
|
||||
bundle_index=node_id,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
return futures
|
||||
|
||||
def stop(self, job_name: str, force: bool = False):
|
||||
""" Stop a job by name. """
|
||||
if job_name in self.jobs:
|
||||
future = self.jobs[job_name]
|
||||
try:
|
||||
ray.cancel(future, force=force)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel job {job_name}: {e}")
|
||||
return
|
||||
self.jobs.pop(job_name, None)
|
||||
logger.info(f"Job {job_name} stopped.")
|
||||
else:
|
||||
logger.warning(f"Job {job_name} not found in running jobs.")
|
||||
|
||||
def stop_all(self, force: bool = False):
|
||||
""" Stop all jobs. """
|
||||
for job_name in list(self.jobs.keys()):
|
||||
self.stop(job_name, force=force)
|
||||
logger.info("All jobs stopped.")
|
||||
self.jobs.clear()
|
||||
|
||||
def wait(self):
|
||||
""" Check every SCHEDULER_WAIT_CHECK_TIME_INTERVAL seconds for the status of all jobs.
|
||||
If any jobs failed terminate all jobs, and return.
|
||||
If any jobs completed, remove them from self.jobs.
|
||||
If all jobs are completed, return.
|
||||
"""
|
||||
while self.jobs:
|
||||
completed_jobs = []
|
||||
for job_name, future in list(self.jobs.items()):
|
||||
try:
|
||||
r = ray.get(future, timeout=0.1)
|
||||
logger.info(f"Job {job_name} completed with result: {r}")
|
||||
completed_jobs.append(job_name)
|
||||
except ray.exceptions.GetTimeoutError:
|
||||
continue
|
||||
except ray.exceptions.RayTaskError as e:
|
||||
logger.error(f"Job {job_name} failed with error: {e}, stopping all jobs.")
|
||||
self.stop_all(force=True)
|
||||
return
|
||||
for job_name in completed_jobs:
|
||||
self.jobs.pop(job_name, None)
|
||||
logger.info(f"Job {job_name} completed. Removed.")
|
||||
time.sleep(RAY_WAIT_CHECK_TIME_INTERVAL)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# usage: python -m arealite.launcher.ray <entry_point> --config <config_path> [<additional_args>] launcher.ray.main_func_name=<main_func_name_in_entry_point>
|
||||
ray.init()
|
||||
config, config_file = parse_cli_args(sys.argv[2:])
|
||||
|
||||
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
|
||||
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
|
||||
name_resolve.reconfigure(config.cluster.name_resolve)
|
||||
|
||||
n_nodes = config.n_nodes
|
||||
n_gpus_per_node = config.n_gpus_per_node
|
||||
if n_gpus_per_node < config.cluster.n_gpus_per_node:
|
||||
raise ValueError(
|
||||
f"Slurm Launcher requires at least {config.cluster.n_gpus_per_node} (#GPUs per node) GPU. For usecases of less GPUs, use LocalLauncher instead."
|
||||
)
|
||||
elif n_gpus_per_node > config.cluster.n_gpus_per_node:
|
||||
raise ValueError(
|
||||
f"#GPU per node required by experiment ({n_gpus_per_node}) is larger than #GPU per node in the cluster ({config.cluster.n_gpus_per_node})."
|
||||
)
|
||||
|
||||
launcher = RayLauncher(
|
||||
experiment_name=config.experiment_name,
|
||||
trial_name=config.trial_name,
|
||||
fileroot=config.cluster.fileroot,
|
||||
)
|
||||
allocation_mode = config.allocation_mode
|
||||
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||
sglang_cmds = []
|
||||
sglang_addrs = []
|
||||
n_sglang_nodes = 0
|
||||
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
||||
# Launcher should launch SGLang servers according to allocation mode.
|
||||
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
|
||||
assert (
|
||||
allocation_mode.gen_pp_size == 1
|
||||
), "Pipeline generation in SGLang is not supported for now."
|
||||
assert (
|
||||
allocation_mode.gen_tp_size <= config.cluster.n_gpus_per_node
|
||||
), "Currently only support SGLang TP size less <= #GPUs per node."
|
||||
sglang_world_size = allocation_mode.gen_world_size
|
||||
sglang_tp_size = allocation_mode.gen_tp_size
|
||||
n_sglang_servers = allocation_mode.gen_dp_size
|
||||
n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node
|
||||
n_sglang_servers_per_node = config.cluster.n_gpus_per_node // sglang_tp_size
|
||||
|
||||
base_seed = config.sglang.random_seed
|
||||
|
||||
sglang_args_list = [
|
||||
sys.argv[2:] + [
|
||||
f"sglang.random_seed={base_seed + i}"
|
||||
] for i in range(n_sglang_servers)
|
||||
]
|
||||
sglang_entry_point = pathlib.Path(__file__).resolve().joinpath("sglang_server.py")
|
||||
sglang_main_func_name = "main_sglang_server"
|
||||
launcher.submit_array(
|
||||
job_name="llm_server",
|
||||
file_path=sglang_entry_point,
|
||||
func_name=sglang_main_func_name,
|
||||
count=n_sglang_servers,
|
||||
nodes=n_sglang_nodes,
|
||||
list_args=sglang_args_list,
|
||||
gpus_per_task=sglang_tp_size,
|
||||
cpus_per_task=config.launcher.inference_server_cpus_per_gpu * sglang_tp_size,
|
||||
mem_per_task=config.launcher.inference_server_mem_per_gpu * sglang_tp_size,
|
||||
env_vars=get_env_vars(
|
||||
config.cluster.cluster_name,
|
||||
config.launcher.inference_server_env_vars,
|
||||
),
|
||||
)
|
||||
|
||||
# Get SGLang slurm nodes, find the hosts
|
||||
sglang_addrs = wait_sglang_server_addrs(
|
||||
config.experiment_name,
|
||||
config.trial_name,
|
||||
n_sglang_servers,
|
||||
)
|
||||
|
||||
trainer_n_nodes = n_nodes - n_sglang_nodes
|
||||
trainer_entry_point = sys.argv[1]
|
||||
trainer_main_func_name = config.launcher.ray.main_func_name
|
||||
n_trainer_processes = trainer_n_nodes * config.cluster.n_gpus_per_node
|
||||
trainer_args_list = [
|
||||
sys.argv[2:] for _ in range(n_trainer_processes)
|
||||
]
|
||||
if not config.server_only:
|
||||
# launch trainers
|
||||
launcher.submit_array(
|
||||
job_name="trainer",
|
||||
file_path=trainer_entry_point,
|
||||
func_name=trainer_main_func_name,
|
||||
count=trainer_n_nodes * config.cluster.n_gpus_per_node,
|
||||
nodes=trainer_n_nodes,
|
||||
list_args=trainer_args_list,
|
||||
gpus_per_task=1,
|
||||
cpus_per_task=config.launcher.trainer_cpus_per_gpu,
|
||||
mem_per_task=config.launcher.trainer_mem_per_gpu,
|
||||
env_vars=dict(
|
||||
**get_env_vars(
|
||||
config.cluster.cluster_name,
|
||||
config.launcher.trainer_env_vars,
|
||||
),
|
||||
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
|
||||
),
|
||||
amend_torch_dist_env=True,
|
||||
)
|
||||
|
||||
try:
|
||||
launcher.wait()
|
||||
except (KeyboardInterrupt, JobException, TimeoutError) as e:
|
||||
launcher.stop_all(force=True)
|
||||
raise e
|
|
@ -147,9 +147,16 @@ class SGLangServerWrapper:
|
|||
f"SGLang server at http://{host_ip}:{server_port} exits, returncode={return_code}"
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
logger.info("Terminating SGLang server process...")
|
||||
self.server_process.terminate()
|
||||
self.server_process.wait()
|
||||
logger.info("SGLang server process terminated.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
config, _ = parse_cli_args(sys.argv[2:])
|
||||
|
||||
def main_sglang_server(argv):
|
||||
config, _ = parse_cli_args(argv)
|
||||
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
|
||||
config.cluster.name_resolve = to_structured_cfg(
|
||||
config.cluster.name_resolve, NameResolveConfig
|
||||
|
@ -169,3 +176,7 @@ if __name__ == "__main__":
|
|||
n_gpus_per_node=config.n_gpus_per_node,
|
||||
)
|
||||
sglang_server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_sglang_server(sys.argv[1:])
|
|
@ -8,8 +8,6 @@ import sys
|
|||
import time
|
||||
from typing import Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from arealite.api.cli_args import (
|
||||
BaseExperimentConfig,
|
||||
|
@ -20,197 +18,24 @@ from arealite.api.cli_args import (
|
|||
to_structured_cfg,
|
||||
)
|
||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.utils.launcher import (
|
||||
get_env_vars,
|
||||
wait_sglang_server_addrs,
|
||||
)
|
||||
from arealite.utils.slurm import (
|
||||
cancel_jobs,
|
||||
query_jobs,
|
||||
SBATCH_SCRIPT_TEMPLATE,
|
||||
DEFAULT_SRUN_CMD_TEMPLATE
|
||||
)
|
||||
from realhf.base import logging, name_resolve, names
|
||||
from realhf.scheduler.client import JobException, JobInfo, JobState
|
||||
|
||||
logger = logging.getLogger("SlurmLauncher")
|
||||
|
||||
|
||||
SQUEUE_FIELDS = [
|
||||
"JobID",
|
||||
"State",
|
||||
"SubmitTime",
|
||||
"StartTime",
|
||||
"Name",
|
||||
"NodeList",
|
||||
"UserName",
|
||||
"MaxCPUs",
|
||||
"cpus-per-task",
|
||||
"NumTasks",
|
||||
"tres-alloc",
|
||||
]
|
||||
STATUS_MAPPING = {
|
||||
"RUNNING": JobState.RUNNING,
|
||||
"COMPLETING": JobState.RUNNING,
|
||||
"PENDING": JobState.PENDING,
|
||||
"CANCELLED": JobState.CANCELLED,
|
||||
"FAILED": JobState.FAILED,
|
||||
"COMPLETED": JobState.COMPLETED,
|
||||
"OUT_OF_MEMORY": JobState.FAILED,
|
||||
"DEADLINE": JobState.COMPLETED,
|
||||
"TIMEOUT": JobState.COMPLETED,
|
||||
}
|
||||
SLURM_WAIT_CHECK_TIME_INTERVAL = 5
|
||||
|
||||
|
||||
def cancel_jobs(
|
||||
slurm_names: Optional[List[str]] = None,
|
||||
slurm_ids: Optional[List[int]] = None,
|
||||
signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL",
|
||||
):
|
||||
assert (
|
||||
slurm_names is not None or slurm_ids is not None
|
||||
), "Must specify slurm_names or slurm_ids."
|
||||
assert not (
|
||||
slurm_names and slurm_ids
|
||||
), "Cannot specify both slurm_names and slurm_ids."
|
||||
cmd = ["scancel", "-s", signal]
|
||||
if slurm_names is not None:
|
||||
cmd += ["-n", ",".join(slurm_names)]
|
||||
elif slurm_ids is not None:
|
||||
cmd += [",".join(str(s) for s in slurm_ids)]
|
||||
subprocess.check_call(cmd)
|
||||
logger.info(
|
||||
f"Cancelled Slurm job with signal {signal}: "
|
||||
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}"
|
||||
)
|
||||
|
||||
|
||||
def query_jobs(
|
||||
slurm_names: Optional[List[str]] = None,
|
||||
slurm_ids: Optional[List[int]] = None,
|
||||
status: str = "all",
|
||||
delimiter: str = "__PSI__",
|
||||
) -> List[JobInfo]:
|
||||
squeue_format = f":.{delimiter},".join(SQUEUE_FIELDS)
|
||||
cmd = ["squeue", "-O", squeue_format, f"-t{status}"]
|
||||
if slurm_names is not None:
|
||||
cmd += ["-n", ",".join(slurm_names)]
|
||||
if slurm_ids is not None:
|
||||
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
|
||||
|
||||
output = (
|
||||
subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip()
|
||||
)
|
||||
rs = []
|
||||
for line in output.split("\n")[1:]:
|
||||
job_id, state, submit_time, start_time, slurm_name, nodelist, *_ = line.split(
|
||||
delimiter
|
||||
)
|
||||
rs.append(
|
||||
JobInfo(
|
||||
name=slurm_name,
|
||||
state=STATUS_MAPPING[state],
|
||||
host=nodelist,
|
||||
submit_time=submit_time,
|
||||
start_time=start_time,
|
||||
slurm_id=int(job_id.strip()),
|
||||
)
|
||||
)
|
||||
return rs
|
||||
|
||||
|
||||
def parse_slurm_nodelist(nodelist: str) -> List[str]:
|
||||
return (
|
||||
subprocess.check_output(
|
||||
[
|
||||
"scontrol",
|
||||
"show",
|
||||
"hostnames",
|
||||
nodelist,
|
||||
]
|
||||
)
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
.split("\n")
|
||||
)
|
||||
|
||||
|
||||
def get_slurm_host_ip(node: str):
|
||||
try:
|
||||
cmd = f"srun --overlap --mpi=pmi2 --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address"
|
||||
return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip()
|
||||
except subprocess.CalledProcessError:
|
||||
logger.warning(f"Get slurm host ip for node {node} failed.")
|
||||
|
||||
|
||||
SGLANG_SERVER_TIMEOUT_SECONDS = 180
|
||||
SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5
|
||||
|
||||
|
||||
SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash
|
||||
{sbatch_options}
|
||||
|
||||
# Getting the node names
|
||||
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
|
||||
echo nodes=$nodes
|
||||
|
||||
nodes_array=($nodes)
|
||||
echo node_array=$nodes_array
|
||||
|
||||
head_node=${{nodes_array[0]}}
|
||||
echo head_node=$head_node
|
||||
|
||||
# Getting the head node IP address
|
||||
head_node_ip=$(srun --overlap --mpi=pmi2 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address)
|
||||
echo head_node_ip=$head_node_ip
|
||||
|
||||
# srun commands
|
||||
{srun_cmds}
|
||||
|
||||
wait
|
||||
"""
|
||||
|
||||
SRUN_CMD_TEMPLATE: str = """srun --overlap --mpi=pmi2 -K -l --chdir $PWD --nodelist=${{nodes_array[{node_id}]}} \\
|
||||
--nodes={nodes} --ntasks={ntasks} --gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} \\
|
||||
--mem-per-cpu={mem_per_cpu}M {apptainer_name} exec {apptainer_options} --bind {container_mounts} \\
|
||||
{container_env_strings} \\
|
||||
{container_image} \\
|
||||
{cmd} &
|
||||
"""
|
||||
|
||||
LOCAL_CACHE_DIR = "/tmp/arealite"
|
||||
PYTORCH_KERNEL_CACHE_PATH = (
|
||||
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
|
||||
)
|
||||
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
|
||||
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
|
||||
BASE_ENVIRONS = {
|
||||
"TOKENIZERS_PARALLELISM": "true",
|
||||
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
|
||||
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
||||
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
|
||||
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
|
||||
"HF_ENDPOINT": "https://hf-mirror.com", # FIXME: move to user option
|
||||
"PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent,
|
||||
}
|
||||
NA132_ENVIRONS = {
|
||||
"NCCL_SOCKET_IFNAME": "bond0",
|
||||
"NCCL_NET_PLUGIN": "",
|
||||
"NCCL_IB_GID_INDEX": "3",
|
||||
"NCCL_IB_TIMEOUT": "2",
|
||||
"NCCL_IB_RETRY_CNT": "7",
|
||||
"NCCL_IB_SL": "5",
|
||||
"NCCL_IB_TC": "136",
|
||||
"NCCL_IB_HCA": "mlx5_bond",
|
||||
"NCCL_IB_QPS_PER_CONNECTION": "8",
|
||||
"NCCL_SET_THREAD_NAME": "1",
|
||||
"NCCL_DEBUG": "WARN",
|
||||
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
||||
}
|
||||
|
||||
|
||||
def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]:
|
||||
"""Returns the environment variables for the cluster."""
|
||||
additional_env_vars = {}
|
||||
if additional_env_vars:
|
||||
additional_env_vars = dict(
|
||||
item.split("=") for item in additional_env_vars.split(",")
|
||||
)
|
||||
if cluster_name == "na132":
|
||||
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars}
|
||||
else:
|
||||
return {**BASE_ENVIRONS, **additional_env_vars}
|
||||
|
||||
|
||||
class SlurmLauncher:
|
||||
|
@ -261,6 +86,7 @@ class SlurmLauncher:
|
|||
self,
|
||||
job_name: str,
|
||||
cmd: List[str] | str,
|
||||
srun_cmd_template: str,
|
||||
count: int,
|
||||
nodes: int,
|
||||
n_gpus_per_node: int,
|
||||
|
@ -271,13 +97,6 @@ class SlurmLauncher:
|
|||
env_vars: Optional[Dict] = None,
|
||||
nodelist: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
apptainer_name: Optional[str] = "singularity",
|
||||
apptainer_options: Optional[Tuple[str, ...]] = (
|
||||
"--no-home",
|
||||
"--writable-tmpfs",
|
||||
"--nv",
|
||||
"--pid",
|
||||
),
|
||||
):
|
||||
"""Submits and launch a job array with SBATCH.
|
||||
Note that a job array has one (unique) slurm name, and one (unique) slurm id.
|
||||
|
@ -359,15 +178,13 @@ class SlurmLauncher:
|
|||
# FIXME: only for debugging, remove and replace new image
|
||||
# job_cmd = f'bash -c "pip3 install -r requirements.txt; {job_cmd}"'
|
||||
|
||||
srun_cmd = SRUN_CMD_TEMPLATE.format(
|
||||
srun_cmd = srun_cmd_template.format(
|
||||
nodes=1,
|
||||
ntasks=1,
|
||||
node_id=node_id,
|
||||
n_gpus_per_node=n_gpus_per_node,
|
||||
cpus_per_task=cpus_per_task,
|
||||
mem_per_cpu=mem_per_cpu,
|
||||
apptainer_name=apptainer_name,
|
||||
apptainer_options=" ".join(apptainer_options),
|
||||
container_mounts=container_mounts or "",
|
||||
container_env_strings=env_string,
|
||||
container_image=container_image,
|
||||
|
@ -534,7 +351,7 @@ class SlurmLauncher:
|
|||
left.remove(slurm_id)
|
||||
if update:
|
||||
self.jobs.pop(slurm_info.slurm_id)
|
||||
time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL)
|
||||
time.sleep(SLURM_WAIT_CHECK_TIME_INTERVAL)
|
||||
|
||||
def _update_all(self):
|
||||
"""Updates the status of all jobs."""
|
||||
|
@ -550,7 +367,7 @@ class SlurmLauncher:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# usage: python -m arealite.launcher.slurm <entry_point> <config_path> [<args>]
|
||||
# usage: python -m arealite.launcher.slurm <entry_point> --config <config_path> [<additional_args>]
|
||||
config, config_file = parse_cli_args(sys.argv[2:])
|
||||
|
||||
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
|
||||
|
@ -601,43 +418,30 @@ if __name__ == "__main__":
|
|||
)
|
||||
sglang_cmds.append(sglang_cmd)
|
||||
|
||||
if sglang_cmds:
|
||||
launcher.submit_array(
|
||||
job_name="sglang-server",
|
||||
cmd=sglang_cmds,
|
||||
count=n_sglang_servers,
|
||||
nodes=n_sglang_nodes,
|
||||
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
|
||||
* sglang_tp_size,
|
||||
mem_per_task=config.launcher.inference_server_mem_per_gpu
|
||||
* sglang_tp_size,
|
||||
container_image=config.cluster.gpu_infer_image,
|
||||
container_mounts=config.cluster.mount,
|
||||
env_vars=get_env_vars(
|
||||
config.cluster.cluster_name,
|
||||
config.launcher.inference_server_env_vars,
|
||||
),
|
||||
)
|
||||
|
||||
launcher.submit_array(
|
||||
job_name="llm_server",
|
||||
cmd=sglang_cmds,
|
||||
srun_cmd_template=config.launcher.srun_cmd_template or DEFAULT_SRUN_CMD_TEMPLATE,
|
||||
count=n_sglang_servers,
|
||||
nodes=n_sglang_nodes,
|
||||
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
|
||||
* sglang_tp_size,
|
||||
mem_per_task=config.launcher.inference_server_mem_per_gpu
|
||||
* sglang_tp_size,
|
||||
container_image=config.cluster.gpu_infer_image,
|
||||
container_mounts=config.cluster.mount,
|
||||
env_vars=get_env_vars(
|
||||
config.cluster.cluster_name,
|
||||
config.launcher.inference_server_env_vars,
|
||||
),
|
||||
)
|
||||
# Get SGLang slurm nodes, find the hosts
|
||||
name = names.gen_servers(config.experiment_name, config.trial_name)
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
sglang_addrs = name_resolve.get_subtree(name)
|
||||
if len(sglang_addrs) >= n_sglang_servers:
|
||||
logger.info(
|
||||
f"Found {n_sglang_servers} SGLang servers: {', '.join(sglang_addrs)}"
|
||||
)
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
if time.perf_counter() - start > SGLANG_SERVER_TIMEOUT_SECONDS:
|
||||
launcher.stop_all()
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for SGLang servers to be ready. "
|
||||
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
|
||||
)
|
||||
sglang_addrs = wait_sglang_server_addrs(
|
||||
config.experiment_name,
|
||||
config.trial_name,
|
||||
n_sglang_servers,
|
||||
)
|
||||
|
||||
trainer_n_nodes = n_nodes - n_sglang_nodes
|
||||
trainer_cmd_template = (
|
||||
|
@ -662,6 +466,7 @@ if __name__ == "__main__":
|
|||
launcher.submit_array(
|
||||
job_name="trainer",
|
||||
cmd=trainer_cmds,
|
||||
srun_cmd_template=config.launcher.srun_cmd_template or DEFAULT_SRUN_CMD_TEMPLATE,
|
||||
count=trainer_n_nodes,
|
||||
nodes=trainer_n_nodes,
|
||||
n_gpus_per_node=config.cluster.n_gpus_per_node,
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
import os
|
||||
import time
|
||||
import pathlib
|
||||
import getpass
|
||||
from typing import Dict, Optional
|
||||
|
||||
from realhf.base import names, name_resolve, logging
|
||||
|
||||
logger = logging.getLogger("Launcher Utils")
|
||||
|
||||
LOCAL_CACHE_DIR = "/tmp/arealite"
|
||||
PYTORCH_KERNEL_CACHE_PATH = (
|
||||
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
|
||||
)
|
||||
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
|
||||
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
|
||||
BASE_ENVIRONS = {
|
||||
"TOKENIZERS_PARALLELISM": "true",
|
||||
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
|
||||
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
||||
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
|
||||
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
|
||||
"HF_ENDPOINT": "https://hf-mirror.com", # FIXME: move to user option
|
||||
"PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent,
|
||||
}
|
||||
NA132_ENVIRONS = {
|
||||
"NCCL_SOCKET_IFNAME": "bond0",
|
||||
"NCCL_NET_PLUGIN": "",
|
||||
"NCCL_IB_GID_INDEX": "3",
|
||||
"NCCL_IB_TIMEOUT": "2",
|
||||
"NCCL_IB_RETRY_CNT": "7",
|
||||
"NCCL_IB_SL": "5",
|
||||
"NCCL_IB_TC": "136",
|
||||
"NCCL_IB_HCA": "mlx5_bond",
|
||||
"NCCL_IB_QPS_PER_CONNECTION": "8",
|
||||
"NCCL_SET_THREAD_NAME": "1",
|
||||
"NCCL_DEBUG": "WARN",
|
||||
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
||||
}
|
||||
SGLANG_SERVER_WAIT_TIMEOUT_SECONDS = 180
|
||||
|
||||
def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]:
|
||||
"""Returns the environment variables for the cluster."""
|
||||
additional_env_vars = {}
|
||||
if additional_env_vars:
|
||||
additional_env_vars = dict(
|
||||
item.split("=") for item in additional_env_vars.split(",")
|
||||
)
|
||||
if cluster_name == "na132":
|
||||
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars}
|
||||
else:
|
||||
return {**BASE_ENVIRONS, **additional_env_vars}
|
||||
|
||||
def wait_sglang_server_addrs(
|
||||
experiment_name: str,
|
||||
trial_name: str,
|
||||
n_sglang_servers: int,
|
||||
):
|
||||
# Get SGLang slurm nodes, find the hosts
|
||||
name = names.gen_servers(experiment_name, trial_name)
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
sglang_addrs = name_resolve.get_subtree(name)
|
||||
if len(sglang_addrs) >= n_sglang_servers:
|
||||
logger.info(
|
||||
f"Found {n_sglang_servers} SGLang servers: {', '.join(sglang_addrs)}"
|
||||
)
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
if time.perf_counter() - start > SGLANG_SERVER_WAIT_TIMEOUT_SECONDS:
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for SGLang servers to be ready. "
|
||||
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
|
||||
)
|
||||
return sglang_addrs
|
|
@ -0,0 +1,77 @@
|
|||
import os
|
||||
import time
|
||||
import pathlib
|
||||
import getpass
|
||||
from typing import Dict
|
||||
|
||||
from realhf.base import names, name_resolve, logging
|
||||
|
||||
logger = logging.getLogger("Launcher Utils")
|
||||
|
||||
LOCAL_CACHE_DIR = "/tmp/arealite"
|
||||
PYTORCH_KERNEL_CACHE_PATH = (
|
||||
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
|
||||
)
|
||||
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
|
||||
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
|
||||
BASE_ENVIRONS = {
|
||||
"TOKENIZERS_PARALLELISM": "true",
|
||||
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
|
||||
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
||||
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
|
||||
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
|
||||
"HF_ENDPOINT": "https://hf-mirror.com", # FIXME: move to user option
|
||||
"PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent,
|
||||
}
|
||||
NA132_ENVIRONS = {
|
||||
"NCCL_SOCKET_IFNAME": "bond0",
|
||||
"NCCL_NET_PLUGIN": "",
|
||||
"NCCL_IB_GID_INDEX": "3",
|
||||
"NCCL_IB_TIMEOUT": "2",
|
||||
"NCCL_IB_RETRY_CNT": "7",
|
||||
"NCCL_IB_SL": "5",
|
||||
"NCCL_IB_TC": "136",
|
||||
"NCCL_IB_HCA": "mlx5_bond",
|
||||
"NCCL_IB_QPS_PER_CONNECTION": "8",
|
||||
"NCCL_SET_THREAD_NAME": "1",
|
||||
"NCCL_DEBUG": "WARN",
|
||||
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
||||
}
|
||||
SGLANG_SERVER_WAIT_TIMEOUT_SECONDS = 180
|
||||
|
||||
def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]:
|
||||
"""Returns the environment variables for the cluster."""
|
||||
additional_env_vars = {}
|
||||
if additional_env_vars:
|
||||
additional_env_vars = dict(
|
||||
item.split("=") for item in additional_env_vars.split(",")
|
||||
)
|
||||
if cluster_name == "na132":
|
||||
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars}
|
||||
else:
|
||||
return {**BASE_ENVIRONS, **additional_env_vars}
|
||||
|
||||
def wait_sglang_server_addrs(
|
||||
experiment_name: str,
|
||||
trial_name: str,
|
||||
n_sglang_servers: int,
|
||||
):
|
||||
# Get SGLang slurm nodes, find the hosts
|
||||
name = names.gen_servers(experiment_name, trial_name)
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
sglang_addrs = name_resolve.get_subtree(name)
|
||||
if len(sglang_addrs) >= n_sglang_servers:
|
||||
logger.info(
|
||||
f"Found {n_sglang_servers} SGLang servers: {', '.join(sglang_addrs)}"
|
||||
)
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
if time.perf_counter() - start > SGLANG_SERVER_WAIT_TIMEOUT_SECONDS:
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for SGLang servers to be ready. "
|
||||
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
|
||||
)
|
||||
return sglang_addrs
|
|
@ -0,0 +1,147 @@
|
|||
import subprocess
|
||||
from typing import List, Optional, Literal
|
||||
|
||||
from realhf.base import logging, name_resolve, names
|
||||
from realhf.scheduler.client import JobException, JobInfo, JobState
|
||||
|
||||
logger = logging.getLogger("Slurm Utils")
|
||||
|
||||
|
||||
SQUEUE_FIELDS = [
|
||||
"JobID",
|
||||
"State",
|
||||
"SubmitTime",
|
||||
"StartTime",
|
||||
"Name",
|
||||
"NodeList",
|
||||
"UserName",
|
||||
"MaxCPUs",
|
||||
"cpus-per-task",
|
||||
"NumTasks",
|
||||
"tres-alloc",
|
||||
]
|
||||
STATUS_MAPPING = {
|
||||
"RUNNING": JobState.RUNNING,
|
||||
"COMPLETING": JobState.RUNNING,
|
||||
"PENDING": JobState.PENDING,
|
||||
"CANCELLED": JobState.CANCELLED,
|
||||
"FAILED": JobState.FAILED,
|
||||
"COMPLETED": JobState.COMPLETED,
|
||||
"OUT_OF_MEMORY": JobState.FAILED,
|
||||
"DEADLINE": JobState.COMPLETED,
|
||||
"TIMEOUT": JobState.COMPLETED,
|
||||
}
|
||||
|
||||
SBATCH_SCRIPT_TEMPLATE = """#!/bin/bash
|
||||
{sbatch_options}
|
||||
|
||||
# Getting the node names
|
||||
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
|
||||
echo nodes=$nodes
|
||||
|
||||
nodes_array=($nodes)
|
||||
echo node_array=$nodes_array
|
||||
|
||||
head_node=${{nodes_array[0]}}
|
||||
echo head_node=$head_node
|
||||
|
||||
# Getting the head node IP address
|
||||
head_node_ip=$(srun --overlap --mpi=pmi2 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist="$head_node" hostname --ip-address)
|
||||
echo head_node_ip=$head_node_ip
|
||||
|
||||
# srun commands
|
||||
{srun_cmds}
|
||||
|
||||
wait
|
||||
"""
|
||||
|
||||
# Default srun command template, using singularity as apptainer
|
||||
DEFAULT_SRUN_CMD_TEMPLATE: str = """srun --overlap --mpi=pmi2 -K -l --chdir $PWD \\
|
||||
--nodelist=${{nodes_array[{node_id}]}} --nodes={nodes} --ntasks={ntasks} \\
|
||||
--gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} --mem-per-cpu={mem_per_cpu}M \\
|
||||
singularity exec --no-home --writable-tmpfs --nv --pid \\
|
||||
--bind {container_mounts} \\
|
||||
{container_env_strings} \\
|
||||
{container_image} \\
|
||||
{cmd} &
|
||||
"""
|
||||
|
||||
def cancel_jobs(
|
||||
slurm_names: Optional[List[str]] = None,
|
||||
slurm_ids: Optional[List[int]] = None,
|
||||
signal: Literal["SIGINT", "SIGKILL"] = "SIGKILL",
|
||||
):
|
||||
assert (
|
||||
slurm_names is not None or slurm_ids is not None
|
||||
), "Must specify slurm_names or slurm_ids."
|
||||
assert not (
|
||||
slurm_names and slurm_ids
|
||||
), "Cannot specify both slurm_names and slurm_ids."
|
||||
cmd = ["scancel", "-s", signal]
|
||||
if slurm_names is not None:
|
||||
cmd += ["-n", ",".join(slurm_names)]
|
||||
elif slurm_ids is not None:
|
||||
cmd += [",".join(str(s) for s in slurm_ids)]
|
||||
subprocess.check_call(cmd)
|
||||
logger.info(
|
||||
f"Cancelled Slurm job with signal {signal}: "
|
||||
f"slurm identifiers {slurm_names if slurm_ids is None else slurm_ids}. CMD: {cmd}"
|
||||
)
|
||||
|
||||
|
||||
def query_jobs(
|
||||
slurm_names: Optional[List[str]] = None,
|
||||
slurm_ids: Optional[List[int]] = None,
|
||||
status: str = "all",
|
||||
delimiter: str = "__PSI__",
|
||||
) -> List[JobInfo]:
|
||||
squeue_format = f":.{delimiter},".join(SQUEUE_FIELDS)
|
||||
cmd = ["squeue", "-O", squeue_format, f"-t{status}"]
|
||||
if slurm_names is not None:
|
||||
cmd += ["-n", ",".join(slurm_names)]
|
||||
if slurm_ids is not None:
|
||||
cmd += ["-j", ",".join([str(s) for s in slurm_ids])]
|
||||
|
||||
output = (
|
||||
subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode("ascii").strip()
|
||||
)
|
||||
rs = []
|
||||
for line in output.split("\n")[1:]:
|
||||
job_id, state, submit_time, start_time, slurm_name, nodelist, *_ = line.split(
|
||||
delimiter
|
||||
)
|
||||
rs.append(
|
||||
JobInfo(
|
||||
name=slurm_name,
|
||||
state=STATUS_MAPPING[state],
|
||||
host=nodelist,
|
||||
submit_time=submit_time,
|
||||
start_time=start_time,
|
||||
slurm_id=int(job_id.strip()),
|
||||
)
|
||||
)
|
||||
return rs
|
||||
|
||||
|
||||
def parse_slurm_nodelist(nodelist: str) -> List[str]:
|
||||
return (
|
||||
subprocess.check_output(
|
||||
[
|
||||
"scontrol",
|
||||
"show",
|
||||
"hostnames",
|
||||
nodelist,
|
||||
]
|
||||
)
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
.split("\n")
|
||||
)
|
||||
|
||||
|
||||
def get_slurm_host_ip(node: str):
|
||||
try:
|
||||
cmd = f"srun --overlap --mpi=pmi2 --immediate=1 --nodes=1 --ntasks=1 -n1 -c1 --mem=10M --nodelist={node} hostname --ip-address"
|
||||
return subprocess.check_output(cmd.split(" ")).decode("utf-8").strip()
|
||||
except subprocess.CalledProcessError:
|
||||
logger.warning(f"Get slurm host ip for node {node} failed.")
|
|
@ -106,8 +106,8 @@ def boba_reward_fn(
|
|||
return label
|
||||
|
||||
|
||||
def main_grpo():
|
||||
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
|
||||
def main_grpo(argv):
|
||||
config, _ = load_expr_config(argv, GRPOConfig)
|
||||
config: GRPOConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
|
@ -238,4 +238,4 @@ def main_grpo():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_grpo()
|
||||
main_grpo(sys.argv[1:])
|
||||
|
|
|
@ -75,8 +75,8 @@ def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **k
|
|||
return int(sol.strip() == ans.strip())
|
||||
|
||||
|
||||
def main_grpo():
|
||||
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
|
||||
def main_grpo(argv):
|
||||
config, _ = load_expr_config(argv, GRPOConfig)
|
||||
config: GRPOConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
|
@ -250,4 +250,4 @@ def main_grpo():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_grpo()
|
||||
main_grpo(sys.argv[1:])
|
||||
|
|
|
@ -38,8 +38,8 @@ def get_gsm8k_dataset(split, tokenizer, rank, world_size):
|
|||
return process_gsm8k_sft_dataset(dataset, tokenizer)
|
||||
|
||||
|
||||
def main_sft():
|
||||
config, _ = load_expr_config(sys.argv[1:], SFTConfig)
|
||||
def main_sft(argv):
|
||||
config, _ = load_expr_config(argv, SFTConfig)
|
||||
config: SFTConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
|
@ -120,4 +120,4 @@ def main_sft():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_sft()
|
||||
main_sft(sys.argv[1:])
|
||||
|
|
Loading…
Reference in New Issue