mirror of https://github.com/inclusionAI/AReaL
80 lines
2.7 KiB
Python
80 lines
2.7 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
import importlib
|
|
import os
|
|
import traceback
|
|
from typing import Type
|
|
|
|
import realhf.api.core.system_api
|
|
import realhf.base.logging as logging
|
|
|
|
logger = logging.getLogger("system")
|
|
|
|
# NOTE: Workers are configured in the following order.
|
|
# Take special care when adding a new worker type.
|
|
WORKER_TYPES = ["model_worker", "master_worker"]
|
|
USE_V2_WORKER = os.getenv("REAL_USE_V2_WORKER", "0") == "1"
|
|
|
|
|
|
def load_worker(worker_type: str) -> Type:
|
|
assert worker_type in WORKER_TYPES, f"Invalid worker type {worker_type}"
|
|
module = importlib.import_module(worker_type_to_module(worker_type))
|
|
class_name = worker_type_to_class_name(worker_type)
|
|
return getattr(module, class_name)
|
|
|
|
|
|
def worker_type_to_module(worker_type: str):
|
|
if worker_type == "master_worker" and USE_V2_WORKER:
|
|
return "realhf.system.v2." + worker_type
|
|
return "realhf.system." + worker_type
|
|
|
|
|
|
def worker_type_to_class_name(worker_type: str):
|
|
return "".join([w.capitalize() for w in worker_type.split("_")])
|
|
|
|
|
|
def run_worker(
|
|
worker_type, experiment_name, trial_name, worker_name, worker_server_type
|
|
):
|
|
"""Run one worker
|
|
Args:
|
|
worker_type: string, one of the worker types listed above,
|
|
experiment_name: string, the experiment this worker belongs to,
|
|
trial_name: string, the specific trial this worker belongs to,
|
|
worker_name: name given to the worker, typically "<worker_type>/<worker_index>"
|
|
worker_server_type: string, either 'zmq' or 'ray'.
|
|
"""
|
|
worker_class = load_worker(worker_type)
|
|
make_server_fn = getattr(
|
|
importlib.import_module("realhf.system.worker_control"), "make_server"
|
|
)
|
|
server = make_server_fn(
|
|
type_=worker_server_type,
|
|
experiment_name=experiment_name,
|
|
trial_name=trial_name,
|
|
worker_name=worker_name,
|
|
)
|
|
worker = worker_class(server=server)
|
|
try:
|
|
worker.run()
|
|
except Exception as e:
|
|
logger.error("Worker %s failed with exception: %s", worker_name, e)
|
|
logger.error(traceback.format_exc())
|
|
raise e
|
|
|
|
|
|
def make_controller(type_, experiment_name, trial_name):
|
|
module = importlib.import_module("realhf.system.controller")
|
|
if type_ == "zmq":
|
|
control_module = importlib.import_module("realhf.system.worker_control")
|
|
panel = getattr(control_module, "make_control")(
|
|
"zmq", experiment_name, trial_name
|
|
)
|
|
return getattr(module, "Controller")(experiment_name, trial_name, panel)
|
|
elif type_ == "ray":
|
|
return getattr(module, "RayController")(experiment_name, trial_name)
|
|
else:
|
|
raise NotImplementedError(f"Unknown controller type {type_}.")
|