diff --git a/examples/run_sync_ppo.sh b/examples/run_sync_ppo.sh index d36964c..8a05e08 100644 --- a/examples/run_sync_ppo.sh +++ b/examples/run_sync_ppo.sh @@ -4,10 +4,10 @@ python3 training/main_sync_ppo.py \ allocation_mode=sglang.d4p1m1+d2p2m1 \ cluster.fileroot=/storage/testing/experiments \ actor.type._class=qwen3 \ - actor.path=/storage/testing/models/Qwen__Qwen3-1.7B \ + actor.path=Qwen/Qwen3-1.7B \ ref.type._class=qwen3 \ - ref.path=/storage/testing/models/Qwen__Qwen3-1.7B \ - dataset.path=/storage/testing/dataset/boba_106k_0319.jsonl \ + ref.path=Qwen/Qwen3-1.7B \ + dataset.path=hf-dataset://inclusionAI/AReaL-RL-Data/data/boba_106k_0319.jsonl \ dataset.train_bs_n_seqs=32 \ group_size=8 \ ppo.gen.max_new_tokens=4096 \ diff --git a/realhf/api/cli_args.py b/realhf/api/cli_args.py index 8c4cb28..6de76c7 100644 --- a/realhf/api/cli_args.py +++ b/realhf/api/cli_args.py @@ -5,7 +5,9 @@ from typing import Dict, List, Optional, Tuple, Type, Union from omegaconf import MISSING -from realhf.base import pkg_version +from realhf.base import logging, pkg_version + +logger = logging.getLogger("CLI args") ## Data and datasets. ## @@ -351,10 +353,6 @@ class SGLangConfig: tp_size=tp_size, # Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process base_gpu_id=base_gpu_id, - file_storage_path=os.path.join( - constants.SGLANG_CACHE_PATH, - f"sglang_storage{server_index}", - ), # Data parallelism dp_size=1, # TODO: check whether we require SGLang dp load_balance_method="round_robin", @@ -870,6 +868,30 @@ def get_user_tmp(): return user_tmp +@dataclass +class NameResolveConfig: + type: str = field( + default="nfs", + metadata={ + "help": "Type of the distributed KV store for name resolving.", + "choices": ["nfs", "etcd3", "ray"], + }, + ) + nfs_record_root: str = field( + default="/tmp/areal/name_resolve", + metadata={ + "help": "Record root for NFS name resolving. Should be available in all nodes." + }, + ) + etcd3_addr: str = field( + default="localhost:2379", metadata={"help": "Address of the ETCD3 server."} + ) + ray_actor_name: str = field( + default="ray_kv_store", + metadata={"help": "Name of the distributed Ray KV store."}, + ) + + @dataclass class ClusterSpecConfig: config_path: str = field( @@ -878,6 +900,10 @@ class ClusterSpecConfig: "help": "JSON config path. If not given, use the following CLI args." }, ) + name_resolve: NameResolveConfig = field( + default_factory=NameResolveConfig, + metadata={"help": "Name resolving configuration."}, + ) cluster_name: str = field( default="local", metadata={"help": "Name of the cluster. Used to set specific environs."}, diff --git a/realhf/api/core/config.py b/realhf/api/core/config.py index b516d2d..df78394 100644 --- a/realhf/api/core/config.py +++ b/realhf/api/core/config.py @@ -6,7 +6,6 @@ import dataclasses import enum from typing import Any, Dict, List, Optional -import realhf.base.cluster as cluster import realhf.base.topology as topology @@ -140,7 +139,7 @@ class ModelShardID: ) def __repr__(self): - n = cluster.spec.suffix_n_digits + n = len(str(self.topo.world_size())) return f"{self.model_name}@pp{self.pp_rank:0{n}d}@tp{self.tp_rank:0{n}d}@dp{self.dp_rank:0{n}d}" def __hash__(self): diff --git a/realhf/api/core/data_api.py b/realhf/api/core/data_api.py index e2e5951..ce6d9bf 100644 --- a/realhf/api/core/data_api.py +++ b/realhf/api/core/data_api.py @@ -40,7 +40,6 @@ from pydantic import field_validator, model_validator from realhf.api.cli_args import MicroBatchSpec from realhf.api.core import config as config_api from realhf.base import constants, datapack, logging, seeding -from realhf.base.cluster import spec as cluster_spec from realhf.utils import load_hf_or_local_file logger = logging.getLogger("api.data") @@ -754,11 +753,11 @@ def get_shuffle_indices(seed: int, size: int): def load_shuffle_split_dataset( util: DatasetUtility, - dataset_path: str, + dataset_path: Optional[str] = None, dataset_builder: Optional[Callable[[], List[Dict[str, str]]]] = None, ): - dataset_path = load_hf_or_local_file(dataset_path) if dataset_path is not None: + dataset_path = load_hf_or_local_file(dataset_path) if dataset_path.endswith(".jsonl"): with open(dataset_path, "r") as f: data = [json.loads(ff) for ff in f] @@ -808,17 +807,12 @@ def make_dataset( dp_rank: int, world_size: int, tokenizer_or_tokenizer_name: Union[transformers.PreTrainedTokenizerFast, str], - experiment_name: str, - trial_name: str, - cache_root: Optional[str] = None, ) -> torch.utils.data.Dataset: if isinstance(cfg, str): cfg = config_api.DatasetAbstraction(type_=cfg) if isinstance(tokenizer_or_tokenizer_name, str): tokenizer = load_hf_tokenizer(tokenizer_or_tokenizer_name) - elif tokenizer_or_tokenizer_name is None: - raise RuntimeError("tokenizer_or_tokenizer_name cannot be None.") else: tokenizer = tokenizer_or_tokenizer_name util = DatasetUtility( @@ -827,46 +821,8 @@ def make_dataset( world_size, tokenizer, ) - - if cache_root is None: - dataset_cls = ALL_DATASET_CLASSES[cfg.type_] - return dataset_cls(util=util, **cfg.args) - - # Create and check cache path. - if not cache_root.startswith(cluster_spec.fileroot) and not cache_root.startswith( - "/home" - ): - raise ValueError( - f"Data cache path {cache_root} should be /home or under {cluster_spec.fileroot}." - ) - if "_" in experiment_name or "_" in trial_name: - raise ValueError(f"Invalid experiment/trial name.") - - output_path = os.path.join( - cache_root, - experiment_name, - trial_name, - cfg.type_, - f"seed{seed}", - f"world_size{world_size}", - f"rank{dp_rank}", - ) - os.makedirs(output_path, exist_ok=True) - - fname = "dataset.pt" - cache_found = os.path.isfile(os.path.join(output_path, fname)) - - tik = time.perf_counter() - if not cache_found: - logger.info(f"No data cache found for rank {dp_rank}. Create it from scratch.") - dataset = ALL_DATASET_CLASSES[cfg.type_](seed, dp_rank, world_size, **cfg.args) - torch.save(dataset, os.path.join(output_path, fname)) - else: - logger.info(f"Rank {dp_rank} find existing data cache, load it.") - dataset = torch.load(os.path.join(output_path, fname)) - logger.info(f"Dataset creation/loading time: {time.perf_counter() - tik:.3f}s") - - return dataset + dataset_cls = ALL_DATASET_CLASSES[cfg.type_] + return dataset_cls(util=util, **cfg.args) def gather_stat(src: List[Dict]) -> Dict: diff --git a/realhf/api/core/system_api.py b/realhf/api/core/system_api.py index 6dda819..ea30213 100644 --- a/realhf/api/core/system_api.py +++ b/realhf/api/core/system_api.py @@ -25,7 +25,6 @@ from realhf.api.core.config import ( StandaloneModelShardAbstraction, ) from realhf.base import constants, topology -from realhf.base.cluster import spec as cluster_spec class ExpStatus(Enum): @@ -49,66 +48,6 @@ class Scheduling: begin: Optional[str] = None # see "--begin" option for format deadline: Optional[str] = None # see "--deadline" option for format - @staticmethod - def master_worker_default(**kwargs): - return Scheduling( - **{ - "cpu": 16, - "mem": 20 * 1024, - "gpu": 0, - "container_image": cluster_spec.cpu_image, - **kwargs, - } - ) - - @staticmethod - def model_worker_default(**kwargs): - return Scheduling( - **{ - "cpu": 2, - "gpu": 1, - "mem": 60 * 1024, - "container_image": cluster_spec.gpu_image, - **kwargs, - } - ) - - @staticmethod - def generation_server_default(**kwargs): - return Scheduling( - **{ - "cpu": 4, - "gpu": 1, - "mem": 60 * 1024, - "container_image": cluster_spec.gpu_infer_image, - **kwargs, - } - ) - - @staticmethod - def gserver_manager_default(**kwargs): - return Scheduling( - **{ - "cpu": 4, - "gpu": 0, - "mem": 10 * 1024, - "container_image": cluster_spec.gpu_image, - **kwargs, - } - ) - - @staticmethod - def rollout_worker_default(**kwargs): - return Scheduling( - **{ - "cpu": 4, - "gpu": 0, - "mem": 20 * 1024, - "container_image": cluster_spec.gpu_image, - **kwargs, - } - ) - @dataclasses.dataclass class WorkerInformation: @@ -159,8 +98,6 @@ class ModelWorker: # dataset, for source model workers tokenizer_name_or_path: Optional[str] = None datasets: Optional[List[Union[str, DatasetAbstraction]]] = None - use_dataset_cache: bool = False - dataset_cahce_root: str = constants.DATASET_CACHE_PATH shuffle_dataset: bool = True cuda_cache_cleanliness: bool = True cuda_cache_clear_freq: int = 10 @@ -215,8 +152,6 @@ class RolloutWorker: env: EnvServiceAbstraction agent: AgentAbstraction datasets: List[Union[str, DatasetAbstraction]] - use_dataset_cache: bool = False - dataset_cahce_root: str = constants.DATASET_CACHE_PATH worker_info: WorkerInformation = None @@ -290,16 +225,9 @@ class ExperimentConfig: assert constants.trial_name() is not None assert constants.experiment_name() is not None - graph_path = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "dataflow_graph.png", - ) - os.makedirs(os.path.dirname(graph_path), exist_ok=True) # If verbose set to True here, every worker will print the graph once # due to lazy init on workers. - G = dfg.build_graph(self.model_rpcs, verbose=False, graph_path=graph_path) + G = dfg.build_graph(self.model_rpcs, verbose=False) for rpc in self.model_rpcs: rpc._G = G @@ -549,4 +477,12 @@ def register_experiment(name, cls): def make_experiment(name) -> Experiment: cls = ALL_EXPERIMENT_CLASSES[name] - return cls() + args = cls() + if args.cluster.config_path: + from realhf.base.cluster import load_spec_from_file + + load_spec_from_file(args.cluster) + from realhf.base import name_resolve + + name_resolve.reconfigure(args.cluster.name_resolve) + return args diff --git a/realhf/api/quickstart/device_mesh.py b/realhf/api/quickstart/device_mesh.py index 84528cb..49f3a44 100644 --- a/realhf/api/quickstart/device_mesh.py +++ b/realhf/api/quickstart/device_mesh.py @@ -8,9 +8,8 @@ from typing import List, Optional, Tuple, Union import numpy as np -from realhf.api.cli_args import ParallelismConfig +from realhf.api.cli_args import ClusterSpecConfig, ParallelismConfig from realhf.api.core.dfg import MFCDef -from realhf.base.cluster import spec as cluster_spec from realhf.base.slurm_utils import are_ones_contiguous, parse_nodelist @@ -76,7 +75,6 @@ class DeviceMesh: return device_mesh def __post_init__(self): - n = cluster_spec.suffix_n_digits assert self._is_valid_mapping() def __eq__(self, other: "DeviceMesh"): @@ -179,7 +177,10 @@ class DeviceMesh: def make_device_mesh_from_name( - global_mesh_name: str, name: str, n_gpus_per_node: int = 8 + cluster: ClusterSpecConfig, + global_mesh_name: str, + name: str, + n_gpus_per_node: int = 8, ): """ DeviceMesh name format: [:] @@ -191,8 +192,8 @@ def make_device_mesh_from_name( Note: cluster device mesh name must occupy entire nodes. """ - prefix = cluster_spec.node_name_prefix - node_list = parse_nodelist(global_mesh_name, prefix) + prefix = cluster.node_name_prefix + node_list = parse_nodelist(cluster, global_mesh_name, prefix) n_nodes = len(node_list) gpu_ids = None @@ -202,7 +203,7 @@ def make_device_mesh_from_name( assert all(gpu_id < n_gpus_per_node for gpu_id in gpu_ids) else: node_names = name - node_names = parse_nodelist(node_names, prefix) + node_names = parse_nodelist(cluster, node_names, prefix) mapping = np.zeros((n_nodes, n_gpus_per_node), dtype=np.int32) if gpu_ids is None: node_indices = [node_list.index(node_name) for node_name in node_names] diff --git a/realhf/api/quickstart/entrypoint.py b/realhf/api/quickstart/entrypoint.py index 389321a..97b8169 100644 --- a/realhf/api/quickstart/entrypoint.py +++ b/realhf/api/quickstart/entrypoint.py @@ -16,24 +16,22 @@ from hydra.core.config_store import ConfigStore from omegaconf import MISSING, OmegaConf import realhf.api.core.system_api as system_api -from realhf.base.constants import init_constants +from realhf.base.constants import ( + QUICKSTART_EXPR_CACHE_PATH, + get_log_path, + get_save_path, +) from realhf.base.ray_utils import check_ray_availability from realhf.base.slurm_utils import check_slurm_availability def kind_reminder(config_name, logger, args): - from realhf.base.constants import LOG_ROOT, MODEL_SAVE_ROOT - logger.info(f"Running {config_name} experiment.") + logger.info(f"Logs will be dumped to {get_log_path(args)}") logger.info( - f"Logs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name)}" - ) - logger.info( - f"Experiment configs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name, 'config.yaml')}" - ) - logger.info( - f"Model checkpoints will be saved to {os.path.join(MODEL_SAVE_ROOT, args.experiment_name, args.trial_name)}" + f"Experiment configs will be dumped to {os.path.join(get_log_path(args), 'config.yaml')}" ) + logger.info(f"Model checkpoints will be saved to {get_save_path(args)}") if args.mode == "slurm": slurm_available = check_slurm_availability() @@ -82,12 +80,7 @@ def register_quickstart_exp(config_name: str, exp_cls: Callable): trial_name = args.trial_name from realhf.apps.main import main_start, main_stop - init_constants(args) - from realhf.base.constants import LOG_ROOT, QUICKSTART_EXPR_CACHE_PATH - - config_save_path = os.path.join( - LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml" - ) + config_save_path = os.path.join(get_log_path(args), "config.yaml") os.makedirs(os.path.dirname(config_save_path), exist_ok=True) with open(config_save_path, "w") as f: yaml.dump( diff --git a/realhf/apps/main.py b/realhf/apps/main.py index 146655e..e1ca087 100644 --- a/realhf/apps/main.py +++ b/realhf/apps/main.py @@ -94,7 +94,7 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0): raise RuntimeError("Experiment initial setup failed.") from e evaluator = ( - AutomaticEvaluator(exp_cfg.evaluator, exp_cfg.wandb, exp_cfg.swanlab) + AutomaticEvaluator(exp_cfg, exp_cfg.evaluator, exp_cfg.wandb, exp_cfg.swanlab) if exp_cfg.auto_eval else None ) @@ -116,7 +116,7 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0): is_recover_run = recover_count > 0 if args.recover_mode == "auto": try: - recover.discover_ckpt(args.experiment_name, args.trial_name) + recover.discover_ckpt(experiment) is_recover_run = True except recover.InValidRecoverCkpt as e: logger.warning( @@ -127,7 +127,7 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0): is_recover_run = False if is_recover_run: recover_ckpt_path, model_ckpt_dirs, recover_info = recover.discover_ckpt( - args.experiment_name, args.trial_name + experiment ) logger.info(f"Will load recover info from {recover_ckpt_path}.") logger.info(f"Will load model checkpoints from {model_ckpt_dirs}.") @@ -138,19 +138,9 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0): ) save_recover_states = args.recover_mode != "disabled" - cluster_spec_path = os.environ.get("CLUSTER_SPEC_PATH", "") - if not cluster_spec_path: - logger.info( - "Environment variable CLUSTER_SPEC_PATH is not set. " - "Will use the fileroot specified in CLI args. " - ) - else: - logger.warning( - "Environment variable CLUSTER_SPEC_PATH is set. " - "Will overwrite the cluster spec in CLI args. " - ) # set env vars BASE_ENVIRONS = constants.get_env_vars( + experiment, REAL_MODE=args.mode.upper(), REAL_RECOVER_RUN="1" if is_recover_run else "0", REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0", @@ -160,10 +150,7 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0): # setup experiments sched = sched_client.make( - mode=scheduler_mode(args.mode), - expr_name=expr_name, - trial_name=trial_name, - schedule_strategy=args.schedule_strategy, + experiment, evaluator=evaluator, job_group_id=job_group_id, job_group_index=recover_count, @@ -302,11 +289,8 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0): def main_stop(args): - sched = sched_client.make( - mode=scheduler_mode(args.mode), - expr_name=args.experiment_name, - trial_name=args.trial_name, - ) + experiment = config_package.make_experiment(args.experiment_name) + sched = sched_client.make(experiment) sched.find_all() sched.stop_all() diff --git a/realhf/apps/quickstart.py b/realhf/apps/quickstart.py index b2a5222..f49227e 100644 --- a/realhf/apps/quickstart.py +++ b/realhf/apps/quickstart.py @@ -16,7 +16,6 @@ from rich.panel import Panel from realhf.api.cli_args import console, highlighter, print_config_help from realhf.api.quickstart.entrypoint import QUICKSTART_CONFIG_CLASSES, QUICKSTART_FN -from realhf.base.cluster import spec as cluster_spec from realhf.base.importing import import_module from realhf.base.prologue import ( PROLOGUE_EXTERNAL_CONFIG_NAME, @@ -36,7 +35,6 @@ import_module( str(pathlib.Path(__file__).resolve().parent.parent / "experiments" / "async_exp"), re.compile(r".*_exp\.py$"), ) -import realhf.experiments.benchmark.profile_exp def print_help(exp_type): @@ -144,7 +142,7 @@ def prepare_hydra_config(name: str, prologue_path: str): config = OmegaConf.load(prologue_path) experiment_name = get_experiment_name(config.get("experiment_name")) trial_name = get_trial_name(config.get("trial_name")) - config_dir = f"{cluster_spec.fileroot}/configs/{getpass.getuser()}/{experiment_name}/{trial_name}" + config_dir = f"{config.cluster.fileroot}/configs/{getpass.getuser()}/{experiment_name}/{trial_name}" os.makedirs(config_dir, exist_ok=True) config.pop(PROLOGUE_EXTERNAL_CONFIG_NAME, {}) diff --git a/realhf/apps/remote.py b/realhf/apps/remote.py index c73fe99..e858940 100644 --- a/realhf/apps/remote.py +++ b/realhf/apps/remote.py @@ -21,7 +21,7 @@ from omegaconf import OmegaConf multiprocessing.set_start_method("spawn", force=True) from realhf.api.quickstart.entrypoint import QUICKSTART_CONFIG_CLASSES -from realhf.base import gpu_utils, importing, logging +from realhf.base import gpu_utils, importing, logging, name_resolve from realhf.version import get_full_version_with_dirty_description logger = logging.getLogger("Main-Workers") @@ -61,7 +61,6 @@ def main_worker(args): import realhf.api.core.system_api as system_api experiment = system_api.make_experiment(name=args.experiment_name) - constants.init_constants(experiment) worker_index_start = args.jobstep_id * args.wprocs_per_jobstep + args.wproc_offset worker_index_end = min( @@ -166,6 +165,8 @@ def main_controller(args): constants.set_experiment_trial_names(args.experiment_name, args.trial_name) _patch_external_impl(args.experiment_name, args.trial_name) + experiment = system_api.make_experiment(args.experiment_name) + logger.debug("Running controller with args: %s", args) assert not args.experiment_name.startswith("/"), args.experiment_name try: @@ -177,10 +178,6 @@ def main_controller(args): experiment_name=args.experiment_name, trial_name=args.trial_name, ) - experiment = system_api.make_experiment(args.experiment_name) - - # Initialize cluster infor from ENV or CLI args. - constants.init_constants(experiment) controller.start( experiment=experiment, diff --git a/realhf/base/cluster.py b/realhf/base/cluster.py index f4d59d3..efda469 100644 --- a/realhf/base/cluster.py +++ b/realhf/base/cluster.py @@ -3,143 +3,23 @@ # Licensed under the Apache License, Version 2.0 (the "License"). import json -import os from typing import TYPE_CHECKING, Dict if TYPE_CHECKING: - from realhf.api.cli_args import BaseExperimentConfig + from realhf.api.cli_args import ClusterSpecConfig -class ClusterSpec: - def __init__(self): - # Set default values to comfort ray - from realhf.api.cli_args import BaseExperimentConfig +def load_spec_from_file(config: "ClusterSpecConfig"): + with open(config.config_path, "r") as f: + spec: Dict = json.load(f) - self.load_spec_from_args(BaseExperimentConfig()) - - self.__loaded = False - - def load_spec_from_file(self, file_path: str): - if not os.path.exists(file_path): - raise FileNotFoundError(f"Cluster spec file not found: {file_path}") - - with open(file_path, "r") as f: - spec: Dict = json.load(f) - - self.__cluster_type = spec["cluster_type"] - self.__cluster_name = spec["cluster_name"] - self.__fileroot = spec["fileroot"] - self.__gpu_type = spec.get("gpu_type", None) - self.__mount = spec.get("default_mount", None) - self.__gpu_image = spec.get("gpu_image", None) - self.__gpu_infer_image = spec.get("gpu_infer_image", self.__gpu_image) - self.__cpu_image = spec.get("cpu_image", None) - self.__node_name_prefix = spec.get("node_name_prefix", "slurmd-") - # self.__n_nodes decides number of digits in slurm hostnames - # e.g. if __n_nodes = 32, then the hostnames will be slurmd-{:02d} - # if __n_nodes = 128, then the hostnames will be slurmd-{:03d} - self.__n_nodes = int(spec.get("n_nodes", 32)) - self.__n_gpus_per_node = int(spec.get("n_gpus_per_node", 8)) - assert isinstance(self.__n_nodes, int) - - self.__loaded = True - - def load_spec_from_args(self, args: "BaseExperimentConfig"): - self.__cluster_type = args.mode - self.__cluster_name = args.cluster.cluster_name - self.__fileroot = args.cluster.fileroot - self.__gpu_type = args.cluster.gpu_type - self.__mount = args.cluster.mount - self.__gpu_image = args.cluster.gpu_image - self.__gpu_infer_image = args.cluster.gpu_infer_image - self.__cpu_image = args.cluster.cpu_image - self.__node_name_prefix = args.cluster.node_name_prefix - self.__n_nodes = args.cluster.n_nodes - self.__n_gpus_per_node = args.cluster.n_gpus_per_node - self.__loaded = True - - @property - def name(self): - assert self.__loaded - return self.__cluster_name - - @property - def gpu_type(self): - assert self.__loaded - return self.__gpu_type - - @property - def fileroot(self) -> str: - """Return the root directory of the file system in the cluster. - - When running experiments, files such as logs, checkpoints, - caches will be saved under this directory. - """ - assert self.__loaded - return self.__fileroot - - @fileroot.setter - def fileroot(self, root: str): - # Used for testing - self.__fileroot = root - - @property - def mount(self) -> str: - """Directories that should be mounted to container that runs - workers.""" - assert self.__loaded - return self.__mount - - @property - def gpu_image(self) -> str: - """Return the default image for containers of GPU trainer workers.""" - assert self.__loaded - return self.__gpu_image - - @property - def gpu_infer_image(self) -> str: - """Return the default image for containers of GPU inference workers.""" - assert self.__loaded - return self.__gpu_infer_image - - @property - def cpu_image(self) -> str: - """Return the default image for containers of CPU workers.""" - assert self.__loaded - return self.__cpu_image - - @property - def node_name_prefix(self) -> str: - """Return the prefix of node names in slurm format.""" - assert self.__loaded - return self.__node_name_prefix - - @property - def n_nodes(self) -> int: - return self.__n_nodes - - @property - def suffix_n_digits(self) -> int: - return len(str(self.__n_nodes)) - - @property - def n_gpus_per_node(self) -> int: - return self.__n_gpus_per_node - - @property - def cluster_type(self) -> str: - return self.__cluster_type - - -spec = ClusterSpec() - - -def init_cluster_spec(args: "BaseExperimentConfig"): - global spec - CLUSTER_SPEC_PATH = os.environ.get("CLUSTER_SPEC_PATH", "") - if args.cluster.config_path: - spec.load_spec_from_file(args.cluster.config_path) - elif CLUSTER_SPEC_PATH: - spec.load_spec_from_file(CLUSTER_SPEC_PATH) - else: - spec.load_spec_from_args(args) + config.cluster_name = spec["cluster_name"] + config.fileroot = spec["fileroot"] + config.gpu_type = spec.get("gpu_type", None) + config.mount = spec.get("default_mount", None) + config.gpu_image = spec.get("gpu_image", None) + config.gpu_infer_image = spec.get("gpu_infer_image", config.gpu_image) + config.cpu_image = spec.get("cpu_image", None) + config.node_name_prefix = spec.get("node_name_prefix", "slurmd-") + config.n_nodes = int(spec.get("n_nodes", 32)) + config.n_gpus_per_node = int(spec.get("n_gpus_per_node", 8)) diff --git a/realhf/base/constants.py b/realhf/base/constants.py index eb63956..73962e2 100644 --- a/realhf/base/constants.py +++ b/realhf/base/constants.py @@ -4,13 +4,10 @@ # log format constants import contextlib -import copy import datetime import getpass import os import pathlib -import subprocess -from collections import defaultdict from pathlib import Path from typing import * @@ -69,139 +66,115 @@ TORCH_FORCE_CPU = False # constants in experiment instance scope LOCAL_CACHE_DIR = "/tmp/realhf" +QUICKSTART_EXPR_CACHE_PATH = str(Path(__file__).parent.parent.parent / ".cache") +os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True) +PORT_LOCKFILE_ROOT = os.getenv("AREAL_PORT_LOCKFILE_ROOT", "/tmp/areal/ports/") +os.makedirs(PORT_LOCKFILE_ROOT, exist_ok=True) + PYTORCH_KERNEL_CACHE_PATH = ( f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels" ) TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton" -QUICKSTART_EXPR_CACHE_PATH = str(Path(__file__).parent.parent.parent / ".cache") os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True) os.makedirs(TRITON_CACHE_PATH, exist_ok=True) -os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True) - -# Global constants that should be initialized after cluster initialization. -MODEL_SAVE_ROOT = None -LOG_ROOT = None -RECOVER_ROOT = None -SLURM_LOCK_FILE_NAME = None -PORT_LOCK_FILE_ROOT = None -DATASET_CACHE_PATH = None -PROFILER_CACHE_PATH = None -PARAM_REALLOC_PATH = None -SGLANG_CACHE_PATH = None -TORCH_EXTENSIONS_DIR = None -BASE_ENVIRONS = None -def init_constants(args: "BaseExperimentConfig"): - from realhf.base.cluster import init_cluster_spec - from realhf.base.cluster import spec as cluster_spec +def get_cache_path(args: "BaseExperimentConfig") -> str: + path = f"{args.cluster.fileroot}/.cache/{getpass.getuser()}/{args.experiment_name}/{args.trial_name}" + os.makedirs(path, exist_ok=True) + return path - init_cluster_spec(args) - globals_dict = globals() # Get module's global variables +def get_log_root(args: "BaseExperimentConfig") -> str: + log_root = f"{args.cluster.fileroot}/logs/{getpass.getuser()}" + os.makedirs(log_root, exist_ok=True) + return log_root - kwargs = dict( - MODEL_SAVE_ROOT=f"{cluster_spec.fileroot}/checkpoints/{getpass.getuser()}", - LOG_ROOT=f"{cluster_spec.fileroot}/logs/{getpass.getuser()}", - RECOVER_ROOT=f"{cluster_spec.fileroot}/recover/{getpass.getuser()}", - SLURM_LOCK_FILE_NAME=f"{cluster_spec.fileroot}/logs/slurm_scheduler.lock", - PORT_LOCK_FILE_ROOT=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/ports", - DATASET_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/datasets", - PROFILER_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/profiler", - PARAM_REALLOC_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/param_realloc", - SGLANG_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/sglang", - TORCH_EXTENSIONS_DIR=( - f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/torch/extensions" - ), - ) - BASE_ENVIRONS = { - # "PYTHONPATH": "/realhf", - "REAL_IS_REMOTE": "1", - # "NCCL_P2P_DISABLE": "1", - # "NCCL_IB_DISABLE": "1", - "TRANSFORMERS_OFFLINE": "1", - "PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH, - "TRITON_CACHE_DIR": TRITON_CACHE_PATH, - "TOKENIZERS_PARALLELISM": "true", - "TORCH_EXTENSIONS_DIR": kwargs["TORCH_EXTENSIONS_DIR"], - # "TORCH_DISTRIBUTED_DEBUG": "DETAIL", - # "NCCL_SOCKET_IFNAME": "ibp71s0", - # "GLOO_SOCKET_IFNAME": "ibp71s0", - # "TORCH_USE_CUDA_DSA": "1", - # "NCCL_IGNORE_DISABLED_P2P": "1", - # "CUDA_LAUNCH_BLOCKING": "1", # NOTE: CUDAGraph Capturing will not work if CUDA_LAUNCH_BLOCKING is set to 1. - # "NCCL_COMM_BLOCKING": "1", # NOTE: CUDAGraph Capturing will not work if NCCL_COMM_BLOCKING is set to 1. - # "NCCL_BLOCKING_WAIT": "1", # NOTE: CUDAGraph Capturing will not work if NCCL_BLOCKING_WAIT is set to 1. - # "TORCH_SHOW_CPP_STACKTRACES": "1", - # "RAY_DEDUP_LOGS": "0", # disable ray log deduplication - "CUDA_DEVICE_MAX_CONNECTIONS": "1", - "OMP_NUM_THREADS": str(min(os.cpu_count(), 32)), - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", - # Whether to enable time mark to plot timelines. - "REAL_CUDA_TMARK": os.getenv("REAL_CUDA_TMARK", "0"), - "REAL_DUMP_TRACE": os.getenv("REAL_DUMP_TRACE", "0"), - "REAL_DUMP_MEMORY": os.getenv("REAL_DUMP_MEMORY", "0"), - "REAL_GPU_MEMORY_KILL_THRESHOLD": os.getenv( - "REAL_GPU_MEMORY_KILL_THRESHOLD", "1.0" - ), - "LC_ALL": "C", - "LANG": "C", - "NCCL_DEBUG": "WARN", - } - kwargs["BASE_ENVIRONS"] = BASE_ENVIRONS - # Set PPU-specific environment variables for stable training. - if cluster_spec.name == "wa180": - logger.warning("Detected PPU. Amending PPU-related environment variables.") - PPU_ENVIRONS = { - "NCCL_DEBUG": "INFO", - "NCCL_IB_DISABLE": "1", - "NCCL_DEBUG_SUBSYS": "INIT", - "NCCL_SET_THREAD_NAME": "1", - "NCCL_IB_HCA": "", - "NCCL_SOCKET_IFNAME": "bond0", - "PCCL_STATE_MONITOR_DISABLE": "1", - } - kwargs["BASE_ENVIRONS"].update(PPU_ENVIRONS) - elif cluster_spec.name == "na132": - # Specific environment variable for h800 cluster na132 - NV_ENVIRONS = { - "NCCL_SOCKET_IFNAME": "bond0", - "NCCL_NET_PLUGIN": "", - "NCCL_IB_GID_INDEX": "3", - "NCCL_IB_TIMEOUT": "2", - "NCCL_IB_RETRY_CNT": "7", - "NCCL_IB_SL": "5", - "NCCL_IB_TC": "136", - "NCCL_IB_HCA": "mlx5_bond", - "NCCL_IB_QPS_PER_CONNECTION": "8", - "NCCL_SET_THREAD_NAME": "1", - "NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH", - } - kwargs["BASE_ENVIRONS"].update(NV_ENVIRONS) - for key, value in kwargs.items(): - if key not in globals_dict: - raise ValueError(f"Invalid constant name: {key}") - if globals_dict[key] is not None and globals_dict[key] != value: - raise RuntimeError(f"Constant '{key}' already initialized!") - globals_dict[key] = value +def get_log_path(args: "BaseExperimentConfig") -> str: + log_path = f"{args.cluster.fileroot}/logs/{getpass.getuser()}/{args.experiment_name}/{args.trial_name}" + os.makedirs(log_path, exist_ok=True) + return log_path - # make directories if does not exist - os.makedirs(globals_dict["PARAM_REALLOC_PATH"], exist_ok=True) - os.makedirs(globals_dict["MODEL_SAVE_ROOT"], exist_ok=True) - os.makedirs(globals_dict["LOG_ROOT"], exist_ok=True) - os.makedirs(globals_dict["RECOVER_ROOT"], exist_ok=True) - os.makedirs(globals_dict["DATASET_CACHE_PATH"], exist_ok=True) - os.makedirs(globals_dict["PROFILER_CACHE_PATH"], exist_ok=True) - os.makedirs(globals_dict["TORCH_EXTENSIONS_DIR"], exist_ok=True) - os.makedirs(globals_dict["PORT_LOCK_FILE_ROOT"], exist_ok=True) - os.makedirs(globals_dict["SGLANG_CACHE_PATH"], exist_ok=True) + +def get_save_root(args: "BaseExperimentConfig") -> str: + path = f"{args.cluster.fileroot}/checkpoints/{getpass.getuser()}" + os.makedirs(path, exist_ok=True) + return path + + +def get_save_path(args: "BaseExperimentConfig") -> str: + path = f"{args.cluster.fileroot}/checkpoints/{getpass.getuser()}/{args.experiment_name}/{args.trial_name}" + os.makedirs(path, exist_ok=True) + return path + + +def get_param_realloc_path(args: "BaseExperimentConfig"): + path = f"{args.cluster.fileroot}/.cache/{getpass.getuser()}/param_realloc" + os.makedirs(path, exist_ok=True) + return path + + +BASE_ENVIRONS = { + # "PYTHONPATH": "/realhf", + "REAL_IS_REMOTE": "1", + # "NCCL_P2P_DISABLE": "1", + # "NCCL_IB_DISABLE": "1", + "TRANSFORMERS_OFFLINE": "1", + "TOKENIZERS_PARALLELISM": "true", + "PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH, + "TRITON_CACHE_DIR": TRITON_CACHE_PATH, + # "TORCH_DISTRIBUTED_DEBUG": "DETAIL", + # "NCCL_SOCKET_IFNAME": "ibp71s0", + # "GLOO_SOCKET_IFNAME": "ibp71s0", + # "TORCH_USE_CUDA_DSA": "1", + # "NCCL_IGNORE_DISABLED_P2P": "1", + # "CUDA_LAUNCH_BLOCKING": "1", # NOTE: CUDAGraph Capturing will not work if CUDA_LAUNCH_BLOCKING is set to 1. + # "NCCL_COMM_BLOCKING": "1", # NOTE: CUDAGraph Capturing will not work if NCCL_COMM_BLOCKING is set to 1. + # "NCCL_BLOCKING_WAIT": "1", # NOTE: CUDAGraph Capturing will not work if NCCL_BLOCKING_WAIT is set to 1. + # "TORCH_SHOW_CPP_STACKTRACES": "1", + # "RAY_DEDUP_LOGS": "0", # disable ray log deduplication + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "OMP_NUM_THREADS": str(min(os.cpu_count(), 32)), + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", + # Whether to enable time mark to plot timelines. + "REAL_DUMP_TRACE": os.getenv("REAL_DUMP_TRACE", "0"), + "REAL_DUMP_MEMORY": os.getenv("REAL_DUMP_MEMORY", "0"), + "REAL_GPU_MEMORY_KILL_THRESHOLD": os.getenv( + "REAL_GPU_MEMORY_KILL_THRESHOLD", "1.0" + ), + "LC_ALL": "C", + "LANG": "C", + "NCCL_DEBUG": "WARN", +} +PPU_ENVIRONS = { + "NCCL_DEBUG": "INFO", + "NCCL_IB_DISABLE": "1", + "NCCL_DEBUG_SUBSYS": "INIT", + "NCCL_SET_THREAD_NAME": "1", + "NCCL_IB_HCA": "", + "NCCL_SOCKET_IFNAME": "bond0", + "PCCL_STATE_MONITOR_DISABLE": "1", +} +NA132_ENVIRONS = { + "NCCL_SOCKET_IFNAME": "bond0", + "NCCL_NET_PLUGIN": "", + "NCCL_IB_GID_INDEX": "3", + "NCCL_IB_TIMEOUT": "2", + "NCCL_IB_RETRY_CNT": "7", + "NCCL_IB_SL": "5", + "NCCL_IB_TC": "136", + "NCCL_IB_HCA": "mlx5_bond", + "NCCL_IB_QPS_PER_CONNECTION": "8", + "NCCL_SET_THREAD_NAME": "1", + "NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH", +} # _model_name will be changed in the model_scope context manager @@ -570,18 +543,21 @@ def get_repo_path() -> pathlib.Path: return pathlib.Path(__file__).resolve().parent.parent.parent -def get_env_vars(**kwargs): +def get_env_vars(exp_cfg: "BaseExperimentConfig", **kwargs): kwargs.update( - CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""), REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"), REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"), FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""), REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"), - REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", "localhost:2379"), REAL_OSS_TESTCASE_PATH=os.getenv("REAL_OSS_TESTCASE_PATH", ""), ) - return { + envvars = { **kwargs, "REAL_PACKAGE_PATH": str(get_repo_path()), **BASE_ENVIRONS, } + if exp_cfg.cluster.cluster_name == "wa180": + envvars.update(**PPU_ENVIRONS) + if exp_cfg.cluster.cluster_name == "na132": + envvars.update(**NA132_ENVIRONS) + return envvars diff --git a/realhf/base/monitor.py b/realhf/base/monitor.py index e2a6e3a..472c54d 100644 --- a/realhf/base/monitor.py +++ b/realhf/base/monitor.py @@ -358,106 +358,6 @@ def calculate_llama_gen_flops( return flops -#################### CUDA Kernel Time Marking Start #################### -# Used to create timeline plots. - - -class CUDATimeMarkType(enum.Enum): - forward = "forward" - backward = "backward" - optim_step = "optim_step" - comm = "comm" - misc = "misc" - mem_layout = "memory_layout" - - -@dataclasses.dataclass -class TimeMarkEntry: - name: str - model_name: "ModelName" - type_: CUDATimeMarkType - start_time: int - end_time: int - - -TIME_MARK_DB = [] - - -def cuda_tmark(name: str, type_: CUDATimeMarkType): - if os.getenv("REAL_CUDA_TMARK", None) == "1": - - def wrapper(f: Callable): - - def _wrapped_f(*args, **kwargs): - import torch - - if constants.use_cuda(): - from realhf.base.constants import _model_name - - torch.cuda.synchronize() - tik = time.time_ns() - res = f(*args, **kwargs) - torch.cuda.synchronize() - tok = time.time_ns() - global TIME_MARK_DB - TIME_MARK_DB.append( - TimeMarkEntry(name, _model_name, type_, tik, tok) - ) - else: - res = f(*args, **kwargs) - return res - - return _wrapped_f - - else: - - def wrapper(f): - return f - - return wrapper - - -@contextlib.contextmanager -def cuda_tmarked(name: str, type_: CUDATimeMarkType): - if os.getenv("REAL_CUDA_TMARK", None) == "1": - import torch - - if constants.use_cuda(): - from realhf.base.constants import _model_name - - torch.cuda.synchronize() - tik = time.time_ns() - yield - if os.getenv("REAL_CUDA_TMARK", None) == "1": - if constants.use_cuda(): - torch.cuda.synchronize() - tok = time.time_ns() - global TIME_MARK_DB - TIME_MARK_DB.append(TimeMarkEntry(name, _model_name, type_, tik, tok)) - - -def fetch_latest_tmark(): - global TIME_MARK_DB - return TIME_MARK_DB[-1] - - -def dump_tmark_db(worker_idx): - if os.getenv("REAL_CUDA_TMARK", None) != "1": - return - fn = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - f"time_marks{worker_idx}.pkl", - ) - global TIME_MARK_DB - with open(fn, "wb") as f: - pickle.dump(TIME_MARK_DB, f) - TIME_MARK_DB.clear() - - -#################### CUDA Kernel Time Marking End #################### - #################### CUDA Kernel Time Statistics Start #################### # Categorizing CUDA kernels into computation, communication, memory IO, and MISC/IDLE, # used to plot the percentage of time spent on each category and show how much we can diff --git a/realhf/base/name_resolve.py b/realhf/base/name_resolve.py index 3ecc53d..9c6ec25 100644 --- a/realhf/base/name_resolve.py +++ b/realhf/base/name_resolve.py @@ -11,7 +11,7 @@ import shutil import threading import time import uuid -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional import ray @@ -22,6 +22,9 @@ except Exception: from realhf.base import logging, security, timeutil +if TYPE_CHECKING: + from realhf.api.cli_args import NameResolveConfig + logger = logging.getLogger("name-resolve") @@ -45,12 +48,6 @@ class NameRecordRepository: except Exception as e: logger.info(f"Exception ignore when deleting NameResolveRepo {e}") - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.reset() - def add( self, name, @@ -287,24 +284,20 @@ class MemoryNameRecordRepository(NameRecordRepository): class NfsNameRecordRepository(NameRecordRepository): - RECORD_ROOT = "" - def __init__(self, **kwargs): + def __init__(self, record_root="", **kwargs): self.__to_delete = set() + self.record_root = record_root - @staticmethod - def __dir_path(name): - if not NfsNameRecordRepository.RECORD_ROOT: - from realhf.base.cluster import spec as cluster_spec + def __dir_path(self, name): + if not self.record_root: + raise RuntimeError( + f"The `record_root` of NfsNameRecordRepository is not properly reconfigured." + ) + return os.path.join(self.record_root, name) - RECORD_ROOT = f"{cluster_spec.fileroot}/name_resolve/" - os.makedirs(RECORD_ROOT, exist_ok=True) - NfsNameRecordRepository.RECORD_ROOT = RECORD_ROOT - return os.path.join(NfsNameRecordRepository.RECORD_ROOT, name) - - @staticmethod - def __file_path(name): - return os.path.join(NfsNameRecordRepository.__dir_path(name), "ENTRY") + def __file_path(self, name): + return os.path.join(self.__dir_path(name), "ENTRY") def add( self, @@ -342,7 +335,7 @@ class NfsNameRecordRepository(NameRecordRepository): os.remove(path) while True: path = os.path.dirname(path) - if path == NfsNameRecordRepository.RECORD_ROOT: + if path == self.record_root: break if len(os.listdir(path)) > 0: break @@ -385,7 +378,7 @@ class NfsNameRecordRepository(NameRecordRepository): continue if files[0] != "ENTRY": continue - key = root.removeprefix(self.RECORD_ROOT) + key = root.removeprefix(self.record_root) key = key.removeprefix("/") rs.append(self.get(key)) except NameEntryNotFoundError: @@ -402,7 +395,7 @@ class NfsNameRecordRepository(NameRecordRepository): continue if files[0] != "ENTRY": continue - key = root.removeprefix(self.RECORD_ROOT) + key = root.removeprefix(self.record_root) key = key.removeprefix("/") rs.append(key) except NameEntryNotFoundError: @@ -571,15 +564,6 @@ class Etcd3NameRecordRepository(NameRecordRepository): TTL-based expiration, atomic operations, and key watching functionality. """ - # Default configuration - try: - host, port = os.getenv("REAL_ETCD_ADDR", "").split(":") - except ValueError: - host, port = "localhost", 2379 - ETCD_HOST = host - ETCD_PORT = int(port) - ETCD_USER = None - ETCD_PASSWORD = None KEEPALIVE_POLL_FREQUENCY = 1 @dataclasses.dataclass @@ -604,14 +588,17 @@ class Etcd3NameRecordRepository(NameRecordRepository): self._lock = threading.Lock() # Set connection parameters - self._host = host or self.ETCD_HOST - self._port = port or self.ETCD_PORT - self._user = user or self.ETCD_USER - self._password = password or self.ETCD_PASSWORD + self._host = host + self._port = port + self._user = user + self._password = password # Connect to etcd self._client = etcd3.client( - host=self._host, port=self._port, user=self._user, password=self._password + host=self._host, + port=self._port, + user=self._user, + password=self._password, ) # Keep track of entries for cleanup and keepalive @@ -835,16 +822,17 @@ class Etcd3NameRecordRepository(NameRecordRepository): def reset(self): """Delete all keys added via this repository instance.""" with self._lock: - count = 0 - for name in self._to_delete: - if name in self._entries: - try: - self._delete_locked(name) - count += 1 - except NameEntryNotFoundError: - pass - self._to_delete = set() - logger.info(f"Reset {count} saved etcd entries") + if hasattr(self, "_to_delete"): + count = 0 + for name in self._to_delete: + if name in self._entries: + try: + self._delete_locked(name) + count += 1 + except NameEntryNotFoundError: + pass + self._to_delete = set() + logger.info(f"Reset {count} saved etcd entries") def _keepalive_thread_run(self): """Background thread to keep leases alive.""" @@ -1099,12 +1087,6 @@ class RayNameResolveRepository: f"Exception ignored when deleting RayNameResolveRepository: {e}" ) - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.reset() - def add( self, name: str, @@ -1376,31 +1358,19 @@ class RayNameResolveRepository: ) -def make_repository(type_="nfs", **kwargs): - if type_ == "memory": - return MemoryNameRecordRepository(**kwargs) - elif type_ == "nfs": - return NfsNameRecordRepository(**kwargs) - elif type_ == "redis": - return RedisNameRecordRepository(**kwargs) - elif type_ == "etcd3": - return Etcd3NameRecordRepository(**kwargs) - elif type_ == "ray": - return RayNameResolveRepository(**kwargs) +def make_repository(args: "NameResolveConfig"): + if args.type == "nfs": + return NfsNameRecordRepository(args.nfs_record_root) + elif args.type == "etcd3": + host, port = args.etcd3_addr.split(":") + return Etcd3NameRecordRepository(host=host, port=int(port)) + elif args.type == "ray": + return RayNameResolveRepository(actor_name=args.ray_actor_name) else: - raise NotImplementedError(f"No such name resolver: {type_}") + raise NotImplementedError(f"No such name resolver: {args.type}") -# DEFAULT_REPOSITORY_TYPE = "redis" if socket.gethostname().startswith("frl") else "nfs" -DEFAULT_REPOSITORY_TYPE = "nfs" -if etcd3 is not None and os.getenv("REAL_ETCD_ADDR", ""): - DEFAULT_REPOSITORY_TYPE = "etcd3" -if os.getenv("REAL_ETCD_ADDR", "") and etcd3 is None: - logger.warning( - f"Detected REAL_ETCD_ADDR but etcd3 client is not available. " - "Please run `pip install -r requirements.txt` if you want to use etcd name resolve." - ) -DEFAULT_REPOSITORY = make_repository(DEFAULT_REPOSITORY_TYPE) +DEFAULT_REPOSITORY = NfsNameRecordRepository() add = DEFAULT_REPOSITORY.add add_subentry = DEFAULT_REPOSITORY.add_subentry delete = DEFAULT_REPOSITORY.delete @@ -1413,11 +1383,10 @@ reset = DEFAULT_REPOSITORY.reset watch_names = DEFAULT_REPOSITORY.watch_names -def reconfigure(*args, **kwargs): - global DEFAULT_REPOSITORY, DEFAULT_REPOSITORY_TYPE +def reconfigure(config: "NameResolveConfig"): + global DEFAULT_REPOSITORY global add, add_subentry, delete, clear_subtree, get, get_subtree, find_subtree, wait, reset, watch_names - DEFAULT_REPOSITORY = make_repository(*args, **kwargs) - DEFAULT_REPOSITORY_TYPE = args[0] + DEFAULT_REPOSITORY = make_repository(config) add = DEFAULT_REPOSITORY.add add_subentry = DEFAULT_REPOSITORY.add_subentry delete = DEFAULT_REPOSITORY.delete diff --git a/realhf/base/network.py b/realhf/base/network.py index e9e338c..5fac8fd 100644 --- a/realhf/base/network.py +++ b/realhf/base/network.py @@ -23,14 +23,20 @@ def gethostip(): def find_free_port( - low=1, high=65536, exclude_ports=None, experiment_name="port", trial_name="port" + low=1, + high=65536, + exclude_ports=None, + experiment_name="port", + trial_name="port", + lockfile_root=constants.PORT_LOCKFILE_ROOT, ): """Find a free port within the specified range, excluding certain ports.""" ports_name = names.used_ports(experiment_name, trial_name, gethostip()) free_port = None - lockfile = os.path.join(constants.PORT_LOCK_FILE_ROOT, gethostip()) + os.makedirs(lockfile_root, exist_ok=True) + lockfile = os.path.join(lockfile_root, gethostip()) while True: with open(lockfile, "w") as fd: # This will block until lock is acquired @@ -58,7 +64,12 @@ def find_free_port( def find_multiple_free_ports( - count, low=1, high=65536, experiment_name="port", trial_name="port" + count, + low=1, + high=65536, + experiment_name="port", + trial_name="port", + lockfile_root=constants.PORT_LOCKFILE_ROOT, ): """Find multiple mutually exclusive free ports.""" free_ports = set() @@ -69,6 +80,7 @@ def find_multiple_free_ports( exclude_ports=free_ports, experiment_name=experiment_name, trial_name=trial_name, + lockfile_root=lockfile_root, ) free_ports.add(port) return list(free_ports) diff --git a/realhf/base/prologue.py b/realhf/base/prologue.py index 31f7a69..62efe5c 100644 --- a/realhf/base/prologue.py +++ b/realhf/base/prologue.py @@ -35,26 +35,6 @@ def global_init(): if key not in os.environ: os.environ[key] = value - # resolve config path for cluster spec. - cluster_spec_path = os.environ.get("CLUSTER_SPEC_PATH", "") - if cluster_spec_path == "": - if external_configs.get("cluster_config"): - fileroot = external_configs.cluster_config.get("fileroot") - if fileroot is not None and fileroot != "": - experiment_name = get_experiment_name(config.get("experiment_name")) - trial_name = get_trial_name(config.get("trial_name")) - config_dir = f"{fileroot}/configs/{getpass.getuser()}/{experiment_name}/{trial_name}" - os.makedirs(config_dir, exist_ok=True) - cluster_spec_path = f"{config_dir}/cluster_config.json" - cluster_spec = OmegaConf.to_container(external_configs.cluster_config) - if "cluster_type" not in cluster_spec: - cluster_spec["cluster_type"] = config.mode - if "cluster_name" not in cluster_spec: - cluster_spec["cluster_name"] = f"{config.mode}_cluster" - with open(cluster_spec_path, "w") as f: - json.dump(cluster_spec, f) - os.environ["CLUSTER_SPEC_PATH"] = cluster_spec_path - def get_experiment_name(default_name: str = ""): if any("experiment_name=" in x for x in sys.argv): diff --git a/realhf/base/recover.py b/realhf/base/recover.py index 32685ea..765fcfd 100644 --- a/realhf/base/recover.py +++ b/realhf/base/recover.py @@ -40,13 +40,11 @@ class RecoverInfo: hash_vals_to_ignore: List[int] = dataclasses.field(default_factory=list) -def dump_recover_info(recover_info: RecoverInfo): +def dump_recover_info(args, recover_info: RecoverInfo): global RECOVER_INFO_PATH if RECOVER_INFO_PATH is None: RECOVER_INFO_PATH = os.path.join( - constants.RECOVER_ROOT, - constants.experiment_name(), - constants.trial_name(), + constants.get_save_path(args), "recover_info.pkl", ) os.makedirs(os.path.dirname(RECOVER_INFO_PATH), exist_ok=True) @@ -54,15 +52,13 @@ def dump_recover_info(recover_info: RecoverInfo): pickle.dump(recover_info, f) -def load_recover_info() -> Tuple[int, Optional[RecoverInfo]]: +def load_recover_info(args) -> Tuple[int, Optional[RecoverInfo]]: if os.environ.get("REAL_RECOVER_RUN", "0") != "1": return False, None global RECOVER_INFO_PATH if RECOVER_INFO_PATH is None: RECOVER_INFO_PATH = os.path.join( - constants.RECOVER_ROOT, - constants.experiment_name(), - constants.trial_name(), + constants.get_save_path(args), "recover_info.pkl", ) os.makedirs(os.path.dirname(RECOVER_INFO_PATH), exist_ok=True) @@ -81,15 +77,9 @@ class InValidRecoverCkpt(Exception): pass -def discover_ckpt( - expr_name: str, trial_name: str -) -> Tuple[str, List[str], RecoverInfo]: - recover_info_file = ( - pathlib.Path(constants.RECOVER_ROOT) - / expr_name - / trial_name - / "recover_info.pkl" - ) +def discover_ckpt(args) -> Tuple[str, List[str], RecoverInfo]: + expr_name, trial_name = args.experiment_name, args.trial_name + recover_info_file = pathlib.Path(constants.get_save_path(args)) / "recover_info.pkl" if os.path.exists(str(recover_info_file)): with open(recover_info_file, "rb") as f: info: RecoverInfo = pickle.load(f) @@ -100,9 +90,7 @@ def discover_ckpt( f"but found {info.last_step_info.epoch}" ) raise InValidRecoverCkpt(msg) - model_save_dir = ( - pathlib.Path(constants.MODEL_SAVE_ROOT) / expr_name / trial_name - ) + model_save_dir = pathlib.Path(constants.get_save_path(args)) model_ckpt_dirs = [] for role in os.listdir(model_save_dir): if "dataset_indices" in role: diff --git a/realhf/base/slurm_utils.py b/realhf/base/slurm_utils.py index c4a1445..ccdfa32 100644 --- a/realhf/base/slurm_utils.py +++ b/realhf/base/slurm_utils.py @@ -9,19 +9,17 @@ from typing import List import numpy as np -from realhf.base.cluster import spec as cluster_spec - def parse_node_id(node_name: str, prefix: str) -> int: return int(node_name.split(prefix)[-1]) -def parse_nodelist(nodelist: str, prefix: str) -> List[str]: +def parse_nodelist(cluster_config, nodelist: str, prefix: str) -> List[str]: if not nodelist.startswith(prefix): raise ValueError( f"Node list `{nodelist}` does not start with hostname prefix `{prefix}`." ) - n = cluster_spec.suffix_n_digits + n = len(str(cluster_config.n_nodes)) nodelist = nodelist.replace(prefix, "") if "[" not in nodelist: return [prefix + nodelist] @@ -38,30 +36,6 @@ def parse_nodelist(nodelist: str, prefix: str) -> List[str]: return [f"{prefix}{node_id:0{n}d}" for node_id in node_ids] -def nodelist_from_nodes(nodes: List[str], prefix: str) -> str: - n = cluster_spec.suffix_n_digits - node_ids = sorted([parse_node_id(node, prefix) for node in nodes]) - assert len(node_ids) > 0 - if len(node_ids) == 1: - return f"{prefix}{node_ids[0]:02d}" - else: - node_reprs = [] - start, end = node_ids[0], node_ids[0] - for i in range(len(node_ids)): - node_id = node_ids[i] - next_node_id = node_ids[i + 1] if i + 1 < len(node_ids) else -1 - if node_id + 1 == next_node_id: - end = next_node_id - else: - if start == end: - node_reprs.append(f"{start:0{n}d}") - else: - node_reprs.append(f"{start:0{n}d}-{end:0{n}d}") - start = next_node_id - end = next_node_id - return f"{prefix}[{','.join(node_reprs)}]" - - def are_ones_contiguous(binary_array: np.ndarray): one_indices = np.where(binary_array == 1)[0] if len(one_indices) == 0: diff --git a/realhf/base/testing.py b/realhf/base/testing.py index b5fcbbc..329f221 100644 --- a/realhf/base/testing.py +++ b/realhf/base/testing.py @@ -19,6 +19,7 @@ import torch import torch.distributed as dist import torch.utils.data +from realhf.api.cli_args import BaseExperimentConfig, NameResolveConfig from realhf.api.core.data_api import SequenceSample from realhf.base import constants, gpu_utils, logging, name_resolve, names, topology from realhf.base.topology import ( @@ -92,6 +93,14 @@ class StandaloneTestingProcess(mp.Process): if constants.use_cuda(): torch.cuda.set_device(0) + from realhf.api.cli_args import NameResolveConfig + + name_resolve.reconfigure( + NameResolveConfig( + "nfs", f"/tmp/areal/testing/{self.expr_name}/{self.trial_name}/" + ) + ) + self.barrier.wait() if self.setup_dist_torch: @@ -103,7 +112,11 @@ class StandaloneTestingProcess(mp.Process): if self.dist_backend is None: self.dist_backend = "gloo" if not constants.use_cuda() else "nccl" setup_global_comm( - self.expr_name, self.trial_name, self.rank, backend=self.dist_backend + BaseExperimentConfig(), + self.expr_name, + self.trial_name, + self.rank, + backend=self.dist_backend, ) # misc setup @@ -149,6 +162,11 @@ class LocalMultiProcessTest: if torch.cuda.is_available(): os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["GPU_DEVICES_ISOLATED"] = str(1) + expr_name = expr_name if expr_name is not None else _DEFAULT_EXPR_NAME + trial_name = trial_name if trial_name is not None else _DEFAULT_TRIAL_NAME + name_resolve.reconfigure( + NameResolveConfig("nfs", f"/tmp/areal/testing/{expr_name}/{trial_name}/") + ) clear_name_resolve(expr_name, trial_name) self.timeout_secs = timeout_secs self.processes = [] diff --git a/realhf/base/topology.py b/realhf/base/topology.py index 4faeaae..e0ab4a9 100644 --- a/realhf/base/topology.py +++ b/realhf/base/topology.py @@ -21,7 +21,6 @@ from typing import Dict, List, NamedTuple, Optional, Tuple import torch.distributed as dist import realhf.base.logging as logging -from realhf.base.cluster import spec as cluster_spec from realhf.base.constants import NCCL_DEFAULT_TIMEOUT logger = logging.getLogger("Topology") @@ -185,7 +184,7 @@ class ProcessTopology: omit_axes = frozenset(omit_axes) axes = [a for a in self.get_axis_names() if a not in omit_axes] names = [] - n = cluster_spec.suffix_n_digits + n = len(str(len(self.mapping))) for ax in axes: ax_rank = getattr(self.get_coord(rank=rank), ax) names.append(f"{ax}{inner_sep}{ax_rank:0{n}d}") diff --git a/realhf/experiments/async_exp/async_ppo_math_exp.py b/realhf/experiments/async_exp/async_ppo_math_exp.py index bb14c67..8c64e2d 100644 --- a/realhf/experiments/async_exp/async_ppo_math_exp.py +++ b/realhf/experiments/async_exp/async_ppo_math_exp.py @@ -2,6 +2,7 @@ import copy import dataclasses +import os from typing import Any, Dict, List, Tuple import realhf.base.logging as logging @@ -13,6 +14,7 @@ from realhf.api.core.config import ( ) from realhf.api.core.model_api import GenerationHyperparameters from realhf.api.quickstart.entrypoint import register_quickstart_exp +from realhf.base import constants from realhf.experiments.async_exp.async_rl_exp import AsyncRLExperimentConfig from realhf.experiments.common.ppo_math_exp import PPOMATHConfig from realhf.experiments.common.utils import asdict @@ -29,6 +31,9 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig): "math-single-step", args=dict( gconfig=self.generation_config, + answer_save_path=os.path.join( + constants.get_log_path(self), "generated" + ), tokenizer_path=self.actor.path, success_rate_lb=self.success_rate_lb, success_rate_ub=self.success_rate_ub, @@ -40,7 +45,10 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig): @property def env(self) -> EnvServiceAbstraction: return EnvServiceAbstraction( - "math-code-single-step", args=dict(dataset_path=self.dataset.path) + "math-code-single-step", + args=dict( + dataset_path=self.dataset.path, + ), ) @property diff --git a/realhf/experiments/async_exp/async_rl_exp.py b/realhf/experiments/async_exp/async_rl_exp.py index b1a2252..3c23f13 100755 --- a/realhf/experiments/async_exp/async_rl_exp.py +++ b/realhf/experiments/async_exp/async_rl_exp.py @@ -38,8 +38,6 @@ from realhf.api.core.system_api import ( TasksGroup, ) from realhf.api.quickstart.device_mesh import RPCAllocation -from realhf.base.cluster import spec as cluster_spec -from realhf.experiments.common.check import check_valid_sglang, check_valid_vllm from realhf.experiments.common.common import CommonExperimentConfig from realhf.experiments.common.utils import ( AllocationMode, @@ -92,49 +90,57 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions): return ExperimentScheduling( master_worker=TasksGroup( count=1, - scheduling=Scheduling.master_worker_default( + scheduling=Scheduling( cpu=self.cpus_per_master_worker, + gpu=0, mem=self.mem_per_master_worker, nodelist=self.nodelist, exclude=self.exclude, + container_image=self.cluster.cpu_image, ), ), model_worker=TasksGroup( count=train_world_size, - scheduling=Scheduling.model_worker_default( + scheduling=Scheduling( cpu=self.cpus_per_model_worker, gpu=1, mem=self.mem_per_model_worker, nodelist=self.nodelist, exclude=self.exclude, + container_image=self.cluster.gpu_image, ), ), generation_server=TasksGroup( count=gen_world_size // gen_tp_size, - scheduling=Scheduling.generation_server_default( + scheduling=Scheduling( cpu=self.cpus_per_generation_server, gpu=gen_tp_size, mem=self.mem_per_generation_server, nodelist=self.nodelist, exclude=self.exclude, + container_image=self.cluster.gpu_infer_image, ), ), gserver_manager=TasksGroup( count=1, - scheduling=Scheduling.gserver_manager_default( + scheduling=Scheduling( cpu=self.cpus_per_gserver_manager, + gpu=0, mem=self.mem_per_gserver_manager, nodelist=self.nodelist, exclude=self.exclude, + container_image=self.cluster.cpu_image, ), ), rollout_worker=TasksGroup( count=self.n_rollout_workers or train_world_size, - scheduling=Scheduling.rollout_worker_default( + scheduling=Scheduling( cpu=self.cpus_per_rollout_worker, + gpu=0, mem=self.mem_per_rollout_worker, nodelist=self.nodelist, exclude=self.exclude, + container_image=self.cluster.cpu_image, ), ), ) @@ -162,7 +168,11 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions): # NOTE: here we use puller stream to wrap the original dataset datasets=[ DatasetAbstraction( - "puller_stream", args=dict(dataset_cfgs=self.datasets) + "puller_stream", + args=dict( + dataset_cfgs=self.datasets, + args=self, + ), ) ], torch_cache_mysophobia=self.torch_cache_mysophobia, diff --git a/realhf/experiments/benchmark/profile_exp.py b/realhf/experiments/benchmark/profile_exp.py deleted file mode 100644 index bd15b43..0000000 --- a/realhf/experiments/benchmark/profile_exp.py +++ /dev/null @@ -1,259 +0,0 @@ -# Copyright 2025 Ant Group Inc. -# Copyright 2024 Wei Fu & Zhiyu Mei -# Licensed under the Apache License, Version 2.0 (the "License"). - -import copy -import dataclasses -import itertools -import json -import os -from typing import * - -from omegaconf import OmegaConf - -from realhf.api.cli_args import ( - MFCConfig, - ModelTrainEvalConfig, - ParallelismConfig, - PromptOnlyDatasetConfig, -) -from realhf.api.core.config import ( - DatasetAbstraction, - ModelInterfaceAbstraction, - ModelInterfaceType, -) -from realhf.api.core.dfg import MFCDef -from realhf.api.core.system_api import ExperimentConfig -from realhf.api.quickstart.entrypoint import register_quickstart_exp -from realhf.base import constants, logging -from realhf.base.topology import decompose_to_three_factors -from realhf.experiments.common.common import CommonExperimentConfig - -logger = logging.getLogger("Profiling Experiment", "system") - - -def default_parallel_config(n_gpus: int) -> List[Dict[str, Any]]: - factors = decompose_to_three_factors(n_gpus) - x = [ - { - "data_parallel_size": dp, - "tensor_parallel_size": tp, - "pipeline_parallel_size": pp, - "use_sequence_parallel": tp > 1, - } - for dp, tp, pp in factors - ] - x += [ - { - "data_parallel_size": dp, - "tensor_parallel_size": tp, - "pipeline_parallel_size": pp, - "use_sequence_parallel": False, - } - for dp, tp, pp in factors - if tp > 1 - ] - return x - - -def dataclass_from_dict(klass, d): - try: - fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)} - return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d}) - except: - return d # Not a dataclass field - - -@dataclasses.dataclass -class ProfileConfig(CommonExperimentConfig): - """The experiment configuration for profiling layers and interfaces. - - The `initial_setup` method in this experiment will return a list of - experiment configurations, which will be run sequentially. - All configurations share the same experiment name, trial name, - and the scheduling configuration. They can have different models, - datasets, or parallel strategies, as long as they always occupy - a fixed number of GPUs. - - It's important to note that, if any error occurs during the execution, - the experiment will terminate immediately. In particular, the OOM error - should not appear because the profiling setup usually uses a small model. - """ - - interfaces_jsonl: str = "" - allocations_jsonl: Optional[str] = None - handle_names: Optional[List[str]] = None - n_mbs: Optional[List[int]] = None - batch_sizes: Optional[List[int]] = None - models_jsonl: str = "" - datasets_jsonl: str = "" - - def __post_init__(self): - # Check that handle_name belones to ["train_step", "generate", "inference"] - self.handle_names = list(set(self.handle_names)) - if any( - k not in ["train_step", "generate", "inference"] for k in self.handle_names - ): - raise NotImplementedError(f"Unknown handle_name: {self.handle_name}") - - # Check the configuration of interfaces - if not os.path.exists(self.interfaces_jsonl): - raise FileNotFoundError( - f"File not found: {self.interfaces_jsonl}. " - "It should be a JSONL file specifying the arguments " - "for the interface implementation." - ) - with open(self.interfaces_jsonl, "r") as f: - self.interface_kwargs = [json.loads(l) for l in f.readlines()] - - # Check the configuration of parallel strategies. - if self.allocations_jsonl is None: - self.parallel_kwargs = default_parallel_config( - self.n_nodes * self.n_gpus_per_node - ) - else: - assert self.allocations_jsonl.endswith(".jsonl") - assert os.path.exists(self.allocations_jsonl) - with open(self.allocations_jsonl, "r") as f: - self.parallel_kwargs = [json.loads(l) for l in f.readlines()] - for pcfg in self.parallel_kwargs: - assert isinstance(pcfg, dict), type(pcfg) - assert all( - k - in [ - "data_parallel_size", - "tensor_parallel_size", - "pipeline_parallel_size", - "use_sequence_parallel", - ] - for k in pcfg.keys() - ), pcfg.keys() - assert (self.n_nodes * self.n_gpus_per_node) == ( - pcfg.get("data_parallel_size", 1) - * pcfg.get("tensor_parallel_size", 1) - * pcfg.get("pipeline_parallel_size", 1) - ) - - if self.n_mbs is None: - self.n_mbs = [1] - else: - self.n_mbs = OmegaConf.to_container(self.n_mbs) - assert isinstance(self.n_mbs, list), type(self.n_mbs) - assert all(isinstance(x, int) for x in self.n_mbs) - - assert self.batch_sizes is not None - - assert os.path.exists(self.models_jsonl) - with open(self.models_jsonl, "r") as f: - self.model_kwargs = [json.loads(l) for l in f.readlines()] - - assert os.path.exists(self.datasets_jsonl) - with open(self.datasets_jsonl, "r") as f: - self.dataset_kwargs = [json.loads(l) for l in f.readlines()] - assert all(x["type_"] == "prompt" for x in self.dataset_kwargs) - - @property - def allocations(self): - return dict(default=self._tmp_allocation) - - @property - def models(self): - return dict(default=self._tmp_model) - - @property - def tokenizer_name_or_path(self): - return self._tmp_model.path - - @property - def max_prompt_len(self): - return self._tmp_dataset.args["max_length"] - - @property - def datasets(self): - return [self._tmp_dataset] - - @property - def rpcs(self): - return dict(default=self._tmp_rpc) - - def initial_setup(self) -> List[ExperimentConfig]: - self.allocation_mode = "manual" - setups = [] - setup_log_path = os.path.join( - constants.LOG_ROOT, - self.experiment_name, - self.trial_name, - "setups.jsonl", - ) - logger.info( - f"Experiment setup configurations of the profiling experiment " - f"will be saved to: {setup_log_path}" - ) - with open(setup_log_path, "w") as f: - # batch size in the most outer loop to delay the possible OOM error - for ( - bs, - pcfg, - n_mbs, - model_cfg, - dataset_cfg, - handle_name, - interface_cfg, - ) in itertools.product( - self.batch_sizes, - self.parallel_kwargs, - self.n_mbs, - self.model_kwargs, - self.dataset_kwargs, - self.handle_names, - self.interface_kwargs, - ): - if handle_name == "generate" and pcfg["use_sequence_parallel"]: - continue - - kwargs_stat = dict( - parallel=pcfg, - n_mbs=n_mbs, - model=model_cfg, - dataset=dataset_cfg, - interface=interface_cfg, - bs=bs, - ) - f.write(json.dumps(kwargs_stat) + "\n") - - # Create tmp object for constructing experiment setups - self._tmp_allocation = MFCConfig( - parallel=ParallelismConfig(**pcfg), n_mbs=n_mbs - ) - self._tmp_model = dataclass_from_dict(ModelTrainEvalConfig, model_cfg) - self._tmp_dataset = DatasetAbstraction(**dataset_cfg) - if handle_name == "train_step": - interface_type = ModelInterfaceType.TRAIN_STEP - elif handle_name == "inference": - interface_type = ModelInterfaceType.INFERENCE - elif handle_name == "generate": - interface_type = ModelInterfaceType.GENERATE - else: - raise NotImplementedError( - f"Unknown which handle to run in the interface: {self.handle_name}" - ) - self._tmp_rpc = MFCDef( - n_seqs=bs, - name="default", - n_mbs=n_mbs, - interface_type=interface_type, - interface_impl=ModelInterfaceAbstraction(**interface_cfg), - model_name="default", - input_keys=["packed_prompts"], - log_return_value=False, - balanced_dp=True, - ) - - setup = copy.deepcopy(super().initial_setup()) - for m in setup.model_worker: - m.profile_mode = True - setups.append(setup) - return setups - - -register_quickstart_exp("profile", ProfileConfig) diff --git a/realhf/experiments/common/common.py b/realhf/experiments/common/common.py index ef47549..370f16f 100644 --- a/realhf/experiments/common/common.py +++ b/realhf/experiments/common/common.py @@ -42,7 +42,6 @@ from realhf.api.quickstart.device_mesh import ( RPCAllocation, make_device_mesh_from_name, ) -from realhf.base.cluster import spec as cluster_spec from realhf.experiments.common.check import ( check_valid_model_and_path, check_valid_optimizer, @@ -164,21 +163,24 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): return ExperimentScheduling( master_worker=TasksGroup( count=1, - scheduling=Scheduling.master_worker_default( + scheduling=Scheduling( cpu=self.cpus_per_master_worker, + gpu=0, mem=self.mem_per_master_worker, nodelist=self.nodelist, exclude=self.exclude, + container_image=self.cluster.cpu_image, ), ), model_worker=TasksGroup( count=self.n_nodes * self.n_gpus_per_node, - scheduling=Scheduling.model_worker_default( + scheduling=Scheduling( cpu=self.cpus_per_model_worker, gpu=1, mem=self.mem_per_model_worker, nodelist=self.nodelist, exclude=self.exclude, + container_image=self.cluster.gpu_image, ), ), ) @@ -326,6 +328,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): rpc=rpc, device_mesh=( make_device_mesh_from_name( + self.cluster, self.nodelist, self.allocations[rpc_type].device_mesh, self.global_device_mesh.n_gpus_per_node, @@ -579,7 +582,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): ) if self.n_gpus_per_node > self.cluster.n_gpus_per_node: raise ValueError( - f"Number of 7used GPUs per node {self.n_gpus_per_node} should not be larger than the cluster limit {self.cluster.n_gpus_per_node}" + f"Number of used GPUs per node {self.n_gpus_per_node} should not be larger than the cluster limit {self.cluster.n_gpus_per_node}" ) if self.n_nodes > 1 and self.n_gpus_per_node != self.cluster.n_gpus_per_node: raise ValueError( diff --git a/realhf/experiments/common/math_code_eval_exp.py b/realhf/experiments/common/math_code_eval_exp.py index 1f2640d..bccab30 100644 --- a/realhf/experiments/common/math_code_eval_exp.py +++ b/realhf/experiments/common/math_code_eval_exp.py @@ -2,9 +2,9 @@ # Copyright 2024 Wei Fu & Zhiyu Mei # Licensed under the Apache License, Version 2.0 (the "License"). import dataclasses +import os from typing import Dict -import realhf.base.logging as logging from realhf.api.cli_args import MathCodeEvalOptions, ModelTrainEvalConfig from realhf.api.core.config import ( DatasetAbstraction, @@ -13,6 +13,7 @@ from realhf.api.core.config import ( ) from realhf.api.core.dfg import MFCDef from realhf.api.quickstart.entrypoint import register_quickstart_exp +from realhf.base import constants, logging from realhf.experiments.common.common import CommonExperimentConfig from realhf.experiments.common.utils import asdict @@ -55,6 +56,9 @@ class MathCodeEvalConfig(MathCodeEvalOptions, CommonExperimentConfig): dataset_path=self.dataset.path, tokenizer_path=self.actor.path, rw_type=self.rw_type, + answer_save_path=os.path.join( + constants.get_log_path(self), "generated" + ), check_xml_format=self.check_xml_format, group_size=self.group_size, check_verifier_status=self.check_verifier_status, diff --git a/realhf/experiments/common/ppo_math_exp.py b/realhf/experiments/common/ppo_math_exp.py index 9b8810e..48fff91 100644 --- a/realhf/experiments/common/ppo_math_exp.py +++ b/realhf/experiments/common/ppo_math_exp.py @@ -6,7 +6,6 @@ import dataclasses import os from typing import Dict -import realhf.base.logging as logging from realhf.api.cli_args import ModelTrainEvalConfig, PPOMATHExperimentOptions from realhf.api.core.config import ( DatasetAbstraction, @@ -16,6 +15,7 @@ from realhf.api.core.config import ( from realhf.api.core.dfg import MFCDef, ParamReallocHook from realhf.api.core.system_api import ExperimentConfig from realhf.api.quickstart.entrypoint import register_quickstart_exp +from realhf.base import constants, logging from realhf.experiments.common.common import CommonExperimentConfig from realhf.experiments.common.utils import ( asdict, @@ -132,6 +132,9 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions): check_xml_format=self.check_xml_format, group_size=self.group_size, check_verifier_status=self.check_verifier_status, + answer_save_path=os.path.join( + constants.get_log_path(self), "generated" + ), ), ) diff --git a/realhf/impl/agent/math_multi_turn_agent.py b/realhf/impl/agent/math_multi_turn_agent.py index c0f3775..12f9278 100644 --- a/realhf/impl/agent/math_multi_turn_agent.py +++ b/realhf/impl/agent/math_multi_turn_agent.py @@ -32,6 +32,7 @@ class MathMultiTurnAgent(Agent): self, gconfig, tokenizer_path, + answer_save_path, reward_scaling=1.0, reward_bias=0.0, turn_level_discount: float = 1.0, @@ -39,6 +40,7 @@ class MathMultiTurnAgent(Agent): ): self.gconfig = gconfig.new(n=1) self.tokenizer = load_hf_tokenizer(tokenizer_path) + self.answer_save_path = answer_save_path self.reward_scaling = reward_scaling self.reward_bias = reward_bias @@ -245,10 +247,7 @@ class MathMultiTurnAgent(Agent): for group_idx in range(group_size): # NOTE: we can ensure that only one process is logging this query id gen_file_path = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "generated", + self.answer_save_path, str(version_starts[group_idx]), f"{qid}.txt", ) @@ -271,10 +270,7 @@ class MathMultiTurnAgent(Agent): _f.write(info + "\n") train_pass_monitor_file_path = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "training_monitor", + self.answer_save_path, str(version_starts[group_idx]), f"{qid}.jsonl", ) diff --git a/realhf/impl/agent/math_single_step_agent.py b/realhf/impl/agent/math_single_step_agent.py index 5276d8c..2a3a981 100644 --- a/realhf/impl/agent/math_single_step_agent.py +++ b/realhf/impl/agent/math_single_step_agent.py @@ -25,6 +25,7 @@ class MathSingleStepAgent(Agent): self, gconfig, tokenizer_path, + answer_save_path, success_rate_lb, success_rate_ub, reward_scaling=1.0, @@ -32,6 +33,7 @@ class MathSingleStepAgent(Agent): ): self.gconfig = gconfig self.tokenizer = load_hf_tokenizer(tokenizer_path) + self.answer_save_path = answer_save_path self.success_rate_lb = success_rate_lb self.success_rate_ub = success_rate_ub @@ -198,10 +200,7 @@ class MathSingleStepAgent(Agent): for group_idx in range(group_size): # NOTE: we can ensure that only one process is logging this query id gen_file_path = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "generated", + self.answer_save_path, str(version_starts[group_idx]), f"{qid}.txt", ) @@ -224,10 +223,7 @@ class MathSingleStepAgent(Agent): _f.write(info + "\n") train_pass_monitor_file_path = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "training_monitor", + self.answer_save_path, str(version_starts[group_idx]), f"{qid}.jsonl", ) diff --git a/realhf/impl/model/backend/megatron.py b/realhf/impl/model/backend/megatron.py index eb1a4d8..648da94 100644 --- a/realhf/impl/model/backend/megatron.py +++ b/realhf/impl/model/backend/megatron.py @@ -21,7 +21,6 @@ from realhf.api.core import model_api from realhf.api.core.data_api import SequenceSample from realhf.base import constants, logging, pkg_version from realhf.base.datapack import flat2d -from realhf.base.monitor import CUDATimeMarkType, cuda_tmarked from realhf.impl.model.backend.inference import PipelinableInferenceEngine from realhf.impl.model.backend.pipe_runner import PipelineRunner, PipeTrainInstrSet from realhf.impl.model.modules.mlp import get_activation_fn @@ -304,7 +303,6 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet): self._no_sync_context.__exit__(None, None, None) self._no_sync_context = None - @cuda_tmarked("bwd", CUDATimeMarkType.backward) def _exec_backward_pass( self, module: ReaLModel, @@ -342,7 +340,6 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet): # self.engine.ddp.start_grad_sync() self.engine.finalize_grads() - @cuda_tmarked("opt", CUDATimeMarkType.optim_step) def _exec_optimizer_step( self, module: ReaLModel, @@ -489,8 +486,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine): loss_scale *= constants.data_parallel_world_size() loss_scale *= self.engine.optim.get_loss_scale().item() loss *= loss_scale - with cuda_tmarked("bwd", CUDATimeMarkType.backward): - loss.backward() + loss.backward() self.engine.finalize_grads() return self._step(version_steps) @@ -530,7 +526,6 @@ class ReaLMegatronEngine(model_api.PipelinableEngine): ) # wrapper for profiler - @cuda_tmarked("opt", CUDATimeMarkType.optim_step) def _step(self, version_steps): # omit the number of zeros in grads update_successful, grad_norm, _ = self.engine.optim.step() diff --git a/realhf/impl/model/backend/sglang.py b/realhf/impl/model/backend/sglang.py index 8347e17..55ff9cb 100644 --- a/realhf/impl/model/backend/sglang.py +++ b/realhf/impl/model/backend/sglang.py @@ -32,7 +32,6 @@ from realhf.api.core.model_api import ( register_backend, ) from realhf.base import ( - cluster, constants, datapack, gpu_utils, @@ -431,9 +430,9 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig): def _initialize(self, model: Model, spec: FinetuneSpec) -> Model: if constants.pipe_parallel_world_size() != 1: raise RuntimeError("SGLang does not support pipe parallel size > 1.") - if constants.tensor_parallel_world_size() > cluster.spec.n_gpus_per_node: + if constants.tensor_parallel_world_size() > torch.cuda.device_count(): raise RuntimeError( - "AReaL's SGLang integration does not support model parallel size > n_gpus_per_node." + "AReaL's SGLang integration does not support model parallel size > torch.cuda.device_count()." ) additional_args = dataclasses.asdict(self) @@ -453,6 +452,9 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig): high=60000, experiment_name=constants.experiment_name(), trial_name=constants.trial_name(), + lockfile_root=os.path.join( + constants.get_cache_path(self.args), "ports" + ), ), group=constants.data_parallel_group(), ) @@ -475,10 +477,6 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig): tp_size=constants.tensor_parallel_world_size(), # Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process base_gpu_id=int(os.environ["CUDA_VISIBLE_DEVICES"]), - file_storage_path=os.path.join( - constants.SGLANG_CACHE_PATH, - f"sglang_storage{constants.data_parallel_rank()}", - ), # Data parallelism dp_size=1, # TODO: check whether we require SGLang dp load_balance_method="round_robin", diff --git a/realhf/impl/model/comm/global_comm.py b/realhf/impl/model/comm/global_comm.py index f0820cd..4b60452 100644 --- a/realhf/impl/model/comm/global_comm.py +++ b/realhf/impl/model/comm/global_comm.py @@ -46,6 +46,7 @@ def filter_match_mwids( def setup_global_comm( + args, expr_name: str, trial_name: str, worker_index: int, @@ -87,10 +88,12 @@ def setup_global_comm( ) if constants.use_cuda(): - assert len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) == 1, os.environ[ - "CUDA_VISIBLE_DEVICES" - ] - local_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"]) + if len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) == 1: + local_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"]) + else: + local_gpu_id = int( + os.environ["CUDA_VISIBLE_DEVICES"].split(",")[worker_index] + ) else: local_gpu_id = global_rank @@ -100,7 +103,11 @@ def setup_global_comm( if worker_index == 0: host_ip = socket.gethostbyname(socket.gethostname()) - port = network.find_free_port(experiment_name=expr_name, trial_name=trial_name) + port = network.find_free_port( + experiment_name=expr_name, + trial_name=trial_name, + lockfile_root=os.path.join(constants.get_cache_path(args), "ports"), + ) pg_init_addr = f"tcp://{host_ip}:{port}" name_resolve.add(pg_master_name, pg_init_addr, keepalive_ttl=300) else: diff --git a/realhf/impl/model/interface/math_rw_interface.py b/realhf/impl/model/interface/math_rw_interface.py index 601a73e..266c37b 100644 --- a/realhf/impl/model/interface/math_rw_interface.py +++ b/realhf/impl/model/interface/math_rw_interface.py @@ -181,6 +181,7 @@ def retokenize_and_verify( class MultiTaskRewardInterface(model_api.ModelInterface): dataset_path: str = "" tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B" + answer_save_path: str = "." output_scaling: float = 1.0 output_bias: float = 0.0 rw_type: str = "sparse" @@ -363,10 +364,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface): ): tik = time.perf_counter() gen_file_path = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "generated", + self.answer_save_path, task_type, f"v{model.version.global_step}r{dist.get_rank()}.txt", ) @@ -386,10 +384,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface): _f.write(info + "\n") gen_file_path = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "generated_jsonl", + self.answer_save_path, task_type, f"v{model.version.global_step}r{dist.get_rank()}.jsonl", ) diff --git a/realhf/impl/model/nn/real_llm_api.py b/realhf/impl/model/nn/real_llm_api.py index bfc70ef..0c88fb1 100644 --- a/realhf/impl/model/nn/real_llm_api.py +++ b/realhf/impl/model/nn/real_llm_api.py @@ -17,7 +17,6 @@ import transformers from realhf.api.core import model_api from realhf.api.core.config import ModelName from realhf.base import constants, logging, topology -from realhf.base.monitor import CUDATimeMarkType, cuda_tmark, cuda_tmarked from realhf.impl.model.comm.global_comm import NCCLProcessGroupInfo from realhf.impl.model.comm.param_realloc import ( ReparallelizeReceiverStep, @@ -470,13 +469,11 @@ class ReaLModel(nn.Module): pp_input_buf[:batch_length] = x.pp_input x.pp_input = pp_input_buf - tmark_type = CUDATimeMarkType.forward - with cuda_tmarked("fwd", tmark_type): - # Main forward calls. - if not self._offloaded: - x, ys = self.__forward(x, ys) - else: - x, ys = self.__overlapped_load_forward(x, ys) + # Main forward calls. + if not self._offloaded: + x, ys = self.__forward(x, ys) + else: + x, ys = self.__overlapped_load_forward(x, ys) # Resume from padding. if ( @@ -644,7 +641,6 @@ class ReaLModel(nn.Module): self._reparallelize_targets[(from_model_name, to_model_name)] = rtgt # FIXME: we can get topo given model name from constants - @cuda_tmark("param_realloc", CUDATimeMarkType.mem_layout) def build_reparallelized_layers_async( self, from_model_name: ModelName, diff --git a/realhf/impl/model/utils/functional.py b/realhf/impl/model/utils/functional.py index acdeb9c..e9d0b6a 100644 --- a/realhf/impl/model/utils/functional.py +++ b/realhf/impl/model/utils/functional.py @@ -166,7 +166,7 @@ def build_leave_one_indices( ) -def _gather_logprobs( +def gather_logprobs( logits: torch.Tensor, labels: torch.Tensor, ): @@ -186,24 +186,6 @@ def _gather_logprobs( return log_probs_labels -_gather_logprobs_compiled = None - - -def gather_logprobs( - logits: torch.Tensor, - labels: torch.Tensor, -): - from realhf.base import cluster - - if cluster.spec.name == "wa180": - # torch.compile doesn't work on PPU - return _gather_logprobs(logits, labels) - global _gather_logprobs_compiled - if _gather_logprobs_compiled is None: - _gather_logprobs_compiled = torch.compile(_gather_logprobs) - return _gather_logprobs_compiled(logits, labels) - - def gather_packed_shifted_log_probs( logits: torch.FloatTensor, cu_seqlens: torch.Tensor, diff --git a/realhf/scheduler/client.py b/realhf/scheduler/client.py index cf6c3ec..fb2f20e 100644 --- a/realhf/scheduler/client.py +++ b/realhf/scheduler/client.py @@ -5,9 +5,10 @@ import dataclasses import enum import subprocess -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional -from realhf.base.cluster import spec as cluster_spec +if TYPE_CHECKING: + from realhf.api.cli_args import BaseExperimentConfig class JobState(enum.Enum): @@ -50,10 +51,11 @@ class JobInfo: class SchedulerClient: - def __init__(self, expr_name, trial_name): - self.expr_name = expr_name - self.trial_name = trial_name - self.run_name = f"{expr_name}_{trial_name}" + def __init__(self, args: "BaseExperimentConfig"): + self.args = args + self.expr_name = args.experiment_name + self.trial_name = args.trial_name + self.run_name = f"{self.expr_name}_{self.trial_name}" def submit(self, worker_type, cmd, **kwargs): """Submits a job to the scheduler. Raises exception if the job is @@ -120,16 +122,10 @@ class SchedulerClient: raise NotImplementedError() -def get_python3_path(): - if cluster_spec.cluster_type == "ray": - return subprocess.check_output(["which", "python3"]).decode("utf-8").strip() - return "python3" - - def remote_worker_cmd(expr_name, trial_name, debug, worker_type): # requires information in scheduler package return ( - f"{get_python3_path()} {'' if debug else '-O'} -m realhf.apps.remote worker -w {worker_type} " + f"python3 {'' if debug else '-O'} -m realhf.apps.remote worker -w {worker_type} " f"-e {expr_name} -f {trial_name} -i {{jobstep_id}} -g {{n_jobsteps}} -r {{worker_submission_index}} " f"-p {{wprocs_per_jobstep}} -j {{wprocs_in_job}} -o {{wproc_offset}}" ) @@ -137,7 +133,7 @@ def remote_worker_cmd(expr_name, trial_name, debug, worker_type): def setup_cmd(expr_name, trial_name, debug): bash_cmd = ( # f"pip3 install -e $REAL_PACKAGE_PATH --no-build-isolation && " - f"{get_python3_path()} {'' if debug else '-O'} -m realhf.apps.remote " + f"python3 {'' if debug else '-O'} -m realhf.apps.remote " f"reset_name_resolve -e {expr_name} -f {trial_name}" ) # return f"bash -c \"{bash_cmd}\"" @@ -146,7 +142,7 @@ def setup_cmd(expr_name, trial_name, debug): def control_cmd(expr_name, trial_name, debug, ignore_worker_error, controller_type): bash_cmd = ( # f"pip3 install -e $REAL_PACKAGE_PATH --no-build-isolation && " - f"{get_python3_path()} {'' if debug else '-O'} -m realhf.apps.remote controller " + f"python3 {'' if debug else '-O'} -m realhf.apps.remote controller " f"-e {expr_name} -f {trial_name} " f"--{'ignore_worker_error' if ignore_worker_error else 'raise_worker_error'} " f"--type {controller_type}" @@ -155,25 +151,23 @@ def control_cmd(expr_name, trial_name, debug, ignore_worker_error, controller_ty return bash_cmd -def make(mode, expr_name, trial_name, **kwargs) -> SchedulerClient: - if mode == "slurm": +def make(args: "BaseExperimentConfig", **kwargs) -> SchedulerClient: + if args.mode == "slurm": from realhf.scheduler.slurm.client import SlurmSchedulerClient - schedule_strategy = kwargs.get("schedule_strategy", "empty_first") - evaluator = kwargs.get("evaluator", None) job_group_id = kwargs.get("job_group_id", None) job_group_index = kwargs.get("job_group_index", None) + evaluator = kwargs.get("evaluator", None) return SlurmSchedulerClient( - expr_name, - trial_name, - schedule_strategy, + args, + args.schedule_strategy, evaluator, job_group_id, job_group_index, ) - elif mode == "local": + elif args.mode == "local": from realhf.scheduler.local.client import LocalSchedulerClient - return LocalSchedulerClient(expr_name, trial_name) + return LocalSchedulerClient(args) else: raise NotImplementedError(f"Scheduler {mode} not found") diff --git a/realhf/scheduler/evaluator.py b/realhf/scheduler/evaluator.py index 7c5fef0..597a75b 100644 --- a/realhf/scheduler/evaluator.py +++ b/realhf/scheduler/evaluator.py @@ -6,13 +6,14 @@ import pathlib import re import subprocess import time -from typing import Dict, Optional +from typing import Any, Dict, Optional import swanlab import wandb import realhf.api.core.system_api as config_pkg -from realhf.base import cluster, constants, logging +from realhf.api.cli_args import BaseExperimentConfig +from realhf.base import constants, logging logger = logging.getLogger("AutomaticEvaluator", "colored") @@ -27,6 +28,7 @@ class EvaluationStepStatus(enum.Enum): @dataclasses.dataclass class EvaluationStep: + args: BaseExperimentConfig global_step: int status: EvaluationStepStatus start_time: Optional[float] = None @@ -34,7 +36,7 @@ class EvaluationStep: process: Optional[subprocess.Popen] = None @staticmethod - def from_ckpt_dir(ckpt_dir): + def from_ckpt_dir(args, ckpt_dir): # NOTE: ckpt_dir should be absolute path if pathlib.Path(ckpt_dir).is_symlink(): return None @@ -44,13 +46,14 @@ class EvaluationStep: return None _, _, global_step = map(int, match.groups()) return EvaluationStep( + args=args, global_step=global_step, status=EvaluationStepStatus.PENDING, ckpt_dir=ckpt_dir, ) @staticmethod - def from_output_dir(output_dir): + def from_output_dir(args, output_dir): # NOTE: output_dir should be absolute path # Should only be called in recover. _dir = os.path.basename(output_dir) @@ -59,15 +62,13 @@ class EvaluationStep: return None global_step = int(match.groups()[0]) return EvaluationStep( - global_step=global_step, status=EvaluationStepStatus.LOGGED + args=args, global_step=global_step, status=EvaluationStepStatus.LOGGED ) @property def output_dir(self): return os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), + constants.get_log_path(self.args), "eval_output", f"globalstep{self.global_step}", ) @@ -77,7 +78,7 @@ class EvaluationStep: cmd = ( f"srun --mpi=pmi2 -J {slurm_job_name} --ntasks=1 --cpus-per-task=128 --gres=gpu:8 --mem-per-cpu=12G " f"singularity exec --no-home --nv --pid --writable-tmpfs --bind /storage:/storage " - f"{config.eval_job_image or cluster.spec.gpu_image} " + f"{config.eval_job_image or self.args.cluster.gpu_image} " f"bash ./evaluation/sh/install_deps_and_eval.sh {self.ckpt_dir} {self.output_dir} " f"{config.data_names} {config.max_gen_tokens} {config.prompt_type}" ) @@ -86,7 +87,7 @@ class EvaluationStep: def submit(self, config: config_pkg.AutomaticEvaluator): os.makedirs(self.output_dir, exist_ok=True) log_file = open(os.path.join(self.output_dir, "output.log"), "w") - if cluster.spec.cluster_type == "slurm": + if self.args.mode == "slurm": cmd = self.slurm_eval_cmd(config) else: raise NotImplementedError( @@ -155,10 +156,12 @@ class AutomaticEvaluator: def __init__( self, + args: BaseExperimentConfig, config: config_pkg.AutomaticEvaluator, wandb_config: config_pkg.WandBConfig, swanlab_config: config_pkg.SwanlabConfig, ): + self.args = args self.__eval_steps: Dict[int, EvaluationStep] = {} self.__max_concurrent_jobs = config.max_concurrent_jobs self.__wandb_config = wandb_config @@ -174,15 +177,13 @@ class AutomaticEvaluator: # Resubmiting or waiting for these jobs will probably result in # unexpected behaviors. output_parent = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), + constants.get_log_path(args), "eval_output", ) if os.path.exists(output_parent): for output_dir in os.listdir(output_parent): output_dir = os.path.join(output_parent, output_dir) - eval_step = EvaluationStep.from_output_dir(output_dir) + eval_step = EvaluationStep.from_output_dir(self.args, output_dir) if eval_step: self.__eval_steps[eval_step.global_step] = eval_step @@ -197,12 +198,13 @@ class AutomaticEvaluator: ) if self.__config.initial_checkpoint_path and 0 not in self.__eval_steps: self.__eval_steps[0] = EvaluationStep( + args=self.args, global_step=0, status=EvaluationStepStatus.PENDING, ckpt_dir=self.__config.initial_checkpoint_path, ) - if not cluster.spec.cluster_type == "slurm": + if not self.args.mode == "slurm": raise NotImplementedError( "Currently only support automatic evaluation for slurm" ) @@ -224,9 +226,7 @@ class AutomaticEvaluator: notes=self.__wandb_config.notes, tags=self.__wandb_config.tags, config=self.__wandb_config.config, - dir=os.path.join( - constants.LOG_ROOT, constants.experiment_name(), constants.trial_name() - ), + dir=constants.get_log_path(self.args), force=True, id=f"{constants.experiment_name()}_{constants.trial_name()}_eval", resume="allow", @@ -270,15 +270,13 @@ class AutomaticEvaluator: def step(self): # Check whether a new evaluation step should be created ckpt_parent = os.path.join( - constants.MODEL_SAVE_ROOT, - constants.experiment_name(), - constants.trial_name(), + constants.get_save_path(self.args), "actor", ) if os.path.exists(ckpt_parent): for ckpt_dir in os.listdir(ckpt_parent): ckpt_dir = os.path.join(ckpt_parent, ckpt_dir) - eval_step = EvaluationStep.from_ckpt_dir(ckpt_dir) + eval_step = EvaluationStep.from_ckpt_dir(self.args, ckpt_dir) if eval_step is None: continue if eval_step.global_step in self.__eval_steps: diff --git a/realhf/scheduler/local/client.py b/realhf/scheduler/local/client.py index 57a2576..ffc461d 100644 --- a/realhf/scheduler/local/client.py +++ b/realhf/scheduler/local/client.py @@ -14,7 +14,7 @@ import psutil import realhf.base.logging as logging from realhf.base import gpu_utils -from realhf.base.constants import LOG_ROOT +from realhf.base.constants import get_log_path from realhf.scheduler.client import ( JobException, JobInfo, @@ -75,14 +75,12 @@ class LocalSchedulerClient(SchedulerClient): def log_path_of(self, worker_type) -> str: return os.path.join( - LOG_ROOT, - self.expr_name, - self.trial_name, + get_log_path(self.args), f"{worker_type}-0", ) - def __init__(self, expr_name, trial_name): - super().__init__(expr_name, trial_name) + def __init__(self, args): + super().__init__(args) self._jobs: Dict[str, subprocess.Popen] = {} self._running_worker_types = [] diff --git a/realhf/scheduler/slurm/client.py b/realhf/scheduler/slurm/client.py index aa3d1a3..772b601 100644 --- a/realhf/scheduler/slurm/client.py +++ b/realhf/scheduler/slurm/client.py @@ -15,9 +15,7 @@ from typing import Dict, List, Literal, Optional, Tuple import colorama import realhf.base.logging as logging -from realhf.base.cluster import spec as cluster_spec -from realhf.base.constants import LOG_ROOT -from realhf.base.constants import SLURM_LOCK_FILE_NAME as LOCK_FILE_NAME +from realhf.base.constants import get_log_path from realhf.scheduler.client import JobException, JobInfo, JobState, SchedulerClient from realhf.scheduler.evaluator import AutomaticEvaluator from realhf.scheduler.slurm.utils import ( @@ -82,14 +80,13 @@ class SlurmSchedulerClient(SchedulerClient): def __init__( self, - expr_name: str, - trial_name: str, + args, schedule_strategy: str, evaluator: Optional[AutomaticEvaluator], job_group_id: str, job_group_index: int, ): - super().__init__(expr_name, trial_name) + super().__init__(args) self.__schedule_strategy = schedule_strategy @@ -124,11 +121,31 @@ class SlurmSchedulerClient(SchedulerClient): deadline: str = None, time_limit: str = None, ): - container_image = container_image or cluster_spec.cpu_image - container_mounts = container_mounts or cluster_spec.mount + container_image = container_image or self.args.cluster.cpu_image + container_mounts = container_mounts or self.args.cluster.mount # record launch information, do not submit to slurm until `wait()` is called # NOTE: fractional GPU requirement will be resolved automatically in `__post_init__` of SlurnLaunchInfo + log_path = os.path.join( + get_log_path(self.args), + f"{worker_type}-{self.__submission_counter[worker_type]}.log", + ) + multiprog_path = os.path.join( + get_log_path(self.args), + "slurm", + "multiprog", + f"{worker_type}-{self.__submission_counter[worker_type]}.multiprog", + ) + os.makedirs(os.path.dirname(multiprog_path), exist_ok=True) + hostfile_path = os.path.join( + get_log_path(self.args), + "slurm", + "hostfile", + f"{worker_type}-{self.__submission_counter[worker_type]}.hostfile", + ) + os.makedirs(os.path.dirname(hostfile_path), exist_ok=True) + launch_info = SlurmLaunchInfo( + args=self.args, worker_type=worker_type, wprocs_in_job=count, resource_requirement=SlurmResource(mem=mem, cpu=cpu, gpu=gpu), @@ -149,6 +166,9 @@ class SlurmSchedulerClient(SchedulerClient): time_limit=time_limit, job_group_id=self.__job_group_id, job_group_index=self.__job_group_index, + log_path=log_path, + multiprog_path=multiprog_path, + hostfile_path=hostfile_path, ) if ( @@ -177,9 +197,9 @@ class SlurmSchedulerClient(SchedulerClient): wproc_offset=self.__wprocs_counter[worker_type], ) wrap_cmd = "singularity exec " - if cluster_spec.name == "na132": + if self.args.cluster.cluster_name == "na132": wrap_cmd += "--pid " - if cluster_spec.gpu_type == "tesla": + if self.args.cluster.gpu_type == "tesla": wrap_cmd += "--nv " wrap_cmd += "--no-home --writable-tmpfs " if len(launch_info.env_vars) > 0: @@ -199,7 +219,9 @@ class SlurmSchedulerClient(SchedulerClient): start_time = time.monotonic() while True: try: - fp = open(LOCK_FILE_NAME, "w") + fp = open( + f"{self.args.cluster.fileroot}/logs/slurm_scheduler.lock", "w" + ) fcntl.flock(fp, fcntl.LOCK_EX) infos = list(self.__pending_jobs.values()) infos = allocate_resources(infos, strategy=self.__schedule_strategy) @@ -297,9 +319,7 @@ class SlurmSchedulerClient(SchedulerClient): threads = [] stop_events = [] - merged_log_path = os.path.join( - LOG_ROOT, self.expr_name, self.trial_name, "main.log" - ) + merged_log_path = os.path.join(get_log_path(self.args), "main.log") for job_name, launch_info in self.__committed_jobs.items(): stop_event = threading.Event() diff --git a/realhf/scheduler/slurm/utils.py b/realhf/scheduler/slurm/utils.py index ebf2e4e..ecf9696 100644 --- a/realhf/scheduler/slurm/utils.py +++ b/realhf/scheduler/slurm/utils.py @@ -20,10 +20,10 @@ from typing import Callable, Dict, List, Literal, Optional, Union import pandas as pd -import realhf.base.cluster as cluster import realhf.base.logging as logging import realhf.version as version -from realhf.base.constants import LOG_ROOT +from realhf.api.cli_args import BaseExperimentConfig +from realhf.base.constants import get_log_path from realhf.scheduler.client import JobException, JobInfo, JobState logger = logging.getLogger("scheduler.slurm.utils") @@ -190,6 +190,7 @@ class SlurmLaunchInfo: multiprog_content (str, optional): The content of the multiprog file. """ + args: BaseExperimentConfig run_name: str exper_name: str trial_name: str @@ -199,6 +200,10 @@ class SlurmLaunchInfo: job_group_id: str job_group_index: str + log_path: str + multiprog_path: str + hostfile_path: str + resource_requirement: SlurmResource cmd: str container_image: str @@ -252,41 +257,6 @@ class SlurmLaunchInfo: else: return None - @property - def log_path(self) -> str: - return os.path.join( - LOG_ROOT, - self.exper_name, - self.trial_name, - f"{self.worker_type}-{self.worker_submission_idx}.log", - ) - - @property - def multiprog_path(self) -> str: - path = os.path.join( - LOG_ROOT, - self.exper_name, - self.trial_name, - "slurm", - "multiprog", - f"{self.worker_type}-{self.worker_submission_idx}.multiprog", - ) - os.makedirs(os.path.dirname(path), exist_ok=True) - return path - - @property - def hostfile_path(self) -> str: - path = os.path.join( - LOG_ROOT, - self.exper_name, - self.trial_name, - "slurm", - "hostfile", - f"{self.worker_type}-{self.worker_submission_idx}.hostfile", - ) - os.makedirs(os.path.dirname(path), exist_ok=True) - return path - def show_log(self): try: terminal_columns = os.get_terminal_size().columns @@ -364,14 +334,14 @@ class SlurmLaunchInfo: # head gres_line = "" if gpu >= 1: - assert (gpu * ntasks) % cluster.spec.n_gpus_per_node == 0 + assert (gpu * ntasks) % self.args.cluster.n_gpus_per_node == 0 # In current slurm cluster setup, we can only use "--gres" to # allocate PPUs per node. There are no options to allocate customized # gres per tasks. - if cluster.spec.gpu_type == "ppu": - gres_line = f"--gres=ppu:{cluster.spec.n_gpus_per_node}" + if self.args.cluster.gpu_type == "ppu": + gres_line = f"--gres=ppu:{self.args.cluster.n_gpus_per_node}" else: - gres_line = f"--gres=gpu:{cluster.spec.n_gpus_per_node}" + gres_line = f"--gres=gpu:{self.args.cluster.n_gpus_per_node}" srun_env = os.environ.copy() job_metadata = { @@ -391,7 +361,7 @@ class SlurmLaunchInfo: f"#SBATCH --output={self.log_path}", "#SBATCH --open-mode=append", f"#SBATCH --ntasks={ntasks}", - f"#SBATCH {gres_line}", + f"#SBATCH {gres_line}" if gpu >= 1 else "", f"#SBATCH --cpus-per-task={cpu}", f"#SBATCH --mem-per-cpu={mem // max(1, cpu)}M", "#SBATCH --distribution=arbitrary" if self.hostfile else "", @@ -722,6 +692,9 @@ def allocate_resources( infos, key=lambda x: x.n_jobsteps * x.resource_requirement, reverse=True ) prioritized_hosts = set() + if len(infos) == 0: + return infos + cluster_config = infos[0].args.cluster for info_idx, info in enumerate(infos): valid_hostnames = available_hostnames( nodelist=info.nodelist, @@ -764,12 +737,12 @@ def allocate_resources( gpu_per_task = info.resource_requirement.gpu if gpu_per_task > 0: assert ( - task_left * gpu_per_task % cluster.spec.n_gpus_per_node == 0 + task_left * gpu_per_task % cluster_config.n_gpus_per_node == 0 ), (task_left, gpu_per_task) assert ( - cluster.spec.n_gpus_per_node % gpu_per_task == 0 + cluster_config.n_gpus_per_node % gpu_per_task == 0 ), gpu_per_task - batched_ntasks = int(cluster.spec.n_gpus_per_node // gpu_per_task) + batched_ntasks = int(cluster_config.n_gpus_per_node // gpu_per_task) batched_requirement = batched_ntasks * info.resource_requirement try: resource = resource - batched_requirement @@ -787,10 +760,10 @@ def allocate_resources( allocated[hostname] = tmp - task_left all_resources[hostname] = resource if task_left > 0: - if cluster.spec.gpu_type == "ppu" and info.resource_requirement.gpu > 0: + if cluster_config.gpu_type == "ppu" and info.resource_requirement.gpu > 0: logger.warning( "For PPU resources, we can only allocate tasks in the " - f"granularity of nodes ({cluster.spec.n_gpus_per_node} PPUs)" + f"granularity of nodes ({cluster_config.n_gpus_per_node} PPUs)" ) logger.warning( f'Unable to allocate {info.n_jobsteps} Jobs with name "{info.slurm_name}". ' diff --git a/realhf/system/__init__.py b/realhf/system/__init__.py index 809ed4c..91952ec 100644 --- a/realhf/system/__init__.py +++ b/realhf/system/__init__.py @@ -61,7 +61,7 @@ def run_worker( ) worker = worker_class(server=server) try: - if worker_type in ["rollout_worker", "master_worker", "gserver_manager"]: + if worker_type in ["rollout_worker", "master_worker"]: asyncio.run(worker.run_async()) else: worker.run() diff --git a/realhf/system/controller.py b/realhf/system/controller.py index 678fa9e..0e8b982 100644 --- a/realhf/system/controller.py +++ b/realhf/system/controller.py @@ -27,7 +27,6 @@ from omegaconf import OmegaConf import realhf.api.core.system_api as system_api from realhf.base import constants, gpu_utils, logging, name_resolve, names, pkg_version -from realhf.base.cluster import spec as cluster_spec from realhf.system import WORKER_TYPES, load_worker, worker_base, worker_control from realhf.system.worker_base import WorkerServerStatus as Wss @@ -247,9 +246,7 @@ class Controller: # If a log exists, find the last failed setup and run it. start_idx = 0 - prev_logfile = os.path.join( - constants.LOG_ROOT, self.experiment_name, self.trial_name, "ctl-0" - ) + prev_logfile = os.path.join(constants.get_log_path(experiment), "ctl-0") if os.path.exists(prev_logfile): with open(prev_logfile, "r") as f: for l in f.readlines(): @@ -670,6 +667,7 @@ class RayController: ] env_vars = constants.get_env_vars( + experiment, REAL_MODE=os.environ.get("REAL_MODE", ""), REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""), REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""), diff --git a/realhf/system/function_executor.py b/realhf/system/function_executor.py index 8add93c..2f82583 100644 --- a/realhf/system/function_executor.py +++ b/realhf/system/function_executor.py @@ -24,6 +24,7 @@ blogger = logging.getLogger("benchmark") class FunctionExecutor: def __init__( self, + args, rpcs: List[MFCDef], msid2mwid: Dict[ModelShardID, int], stream: NameResolvingRequestClient, @@ -35,6 +36,8 @@ class FunctionExecutor: shuffle_dataset: bool, ): + self.args = args + self.func_calls: Dict[str, ModelFunctionCall] = {} self.ctrl = ctrl @@ -42,7 +45,9 @@ class FunctionExecutor: self.msid2mwid = msid2mwid self.storage_tracker = GlobalStorageTracker(self.n_model_workers) - self.redistrib_planner = RedistribPlanner(self.storage_tracker) + self.redistrib_planner = RedistribPlanner( + self.args.cluster, self.storage_tracker + ) self.rpcs = rpcs self.src_rpc = list(filter(lambda rpc: rpc.is_src, rpcs))[0] @@ -51,6 +56,7 @@ class FunctionExecutor: # Create model function calls. for rpc in self.rpcs: func_call = ModelFunctionCall( + args=self.args, rpc=rpc, src_rpc=self.src_rpc, stream=stream, diff --git a/realhf/system/generation_server.py b/realhf/system/generation_server.py index 9bf39dc..227554d 100644 --- a/realhf/system/generation_server.py +++ b/realhf/system/generation_server.py @@ -20,7 +20,6 @@ from realhf.base import ( pkg_version, seeding, ) -from realhf.base.cluster import spec as cluster_spec from realhf.system.worker_base import PollResult, Worker logger = logging.getLogger(__name__) @@ -139,7 +138,7 @@ class GenerationServer(Worker): map(str, range(gpu_utils.gpu_count())) ) else: - servers_per_node = cluster_spec.n_gpus_per_node // self.config.tp_size + servers_per_node = self.args.cluster.n_gpus_per_node // self.config.tp_size idx_on_this_node = self.worker_index % servers_per_node self.base_gpu_id = idx_on_this_node * self.config.tp_size @@ -159,7 +158,7 @@ class GenerationServer(Worker): # NOTE: Ports returned by `find_multiple_free_ports` are unique, # but SGLang servers still encounter conflicts. # Use a clearance period to hack over this issue. - servers_per_node = cluster_spec.n_gpus_per_node // self.config.tp_size + servers_per_node = self.args.cluster.n_gpus_per_node // self.config.tp_size idx_on_this_node = self.worker_index % servers_per_node time.sleep(idx_on_this_node * PORT_CLEARANCE_PERIOD / servers_per_node) @@ -169,6 +168,7 @@ class GenerationServer(Worker): high=60000, experiment_name=self.experiment_name, trial_name=self.trial_name, + lockfile_root=os.path.join(constants.get_cache_path(self.args), "ports"), ) server_port = ports[0] nccl_port = ports[1] diff --git a/realhf/system/gserver_manager.py b/realhf/system/gserver_manager.py index aabdefc..e746fe9 100644 --- a/realhf/system/gserver_manager.py +++ b/realhf/system/gserver_manager.py @@ -29,7 +29,7 @@ class AllocateRolloutInput: qid: str -class GserverManager(AsyncWorker): +class GserverManager(Worker): """This worker has the following functionalities: 1. As a router, it schedules generation requests and returns the best server urls to clients for submitting generation requests. @@ -71,7 +71,7 @@ class GserverManager(AsyncWorker): self.server_urls = [] # recover info - self.__recover_run, self.__recover_info = recover.load_recover_info() + self.__recover_run, self.__recover_info = recover.load_recover_info(self.args) if self.__recover_run: # update weights will be automatically triggered upon the first schedule_request # self._last_param_realloc_step will also be updated @@ -110,11 +110,7 @@ class GserverManager(AsyncWorker): epoch = self.__recover_info.last_step_info.epoch + 1 epochstep = self.__recover_info.last_step_info.epoch_step + 1 globalstep = self.__recover_info.last_step_info.global_step + 1 - save_root = os.path.join( - constants.MODEL_SAVE_ROOT, - constants.experiment_name(), - constants.trial_name(), - ) + save_root = constants.get_save_path(self.args) role_path = os.path.join(save_root, role) if not os.path.exists(role_path): raise RuntimeError( @@ -150,9 +146,7 @@ class GserverManager(AsyncWorker): self._loaded_recover_weights = True else: realloc_dir = os.path.join( - constants.PARAM_REALLOC_PATH, - constants.experiment_name(), - constants.trial_name(), + constants.get_param_realloc_path(self.args), self.model_name.role, str(realloc_version), ) @@ -213,7 +207,7 @@ class GserverManager(AsyncWorker): url = min(self.server_urls, key=lambda k: self._server_token_usage[k]) return self.server_urls.index(url) - async def _poll_async(self): + def _poll(self): if not self.thread: # Find addresses of generation servers self.server_urls = self._discover_servers(self.config.n_servers) @@ -292,9 +286,7 @@ class GserverManager(AsyncWorker): # clear old weights realloc_root = os.path.join( - constants.PARAM_REALLOC_PATH, - constants.experiment_name(), - constants.trial_name(), + constants.get_param_realloc_path(self.args), self.model_name.role, ) if os.path.exists(realloc_root): @@ -483,6 +475,7 @@ class GserverManager(AsyncWorker): port = network.find_free_port( experiment_name=self.experiment_name, trial_name=self.trial_name, + lockfile_root=os.path.join(constants.get_cache_path(self.args), "ports"), ) self.manager_addr = f"{network.gethostip()}:{port}" diff --git a/realhf/system/master_worker.py b/realhf/system/master_worker.py index ef4d1ed..cabed4c 100644 --- a/realhf/system/master_worker.py +++ b/realhf/system/master_worker.py @@ -97,15 +97,8 @@ class MasterWorker(worker_base.AsyncWorker): freq_sec=config.exp_ctrl.eval_freq_secs, ) - self.MODEL_SAVE_ROOT = os.path.join( - constants.MODEL_SAVE_ROOT, - config.worker_info.experiment_name, - config.worker_info.trial_name, - ) - os.makedirs(self.MODEL_SAVE_ROOT, exist_ok=True) - self.__initialized = False - self.__recover_run, self.__recover_info = recover.load_recover_info() + self.__recover_run, self.__recover_info = recover.load_recover_info(self.args) if self.__recover_info is not None: logger.info( f"Loaded recover info: recover_start={self.__recover_info.recover_start}, " @@ -305,9 +298,7 @@ class MasterWorker(worker_base.AsyncWorker): notes=self.wandb_config.notes, tags=self.wandb_config.tags, config=self.wandb_config.config, - dir=os.path.join( - constants.LOG_ROOT, constants.experiment_name(), constants.trial_name() - ), + dir=constants.get_log_path(self.args), force=True, id=f"{constants.experiment_name()}_{constants.trial_name()}_train", resume="allow", @@ -355,6 +346,7 @@ class MasterWorker(worker_base.AsyncWorker): # Create coroutines for model RPCs. logger.debug(f"Creating asyncio coroutines...") self.func_executor = FunctionExecutor( + args=self.args, rpcs=self.__model_rpcs, msid2mwid=self.config.msid2mwid, stream=self.__stream, @@ -599,7 +591,7 @@ class MasterWorker(worker_base.AsyncWorker): hash_vals_to_ignore=self.__rpc_ctrl.used_hash_vals_this_epoch, ) - recover.dump_recover_info(recover_info) + recover.dump_recover_info(self.args, recover_info) logger.info("Dumped recover info to file.") logger.info(f"Will recover from: {recover_info.recover_start}") logger.info( diff --git a/realhf/system/model_function_call.py b/realhf/system/model_function_call.py index 8bda855..87fb6e4 100644 --- a/realhf/system/model_function_call.py +++ b/realhf/system/model_function_call.py @@ -56,6 +56,7 @@ class RPCCorountineControl: class ModelFunctionCall: def __init__( self, + args, rpc: dfg.MFCDef, src_rpc: dfg.MFCDef, stream: request_reply_stream.NameResolvingRequestClient, @@ -68,6 +69,8 @@ class ModelFunctionCall: summary_writer: SummaryWriter | None, ): + self.args = args + self.rpc = rpc self.src_rpc = src_rpc self.stream = stream @@ -82,12 +85,6 @@ class ModelFunctionCall: for msid, mwid in msid2mwid.items(): self.mwid2msids[mwid].append(msid) - self.model_save_root = os.path.join( - constants.MODEL_SAVE_ROOT, - constants.experiment_name(), - constants.trial_name(), - ) - self.rpc_ctrl = ctrl self.buffers = buffers self.redistrib_planner = redistrib_planner @@ -221,7 +218,7 @@ class ModelFunctionCall: for p in payloads.values(): p.post_hooks.append("save") save_dir = os.path.join( - self.model_save_root, + constants.get_log_path(self.args), rpc.model_name.role, f"epoch{ctrl.step_info.epoch + 1}" f"epochstep{ctrl.step_info.epoch_step + 1}" diff --git a/realhf/system/model_worker.py b/realhf/system/model_worker.py index e363aa6..9a116af 100644 --- a/realhf/system/model_worker.py +++ b/realhf/system/model_worker.py @@ -43,12 +43,6 @@ from realhf.base import ( timeutil, topology, ) -from realhf.base.monitor import ( - CUDATimeMarkType, - cuda_tmark, - cuda_tmarked, - dump_tmark_db, -) from realhf.impl.model.nn.real_llm_api import ReaLModel from realhf.impl.model.utils import cuda_graph from realhf.system import request_reply_stream, worker_base @@ -136,7 +130,7 @@ class ModelWorker(worker_base.Worker): r = self.config.worker_info # recover info - self.__recover_run, self.__recover_info = recover.load_recover_info() + self.__recover_run, self.__recover_info = recover.load_recover_info(self.args) # Whether to enable profiling is controlled by the following environment variables. self.__enable_profiler = os.getenv("REAL_DUMP_TRACE", "0") == "1" @@ -174,11 +168,7 @@ class ModelWorker(worker_base.Worker): epoch = self.__recover_info.last_step_info.epoch + 1 epochstep = self.__recover_info.last_step_info.epoch_step + 1 globalstep = self.__recover_info.last_step_info.global_step + 1 - save_root = os.path.join( - constants.MODEL_SAVE_ROOT, - constants.experiment_name(), - constants.trial_name(), - ) + save_root = constants.get_save_path(self.args) if epoch > 0: role_path = os.path.join(save_root, role) if os.path.exists(role_path): @@ -251,6 +241,7 @@ class ModelWorker(worker_base.Worker): ) self.__pg_info = global_comm.setup_global_comm( + args=self.args, expr_name=self.__experiment_name, trial_name=self.__trial_name, worker_index=self.__worker_index, @@ -312,13 +303,6 @@ class ModelWorker(worker_base.Worker): self.__dataset_dp_rank, self.__dataset_dp_size, self.config.tokenizer_name_or_path, - self.config.worker_info.experiment_name, - self.config.worker_info.trial_name, - cache_root=( - None - if not self.config.use_dataset_cache - else self.config.dataset_cahce_root - ), ) for d in self.config.datasets ] @@ -388,9 +372,7 @@ class ModelWorker(worker_base.Worker): and hasattr(d, "filter") ): dataset_indices_path = os.path.join( - constants.MODEL_SAVE_ROOT, - constants.experiment_name(), - constants.trial_name(), + constants.get_save_path(self.args), "dataset_indices", f"{self._dp_rank}_{i}.npy", ) @@ -438,13 +420,6 @@ class ModelWorker(worker_base.Worker): s.id.dp_rank, s.id.topo.get_dim("data"), self.__models[s.id.model_name].tokenizer, - self.config.worker_info.experiment_name, - self.config.worker_info.trial_name, - cache_root=( - None - if not self.config.use_dataset_cache - else self.config.dataset_cahce_root - ), ) eval_dataloader = torch.utils.data.DataLoader( eval_dataset, @@ -497,16 +472,15 @@ class ModelWorker(worker_base.Worker): self.__param_realloc(hook_data) elif hook == "offload": # NOTE: Profiling (or cuda synchronization) will cause an overhead ~0.5s. - with cuda_tmarked("offload", CUDATimeMarkType.mem_layout): - m = self.__unwrapped_models[hook_data["model_name"]] - if not isinstance(m, ReaLModel): - logger.warning( - f"Model {hook_data['model_name']} (type={type(m)}) is not a ReaLModel, " - f"so it can't use offload." - ) - return - if not m._offloaded: - m.async_offload() + m = self.__unwrapped_models[hook_data["model_name"]] + if not isinstance(m, ReaLModel): + logger.warning( + f"Model {hook_data['model_name']} (type={type(m)}) is not a ReaLModel, " + f"so it can't use offload." + ) + return + if not m._offloaded: + m.async_offload() elif hook == "save": self.__save_model(hook_data) elif hook == "evaluate": @@ -602,15 +576,11 @@ class ModelWorker(worker_base.Worker): except StopIteration: # Upon the first fetch request, filter dataset and create dataloader. eval_scores_path = os.path.join( - constants.MODEL_SAVE_ROOT, - constants.experiment_name(), - constants.trial_name(), + constants.get_save_path(self.args), "dataset_eval_scores.json", ) dataset_indices_path = os.path.join( - constants.MODEL_SAVE_ROOT, - constants.experiment_name(), - constants.trial_name(), + constants.get_save_path(self.args), "dataset_indices", f"{dp_rank}_{dataset_id}.npy", ) @@ -668,7 +638,7 @@ class ModelWorker(worker_base.Worker): continue if self.data_manager.has_data(x.ids[0]): continue - data_loaded.append(x) + data_loaded.append(x.cpu()) self.data_manager.store(x) assert len(set([x.ids[0] for x in data_loaded])) == len(data_loaded) @@ -700,25 +670,23 @@ class ModelWorker(worker_base.Worker): "dataset_size": self.dataset_size, } elif request.handle_name == "clear_data_cache": - with cuda_tmarked("clear_data_cache", CUDATimeMarkType.misc): - ids = request.data - self.data_manager.remove(ids) - gc.collect() - if ( - self.config.cuda_cache_cleanliness - and self.__clear_cache_frequency.check() - ): - st = time.monotonic() - self._clear_memory(force=True) - et = time.monotonic() - blogger.debug( - f"Model worker {self.__worker_index} cleared cache in {et-st:.4f}s. " - ) + ids = request.data + self.data_manager.remove(ids) + gc.collect() + if ( + self.config.cuda_cache_cleanliness + and self.__clear_cache_frequency.check() + ): + st = time.monotonic() + self._clear_memory(force=True) + et = time.monotonic() + blogger.debug( + f"Model worker {self.__worker_index} cleared cache in {et-st:.4f}s. " + ) logger.debug( - "Get clear_data_cache, dump cuda tmark. " + "Get clear_data_cache. " f"Remaining data in local storage: {self.data_manager.storage_size()}. " ) - dump_tmark_db(self.__worker_index) res = request_reply_stream.NoResponse() self.__reply_queue.put_nowait((request, res)) self.__request_sample_size[request.request_id] = 1 @@ -820,9 +788,7 @@ class ModelWorker(worker_base.Worker): tik = time.perf_counter() global_step = self.__models[model_name].version.global_step realloc_dir = os.path.join( - constants.PARAM_REALLOC_PATH, - constants.experiment_name(), - constants.trial_name(), + constants.get_param_realloc_path(self.args), model_name.role, str(global_step), ) @@ -852,9 +818,7 @@ class ModelWorker(worker_base.Worker): def _get_setup_logdir(self, name): subdir = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), + constants.get_log_path(self.args), name, f"setup{self._setup_counter}", ) @@ -990,9 +954,7 @@ class ModelWorker(worker_base.Worker): raise NotImplementedError(f"Unknown MFC type: {request.handle_name}.") eval_scores_path = os.path.join( - constants.MODEL_SAVE_ROOT, - constants.experiment_name(), - constants.trial_name(), + constants.get_save_path(self.args), "dataset_eval_scores.json", ) eval_scores = {} @@ -1061,7 +1023,6 @@ class ModelWorker(worker_base.Worker): dist.barrier(group=constants.cpu_parallelism_group()) return res - @cuda_tmark("data_transfer", CUDATimeMarkType.comm) def __data_transfer_among_workers(self, hook_data: Dict[str, Any]): meta_sample = hook_data["meta_sample"] @@ -1130,9 +1091,7 @@ class ModelWorker(worker_base.Worker): global_step = int(global_step.item()) realloc_dir = os.path.join( - constants.PARAM_REALLOC_PATH, - constants.experiment_name(), - constants.trial_name(), + constants.get_param_realloc_path(self.args), from_model_name.role, str(global_step), ) @@ -1284,7 +1243,6 @@ class ModelWorker(worker_base.Worker): f"Time consumption: {float(t):.4f}s." ) - @cuda_tmark("post_response", CUDATimeMarkType.misc) def maybe_post_responses(self): ready_to_post = [] while True: @@ -1335,7 +1293,6 @@ class ModelWorker(worker_base.Worker): time.sleep(_MODEL_WORKER_POLL_REQUESTS_INTERVAL_SECS) pass - @cuda_tmark("receive_request", CUDATimeMarkType.misc) def maybe_receive_requests(self): tik = time.perf_counter() while time.perf_counter() - tik < _MODEL_WORKER_POLL_REQUESTS_SECS: diff --git a/realhf/system/push_pull_stream.py b/realhf/system/push_pull_stream.py index a2f4019..e5fdb07 100644 --- a/realhf/system/push_pull_stream.py +++ b/realhf/system/push_pull_stream.py @@ -1,4 +1,5 @@ import logging +import os from queue import Empty as QueueEmpty from typing import Any, Dict, List, Optional, Union @@ -6,7 +7,7 @@ import orjson import zmq from zmq.utils.strtypes import asbytes -from realhf.base import logging, name_resolve, names, network +from realhf.base import constants, logging, name_resolve, names, network logger = logging.getLogger("ZMQ Push-Pull Stream") @@ -160,12 +161,16 @@ class NameResolvingZmqPusher(ZMQJsonPusher): class NameResolvingZmqPuller(ZMQJsonPuller): - def __init__(self, experiment_name, trial_name, puller_index, **kwargs): + def __init__(self, args, puller_index, **kwargs): + experiment_name = args.experiment_name + trial_name = args.trial_name name = names.push_pull_stream( experiment_name, trial_name, stream_name=f"puller{puller_index}" ) host, port = network.gethostip(), network.find_free_port( - experiment_name=experiment_name, trial_name=trial_name + experiment_name=experiment_name, + trial_name=trial_name, + lockfile_root=os.path.join(constants.get_cache_path(args), "ports"), ) addr = f"{host}:{port}" name_resolve.add(name, addr) diff --git a/realhf/system/redistributor.py b/realhf/system/redistributor.py index 61daa6d..05f6431 100644 --- a/realhf/system/redistributor.py +++ b/realhf/system/redistributor.py @@ -6,7 +6,7 @@ import itertools from collections import defaultdict from typing import * -from realhf.base.cluster import spec as cluster_spec +from realhf.api.cli_args import ClusterSpecConfig class GlobalStorageTracker: @@ -70,7 +70,10 @@ class RedistribStep: class RedistribPlanner: - def __init__(self, storage_tracker: GlobalStorageTracker): + def __init__( + self, cluster_config: ClusterSpecConfig, storage_tracker: GlobalStorageTracker + ): + self.cluster_config = cluster_config self.storage_tracker = storage_tracker def derive_plan( @@ -269,8 +272,8 @@ class RedistribPlanner: return self._group_bcast_transfers() def _on_same_node(self, i, j) -> bool: - return (i // cluster_spec.n_gpus_per_node) == ( - j // cluster_spec.n_gpus_per_node + return (i // self.cluster_config.n_gpus_per_node) == ( + j // self.cluster_config.n_gpus_per_node ) def _select_best_bcast_source(self, source_gpus, target_gpus): diff --git a/realhf/system/rollout_worker.py b/realhf/system/rollout_worker.py index 045bf06..1fe02fc 100644 --- a/realhf/system/rollout_worker.py +++ b/realhf/system/rollout_worker.py @@ -87,7 +87,7 @@ class RolloutWorker(AsyncWorker): self.rollout_stat = RolloutStat() # recover info - self.__recover_run, self.__recover_info = recover.load_recover_info() + self.__recover_run, self.__recover_info = recover.load_recover_info(self.args) return config.worker_info @@ -101,13 +101,6 @@ class RolloutWorker(AsyncWorker): self.worker_index, self.worker_count, self.config.tokenizer_path, - self.config.worker_info.experiment_name, - self.config.worker_info.trial_name, - cache_root=( - None - if not self.config.use_dataset_cache - else self.config.dataset_cahce_root - ), ) for d in self.config.datasets ] @@ -129,9 +122,7 @@ class RolloutWorker(AsyncWorker): # Recover indices for dynamic dataset if hasattr(self.dataset, "filter"): dataset_indices_path = os.path.join( - constants.MODEL_SAVE_ROOT, - constants.experiment_name(), - constants.trial_name(), + constants.get_log_path(self.args), f"dataset_indices_{self.worker_index}.npy", ) if os.path.exists(dataset_indices_path): @@ -156,15 +147,11 @@ class RolloutWorker(AsyncWorker): self.is_new_epoch = True # Upon the first fetch request, filter dataset and create dataloader. eval_scores_path = os.path.join( - constants.MODEL_SAVE_ROOT, - constants.experiment_name(), - constants.trial_name(), + constants.get_log_path(self.args), "dataset_eval_scores.json", ) dataset_indices_path = os.path.join( - constants.MODEL_SAVE_ROOT, - constants.experiment_name(), - constants.trial_name(), + constants.get_log_path(self.args), f"dataset_indices_{self.worker_index}.npy", ) if hasattr(self.dataset, "filter") and os.path.exists(eval_scores_path): diff --git a/realhf/system/stream_dataset.py b/realhf/system/stream_dataset.py index f688c61..2c0a27a 100644 --- a/realhf/system/stream_dataset.py +++ b/realhf/system/stream_dataset.py @@ -2,6 +2,7 @@ import queue import sys import threading import time +import traceback from typing import Any, List, Optional from torch.utils.data import ConcatDataset, Dataset @@ -23,6 +24,7 @@ class PullerStreamDataset(Dataset): def __init__( self, util: DatasetUtility, + args, dataset_cfgs: List[DatasetAbstraction], pull_timeout_ms=100, ): @@ -35,8 +37,6 @@ class PullerStreamDataset(Dataset): dp_rank=util.dp_rank, world_size=util.world_size, tokenizer_or_tokenizer_name=util.tokenizer, - experiment_name=constants.experiment_name(), - trial_name=constants.trial_name(), ) for dataset_cfg in dataset_cfgs ] @@ -51,6 +51,8 @@ class PullerStreamDataset(Dataset): self.data_queue = queue.Queue(maxsize=self.dataset_size * util.world_size) self._stop_event = threading.Event() + self.args = args + # Pass ZMQ context (thread-safe) and let worker create the socket self.util = util self.worker_thread = threading.Thread(target=self._pull_data_worker) @@ -60,33 +62,33 @@ class PullerStreamDataset(Dataset): """Worker thread that creates its own ZMQ puller and streams data.""" # Initialize the puller inside the worker thread stream = NameResolvingZmqPuller( - constants.experiment_name(), - constants.trial_name(), + self.args, puller_index=self.util.dp_rank, ) - try: - while not self._stop_event.is_set(): + processed_data = None + while not self._stop_event.is_set(): + if processed_data is not None: try: - data = stream.pull(timeout_ms=self.pull_timeout_ms) - processed_data = [ - SequenceSample.from_json_compatible(x) for x in data - ] - logger.debug( - f"Get data {[x.ids[0] for x in processed_data]} from puller stream." - ) - self.data_queue.put(processed_data) - except queue.Empty: - logger.debug(f"No data from puller stream.") + self.data_queue.put_nowait(processed_data) + processed_data = None + except queue.Full: time.sleep(0.1) continue - finally: - # Ensure socket is closed in the same thread - del stream - # Exit if this thread has an error - sys.exit(1) + try: + data = stream.pull(timeout_ms=self.pull_timeout_ms) + processed_data = [SequenceSample.from_json_compatible(x) for x in data] + logger.debug( + f"Get data {[x.ids[0] for x in processed_data]} from puller stream." + ) + except queue.Empty: + logger.debug(f"No data from puller stream.") + time.sleep(0.1) + continue def __getitem__(self, idx: int) -> Optional[Any]: samples = [] + if not self.worker_thread.is_alive(): + raise RuntimeError("Stream dataset puller thread is not alive.") while True: try: samples += self.data_queue.get_nowait() @@ -99,8 +101,6 @@ class PullerStreamDataset(Dataset): def __del__(self): self._stop_event.set() - if self.worker_thread.is_alive(): - self.worker_thread.join(timeout=1.0) register_dataset("puller_stream", PullerStreamDataset) diff --git a/realhf/system/worker_base.py b/realhf/system/worker_base.py index d673f50..de41aa3 100644 --- a/realhf/system/worker_base.py +++ b/realhf/system/worker_base.py @@ -568,6 +568,7 @@ class Worker: self.__worker_index = worker_info.worker_index experiment = system_api.make_experiment(name=worker_info.experiment_name) + self.args = experiment expr_config = experiment.initial_setup() if isinstance(expr_config, list): diff --git a/tests/agent/test_math_single_step_agent.py b/tests/agent/test_math_single_step_agent.py index d9cfcac..ebcd85c 100644 --- a/tests/agent/test_math_single_step_agent.py +++ b/tests/agent/test_math_single_step_agent.py @@ -1,6 +1,7 @@ import asyncio import json import os +from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import numpy as np @@ -21,7 +22,7 @@ def mock_env(): @pytest.fixture -def agent_config(): +def agent_config(tmp_path): return { "gconfig": MagicMock(n=2), "tokenizer_path": "/storage/openpsi/models/Qwen__Qwen2.5-0.5B-Instruct/", @@ -29,6 +30,7 @@ def agent_config(): "success_rate_ub": 1.0, "reward_scaling": 2.0, "reward_bias": 0.1, + "answer_save_path": tmp_path, } @@ -140,48 +142,34 @@ async def test_collect_trajectory_empty_act_queue(agent, mock_env, mock_prompt): def test_log_rewards_to_file(agent, tmp_path): # Setup test directories - with ( - patch("realhf.base.constants.LOG_ROOT", tmp_path), - patch("realhf.base.constants.experiment_name", return_value="test_exp"), - patch("realhf.base.constants.trial_name", return_value="test_trial"), - ): - agent.log_rewards_to_file( - qid="123", - prompt="test_prompt", - prompt_len=3, - answers=["answer1", "answer2"], - seqlens=[5, 6], - rewards=[0.5, 0.7], - success=[True, False], - version_starts=[1, 2], - version_ends=[2, 3], - ) + agent.log_rewards_to_file( + qid="123", + prompt="test_prompt", + prompt_len=3, + answers=["answer1", "answer2"], + seqlens=[5, 6], + rewards=[0.5, 0.7], + success=[True, False], + version_starts=[1, 2], + version_ends=[2, 3], + ) - # Check generated file - gen_file_path = ( - tmp_path / "test_exp" / "test_trial" / "generated" / "1" / "123.txt" - ) - assert gen_file_path.exists() - with open(gen_file_path) as f: - content = f.read() - assert "idx: 1 / 2" in content - assert "seqlen: 5" in content - assert "test_prompt" in content + # Check generated file + gen_file_path = Path(agent.answer_save_path) / "1" / "123.txt" + assert gen_file_path.exists() + with open(gen_file_path) as f: + content = f.read() + assert "idx: 1 / 2" in content + assert "seqlen: 5" in content + assert "test_prompt" in content - # Check monitor file - monitor_file_path = ( - tmp_path - / "test_exp" - / "test_trial" - / "training_monitor" - / "1" - / "123.jsonl" - ) - assert monitor_file_path.exists() - with open(monitor_file_path) as f: - data = json.loads(f.readline()) - assert data["version_start"] == 1 - assert data["prompt_len"] == 3 + # Check monitor file + monitor_file_path = Path(agent.answer_save_path) / "1" / "123.jsonl" + assert monitor_file_path.exists() + with open(monitor_file_path) as f: + data = json.loads(f.readline()) + assert data["version_start"] == 1 + assert data["prompt_len"] == 3 def test_reward_calculation(agent): diff --git a/tests/comm/test_data_transfer.py b/tests/comm/test_data_transfer.py index 568bfd2..b15b51e 100644 --- a/tests/comm/test_data_transfer.py +++ b/tests/comm/test_data_transfer.py @@ -14,6 +14,7 @@ import pytest import torch import torch.distributed as dist +from realhf.api.cli_args import ClusterSpecConfig from realhf.api.core.config import ModelName, ModelShardID from realhf.api.core.data_api import SequenceSample from realhf.base import constants, testing, topology @@ -166,7 +167,7 @@ def _test_data_transfer( data_manager.setup_process_groups() storage_tracker = GlobalStorageTracker(dist.get_world_size()) - planner = RedistribPlanner(storage_tracker) + planner = RedistribPlanner(ClusterSpecConfig(), storage_tracker) key = "input_ids" diff --git a/tests/cpp_extensions/test_grouped_gemm.py b/tests/cpp_extensions/test_grouped_gemm.py index d1bfce5..177591a 100644 --- a/tests/cpp_extensions/test_grouped_gemm.py +++ b/tests/cpp_extensions/test_grouped_gemm.py @@ -4,12 +4,12 @@ import random import time +import uuid import pytest import torch -import realhf.base.constants as constants -import realhf.base.testing as testing +from realhf.base import constants, name_resolve, testing # This is a test for grouped_gemm experts implementation of MoE. @@ -99,10 +99,7 @@ def run_grouped_mlp(num_tokens, tp_size, token_dispatch_strategy, seed=1): ) -@pytest.mark.skipif( - not torch.cuda.is_available(), - reason="This test requires GPU to run", -) +@pytest.mark.skip("grouped_gemm is not used for now.") @pytest.mark.parametrize("num_tokens", [200]) @pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("token_dispatch_strategy", ["random"]) @@ -123,10 +120,7 @@ def test_grouped_mlp( test.launch() -@pytest.mark.skipif( - not torch.cuda.is_available(), - reason="This test requires GPU to run", -) +@pytest.mark.skip("grouped_gemm is not used for now.") @pytest.mark.gpu def test_grouped_gemm(): torch.manual_seed(1) diff --git a/tests/cpp_extensions/test_interval_ops.py b/tests/cpp_extensions/test_interval_ops.py index 6e326a4..fb7cf83 100644 --- a/tests/cpp_extensions/test_interval_ops.py +++ b/tests/cpp_extensions/test_interval_ops.py @@ -36,7 +36,7 @@ def maybe_synchronize_cuda(): torch.cuda.synchronize() -@pytest.mark.skipif(not torch.cuda.is_available(), reason="This test requires a GPU.") +@pytest.mark.skip("interval_ops are not used now.") @pytest.mark.parametrize( "n_intervals", list(reversed([1, 100, 500, 1000, 2000, 4000, 10000, 100000])) ) @@ -86,7 +86,7 @@ def test_get(n_intervals: int, dtype: torch.dtype): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="This test requires a GPU.") +@pytest.mark.skip("interval_ops are not used now.") @pytest.mark.parametrize( "n_intervals", list(reversed([1, 10, 100, 500, 1000, 1000, 10000, 100000])) ) diff --git a/tests/data/test_load_data.py b/tests/data/test_load_data.py index b771dfd..e0cec40 100644 --- a/tests/data/test_load_data.py +++ b/tests/data/test_load_data.py @@ -21,8 +21,6 @@ def _validate_dataset(cfg: config_api.DatasetAbstraction, tokenizer): dp_rank=0, world_size=1, tokenizer_or_tokenizer_name=tokenizer, - experiment_name=str(uuid.uuid4()), - trial_name=str(uuid.uuid4()), ) dataloader = DataLoader( dataset, diff --git a/tests/distributed/test_name_resolve.py b/tests/distributed/test_name_resolve.py index 569571c..bb035e3 100644 --- a/tests/distributed/test_name_resolve.py +++ b/tests/distributed/test_name_resolve.py @@ -21,13 +21,13 @@ BACKENDS = [ ("nfs", {}), ("ray", {}), ] -if os.environ.get("REAL_ETCD_ADDR"): +if os.environ.get("TESTING_ETCD_ADDR"): BACKENDS.append( ( "etcd3", { - "host": os.getenv("REAL_ETCD_ADDR").split(":")[0], - "port": int(os.getenv("REAL_ETCD_ADDR").split(":")[1]), + "host": os.getenv("TESTING_ETCD_ADDR").split(":")[0], + "port": int(os.getenv("TESTING_ETCD_ADDR").split(":")[1]), }, ) ) @@ -43,12 +43,9 @@ def name_resolve(request): temp_dir = tempfile.mkdtemp() from realhf.base.name_resolve import NfsNameRecordRepository - original_root = NfsNameRecordRepository.RECORD_ROOT - NfsNameRecordRepository.RECORD_ROOT = temp_dir - repo = NfsNameRecordRepository() + repo = NfsNameRecordRepository(temp_dir) yield repo repo.reset() - NfsNameRecordRepository.RECORD_ROOT = original_root shutil.rmtree(temp_dir) elif backend_type == "memory": from realhf.base.name_resolve import MemoryNameRecordRepository @@ -208,17 +205,6 @@ def test_reset(name_resolve): name_resolve.delete("test_key_no_delete") -def test_context_manager(name_resolve): - """Test context manager functionality.""" - with name_resolve.__class__() as repo: - repo.add("test_key", "test_value", delete_on_exit=True) - assert repo.get("test_key") == "test_value" - - # Key should be deleted after context exits - with pytest.raises(NameEntryNotFoundError): - name_resolve.get("test_key") - - def test_concurrent_access(name_resolve): """Test concurrent access to the same key.""" name_resolve.add("test_key", "initial_value") @@ -648,7 +634,9 @@ def test_corner_case_get_same_as_prefix(name_resolve): assert set(keys) == {"prefix", "prefix/child"} -@pytest.mark.skipif(os.getenv("REAL_ETCD_ADDR") is None, reason="ETCD3 not configured") +@pytest.mark.skipif( + os.getenv("TESTING_ETCD_ADDR") is None, reason="ETCD3 not configured" +) def test_etcd3_specific_features(name_resolve): if not isinstance(name_resolve, Etcd3NameRecordRepository): pytest.skip("ETCD3 specific test") @@ -663,7 +651,9 @@ def test_etcd3_specific_features(name_resolve): name_resolve.get("test_key") -@pytest.mark.skipif(os.getenv("REAL_ETCD_ADDR") is not None, reason="NFS specific test") +@pytest.mark.skipif( + os.getenv("TESTING_ETCD_ADDR") is not None, reason="NFS specific test" +) def test_nfs_specific_features(name_resolve): """Test features specific to NFS backend.""" from realhf.base.name_resolve import NfsNameRecordRepository diff --git a/tests/experiments/test_buffer_recover.py b/tests/experiments/test_buffer_recover.py index 9b7151b..972bba4 100644 --- a/tests/experiments/test_buffer_recover.py +++ b/tests/experiments/test_buffer_recover.py @@ -8,6 +8,7 @@ from typing import * import pytest from realhf.api.cli_args import ( + ClusterSpecConfig, ExperimentSaveEvalControl, MFCConfig, ModelTrainEvalConfig, @@ -15,7 +16,7 @@ from realhf.api.cli_args import ( PromptAnswerDatasetConfig, PromptOnlyDatasetConfig, ) -from realhf.base import cluster, logging, name_resolve, testing +from realhf.base import logging, name_resolve, testing from realhf.experiments.common.null_exp import NullPPOConfig, NullSFTConfig from tests.experiments.utils import run_test_exp from tests.fixtures import * @@ -64,11 +65,8 @@ def test_buffer_recover( ): _, dataset_size = math_code_dataset_with_size # Setup experiment env. Should be done before any other operations. - log_root = tmp_path_factory.mktemp("buffer-recover") - cluster.spec.fileroot = str(log_root) expr_name = str(uuid.uuid4()) trial_name = str(uuid.uuid4()) - testing.clear_name_resolve(expr_name, trial_name) constants.set_experiment_trial_names(expr_name, trial_name) exp_cfg = NullPPOConfig( @@ -114,6 +112,10 @@ def test_buffer_recover( save_freq_steps=2, benchmark_steps=0, ), + cluster=ClusterSpecConfig( + fileroot=str(tmp_path_factory.mktemp("buffer-recover")), + n_gpus_per_node=16, + ), ) os.environ["REAL_SAVE_RECOVER_STATES"] = "1" diff --git a/tests/experiments/test_math_ppo.py b/tests/experiments/test_math_ppo.py index aa5e26b..39d2e91 100644 --- a/tests/experiments/test_math_ppo.py +++ b/tests/experiments/test_math_ppo.py @@ -8,6 +8,7 @@ from typing import * import pytest from realhf.api.cli_args import ( + ClusterSpecConfig, ExperimentSaveEvalControl, GenerationHyperparameters, MFCConfig, @@ -17,7 +18,7 @@ from realhf.api.cli_args import ( PPOHyperparameters, PromptOnlyDatasetConfig, ) -from realhf.base import cluster, testing +from realhf.base import testing from realhf.experiments.common.ppo_math_exp import PPOMATHConfig from tests.experiments.utils import run_test_exp from tests.fixtures import * @@ -73,8 +74,6 @@ def test_ppo_symm( mp, ): # Setup experiment env. Should be done before any other operations. - log_root = tmp_path_factory.mktemp("ppo") - cluster.spec.fileroot = str(log_root) constants.set_experiment_trial_names( testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME ) @@ -117,6 +116,7 @@ def test_ppo_symm( ), ), group_size=2, + cluster=ClusterSpecConfig(fileroot=str(tmp_path_factory.mktemp("ppo"))), ) run_test_exp(exp_cfg) @@ -152,8 +152,6 @@ def test_ppo_decoupled( gmp, ): # Setup experiment env. Should be done before any other operations. - log_root = tmp_path_factory.mktemp("ppo") - cluster.spec.fileroot = str(log_root) constants.set_experiment_trial_names( testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME ) @@ -245,6 +243,7 @@ def test_ppo_decoupled( ), ), group_size=2, + cluster=ClusterSpecConfig(fileroot=str(tmp_path_factory.mktemp("ppo"))), ) run_test_exp(exp_cfg) @@ -275,8 +274,6 @@ def test_ppo_global_reshard( rew_inf, ): # Setup experiment env. Should be done before any other operations. - log_root = tmp_path_factory.mktemp("ppo-global-reshard") - cluster.spec.fileroot = str(log_root) constants.set_experiment_trial_names( testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME ) @@ -368,6 +365,7 @@ def test_ppo_global_reshard( pipeline_parallel_size=critic_train[2], ), ), + cluster=ClusterSpecConfig(fileroot=str(tmp_path_factory.mktemp("ppo"))), ) run_test_exp(exp_cfg) @@ -388,8 +386,6 @@ def test_ppo_param_realloc_sub_device_mesh( critic_inf, ): # Setup experiment env. Should be done before any other operations. - log_root = tmp_path_factory.mktemp("ppo-submesh") - cluster.spec.fileroot = str(log_root) constants.set_experiment_trial_names( testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME ) @@ -484,6 +480,7 @@ def test_ppo_param_realloc_sub_device_mesh( pipeline_parallel_size=2, ), ), + cluster=ClusterSpecConfig(fileroot=str(tmp_path_factory.mktemp("ppo"))), ) run_test_exp(exp_cfg) @@ -503,13 +500,9 @@ def test_ppo_save( bs, ): # Setup experiment env. Should be done before any other operations. - log_root = tmp_path_factory.mktemp("ppo") - cluster.spec.fileroot = str(log_root) constants.set_experiment_trial_names( testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME ) - shutil.rmtree(constants.MODEL_SAVE_ROOT, ignore_errors=True) - os.makedirs(constants.MODEL_SAVE_ROOT, exist_ok=True) total_train_epochs = 3 @@ -604,7 +597,11 @@ def test_ppo_save( pipeline_parallel_size=1, ), ), + cluster=ClusterSpecConfig(fileroot=str(tmp_path_factory.mktemp("ppo"))), ) + shutil.rmtree(constants.get_save_path(exp_cfg), ignore_errors=True) + os.makedirs(constants.get_save_path(exp_cfg), exist_ok=True) + exp_cfg.actor.vllm.hybrid_train = True exp_cfg.actor.vllm.enforce_eager = True @@ -636,9 +633,7 @@ def test_ppo_save( int(os.path.basename(f).split("globalstep")[-1]) for f in os.listdir( os.path.join( - constants.MODEL_SAVE_ROOT, - testing._DEFAULT_EXPR_NAME, - testing._DEFAULT_EXPR_NAME, + constants.get_save_path(exp_cfg), model_name, ) ) diff --git a/tests/experiments/test_sft.py b/tests/experiments/test_sft.py index b0c54ed..bdce110 100644 --- a/tests/experiments/test_sft.py +++ b/tests/experiments/test_sft.py @@ -6,11 +6,12 @@ from typing import * import pytest from realhf.api.cli_args import ( + ClusterSpecConfig, ExperimentSaveEvalControl, ModelTrainEvalConfig, PromptAnswerDatasetConfig, ) -from realhf.base import cluster, testing +from realhf.base import testing from realhf.experiments.common.sft_exp import SFTConfig from tests.experiments.utils import run_test_exp from tests.fixtures import * @@ -32,9 +33,6 @@ def model_class(request): (1, 2, 4), (2, 4, 1), (2, 1, 4), - (4, 2, 2), - (2, 4, 2), - (2, 2, 4), ], ) def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp): @@ -61,8 +59,6 @@ def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp): # Setup experiment env. Should be done before any other operations. - log_root = tmp_path_factory.mktemp("sft") - cluster.spec.fileroot = str(log_root) constants.set_experiment_trial_names( testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME ) @@ -89,6 +85,7 @@ def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp): valid_bs_n_seqs=minbs, fill_to_max_length=False, ), + cluster=ClusterSpecConfig(fileroot=str(tmp_path_factory.mktemp("sft"))), ) run_test_exp(exp_cfg) diff --git a/tests/experiments/utils.py b/tests/experiments/utils.py index 1821dd3..b459cac 100644 --- a/tests/experiments/utils.py +++ b/tests/experiments/utils.py @@ -1,4 +1,5 @@ # Copyright 2025 Ant Group Inc. +import asyncio import functools import multiprocessing as mp from typing import * @@ -6,7 +7,7 @@ from typing import * import pytest from realhf.api.core.system_api import Experiment, register_experiment -from realhf.base import cluster, constants, logging, testing +from realhf.base import constants, logging, testing from realhf.system.worker_base import WorkerServerStatus from tests.fixtures import * @@ -74,7 +75,6 @@ def run_test_exp( logger.info("Configuring master worker...") mas.configure(setup_id=0, worker_info=exp_setup.master_worker[0].worker_info) logger.info("Configuring master worker... Done.") - initd = False # Run model workers in subprocesses barrier = mp.Barrier(len(exp_setup.model_worker)) @@ -98,13 +98,17 @@ def run_test_exp( testcase.start() # Run the master worker. - for _ in range(int(1e4)): - if mas.status == WorkerServerStatus.PAUSED: - break - if not initd: - logger.info("Running master worker lazy initialization...") - mas._poll() - if not initd: - logger.info("Running master worker lazy initialization... Done.") - initd = True + async def run_master_worker(): + initd = False + for _ in range(int(1e4)): + if mas.status == WorkerServerStatus.PAUSED: + break + if not initd: + logger.info("Running master worker lazy initialization...") + await mas._poll_async() + if not initd: + logger.info("Running master worker lazy initialization... Done.") + initd = True + + asyncio.run(run_master_worker()) testcase.wait(timeout=0.1) diff --git a/tests/interfaces/test_multi_task_reward.py b/tests/interfaces/test_multi_task_reward.py index 0dda871..a39efb5 100644 --- a/tests/interfaces/test_multi_task_reward.py +++ b/tests/interfaces/test_multi_task_reward.py @@ -16,7 +16,7 @@ from realhf.api.core.data_api import ( load_hf_tokenizer, ) from realhf.api.core.model_api import FinetuneSpec, Model -from realhf.base import constants, network, testing +from realhf.base import constants, name_resolve, network, testing from tests.fixtures import * @@ -61,8 +61,12 @@ def math_code_dataset(request, save_path): "tokenizer_path", ["/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"] ) def test_multi_task_reward_interface(save_path, tokenizer_path, math_code_dataset): + from realhf.api.cli_args import NameResolveConfig from realhf.impl.dataset.math_code_dataset import MATHCodePromptDataset + name_resolve.reconfigure( + NameResolveConfig("nfs", f"/tmp/areal/{str(uuid.uuid4())}/") + ) dist.init_process_group( rank=0, world_size=1, init_method=f"tcp://localhost:{network.find_free_port()}" ) diff --git a/tests/system/test_gserver_manager.py b/tests/system/test_gserver_manager.py index 10266b6..cd4bb29 100644 --- a/tests/system/test_gserver_manager.py +++ b/tests/system/test_gserver_manager.py @@ -13,7 +13,9 @@ from typing import Optional import aiohttp import pytest +from realhf.api.cli_args import BaseExperimentConfig, NameResolveConfig from realhf.api.core.config import ModelName +from realhf.api.core.system_api import ExpStatus from realhf.api.core.system_api import GserverManager as GserverManagerConfig from realhf.api.core.system_api import WorkerInformation from realhf.base import constants, name_resolve, names, network, testing @@ -42,6 +44,9 @@ def mock_servers(): from fastapi import FastAPI from fastapi.responses import ORJSONResponse, PlainTextResponse + name_resolve.reconfigure( + NameResolveConfig("nfs", "/tmp/areal/test-gserver-manager") + ) ports = network.find_multiple_free_ports(N_SERVERS) # Create mock server responses @@ -124,7 +129,6 @@ def mock_servers(): @pytest.fixture def gserver_manager(request, mock_servers): train_batch_size, offpolicyness = request.param - testing.clear_name_resolve() constants.set_experiment_trial_names( testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME ) @@ -135,6 +139,9 @@ def gserver_manager(request, mock_servers): name_resolve.add_subentry(name, server_urls[0]) name_resolve.add_subentry(name, server_urls[1]) + name = names.experiment_status(constants.experiment_name(), constants.trial_name()) + name_resolve.add(name, ExpStatus.RUNNING) + # Mock requests.get for metrics endpoint m = GserverManager() config = GserverManagerConfig( @@ -153,6 +160,7 @@ def gserver_manager(request, mock_servers): worker_index=0, ), ) + m.args = BaseExperimentConfig() m._configure(config) # launch the server m._poll() diff --git a/tests/system/test_partial_rollout.py b/tests/system/test_partial_rollout.py index bd1f8a0..5e63276 100644 --- a/tests/system/test_partial_rollout.py +++ b/tests/system/test_partial_rollout.py @@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from transformers import PreTrainedTokenizerFast +from realhf.api.cli_args import NameResolveConfig from realhf.api.core.model_api import ( APIGenerateInput, APIGenerateOutput, @@ -67,6 +68,10 @@ def partial_rollout_manager(): request_queue = asyncio.Queue() reply_queue = asyncio.Queue() + name_resolve.reconfigure( + NameResolveConfig("nfs", "/tmp/areal/test-partial-rollout") + ) + testing.clear_name_resolve() constants.set_experiment_trial_names( testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME diff --git a/tests/system/test_stream_dataset.py b/tests/system/test_stream_dataset.py index c5a0309..3c1ab15 100644 --- a/tests/system/test_stream_dataset.py +++ b/tests/system/test_stream_dataset.py @@ -7,9 +7,10 @@ import pytest import torch from torch.utils.data import DataLoader +from realhf.api.cli_args import BaseExperimentConfig, NameResolveConfig from realhf.api.core import config as config_api from realhf.api.core import data_api -from realhf.base import constants, testing +from realhf.base import constants, name_resolve, testing from tests.fixtures import * @@ -64,6 +65,7 @@ def test_load_stream_dataset(prompt_dataset_cfg, tokenizer, mock_puller): testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME ) + name_resolve.reconfigure(NameResolveConfig("nfs", "/tmp/areal/test-stream-dataset")) testing.clear_name_resolve() util = data_api.DatasetUtility( @@ -74,7 +76,12 @@ def test_load_stream_dataset(prompt_dataset_cfg, tokenizer, mock_puller): ) # Test initialization - dataset = PullerStreamDataset(util, prompt_dataset_cfg, pull_timeout_ms=100) + dataset = PullerStreamDataset( + util, + args=BaseExperimentConfig(), + dataset_cfgs=prompt_dataset_cfg, + pull_timeout_ms=100, + ) assert len(dataset) > 0 # Should have non-zero size from prompt dataset assert dataset.data_queue.empty() @@ -91,12 +98,14 @@ def test_load_stream_dataset(prompt_dataset_cfg, tokenizer, mock_puller): assert len(items1) + len(items2) > 0 # Test cleanup + dataset._stop_event.set() del dataset def test_puller_stream_dataset_timeout(prompt_dataset_cfg, tokenizer): from realhf.system.stream_dataset import PullerStreamDataset + name_resolve.reconfigure(NameResolveConfig("nfs", "/tmp/areal/test-stream-dataset")) testing.clear_name_resolve() util = data_api.DatasetUtility( @@ -109,15 +118,19 @@ def test_puller_stream_dataset_timeout(prompt_dataset_cfg, tokenizer): with patch("realhf.system.stream_dataset.NameResolvingZmqPuller") as mock_puller: mock_puller.return_value.pull.side_effect = queue.Empty() - dataset = PullerStreamDataset(util, prompt_dataset_cfg, pull_timeout_ms=10) + dataset = PullerStreamDataset( + util, BaseExperimentConfig(), prompt_dataset_cfg, pull_timeout_ms=10 + ) # Should handle timeout gracefully assert dataset[0] == [] + dataset._stop_event.set() del dataset def test_puller_stream_dataset_stop_event(prompt_dataset_cfg, tokenizer, mock_puller): from realhf.system.stream_dataset import PullerStreamDataset + name_resolve.reconfigure(NameResolveConfig("nfs", "/tmp/areal/test-stream-dataset")) testing.clear_name_resolve() util = data_api.DatasetUtility( @@ -127,7 +140,7 @@ def test_puller_stream_dataset_stop_event(prompt_dataset_cfg, tokenizer, mock_pu tokenizer=tokenizer, ) - dataset = PullerStreamDataset(util, prompt_dataset_cfg) + dataset = PullerStreamDataset(util, BaseExperimentConfig(), prompt_dataset_cfg) assert not dataset._stop_event.is_set() # Trigger stop event and verify thread stops @@ -140,6 +153,7 @@ def test_puller_stream_dataset_stop_event(prompt_dataset_cfg, tokenizer, mock_pu def test_puller_stream_dataset_worker_thread_exception(prompt_dataset_cfg, tokenizer): from realhf.system.stream_dataset import PullerStreamDataset + name_resolve.reconfigure(NameResolveConfig("nfs", "/tmp/areal/test-stream-dataset")) testing.clear_name_resolve() util = data_api.DatasetUtility( @@ -152,7 +166,7 @@ def test_puller_stream_dataset_worker_thread_exception(prompt_dataset_cfg, token with patch("realhf.system.stream_dataset.NameResolvingZmqPuller") as mock_puller: mock_puller.return_value.pull.side_effect = Exception("Test error") - dataset = PullerStreamDataset(util, prompt_dataset_cfg) + dataset = PullerStreamDataset(util, BaseExperimentConfig(), prompt_dataset_cfg) time.sleep(0.1) # Give thread time to crash assert not dataset.worker_thread.is_alive() del dataset diff --git a/training/main_async_ppo.py b/training/main_async_ppo.py index e9343de..f355336 100644 --- a/training/main_async_ppo.py +++ b/training/main_async_ppo.py @@ -8,7 +8,6 @@ import yaml from omegaconf import MISSING, OmegaConf from realhf.api.quickstart.entrypoint import kind_reminder -from realhf.base.constants import init_constants from realhf.experiments.async_exp.async_ppo_math_exp import AsyncPPOMATHConfig from training.utils import run_experiment @@ -37,14 +36,10 @@ def main_ppo_math(args): if args.mode != "ray": raise RuntimeError("This script only supports the `ray` mode.") - init_constants(args) - - from realhf.base.constants import LOG_ROOT + from realhf.base.constants import get_log_path # Save overwritten configuration to yaml - config_save_path = os.path.join( - LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml" - ) + config_save_path = os.path.join(get_log_path(args), "config.yaml") os.makedirs(os.path.dirname(config_save_path), exist_ok=True) with open(config_save_path, "w") as f: config_dict: Dict = dataclasses.asdict(args) diff --git a/training/main_sft.py b/training/main_sft.py index 80f4b2f..57d0ec8 100644 --- a/training/main_sft.py +++ b/training/main_sft.py @@ -8,7 +8,6 @@ import yaml from omegaconf import MISSING, OmegaConf from realhf.api.quickstart.entrypoint import kind_reminder -from realhf.base.constants import init_constants from realhf.experiments.common.sft_exp import SFTConfig from training.utils import run_experiment @@ -37,14 +36,10 @@ def main(args): if args.mode != "ray": raise RuntimeError("This script only supports the `ray` mode.") - init_constants(args) - - from realhf.base.constants import LOG_ROOT + from realhf.base.constants import get_log_path # Save overwritten configuration to yaml - config_save_path = os.path.join( - LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml" - ) + config_save_path = os.path.join(get_log_path(args), "config.yaml") os.makedirs(os.path.dirname(config_save_path), exist_ok=True) with open(config_save_path, "w") as f: config_dict: Dict = dataclasses.asdict(args) diff --git a/training/main_sync_ppo.py b/training/main_sync_ppo.py index ffc6b37..fc87bf0 100644 --- a/training/main_sync_ppo.py +++ b/training/main_sync_ppo.py @@ -8,7 +8,6 @@ import yaml from omegaconf import MISSING, OmegaConf from realhf.api.quickstart.entrypoint import kind_reminder -from realhf.base.constants import init_constants from realhf.experiments.common.ppo_math_exp import PPOMATHConfig from training.utils import run_experiment @@ -37,14 +36,10 @@ def main(args): if args.mode != "ray": raise RuntimeError("This script only supports the `ray` mode.") - init_constants(args) - - from realhf.base.constants import LOG_ROOT + from realhf.base.constants import get_log_path # Save overwritten configuration to yaml - config_save_path = os.path.join( - LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml" - ) + config_save_path = os.path.join(get_log_path(args), "config.yaml") os.makedirs(os.path.dirname(config_save_path), exist_ok=True) with open(config_save_path, "w") as f: config_dict: Dict = dataclasses.asdict(args) diff --git a/training/utils.py b/training/utils.py index a20ecbf..e5d409b 100644 --- a/training/utils.py +++ b/training/utils.py @@ -12,6 +12,7 @@ from typing import Any, List import psutil import ray +from realhf.api.cli_args import NameResolveConfig from realhf.api.core.system_api import Experiment, ExperimentScheduling, TasksGroup from realhf.base import constants, logging, name_resolve, names from realhf.system import WORKER_TYPES, load_worker @@ -64,6 +65,7 @@ class RayWorker: def __init__( self, + args, worker_type: str, worker_cls, kv_store_name, @@ -74,15 +76,16 @@ class RayWorker: os.environ["REAL_MODE"] = "RAY" - name_resolve.reconfigure("ray", actor_name=kv_store_name) + name_recolve_config = NameResolveConfig("ray", ray_actor_name=kv_store_name) + name_resolve.reconfigure(name_recolve_config) self.worker: Worker | AsyncWorker = worker_cls() self.worker_type = worker_type + self.args = args def __repr__(self): return "".join([c.capitalize() for c in self.worker_type.split("_")]) def configure(self, cfg: Any, expr_config: Any): - constants.init_constants(expr_config) worker_info = cfg.worker_info idx = worker_info.worker_index @@ -92,6 +95,7 @@ class RayWorker: self.worker.wandb_config = expr_config.wandb self.worker.swanlab_config = expr_config.swanlab self.worker.tensorboard_config = expr_config.tensorboard + self.worker.args = self.args self.logger = logging.getLogger(f"{self.worker_type} {idx}", "benchmark") self.logger.info(f"Configuring {self.worker_type}...") self.worker._configure(cfg) @@ -125,6 +129,7 @@ def _run_experiment(exp_cfg, expr_name, trial_name): # Initialize ray in the Ray cluster env_vars = constants.get_env_vars( + exp_cfg, WADNB_MODE=exp_cfg.wandb.mode, SWANLAB_MODE=exp_cfg.swanlab.mode, REAL_MODE="ray", @@ -145,7 +150,8 @@ def _run_experiment(exp_cfg, expr_name, trial_name): logger.info("Ray initialized! Ready to run workers.") ray_kv_store_name = f"{expr_name}/{trial_name}/ray_kv_store" - name_resolve.reconfigure("ray", actor_name=ray_kv_store_name) + name_recolve_config = NameResolveConfig("ray", ray_actor_name=ray_kv_store_name) + name_resolve.reconfigure(name_recolve_config) name_resolve.clear_subtree( names.trial_root(experiment_name=expr_name, trial_name=trial_name) @@ -210,6 +216,7 @@ def _run_experiment(exp_cfg, expr_name, trial_name): num_gpus=sch.scheduling.gpu, memory=sch.scheduling.mem * 1024**2, ).remote( + args=exp_cfg, worker_type=worker_type, worker_cls=load_worker(worker_type), kv_store_name=ray_kv_store_name, @@ -239,7 +246,7 @@ def _run_experiment(exp_cfg, expr_name, trial_name): run_jobs = [] for worker_type in all_workers: workers = all_workers[worker_type] - if worker_type in ["master_worker", "rollout_worker", "gserver_manager"]: + if worker_type in ["master_worker", "rollout_worker"]: # Only the rollout worker is asynchronous jobs = [w.run_async.remote() for w in workers] else: @@ -270,7 +277,7 @@ class DualOutput: def run_experiment(exp_cfg, expr_name, trial_name): - log_path = os.path.join(constants.LOG_ROOT, expr_name, trial_name, "main.log") + log_path = os.path.join(constants.get_log_path(exp_cfg), "main.log") with open(log_path, "a") as f: # Create dual output handler dual_out = DualOutput(f, sys.stdout)