AReaL/arealite/utils/launcher.py

77 lines
2.6 KiB
Python

import os
import time
import pathlib
import getpass
from typing import Dict
from realhf.base import names, name_resolve, logging
logger = logging.getLogger("Launcher Utils")
LOCAL_CACHE_DIR = "/tmp/arealite"
PYTORCH_KERNEL_CACHE_PATH = (
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
)
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
BASE_ENVIRONS = {
"TOKENIZERS_PARALLELISM": "true",
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
"HF_ENDPOINT": "https://hf-mirror.com", # FIXME: move to user option
"PYTHONPATH": pathlib.Path(__file__).resolve().parent.parent.parent,
}
NA132_ENVIRONS = {
"NCCL_SOCKET_IFNAME": "bond0",
"NCCL_NET_PLUGIN": "",
"NCCL_IB_GID_INDEX": "3",
"NCCL_IB_TIMEOUT": "2",
"NCCL_IB_RETRY_CNT": "7",
"NCCL_IB_SL": "5",
"NCCL_IB_TC": "136",
"NCCL_IB_HCA": "mlx5_bond",
"NCCL_IB_QPS_PER_CONNECTION": "8",
"NCCL_SET_THREAD_NAME": "1",
"NCCL_DEBUG": "WARN",
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
}
SGLANG_SERVER_WAIT_TIMEOUT_SECONDS = 180
def get_env_vars(cluster_name: str, additional_env_vars: str = None) -> Dict[str, str]:
"""Returns the environment variables for the cluster."""
additional_env_vars = {}
if additional_env_vars:
additional_env_vars = dict(
item.split("=") for item in additional_env_vars.split(",")
)
if cluster_name == "na132":
return {**BASE_ENVIRONS, **NA132_ENVIRONS, **additional_env_vars}
else:
return {**BASE_ENVIRONS, **additional_env_vars}
def wait_sglang_server_addrs(
experiment_name: str,
trial_name: str,
n_sglang_servers: int,
):
# Get SGLang slurm nodes, find the hosts
name = names.gen_servers(experiment_name, trial_name)
start = time.perf_counter()
while True:
sglang_addrs = name_resolve.get_subtree(name)
if len(sglang_addrs) >= n_sglang_servers:
logger.info(
f"Found {n_sglang_servers} SGLang servers: {', '.join(sglang_addrs)}"
)
break
time.sleep(1)
if time.perf_counter() - start > SGLANG_SERVER_WAIT_TIMEOUT_SECONDS:
raise TimeoutError(
f"Timeout waiting for SGLang servers to be ready. "
f"Expected {n_sglang_servers} servers, found {len(sglang_addrs)}."
)
return sglang_addrs