mirror of https://github.com/inclusionAI/AReaL
411 lines
15 KiB
Python
411 lines
15 KiB
Python
import getpass
|
|
import importlib.util
|
|
import os
|
|
import pathlib
|
|
import sys
|
|
import time
|
|
from typing import Dict, List, Optional
|
|
|
|
import ray
|
|
import ray.exceptions
|
|
from ray.runtime_env import RuntimeEnv
|
|
from ray.util.placement_group import PlacementGroup
|
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
|
|
|
import realhf.base.logging as logging
|
|
from arealite.api.cli_args import (
|
|
ClusterSpecConfig,
|
|
LauncherConfig,
|
|
SGLangConfig,
|
|
parse_cli_args,
|
|
to_structured_cfg,
|
|
)
|
|
from arealite.api.io_struct import AllocationMode, AllocationType
|
|
from arealite.utils.launcher import (
|
|
get_env_vars,
|
|
validate_config_for_distributed_launcher,
|
|
wait_sglang_server_addrs,
|
|
)
|
|
from arealite.utils.ray import get_placement_group_master_ip_and_port
|
|
from realhf.base import logging, name_resolve, names
|
|
from realhf.scheduler.client import JobException, JobState
|
|
|
|
logger = logging.getLogger("RayLauncher")
|
|
|
|
RAY_WAIT_CHECK_TIME_INTERVAL = 5 # seconds
|
|
DEFAULT_MAIN_FUNC_NAME = "main"
|
|
|
|
|
|
def run_func(file_path, function_name, *args, **kwargs):
|
|
# Convert the file path to a module name
|
|
module_name = file_path.replace("/", "_").replace(".", "_")
|
|
|
|
# Load the module from file path
|
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[module_name] = module
|
|
spec.loader.exec_module(module)
|
|
|
|
# Get the function and execute it
|
|
try:
|
|
function = getattr(module, function_name)
|
|
except AttributeError as e:
|
|
raise ValueError(
|
|
f"Function '{function_name}' not found in module '{module_name}'. "
|
|
f"Please ensure the name of the main function in your entry point "
|
|
f"is '{function_name}'."
|
|
) from e
|
|
return function(*args, **kwargs)
|
|
|
|
|
|
class RayLauncher:
|
|
def __init__(self, experiment_name: str, trial_name: str, fileroot: str):
|
|
self.experiment_name = experiment_name
|
|
self.trial_name = trial_name
|
|
self.fileroot = fileroot
|
|
|
|
# job_name to ray future
|
|
self.jobs = {}
|
|
|
|
@property
|
|
def run_name(self):
|
|
return f"{self.experiment_name}_{self.trial_name}"
|
|
|
|
def log_path_of(self, job_name: str) -> str:
|
|
log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}"
|
|
os.makedirs(log_path, exist_ok=True)
|
|
return os.path.join(log_path, f"{job_name}.log")
|
|
|
|
def submit(
|
|
self,
|
|
job_name: str,
|
|
file_path: str,
|
|
func_name: str,
|
|
args: List[str], # arguments to pass to the function
|
|
gpus: int,
|
|
cpus: int,
|
|
mem: int, # MB
|
|
env_vars: Optional[Dict] = None,
|
|
placement_group: Optional[PlacementGroup] = None,
|
|
bundle_index: int = -1,
|
|
kwargs: Optional[
|
|
Dict[str, str]
|
|
] = None, # keyword arguments to pass to the function
|
|
):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
runtime_env = RuntimeEnv(
|
|
env_vars=env_vars or dict(),
|
|
)
|
|
scheduling_strategy = (
|
|
PlacementGroupSchedulingStrategy(
|
|
placement_group=placement_group,
|
|
placement_group_bundle_index=bundle_index,
|
|
placement_group_capture_child_tasks=True,
|
|
)
|
|
if placement_group is not None
|
|
else "DEFAULT"
|
|
)
|
|
future = ray.remote(
|
|
num_cpus=cpus,
|
|
num_gpus=gpus,
|
|
memory=mem * 1024 * 1024, # Convert MB to bytes
|
|
runtime_env=runtime_env,
|
|
scheduling_strategy=scheduling_strategy,
|
|
)(run_func).remote(file_path, func_name, *args, **kwargs)
|
|
self.jobs[job_name] = future
|
|
return future
|
|
|
|
def submit_array(
|
|
self,
|
|
job_name: str,
|
|
file_path: str,
|
|
func_name: str,
|
|
count: int,
|
|
nodes: int,
|
|
list_args: List[List],
|
|
gpus_per_task: int,
|
|
cpus_per_task: int,
|
|
mem_per_task: int, # MB
|
|
list_kwargs: List[Dict] | None = None,
|
|
env_vars: Optional[Dict] = None,
|
|
amend_torch_dist_env: bool = False,
|
|
):
|
|
"""Submit an array of jobs to Ray with ray placement groups.
|
|
|
|
Note: Here we use `ray.remote` instead of `ray job submit` since `ray job submit`
|
|
does not support placement groups, and can not specify which node to run the job on.
|
|
Therefore we could not know the IP address of jobs for torch distributed initialization.
|
|
"""
|
|
|
|
if count % nodes != 0:
|
|
raise ValueError(
|
|
f"Count {count} is not divisible by nodes {nodes}. "
|
|
"Please ensure that count is a multiple of nodes."
|
|
)
|
|
assert (
|
|
len(list_args) == count
|
|
), f"Length of list_args {len(list_args)} does not match count {count}."
|
|
if list_kwargs is not None:
|
|
assert (
|
|
len(list_kwargs) == count
|
|
), f"Length of list_kwargs {len(list_kwargs)} does not match count {count}."
|
|
|
|
tasks_per_node = count // nodes
|
|
gpus_per_node = gpus_per_task * tasks_per_node
|
|
cpus_per_node = cpus_per_task * tasks_per_node
|
|
mem_per_node = mem_per_task * tasks_per_node
|
|
|
|
placement_group = ray.util.placement_group(
|
|
bundles=[
|
|
{
|
|
"CPU": cpus_per_node,
|
|
"GPU": gpus_per_node,
|
|
"memory": mem_per_node * 1024 * 1024, # Convert MB to bytes
|
|
}
|
|
]
|
|
* nodes,
|
|
strategy="STRICT_SPREAD",
|
|
)
|
|
try:
|
|
ray.get(placement_group.ready(), timeout=30)
|
|
except ray.exceptions.GetTimeoutError as e:
|
|
logger.error(
|
|
"Ray placement group timeout, please check if the resource requirement "
|
|
"for your experiment exceeds the available resources in the cluster. \n"
|
|
f"ray.nodes(): {ray.nodes()} \n"
|
|
f"Placement Group bundles: "
|
|
f"cpus_per_node={cpus_per_node}, gpus_per_node={gpus_per_node}, "
|
|
f"mem_per_node={mem_per_node}MB, nodes={nodes}"
|
|
)
|
|
raise e
|
|
|
|
if amend_torch_dist_env:
|
|
host_ip, port = get_placement_group_master_ip_and_port(placement_group)
|
|
logger.info(
|
|
f"Amend torch distributed env vars: "
|
|
f"MASTER_ADDR={host_ip}, PORT={port}"
|
|
)
|
|
|
|
futures = []
|
|
for i in range(count):
|
|
args = list_args[i]
|
|
kwargs = list_kwargs[i] if list_kwargs is not None else {}
|
|
|
|
# manage environment variables
|
|
env_vars = env_vars or {}
|
|
if "CUDA_VISIBLE_DEVICES" in env_vars:
|
|
logger.warning(
|
|
"Setting CUDA_VISIBLE_DEVICES before running ray jobs may result in unexpected behavior."
|
|
)
|
|
|
|
node_id = i // tasks_per_node
|
|
_env_vars = {
|
|
**env_vars,
|
|
}
|
|
|
|
if amend_torch_dist_env:
|
|
assert gpus_per_task == 1
|
|
# NOTE: Here we only provide environment variables for torch distributed
|
|
# initialization, and LOCAL_RANK for torch.device.
|
|
# Other environment variables automatically set by torchrun are not set, and
|
|
# they should be never accessed in trainer code.
|
|
_env_vars.update(
|
|
{
|
|
"RANK": str(i),
|
|
"WORLD_SIZE": str(count),
|
|
# Ray will automatically isolate CUDA_VISIBLE_DEVICES for each GPU
|
|
"LOCAL_RANK": "0",
|
|
"MASTER_ADDR": str(host_ip),
|
|
"MASTER_PORT": str(port),
|
|
}
|
|
)
|
|
future = self.submit(
|
|
job_name=f"{job_name}:{i}",
|
|
file_path=file_path,
|
|
func_name=func_name,
|
|
args=args,
|
|
gpus=gpus_per_task,
|
|
cpus=cpus_per_task,
|
|
mem=mem_per_task,
|
|
env_vars=_env_vars,
|
|
placement_group=placement_group,
|
|
bundle_index=node_id,
|
|
kwargs=kwargs,
|
|
)
|
|
futures.append(future)
|
|
|
|
return futures
|
|
|
|
def stop(self, job_name: str, force: bool = False):
|
|
"""Stop a job by name."""
|
|
if job_name in self.jobs:
|
|
future = self.jobs[job_name]
|
|
try:
|
|
ray.cancel(future, force=force)
|
|
except Exception as e:
|
|
logger.error(f"Failed to cancel job {job_name}: {e}")
|
|
return
|
|
self.jobs.pop(job_name, None)
|
|
logger.info(f"Job {job_name} stopped.")
|
|
else:
|
|
logger.warning(f"Job {job_name} not found in running jobs.")
|
|
|
|
def stop_all(self, force: bool = False):
|
|
"""Stop all jobs."""
|
|
for job_name in list(self.jobs.keys()):
|
|
self.stop(job_name, force=force)
|
|
logger.info("All jobs stopped.")
|
|
self.jobs.clear()
|
|
|
|
def wait(
|
|
self, check_status=(JobState.FAILED,), remove_status=(JobState.COMPLETED,)
|
|
):
|
|
"""Check every RAY_WAIT_CHECK_TIME_INTERVAL seconds for the status of all jobs.
|
|
If a ray job returns, its status changes to JobState.COMPLETED.
|
|
If a ray job failed, its status changes to JobState.FAILED.
|
|
If any job is in check_status, stop all jobs at once.
|
|
If any job is in remove status, remove them from job list.
|
|
Return if all jobs are removed from job list, or some job is in check status.
|
|
"""
|
|
for status in list(check_status) + list(remove_status):
|
|
assert status in [
|
|
JobState.COMPLETED,
|
|
JobState.FAILED,
|
|
], "In RayLauncher.wait, we only check completed or failed jobs."
|
|
logger.info(f"Waiting for {len(self.jobs)} jobs.")
|
|
while self.jobs:
|
|
job_status = {}
|
|
for job_name, future in list(self.jobs.items()):
|
|
try:
|
|
r = ray.get(future, timeout=0.1)
|
|
logger.info(f"Job {job_name} completed with result: {r}")
|
|
job_status[job_name] = JobState.COMPLETED
|
|
except ray.exceptions.RayTaskError as e:
|
|
logger.error(f"Job {job_name} failed with error: {e}.")
|
|
job_status[job_name] = JobState.FAILED
|
|
except ray.exceptions.GetTimeoutError:
|
|
continue
|
|
|
|
for job_name, status in job_status.items():
|
|
if status in check_status:
|
|
logger.info(f"Job {job_name} is {status}, stopping all jobs.")
|
|
self.stop_all(force=True)
|
|
return
|
|
if status in remove_status:
|
|
logger.info(f"Job {job_name} is {status}, removed.")
|
|
self.jobs.pop(job_name)
|
|
|
|
time.sleep(RAY_WAIT_CHECK_TIME_INTERVAL)
|
|
|
|
|
|
def ray_main():
|
|
# usage: python -m arealite.launcher.ray <entry_point> --config <config_path> [<additional_args>]
|
|
ray.init()
|
|
config, config_file = parse_cli_args(sys.argv[2:])
|
|
config.launcher = to_structured_cfg(config.launcher, LauncherConfig)
|
|
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
|
|
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
|
|
validate_config_for_distributed_launcher(config)
|
|
|
|
name_resolve.reconfigure(config.cluster.name_resolve)
|
|
name_resolve.clear_subtree(
|
|
names.trial_root(
|
|
experiment_name=config.experiment_name, trial_name=config.trial_name
|
|
)
|
|
)
|
|
|
|
n_nodes = config.cluster.n_nodes
|
|
n_gpus_per_node = config.cluster.n_gpus_per_node
|
|
launcher = RayLauncher(
|
|
experiment_name=config.experiment_name,
|
|
trial_name=config.trial_name,
|
|
fileroot=config.cluster.fileroot,
|
|
)
|
|
allocation_mode = config.allocation_mode
|
|
allocation_mode = AllocationMode.from_str(allocation_mode)
|
|
sglang_cmds = []
|
|
sglang_addrs = []
|
|
n_sglang_nodes = 0
|
|
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
|
# Launcher should launch SGLang servers according to allocation mode.
|
|
sglang_tp_size = allocation_mode.gen_tp_size
|
|
n_sglang_servers = allocation_mode.gen_dp_size
|
|
n_sglang_nodes = allocation_mode.gen_world_size // n_gpus_per_node
|
|
|
|
base_seed = config.sglang.random_seed
|
|
sglang_args_list = [
|
|
[sys.argv[2:] + [f"sglang.random_seed={base_seed + i}"]]
|
|
for i in range(n_sglang_servers)
|
|
]
|
|
sglang_entry_point = str(
|
|
pathlib.Path(__file__).resolve().parent.joinpath("sglang_server.py")
|
|
)
|
|
launcher.submit_array(
|
|
job_name="llm_server",
|
|
file_path=sglang_entry_point,
|
|
func_name=DEFAULT_MAIN_FUNC_NAME,
|
|
count=n_sglang_servers,
|
|
nodes=n_sglang_nodes,
|
|
list_args=sglang_args_list,
|
|
gpus_per_task=sglang_tp_size,
|
|
cpus_per_task=config.launcher.inference_server_cpus_per_gpu
|
|
* sglang_tp_size,
|
|
mem_per_task=config.launcher.inference_server_mem_per_gpu * sglang_tp_size,
|
|
env_vars=get_env_vars(
|
|
config.cluster.cluster_name,
|
|
config.launcher.inference_server_env_vars,
|
|
),
|
|
)
|
|
# Get SGLang server addresses via name_resolve
|
|
try:
|
|
sglang_addrs = wait_sglang_server_addrs(
|
|
config.experiment_name,
|
|
config.trial_name,
|
|
n_sglang_servers,
|
|
)
|
|
except TimeoutError as e:
|
|
launcher.stop_all(force=True)
|
|
raise e
|
|
|
|
trainer_n_nodes = n_nodes - n_sglang_nodes
|
|
trainer_entry_point = sys.argv[1]
|
|
n_trainer_processes = trainer_n_nodes * config.cluster.n_gpus_per_node
|
|
trainer_args_list = [[sys.argv[2:]] for _ in range(n_trainer_processes)]
|
|
if not config.server_only:
|
|
# In ray, we launch trainer in the granularity of processes (1 GPU per process)
|
|
# We amend environment variable similar to torchrun to ensure correct initialization of
|
|
# torch distributed.
|
|
launcher.submit_array(
|
|
job_name="trainer",
|
|
file_path=trainer_entry_point,
|
|
func_name=DEFAULT_MAIN_FUNC_NAME,
|
|
count=trainer_n_nodes * config.cluster.n_gpus_per_node,
|
|
nodes=trainer_n_nodes,
|
|
list_args=trainer_args_list,
|
|
gpus_per_task=1,
|
|
cpus_per_task=config.launcher.trainer_cpus_per_gpu,
|
|
mem_per_task=config.launcher.trainer_mem_per_gpu,
|
|
env_vars=dict(
|
|
**get_env_vars(
|
|
config.cluster.cluster_name,
|
|
config.launcher.trainer_env_vars,
|
|
),
|
|
AREAL_LLM_SERVER_ADDRS=",".join(sglang_addrs),
|
|
),
|
|
amend_torch_dist_env=True,
|
|
)
|
|
|
|
try:
|
|
launcher.wait(check_status=(JobState.COMPLETED, JobState.FAILED))
|
|
except (KeyboardInterrupt, JobException, TimeoutError) as e:
|
|
launcher.stop_all(force=True)
|
|
raise e
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# usage: python -m arealite.launcher.ray \
|
|
# <entry_point> --config <config_path> [<additional_args>] \
|
|
# launcher.ray.main_func_name=<main_func_name_in_entry_point>
|
|
ray_main()
|