mirror of https://github.com/inclusionAI/AReaL
579 lines
18 KiB
Python
579 lines
18 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
# log format constants
|
|
import contextlib
|
|
import copy
|
|
import datetime
|
|
import getpass
|
|
import os
|
|
import pathlib
|
|
import subprocess
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import *
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
import realhf.base.logging as logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from realhf.api.cli_args import BaseExperimentConfig
|
|
from realhf.api.core.config import ModelName
|
|
from realhf.api.core.system_api import ModelShardID
|
|
from realhf.base.topology import ParallelGrid, ProcessTopology
|
|
|
|
|
|
class GlobalMemoryBuffer:
|
|
"""Global buffer to avoid dynamic memory allocations.
|
|
|
|
Caller should ensure that buffers of the same name are not used
|
|
concurrently.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.buffer = {}
|
|
|
|
def get_tensor(self, tensor_shape, dtype, name, force_zero: bool = False):
|
|
device = current_device()
|
|
required_len = int(np.prod(tensor_shape))
|
|
if self.buffer.get((name, dtype), None) is None:
|
|
self.buffer[(name, dtype)] = torch.empty(
|
|
required_len,
|
|
dtype=dtype,
|
|
device=device,
|
|
requires_grad=False,
|
|
)
|
|
elif self.buffer[(name, dtype)].numel() < required_len:
|
|
self.buffer[(name, dtype)] = torch.nn.functional.pad(
|
|
self.buffer[(name, dtype)],
|
|
(0, required_len - self.buffer[(name, dtype)].numel()),
|
|
value=0,
|
|
)
|
|
res = self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
|
|
if force_zero:
|
|
res.zero_()
|
|
return res
|
|
|
|
|
|
# For large models, generation may consume more than 3600s.
|
|
# We set a large value to avoid NCCL timeout issues during generaiton.
|
|
NCCL_DEFAULT_TIMEOUT = datetime.timedelta(seconds=7200)
|
|
|
|
# We may want to use CPU for testing even when CUDA is available.
|
|
TORCH_FORCE_CPU = False
|
|
|
|
# constants in experiment instance scope
|
|
LOCAL_CACHE_DIR = "/tmp/realhf"
|
|
PYTORCH_KERNEL_CACHE_PATH = (
|
|
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
|
|
)
|
|
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
|
|
QUICKSTART_EXPR_CACHE_PATH = str(Path(__file__).parent.parent.parent / ".cache")
|
|
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
|
|
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
|
|
os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True)
|
|
|
|
# Global constants that should be initialized after cluster initialization.
|
|
MODEL_SAVE_ROOT = None
|
|
LOG_ROOT = None
|
|
RECOVER_ROOT = None
|
|
SLURM_LOCK_FILE_NAME = None
|
|
PORT_LOCK_FILE_ROOT = None
|
|
DATASET_CACHE_PATH = None
|
|
PROFILER_CACHE_PATH = None
|
|
PARAM_REALLOC_PATH = None
|
|
SGLANG_CACHE_PATH = None
|
|
TORCH_EXTENSIONS_DIR = None
|
|
BASE_ENVIRONS = None
|
|
|
|
|
|
def init_constants(args: "BaseExperimentConfig"):
|
|
from realhf.base.cluster import init_cluster_spec
|
|
from realhf.base.cluster import spec as cluster_spec
|
|
|
|
init_cluster_spec(args)
|
|
|
|
globals_dict = globals() # Get module's global variables
|
|
|
|
kwargs = dict(
|
|
MODEL_SAVE_ROOT=f"{cluster_spec.fileroot}/checkpoints/{getpass.getuser()}",
|
|
LOG_ROOT=f"{cluster_spec.fileroot}/logs/{getpass.getuser()}",
|
|
RECOVER_ROOT=f"{cluster_spec.fileroot}/recover/{getpass.getuser()}",
|
|
SLURM_LOCK_FILE_NAME=f"{cluster_spec.fileroot}/logs/slurm_scheduler.lock",
|
|
PORT_LOCK_FILE_ROOT=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/ports",
|
|
DATASET_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/datasets",
|
|
PROFILER_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/profiler",
|
|
PARAM_REALLOC_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/param_realloc",
|
|
SGLANG_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/sglang",
|
|
TORCH_EXTENSIONS_DIR=(
|
|
f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/torch/extensions"
|
|
),
|
|
)
|
|
BASE_ENVIRONS = {
|
|
# "PYTHONPATH": "/realhf",
|
|
"REAL_IS_REMOTE": "1",
|
|
# "NCCL_P2P_DISABLE": "1",
|
|
# "NCCL_IB_DISABLE": "1",
|
|
"TRANSFORMERS_OFFLINE": "1",
|
|
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
|
|
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
|
"TOKENIZERS_PARALLELISM": "true",
|
|
"TORCH_EXTENSIONS_DIR": kwargs["TORCH_EXTENSIONS_DIR"],
|
|
# "TORCH_DISTRIBUTED_DEBUG": "DETAIL",
|
|
# "NCCL_SOCKET_IFNAME": "ibp71s0",
|
|
# "GLOO_SOCKET_IFNAME": "ibp71s0",
|
|
# "TORCH_USE_CUDA_DSA": "1",
|
|
# "NCCL_IGNORE_DISABLED_P2P": "1",
|
|
# "CUDA_LAUNCH_BLOCKING": "1", # NOTE: CUDAGraph Capturing will not work if CUDA_LAUNCH_BLOCKING is set to 1.
|
|
# "NCCL_COMM_BLOCKING": "1", # NOTE: CUDAGraph Capturing will not work if NCCL_COMM_BLOCKING is set to 1.
|
|
# "NCCL_BLOCKING_WAIT": "1", # NOTE: CUDAGraph Capturing will not work if NCCL_BLOCKING_WAIT is set to 1.
|
|
# "TORCH_SHOW_CPP_STACKTRACES": "1",
|
|
# "RAY_DEDUP_LOGS": "0", # disable ray log deduplication
|
|
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
|
|
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
|
|
# torch.distributed.all_reduce does not free the input tensor until
|
|
# the synchronization point. This causes the memory usage to grow
|
|
# as the number of all_reduce calls increases. This env var disables
|
|
# this behavior.
|
|
# Related issue:
|
|
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
|
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
|
|
# Whether to enable time mark to plot timelines.
|
|
"REAL_CUDA_TMARK": os.getenv("REAL_CUDA_TMARK", "0"),
|
|
"REAL_DUMP_TRACE": os.getenv("REAL_DUMP_TRACE", "0"),
|
|
"REAL_DUMP_MEMORY": os.getenv("REAL_DUMP_MEMORY", "0"),
|
|
"REAL_GPU_MEMORY_KILL_THRESHOLD": os.getenv(
|
|
"REAL_GPU_MEMORY_KILL_THRESHOLD", "0.95"
|
|
),
|
|
"LC_ALL": "C",
|
|
"LANG": "C",
|
|
"NCCL_DEBUG": "WARN",
|
|
}
|
|
kwargs["BASE_ENVIRONS"] = BASE_ENVIRONS
|
|
# Set PPU-specific environment variables for stable training.
|
|
if cluster_spec.name == "wa180":
|
|
logger.warning("Detected PPU. Amending PPU-related environment variables.")
|
|
PPU_ENVIRONS = {
|
|
"NCCL_DEBUG": "INFO",
|
|
"NCCL_IB_DISABLE": "1",
|
|
"NCCL_DEBUG_SUBSYS": "INIT",
|
|
"NCCL_SET_THREAD_NAME": "1",
|
|
"NCCL_IB_HCA": "",
|
|
"NCCL_SOCKET_IFNAME": "bond0",
|
|
"PCCL_STATE_MONITOR_DISABLE": "1",
|
|
}
|
|
kwargs["BASE_ENVIRONS"].update(PPU_ENVIRONS)
|
|
elif cluster_spec.name == "na132":
|
|
# Specific environment variable for h800 cluster na132
|
|
NV_ENVIRONS = {
|
|
"NCCL_SOCKET_IFNAME": "bond0",
|
|
"NCCL_NET_PLUGIN": "",
|
|
"NCCL_IB_GID_INDEX": "3",
|
|
"NCCL_IB_TIMEOUT": "2",
|
|
"NCCL_IB_RETRY_CNT": "7",
|
|
"NCCL_IB_SL": "5",
|
|
"NCCL_IB_TC": "136",
|
|
"NCCL_IB_HCA": "mlx5_bond",
|
|
"NCCL_IB_QPS_PER_CONNECTION": "8",
|
|
"NCCL_SET_THREAD_NAME": "1",
|
|
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
|
}
|
|
kwargs["BASE_ENVIRONS"].update(NV_ENVIRONS)
|
|
|
|
for key, value in kwargs.items():
|
|
if key not in globals_dict:
|
|
raise ValueError(f"Invalid constant name: {key}")
|
|
if globals_dict[key] is not None and globals_dict[key] != value:
|
|
raise RuntimeError(f"Constant '{key}' already initialized!")
|
|
globals_dict[key] = value
|
|
|
|
# make directories if does not exist
|
|
os.makedirs(globals_dict["PARAM_REALLOC_PATH"], exist_ok=True)
|
|
os.makedirs(globals_dict["MODEL_SAVE_ROOT"], exist_ok=True)
|
|
os.makedirs(globals_dict["LOG_ROOT"], exist_ok=True)
|
|
os.makedirs(globals_dict["RECOVER_ROOT"], exist_ok=True)
|
|
os.makedirs(globals_dict["DATASET_CACHE_PATH"], exist_ok=True)
|
|
os.makedirs(globals_dict["PROFILER_CACHE_PATH"], exist_ok=True)
|
|
os.makedirs(globals_dict["TORCH_EXTENSIONS_DIR"], exist_ok=True)
|
|
os.makedirs(globals_dict["PORT_LOCK_FILE_ROOT"], exist_ok=True)
|
|
os.makedirs(globals_dict["SGLANG_CACHE_PATH"], exist_ok=True)
|
|
|
|
|
|
# _model_name will be changed in the model_scope context manager
|
|
_model_name: "ModelName" = None
|
|
|
|
# constants in worker/process scope
|
|
_experiment_name = None
|
|
_trial_name = None
|
|
|
|
_grids: Dict["ModelName", "ParallelGrid"] = {}
|
|
_pgroups: Dict["ModelName", Any] = (
|
|
{}
|
|
) # torch.distributed.ProcessGroup, not type hint here to avoid importing torch
|
|
_cpu_pgroups: Dict["ModelName", Any] = (
|
|
{}
|
|
) # torch.distributed.ProcessGroup, not type hint here to avoid importing torch
|
|
_pgroup_ranks: Dict["ModelName", List[int]] = {}
|
|
_self_group = None
|
|
_rank_mapping: Dict["ModelName", Dict["ModelShardID", int]] = {}
|
|
_global_memory_buffer: GlobalMemoryBuffer = GlobalMemoryBuffer()
|
|
|
|
|
|
# TODO: As in Megatron, we can set NCCL group options. Is it necessary?
|
|
|
|
|
|
def reset_run():
|
|
global _model_name, _grids, _pgroups, _pgroup_ranks, _self_group, _rank_mapping, _global_memory_buffer
|
|
_model_name = None
|
|
_grids = {}
|
|
_pgroups = {}
|
|
_pgroup_ranks = {}
|
|
_self_group = None
|
|
_rank_mapping = {}
|
|
_global_memory_buffer = GlobalMemoryBuffer()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def model_scope(model_name: "ModelName"):
|
|
global _model_name
|
|
assert _model_name is None
|
|
_model_name = model_name
|
|
yield
|
|
assert _model_name == model_name
|
|
_model_name = None
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def model_scope_disabled():
|
|
global _model_name
|
|
assert _model_name is not None
|
|
t, _model_name = _model_name, None
|
|
yield
|
|
_model_name = t
|
|
|
|
|
|
################# setter functions #################
|
|
def set_force_cpu(val: bool):
|
|
global TORCH_FORCE_CPU
|
|
TORCH_FORCE_CPU = val
|
|
|
|
|
|
def set_experiment_trial_names(expr_name: str, trial_name: str):
|
|
global _experiment_name, _trial_name
|
|
if _experiment_name is not None and _experiment_name != expr_name:
|
|
raise RuntimeError("Experiment name has been set.")
|
|
if _trial_name is not None and _trial_name != trial_name:
|
|
raise RuntimeError("Trial name has been set.")
|
|
_experiment_name = expr_name
|
|
_trial_name = trial_name
|
|
|
|
|
|
def set_grid(model_name: "ModelName", grid: "ParallelGrid"):
|
|
global _grids
|
|
if model_name in _grids:
|
|
raise RuntimeError(f"Grid for model {model_name} is already set.")
|
|
_grids[model_name] = grid
|
|
|
|
|
|
def set_parallelism_group(model_name: "ModelName", pgroup, ranks):
|
|
global _pgroups
|
|
if model_name in _pgroups:
|
|
raise RuntimeError(f"Parallelism group for model {model_name} is already set.")
|
|
_pgroups[model_name] = pgroup
|
|
_pgroup_ranks[model_name] = ranks
|
|
|
|
|
|
def set_cpu_parallelism_group(model_name: "ModelName", pgroup):
|
|
global _cpu_pgroups
|
|
if model_name in _cpu_pgroups:
|
|
raise RuntimeError(f"Parallelism group for model {model_name} is already set.")
|
|
_cpu_pgroups[model_name] = pgroup
|
|
|
|
|
|
def set_self_group(pgroup):
|
|
global _self_group
|
|
if _self_group is not None:
|
|
raise RuntimeError("Self group is already set.")
|
|
_self_group = pgroup
|
|
|
|
|
|
def set_rank_mapping(
|
|
model_name: "ModelName",
|
|
topo: "ProcessTopology",
|
|
msid2mwid: Optional[Dict["ModelShardID", int]] = None,
|
|
):
|
|
global _rank_mapping
|
|
if model_name in _rank_mapping:
|
|
raise RuntimeError(f"Rank mapping for model {model_name} is already set.")
|
|
if msid2mwid is None:
|
|
_rank_mapping[model_name] = {i: i for i in range(topo.world_size())}
|
|
else:
|
|
msid2mwid = {k: v for k, v in msid2mwid.items() if k.model_name == model_name}
|
|
_rank_mapping[model_name] = {
|
|
topo.get_rank(data=s.dp_rank, tensor=s.tp_rank, pipe=s.pp_rank): mw_id
|
|
for s, mw_id in msid2mwid.items()
|
|
}
|
|
|
|
|
|
################# attribute functions #################
|
|
def current_device() -> torch.device:
|
|
global TORCH_FORCE_CPU
|
|
if TORCH_FORCE_CPU or not torch.cuda.is_available():
|
|
return torch.device("cpu")
|
|
return torch.cuda.current_device()
|
|
|
|
|
|
def use_cuda() -> bool:
|
|
return not TORCH_FORCE_CPU and torch.cuda.is_available()
|
|
|
|
|
|
def use_te_impl() -> bool:
|
|
try:
|
|
import transformer_engine.pytorch as te
|
|
|
|
TE_ENABLED = True
|
|
except ImportError:
|
|
TE_ENABLED = False
|
|
return TE_ENABLED and os.getenv("REAL_LLM_USE_TE") == "1"
|
|
|
|
|
|
def sequence_parallel() -> bool:
|
|
return grid().topology().sequence_parallel
|
|
|
|
|
|
def gradient_accumulation_fusion() -> bool:
|
|
_grad_accum_fusion_available = True
|
|
try:
|
|
import fused_weight_gradient_mlp_cuda
|
|
except ImportError:
|
|
_grad_accum_fusion_available = False
|
|
return _grad_accum_fusion_available and getattr(
|
|
grid().topology(), "gradient_accumulation_fusion", False
|
|
)
|
|
|
|
|
|
def max_prompt_len() -> int:
|
|
return grid().topology().max_prompt_len
|
|
|
|
|
|
def gradient_checkpointing() -> bool:
|
|
return getattr(grid().topology(), "gradient_checkpointing", False)
|
|
|
|
|
|
def has_model_name(name: str) -> bool:
|
|
return name in _grids and _grids[name].global_rank != -1
|
|
|
|
|
|
def self_group():
|
|
global _self_group
|
|
assert _self_group is not None
|
|
return _self_group
|
|
|
|
|
|
def model_name():
|
|
if _model_name == None:
|
|
raise RuntimeError(
|
|
"Global constant `model_name` should be accessed in the `model_scope` context."
|
|
)
|
|
return _model_name
|
|
|
|
|
|
def experiment_name():
|
|
if _experiment_name == None:
|
|
raise RuntimeError("Global constant `experiment_name` is accessed before set.")
|
|
return _experiment_name
|
|
|
|
|
|
def trial_name():
|
|
if _trial_name == None:
|
|
raise RuntimeError("Global constant `trial_name` is accessed before set.")
|
|
return _trial_name
|
|
|
|
|
|
def grid() -> "ParallelGrid":
|
|
if _model_name is None:
|
|
raise RuntimeError("Global constant `model_name` is accessed before set.")
|
|
if _grids.get(_model_name, None) is None:
|
|
raise RuntimeError(f"Grid for model {_model_name} is not set.")
|
|
return _grids[_model_name]
|
|
|
|
|
|
def grid_of_model(model_name: str) -> "ParallelGrid":
|
|
if _grids.get(model_name, None) is None:
|
|
raise RuntimeError(f"Grid for model {model_name} is not set.")
|
|
return _grids[model_name]
|
|
|
|
|
|
def parallelism_group():
|
|
"""Returns the 3D parallelism group of a specific model."""
|
|
if _model_name is None:
|
|
raise RuntimeError("Global constant `model_name` is accessed before set.")
|
|
if _pgroups.get(_model_name, None) is None:
|
|
raise RuntimeError(f"Parallelism group for model {_model_name} is not set.")
|
|
return _pgroups[_model_name]
|
|
|
|
|
|
def cpu_parallelism_group():
|
|
"""Returns the GLOO 3D parallelism group of a specific model."""
|
|
if _model_name is None:
|
|
raise RuntimeError("Global constant `model_name` is accessed before set.")
|
|
if _cpu_pgroups.get(_model_name, None) is None:
|
|
raise RuntimeError(f"Parallelism group for model {_model_name} is not set.")
|
|
return _cpu_pgroups[_model_name]
|
|
|
|
|
|
def parallelism_group_ranks():
|
|
if _model_name is None:
|
|
raise RuntimeError("Global constant `model_name` is accessed before set.")
|
|
if _pgroup_ranks.get(_model_name, None) is None:
|
|
raise RuntimeError(
|
|
f"Parallelism group ranks for model {_model_name} is not set."
|
|
)
|
|
return _pgroup_ranks[_model_name]
|
|
|
|
|
|
def parallelism_group_size() -> int:
|
|
"""The 3D parallelism group size of a specific model, normally dp_size *
|
|
pp_size * tp_size."""
|
|
import torch.distributed as dist
|
|
|
|
return dist.get_world_size(group=parallelism_group())
|
|
|
|
|
|
def parallelism_rank() -> int:
|
|
"""Return the rank of a specific model in its 3D parallelism group."""
|
|
import torch.distributed as dist
|
|
|
|
return dist.get_rank(group=parallelism_group())
|
|
|
|
|
|
def to_global_pg_rank(local_rank: int) -> int:
|
|
global _rank_mapping
|
|
if _rank_mapping is None or model_name() not in _rank_mapping:
|
|
raise RuntimeError("Rank mapping is not set.")
|
|
return _rank_mapping[model_name()][local_rank]
|
|
|
|
|
|
def rank_mapping_of_model(model_name: str) -> Dict["ModelShardID", int]:
|
|
global _rank_mapping
|
|
if _rank_mapping is None or _rank_mapping.get(model_name, None) is None:
|
|
raise RuntimeError(f"Rank mapping for model {model_name} is not set.")
|
|
return _rank_mapping[model_name]
|
|
|
|
|
|
def pipe_parallel_rank() -> int:
|
|
return grid().get_pipe_parallel_rank()
|
|
|
|
|
|
def pipe_parallel_world_size() -> int:
|
|
return grid().get_pipe_parallel_world_size()
|
|
|
|
|
|
def pipe_parallel_group():
|
|
return grid().get_pipe_parallel_group()
|
|
|
|
|
|
def pipe_parallel_cpu_group():
|
|
return grid().pp_proc_group_gloo
|
|
|
|
|
|
def is_last_pipe_stage():
|
|
return pipe_parallel_rank() == pipe_parallel_world_size() - 1
|
|
|
|
|
|
def is_first_pipe_stage():
|
|
return pipe_parallel_rank() == 0
|
|
|
|
|
|
def next_pipe_stage():
|
|
return (pipe_parallel_rank() + 1) % pipe_parallel_world_size()
|
|
|
|
|
|
def prev_pipe_stage():
|
|
return (
|
|
pipe_parallel_world_size() + pipe_parallel_rank() - 1
|
|
) % pipe_parallel_world_size()
|
|
|
|
|
|
def is_dp_head():
|
|
return is_last_pipe_stage() and tensor_parallel_rank() == 0
|
|
|
|
|
|
def tensor_parallel_rank() -> int:
|
|
"""Return the rank inside the tensor parallelism group."""
|
|
return grid().get_tensor_model_parallel_rank()
|
|
|
|
|
|
def tensor_parallel_world_size() -> int:
|
|
"""Return the world size of the tensor parallelism group."""
|
|
return grid().get_tensor_model_parallel_world_size()
|
|
|
|
|
|
def tensor_parallel_group():
|
|
"""Return the NCCL tensor parallelism process group."""
|
|
return grid().get_tensor_model_parallel_group()
|
|
|
|
|
|
def tensor_parallel_cpu_group():
|
|
"""Return the GLOO tensor parallelism process group."""
|
|
return grid().get_tensor_model_parallel_cpu_group()
|
|
|
|
|
|
def tp_and_pp_group():
|
|
"""Used as the world group of vLLM."""
|
|
return grid().get_model_parallel_group()
|
|
|
|
|
|
def tp_and_pp_cpu_group():
|
|
return grid().ds_model_proc_group_gloo
|
|
|
|
|
|
def tp_and_pp_rank():
|
|
"""Used as the rank in the world group of vLLM."""
|
|
return grid().get_model_parallel_rank()
|
|
|
|
|
|
def tp_and_pp_world_size():
|
|
"""Used as the world size of vLLM."""
|
|
return grid().get_model_parallel_world_size()
|
|
|
|
|
|
def data_parallel_rank() -> int:
|
|
return grid().get_data_parallel_rank()
|
|
|
|
|
|
def data_parallel_world_size() -> int:
|
|
return grid().get_data_parallel_world_size()
|
|
|
|
|
|
def data_parallel_group():
|
|
return grid().get_data_parallel_group()
|
|
|
|
|
|
def get_global_memory_buffer():
|
|
global _global_memory_buffer
|
|
assert _global_memory_buffer is not None, "global memory buffer is not set"
|
|
return _global_memory_buffer
|
|
|
|
|
|
def clear_global_memory_buffer():
|
|
global _global_memory_buffer
|
|
_global_memory_buffer = GlobalMemoryBuffer()
|
|
|
|
|
|
def get_repo_path() -> pathlib.Path:
|
|
return pathlib.Path(__file__).resolve().parent.parent.parent
|
|
|
|
|
|
def get_env_vars(**kwargs):
|
|
return {
|
|
**kwargs,
|
|
"REAL_PACKAGE_PATH": str(get_repo_path()),
|
|
**BASE_ENVIRONS,
|
|
}
|