mirror of https://github.com/inclusionAI/AReaL
639 lines
24 KiB
Python
639 lines
24 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
import copy
|
|
import dataclasses
|
|
import enum
|
|
import functools
|
|
import getpass
|
|
import json
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from dataclasses import asdict
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import colorama
|
|
import ray
|
|
import ray.util.queue as rq
|
|
import torch
|
|
from omegaconf import OmegaConf
|
|
|
|
import realhf.api.core.system_api as system_api
|
|
from realhf.base import constants, gpu_utils, logging, name_resolve, names
|
|
from realhf.base.cluster import spec as cluster_spec
|
|
from realhf.system import WORKER_TYPES, load_worker, worker_base, worker_control
|
|
from realhf.system.worker_base import WorkerServerStatus as Wss
|
|
|
|
CONNECTION_RETRY_AFTER_SECONDS = 360
|
|
|
|
logger = logging.getLogger("controller", "colored")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TrialStatus:
|
|
experiment_name: str
|
|
trial_name: str
|
|
running_workers: Dict[str, List[str]] = dataclasses.field(default_factory=dict)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TrialHistory:
|
|
experiment_name: str
|
|
trial_name: str
|
|
age_days: int
|
|
|
|
|
|
class ControllerExitStatus(enum.Enum):
|
|
SUCCESS = 0
|
|
TIMEOUT = 1
|
|
INTERRUPTED = 9
|
|
FAIL = 101
|
|
LOST = 102
|
|
UNKNOWN = 404
|
|
|
|
|
|
class Controller:
|
|
|
|
def __init__(
|
|
self, experiment_name, trial_name, panel: worker_base.WorkerControlPanel
|
|
):
|
|
assert "_" not in experiment_name, (
|
|
f"_ not allowed in experiment_name (args: -e) "
|
|
f"{experiment_name}, use '-' instead."
|
|
)
|
|
assert (
|
|
"_" not in trial_name
|
|
), f"_ not allowed in trial_name (args: -f) {trial_name}, use '-' instead."
|
|
self.experiment_name = experiment_name
|
|
self.trial_name = trial_name
|
|
|
|
logger.info("Experiment: %s %s", self.experiment_name, self.trial_name)
|
|
|
|
self.__control = panel
|
|
|
|
def reconnect(self):
|
|
"""Automatically reconnect to workers.
|
|
|
|
And list all jobs to scheduler.
|
|
"""
|
|
self.__control.auto_connect()
|
|
|
|
def __check_consistent_scheduling(
|
|
self,
|
|
scheduling: system_api.ExperimentScheduling,
|
|
setup: system_api.ExperimentConfig,
|
|
verbose=False,
|
|
):
|
|
# Scheduling and connecting to workers.
|
|
workers_configs = [
|
|
(k, getattr(setup, k), getattr(scheduling, k)) for k in WORKER_TYPES
|
|
]
|
|
|
|
# Sanity check for scheduling and configuration.
|
|
for _, worker_setups, schedules in workers_configs:
|
|
if not isinstance(schedules, List):
|
|
schedules = [schedules]
|
|
if len(worker_setups) != sum(s.count for s in schedules):
|
|
raise ValueError(
|
|
f"Configuration and scheduling mismatch. "
|
|
f"Number of worker configurations: {len(worker_setups)}, "
|
|
f"Scheduling configs: {schedules}."
|
|
)
|
|
|
|
for name, config, schedule in workers_configs:
|
|
count = (
|
|
sum([s.count for s in schedule])
|
|
if isinstance(schedule, list)
|
|
else schedule.count
|
|
)
|
|
if len(config) != count:
|
|
logger.error(
|
|
"Scheduling and config mismatch, interrupting all workers."
|
|
)
|
|
self.interrupt()
|
|
raise IndexError(
|
|
f"Configuration has {len(config)} {name}, {count} scheduled."
|
|
)
|
|
if verbose:
|
|
logger.info(f"Configuration has {len(config)} {name}.")
|
|
|
|
def start(self, experiment: system_api.Experiment, ignore_worker_error=False):
|
|
if ignore_worker_error:
|
|
check_worker_status = ()
|
|
remove_worker_status = (
|
|
Wss.COMPLETED,
|
|
Wss.ERROR,
|
|
Wss.LOST,
|
|
Wss.UNKNOWN,
|
|
Wss.PAUSED,
|
|
)
|
|
else:
|
|
check_worker_status = (Wss.ERROR, Wss.LOST, Wss.UNKNOWN)
|
|
remove_worker_status = (Wss.COMPLETED, Wss.PAUSED)
|
|
|
|
scheduling = experiment.scheduling_setup()
|
|
raw_experiment = copy.deepcopy(experiment)
|
|
setups = experiment.initial_setup()
|
|
if not isinstance(setups, list):
|
|
setups = [setups]
|
|
|
|
# Sanity check before launching workers.
|
|
for i, setup in enumerate(setups):
|
|
self.__check_consistent_scheduling(scheduling, setup, verbose=(i == 0))
|
|
|
|
worker_counts = [(k, len(getattr(setups[0], k))) for k in WORKER_TYPES]
|
|
|
|
name_resolve.add(
|
|
names.trial_registry(self.experiment_name, self.trial_name),
|
|
value=datetime.now().strftime("%Y%m%d"),
|
|
delete_on_exit=False,
|
|
replace=True,
|
|
)
|
|
name_resolve.add(
|
|
names.worker_status(
|
|
experiment_name=self.experiment_name,
|
|
trial_name=self.trial_name,
|
|
worker_name="ctl",
|
|
),
|
|
value="READY",
|
|
delete_on_exit=True,
|
|
)
|
|
|
|
while True:
|
|
try:
|
|
logger.info("Connecting to workers...")
|
|
self.__control.connect(
|
|
[
|
|
self.__control.name(name, i)
|
|
for name, count in worker_counts
|
|
for i in range(count)
|
|
],
|
|
progress=True,
|
|
timeout=CONNECTION_RETRY_AFTER_SECONDS,
|
|
raises_timeout_error=True,
|
|
)
|
|
break
|
|
|
|
except TimeoutError:
|
|
logger.info("Connecting to workers timeout. Retrying...")
|
|
except KeyboardInterrupt as e:
|
|
logger.info("Interrupted by user. Stopping all and exiting...")
|
|
raise e
|
|
|
|
name_resolve.delete(
|
|
names.worker_status(
|
|
experiment_name=self.experiment_name,
|
|
trial_name=self.trial_name,
|
|
worker_name="ctl",
|
|
)
|
|
)
|
|
|
|
# If a log exists, find the last failed setup and run it.
|
|
start_idx = 0
|
|
prev_logfile = os.path.join(
|
|
constants.LOG_ROOT, self.experiment_name, self.trial_name, "ctl-0"
|
|
)
|
|
if os.path.exists(prev_logfile):
|
|
with open(prev_logfile, "r") as f:
|
|
for l in f.readlines():
|
|
match = re.search(r"Entering setup (\d+)/(\d+)", l)
|
|
if match and int(match.group(2)) == len(setups):
|
|
last_end_idx = int(match.group(1)) - 1
|
|
if last_end_idx < len(setups) - 1:
|
|
start_idx = last_end_idx
|
|
|
|
# NOTE: Since worker processes are created and killed by the scheduler,
|
|
# the controller cannot restart a dead worker when error occurs,
|
|
# and it's impossible to continue the experiment when any of the multiple setups fails.
|
|
# We can only relaunch the entire experiment in this case.
|
|
# In particular, while it seems to be possible to continue the experiment if
|
|
# the OOM error occurs, OOM will cause NCCL communication getting stuck (e.g, send/recv),
|
|
# which will finally throw out a C++ exception in the watchdog thread after reaching timeout.
|
|
# We cannot catch this exception, so OOM is irrecoverable.
|
|
for offset, setup in enumerate(setups[start_idx:]):
|
|
i = offset + start_idx
|
|
|
|
s = f" Entering setup {i+1}/{len(setups)}... ".center(80, "#")
|
|
logger.info(colorama.Fore.RED + "#" * len(s) + colorama.Style.RESET_ALL)
|
|
logger.info(colorama.Fore.RED + s + colorama.Style.RESET_ALL)
|
|
logger.info(colorama.Fore.RED + "#" * len(s) + colorama.Style.RESET_ALL)
|
|
|
|
# Configure workers.
|
|
setup.set_worker_information(
|
|
experiment_name=self.experiment_name, trial_name=self.trial_name
|
|
)
|
|
try:
|
|
for name in WORKER_TYPES:
|
|
worker_infos = [x.worker_info for x in getattr(setup, name)]
|
|
logger.info(f"Configuring Workers: {name}...")
|
|
|
|
self.__control.group_request(
|
|
"configure",
|
|
worker_names=[
|
|
self.__control.name(name, i)
|
|
for i in range(len(worker_infos))
|
|
],
|
|
worker_kwargs=[
|
|
dict(worker_info=wi, setup_id=i) for wi in worker_infos
|
|
],
|
|
progress=True,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Configuring Failed: {e}. Exiting Workers.")
|
|
logger.error(traceback.format_exc())
|
|
self.interrupt(wait_timeout=120)
|
|
raise e
|
|
|
|
logger.info("Start workers...")
|
|
self.__control.group_request("start")
|
|
logger.info("Started.")
|
|
try:
|
|
self.wait(
|
|
timeout=None,
|
|
check_status=check_worker_status,
|
|
remove_status=remove_worker_status,
|
|
)
|
|
except worker_base.WorkerException as e:
|
|
logger.error(e)
|
|
self.interrupt(wait_timeout=30)
|
|
except KeyboardInterrupt:
|
|
logger.info("Interrupted.")
|
|
self.interrupt(wait_timeout=30)
|
|
|
|
s = f" Finishing setup {i+1}/{len(setups)}, pausing workers... ".center(
|
|
80, "#"
|
|
)
|
|
logger.info(colorama.Fore.RED + s + colorama.Style.RESET_ALL)
|
|
|
|
logger.info(
|
|
colorama.Fore.YELLOW
|
|
+ colorama.Style.BRIGHT
|
|
+ "\033[1m"
|
|
+ "=" * 80
|
|
+ colorama.Style.RESET_ALL
|
|
)
|
|
logger.info(
|
|
colorama.Fore.YELLOW
|
|
+ colorama.Style.BRIGHT
|
|
+ "\033[1m"
|
|
+ (
|
|
f" All {len(setups)} setups are done. "
|
|
"You've done an excellent job! Congrats! "
|
|
).center(80, "=")
|
|
+ colorama.Style.RESET_ALL
|
|
)
|
|
logger.info(
|
|
colorama.Fore.YELLOW
|
|
+ colorama.Style.BRIGHT
|
|
+ "\033[1m"
|
|
+ "=" * 80
|
|
+ colorama.Style.RESET_ALL
|
|
)
|
|
logger.info(f"Existing all workers...")
|
|
self.__control.group_request("exit")
|
|
|
|
def wait(
|
|
self,
|
|
timeout: Optional[int],
|
|
check_status: Tuple[Wss, ...],
|
|
remove_status: Tuple[Wss, ...],
|
|
):
|
|
deadline = None if timeout is None else time.time() + timeout
|
|
left = set(self.__control.worker_names)
|
|
num_jobs_left = len(left)
|
|
logger.info(f"Waiting for {num_jobs_left} jobs.")
|
|
current_status = {name: Wss.UNKNOWN for name in self.__control.worker_names}
|
|
while len(left) > 0:
|
|
logger.debug(
|
|
f"JOBS LEFT: {[str(len([l for l in left if job_type in l])) + ' ' + job_type for job_type in set([job_id.split('/')[0] for job_id in left])]}"
|
|
)
|
|
if len(left) < num_jobs_left:
|
|
num_jobs_left = len(left)
|
|
logger.info(f"Waiting for {num_jobs_left} jobs.")
|
|
if deadline is not None and time.time() > deadline:
|
|
raise TimeoutError(
|
|
f"Timeout waiting for {self.experiment_name, self.trial_name}: {', '.join(sorted(left))}"
|
|
)
|
|
for worker_name, worker_status in self.__control.pulse().items():
|
|
if worker_status in check_status:
|
|
raise worker_base.WorkerException(
|
|
worker_name, worker_status, "experiment is running."
|
|
)
|
|
if worker_status in remove_status:
|
|
if worker_name in current_status:
|
|
logger.debug(
|
|
f"Worker {worker_name} is {worker_status}. Removed from waiting list."
|
|
)
|
|
current_status.pop(worker_name)
|
|
else:
|
|
pass
|
|
else:
|
|
if current_status.get(worker_name, None) != worker_status:
|
|
current_status.update({worker_name: worker_status})
|
|
logger.debug(
|
|
f"Update worker status: {worker_name} -> {worker_status}"
|
|
)
|
|
|
|
left = set(current_status.keys())
|
|
time.sleep(10)
|
|
|
|
def stop(self):
|
|
"""Stop the experiment.
|
|
|
|
Note:
|
|
This method assumes that the controller and scheduler is connected to the correct workers. To ensure this,
|
|
call controller.reconnect before your call controller.stop.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def interrupt(self, wait_timeout=120):
|
|
"""Interrupt the experiment."""
|
|
logger.info("Interrupting experiment")
|
|
self.__control.group_request("interrupt", wait_response=False)
|
|
try:
|
|
self.wait(
|
|
timeout=wait_timeout,
|
|
check_status=(),
|
|
remove_status=(
|
|
Wss.ERROR,
|
|
Wss.LOST,
|
|
Wss.COMPLETED,
|
|
Wss.INTERRUPTED,
|
|
),
|
|
)
|
|
except TimeoutError:
|
|
raise RuntimeError(f"Fail to interrupt workers, timeout={wait_timeout}.")
|
|
|
|
|
|
def run_ray_worker(
|
|
worker_type,
|
|
idx,
|
|
world_size,
|
|
experiment_name,
|
|
trial_name,
|
|
comm: Tuple[rq.Queue, rq.Queue],
|
|
):
|
|
|
|
constants.set_experiment_trial_names(experiment_name, trial_name)
|
|
|
|
import realhf.api.core.system_api as system_api
|
|
from realhf.api.quickstart.entrypoint import (
|
|
QUICKSTART_CONFIG_CLASSES,
|
|
QUICKSTART_EXPR_CACHE_PATH,
|
|
)
|
|
from realhf.base import importing
|
|
|
|
if os.path.exists(QUICKSTART_EXPR_CACHE_PATH):
|
|
for exp_cache in os.listdir(QUICKSTART_EXPR_CACHE_PATH):
|
|
target_cache_name = f"{experiment_name}_{trial_name}.json"
|
|
if exp_cache != target_cache_name:
|
|
continue
|
|
cache_file = os.path.join(QUICKSTART_EXPR_CACHE_PATH, target_cache_name)
|
|
with open(cache_file, "r") as f:
|
|
cache = json.load(f)
|
|
usercode_path = cache["usercode_path"]
|
|
exp_cls_args = OmegaConf.create(cache["args"])
|
|
config_name = cache["config_name"]
|
|
# Import user code to register quickstart experiments.
|
|
importing.import_usercode(usercode_path, "_realhf_user_code")
|
|
# Register the internal experiment.
|
|
exp_cls = QUICKSTART_CONFIG_CLASSES[config_name]
|
|
system_api.register_experiment(
|
|
experiment_name, functools.partial(exp_cls, **exp_cls_args)
|
|
)
|
|
|
|
# Isolate within the same slurm job, among different jobsteps.
|
|
if torch.cuda.is_initialized():
|
|
raise RuntimeError(
|
|
"CUDA already initialized before isolating CUDA devices. This should not happen."
|
|
)
|
|
gpu_utils.isolate_cuda_device(
|
|
worker_type,
|
|
idx,
|
|
world_size,
|
|
experiment_name,
|
|
trial_name,
|
|
)
|
|
if os.environ.get("CUDA_VISIBLE_DEVICES", None):
|
|
logger.debug("CUDA_VISIBLE_DEVICES: %s", os.environ["CUDA_VISIBLE_DEVICES"])
|
|
|
|
# NOTE: Importing these will initialize DeepSpeed/CUDA devices.
|
|
# profiler.import_profiler_registers()
|
|
import realhf.impl.dataset
|
|
import realhf.impl.model
|
|
import realhf.system
|
|
|
|
worker_name = f"{worker_type}/{idx}"
|
|
server = worker_control.make_server(
|
|
"ray",
|
|
worker_name=worker_name,
|
|
experiment_name=experiment_name,
|
|
trial_name=trial_name,
|
|
comm=comm,
|
|
)
|
|
worker = load_worker(worker_type)(server=server)
|
|
try:
|
|
worker.run()
|
|
except Exception as e:
|
|
logging.error("Worker %s failed with exception: %s", worker_name, e)
|
|
logging.error(traceback.format_exc())
|
|
raise e
|
|
|
|
|
|
class RayController:
|
|
"""A controller that uses Ray to manage workers.
|
|
|
|
It uses the basic Controller to configure workers. Besides, it
|
|
launchs all remote workers using Ray, instead of submitting them to
|
|
the scheduler.
|
|
"""
|
|
|
|
def __init__(self, experiment_name, trial_name):
|
|
# base controller will be lazier initialized when launching workers.
|
|
self.__experiment_name = experiment_name
|
|
self.__trial_name = trial_name
|
|
self.__base_controller = None
|
|
|
|
self.__workers_reply_comm = None
|
|
self.__workers_request_comm = None
|
|
self.__workers_ref = None
|
|
|
|
def _launch_workers(
|
|
self, worker_counts: List[Tuple[str, int, system_api.TasksGroup]]
|
|
):
|
|
# Launch remote workers.
|
|
logger.info("Launching remote workers using Ray...")
|
|
self.__workers_ref: Dict[str, ray.ObjectRef] = {}
|
|
self.__workers_request_comm: Dict[str, rq.Queue] = dict()
|
|
self.__workers_reply_comm: Dict[str, rq.Queue] = dict()
|
|
|
|
# Count the total required resources and check whether Ray currently has enough of them.
|
|
cpu = gpu = mem = 0.0
|
|
for worker_type, _, schedule in worker_counts:
|
|
if not isinstance(schedule, List):
|
|
schedule = [schedule]
|
|
for s in schedule:
|
|
cpu += s.scheduling.cpu * s.count
|
|
gpu += s.scheduling.gpu * s.count
|
|
mem += s.scheduling.mem * s.count / 1024 # in GB
|
|
available_resources = ray.available_resources()
|
|
acpu = available_resources.get("CPU", 0)
|
|
agpu = available_resources.get("GPU", 0)
|
|
amem = available_resources.get("memory", 0) / 1024**3
|
|
if acpu < cpu or agpu < gpu or amem < mem:
|
|
logger.critical(
|
|
f"Ray does not have enough resources to launch workers. "
|
|
f"Required: {cpu} CPU, {gpu} GPU, {mem:.2f} GB memory. "
|
|
f"Available: {acpu} CPU, {agpu} GPU, {amem:.2f} GB memory. "
|
|
f"Please launch more Ray nodes otherwise the experiment will get stuck."
|
|
)
|
|
|
|
# Launch ray jobs.
|
|
for worker_type, count, schedule in worker_counts:
|
|
all_schedules: List[system_api.TasksGroup] = []
|
|
if isinstance(schedule, List):
|
|
for s in schedule:
|
|
for _ in range(s.count):
|
|
s_ = copy.deepcopy(s)
|
|
s_.count = 1
|
|
all_schedules.append(s_)
|
|
else:
|
|
for _ in range(schedule.count):
|
|
s_ = copy.deepcopy(schedule)
|
|
s_.count = 1
|
|
all_schedules.append(s_)
|
|
assert len(all_schedules) == count
|
|
comms = [(rq.Queue(maxsize=8), rq.Queue(maxsize=8)) for _ in all_schedules]
|
|
world_size = len(all_schedules)
|
|
if any(sch.scheduling.gpu > 0 for sch in all_schedules):
|
|
# For GPU jobs, use a customized packed scheduling method
|
|
# that sequentially allocates nodes.
|
|
if not all(
|
|
sch.scheduling.gpu == all_schedules[0].scheduling.gpu == 1
|
|
for sch in all_schedules
|
|
):
|
|
raise ValueError(
|
|
"Ray scheduler only supports resource requirements where #GPU=1 or #GPU=0."
|
|
)
|
|
available_nodes = [
|
|
k
|
|
for k in available_resources
|
|
if re.match(r"node:(\b(?:\d{1,3}\.){3}\d{1,3}\b)", k)
|
|
]
|
|
total_gpus = available_resources["GPU"]
|
|
if total_gpus % len(available_nodes) != 0:
|
|
raise ValueError(
|
|
"Cannot schedule Ray jobs to nodes with heterogeneous numbers of GPUs."
|
|
)
|
|
n_gpus_per_node = int(total_gpus // len(available_nodes))
|
|
if total_gpus < count:
|
|
raise RuntimeError(
|
|
"Available GPUs is smaller than the number of scheduled GPU workers."
|
|
)
|
|
|
|
jobs = []
|
|
for node_idx, i in enumerate(range(0, count, n_gpus_per_node)):
|
|
_schedules = all_schedules[i : i + n_gpus_per_node]
|
|
_comms = comms[i : i + n_gpus_per_node]
|
|
for _idx, (comm, sch) in enumerate(zip(_comms, _schedules)):
|
|
# Schedule jobs one-by-one to maintain the order on remote nodes.
|
|
job = ray.remote(
|
|
num_cpus=sch.scheduling.cpu,
|
|
num_gpus=sch.scheduling.gpu,
|
|
memory=sch.scheduling.mem * 1024**2,
|
|
name=f"{worker_type}/{_idx + i}",
|
|
resources={available_nodes[node_idx]: 1 / n_gpus_per_node},
|
|
)(run_ray_worker).remote(
|
|
worker_type,
|
|
_idx + i,
|
|
world_size,
|
|
self.__experiment_name,
|
|
self.__trial_name,
|
|
comm,
|
|
)
|
|
try:
|
|
ray.get(job, timeout=0.1)
|
|
except ray.exceptions.GetTimeoutError:
|
|
pass
|
|
jobs.append(job)
|
|
else:
|
|
# Use the default Ray scheduler, which may have some randomness.
|
|
jobs = [
|
|
ray.remote(
|
|
num_cpus=sch.scheduling.cpu,
|
|
num_gpus=sch.scheduling.gpu,
|
|
memory=sch.scheduling.mem * 1024**2,
|
|
name=f"{worker_type}/{idx}",
|
|
)(run_ray_worker).remote(
|
|
worker_type,
|
|
idx,
|
|
world_size,
|
|
self.__experiment_name,
|
|
self.__trial_name,
|
|
comm,
|
|
)
|
|
for idx, (comm, sch) in enumerate(zip(comms, all_schedules))
|
|
]
|
|
for idx, (job, c) in enumerate(zip(jobs, comms)):
|
|
name = f"{worker_type}/{idx}"
|
|
self.__workers_ref[name] = job
|
|
self.__workers_request_comm[name] = c[0]
|
|
self.__workers_reply_comm[name] = c[1]
|
|
# Perform a poll step on remote jobs to let them raise setup errors,
|
|
# e.g., ImportError, ModuleNotFoundError, etc.
|
|
try:
|
|
ray.get(jobs, timeout=1)
|
|
except ray.exceptions.GetTimeoutError:
|
|
pass
|
|
logger.info(f"Launched {count} {worker_type}.")
|
|
|
|
panel = worker_control.make_control(
|
|
"ray",
|
|
self.__experiment_name,
|
|
self.__trial_name,
|
|
request_comms=self.__workers_request_comm,
|
|
reply_comms=self.__workers_reply_comm,
|
|
)
|
|
self.__base_controller = Controller(
|
|
self.__experiment_name, self.__trial_name, panel
|
|
)
|
|
logger.info("All Ray workers are lauched.")
|
|
|
|
def start(self, experiment: system_api.Experiment, ignore_worker_error=False):
|
|
scheduling: system_api.ExperimentScheduling = experiment.scheduling_setup()
|
|
setup = experiment.initial_setup()
|
|
if not isinstance(setup, list):
|
|
setup = [setup]
|
|
worker_counts = [
|
|
(k, len(getattr(setup[0], k)), getattr(scheduling, k)) for k in WORKER_TYPES
|
|
]
|
|
|
|
env_vars = constants.get_env_vars(
|
|
REAL_MODE=os.environ.get("REAL_MODE", ""),
|
|
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
|
|
REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""),
|
|
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
|
|
REAL_MATH_METADATA_PATH=os.environ.get("REAL_MATH_METADATA_PATH", ""),
|
|
REAL_USE_V2_WORKER=os.getenv("REAL_USE_V2_WORKER", "0"),
|
|
)
|
|
runtime_env = {
|
|
"env_vars": env_vars,
|
|
"working_dir": os.getcwd(),
|
|
}
|
|
logger.info(f"Ray workers runtime env: {runtime_env}")
|
|
ray.init(runtime_env=runtime_env)
|
|
|
|
logger.info("Ray initialized! Ready to run workers.")
|
|
|
|
try:
|
|
self._launch_workers(worker_counts)
|
|
self.__base_controller.start(experiment, ignore_worker_error)
|
|
except Exception as e:
|
|
ray.shutdown()
|
|
raise e
|