mirror of https://github.com/inclusionAI/AReaL
PullRequest: 252 [Feature] Fix constants initialization. (#122)
Merge branch fw/gh/fix-init-constants of git@code.alipay.com:inclusionAI/AReaL.git into gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/252?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * remove LOG_ROOT * remove MODEL_SAVE_PATH * remove PARAM_REALLOC_PATH, DATASET_CACHE * prepare for testing * prepare for testing * ready for run * local run * tests mainly pass * format * amend cluster.py * fix
This commit is contained in:
parent
b63eea9d07
commit
1ec1399f19
|
@ -4,10 +4,10 @@ python3 training/main_sync_ppo.py \
|
||||||
allocation_mode=sglang.d4p1m1+d2p2m1 \
|
allocation_mode=sglang.d4p1m1+d2p2m1 \
|
||||||
cluster.fileroot=/storage/testing/experiments \
|
cluster.fileroot=/storage/testing/experiments \
|
||||||
actor.type._class=qwen3 \
|
actor.type._class=qwen3 \
|
||||||
actor.path=/storage/testing/models/Qwen__Qwen3-1.7B \
|
actor.path=Qwen/Qwen3-1.7B \
|
||||||
ref.type._class=qwen3 \
|
ref.type._class=qwen3 \
|
||||||
ref.path=/storage/testing/models/Qwen__Qwen3-1.7B \
|
ref.path=Qwen/Qwen3-1.7B \
|
||||||
dataset.path=/storage/testing/dataset/boba_106k_0319.jsonl \
|
dataset.path=hf-dataset://inclusionAI/AReaL-RL-Data/data/boba_106k_0319.jsonl \
|
||||||
dataset.train_bs_n_seqs=32 \
|
dataset.train_bs_n_seqs=32 \
|
||||||
group_size=8 \
|
group_size=8 \
|
||||||
ppo.gen.max_new_tokens=4096 \
|
ppo.gen.max_new_tokens=4096 \
|
||||||
|
|
|
@ -5,7 +5,9 @@ from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from omegaconf import MISSING
|
from omegaconf import MISSING
|
||||||
|
|
||||||
from realhf.base import pkg_version
|
from realhf.base import logging, pkg_version
|
||||||
|
|
||||||
|
logger = logging.getLogger("CLI args")
|
||||||
|
|
||||||
## Data and datasets. ##
|
## Data and datasets. ##
|
||||||
|
|
||||||
|
@ -351,10 +353,6 @@ class SGLangConfig:
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
|
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
|
||||||
base_gpu_id=base_gpu_id,
|
base_gpu_id=base_gpu_id,
|
||||||
file_storage_path=os.path.join(
|
|
||||||
constants.SGLANG_CACHE_PATH,
|
|
||||||
f"sglang_storage{server_index}",
|
|
||||||
),
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
dp_size=1, # TODO: check whether we require SGLang dp
|
dp_size=1, # TODO: check whether we require SGLang dp
|
||||||
load_balance_method="round_robin",
|
load_balance_method="round_robin",
|
||||||
|
@ -870,6 +868,30 @@ def get_user_tmp():
|
||||||
return user_tmp
|
return user_tmp
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NameResolveConfig:
|
||||||
|
type: str = field(
|
||||||
|
default="nfs",
|
||||||
|
metadata={
|
||||||
|
"help": "Type of the distributed KV store for name resolving.",
|
||||||
|
"choices": ["nfs", "etcd3", "ray"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
nfs_record_root: str = field(
|
||||||
|
default="/tmp/areal/name_resolve",
|
||||||
|
metadata={
|
||||||
|
"help": "Record root for NFS name resolving. Should be available in all nodes."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
etcd3_addr: str = field(
|
||||||
|
default="localhost:2379", metadata={"help": "Address of the ETCD3 server."}
|
||||||
|
)
|
||||||
|
ray_actor_name: str = field(
|
||||||
|
default="ray_kv_store",
|
||||||
|
metadata={"help": "Name of the distributed Ray KV store."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClusterSpecConfig:
|
class ClusterSpecConfig:
|
||||||
config_path: str = field(
|
config_path: str = field(
|
||||||
|
@ -878,6 +900,10 @@ class ClusterSpecConfig:
|
||||||
"help": "JSON config path. If not given, use the following CLI args."
|
"help": "JSON config path. If not given, use the following CLI args."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
name_resolve: NameResolveConfig = field(
|
||||||
|
default_factory=NameResolveConfig,
|
||||||
|
metadata={"help": "Name resolving configuration."},
|
||||||
|
)
|
||||||
cluster_name: str = field(
|
cluster_name: str = field(
|
||||||
default="local",
|
default="local",
|
||||||
metadata={"help": "Name of the cluster. Used to set specific environs."},
|
metadata={"help": "Name of the cluster. Used to set specific environs."},
|
||||||
|
|
|
@ -6,7 +6,6 @@ import dataclasses
|
||||||
import enum
|
import enum
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import realhf.base.cluster as cluster
|
|
||||||
import realhf.base.topology as topology
|
import realhf.base.topology as topology
|
||||||
|
|
||||||
|
|
||||||
|
@ -140,7 +139,7 @@ class ModelShardID:
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
n = cluster.spec.suffix_n_digits
|
n = len(str(self.topo.world_size()))
|
||||||
return f"{self.model_name}@pp{self.pp_rank:0{n}d}@tp{self.tp_rank:0{n}d}@dp{self.dp_rank:0{n}d}"
|
return f"{self.model_name}@pp{self.pp_rank:0{n}d}@tp{self.tp_rank:0{n}d}@dp{self.dp_rank:0{n}d}"
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
|
|
|
@ -40,7 +40,6 @@ from pydantic import field_validator, model_validator
|
||||||
from realhf.api.cli_args import MicroBatchSpec
|
from realhf.api.cli_args import MicroBatchSpec
|
||||||
from realhf.api.core import config as config_api
|
from realhf.api.core import config as config_api
|
||||||
from realhf.base import constants, datapack, logging, seeding
|
from realhf.base import constants, datapack, logging, seeding
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
|
||||||
from realhf.utils import load_hf_or_local_file
|
from realhf.utils import load_hf_or_local_file
|
||||||
|
|
||||||
logger = logging.getLogger("api.data")
|
logger = logging.getLogger("api.data")
|
||||||
|
@ -754,11 +753,11 @@ def get_shuffle_indices(seed: int, size: int):
|
||||||
|
|
||||||
def load_shuffle_split_dataset(
|
def load_shuffle_split_dataset(
|
||||||
util: DatasetUtility,
|
util: DatasetUtility,
|
||||||
dataset_path: str,
|
dataset_path: Optional[str] = None,
|
||||||
dataset_builder: Optional[Callable[[], List[Dict[str, str]]]] = None,
|
dataset_builder: Optional[Callable[[], List[Dict[str, str]]]] = None,
|
||||||
):
|
):
|
||||||
dataset_path = load_hf_or_local_file(dataset_path)
|
|
||||||
if dataset_path is not None:
|
if dataset_path is not None:
|
||||||
|
dataset_path = load_hf_or_local_file(dataset_path)
|
||||||
if dataset_path.endswith(".jsonl"):
|
if dataset_path.endswith(".jsonl"):
|
||||||
with open(dataset_path, "r") as f:
|
with open(dataset_path, "r") as f:
|
||||||
data = [json.loads(ff) for ff in f]
|
data = [json.loads(ff) for ff in f]
|
||||||
|
@ -808,17 +807,12 @@ def make_dataset(
|
||||||
dp_rank: int,
|
dp_rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
tokenizer_or_tokenizer_name: Union[transformers.PreTrainedTokenizerFast, str],
|
tokenizer_or_tokenizer_name: Union[transformers.PreTrainedTokenizerFast, str],
|
||||||
experiment_name: str,
|
|
||||||
trial_name: str,
|
|
||||||
cache_root: Optional[str] = None,
|
|
||||||
) -> torch.utils.data.Dataset:
|
) -> torch.utils.data.Dataset:
|
||||||
if isinstance(cfg, str):
|
if isinstance(cfg, str):
|
||||||
cfg = config_api.DatasetAbstraction(type_=cfg)
|
cfg = config_api.DatasetAbstraction(type_=cfg)
|
||||||
|
|
||||||
if isinstance(tokenizer_or_tokenizer_name, str):
|
if isinstance(tokenizer_or_tokenizer_name, str):
|
||||||
tokenizer = load_hf_tokenizer(tokenizer_or_tokenizer_name)
|
tokenizer = load_hf_tokenizer(tokenizer_or_tokenizer_name)
|
||||||
elif tokenizer_or_tokenizer_name is None:
|
|
||||||
raise RuntimeError("tokenizer_or_tokenizer_name cannot be None.")
|
|
||||||
else:
|
else:
|
||||||
tokenizer = tokenizer_or_tokenizer_name
|
tokenizer = tokenizer_or_tokenizer_name
|
||||||
util = DatasetUtility(
|
util = DatasetUtility(
|
||||||
|
@ -827,46 +821,8 @@ def make_dataset(
|
||||||
world_size,
|
world_size,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
)
|
)
|
||||||
|
dataset_cls = ALL_DATASET_CLASSES[cfg.type_]
|
||||||
if cache_root is None:
|
return dataset_cls(util=util, **cfg.args)
|
||||||
dataset_cls = ALL_DATASET_CLASSES[cfg.type_]
|
|
||||||
return dataset_cls(util=util, **cfg.args)
|
|
||||||
|
|
||||||
# Create and check cache path.
|
|
||||||
if not cache_root.startswith(cluster_spec.fileroot) and not cache_root.startswith(
|
|
||||||
"/home"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Data cache path {cache_root} should be /home or under {cluster_spec.fileroot}."
|
|
||||||
)
|
|
||||||
if "_" in experiment_name or "_" in trial_name:
|
|
||||||
raise ValueError(f"Invalid experiment/trial name.")
|
|
||||||
|
|
||||||
output_path = os.path.join(
|
|
||||||
cache_root,
|
|
||||||
experiment_name,
|
|
||||||
trial_name,
|
|
||||||
cfg.type_,
|
|
||||||
f"seed{seed}",
|
|
||||||
f"world_size{world_size}",
|
|
||||||
f"rank{dp_rank}",
|
|
||||||
)
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
|
|
||||||
fname = "dataset.pt"
|
|
||||||
cache_found = os.path.isfile(os.path.join(output_path, fname))
|
|
||||||
|
|
||||||
tik = time.perf_counter()
|
|
||||||
if not cache_found:
|
|
||||||
logger.info(f"No data cache found for rank {dp_rank}. Create it from scratch.")
|
|
||||||
dataset = ALL_DATASET_CLASSES[cfg.type_](seed, dp_rank, world_size, **cfg.args)
|
|
||||||
torch.save(dataset, os.path.join(output_path, fname))
|
|
||||||
else:
|
|
||||||
logger.info(f"Rank {dp_rank} find existing data cache, load it.")
|
|
||||||
dataset = torch.load(os.path.join(output_path, fname))
|
|
||||||
logger.info(f"Dataset creation/loading time: {time.perf_counter() - tik:.3f}s")
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
def gather_stat(src: List[Dict]) -> Dict:
|
def gather_stat(src: List[Dict]) -> Dict:
|
||||||
|
|
|
@ -25,7 +25,6 @@ from realhf.api.core.config import (
|
||||||
StandaloneModelShardAbstraction,
|
StandaloneModelShardAbstraction,
|
||||||
)
|
)
|
||||||
from realhf.base import constants, topology
|
from realhf.base import constants, topology
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
|
||||||
|
|
||||||
|
|
||||||
class ExpStatus(Enum):
|
class ExpStatus(Enum):
|
||||||
|
@ -49,66 +48,6 @@ class Scheduling:
|
||||||
begin: Optional[str] = None # see "--begin" option for format
|
begin: Optional[str] = None # see "--begin" option for format
|
||||||
deadline: Optional[str] = None # see "--deadline" option for format
|
deadline: Optional[str] = None # see "--deadline" option for format
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def master_worker_default(**kwargs):
|
|
||||||
return Scheduling(
|
|
||||||
**{
|
|
||||||
"cpu": 16,
|
|
||||||
"mem": 20 * 1024,
|
|
||||||
"gpu": 0,
|
|
||||||
"container_image": cluster_spec.cpu_image,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def model_worker_default(**kwargs):
|
|
||||||
return Scheduling(
|
|
||||||
**{
|
|
||||||
"cpu": 2,
|
|
||||||
"gpu": 1,
|
|
||||||
"mem": 60 * 1024,
|
|
||||||
"container_image": cluster_spec.gpu_image,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generation_server_default(**kwargs):
|
|
||||||
return Scheduling(
|
|
||||||
**{
|
|
||||||
"cpu": 4,
|
|
||||||
"gpu": 1,
|
|
||||||
"mem": 60 * 1024,
|
|
||||||
"container_image": cluster_spec.gpu_infer_image,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gserver_manager_default(**kwargs):
|
|
||||||
return Scheduling(
|
|
||||||
**{
|
|
||||||
"cpu": 4,
|
|
||||||
"gpu": 0,
|
|
||||||
"mem": 10 * 1024,
|
|
||||||
"container_image": cluster_spec.gpu_image,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def rollout_worker_default(**kwargs):
|
|
||||||
return Scheduling(
|
|
||||||
**{
|
|
||||||
"cpu": 4,
|
|
||||||
"gpu": 0,
|
|
||||||
"mem": 20 * 1024,
|
|
||||||
"container_image": cluster_spec.gpu_image,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class WorkerInformation:
|
class WorkerInformation:
|
||||||
|
@ -159,8 +98,6 @@ class ModelWorker:
|
||||||
# dataset, for source model workers
|
# dataset, for source model workers
|
||||||
tokenizer_name_or_path: Optional[str] = None
|
tokenizer_name_or_path: Optional[str] = None
|
||||||
datasets: Optional[List[Union[str, DatasetAbstraction]]] = None
|
datasets: Optional[List[Union[str, DatasetAbstraction]]] = None
|
||||||
use_dataset_cache: bool = False
|
|
||||||
dataset_cahce_root: str = constants.DATASET_CACHE_PATH
|
|
||||||
shuffle_dataset: bool = True
|
shuffle_dataset: bool = True
|
||||||
cuda_cache_cleanliness: bool = True
|
cuda_cache_cleanliness: bool = True
|
||||||
cuda_cache_clear_freq: int = 10
|
cuda_cache_clear_freq: int = 10
|
||||||
|
@ -215,8 +152,6 @@ class RolloutWorker:
|
||||||
env: EnvServiceAbstraction
|
env: EnvServiceAbstraction
|
||||||
agent: AgentAbstraction
|
agent: AgentAbstraction
|
||||||
datasets: List[Union[str, DatasetAbstraction]]
|
datasets: List[Union[str, DatasetAbstraction]]
|
||||||
use_dataset_cache: bool = False
|
|
||||||
dataset_cahce_root: str = constants.DATASET_CACHE_PATH
|
|
||||||
worker_info: WorkerInformation = None
|
worker_info: WorkerInformation = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -290,16 +225,9 @@ class ExperimentConfig:
|
||||||
|
|
||||||
assert constants.trial_name() is not None
|
assert constants.trial_name() is not None
|
||||||
assert constants.experiment_name() is not None
|
assert constants.experiment_name() is not None
|
||||||
graph_path = os.path.join(
|
|
||||||
constants.LOG_ROOT,
|
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"dataflow_graph.png",
|
|
||||||
)
|
|
||||||
os.makedirs(os.path.dirname(graph_path), exist_ok=True)
|
|
||||||
# If verbose set to True here, every worker will print the graph once
|
# If verbose set to True here, every worker will print the graph once
|
||||||
# due to lazy init on workers.
|
# due to lazy init on workers.
|
||||||
G = dfg.build_graph(self.model_rpcs, verbose=False, graph_path=graph_path)
|
G = dfg.build_graph(self.model_rpcs, verbose=False)
|
||||||
for rpc in self.model_rpcs:
|
for rpc in self.model_rpcs:
|
||||||
rpc._G = G
|
rpc._G = G
|
||||||
|
|
||||||
|
@ -549,4 +477,12 @@ def register_experiment(name, cls):
|
||||||
|
|
||||||
def make_experiment(name) -> Experiment:
|
def make_experiment(name) -> Experiment:
|
||||||
cls = ALL_EXPERIMENT_CLASSES[name]
|
cls = ALL_EXPERIMENT_CLASSES[name]
|
||||||
return cls()
|
args = cls()
|
||||||
|
if args.cluster.config_path:
|
||||||
|
from realhf.base.cluster import load_spec_from_file
|
||||||
|
|
||||||
|
load_spec_from_file(args.cluster)
|
||||||
|
from realhf.base import name_resolve
|
||||||
|
|
||||||
|
name_resolve.reconfigure(args.cluster.name_resolve)
|
||||||
|
return args
|
||||||
|
|
|
@ -8,9 +8,8 @@ from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from realhf.api.cli_args import ParallelismConfig
|
from realhf.api.cli_args import ClusterSpecConfig, ParallelismConfig
|
||||||
from realhf.api.core.dfg import MFCDef
|
from realhf.api.core.dfg import MFCDef
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
|
||||||
from realhf.base.slurm_utils import are_ones_contiguous, parse_nodelist
|
from realhf.base.slurm_utils import are_ones_contiguous, parse_nodelist
|
||||||
|
|
||||||
|
|
||||||
|
@ -76,7 +75,6 @@ class DeviceMesh:
|
||||||
return device_mesh
|
return device_mesh
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
n = cluster_spec.suffix_n_digits
|
|
||||||
assert self._is_valid_mapping()
|
assert self._is_valid_mapping()
|
||||||
|
|
||||||
def __eq__(self, other: "DeviceMesh"):
|
def __eq__(self, other: "DeviceMesh"):
|
||||||
|
@ -179,7 +177,10 @@ class DeviceMesh:
|
||||||
|
|
||||||
|
|
||||||
def make_device_mesh_from_name(
|
def make_device_mesh_from_name(
|
||||||
global_mesh_name: str, name: str, n_gpus_per_node: int = 8
|
cluster: ClusterSpecConfig,
|
||||||
|
global_mesh_name: str,
|
||||||
|
name: str,
|
||||||
|
n_gpus_per_node: int = 8,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
DeviceMesh name format: <prefix><node_indices>[:<gpu_ids>]
|
DeviceMesh name format: <prefix><node_indices>[:<gpu_ids>]
|
||||||
|
@ -191,8 +192,8 @@ def make_device_mesh_from_name(
|
||||||
|
|
||||||
Note: cluster device mesh name must occupy entire nodes.
|
Note: cluster device mesh name must occupy entire nodes.
|
||||||
"""
|
"""
|
||||||
prefix = cluster_spec.node_name_prefix
|
prefix = cluster.node_name_prefix
|
||||||
node_list = parse_nodelist(global_mesh_name, prefix)
|
node_list = parse_nodelist(cluster, global_mesh_name, prefix)
|
||||||
n_nodes = len(node_list)
|
n_nodes = len(node_list)
|
||||||
|
|
||||||
gpu_ids = None
|
gpu_ids = None
|
||||||
|
@ -202,7 +203,7 @@ def make_device_mesh_from_name(
|
||||||
assert all(gpu_id < n_gpus_per_node for gpu_id in gpu_ids)
|
assert all(gpu_id < n_gpus_per_node for gpu_id in gpu_ids)
|
||||||
else:
|
else:
|
||||||
node_names = name
|
node_names = name
|
||||||
node_names = parse_nodelist(node_names, prefix)
|
node_names = parse_nodelist(cluster, node_names, prefix)
|
||||||
mapping = np.zeros((n_nodes, n_gpus_per_node), dtype=np.int32)
|
mapping = np.zeros((n_nodes, n_gpus_per_node), dtype=np.int32)
|
||||||
if gpu_ids is None:
|
if gpu_ids is None:
|
||||||
node_indices = [node_list.index(node_name) for node_name in node_names]
|
node_indices = [node_list.index(node_name) for node_name in node_names]
|
||||||
|
|
|
@ -16,24 +16,22 @@ from hydra.core.config_store import ConfigStore
|
||||||
from omegaconf import MISSING, OmegaConf
|
from omegaconf import MISSING, OmegaConf
|
||||||
|
|
||||||
import realhf.api.core.system_api as system_api
|
import realhf.api.core.system_api as system_api
|
||||||
from realhf.base.constants import init_constants
|
from realhf.base.constants import (
|
||||||
|
QUICKSTART_EXPR_CACHE_PATH,
|
||||||
|
get_log_path,
|
||||||
|
get_save_path,
|
||||||
|
)
|
||||||
from realhf.base.ray_utils import check_ray_availability
|
from realhf.base.ray_utils import check_ray_availability
|
||||||
from realhf.base.slurm_utils import check_slurm_availability
|
from realhf.base.slurm_utils import check_slurm_availability
|
||||||
|
|
||||||
|
|
||||||
def kind_reminder(config_name, logger, args):
|
def kind_reminder(config_name, logger, args):
|
||||||
from realhf.base.constants import LOG_ROOT, MODEL_SAVE_ROOT
|
|
||||||
|
|
||||||
logger.info(f"Running {config_name} experiment.")
|
logger.info(f"Running {config_name} experiment.")
|
||||||
|
logger.info(f"Logs will be dumped to {get_log_path(args)}")
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Logs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name)}"
|
f"Experiment configs will be dumped to {os.path.join(get_log_path(args), 'config.yaml')}"
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Experiment configs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name, 'config.yaml')}"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Model checkpoints will be saved to {os.path.join(MODEL_SAVE_ROOT, args.experiment_name, args.trial_name)}"
|
|
||||||
)
|
)
|
||||||
|
logger.info(f"Model checkpoints will be saved to {get_save_path(args)}")
|
||||||
|
|
||||||
if args.mode == "slurm":
|
if args.mode == "slurm":
|
||||||
slurm_available = check_slurm_availability()
|
slurm_available = check_slurm_availability()
|
||||||
|
@ -82,12 +80,7 @@ def register_quickstart_exp(config_name: str, exp_cls: Callable):
|
||||||
trial_name = args.trial_name
|
trial_name = args.trial_name
|
||||||
from realhf.apps.main import main_start, main_stop
|
from realhf.apps.main import main_start, main_stop
|
||||||
|
|
||||||
init_constants(args)
|
config_save_path = os.path.join(get_log_path(args), "config.yaml")
|
||||||
from realhf.base.constants import LOG_ROOT, QUICKSTART_EXPR_CACHE_PATH
|
|
||||||
|
|
||||||
config_save_path = os.path.join(
|
|
||||||
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
|
|
||||||
)
|
|
||||||
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
|
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
|
||||||
with open(config_save_path, "w") as f:
|
with open(config_save_path, "w") as f:
|
||||||
yaml.dump(
|
yaml.dump(
|
||||||
|
|
|
@ -94,7 +94,7 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
|
||||||
raise RuntimeError("Experiment initial setup failed.") from e
|
raise RuntimeError("Experiment initial setup failed.") from e
|
||||||
|
|
||||||
evaluator = (
|
evaluator = (
|
||||||
AutomaticEvaluator(exp_cfg.evaluator, exp_cfg.wandb, exp_cfg.swanlab)
|
AutomaticEvaluator(exp_cfg, exp_cfg.evaluator, exp_cfg.wandb, exp_cfg.swanlab)
|
||||||
if exp_cfg.auto_eval
|
if exp_cfg.auto_eval
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
@ -116,7 +116,7 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
|
||||||
is_recover_run = recover_count > 0
|
is_recover_run = recover_count > 0
|
||||||
if args.recover_mode == "auto":
|
if args.recover_mode == "auto":
|
||||||
try:
|
try:
|
||||||
recover.discover_ckpt(args.experiment_name, args.trial_name)
|
recover.discover_ckpt(experiment)
|
||||||
is_recover_run = True
|
is_recover_run = True
|
||||||
except recover.InValidRecoverCkpt as e:
|
except recover.InValidRecoverCkpt as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -127,7 +127,7 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
|
||||||
is_recover_run = False
|
is_recover_run = False
|
||||||
if is_recover_run:
|
if is_recover_run:
|
||||||
recover_ckpt_path, model_ckpt_dirs, recover_info = recover.discover_ckpt(
|
recover_ckpt_path, model_ckpt_dirs, recover_info = recover.discover_ckpt(
|
||||||
args.experiment_name, args.trial_name
|
experiment
|
||||||
)
|
)
|
||||||
logger.info(f"Will load recover info from {recover_ckpt_path}.")
|
logger.info(f"Will load recover info from {recover_ckpt_path}.")
|
||||||
logger.info(f"Will load model checkpoints from {model_ckpt_dirs}.")
|
logger.info(f"Will load model checkpoints from {model_ckpt_dirs}.")
|
||||||
|
@ -138,19 +138,9 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
|
||||||
)
|
)
|
||||||
save_recover_states = args.recover_mode != "disabled"
|
save_recover_states = args.recover_mode != "disabled"
|
||||||
|
|
||||||
cluster_spec_path = os.environ.get("CLUSTER_SPEC_PATH", "")
|
|
||||||
if not cluster_spec_path:
|
|
||||||
logger.info(
|
|
||||||
"Environment variable CLUSTER_SPEC_PATH is not set. "
|
|
||||||
"Will use the fileroot specified in CLI args. "
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"Environment variable CLUSTER_SPEC_PATH is set. "
|
|
||||||
"Will overwrite the cluster spec in CLI args. "
|
|
||||||
)
|
|
||||||
# set env vars
|
# set env vars
|
||||||
BASE_ENVIRONS = constants.get_env_vars(
|
BASE_ENVIRONS = constants.get_env_vars(
|
||||||
|
experiment,
|
||||||
REAL_MODE=args.mode.upper(),
|
REAL_MODE=args.mode.upper(),
|
||||||
REAL_RECOVER_RUN="1" if is_recover_run else "0",
|
REAL_RECOVER_RUN="1" if is_recover_run else "0",
|
||||||
REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0",
|
REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0",
|
||||||
|
@ -160,10 +150,7 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
|
||||||
|
|
||||||
# setup experiments
|
# setup experiments
|
||||||
sched = sched_client.make(
|
sched = sched_client.make(
|
||||||
mode=scheduler_mode(args.mode),
|
experiment,
|
||||||
expr_name=expr_name,
|
|
||||||
trial_name=trial_name,
|
|
||||||
schedule_strategy=args.schedule_strategy,
|
|
||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
job_group_id=job_group_id,
|
job_group_id=job_group_id,
|
||||||
job_group_index=recover_count,
|
job_group_index=recover_count,
|
||||||
|
@ -302,11 +289,8 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
|
||||||
|
|
||||||
|
|
||||||
def main_stop(args):
|
def main_stop(args):
|
||||||
sched = sched_client.make(
|
experiment = config_package.make_experiment(args.experiment_name)
|
||||||
mode=scheduler_mode(args.mode),
|
sched = sched_client.make(experiment)
|
||||||
expr_name=args.experiment_name,
|
|
||||||
trial_name=args.trial_name,
|
|
||||||
)
|
|
||||||
sched.find_all()
|
sched.find_all()
|
||||||
sched.stop_all()
|
sched.stop_all()
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,6 @@ from rich.panel import Panel
|
||||||
|
|
||||||
from realhf.api.cli_args import console, highlighter, print_config_help
|
from realhf.api.cli_args import console, highlighter, print_config_help
|
||||||
from realhf.api.quickstart.entrypoint import QUICKSTART_CONFIG_CLASSES, QUICKSTART_FN
|
from realhf.api.quickstart.entrypoint import QUICKSTART_CONFIG_CLASSES, QUICKSTART_FN
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
|
||||||
from realhf.base.importing import import_module
|
from realhf.base.importing import import_module
|
||||||
from realhf.base.prologue import (
|
from realhf.base.prologue import (
|
||||||
PROLOGUE_EXTERNAL_CONFIG_NAME,
|
PROLOGUE_EXTERNAL_CONFIG_NAME,
|
||||||
|
@ -36,7 +35,6 @@ import_module(
|
||||||
str(pathlib.Path(__file__).resolve().parent.parent / "experiments" / "async_exp"),
|
str(pathlib.Path(__file__).resolve().parent.parent / "experiments" / "async_exp"),
|
||||||
re.compile(r".*_exp\.py$"),
|
re.compile(r".*_exp\.py$"),
|
||||||
)
|
)
|
||||||
import realhf.experiments.benchmark.profile_exp
|
|
||||||
|
|
||||||
|
|
||||||
def print_help(exp_type):
|
def print_help(exp_type):
|
||||||
|
@ -144,7 +142,7 @@ def prepare_hydra_config(name: str, prologue_path: str):
|
||||||
config = OmegaConf.load(prologue_path)
|
config = OmegaConf.load(prologue_path)
|
||||||
experiment_name = get_experiment_name(config.get("experiment_name"))
|
experiment_name = get_experiment_name(config.get("experiment_name"))
|
||||||
trial_name = get_trial_name(config.get("trial_name"))
|
trial_name = get_trial_name(config.get("trial_name"))
|
||||||
config_dir = f"{cluster_spec.fileroot}/configs/{getpass.getuser()}/{experiment_name}/{trial_name}"
|
config_dir = f"{config.cluster.fileroot}/configs/{getpass.getuser()}/{experiment_name}/{trial_name}"
|
||||||
os.makedirs(config_dir, exist_ok=True)
|
os.makedirs(config_dir, exist_ok=True)
|
||||||
|
|
||||||
config.pop(PROLOGUE_EXTERNAL_CONFIG_NAME, {})
|
config.pop(PROLOGUE_EXTERNAL_CONFIG_NAME, {})
|
||||||
|
|
|
@ -21,7 +21,7 @@ from omegaconf import OmegaConf
|
||||||
multiprocessing.set_start_method("spawn", force=True)
|
multiprocessing.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
from realhf.api.quickstart.entrypoint import QUICKSTART_CONFIG_CLASSES
|
from realhf.api.quickstart.entrypoint import QUICKSTART_CONFIG_CLASSES
|
||||||
from realhf.base import gpu_utils, importing, logging
|
from realhf.base import gpu_utils, importing, logging, name_resolve
|
||||||
from realhf.version import get_full_version_with_dirty_description
|
from realhf.version import get_full_version_with_dirty_description
|
||||||
|
|
||||||
logger = logging.getLogger("Main-Workers")
|
logger = logging.getLogger("Main-Workers")
|
||||||
|
@ -61,7 +61,6 @@ def main_worker(args):
|
||||||
import realhf.api.core.system_api as system_api
|
import realhf.api.core.system_api as system_api
|
||||||
|
|
||||||
experiment = system_api.make_experiment(name=args.experiment_name)
|
experiment = system_api.make_experiment(name=args.experiment_name)
|
||||||
constants.init_constants(experiment)
|
|
||||||
|
|
||||||
worker_index_start = args.jobstep_id * args.wprocs_per_jobstep + args.wproc_offset
|
worker_index_start = args.jobstep_id * args.wprocs_per_jobstep + args.wproc_offset
|
||||||
worker_index_end = min(
|
worker_index_end = min(
|
||||||
|
@ -166,6 +165,8 @@ def main_controller(args):
|
||||||
constants.set_experiment_trial_names(args.experiment_name, args.trial_name)
|
constants.set_experiment_trial_names(args.experiment_name, args.trial_name)
|
||||||
_patch_external_impl(args.experiment_name, args.trial_name)
|
_patch_external_impl(args.experiment_name, args.trial_name)
|
||||||
|
|
||||||
|
experiment = system_api.make_experiment(args.experiment_name)
|
||||||
|
|
||||||
logger.debug("Running controller with args: %s", args)
|
logger.debug("Running controller with args: %s", args)
|
||||||
assert not args.experiment_name.startswith("/"), args.experiment_name
|
assert not args.experiment_name.startswith("/"), args.experiment_name
|
||||||
try:
|
try:
|
||||||
|
@ -177,10 +178,6 @@ def main_controller(args):
|
||||||
experiment_name=args.experiment_name,
|
experiment_name=args.experiment_name,
|
||||||
trial_name=args.trial_name,
|
trial_name=args.trial_name,
|
||||||
)
|
)
|
||||||
experiment = system_api.make_experiment(args.experiment_name)
|
|
||||||
|
|
||||||
# Initialize cluster infor from ENV or CLI args.
|
|
||||||
constants.init_constants(experiment)
|
|
||||||
|
|
||||||
controller.start(
|
controller.start(
|
||||||
experiment=experiment,
|
experiment=experiment,
|
||||||
|
|
|
@ -3,143 +3,23 @@
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from realhf.api.cli_args import BaseExperimentConfig
|
from realhf.api.cli_args import ClusterSpecConfig
|
||||||
|
|
||||||
|
|
||||||
class ClusterSpec:
|
def load_spec_from_file(config: "ClusterSpecConfig"):
|
||||||
def __init__(self):
|
with open(config.config_path, "r") as f:
|
||||||
# Set default values to comfort ray
|
spec: Dict = json.load(f)
|
||||||
from realhf.api.cli_args import BaseExperimentConfig
|
|
||||||
|
|
||||||
self.load_spec_from_args(BaseExperimentConfig())
|
config.cluster_name = spec["cluster_name"]
|
||||||
|
config.fileroot = spec["fileroot"]
|
||||||
self.__loaded = False
|
config.gpu_type = spec.get("gpu_type", None)
|
||||||
|
config.mount = spec.get("default_mount", None)
|
||||||
def load_spec_from_file(self, file_path: str):
|
config.gpu_image = spec.get("gpu_image", None)
|
||||||
if not os.path.exists(file_path):
|
config.gpu_infer_image = spec.get("gpu_infer_image", config.gpu_image)
|
||||||
raise FileNotFoundError(f"Cluster spec file not found: {file_path}")
|
config.cpu_image = spec.get("cpu_image", None)
|
||||||
|
config.node_name_prefix = spec.get("node_name_prefix", "slurmd-")
|
||||||
with open(file_path, "r") as f:
|
config.n_nodes = int(spec.get("n_nodes", 32))
|
||||||
spec: Dict = json.load(f)
|
config.n_gpus_per_node = int(spec.get("n_gpus_per_node", 8))
|
||||||
|
|
||||||
self.__cluster_type = spec["cluster_type"]
|
|
||||||
self.__cluster_name = spec["cluster_name"]
|
|
||||||
self.__fileroot = spec["fileroot"]
|
|
||||||
self.__gpu_type = spec.get("gpu_type", None)
|
|
||||||
self.__mount = spec.get("default_mount", None)
|
|
||||||
self.__gpu_image = spec.get("gpu_image", None)
|
|
||||||
self.__gpu_infer_image = spec.get("gpu_infer_image", self.__gpu_image)
|
|
||||||
self.__cpu_image = spec.get("cpu_image", None)
|
|
||||||
self.__node_name_prefix = spec.get("node_name_prefix", "slurmd-")
|
|
||||||
# self.__n_nodes decides number of digits in slurm hostnames
|
|
||||||
# e.g. if __n_nodes = 32, then the hostnames will be slurmd-{:02d}
|
|
||||||
# if __n_nodes = 128, then the hostnames will be slurmd-{:03d}
|
|
||||||
self.__n_nodes = int(spec.get("n_nodes", 32))
|
|
||||||
self.__n_gpus_per_node = int(spec.get("n_gpus_per_node", 8))
|
|
||||||
assert isinstance(self.__n_nodes, int)
|
|
||||||
|
|
||||||
self.__loaded = True
|
|
||||||
|
|
||||||
def load_spec_from_args(self, args: "BaseExperimentConfig"):
|
|
||||||
self.__cluster_type = args.mode
|
|
||||||
self.__cluster_name = args.cluster.cluster_name
|
|
||||||
self.__fileroot = args.cluster.fileroot
|
|
||||||
self.__gpu_type = args.cluster.gpu_type
|
|
||||||
self.__mount = args.cluster.mount
|
|
||||||
self.__gpu_image = args.cluster.gpu_image
|
|
||||||
self.__gpu_infer_image = args.cluster.gpu_infer_image
|
|
||||||
self.__cpu_image = args.cluster.cpu_image
|
|
||||||
self.__node_name_prefix = args.cluster.node_name_prefix
|
|
||||||
self.__n_nodes = args.cluster.n_nodes
|
|
||||||
self.__n_gpus_per_node = args.cluster.n_gpus_per_node
|
|
||||||
self.__loaded = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
assert self.__loaded
|
|
||||||
return self.__cluster_name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def gpu_type(self):
|
|
||||||
assert self.__loaded
|
|
||||||
return self.__gpu_type
|
|
||||||
|
|
||||||
@property
|
|
||||||
def fileroot(self) -> str:
|
|
||||||
"""Return the root directory of the file system in the cluster.
|
|
||||||
|
|
||||||
When running experiments, files such as logs, checkpoints,
|
|
||||||
caches will be saved under this directory.
|
|
||||||
"""
|
|
||||||
assert self.__loaded
|
|
||||||
return self.__fileroot
|
|
||||||
|
|
||||||
@fileroot.setter
|
|
||||||
def fileroot(self, root: str):
|
|
||||||
# Used for testing
|
|
||||||
self.__fileroot = root
|
|
||||||
|
|
||||||
@property
|
|
||||||
def mount(self) -> str:
|
|
||||||
"""Directories that should be mounted to container that runs
|
|
||||||
workers."""
|
|
||||||
assert self.__loaded
|
|
||||||
return self.__mount
|
|
||||||
|
|
||||||
@property
|
|
||||||
def gpu_image(self) -> str:
|
|
||||||
"""Return the default image for containers of GPU trainer workers."""
|
|
||||||
assert self.__loaded
|
|
||||||
return self.__gpu_image
|
|
||||||
|
|
||||||
@property
|
|
||||||
def gpu_infer_image(self) -> str:
|
|
||||||
"""Return the default image for containers of GPU inference workers."""
|
|
||||||
assert self.__loaded
|
|
||||||
return self.__gpu_infer_image
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cpu_image(self) -> str:
|
|
||||||
"""Return the default image for containers of CPU workers."""
|
|
||||||
assert self.__loaded
|
|
||||||
return self.__cpu_image
|
|
||||||
|
|
||||||
@property
|
|
||||||
def node_name_prefix(self) -> str:
|
|
||||||
"""Return the prefix of node names in slurm format."""
|
|
||||||
assert self.__loaded
|
|
||||||
return self.__node_name_prefix
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_nodes(self) -> int:
|
|
||||||
return self.__n_nodes
|
|
||||||
|
|
||||||
@property
|
|
||||||
def suffix_n_digits(self) -> int:
|
|
||||||
return len(str(self.__n_nodes))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_gpus_per_node(self) -> int:
|
|
||||||
return self.__n_gpus_per_node
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cluster_type(self) -> str:
|
|
||||||
return self.__cluster_type
|
|
||||||
|
|
||||||
|
|
||||||
spec = ClusterSpec()
|
|
||||||
|
|
||||||
|
|
||||||
def init_cluster_spec(args: "BaseExperimentConfig"):
|
|
||||||
global spec
|
|
||||||
CLUSTER_SPEC_PATH = os.environ.get("CLUSTER_SPEC_PATH", "")
|
|
||||||
if args.cluster.config_path:
|
|
||||||
spec.load_spec_from_file(args.cluster.config_path)
|
|
||||||
elif CLUSTER_SPEC_PATH:
|
|
||||||
spec.load_spec_from_file(CLUSTER_SPEC_PATH)
|
|
||||||
else:
|
|
||||||
spec.load_spec_from_args(args)
|
|
||||||
|
|
|
@ -4,13 +4,10 @@
|
||||||
|
|
||||||
# log format constants
|
# log format constants
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
|
||||||
import datetime
|
import datetime
|
||||||
import getpass
|
import getpass
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import subprocess
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import *
|
from typing import *
|
||||||
|
|
||||||
|
@ -69,139 +66,115 @@ TORCH_FORCE_CPU = False
|
||||||
|
|
||||||
# constants in experiment instance scope
|
# constants in experiment instance scope
|
||||||
LOCAL_CACHE_DIR = "/tmp/realhf"
|
LOCAL_CACHE_DIR = "/tmp/realhf"
|
||||||
|
QUICKSTART_EXPR_CACHE_PATH = str(Path(__file__).parent.parent.parent / ".cache")
|
||||||
|
os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True)
|
||||||
|
PORT_LOCKFILE_ROOT = os.getenv("AREAL_PORT_LOCKFILE_ROOT", "/tmp/areal/ports/")
|
||||||
|
os.makedirs(PORT_LOCKFILE_ROOT, exist_ok=True)
|
||||||
|
|
||||||
PYTORCH_KERNEL_CACHE_PATH = (
|
PYTORCH_KERNEL_CACHE_PATH = (
|
||||||
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
|
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
|
||||||
)
|
)
|
||||||
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
|
TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
|
||||||
QUICKSTART_EXPR_CACHE_PATH = str(Path(__file__).parent.parent.parent / ".cache")
|
|
||||||
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
|
os.makedirs(PYTORCH_KERNEL_CACHE_PATH, exist_ok=True)
|
||||||
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
|
os.makedirs(TRITON_CACHE_PATH, exist_ok=True)
|
||||||
os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True)
|
|
||||||
|
|
||||||
# Global constants that should be initialized after cluster initialization.
|
|
||||||
MODEL_SAVE_ROOT = None
|
|
||||||
LOG_ROOT = None
|
|
||||||
RECOVER_ROOT = None
|
|
||||||
SLURM_LOCK_FILE_NAME = None
|
|
||||||
PORT_LOCK_FILE_ROOT = None
|
|
||||||
DATASET_CACHE_PATH = None
|
|
||||||
PROFILER_CACHE_PATH = None
|
|
||||||
PARAM_REALLOC_PATH = None
|
|
||||||
SGLANG_CACHE_PATH = None
|
|
||||||
TORCH_EXTENSIONS_DIR = None
|
|
||||||
BASE_ENVIRONS = None
|
|
||||||
|
|
||||||
|
|
||||||
def init_constants(args: "BaseExperimentConfig"):
|
def get_cache_path(args: "BaseExperimentConfig") -> str:
|
||||||
from realhf.base.cluster import init_cluster_spec
|
path = f"{args.cluster.fileroot}/.cache/{getpass.getuser()}/{args.experiment_name}/{args.trial_name}"
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
os.makedirs(path, exist_ok=True)
|
||||||
|
return path
|
||||||
|
|
||||||
init_cluster_spec(args)
|
|
||||||
|
|
||||||
globals_dict = globals() # Get module's global variables
|
def get_log_root(args: "BaseExperimentConfig") -> str:
|
||||||
|
log_root = f"{args.cluster.fileroot}/logs/{getpass.getuser()}"
|
||||||
|
os.makedirs(log_root, exist_ok=True)
|
||||||
|
return log_root
|
||||||
|
|
||||||
kwargs = dict(
|
|
||||||
MODEL_SAVE_ROOT=f"{cluster_spec.fileroot}/checkpoints/{getpass.getuser()}",
|
|
||||||
LOG_ROOT=f"{cluster_spec.fileroot}/logs/{getpass.getuser()}",
|
|
||||||
RECOVER_ROOT=f"{cluster_spec.fileroot}/recover/{getpass.getuser()}",
|
|
||||||
SLURM_LOCK_FILE_NAME=f"{cluster_spec.fileroot}/logs/slurm_scheduler.lock",
|
|
||||||
PORT_LOCK_FILE_ROOT=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/ports",
|
|
||||||
DATASET_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/datasets",
|
|
||||||
PROFILER_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/profiler",
|
|
||||||
PARAM_REALLOC_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/param_realloc",
|
|
||||||
SGLANG_CACHE_PATH=f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/sglang",
|
|
||||||
TORCH_EXTENSIONS_DIR=(
|
|
||||||
f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/torch/extensions"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
BASE_ENVIRONS = {
|
|
||||||
# "PYTHONPATH": "/realhf",
|
|
||||||
"REAL_IS_REMOTE": "1",
|
|
||||||
# "NCCL_P2P_DISABLE": "1",
|
|
||||||
# "NCCL_IB_DISABLE": "1",
|
|
||||||
"TRANSFORMERS_OFFLINE": "1",
|
|
||||||
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
|
|
||||||
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
|
||||||
"TOKENIZERS_PARALLELISM": "true",
|
|
||||||
"TORCH_EXTENSIONS_DIR": kwargs["TORCH_EXTENSIONS_DIR"],
|
|
||||||
# "TORCH_DISTRIBUTED_DEBUG": "DETAIL",
|
|
||||||
# "NCCL_SOCKET_IFNAME": "ibp71s0",
|
|
||||||
# "GLOO_SOCKET_IFNAME": "ibp71s0",
|
|
||||||
# "TORCH_USE_CUDA_DSA": "1",
|
|
||||||
# "NCCL_IGNORE_DISABLED_P2P": "1",
|
|
||||||
# "CUDA_LAUNCH_BLOCKING": "1", # NOTE: CUDAGraph Capturing will not work if CUDA_LAUNCH_BLOCKING is set to 1.
|
|
||||||
# "NCCL_COMM_BLOCKING": "1", # NOTE: CUDAGraph Capturing will not work if NCCL_COMM_BLOCKING is set to 1.
|
|
||||||
# "NCCL_BLOCKING_WAIT": "1", # NOTE: CUDAGraph Capturing will not work if NCCL_BLOCKING_WAIT is set to 1.
|
|
||||||
# "TORCH_SHOW_CPP_STACKTRACES": "1",
|
|
||||||
# "RAY_DEDUP_LOGS": "0", # disable ray log deduplication
|
|
||||||
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
|
|
||||||
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
|
|
||||||
# torch.distributed.all_reduce does not free the input tensor until
|
|
||||||
# the synchronization point. This causes the memory usage to grow
|
|
||||||
# as the number of all_reduce calls increases. This env var disables
|
|
||||||
# this behavior.
|
|
||||||
# Related issue:
|
|
||||||
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
|
||||||
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
|
|
||||||
# Whether to enable time mark to plot timelines.
|
|
||||||
"REAL_CUDA_TMARK": os.getenv("REAL_CUDA_TMARK", "0"),
|
|
||||||
"REAL_DUMP_TRACE": os.getenv("REAL_DUMP_TRACE", "0"),
|
|
||||||
"REAL_DUMP_MEMORY": os.getenv("REAL_DUMP_MEMORY", "0"),
|
|
||||||
"REAL_GPU_MEMORY_KILL_THRESHOLD": os.getenv(
|
|
||||||
"REAL_GPU_MEMORY_KILL_THRESHOLD", "1.0"
|
|
||||||
),
|
|
||||||
"LC_ALL": "C",
|
|
||||||
"LANG": "C",
|
|
||||||
"NCCL_DEBUG": "WARN",
|
|
||||||
}
|
|
||||||
kwargs["BASE_ENVIRONS"] = BASE_ENVIRONS
|
|
||||||
# Set PPU-specific environment variables for stable training.
|
|
||||||
if cluster_spec.name == "wa180":
|
|
||||||
logger.warning("Detected PPU. Amending PPU-related environment variables.")
|
|
||||||
PPU_ENVIRONS = {
|
|
||||||
"NCCL_DEBUG": "INFO",
|
|
||||||
"NCCL_IB_DISABLE": "1",
|
|
||||||
"NCCL_DEBUG_SUBSYS": "INIT",
|
|
||||||
"NCCL_SET_THREAD_NAME": "1",
|
|
||||||
"NCCL_IB_HCA": "",
|
|
||||||
"NCCL_SOCKET_IFNAME": "bond0",
|
|
||||||
"PCCL_STATE_MONITOR_DISABLE": "1",
|
|
||||||
}
|
|
||||||
kwargs["BASE_ENVIRONS"].update(PPU_ENVIRONS)
|
|
||||||
elif cluster_spec.name == "na132":
|
|
||||||
# Specific environment variable for h800 cluster na132
|
|
||||||
NV_ENVIRONS = {
|
|
||||||
"NCCL_SOCKET_IFNAME": "bond0",
|
|
||||||
"NCCL_NET_PLUGIN": "",
|
|
||||||
"NCCL_IB_GID_INDEX": "3",
|
|
||||||
"NCCL_IB_TIMEOUT": "2",
|
|
||||||
"NCCL_IB_RETRY_CNT": "7",
|
|
||||||
"NCCL_IB_SL": "5",
|
|
||||||
"NCCL_IB_TC": "136",
|
|
||||||
"NCCL_IB_HCA": "mlx5_bond",
|
|
||||||
"NCCL_IB_QPS_PER_CONNECTION": "8",
|
|
||||||
"NCCL_SET_THREAD_NAME": "1",
|
|
||||||
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
|
||||||
}
|
|
||||||
kwargs["BASE_ENVIRONS"].update(NV_ENVIRONS)
|
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
def get_log_path(args: "BaseExperimentConfig") -> str:
|
||||||
if key not in globals_dict:
|
log_path = f"{args.cluster.fileroot}/logs/{getpass.getuser()}/{args.experiment_name}/{args.trial_name}"
|
||||||
raise ValueError(f"Invalid constant name: {key}")
|
os.makedirs(log_path, exist_ok=True)
|
||||||
if globals_dict[key] is not None and globals_dict[key] != value:
|
return log_path
|
||||||
raise RuntimeError(f"Constant '{key}' already initialized!")
|
|
||||||
globals_dict[key] = value
|
|
||||||
|
|
||||||
# make directories if does not exist
|
|
||||||
os.makedirs(globals_dict["PARAM_REALLOC_PATH"], exist_ok=True)
|
def get_save_root(args: "BaseExperimentConfig") -> str:
|
||||||
os.makedirs(globals_dict["MODEL_SAVE_ROOT"], exist_ok=True)
|
path = f"{args.cluster.fileroot}/checkpoints/{getpass.getuser()}"
|
||||||
os.makedirs(globals_dict["LOG_ROOT"], exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
os.makedirs(globals_dict["RECOVER_ROOT"], exist_ok=True)
|
return path
|
||||||
os.makedirs(globals_dict["DATASET_CACHE_PATH"], exist_ok=True)
|
|
||||||
os.makedirs(globals_dict["PROFILER_CACHE_PATH"], exist_ok=True)
|
|
||||||
os.makedirs(globals_dict["TORCH_EXTENSIONS_DIR"], exist_ok=True)
|
def get_save_path(args: "BaseExperimentConfig") -> str:
|
||||||
os.makedirs(globals_dict["PORT_LOCK_FILE_ROOT"], exist_ok=True)
|
path = f"{args.cluster.fileroot}/checkpoints/{getpass.getuser()}/{args.experiment_name}/{args.trial_name}"
|
||||||
os.makedirs(globals_dict["SGLANG_CACHE_PATH"], exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def get_param_realloc_path(args: "BaseExperimentConfig"):
|
||||||
|
path = f"{args.cluster.fileroot}/.cache/{getpass.getuser()}/param_realloc"
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
BASE_ENVIRONS = {
|
||||||
|
# "PYTHONPATH": "/realhf",
|
||||||
|
"REAL_IS_REMOTE": "1",
|
||||||
|
# "NCCL_P2P_DISABLE": "1",
|
||||||
|
# "NCCL_IB_DISABLE": "1",
|
||||||
|
"TRANSFORMERS_OFFLINE": "1",
|
||||||
|
"TOKENIZERS_PARALLELISM": "true",
|
||||||
|
"PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH,
|
||||||
|
"TRITON_CACHE_DIR": TRITON_CACHE_PATH,
|
||||||
|
# "TORCH_DISTRIBUTED_DEBUG": "DETAIL",
|
||||||
|
# "NCCL_SOCKET_IFNAME": "ibp71s0",
|
||||||
|
# "GLOO_SOCKET_IFNAME": "ibp71s0",
|
||||||
|
# "TORCH_USE_CUDA_DSA": "1",
|
||||||
|
# "NCCL_IGNORE_DISABLED_P2P": "1",
|
||||||
|
# "CUDA_LAUNCH_BLOCKING": "1", # NOTE: CUDAGraph Capturing will not work if CUDA_LAUNCH_BLOCKING is set to 1.
|
||||||
|
# "NCCL_COMM_BLOCKING": "1", # NOTE: CUDAGraph Capturing will not work if NCCL_COMM_BLOCKING is set to 1.
|
||||||
|
# "NCCL_BLOCKING_WAIT": "1", # NOTE: CUDAGraph Capturing will not work if NCCL_BLOCKING_WAIT is set to 1.
|
||||||
|
# "TORCH_SHOW_CPP_STACKTRACES": "1",
|
||||||
|
# "RAY_DEDUP_LOGS": "0", # disable ray log deduplication
|
||||||
|
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
|
||||||
|
"OMP_NUM_THREADS": str(min(os.cpu_count(), 32)),
|
||||||
|
# torch.distributed.all_reduce does not free the input tensor until
|
||||||
|
# the synchronization point. This causes the memory usage to grow
|
||||||
|
# as the number of all_reduce calls increases. This env var disables
|
||||||
|
# this behavior.
|
||||||
|
# Related issue:
|
||||||
|
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||||
|
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
|
||||||
|
# Whether to enable time mark to plot timelines.
|
||||||
|
"REAL_DUMP_TRACE": os.getenv("REAL_DUMP_TRACE", "0"),
|
||||||
|
"REAL_DUMP_MEMORY": os.getenv("REAL_DUMP_MEMORY", "0"),
|
||||||
|
"REAL_GPU_MEMORY_KILL_THRESHOLD": os.getenv(
|
||||||
|
"REAL_GPU_MEMORY_KILL_THRESHOLD", "1.0"
|
||||||
|
),
|
||||||
|
"LC_ALL": "C",
|
||||||
|
"LANG": "C",
|
||||||
|
"NCCL_DEBUG": "WARN",
|
||||||
|
}
|
||||||
|
PPU_ENVIRONS = {
|
||||||
|
"NCCL_DEBUG": "INFO",
|
||||||
|
"NCCL_IB_DISABLE": "1",
|
||||||
|
"NCCL_DEBUG_SUBSYS": "INIT",
|
||||||
|
"NCCL_SET_THREAD_NAME": "1",
|
||||||
|
"NCCL_IB_HCA": "",
|
||||||
|
"NCCL_SOCKET_IFNAME": "bond0",
|
||||||
|
"PCCL_STATE_MONITOR_DISABLE": "1",
|
||||||
|
}
|
||||||
|
NA132_ENVIRONS = {
|
||||||
|
"NCCL_SOCKET_IFNAME": "bond0",
|
||||||
|
"NCCL_NET_PLUGIN": "",
|
||||||
|
"NCCL_IB_GID_INDEX": "3",
|
||||||
|
"NCCL_IB_TIMEOUT": "2",
|
||||||
|
"NCCL_IB_RETRY_CNT": "7",
|
||||||
|
"NCCL_IB_SL": "5",
|
||||||
|
"NCCL_IB_TC": "136",
|
||||||
|
"NCCL_IB_HCA": "mlx5_bond",
|
||||||
|
"NCCL_IB_QPS_PER_CONNECTION": "8",
|
||||||
|
"NCCL_SET_THREAD_NAME": "1",
|
||||||
|
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# _model_name will be changed in the model_scope context manager
|
# _model_name will be changed in the model_scope context manager
|
||||||
|
@ -570,18 +543,21 @@ def get_repo_path() -> pathlib.Path:
|
||||||
return pathlib.Path(__file__).resolve().parent.parent.parent
|
return pathlib.Path(__file__).resolve().parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
def get_env_vars(**kwargs):
|
def get_env_vars(exp_cfg: "BaseExperimentConfig", **kwargs):
|
||||||
kwargs.update(
|
kwargs.update(
|
||||||
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
|
|
||||||
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
|
REAL_DUMP_TRACE=os.environ.get("REAL_DUMP_TRACE", "0"),
|
||||||
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),
|
REAL_RECORD_PERFORMANCE=os.environ.get("REAL_RECORD_PERFORMANCE", "0"),
|
||||||
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
||||||
REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"),
|
REAL_DUMP_MEMORY=os.environ.get("REAL_DUMP_MEMORY", "0"),
|
||||||
REAL_ETCD_ADDR=os.getenv("REAL_ETCD_ADDR", "localhost:2379"),
|
|
||||||
REAL_OSS_TESTCASE_PATH=os.getenv("REAL_OSS_TESTCASE_PATH", ""),
|
REAL_OSS_TESTCASE_PATH=os.getenv("REAL_OSS_TESTCASE_PATH", ""),
|
||||||
)
|
)
|
||||||
return {
|
envvars = {
|
||||||
**kwargs,
|
**kwargs,
|
||||||
"REAL_PACKAGE_PATH": str(get_repo_path()),
|
"REAL_PACKAGE_PATH": str(get_repo_path()),
|
||||||
**BASE_ENVIRONS,
|
**BASE_ENVIRONS,
|
||||||
}
|
}
|
||||||
|
if exp_cfg.cluster.cluster_name == "wa180":
|
||||||
|
envvars.update(**PPU_ENVIRONS)
|
||||||
|
if exp_cfg.cluster.cluster_name == "na132":
|
||||||
|
envvars.update(**NA132_ENVIRONS)
|
||||||
|
return envvars
|
||||||
|
|
|
@ -358,106 +358,6 @@ def calculate_llama_gen_flops(
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
|
|
||||||
#################### CUDA Kernel Time Marking Start ####################
|
|
||||||
# Used to create timeline plots.
|
|
||||||
|
|
||||||
|
|
||||||
class CUDATimeMarkType(enum.Enum):
|
|
||||||
forward = "forward"
|
|
||||||
backward = "backward"
|
|
||||||
optim_step = "optim_step"
|
|
||||||
comm = "comm"
|
|
||||||
misc = "misc"
|
|
||||||
mem_layout = "memory_layout"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class TimeMarkEntry:
|
|
||||||
name: str
|
|
||||||
model_name: "ModelName"
|
|
||||||
type_: CUDATimeMarkType
|
|
||||||
start_time: int
|
|
||||||
end_time: int
|
|
||||||
|
|
||||||
|
|
||||||
TIME_MARK_DB = []
|
|
||||||
|
|
||||||
|
|
||||||
def cuda_tmark(name: str, type_: CUDATimeMarkType):
|
|
||||||
if os.getenv("REAL_CUDA_TMARK", None) == "1":
|
|
||||||
|
|
||||||
def wrapper(f: Callable):
|
|
||||||
|
|
||||||
def _wrapped_f(*args, **kwargs):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if constants.use_cuda():
|
|
||||||
from realhf.base.constants import _model_name
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
tik = time.time_ns()
|
|
||||||
res = f(*args, **kwargs)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
tok = time.time_ns()
|
|
||||||
global TIME_MARK_DB
|
|
||||||
TIME_MARK_DB.append(
|
|
||||||
TimeMarkEntry(name, _model_name, type_, tik, tok)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
res = f(*args, **kwargs)
|
|
||||||
return res
|
|
||||||
|
|
||||||
return _wrapped_f
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def wrapper(f):
|
|
||||||
return f
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def cuda_tmarked(name: str, type_: CUDATimeMarkType):
|
|
||||||
if os.getenv("REAL_CUDA_TMARK", None) == "1":
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if constants.use_cuda():
|
|
||||||
from realhf.base.constants import _model_name
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
tik = time.time_ns()
|
|
||||||
yield
|
|
||||||
if os.getenv("REAL_CUDA_TMARK", None) == "1":
|
|
||||||
if constants.use_cuda():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
tok = time.time_ns()
|
|
||||||
global TIME_MARK_DB
|
|
||||||
TIME_MARK_DB.append(TimeMarkEntry(name, _model_name, type_, tik, tok))
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_latest_tmark():
|
|
||||||
global TIME_MARK_DB
|
|
||||||
return TIME_MARK_DB[-1]
|
|
||||||
|
|
||||||
|
|
||||||
def dump_tmark_db(worker_idx):
|
|
||||||
if os.getenv("REAL_CUDA_TMARK", None) != "1":
|
|
||||||
return
|
|
||||||
fn = os.path.join(
|
|
||||||
constants.LOG_ROOT,
|
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
f"time_marks{worker_idx}.pkl",
|
|
||||||
)
|
|
||||||
global TIME_MARK_DB
|
|
||||||
with open(fn, "wb") as f:
|
|
||||||
pickle.dump(TIME_MARK_DB, f)
|
|
||||||
TIME_MARK_DB.clear()
|
|
||||||
|
|
||||||
|
|
||||||
#################### CUDA Kernel Time Marking End ####################
|
|
||||||
|
|
||||||
#################### CUDA Kernel Time Statistics Start ####################
|
#################### CUDA Kernel Time Statistics Start ####################
|
||||||
# Categorizing CUDA kernels into computation, communication, memory IO, and MISC/IDLE,
|
# Categorizing CUDA kernels into computation, communication, memory IO, and MISC/IDLE,
|
||||||
# used to plot the percentage of time spent on each category and show how much we can
|
# used to plot the percentage of time spent on each category and show how much we can
|
||||||
|
|
|
@ -11,7 +11,7 @@ import shutil
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Callable, List, Optional
|
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
|
@ -22,6 +22,9 @@ except Exception:
|
||||||
|
|
||||||
from realhf.base import logging, security, timeutil
|
from realhf.base import logging, security, timeutil
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from realhf.api.cli_args import NameResolveConfig
|
||||||
|
|
||||||
logger = logging.getLogger("name-resolve")
|
logger = logging.getLogger("name-resolve")
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,12 +48,6 @@ class NameRecordRepository:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Exception ignore when deleting NameResolveRepo {e}")
|
logger.info(f"Exception ignore when deleting NameResolveRepo {e}")
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
name,
|
name,
|
||||||
|
@ -287,24 +284,20 @@ class MemoryNameRecordRepository(NameRecordRepository):
|
||||||
|
|
||||||
|
|
||||||
class NfsNameRecordRepository(NameRecordRepository):
|
class NfsNameRecordRepository(NameRecordRepository):
|
||||||
RECORD_ROOT = ""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, record_root="", **kwargs):
|
||||||
self.__to_delete = set()
|
self.__to_delete = set()
|
||||||
|
self.record_root = record_root
|
||||||
|
|
||||||
@staticmethod
|
def __dir_path(self, name):
|
||||||
def __dir_path(name):
|
if not self.record_root:
|
||||||
if not NfsNameRecordRepository.RECORD_ROOT:
|
raise RuntimeError(
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
f"The `record_root` of NfsNameRecordRepository is not properly reconfigured."
|
||||||
|
)
|
||||||
|
return os.path.join(self.record_root, name)
|
||||||
|
|
||||||
RECORD_ROOT = f"{cluster_spec.fileroot}/name_resolve/"
|
def __file_path(self, name):
|
||||||
os.makedirs(RECORD_ROOT, exist_ok=True)
|
return os.path.join(self.__dir_path(name), "ENTRY")
|
||||||
NfsNameRecordRepository.RECORD_ROOT = RECORD_ROOT
|
|
||||||
return os.path.join(NfsNameRecordRepository.RECORD_ROOT, name)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __file_path(name):
|
|
||||||
return os.path.join(NfsNameRecordRepository.__dir_path(name), "ENTRY")
|
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
|
@ -342,7 +335,7 @@ class NfsNameRecordRepository(NameRecordRepository):
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
while True:
|
while True:
|
||||||
path = os.path.dirname(path)
|
path = os.path.dirname(path)
|
||||||
if path == NfsNameRecordRepository.RECORD_ROOT:
|
if path == self.record_root:
|
||||||
break
|
break
|
||||||
if len(os.listdir(path)) > 0:
|
if len(os.listdir(path)) > 0:
|
||||||
break
|
break
|
||||||
|
@ -385,7 +378,7 @@ class NfsNameRecordRepository(NameRecordRepository):
|
||||||
continue
|
continue
|
||||||
if files[0] != "ENTRY":
|
if files[0] != "ENTRY":
|
||||||
continue
|
continue
|
||||||
key = root.removeprefix(self.RECORD_ROOT)
|
key = root.removeprefix(self.record_root)
|
||||||
key = key.removeprefix("/")
|
key = key.removeprefix("/")
|
||||||
rs.append(self.get(key))
|
rs.append(self.get(key))
|
||||||
except NameEntryNotFoundError:
|
except NameEntryNotFoundError:
|
||||||
|
@ -402,7 +395,7 @@ class NfsNameRecordRepository(NameRecordRepository):
|
||||||
continue
|
continue
|
||||||
if files[0] != "ENTRY":
|
if files[0] != "ENTRY":
|
||||||
continue
|
continue
|
||||||
key = root.removeprefix(self.RECORD_ROOT)
|
key = root.removeprefix(self.record_root)
|
||||||
key = key.removeprefix("/")
|
key = key.removeprefix("/")
|
||||||
rs.append(key)
|
rs.append(key)
|
||||||
except NameEntryNotFoundError:
|
except NameEntryNotFoundError:
|
||||||
|
@ -571,15 +564,6 @@ class Etcd3NameRecordRepository(NameRecordRepository):
|
||||||
TTL-based expiration, atomic operations, and key watching functionality.
|
TTL-based expiration, atomic operations, and key watching functionality.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Default configuration
|
|
||||||
try:
|
|
||||||
host, port = os.getenv("REAL_ETCD_ADDR", "").split(":")
|
|
||||||
except ValueError:
|
|
||||||
host, port = "localhost", 2379
|
|
||||||
ETCD_HOST = host
|
|
||||||
ETCD_PORT = int(port)
|
|
||||||
ETCD_USER = None
|
|
||||||
ETCD_PASSWORD = None
|
|
||||||
KEEPALIVE_POLL_FREQUENCY = 1
|
KEEPALIVE_POLL_FREQUENCY = 1
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
@ -604,14 +588,17 @@ class Etcd3NameRecordRepository(NameRecordRepository):
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
# Set connection parameters
|
# Set connection parameters
|
||||||
self._host = host or self.ETCD_HOST
|
self._host = host
|
||||||
self._port = port or self.ETCD_PORT
|
self._port = port
|
||||||
self._user = user or self.ETCD_USER
|
self._user = user
|
||||||
self._password = password or self.ETCD_PASSWORD
|
self._password = password
|
||||||
|
|
||||||
# Connect to etcd
|
# Connect to etcd
|
||||||
self._client = etcd3.client(
|
self._client = etcd3.client(
|
||||||
host=self._host, port=self._port, user=self._user, password=self._password
|
host=self._host,
|
||||||
|
port=self._port,
|
||||||
|
user=self._user,
|
||||||
|
password=self._password,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keep track of entries for cleanup and keepalive
|
# Keep track of entries for cleanup and keepalive
|
||||||
|
@ -835,16 +822,17 @@ class Etcd3NameRecordRepository(NameRecordRepository):
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Delete all keys added via this repository instance."""
|
"""Delete all keys added via this repository instance."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
count = 0
|
if hasattr(self, "_to_delete"):
|
||||||
for name in self._to_delete:
|
count = 0
|
||||||
if name in self._entries:
|
for name in self._to_delete:
|
||||||
try:
|
if name in self._entries:
|
||||||
self._delete_locked(name)
|
try:
|
||||||
count += 1
|
self._delete_locked(name)
|
||||||
except NameEntryNotFoundError:
|
count += 1
|
||||||
pass
|
except NameEntryNotFoundError:
|
||||||
self._to_delete = set()
|
pass
|
||||||
logger.info(f"Reset {count} saved etcd entries")
|
self._to_delete = set()
|
||||||
|
logger.info(f"Reset {count} saved etcd entries")
|
||||||
|
|
||||||
def _keepalive_thread_run(self):
|
def _keepalive_thread_run(self):
|
||||||
"""Background thread to keep leases alive."""
|
"""Background thread to keep leases alive."""
|
||||||
|
@ -1099,12 +1087,6 @@ class RayNameResolveRepository:
|
||||||
f"Exception ignored when deleting RayNameResolveRepository: {e}"
|
f"Exception ignored when deleting RayNameResolveRepository: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
|
@ -1376,31 +1358,19 @@ class RayNameResolveRepository:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_repository(type_="nfs", **kwargs):
|
def make_repository(args: "NameResolveConfig"):
|
||||||
if type_ == "memory":
|
if args.type == "nfs":
|
||||||
return MemoryNameRecordRepository(**kwargs)
|
return NfsNameRecordRepository(args.nfs_record_root)
|
||||||
elif type_ == "nfs":
|
elif args.type == "etcd3":
|
||||||
return NfsNameRecordRepository(**kwargs)
|
host, port = args.etcd3_addr.split(":")
|
||||||
elif type_ == "redis":
|
return Etcd3NameRecordRepository(host=host, port=int(port))
|
||||||
return RedisNameRecordRepository(**kwargs)
|
elif args.type == "ray":
|
||||||
elif type_ == "etcd3":
|
return RayNameResolveRepository(actor_name=args.ray_actor_name)
|
||||||
return Etcd3NameRecordRepository(**kwargs)
|
|
||||||
elif type_ == "ray":
|
|
||||||
return RayNameResolveRepository(**kwargs)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"No such name resolver: {type_}")
|
raise NotImplementedError(f"No such name resolver: {args.type}")
|
||||||
|
|
||||||
|
|
||||||
# DEFAULT_REPOSITORY_TYPE = "redis" if socket.gethostname().startswith("frl") else "nfs"
|
DEFAULT_REPOSITORY = NfsNameRecordRepository()
|
||||||
DEFAULT_REPOSITORY_TYPE = "nfs"
|
|
||||||
if etcd3 is not None and os.getenv("REAL_ETCD_ADDR", ""):
|
|
||||||
DEFAULT_REPOSITORY_TYPE = "etcd3"
|
|
||||||
if os.getenv("REAL_ETCD_ADDR", "") and etcd3 is None:
|
|
||||||
logger.warning(
|
|
||||||
f"Detected REAL_ETCD_ADDR but etcd3 client is not available. "
|
|
||||||
"Please run `pip install -r requirements.txt` if you want to use etcd name resolve."
|
|
||||||
)
|
|
||||||
DEFAULT_REPOSITORY = make_repository(DEFAULT_REPOSITORY_TYPE)
|
|
||||||
add = DEFAULT_REPOSITORY.add
|
add = DEFAULT_REPOSITORY.add
|
||||||
add_subentry = DEFAULT_REPOSITORY.add_subentry
|
add_subentry = DEFAULT_REPOSITORY.add_subentry
|
||||||
delete = DEFAULT_REPOSITORY.delete
|
delete = DEFAULT_REPOSITORY.delete
|
||||||
|
@ -1413,11 +1383,10 @@ reset = DEFAULT_REPOSITORY.reset
|
||||||
watch_names = DEFAULT_REPOSITORY.watch_names
|
watch_names = DEFAULT_REPOSITORY.watch_names
|
||||||
|
|
||||||
|
|
||||||
def reconfigure(*args, **kwargs):
|
def reconfigure(config: "NameResolveConfig"):
|
||||||
global DEFAULT_REPOSITORY, DEFAULT_REPOSITORY_TYPE
|
global DEFAULT_REPOSITORY
|
||||||
global add, add_subentry, delete, clear_subtree, get, get_subtree, find_subtree, wait, reset, watch_names
|
global add, add_subentry, delete, clear_subtree, get, get_subtree, find_subtree, wait, reset, watch_names
|
||||||
DEFAULT_REPOSITORY = make_repository(*args, **kwargs)
|
DEFAULT_REPOSITORY = make_repository(config)
|
||||||
DEFAULT_REPOSITORY_TYPE = args[0]
|
|
||||||
add = DEFAULT_REPOSITORY.add
|
add = DEFAULT_REPOSITORY.add
|
||||||
add_subentry = DEFAULT_REPOSITORY.add_subentry
|
add_subentry = DEFAULT_REPOSITORY.add_subentry
|
||||||
delete = DEFAULT_REPOSITORY.delete
|
delete = DEFAULT_REPOSITORY.delete
|
||||||
|
|
|
@ -23,14 +23,20 @@ def gethostip():
|
||||||
|
|
||||||
|
|
||||||
def find_free_port(
|
def find_free_port(
|
||||||
low=1, high=65536, exclude_ports=None, experiment_name="port", trial_name="port"
|
low=1,
|
||||||
|
high=65536,
|
||||||
|
exclude_ports=None,
|
||||||
|
experiment_name="port",
|
||||||
|
trial_name="port",
|
||||||
|
lockfile_root=constants.PORT_LOCKFILE_ROOT,
|
||||||
):
|
):
|
||||||
"""Find a free port within the specified range, excluding certain ports."""
|
"""Find a free port within the specified range, excluding certain ports."""
|
||||||
|
|
||||||
ports_name = names.used_ports(experiment_name, trial_name, gethostip())
|
ports_name = names.used_ports(experiment_name, trial_name, gethostip())
|
||||||
|
|
||||||
free_port = None
|
free_port = None
|
||||||
lockfile = os.path.join(constants.PORT_LOCK_FILE_ROOT, gethostip())
|
os.makedirs(lockfile_root, exist_ok=True)
|
||||||
|
lockfile = os.path.join(lockfile_root, gethostip())
|
||||||
while True:
|
while True:
|
||||||
with open(lockfile, "w") as fd:
|
with open(lockfile, "w") as fd:
|
||||||
# This will block until lock is acquired
|
# This will block until lock is acquired
|
||||||
|
@ -58,7 +64,12 @@ def find_free_port(
|
||||||
|
|
||||||
|
|
||||||
def find_multiple_free_ports(
|
def find_multiple_free_ports(
|
||||||
count, low=1, high=65536, experiment_name="port", trial_name="port"
|
count,
|
||||||
|
low=1,
|
||||||
|
high=65536,
|
||||||
|
experiment_name="port",
|
||||||
|
trial_name="port",
|
||||||
|
lockfile_root=constants.PORT_LOCKFILE_ROOT,
|
||||||
):
|
):
|
||||||
"""Find multiple mutually exclusive free ports."""
|
"""Find multiple mutually exclusive free ports."""
|
||||||
free_ports = set()
|
free_ports = set()
|
||||||
|
@ -69,6 +80,7 @@ def find_multiple_free_ports(
|
||||||
exclude_ports=free_ports,
|
exclude_ports=free_ports,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
trial_name=trial_name,
|
trial_name=trial_name,
|
||||||
|
lockfile_root=lockfile_root,
|
||||||
)
|
)
|
||||||
free_ports.add(port)
|
free_ports.add(port)
|
||||||
return list(free_ports)
|
return list(free_ports)
|
||||||
|
|
|
@ -35,26 +35,6 @@ def global_init():
|
||||||
if key not in os.environ:
|
if key not in os.environ:
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
|
|
||||||
# resolve config path for cluster spec.
|
|
||||||
cluster_spec_path = os.environ.get("CLUSTER_SPEC_PATH", "")
|
|
||||||
if cluster_spec_path == "":
|
|
||||||
if external_configs.get("cluster_config"):
|
|
||||||
fileroot = external_configs.cluster_config.get("fileroot")
|
|
||||||
if fileroot is not None and fileroot != "":
|
|
||||||
experiment_name = get_experiment_name(config.get("experiment_name"))
|
|
||||||
trial_name = get_trial_name(config.get("trial_name"))
|
|
||||||
config_dir = f"{fileroot}/configs/{getpass.getuser()}/{experiment_name}/{trial_name}"
|
|
||||||
os.makedirs(config_dir, exist_ok=True)
|
|
||||||
cluster_spec_path = f"{config_dir}/cluster_config.json"
|
|
||||||
cluster_spec = OmegaConf.to_container(external_configs.cluster_config)
|
|
||||||
if "cluster_type" not in cluster_spec:
|
|
||||||
cluster_spec["cluster_type"] = config.mode
|
|
||||||
if "cluster_name" not in cluster_spec:
|
|
||||||
cluster_spec["cluster_name"] = f"{config.mode}_cluster"
|
|
||||||
with open(cluster_spec_path, "w") as f:
|
|
||||||
json.dump(cluster_spec, f)
|
|
||||||
os.environ["CLUSTER_SPEC_PATH"] = cluster_spec_path
|
|
||||||
|
|
||||||
|
|
||||||
def get_experiment_name(default_name: str = ""):
|
def get_experiment_name(default_name: str = ""):
|
||||||
if any("experiment_name=" in x for x in sys.argv):
|
if any("experiment_name=" in x for x in sys.argv):
|
||||||
|
|
|
@ -40,13 +40,11 @@ class RecoverInfo:
|
||||||
hash_vals_to_ignore: List[int] = dataclasses.field(default_factory=list)
|
hash_vals_to_ignore: List[int] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def dump_recover_info(recover_info: RecoverInfo):
|
def dump_recover_info(args, recover_info: RecoverInfo):
|
||||||
global RECOVER_INFO_PATH
|
global RECOVER_INFO_PATH
|
||||||
if RECOVER_INFO_PATH is None:
|
if RECOVER_INFO_PATH is None:
|
||||||
RECOVER_INFO_PATH = os.path.join(
|
RECOVER_INFO_PATH = os.path.join(
|
||||||
constants.RECOVER_ROOT,
|
constants.get_save_path(args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"recover_info.pkl",
|
"recover_info.pkl",
|
||||||
)
|
)
|
||||||
os.makedirs(os.path.dirname(RECOVER_INFO_PATH), exist_ok=True)
|
os.makedirs(os.path.dirname(RECOVER_INFO_PATH), exist_ok=True)
|
||||||
|
@ -54,15 +52,13 @@ def dump_recover_info(recover_info: RecoverInfo):
|
||||||
pickle.dump(recover_info, f)
|
pickle.dump(recover_info, f)
|
||||||
|
|
||||||
|
|
||||||
def load_recover_info() -> Tuple[int, Optional[RecoverInfo]]:
|
def load_recover_info(args) -> Tuple[int, Optional[RecoverInfo]]:
|
||||||
if os.environ.get("REAL_RECOVER_RUN", "0") != "1":
|
if os.environ.get("REAL_RECOVER_RUN", "0") != "1":
|
||||||
return False, None
|
return False, None
|
||||||
global RECOVER_INFO_PATH
|
global RECOVER_INFO_PATH
|
||||||
if RECOVER_INFO_PATH is None:
|
if RECOVER_INFO_PATH is None:
|
||||||
RECOVER_INFO_PATH = os.path.join(
|
RECOVER_INFO_PATH = os.path.join(
|
||||||
constants.RECOVER_ROOT,
|
constants.get_save_path(args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"recover_info.pkl",
|
"recover_info.pkl",
|
||||||
)
|
)
|
||||||
os.makedirs(os.path.dirname(RECOVER_INFO_PATH), exist_ok=True)
|
os.makedirs(os.path.dirname(RECOVER_INFO_PATH), exist_ok=True)
|
||||||
|
@ -81,15 +77,9 @@ class InValidRecoverCkpt(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def discover_ckpt(
|
def discover_ckpt(args) -> Tuple[str, List[str], RecoverInfo]:
|
||||||
expr_name: str, trial_name: str
|
expr_name, trial_name = args.experiment_name, args.trial_name
|
||||||
) -> Tuple[str, List[str], RecoverInfo]:
|
recover_info_file = pathlib.Path(constants.get_save_path(args)) / "recover_info.pkl"
|
||||||
recover_info_file = (
|
|
||||||
pathlib.Path(constants.RECOVER_ROOT)
|
|
||||||
/ expr_name
|
|
||||||
/ trial_name
|
|
||||||
/ "recover_info.pkl"
|
|
||||||
)
|
|
||||||
if os.path.exists(str(recover_info_file)):
|
if os.path.exists(str(recover_info_file)):
|
||||||
with open(recover_info_file, "rb") as f:
|
with open(recover_info_file, "rb") as f:
|
||||||
info: RecoverInfo = pickle.load(f)
|
info: RecoverInfo = pickle.load(f)
|
||||||
|
@ -100,9 +90,7 @@ def discover_ckpt(
|
||||||
f"but found {info.last_step_info.epoch}"
|
f"but found {info.last_step_info.epoch}"
|
||||||
)
|
)
|
||||||
raise InValidRecoverCkpt(msg)
|
raise InValidRecoverCkpt(msg)
|
||||||
model_save_dir = (
|
model_save_dir = pathlib.Path(constants.get_save_path(args))
|
||||||
pathlib.Path(constants.MODEL_SAVE_ROOT) / expr_name / trial_name
|
|
||||||
)
|
|
||||||
model_ckpt_dirs = []
|
model_ckpt_dirs = []
|
||||||
for role in os.listdir(model_save_dir):
|
for role in os.listdir(model_save_dir):
|
||||||
if "dataset_indices" in role:
|
if "dataset_indices" in role:
|
||||||
|
|
|
@ -9,19 +9,17 @@ from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
|
||||||
|
|
||||||
|
|
||||||
def parse_node_id(node_name: str, prefix: str) -> int:
|
def parse_node_id(node_name: str, prefix: str) -> int:
|
||||||
return int(node_name.split(prefix)[-1])
|
return int(node_name.split(prefix)[-1])
|
||||||
|
|
||||||
|
|
||||||
def parse_nodelist(nodelist: str, prefix: str) -> List[str]:
|
def parse_nodelist(cluster_config, nodelist: str, prefix: str) -> List[str]:
|
||||||
if not nodelist.startswith(prefix):
|
if not nodelist.startswith(prefix):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Node list `{nodelist}` does not start with hostname prefix `{prefix}`."
|
f"Node list `{nodelist}` does not start with hostname prefix `{prefix}`."
|
||||||
)
|
)
|
||||||
n = cluster_spec.suffix_n_digits
|
n = len(str(cluster_config.n_nodes))
|
||||||
nodelist = nodelist.replace(prefix, "")
|
nodelist = nodelist.replace(prefix, "")
|
||||||
if "[" not in nodelist:
|
if "[" not in nodelist:
|
||||||
return [prefix + nodelist]
|
return [prefix + nodelist]
|
||||||
|
@ -38,30 +36,6 @@ def parse_nodelist(nodelist: str, prefix: str) -> List[str]:
|
||||||
return [f"{prefix}{node_id:0{n}d}" for node_id in node_ids]
|
return [f"{prefix}{node_id:0{n}d}" for node_id in node_ids]
|
||||||
|
|
||||||
|
|
||||||
def nodelist_from_nodes(nodes: List[str], prefix: str) -> str:
|
|
||||||
n = cluster_spec.suffix_n_digits
|
|
||||||
node_ids = sorted([parse_node_id(node, prefix) for node in nodes])
|
|
||||||
assert len(node_ids) > 0
|
|
||||||
if len(node_ids) == 1:
|
|
||||||
return f"{prefix}{node_ids[0]:02d}"
|
|
||||||
else:
|
|
||||||
node_reprs = []
|
|
||||||
start, end = node_ids[0], node_ids[0]
|
|
||||||
for i in range(len(node_ids)):
|
|
||||||
node_id = node_ids[i]
|
|
||||||
next_node_id = node_ids[i + 1] if i + 1 < len(node_ids) else -1
|
|
||||||
if node_id + 1 == next_node_id:
|
|
||||||
end = next_node_id
|
|
||||||
else:
|
|
||||||
if start == end:
|
|
||||||
node_reprs.append(f"{start:0{n}d}")
|
|
||||||
else:
|
|
||||||
node_reprs.append(f"{start:0{n}d}-{end:0{n}d}")
|
|
||||||
start = next_node_id
|
|
||||||
end = next_node_id
|
|
||||||
return f"{prefix}[{','.join(node_reprs)}]"
|
|
||||||
|
|
||||||
|
|
||||||
def are_ones_contiguous(binary_array: np.ndarray):
|
def are_ones_contiguous(binary_array: np.ndarray):
|
||||||
one_indices = np.where(binary_array == 1)[0]
|
one_indices = np.where(binary_array == 1)[0]
|
||||||
if len(one_indices) == 0:
|
if len(one_indices) == 0:
|
||||||
|
|
|
@ -19,6 +19,7 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
|
||||||
|
from realhf.api.cli_args import BaseExperimentConfig, NameResolveConfig
|
||||||
from realhf.api.core.data_api import SequenceSample
|
from realhf.api.core.data_api import SequenceSample
|
||||||
from realhf.base import constants, gpu_utils, logging, name_resolve, names, topology
|
from realhf.base import constants, gpu_utils, logging, name_resolve, names, topology
|
||||||
from realhf.base.topology import (
|
from realhf.base.topology import (
|
||||||
|
@ -92,6 +93,14 @@ class StandaloneTestingProcess(mp.Process):
|
||||||
if constants.use_cuda():
|
if constants.use_cuda():
|
||||||
torch.cuda.set_device(0)
|
torch.cuda.set_device(0)
|
||||||
|
|
||||||
|
from realhf.api.cli_args import NameResolveConfig
|
||||||
|
|
||||||
|
name_resolve.reconfigure(
|
||||||
|
NameResolveConfig(
|
||||||
|
"nfs", f"/tmp/areal/testing/{self.expr_name}/{self.trial_name}/"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.barrier.wait()
|
self.barrier.wait()
|
||||||
|
|
||||||
if self.setup_dist_torch:
|
if self.setup_dist_torch:
|
||||||
|
@ -103,7 +112,11 @@ class StandaloneTestingProcess(mp.Process):
|
||||||
if self.dist_backend is None:
|
if self.dist_backend is None:
|
||||||
self.dist_backend = "gloo" if not constants.use_cuda() else "nccl"
|
self.dist_backend = "gloo" if not constants.use_cuda() else "nccl"
|
||||||
setup_global_comm(
|
setup_global_comm(
|
||||||
self.expr_name, self.trial_name, self.rank, backend=self.dist_backend
|
BaseExperimentConfig(),
|
||||||
|
self.expr_name,
|
||||||
|
self.trial_name,
|
||||||
|
self.rank,
|
||||||
|
backend=self.dist_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
# misc setup
|
# misc setup
|
||||||
|
@ -149,6 +162,11 @@ class LocalMultiProcessTest:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||||
os.environ["GPU_DEVICES_ISOLATED"] = str(1)
|
os.environ["GPU_DEVICES_ISOLATED"] = str(1)
|
||||||
|
expr_name = expr_name if expr_name is not None else _DEFAULT_EXPR_NAME
|
||||||
|
trial_name = trial_name if trial_name is not None else _DEFAULT_TRIAL_NAME
|
||||||
|
name_resolve.reconfigure(
|
||||||
|
NameResolveConfig("nfs", f"/tmp/areal/testing/{expr_name}/{trial_name}/")
|
||||||
|
)
|
||||||
clear_name_resolve(expr_name, trial_name)
|
clear_name_resolve(expr_name, trial_name)
|
||||||
self.timeout_secs = timeout_secs
|
self.timeout_secs = timeout_secs
|
||||||
self.processes = []
|
self.processes = []
|
||||||
|
|
|
@ -21,7 +21,6 @@ from typing import Dict, List, NamedTuple, Optional, Tuple
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
import realhf.base.logging as logging
|
import realhf.base.logging as logging
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
|
||||||
from realhf.base.constants import NCCL_DEFAULT_TIMEOUT
|
from realhf.base.constants import NCCL_DEFAULT_TIMEOUT
|
||||||
|
|
||||||
logger = logging.getLogger("Topology")
|
logger = logging.getLogger("Topology")
|
||||||
|
@ -185,7 +184,7 @@ class ProcessTopology:
|
||||||
omit_axes = frozenset(omit_axes)
|
omit_axes = frozenset(omit_axes)
|
||||||
axes = [a for a in self.get_axis_names() if a not in omit_axes]
|
axes = [a for a in self.get_axis_names() if a not in omit_axes]
|
||||||
names = []
|
names = []
|
||||||
n = cluster_spec.suffix_n_digits
|
n = len(str(len(self.mapping)))
|
||||||
for ax in axes:
|
for ax in axes:
|
||||||
ax_rank = getattr(self.get_coord(rank=rank), ax)
|
ax_rank = getattr(self.get_coord(rank=rank), ax)
|
||||||
names.append(f"{ax}{inner_sep}{ax_rank:0{n}d}")
|
names.append(f"{ax}{inner_sep}{ax_rank:0{n}d}")
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import os
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
import realhf.base.logging as logging
|
import realhf.base.logging as logging
|
||||||
|
@ -13,6 +14,7 @@ from realhf.api.core.config import (
|
||||||
)
|
)
|
||||||
from realhf.api.core.model_api import GenerationHyperparameters
|
from realhf.api.core.model_api import GenerationHyperparameters
|
||||||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
||||||
|
from realhf.base import constants
|
||||||
from realhf.experiments.async_exp.async_rl_exp import AsyncRLExperimentConfig
|
from realhf.experiments.async_exp.async_rl_exp import AsyncRLExperimentConfig
|
||||||
from realhf.experiments.common.ppo_math_exp import PPOMATHConfig
|
from realhf.experiments.common.ppo_math_exp import PPOMATHConfig
|
||||||
from realhf.experiments.common.utils import asdict
|
from realhf.experiments.common.utils import asdict
|
||||||
|
@ -29,6 +31,9 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
|
||||||
"math-single-step",
|
"math-single-step",
|
||||||
args=dict(
|
args=dict(
|
||||||
gconfig=self.generation_config,
|
gconfig=self.generation_config,
|
||||||
|
answer_save_path=os.path.join(
|
||||||
|
constants.get_log_path(self), "generated"
|
||||||
|
),
|
||||||
tokenizer_path=self.actor.path,
|
tokenizer_path=self.actor.path,
|
||||||
success_rate_lb=self.success_rate_lb,
|
success_rate_lb=self.success_rate_lb,
|
||||||
success_rate_ub=self.success_rate_ub,
|
success_rate_ub=self.success_rate_ub,
|
||||||
|
@ -40,7 +45,10 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
|
||||||
@property
|
@property
|
||||||
def env(self) -> EnvServiceAbstraction:
|
def env(self) -> EnvServiceAbstraction:
|
||||||
return EnvServiceAbstraction(
|
return EnvServiceAbstraction(
|
||||||
"math-code-single-step", args=dict(dataset_path=self.dataset.path)
|
"math-code-single-step",
|
||||||
|
args=dict(
|
||||||
|
dataset_path=self.dataset.path,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -38,8 +38,6 @@ from realhf.api.core.system_api import (
|
||||||
TasksGroup,
|
TasksGroup,
|
||||||
)
|
)
|
||||||
from realhf.api.quickstart.device_mesh import RPCAllocation
|
from realhf.api.quickstart.device_mesh import RPCAllocation
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
|
||||||
from realhf.experiments.common.check import check_valid_sglang, check_valid_vllm
|
|
||||||
from realhf.experiments.common.common import CommonExperimentConfig
|
from realhf.experiments.common.common import CommonExperimentConfig
|
||||||
from realhf.experiments.common.utils import (
|
from realhf.experiments.common.utils import (
|
||||||
AllocationMode,
|
AllocationMode,
|
||||||
|
@ -92,49 +90,57 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
|
||||||
return ExperimentScheduling(
|
return ExperimentScheduling(
|
||||||
master_worker=TasksGroup(
|
master_worker=TasksGroup(
|
||||||
count=1,
|
count=1,
|
||||||
scheduling=Scheduling.master_worker_default(
|
scheduling=Scheduling(
|
||||||
cpu=self.cpus_per_master_worker,
|
cpu=self.cpus_per_master_worker,
|
||||||
|
gpu=0,
|
||||||
mem=self.mem_per_master_worker,
|
mem=self.mem_per_master_worker,
|
||||||
nodelist=self.nodelist,
|
nodelist=self.nodelist,
|
||||||
exclude=self.exclude,
|
exclude=self.exclude,
|
||||||
|
container_image=self.cluster.cpu_image,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
model_worker=TasksGroup(
|
model_worker=TasksGroup(
|
||||||
count=train_world_size,
|
count=train_world_size,
|
||||||
scheduling=Scheduling.model_worker_default(
|
scheduling=Scheduling(
|
||||||
cpu=self.cpus_per_model_worker,
|
cpu=self.cpus_per_model_worker,
|
||||||
gpu=1,
|
gpu=1,
|
||||||
mem=self.mem_per_model_worker,
|
mem=self.mem_per_model_worker,
|
||||||
nodelist=self.nodelist,
|
nodelist=self.nodelist,
|
||||||
exclude=self.exclude,
|
exclude=self.exclude,
|
||||||
|
container_image=self.cluster.gpu_image,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
generation_server=TasksGroup(
|
generation_server=TasksGroup(
|
||||||
count=gen_world_size // gen_tp_size,
|
count=gen_world_size // gen_tp_size,
|
||||||
scheduling=Scheduling.generation_server_default(
|
scheduling=Scheduling(
|
||||||
cpu=self.cpus_per_generation_server,
|
cpu=self.cpus_per_generation_server,
|
||||||
gpu=gen_tp_size,
|
gpu=gen_tp_size,
|
||||||
mem=self.mem_per_generation_server,
|
mem=self.mem_per_generation_server,
|
||||||
nodelist=self.nodelist,
|
nodelist=self.nodelist,
|
||||||
exclude=self.exclude,
|
exclude=self.exclude,
|
||||||
|
container_image=self.cluster.gpu_infer_image,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
gserver_manager=TasksGroup(
|
gserver_manager=TasksGroup(
|
||||||
count=1,
|
count=1,
|
||||||
scheduling=Scheduling.gserver_manager_default(
|
scheduling=Scheduling(
|
||||||
cpu=self.cpus_per_gserver_manager,
|
cpu=self.cpus_per_gserver_manager,
|
||||||
|
gpu=0,
|
||||||
mem=self.mem_per_gserver_manager,
|
mem=self.mem_per_gserver_manager,
|
||||||
nodelist=self.nodelist,
|
nodelist=self.nodelist,
|
||||||
exclude=self.exclude,
|
exclude=self.exclude,
|
||||||
|
container_image=self.cluster.cpu_image,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
rollout_worker=TasksGroup(
|
rollout_worker=TasksGroup(
|
||||||
count=self.n_rollout_workers or train_world_size,
|
count=self.n_rollout_workers or train_world_size,
|
||||||
scheduling=Scheduling.rollout_worker_default(
|
scheduling=Scheduling(
|
||||||
cpu=self.cpus_per_rollout_worker,
|
cpu=self.cpus_per_rollout_worker,
|
||||||
|
gpu=0,
|
||||||
mem=self.mem_per_rollout_worker,
|
mem=self.mem_per_rollout_worker,
|
||||||
nodelist=self.nodelist,
|
nodelist=self.nodelist,
|
||||||
exclude=self.exclude,
|
exclude=self.exclude,
|
||||||
|
container_image=self.cluster.cpu_image,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -162,7 +168,11 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
|
||||||
# NOTE: here we use puller stream to wrap the original dataset
|
# NOTE: here we use puller stream to wrap the original dataset
|
||||||
datasets=[
|
datasets=[
|
||||||
DatasetAbstraction(
|
DatasetAbstraction(
|
||||||
"puller_stream", args=dict(dataset_cfgs=self.datasets)
|
"puller_stream",
|
||||||
|
args=dict(
|
||||||
|
dataset_cfgs=self.datasets,
|
||||||
|
args=self,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
torch_cache_mysophobia=self.torch_cache_mysophobia,
|
torch_cache_mysophobia=self.torch_cache_mysophobia,
|
||||||
|
|
|
@ -1,259 +0,0 @@
|
||||||
# Copyright 2025 Ant Group Inc.
|
|
||||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import dataclasses
|
|
||||||
import itertools
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from typing import *
|
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
from realhf.api.cli_args import (
|
|
||||||
MFCConfig,
|
|
||||||
ModelTrainEvalConfig,
|
|
||||||
ParallelismConfig,
|
|
||||||
PromptOnlyDatasetConfig,
|
|
||||||
)
|
|
||||||
from realhf.api.core.config import (
|
|
||||||
DatasetAbstraction,
|
|
||||||
ModelInterfaceAbstraction,
|
|
||||||
ModelInterfaceType,
|
|
||||||
)
|
|
||||||
from realhf.api.core.dfg import MFCDef
|
|
||||||
from realhf.api.core.system_api import ExperimentConfig
|
|
||||||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
|
||||||
from realhf.base import constants, logging
|
|
||||||
from realhf.base.topology import decompose_to_three_factors
|
|
||||||
from realhf.experiments.common.common import CommonExperimentConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger("Profiling Experiment", "system")
|
|
||||||
|
|
||||||
|
|
||||||
def default_parallel_config(n_gpus: int) -> List[Dict[str, Any]]:
|
|
||||||
factors = decompose_to_three_factors(n_gpus)
|
|
||||||
x = [
|
|
||||||
{
|
|
||||||
"data_parallel_size": dp,
|
|
||||||
"tensor_parallel_size": tp,
|
|
||||||
"pipeline_parallel_size": pp,
|
|
||||||
"use_sequence_parallel": tp > 1,
|
|
||||||
}
|
|
||||||
for dp, tp, pp in factors
|
|
||||||
]
|
|
||||||
x += [
|
|
||||||
{
|
|
||||||
"data_parallel_size": dp,
|
|
||||||
"tensor_parallel_size": tp,
|
|
||||||
"pipeline_parallel_size": pp,
|
|
||||||
"use_sequence_parallel": False,
|
|
||||||
}
|
|
||||||
for dp, tp, pp in factors
|
|
||||||
if tp > 1
|
|
||||||
]
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def dataclass_from_dict(klass, d):
|
|
||||||
try:
|
|
||||||
fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)}
|
|
||||||
return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d})
|
|
||||||
except:
|
|
||||||
return d # Not a dataclass field
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class ProfileConfig(CommonExperimentConfig):
|
|
||||||
"""The experiment configuration for profiling layers and interfaces.
|
|
||||||
|
|
||||||
The `initial_setup` method in this experiment will return a list of
|
|
||||||
experiment configurations, which will be run sequentially.
|
|
||||||
All configurations share the same experiment name, trial name,
|
|
||||||
and the scheduling configuration. They can have different models,
|
|
||||||
datasets, or parallel strategies, as long as they always occupy
|
|
||||||
a fixed number of GPUs.
|
|
||||||
|
|
||||||
It's important to note that, if any error occurs during the execution,
|
|
||||||
the experiment will terminate immediately. In particular, the OOM error
|
|
||||||
should not appear because the profiling setup usually uses a small model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
interfaces_jsonl: str = ""
|
|
||||||
allocations_jsonl: Optional[str] = None
|
|
||||||
handle_names: Optional[List[str]] = None
|
|
||||||
n_mbs: Optional[List[int]] = None
|
|
||||||
batch_sizes: Optional[List[int]] = None
|
|
||||||
models_jsonl: str = ""
|
|
||||||
datasets_jsonl: str = ""
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Check that handle_name belones to ["train_step", "generate", "inference"]
|
|
||||||
self.handle_names = list(set(self.handle_names))
|
|
||||||
if any(
|
|
||||||
k not in ["train_step", "generate", "inference"] for k in self.handle_names
|
|
||||||
):
|
|
||||||
raise NotImplementedError(f"Unknown handle_name: {self.handle_name}")
|
|
||||||
|
|
||||||
# Check the configuration of interfaces
|
|
||||||
if not os.path.exists(self.interfaces_jsonl):
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"File not found: {self.interfaces_jsonl}. "
|
|
||||||
"It should be a JSONL file specifying the arguments "
|
|
||||||
"for the interface implementation."
|
|
||||||
)
|
|
||||||
with open(self.interfaces_jsonl, "r") as f:
|
|
||||||
self.interface_kwargs = [json.loads(l) for l in f.readlines()]
|
|
||||||
|
|
||||||
# Check the configuration of parallel strategies.
|
|
||||||
if self.allocations_jsonl is None:
|
|
||||||
self.parallel_kwargs = default_parallel_config(
|
|
||||||
self.n_nodes * self.n_gpus_per_node
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert self.allocations_jsonl.endswith(".jsonl")
|
|
||||||
assert os.path.exists(self.allocations_jsonl)
|
|
||||||
with open(self.allocations_jsonl, "r") as f:
|
|
||||||
self.parallel_kwargs = [json.loads(l) for l in f.readlines()]
|
|
||||||
for pcfg in self.parallel_kwargs:
|
|
||||||
assert isinstance(pcfg, dict), type(pcfg)
|
|
||||||
assert all(
|
|
||||||
k
|
|
||||||
in [
|
|
||||||
"data_parallel_size",
|
|
||||||
"tensor_parallel_size",
|
|
||||||
"pipeline_parallel_size",
|
|
||||||
"use_sequence_parallel",
|
|
||||||
]
|
|
||||||
for k in pcfg.keys()
|
|
||||||
), pcfg.keys()
|
|
||||||
assert (self.n_nodes * self.n_gpus_per_node) == (
|
|
||||||
pcfg.get("data_parallel_size", 1)
|
|
||||||
* pcfg.get("tensor_parallel_size", 1)
|
|
||||||
* pcfg.get("pipeline_parallel_size", 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.n_mbs is None:
|
|
||||||
self.n_mbs = [1]
|
|
||||||
else:
|
|
||||||
self.n_mbs = OmegaConf.to_container(self.n_mbs)
|
|
||||||
assert isinstance(self.n_mbs, list), type(self.n_mbs)
|
|
||||||
assert all(isinstance(x, int) for x in self.n_mbs)
|
|
||||||
|
|
||||||
assert self.batch_sizes is not None
|
|
||||||
|
|
||||||
assert os.path.exists(self.models_jsonl)
|
|
||||||
with open(self.models_jsonl, "r") as f:
|
|
||||||
self.model_kwargs = [json.loads(l) for l in f.readlines()]
|
|
||||||
|
|
||||||
assert os.path.exists(self.datasets_jsonl)
|
|
||||||
with open(self.datasets_jsonl, "r") as f:
|
|
||||||
self.dataset_kwargs = [json.loads(l) for l in f.readlines()]
|
|
||||||
assert all(x["type_"] == "prompt" for x in self.dataset_kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def allocations(self):
|
|
||||||
return dict(default=self._tmp_allocation)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def models(self):
|
|
||||||
return dict(default=self._tmp_model)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tokenizer_name_or_path(self):
|
|
||||||
return self._tmp_model.path
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_prompt_len(self):
|
|
||||||
return self._tmp_dataset.args["max_length"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def datasets(self):
|
|
||||||
return [self._tmp_dataset]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def rpcs(self):
|
|
||||||
return dict(default=self._tmp_rpc)
|
|
||||||
|
|
||||||
def initial_setup(self) -> List[ExperimentConfig]:
|
|
||||||
self.allocation_mode = "manual"
|
|
||||||
setups = []
|
|
||||||
setup_log_path = os.path.join(
|
|
||||||
constants.LOG_ROOT,
|
|
||||||
self.experiment_name,
|
|
||||||
self.trial_name,
|
|
||||||
"setups.jsonl",
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Experiment setup configurations of the profiling experiment "
|
|
||||||
f"will be saved to: {setup_log_path}"
|
|
||||||
)
|
|
||||||
with open(setup_log_path, "w") as f:
|
|
||||||
# batch size in the most outer loop to delay the possible OOM error
|
|
||||||
for (
|
|
||||||
bs,
|
|
||||||
pcfg,
|
|
||||||
n_mbs,
|
|
||||||
model_cfg,
|
|
||||||
dataset_cfg,
|
|
||||||
handle_name,
|
|
||||||
interface_cfg,
|
|
||||||
) in itertools.product(
|
|
||||||
self.batch_sizes,
|
|
||||||
self.parallel_kwargs,
|
|
||||||
self.n_mbs,
|
|
||||||
self.model_kwargs,
|
|
||||||
self.dataset_kwargs,
|
|
||||||
self.handle_names,
|
|
||||||
self.interface_kwargs,
|
|
||||||
):
|
|
||||||
if handle_name == "generate" and pcfg["use_sequence_parallel"]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
kwargs_stat = dict(
|
|
||||||
parallel=pcfg,
|
|
||||||
n_mbs=n_mbs,
|
|
||||||
model=model_cfg,
|
|
||||||
dataset=dataset_cfg,
|
|
||||||
interface=interface_cfg,
|
|
||||||
bs=bs,
|
|
||||||
)
|
|
||||||
f.write(json.dumps(kwargs_stat) + "\n")
|
|
||||||
|
|
||||||
# Create tmp object for constructing experiment setups
|
|
||||||
self._tmp_allocation = MFCConfig(
|
|
||||||
parallel=ParallelismConfig(**pcfg), n_mbs=n_mbs
|
|
||||||
)
|
|
||||||
self._tmp_model = dataclass_from_dict(ModelTrainEvalConfig, model_cfg)
|
|
||||||
self._tmp_dataset = DatasetAbstraction(**dataset_cfg)
|
|
||||||
if handle_name == "train_step":
|
|
||||||
interface_type = ModelInterfaceType.TRAIN_STEP
|
|
||||||
elif handle_name == "inference":
|
|
||||||
interface_type = ModelInterfaceType.INFERENCE
|
|
||||||
elif handle_name == "generate":
|
|
||||||
interface_type = ModelInterfaceType.GENERATE
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Unknown which handle to run in the interface: {self.handle_name}"
|
|
||||||
)
|
|
||||||
self._tmp_rpc = MFCDef(
|
|
||||||
n_seqs=bs,
|
|
||||||
name="default",
|
|
||||||
n_mbs=n_mbs,
|
|
||||||
interface_type=interface_type,
|
|
||||||
interface_impl=ModelInterfaceAbstraction(**interface_cfg),
|
|
||||||
model_name="default",
|
|
||||||
input_keys=["packed_prompts"],
|
|
||||||
log_return_value=False,
|
|
||||||
balanced_dp=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
setup = copy.deepcopy(super().initial_setup())
|
|
||||||
for m in setup.model_worker:
|
|
||||||
m.profile_mode = True
|
|
||||||
setups.append(setup)
|
|
||||||
return setups
|
|
||||||
|
|
||||||
|
|
||||||
register_quickstart_exp("profile", ProfileConfig)
|
|
|
@ -42,7 +42,6 @@ from realhf.api.quickstart.device_mesh import (
|
||||||
RPCAllocation,
|
RPCAllocation,
|
||||||
make_device_mesh_from_name,
|
make_device_mesh_from_name,
|
||||||
)
|
)
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
|
||||||
from realhf.experiments.common.check import (
|
from realhf.experiments.common.check import (
|
||||||
check_valid_model_and_path,
|
check_valid_model_and_path,
|
||||||
check_valid_optimizer,
|
check_valid_optimizer,
|
||||||
|
@ -164,21 +163,24 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
||||||
return ExperimentScheduling(
|
return ExperimentScheduling(
|
||||||
master_worker=TasksGroup(
|
master_worker=TasksGroup(
|
||||||
count=1,
|
count=1,
|
||||||
scheduling=Scheduling.master_worker_default(
|
scheduling=Scheduling(
|
||||||
cpu=self.cpus_per_master_worker,
|
cpu=self.cpus_per_master_worker,
|
||||||
|
gpu=0,
|
||||||
mem=self.mem_per_master_worker,
|
mem=self.mem_per_master_worker,
|
||||||
nodelist=self.nodelist,
|
nodelist=self.nodelist,
|
||||||
exclude=self.exclude,
|
exclude=self.exclude,
|
||||||
|
container_image=self.cluster.cpu_image,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
model_worker=TasksGroup(
|
model_worker=TasksGroup(
|
||||||
count=self.n_nodes * self.n_gpus_per_node,
|
count=self.n_nodes * self.n_gpus_per_node,
|
||||||
scheduling=Scheduling.model_worker_default(
|
scheduling=Scheduling(
|
||||||
cpu=self.cpus_per_model_worker,
|
cpu=self.cpus_per_model_worker,
|
||||||
gpu=1,
|
gpu=1,
|
||||||
mem=self.mem_per_model_worker,
|
mem=self.mem_per_model_worker,
|
||||||
nodelist=self.nodelist,
|
nodelist=self.nodelist,
|
||||||
exclude=self.exclude,
|
exclude=self.exclude,
|
||||||
|
container_image=self.cluster.gpu_image,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -326,6 +328,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
||||||
rpc=rpc,
|
rpc=rpc,
|
||||||
device_mesh=(
|
device_mesh=(
|
||||||
make_device_mesh_from_name(
|
make_device_mesh_from_name(
|
||||||
|
self.cluster,
|
||||||
self.nodelist,
|
self.nodelist,
|
||||||
self.allocations[rpc_type].device_mesh,
|
self.allocations[rpc_type].device_mesh,
|
||||||
self.global_device_mesh.n_gpus_per_node,
|
self.global_device_mesh.n_gpus_per_node,
|
||||||
|
@ -579,7 +582,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
||||||
)
|
)
|
||||||
if self.n_gpus_per_node > self.cluster.n_gpus_per_node:
|
if self.n_gpus_per_node > self.cluster.n_gpus_per_node:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Number of 7used GPUs per node {self.n_gpus_per_node} should not be larger than the cluster limit {self.cluster.n_gpus_per_node}"
|
f"Number of used GPUs per node {self.n_gpus_per_node} should not be larger than the cluster limit {self.cluster.n_gpus_per_node}"
|
||||||
)
|
)
|
||||||
if self.n_nodes > 1 and self.n_gpus_per_node != self.cluster.n_gpus_per_node:
|
if self.n_nodes > 1 and self.n_gpus_per_node != self.cluster.n_gpus_per_node:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -2,9 +2,9 @@
|
||||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import os
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import realhf.base.logging as logging
|
|
||||||
from realhf.api.cli_args import MathCodeEvalOptions, ModelTrainEvalConfig
|
from realhf.api.cli_args import MathCodeEvalOptions, ModelTrainEvalConfig
|
||||||
from realhf.api.core.config import (
|
from realhf.api.core.config import (
|
||||||
DatasetAbstraction,
|
DatasetAbstraction,
|
||||||
|
@ -13,6 +13,7 @@ from realhf.api.core.config import (
|
||||||
)
|
)
|
||||||
from realhf.api.core.dfg import MFCDef
|
from realhf.api.core.dfg import MFCDef
|
||||||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
||||||
|
from realhf.base import constants, logging
|
||||||
from realhf.experiments.common.common import CommonExperimentConfig
|
from realhf.experiments.common.common import CommonExperimentConfig
|
||||||
from realhf.experiments.common.utils import asdict
|
from realhf.experiments.common.utils import asdict
|
||||||
|
|
||||||
|
@ -55,6 +56,9 @@ class MathCodeEvalConfig(MathCodeEvalOptions, CommonExperimentConfig):
|
||||||
dataset_path=self.dataset.path,
|
dataset_path=self.dataset.path,
|
||||||
tokenizer_path=self.actor.path,
|
tokenizer_path=self.actor.path,
|
||||||
rw_type=self.rw_type,
|
rw_type=self.rw_type,
|
||||||
|
answer_save_path=os.path.join(
|
||||||
|
constants.get_log_path(self), "generated"
|
||||||
|
),
|
||||||
check_xml_format=self.check_xml_format,
|
check_xml_format=self.check_xml_format,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
check_verifier_status=self.check_verifier_status,
|
check_verifier_status=self.check_verifier_status,
|
||||||
|
|
|
@ -6,7 +6,6 @@ import dataclasses
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import realhf.base.logging as logging
|
|
||||||
from realhf.api.cli_args import ModelTrainEvalConfig, PPOMATHExperimentOptions
|
from realhf.api.cli_args import ModelTrainEvalConfig, PPOMATHExperimentOptions
|
||||||
from realhf.api.core.config import (
|
from realhf.api.core.config import (
|
||||||
DatasetAbstraction,
|
DatasetAbstraction,
|
||||||
|
@ -16,6 +15,7 @@ from realhf.api.core.config import (
|
||||||
from realhf.api.core.dfg import MFCDef, ParamReallocHook
|
from realhf.api.core.dfg import MFCDef, ParamReallocHook
|
||||||
from realhf.api.core.system_api import ExperimentConfig
|
from realhf.api.core.system_api import ExperimentConfig
|
||||||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
||||||
|
from realhf.base import constants, logging
|
||||||
from realhf.experiments.common.common import CommonExperimentConfig
|
from realhf.experiments.common.common import CommonExperimentConfig
|
||||||
from realhf.experiments.common.utils import (
|
from realhf.experiments.common.utils import (
|
||||||
asdict,
|
asdict,
|
||||||
|
@ -132,6 +132,9 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
||||||
check_xml_format=self.check_xml_format,
|
check_xml_format=self.check_xml_format,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
check_verifier_status=self.check_verifier_status,
|
check_verifier_status=self.check_verifier_status,
|
||||||
|
answer_save_path=os.path.join(
|
||||||
|
constants.get_log_path(self), "generated"
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ class MathMultiTurnAgent(Agent):
|
||||||
self,
|
self,
|
||||||
gconfig,
|
gconfig,
|
||||||
tokenizer_path,
|
tokenizer_path,
|
||||||
|
answer_save_path,
|
||||||
reward_scaling=1.0,
|
reward_scaling=1.0,
|
||||||
reward_bias=0.0,
|
reward_bias=0.0,
|
||||||
turn_level_discount: float = 1.0,
|
turn_level_discount: float = 1.0,
|
||||||
|
@ -39,6 +40,7 @@ class MathMultiTurnAgent(Agent):
|
||||||
):
|
):
|
||||||
self.gconfig = gconfig.new(n=1)
|
self.gconfig = gconfig.new(n=1)
|
||||||
self.tokenizer = load_hf_tokenizer(tokenizer_path)
|
self.tokenizer = load_hf_tokenizer(tokenizer_path)
|
||||||
|
self.answer_save_path = answer_save_path
|
||||||
|
|
||||||
self.reward_scaling = reward_scaling
|
self.reward_scaling = reward_scaling
|
||||||
self.reward_bias = reward_bias
|
self.reward_bias = reward_bias
|
||||||
|
@ -245,10 +247,7 @@ class MathMultiTurnAgent(Agent):
|
||||||
for group_idx in range(group_size):
|
for group_idx in range(group_size):
|
||||||
# NOTE: we can ensure that only one process is logging this query id
|
# NOTE: we can ensure that only one process is logging this query id
|
||||||
gen_file_path = os.path.join(
|
gen_file_path = os.path.join(
|
||||||
constants.LOG_ROOT,
|
self.answer_save_path,
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"generated",
|
|
||||||
str(version_starts[group_idx]),
|
str(version_starts[group_idx]),
|
||||||
f"{qid}.txt",
|
f"{qid}.txt",
|
||||||
)
|
)
|
||||||
|
@ -271,10 +270,7 @@ class MathMultiTurnAgent(Agent):
|
||||||
_f.write(info + "\n")
|
_f.write(info + "\n")
|
||||||
|
|
||||||
train_pass_monitor_file_path = os.path.join(
|
train_pass_monitor_file_path = os.path.join(
|
||||||
constants.LOG_ROOT,
|
self.answer_save_path,
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"training_monitor",
|
|
||||||
str(version_starts[group_idx]),
|
str(version_starts[group_idx]),
|
||||||
f"{qid}.jsonl",
|
f"{qid}.jsonl",
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,6 +25,7 @@ class MathSingleStepAgent(Agent):
|
||||||
self,
|
self,
|
||||||
gconfig,
|
gconfig,
|
||||||
tokenizer_path,
|
tokenizer_path,
|
||||||
|
answer_save_path,
|
||||||
success_rate_lb,
|
success_rate_lb,
|
||||||
success_rate_ub,
|
success_rate_ub,
|
||||||
reward_scaling=1.0,
|
reward_scaling=1.0,
|
||||||
|
@ -32,6 +33,7 @@ class MathSingleStepAgent(Agent):
|
||||||
):
|
):
|
||||||
self.gconfig = gconfig
|
self.gconfig = gconfig
|
||||||
self.tokenizer = load_hf_tokenizer(tokenizer_path)
|
self.tokenizer = load_hf_tokenizer(tokenizer_path)
|
||||||
|
self.answer_save_path = answer_save_path
|
||||||
|
|
||||||
self.success_rate_lb = success_rate_lb
|
self.success_rate_lb = success_rate_lb
|
||||||
self.success_rate_ub = success_rate_ub
|
self.success_rate_ub = success_rate_ub
|
||||||
|
@ -198,10 +200,7 @@ class MathSingleStepAgent(Agent):
|
||||||
for group_idx in range(group_size):
|
for group_idx in range(group_size):
|
||||||
# NOTE: we can ensure that only one process is logging this query id
|
# NOTE: we can ensure that only one process is logging this query id
|
||||||
gen_file_path = os.path.join(
|
gen_file_path = os.path.join(
|
||||||
constants.LOG_ROOT,
|
self.answer_save_path,
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"generated",
|
|
||||||
str(version_starts[group_idx]),
|
str(version_starts[group_idx]),
|
||||||
f"{qid}.txt",
|
f"{qid}.txt",
|
||||||
)
|
)
|
||||||
|
@ -224,10 +223,7 @@ class MathSingleStepAgent(Agent):
|
||||||
_f.write(info + "\n")
|
_f.write(info + "\n")
|
||||||
|
|
||||||
train_pass_monitor_file_path = os.path.join(
|
train_pass_monitor_file_path = os.path.join(
|
||||||
constants.LOG_ROOT,
|
self.answer_save_path,
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"training_monitor",
|
|
||||||
str(version_starts[group_idx]),
|
str(version_starts[group_idx]),
|
||||||
f"{qid}.jsonl",
|
f"{qid}.jsonl",
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,7 +21,6 @@ from realhf.api.core import model_api
|
||||||
from realhf.api.core.data_api import SequenceSample
|
from realhf.api.core.data_api import SequenceSample
|
||||||
from realhf.base import constants, logging, pkg_version
|
from realhf.base import constants, logging, pkg_version
|
||||||
from realhf.base.datapack import flat2d
|
from realhf.base.datapack import flat2d
|
||||||
from realhf.base.monitor import CUDATimeMarkType, cuda_tmarked
|
|
||||||
from realhf.impl.model.backend.inference import PipelinableInferenceEngine
|
from realhf.impl.model.backend.inference import PipelinableInferenceEngine
|
||||||
from realhf.impl.model.backend.pipe_runner import PipelineRunner, PipeTrainInstrSet
|
from realhf.impl.model.backend.pipe_runner import PipelineRunner, PipeTrainInstrSet
|
||||||
from realhf.impl.model.modules.mlp import get_activation_fn
|
from realhf.impl.model.modules.mlp import get_activation_fn
|
||||||
|
@ -304,7 +303,6 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet):
|
||||||
self._no_sync_context.__exit__(None, None, None)
|
self._no_sync_context.__exit__(None, None, None)
|
||||||
self._no_sync_context = None
|
self._no_sync_context = None
|
||||||
|
|
||||||
@cuda_tmarked("bwd", CUDATimeMarkType.backward)
|
|
||||||
def _exec_backward_pass(
|
def _exec_backward_pass(
|
||||||
self,
|
self,
|
||||||
module: ReaLModel,
|
module: ReaLModel,
|
||||||
|
@ -342,7 +340,6 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet):
|
||||||
# self.engine.ddp.start_grad_sync()
|
# self.engine.ddp.start_grad_sync()
|
||||||
self.engine.finalize_grads()
|
self.engine.finalize_grads()
|
||||||
|
|
||||||
@cuda_tmarked("opt", CUDATimeMarkType.optim_step)
|
|
||||||
def _exec_optimizer_step(
|
def _exec_optimizer_step(
|
||||||
self,
|
self,
|
||||||
module: ReaLModel,
|
module: ReaLModel,
|
||||||
|
@ -489,8 +486,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
||||||
loss_scale *= constants.data_parallel_world_size()
|
loss_scale *= constants.data_parallel_world_size()
|
||||||
loss_scale *= self.engine.optim.get_loss_scale().item()
|
loss_scale *= self.engine.optim.get_loss_scale().item()
|
||||||
loss *= loss_scale
|
loss *= loss_scale
|
||||||
with cuda_tmarked("bwd", CUDATimeMarkType.backward):
|
loss.backward()
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
self.engine.finalize_grads()
|
self.engine.finalize_grads()
|
||||||
return self._step(version_steps)
|
return self._step(version_steps)
|
||||||
|
@ -530,7 +526,6 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
||||||
)
|
)
|
||||||
|
|
||||||
# wrapper for profiler
|
# wrapper for profiler
|
||||||
@cuda_tmarked("opt", CUDATimeMarkType.optim_step)
|
|
||||||
def _step(self, version_steps):
|
def _step(self, version_steps):
|
||||||
# omit the number of zeros in grads
|
# omit the number of zeros in grads
|
||||||
update_successful, grad_norm, _ = self.engine.optim.step()
|
update_successful, grad_norm, _ = self.engine.optim.step()
|
||||||
|
|
|
@ -32,7 +32,6 @@ from realhf.api.core.model_api import (
|
||||||
register_backend,
|
register_backend,
|
||||||
)
|
)
|
||||||
from realhf.base import (
|
from realhf.base import (
|
||||||
cluster,
|
|
||||||
constants,
|
constants,
|
||||||
datapack,
|
datapack,
|
||||||
gpu_utils,
|
gpu_utils,
|
||||||
|
@ -431,9 +430,9 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig):
|
||||||
def _initialize(self, model: Model, spec: FinetuneSpec) -> Model:
|
def _initialize(self, model: Model, spec: FinetuneSpec) -> Model:
|
||||||
if constants.pipe_parallel_world_size() != 1:
|
if constants.pipe_parallel_world_size() != 1:
|
||||||
raise RuntimeError("SGLang does not support pipe parallel size > 1.")
|
raise RuntimeError("SGLang does not support pipe parallel size > 1.")
|
||||||
if constants.tensor_parallel_world_size() > cluster.spec.n_gpus_per_node:
|
if constants.tensor_parallel_world_size() > torch.cuda.device_count():
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"AReaL's SGLang integration does not support model parallel size > n_gpus_per_node."
|
"AReaL's SGLang integration does not support model parallel size > torch.cuda.device_count()."
|
||||||
)
|
)
|
||||||
|
|
||||||
additional_args = dataclasses.asdict(self)
|
additional_args = dataclasses.asdict(self)
|
||||||
|
@ -453,6 +452,9 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig):
|
||||||
high=60000,
|
high=60000,
|
||||||
experiment_name=constants.experiment_name(),
|
experiment_name=constants.experiment_name(),
|
||||||
trial_name=constants.trial_name(),
|
trial_name=constants.trial_name(),
|
||||||
|
lockfile_root=os.path.join(
|
||||||
|
constants.get_cache_path(self.args), "ports"
|
||||||
|
),
|
||||||
),
|
),
|
||||||
group=constants.data_parallel_group(),
|
group=constants.data_parallel_group(),
|
||||||
)
|
)
|
||||||
|
@ -475,10 +477,6 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig):
|
||||||
tp_size=constants.tensor_parallel_world_size(),
|
tp_size=constants.tensor_parallel_world_size(),
|
||||||
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
|
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
|
||||||
base_gpu_id=int(os.environ["CUDA_VISIBLE_DEVICES"]),
|
base_gpu_id=int(os.environ["CUDA_VISIBLE_DEVICES"]),
|
||||||
file_storage_path=os.path.join(
|
|
||||||
constants.SGLANG_CACHE_PATH,
|
|
||||||
f"sglang_storage{constants.data_parallel_rank()}",
|
|
||||||
),
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
dp_size=1, # TODO: check whether we require SGLang dp
|
dp_size=1, # TODO: check whether we require SGLang dp
|
||||||
load_balance_method="round_robin",
|
load_balance_method="round_robin",
|
||||||
|
|
|
@ -46,6 +46,7 @@ def filter_match_mwids(
|
||||||
|
|
||||||
|
|
||||||
def setup_global_comm(
|
def setup_global_comm(
|
||||||
|
args,
|
||||||
expr_name: str,
|
expr_name: str,
|
||||||
trial_name: str,
|
trial_name: str,
|
||||||
worker_index: int,
|
worker_index: int,
|
||||||
|
@ -87,10 +88,12 @@ def setup_global_comm(
|
||||||
)
|
)
|
||||||
|
|
||||||
if constants.use_cuda():
|
if constants.use_cuda():
|
||||||
assert len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) == 1, os.environ[
|
if len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) == 1:
|
||||||
"CUDA_VISIBLE_DEVICES"
|
local_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"])
|
||||||
]
|
else:
|
||||||
local_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"])
|
local_gpu_id = int(
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"].split(",")[worker_index]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
local_gpu_id = global_rank
|
local_gpu_id = global_rank
|
||||||
|
|
||||||
|
@ -100,7 +103,11 @@ def setup_global_comm(
|
||||||
|
|
||||||
if worker_index == 0:
|
if worker_index == 0:
|
||||||
host_ip = socket.gethostbyname(socket.gethostname())
|
host_ip = socket.gethostbyname(socket.gethostname())
|
||||||
port = network.find_free_port(experiment_name=expr_name, trial_name=trial_name)
|
port = network.find_free_port(
|
||||||
|
experiment_name=expr_name,
|
||||||
|
trial_name=trial_name,
|
||||||
|
lockfile_root=os.path.join(constants.get_cache_path(args), "ports"),
|
||||||
|
)
|
||||||
pg_init_addr = f"tcp://{host_ip}:{port}"
|
pg_init_addr = f"tcp://{host_ip}:{port}"
|
||||||
name_resolve.add(pg_master_name, pg_init_addr, keepalive_ttl=300)
|
name_resolve.add(pg_master_name, pg_init_addr, keepalive_ttl=300)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -181,6 +181,7 @@ def retokenize_and_verify(
|
||||||
class MultiTaskRewardInterface(model_api.ModelInterface):
|
class MultiTaskRewardInterface(model_api.ModelInterface):
|
||||||
dataset_path: str = ""
|
dataset_path: str = ""
|
||||||
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
tokenizer_path: str = "/storage/models/Qwen__Qwen2.5-1.5B"
|
||||||
|
answer_save_path: str = "."
|
||||||
output_scaling: float = 1.0
|
output_scaling: float = 1.0
|
||||||
output_bias: float = 0.0
|
output_bias: float = 0.0
|
||||||
rw_type: str = "sparse"
|
rw_type: str = "sparse"
|
||||||
|
@ -363,10 +364,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
||||||
):
|
):
|
||||||
tik = time.perf_counter()
|
tik = time.perf_counter()
|
||||||
gen_file_path = os.path.join(
|
gen_file_path = os.path.join(
|
||||||
constants.LOG_ROOT,
|
self.answer_save_path,
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"generated",
|
|
||||||
task_type,
|
task_type,
|
||||||
f"v{model.version.global_step}r{dist.get_rank()}.txt",
|
f"v{model.version.global_step}r{dist.get_rank()}.txt",
|
||||||
)
|
)
|
||||||
|
@ -386,10 +384,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
||||||
_f.write(info + "\n")
|
_f.write(info + "\n")
|
||||||
|
|
||||||
gen_file_path = os.path.join(
|
gen_file_path = os.path.join(
|
||||||
constants.LOG_ROOT,
|
self.answer_save_path,
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"generated_jsonl",
|
|
||||||
task_type,
|
task_type,
|
||||||
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
|
f"v{model.version.global_step}r{dist.get_rank()}.jsonl",
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,7 +17,6 @@ import transformers
|
||||||
from realhf.api.core import model_api
|
from realhf.api.core import model_api
|
||||||
from realhf.api.core.config import ModelName
|
from realhf.api.core.config import ModelName
|
||||||
from realhf.base import constants, logging, topology
|
from realhf.base import constants, logging, topology
|
||||||
from realhf.base.monitor import CUDATimeMarkType, cuda_tmark, cuda_tmarked
|
|
||||||
from realhf.impl.model.comm.global_comm import NCCLProcessGroupInfo
|
from realhf.impl.model.comm.global_comm import NCCLProcessGroupInfo
|
||||||
from realhf.impl.model.comm.param_realloc import (
|
from realhf.impl.model.comm.param_realloc import (
|
||||||
ReparallelizeReceiverStep,
|
ReparallelizeReceiverStep,
|
||||||
|
@ -470,13 +469,11 @@ class ReaLModel(nn.Module):
|
||||||
pp_input_buf[:batch_length] = x.pp_input
|
pp_input_buf[:batch_length] = x.pp_input
|
||||||
x.pp_input = pp_input_buf
|
x.pp_input = pp_input_buf
|
||||||
|
|
||||||
tmark_type = CUDATimeMarkType.forward
|
# Main forward calls.
|
||||||
with cuda_tmarked("fwd", tmark_type):
|
if not self._offloaded:
|
||||||
# Main forward calls.
|
x, ys = self.__forward(x, ys)
|
||||||
if not self._offloaded:
|
else:
|
||||||
x, ys = self.__forward(x, ys)
|
x, ys = self.__overlapped_load_forward(x, ys)
|
||||||
else:
|
|
||||||
x, ys = self.__overlapped_load_forward(x, ys)
|
|
||||||
|
|
||||||
# Resume from padding.
|
# Resume from padding.
|
||||||
if (
|
if (
|
||||||
|
@ -644,7 +641,6 @@ class ReaLModel(nn.Module):
|
||||||
self._reparallelize_targets[(from_model_name, to_model_name)] = rtgt
|
self._reparallelize_targets[(from_model_name, to_model_name)] = rtgt
|
||||||
|
|
||||||
# FIXME: we can get topo given model name from constants
|
# FIXME: we can get topo given model name from constants
|
||||||
@cuda_tmark("param_realloc", CUDATimeMarkType.mem_layout)
|
|
||||||
def build_reparallelized_layers_async(
|
def build_reparallelized_layers_async(
|
||||||
self,
|
self,
|
||||||
from_model_name: ModelName,
|
from_model_name: ModelName,
|
||||||
|
|
|
@ -166,7 +166,7 @@ def build_leave_one_indices(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _gather_logprobs(
|
def gather_logprobs(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
labels: torch.Tensor,
|
labels: torch.Tensor,
|
||||||
):
|
):
|
||||||
|
@ -186,24 +186,6 @@ def _gather_logprobs(
|
||||||
return log_probs_labels
|
return log_probs_labels
|
||||||
|
|
||||||
|
|
||||||
_gather_logprobs_compiled = None
|
|
||||||
|
|
||||||
|
|
||||||
def gather_logprobs(
|
|
||||||
logits: torch.Tensor,
|
|
||||||
labels: torch.Tensor,
|
|
||||||
):
|
|
||||||
from realhf.base import cluster
|
|
||||||
|
|
||||||
if cluster.spec.name == "wa180":
|
|
||||||
# torch.compile doesn't work on PPU
|
|
||||||
return _gather_logprobs(logits, labels)
|
|
||||||
global _gather_logprobs_compiled
|
|
||||||
if _gather_logprobs_compiled is None:
|
|
||||||
_gather_logprobs_compiled = torch.compile(_gather_logprobs)
|
|
||||||
return _gather_logprobs_compiled(logits, labels)
|
|
||||||
|
|
||||||
|
|
||||||
def gather_packed_shifted_log_probs(
|
def gather_packed_shifted_log_probs(
|
||||||
logits: torch.FloatTensor,
|
logits: torch.FloatTensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
|
|
|
@ -5,9 +5,10 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
if TYPE_CHECKING:
|
||||||
|
from realhf.api.cli_args import BaseExperimentConfig
|
||||||
|
|
||||||
|
|
||||||
class JobState(enum.Enum):
|
class JobState(enum.Enum):
|
||||||
|
@ -50,10 +51,11 @@ class JobInfo:
|
||||||
|
|
||||||
class SchedulerClient:
|
class SchedulerClient:
|
||||||
|
|
||||||
def __init__(self, expr_name, trial_name):
|
def __init__(self, args: "BaseExperimentConfig"):
|
||||||
self.expr_name = expr_name
|
self.args = args
|
||||||
self.trial_name = trial_name
|
self.expr_name = args.experiment_name
|
||||||
self.run_name = f"{expr_name}_{trial_name}"
|
self.trial_name = args.trial_name
|
||||||
|
self.run_name = f"{self.expr_name}_{self.trial_name}"
|
||||||
|
|
||||||
def submit(self, worker_type, cmd, **kwargs):
|
def submit(self, worker_type, cmd, **kwargs):
|
||||||
"""Submits a job to the scheduler. Raises exception if the job is
|
"""Submits a job to the scheduler. Raises exception if the job is
|
||||||
|
@ -120,16 +122,10 @@ class SchedulerClient:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def get_python3_path():
|
|
||||||
if cluster_spec.cluster_type == "ray":
|
|
||||||
return subprocess.check_output(["which", "python3"]).decode("utf-8").strip()
|
|
||||||
return "python3"
|
|
||||||
|
|
||||||
|
|
||||||
def remote_worker_cmd(expr_name, trial_name, debug, worker_type):
|
def remote_worker_cmd(expr_name, trial_name, debug, worker_type):
|
||||||
# requires information in scheduler package
|
# requires information in scheduler package
|
||||||
return (
|
return (
|
||||||
f"{get_python3_path()} {'' if debug else '-O'} -m realhf.apps.remote worker -w {worker_type} "
|
f"python3 {'' if debug else '-O'} -m realhf.apps.remote worker -w {worker_type} "
|
||||||
f"-e {expr_name} -f {trial_name} -i {{jobstep_id}} -g {{n_jobsteps}} -r {{worker_submission_index}} "
|
f"-e {expr_name} -f {trial_name} -i {{jobstep_id}} -g {{n_jobsteps}} -r {{worker_submission_index}} "
|
||||||
f"-p {{wprocs_per_jobstep}} -j {{wprocs_in_job}} -o {{wproc_offset}}"
|
f"-p {{wprocs_per_jobstep}} -j {{wprocs_in_job}} -o {{wproc_offset}}"
|
||||||
)
|
)
|
||||||
|
@ -137,7 +133,7 @@ def remote_worker_cmd(expr_name, trial_name, debug, worker_type):
|
||||||
|
|
||||||
def setup_cmd(expr_name, trial_name, debug):
|
def setup_cmd(expr_name, trial_name, debug):
|
||||||
bash_cmd = ( # f"pip3 install -e $REAL_PACKAGE_PATH --no-build-isolation && "
|
bash_cmd = ( # f"pip3 install -e $REAL_PACKAGE_PATH --no-build-isolation && "
|
||||||
f"{get_python3_path()} {'' if debug else '-O'} -m realhf.apps.remote "
|
f"python3 {'' if debug else '-O'} -m realhf.apps.remote "
|
||||||
f"reset_name_resolve -e {expr_name} -f {trial_name}"
|
f"reset_name_resolve -e {expr_name} -f {trial_name}"
|
||||||
)
|
)
|
||||||
# return f"bash -c \"{bash_cmd}\""
|
# return f"bash -c \"{bash_cmd}\""
|
||||||
|
@ -146,7 +142,7 @@ def setup_cmd(expr_name, trial_name, debug):
|
||||||
|
|
||||||
def control_cmd(expr_name, trial_name, debug, ignore_worker_error, controller_type):
|
def control_cmd(expr_name, trial_name, debug, ignore_worker_error, controller_type):
|
||||||
bash_cmd = ( # f"pip3 install -e $REAL_PACKAGE_PATH --no-build-isolation && "
|
bash_cmd = ( # f"pip3 install -e $REAL_PACKAGE_PATH --no-build-isolation && "
|
||||||
f"{get_python3_path()} {'' if debug else '-O'} -m realhf.apps.remote controller "
|
f"python3 {'' if debug else '-O'} -m realhf.apps.remote controller "
|
||||||
f"-e {expr_name} -f {trial_name} "
|
f"-e {expr_name} -f {trial_name} "
|
||||||
f"--{'ignore_worker_error' if ignore_worker_error else 'raise_worker_error'} "
|
f"--{'ignore_worker_error' if ignore_worker_error else 'raise_worker_error'} "
|
||||||
f"--type {controller_type}"
|
f"--type {controller_type}"
|
||||||
|
@ -155,25 +151,23 @@ def control_cmd(expr_name, trial_name, debug, ignore_worker_error, controller_ty
|
||||||
return bash_cmd
|
return bash_cmd
|
||||||
|
|
||||||
|
|
||||||
def make(mode, expr_name, trial_name, **kwargs) -> SchedulerClient:
|
def make(args: "BaseExperimentConfig", **kwargs) -> SchedulerClient:
|
||||||
if mode == "slurm":
|
if args.mode == "slurm":
|
||||||
from realhf.scheduler.slurm.client import SlurmSchedulerClient
|
from realhf.scheduler.slurm.client import SlurmSchedulerClient
|
||||||
|
|
||||||
schedule_strategy = kwargs.get("schedule_strategy", "empty_first")
|
|
||||||
evaluator = kwargs.get("evaluator", None)
|
|
||||||
job_group_id = kwargs.get("job_group_id", None)
|
job_group_id = kwargs.get("job_group_id", None)
|
||||||
job_group_index = kwargs.get("job_group_index", None)
|
job_group_index = kwargs.get("job_group_index", None)
|
||||||
|
evaluator = kwargs.get("evaluator", None)
|
||||||
return SlurmSchedulerClient(
|
return SlurmSchedulerClient(
|
||||||
expr_name,
|
args,
|
||||||
trial_name,
|
args.schedule_strategy,
|
||||||
schedule_strategy,
|
|
||||||
evaluator,
|
evaluator,
|
||||||
job_group_id,
|
job_group_id,
|
||||||
job_group_index,
|
job_group_index,
|
||||||
)
|
)
|
||||||
elif mode == "local":
|
elif args.mode == "local":
|
||||||
from realhf.scheduler.local.client import LocalSchedulerClient
|
from realhf.scheduler.local.client import LocalSchedulerClient
|
||||||
|
|
||||||
return LocalSchedulerClient(expr_name, trial_name)
|
return LocalSchedulerClient(args)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Scheduler {mode} not found")
|
raise NotImplementedError(f"Scheduler {mode} not found")
|
||||||
|
|
|
@ -6,13 +6,14 @@ import pathlib
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import swanlab
|
import swanlab
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
import realhf.api.core.system_api as config_pkg
|
import realhf.api.core.system_api as config_pkg
|
||||||
from realhf.base import cluster, constants, logging
|
from realhf.api.cli_args import BaseExperimentConfig
|
||||||
|
from realhf.base import constants, logging
|
||||||
|
|
||||||
logger = logging.getLogger("AutomaticEvaluator", "colored")
|
logger = logging.getLogger("AutomaticEvaluator", "colored")
|
||||||
|
|
||||||
|
@ -27,6 +28,7 @@ class EvaluationStepStatus(enum.Enum):
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class EvaluationStep:
|
class EvaluationStep:
|
||||||
|
args: BaseExperimentConfig
|
||||||
global_step: int
|
global_step: int
|
||||||
status: EvaluationStepStatus
|
status: EvaluationStepStatus
|
||||||
start_time: Optional[float] = None
|
start_time: Optional[float] = None
|
||||||
|
@ -34,7 +36,7 @@ class EvaluationStep:
|
||||||
process: Optional[subprocess.Popen] = None
|
process: Optional[subprocess.Popen] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_ckpt_dir(ckpt_dir):
|
def from_ckpt_dir(args, ckpt_dir):
|
||||||
# NOTE: ckpt_dir should be absolute path
|
# NOTE: ckpt_dir should be absolute path
|
||||||
if pathlib.Path(ckpt_dir).is_symlink():
|
if pathlib.Path(ckpt_dir).is_symlink():
|
||||||
return None
|
return None
|
||||||
|
@ -44,13 +46,14 @@ class EvaluationStep:
|
||||||
return None
|
return None
|
||||||
_, _, global_step = map(int, match.groups())
|
_, _, global_step = map(int, match.groups())
|
||||||
return EvaluationStep(
|
return EvaluationStep(
|
||||||
|
args=args,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
status=EvaluationStepStatus.PENDING,
|
status=EvaluationStepStatus.PENDING,
|
||||||
ckpt_dir=ckpt_dir,
|
ckpt_dir=ckpt_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_output_dir(output_dir):
|
def from_output_dir(args, output_dir):
|
||||||
# NOTE: output_dir should be absolute path
|
# NOTE: output_dir should be absolute path
|
||||||
# Should only be called in recover.
|
# Should only be called in recover.
|
||||||
_dir = os.path.basename(output_dir)
|
_dir = os.path.basename(output_dir)
|
||||||
|
@ -59,15 +62,13 @@ class EvaluationStep:
|
||||||
return None
|
return None
|
||||||
global_step = int(match.groups()[0])
|
global_step = int(match.groups()[0])
|
||||||
return EvaluationStep(
|
return EvaluationStep(
|
||||||
global_step=global_step, status=EvaluationStepStatus.LOGGED
|
args=args, global_step=global_step, status=EvaluationStepStatus.LOGGED
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_dir(self):
|
def output_dir(self):
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
constants.LOG_ROOT,
|
constants.get_log_path(self.args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"eval_output",
|
"eval_output",
|
||||||
f"globalstep{self.global_step}",
|
f"globalstep{self.global_step}",
|
||||||
)
|
)
|
||||||
|
@ -77,7 +78,7 @@ class EvaluationStep:
|
||||||
cmd = (
|
cmd = (
|
||||||
f"srun --mpi=pmi2 -J {slurm_job_name} --ntasks=1 --cpus-per-task=128 --gres=gpu:8 --mem-per-cpu=12G "
|
f"srun --mpi=pmi2 -J {slurm_job_name} --ntasks=1 --cpus-per-task=128 --gres=gpu:8 --mem-per-cpu=12G "
|
||||||
f"singularity exec --no-home --nv --pid --writable-tmpfs --bind /storage:/storage "
|
f"singularity exec --no-home --nv --pid --writable-tmpfs --bind /storage:/storage "
|
||||||
f"{config.eval_job_image or cluster.spec.gpu_image} "
|
f"{config.eval_job_image or self.args.cluster.gpu_image} "
|
||||||
f"bash ./evaluation/sh/install_deps_and_eval.sh {self.ckpt_dir} {self.output_dir} "
|
f"bash ./evaluation/sh/install_deps_and_eval.sh {self.ckpt_dir} {self.output_dir} "
|
||||||
f"{config.data_names} {config.max_gen_tokens} {config.prompt_type}"
|
f"{config.data_names} {config.max_gen_tokens} {config.prompt_type}"
|
||||||
)
|
)
|
||||||
|
@ -86,7 +87,7 @@ class EvaluationStep:
|
||||||
def submit(self, config: config_pkg.AutomaticEvaluator):
|
def submit(self, config: config_pkg.AutomaticEvaluator):
|
||||||
os.makedirs(self.output_dir, exist_ok=True)
|
os.makedirs(self.output_dir, exist_ok=True)
|
||||||
log_file = open(os.path.join(self.output_dir, "output.log"), "w")
|
log_file = open(os.path.join(self.output_dir, "output.log"), "w")
|
||||||
if cluster.spec.cluster_type == "slurm":
|
if self.args.mode == "slurm":
|
||||||
cmd = self.slurm_eval_cmd(config)
|
cmd = self.slurm_eval_cmd(config)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -155,10 +156,12 @@ class AutomaticEvaluator:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
args: BaseExperimentConfig,
|
||||||
config: config_pkg.AutomaticEvaluator,
|
config: config_pkg.AutomaticEvaluator,
|
||||||
wandb_config: config_pkg.WandBConfig,
|
wandb_config: config_pkg.WandBConfig,
|
||||||
swanlab_config: config_pkg.SwanlabConfig,
|
swanlab_config: config_pkg.SwanlabConfig,
|
||||||
):
|
):
|
||||||
|
self.args = args
|
||||||
self.__eval_steps: Dict[int, EvaluationStep] = {}
|
self.__eval_steps: Dict[int, EvaluationStep] = {}
|
||||||
self.__max_concurrent_jobs = config.max_concurrent_jobs
|
self.__max_concurrent_jobs = config.max_concurrent_jobs
|
||||||
self.__wandb_config = wandb_config
|
self.__wandb_config = wandb_config
|
||||||
|
@ -174,15 +177,13 @@ class AutomaticEvaluator:
|
||||||
# Resubmiting or waiting for these jobs will probably result in
|
# Resubmiting or waiting for these jobs will probably result in
|
||||||
# unexpected behaviors.
|
# unexpected behaviors.
|
||||||
output_parent = os.path.join(
|
output_parent = os.path.join(
|
||||||
constants.LOG_ROOT,
|
constants.get_log_path(args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"eval_output",
|
"eval_output",
|
||||||
)
|
)
|
||||||
if os.path.exists(output_parent):
|
if os.path.exists(output_parent):
|
||||||
for output_dir in os.listdir(output_parent):
|
for output_dir in os.listdir(output_parent):
|
||||||
output_dir = os.path.join(output_parent, output_dir)
|
output_dir = os.path.join(output_parent, output_dir)
|
||||||
eval_step = EvaluationStep.from_output_dir(output_dir)
|
eval_step = EvaluationStep.from_output_dir(self.args, output_dir)
|
||||||
if eval_step:
|
if eval_step:
|
||||||
self.__eval_steps[eval_step.global_step] = eval_step
|
self.__eval_steps[eval_step.global_step] = eval_step
|
||||||
|
|
||||||
|
@ -197,12 +198,13 @@ class AutomaticEvaluator:
|
||||||
)
|
)
|
||||||
if self.__config.initial_checkpoint_path and 0 not in self.__eval_steps:
|
if self.__config.initial_checkpoint_path and 0 not in self.__eval_steps:
|
||||||
self.__eval_steps[0] = EvaluationStep(
|
self.__eval_steps[0] = EvaluationStep(
|
||||||
|
args=self.args,
|
||||||
global_step=0,
|
global_step=0,
|
||||||
status=EvaluationStepStatus.PENDING,
|
status=EvaluationStepStatus.PENDING,
|
||||||
ckpt_dir=self.__config.initial_checkpoint_path,
|
ckpt_dir=self.__config.initial_checkpoint_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not cluster.spec.cluster_type == "slurm":
|
if not self.args.mode == "slurm":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Currently only support automatic evaluation for slurm"
|
"Currently only support automatic evaluation for slurm"
|
||||||
)
|
)
|
||||||
|
@ -224,9 +226,7 @@ class AutomaticEvaluator:
|
||||||
notes=self.__wandb_config.notes,
|
notes=self.__wandb_config.notes,
|
||||||
tags=self.__wandb_config.tags,
|
tags=self.__wandb_config.tags,
|
||||||
config=self.__wandb_config.config,
|
config=self.__wandb_config.config,
|
||||||
dir=os.path.join(
|
dir=constants.get_log_path(self.args),
|
||||||
constants.LOG_ROOT, constants.experiment_name(), constants.trial_name()
|
|
||||||
),
|
|
||||||
force=True,
|
force=True,
|
||||||
id=f"{constants.experiment_name()}_{constants.trial_name()}_eval",
|
id=f"{constants.experiment_name()}_{constants.trial_name()}_eval",
|
||||||
resume="allow",
|
resume="allow",
|
||||||
|
@ -270,15 +270,13 @@ class AutomaticEvaluator:
|
||||||
def step(self):
|
def step(self):
|
||||||
# Check whether a new evaluation step should be created
|
# Check whether a new evaluation step should be created
|
||||||
ckpt_parent = os.path.join(
|
ckpt_parent = os.path.join(
|
||||||
constants.MODEL_SAVE_ROOT,
|
constants.get_save_path(self.args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"actor",
|
"actor",
|
||||||
)
|
)
|
||||||
if os.path.exists(ckpt_parent):
|
if os.path.exists(ckpt_parent):
|
||||||
for ckpt_dir in os.listdir(ckpt_parent):
|
for ckpt_dir in os.listdir(ckpt_parent):
|
||||||
ckpt_dir = os.path.join(ckpt_parent, ckpt_dir)
|
ckpt_dir = os.path.join(ckpt_parent, ckpt_dir)
|
||||||
eval_step = EvaluationStep.from_ckpt_dir(ckpt_dir)
|
eval_step = EvaluationStep.from_ckpt_dir(self.args, ckpt_dir)
|
||||||
if eval_step is None:
|
if eval_step is None:
|
||||||
continue
|
continue
|
||||||
if eval_step.global_step in self.__eval_steps:
|
if eval_step.global_step in self.__eval_steps:
|
||||||
|
|
|
@ -14,7 +14,7 @@ import psutil
|
||||||
|
|
||||||
import realhf.base.logging as logging
|
import realhf.base.logging as logging
|
||||||
from realhf.base import gpu_utils
|
from realhf.base import gpu_utils
|
||||||
from realhf.base.constants import LOG_ROOT
|
from realhf.base.constants import get_log_path
|
||||||
from realhf.scheduler.client import (
|
from realhf.scheduler.client import (
|
||||||
JobException,
|
JobException,
|
||||||
JobInfo,
|
JobInfo,
|
||||||
|
@ -75,14 +75,12 @@ class LocalSchedulerClient(SchedulerClient):
|
||||||
|
|
||||||
def log_path_of(self, worker_type) -> str:
|
def log_path_of(self, worker_type) -> str:
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
LOG_ROOT,
|
get_log_path(self.args),
|
||||||
self.expr_name,
|
|
||||||
self.trial_name,
|
|
||||||
f"{worker_type}-0",
|
f"{worker_type}-0",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, expr_name, trial_name):
|
def __init__(self, args):
|
||||||
super().__init__(expr_name, trial_name)
|
super().__init__(args)
|
||||||
self._jobs: Dict[str, subprocess.Popen] = {}
|
self._jobs: Dict[str, subprocess.Popen] = {}
|
||||||
self._running_worker_types = []
|
self._running_worker_types = []
|
||||||
|
|
||||||
|
|
|
@ -15,9 +15,7 @@ from typing import Dict, List, Literal, Optional, Tuple
|
||||||
import colorama
|
import colorama
|
||||||
|
|
||||||
import realhf.base.logging as logging
|
import realhf.base.logging as logging
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
from realhf.base.constants import get_log_path
|
||||||
from realhf.base.constants import LOG_ROOT
|
|
||||||
from realhf.base.constants import SLURM_LOCK_FILE_NAME as LOCK_FILE_NAME
|
|
||||||
from realhf.scheduler.client import JobException, JobInfo, JobState, SchedulerClient
|
from realhf.scheduler.client import JobException, JobInfo, JobState, SchedulerClient
|
||||||
from realhf.scheduler.evaluator import AutomaticEvaluator
|
from realhf.scheduler.evaluator import AutomaticEvaluator
|
||||||
from realhf.scheduler.slurm.utils import (
|
from realhf.scheduler.slurm.utils import (
|
||||||
|
@ -82,14 +80,13 @@ class SlurmSchedulerClient(SchedulerClient):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
expr_name: str,
|
args,
|
||||||
trial_name: str,
|
|
||||||
schedule_strategy: str,
|
schedule_strategy: str,
|
||||||
evaluator: Optional[AutomaticEvaluator],
|
evaluator: Optional[AutomaticEvaluator],
|
||||||
job_group_id: str,
|
job_group_id: str,
|
||||||
job_group_index: int,
|
job_group_index: int,
|
||||||
):
|
):
|
||||||
super().__init__(expr_name, trial_name)
|
super().__init__(args)
|
||||||
|
|
||||||
self.__schedule_strategy = schedule_strategy
|
self.__schedule_strategy = schedule_strategy
|
||||||
|
|
||||||
|
@ -124,11 +121,31 @@ class SlurmSchedulerClient(SchedulerClient):
|
||||||
deadline: str = None,
|
deadline: str = None,
|
||||||
time_limit: str = None,
|
time_limit: str = None,
|
||||||
):
|
):
|
||||||
container_image = container_image or cluster_spec.cpu_image
|
container_image = container_image or self.args.cluster.cpu_image
|
||||||
container_mounts = container_mounts or cluster_spec.mount
|
container_mounts = container_mounts or self.args.cluster.mount
|
||||||
# record launch information, do not submit to slurm until `wait()` is called
|
# record launch information, do not submit to slurm until `wait()` is called
|
||||||
# NOTE: fractional GPU requirement will be resolved automatically in `__post_init__` of SlurnLaunchInfo
|
# NOTE: fractional GPU requirement will be resolved automatically in `__post_init__` of SlurnLaunchInfo
|
||||||
|
log_path = os.path.join(
|
||||||
|
get_log_path(self.args),
|
||||||
|
f"{worker_type}-{self.__submission_counter[worker_type]}.log",
|
||||||
|
)
|
||||||
|
multiprog_path = os.path.join(
|
||||||
|
get_log_path(self.args),
|
||||||
|
"slurm",
|
||||||
|
"multiprog",
|
||||||
|
f"{worker_type}-{self.__submission_counter[worker_type]}.multiprog",
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.dirname(multiprog_path), exist_ok=True)
|
||||||
|
hostfile_path = os.path.join(
|
||||||
|
get_log_path(self.args),
|
||||||
|
"slurm",
|
||||||
|
"hostfile",
|
||||||
|
f"{worker_type}-{self.__submission_counter[worker_type]}.hostfile",
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.dirname(hostfile_path), exist_ok=True)
|
||||||
|
|
||||||
launch_info = SlurmLaunchInfo(
|
launch_info = SlurmLaunchInfo(
|
||||||
|
args=self.args,
|
||||||
worker_type=worker_type,
|
worker_type=worker_type,
|
||||||
wprocs_in_job=count,
|
wprocs_in_job=count,
|
||||||
resource_requirement=SlurmResource(mem=mem, cpu=cpu, gpu=gpu),
|
resource_requirement=SlurmResource(mem=mem, cpu=cpu, gpu=gpu),
|
||||||
|
@ -149,6 +166,9 @@ class SlurmSchedulerClient(SchedulerClient):
|
||||||
time_limit=time_limit,
|
time_limit=time_limit,
|
||||||
job_group_id=self.__job_group_id,
|
job_group_id=self.__job_group_id,
|
||||||
job_group_index=self.__job_group_index,
|
job_group_index=self.__job_group_index,
|
||||||
|
log_path=log_path,
|
||||||
|
multiprog_path=multiprog_path,
|
||||||
|
hostfile_path=hostfile_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -177,9 +197,9 @@ class SlurmSchedulerClient(SchedulerClient):
|
||||||
wproc_offset=self.__wprocs_counter[worker_type],
|
wproc_offset=self.__wprocs_counter[worker_type],
|
||||||
)
|
)
|
||||||
wrap_cmd = "singularity exec "
|
wrap_cmd = "singularity exec "
|
||||||
if cluster_spec.name == "na132":
|
if self.args.cluster.cluster_name == "na132":
|
||||||
wrap_cmd += "--pid "
|
wrap_cmd += "--pid "
|
||||||
if cluster_spec.gpu_type == "tesla":
|
if self.args.cluster.gpu_type == "tesla":
|
||||||
wrap_cmd += "--nv "
|
wrap_cmd += "--nv "
|
||||||
wrap_cmd += "--no-home --writable-tmpfs "
|
wrap_cmd += "--no-home --writable-tmpfs "
|
||||||
if len(launch_info.env_vars) > 0:
|
if len(launch_info.env_vars) > 0:
|
||||||
|
@ -199,7 +219,9 @@ class SlurmSchedulerClient(SchedulerClient):
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
fp = open(LOCK_FILE_NAME, "w")
|
fp = open(
|
||||||
|
f"{self.args.cluster.fileroot}/logs/slurm_scheduler.lock", "w"
|
||||||
|
)
|
||||||
fcntl.flock(fp, fcntl.LOCK_EX)
|
fcntl.flock(fp, fcntl.LOCK_EX)
|
||||||
infos = list(self.__pending_jobs.values())
|
infos = list(self.__pending_jobs.values())
|
||||||
infos = allocate_resources(infos, strategy=self.__schedule_strategy)
|
infos = allocate_resources(infos, strategy=self.__schedule_strategy)
|
||||||
|
@ -297,9 +319,7 @@ class SlurmSchedulerClient(SchedulerClient):
|
||||||
threads = []
|
threads = []
|
||||||
stop_events = []
|
stop_events = []
|
||||||
|
|
||||||
merged_log_path = os.path.join(
|
merged_log_path = os.path.join(get_log_path(self.args), "main.log")
|
||||||
LOG_ROOT, self.expr_name, self.trial_name, "main.log"
|
|
||||||
)
|
|
||||||
|
|
||||||
for job_name, launch_info in self.__committed_jobs.items():
|
for job_name, launch_info in self.__committed_jobs.items():
|
||||||
stop_event = threading.Event()
|
stop_event = threading.Event()
|
||||||
|
|
|
@ -20,10 +20,10 @@ from typing import Callable, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
import realhf.base.cluster as cluster
|
|
||||||
import realhf.base.logging as logging
|
import realhf.base.logging as logging
|
||||||
import realhf.version as version
|
import realhf.version as version
|
||||||
from realhf.base.constants import LOG_ROOT
|
from realhf.api.cli_args import BaseExperimentConfig
|
||||||
|
from realhf.base.constants import get_log_path
|
||||||
from realhf.scheduler.client import JobException, JobInfo, JobState
|
from realhf.scheduler.client import JobException, JobInfo, JobState
|
||||||
|
|
||||||
logger = logging.getLogger("scheduler.slurm.utils")
|
logger = logging.getLogger("scheduler.slurm.utils")
|
||||||
|
@ -190,6 +190,7 @@ class SlurmLaunchInfo:
|
||||||
multiprog_content (str, optional): The content of the multiprog file.
|
multiprog_content (str, optional): The content of the multiprog file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
args: BaseExperimentConfig
|
||||||
run_name: str
|
run_name: str
|
||||||
exper_name: str
|
exper_name: str
|
||||||
trial_name: str
|
trial_name: str
|
||||||
|
@ -199,6 +200,10 @@ class SlurmLaunchInfo:
|
||||||
job_group_id: str
|
job_group_id: str
|
||||||
job_group_index: str
|
job_group_index: str
|
||||||
|
|
||||||
|
log_path: str
|
||||||
|
multiprog_path: str
|
||||||
|
hostfile_path: str
|
||||||
|
|
||||||
resource_requirement: SlurmResource
|
resource_requirement: SlurmResource
|
||||||
cmd: str
|
cmd: str
|
||||||
container_image: str
|
container_image: str
|
||||||
|
@ -252,41 +257,6 @@ class SlurmLaunchInfo:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
|
||||||
def log_path(self) -> str:
|
|
||||||
return os.path.join(
|
|
||||||
LOG_ROOT,
|
|
||||||
self.exper_name,
|
|
||||||
self.trial_name,
|
|
||||||
f"{self.worker_type}-{self.worker_submission_idx}.log",
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def multiprog_path(self) -> str:
|
|
||||||
path = os.path.join(
|
|
||||||
LOG_ROOT,
|
|
||||||
self.exper_name,
|
|
||||||
self.trial_name,
|
|
||||||
"slurm",
|
|
||||||
"multiprog",
|
|
||||||
f"{self.worker_type}-{self.worker_submission_idx}.multiprog",
|
|
||||||
)
|
|
||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
||||||
return path
|
|
||||||
|
|
||||||
@property
|
|
||||||
def hostfile_path(self) -> str:
|
|
||||||
path = os.path.join(
|
|
||||||
LOG_ROOT,
|
|
||||||
self.exper_name,
|
|
||||||
self.trial_name,
|
|
||||||
"slurm",
|
|
||||||
"hostfile",
|
|
||||||
f"{self.worker_type}-{self.worker_submission_idx}.hostfile",
|
|
||||||
)
|
|
||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
||||||
return path
|
|
||||||
|
|
||||||
def show_log(self):
|
def show_log(self):
|
||||||
try:
|
try:
|
||||||
terminal_columns = os.get_terminal_size().columns
|
terminal_columns = os.get_terminal_size().columns
|
||||||
|
@ -364,14 +334,14 @@ class SlurmLaunchInfo:
|
||||||
# head
|
# head
|
||||||
gres_line = ""
|
gres_line = ""
|
||||||
if gpu >= 1:
|
if gpu >= 1:
|
||||||
assert (gpu * ntasks) % cluster.spec.n_gpus_per_node == 0
|
assert (gpu * ntasks) % self.args.cluster.n_gpus_per_node == 0
|
||||||
# In current slurm cluster setup, we can only use "--gres" to
|
# In current slurm cluster setup, we can only use "--gres" to
|
||||||
# allocate PPUs per node. There are no options to allocate customized
|
# allocate PPUs per node. There are no options to allocate customized
|
||||||
# gres per tasks.
|
# gres per tasks.
|
||||||
if cluster.spec.gpu_type == "ppu":
|
if self.args.cluster.gpu_type == "ppu":
|
||||||
gres_line = f"--gres=ppu:{cluster.spec.n_gpus_per_node}"
|
gres_line = f"--gres=ppu:{self.args.cluster.n_gpus_per_node}"
|
||||||
else:
|
else:
|
||||||
gres_line = f"--gres=gpu:{cluster.spec.n_gpus_per_node}"
|
gres_line = f"--gres=gpu:{self.args.cluster.n_gpus_per_node}"
|
||||||
|
|
||||||
srun_env = os.environ.copy()
|
srun_env = os.environ.copy()
|
||||||
job_metadata = {
|
job_metadata = {
|
||||||
|
@ -391,7 +361,7 @@ class SlurmLaunchInfo:
|
||||||
f"#SBATCH --output={self.log_path}",
|
f"#SBATCH --output={self.log_path}",
|
||||||
"#SBATCH --open-mode=append",
|
"#SBATCH --open-mode=append",
|
||||||
f"#SBATCH --ntasks={ntasks}",
|
f"#SBATCH --ntasks={ntasks}",
|
||||||
f"#SBATCH {gres_line}",
|
f"#SBATCH {gres_line}" if gpu >= 1 else "",
|
||||||
f"#SBATCH --cpus-per-task={cpu}",
|
f"#SBATCH --cpus-per-task={cpu}",
|
||||||
f"#SBATCH --mem-per-cpu={mem // max(1, cpu)}M",
|
f"#SBATCH --mem-per-cpu={mem // max(1, cpu)}M",
|
||||||
"#SBATCH --distribution=arbitrary" if self.hostfile else "",
|
"#SBATCH --distribution=arbitrary" if self.hostfile else "",
|
||||||
|
@ -722,6 +692,9 @@ def allocate_resources(
|
||||||
infos, key=lambda x: x.n_jobsteps * x.resource_requirement, reverse=True
|
infos, key=lambda x: x.n_jobsteps * x.resource_requirement, reverse=True
|
||||||
)
|
)
|
||||||
prioritized_hosts = set()
|
prioritized_hosts = set()
|
||||||
|
if len(infos) == 0:
|
||||||
|
return infos
|
||||||
|
cluster_config = infos[0].args.cluster
|
||||||
for info_idx, info in enumerate(infos):
|
for info_idx, info in enumerate(infos):
|
||||||
valid_hostnames = available_hostnames(
|
valid_hostnames = available_hostnames(
|
||||||
nodelist=info.nodelist,
|
nodelist=info.nodelist,
|
||||||
|
@ -764,12 +737,12 @@ def allocate_resources(
|
||||||
gpu_per_task = info.resource_requirement.gpu
|
gpu_per_task = info.resource_requirement.gpu
|
||||||
if gpu_per_task > 0:
|
if gpu_per_task > 0:
|
||||||
assert (
|
assert (
|
||||||
task_left * gpu_per_task % cluster.spec.n_gpus_per_node == 0
|
task_left * gpu_per_task % cluster_config.n_gpus_per_node == 0
|
||||||
), (task_left, gpu_per_task)
|
), (task_left, gpu_per_task)
|
||||||
assert (
|
assert (
|
||||||
cluster.spec.n_gpus_per_node % gpu_per_task == 0
|
cluster_config.n_gpus_per_node % gpu_per_task == 0
|
||||||
), gpu_per_task
|
), gpu_per_task
|
||||||
batched_ntasks = int(cluster.spec.n_gpus_per_node // gpu_per_task)
|
batched_ntasks = int(cluster_config.n_gpus_per_node // gpu_per_task)
|
||||||
batched_requirement = batched_ntasks * info.resource_requirement
|
batched_requirement = batched_ntasks * info.resource_requirement
|
||||||
try:
|
try:
|
||||||
resource = resource - batched_requirement
|
resource = resource - batched_requirement
|
||||||
|
@ -787,10 +760,10 @@ def allocate_resources(
|
||||||
allocated[hostname] = tmp - task_left
|
allocated[hostname] = tmp - task_left
|
||||||
all_resources[hostname] = resource
|
all_resources[hostname] = resource
|
||||||
if task_left > 0:
|
if task_left > 0:
|
||||||
if cluster.spec.gpu_type == "ppu" and info.resource_requirement.gpu > 0:
|
if cluster_config.gpu_type == "ppu" and info.resource_requirement.gpu > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"For PPU resources, we can only allocate tasks in the "
|
"For PPU resources, we can only allocate tasks in the "
|
||||||
f"granularity of nodes ({cluster.spec.n_gpus_per_node} PPUs)"
|
f"granularity of nodes ({cluster_config.n_gpus_per_node} PPUs)"
|
||||||
)
|
)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f'Unable to allocate {info.n_jobsteps} Jobs with name "{info.slurm_name}". '
|
f'Unable to allocate {info.n_jobsteps} Jobs with name "{info.slurm_name}". '
|
||||||
|
|
|
@ -61,7 +61,7 @@ def run_worker(
|
||||||
)
|
)
|
||||||
worker = worker_class(server=server)
|
worker = worker_class(server=server)
|
||||||
try:
|
try:
|
||||||
if worker_type in ["rollout_worker", "master_worker", "gserver_manager"]:
|
if worker_type in ["rollout_worker", "master_worker"]:
|
||||||
asyncio.run(worker.run_async())
|
asyncio.run(worker.run_async())
|
||||||
else:
|
else:
|
||||||
worker.run()
|
worker.run()
|
||||||
|
|
|
@ -27,7 +27,6 @@ from omegaconf import OmegaConf
|
||||||
|
|
||||||
import realhf.api.core.system_api as system_api
|
import realhf.api.core.system_api as system_api
|
||||||
from realhf.base import constants, gpu_utils, logging, name_resolve, names, pkg_version
|
from realhf.base import constants, gpu_utils, logging, name_resolve, names, pkg_version
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
|
||||||
from realhf.system import WORKER_TYPES, load_worker, worker_base, worker_control
|
from realhf.system import WORKER_TYPES, load_worker, worker_base, worker_control
|
||||||
from realhf.system.worker_base import WorkerServerStatus as Wss
|
from realhf.system.worker_base import WorkerServerStatus as Wss
|
||||||
|
|
||||||
|
@ -247,9 +246,7 @@ class Controller:
|
||||||
|
|
||||||
# If a log exists, find the last failed setup and run it.
|
# If a log exists, find the last failed setup and run it.
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
prev_logfile = os.path.join(
|
prev_logfile = os.path.join(constants.get_log_path(experiment), "ctl-0")
|
||||||
constants.LOG_ROOT, self.experiment_name, self.trial_name, "ctl-0"
|
|
||||||
)
|
|
||||||
if os.path.exists(prev_logfile):
|
if os.path.exists(prev_logfile):
|
||||||
with open(prev_logfile, "r") as f:
|
with open(prev_logfile, "r") as f:
|
||||||
for l in f.readlines():
|
for l in f.readlines():
|
||||||
|
@ -670,6 +667,7 @@ class RayController:
|
||||||
]
|
]
|
||||||
|
|
||||||
env_vars = constants.get_env_vars(
|
env_vars = constants.get_env_vars(
|
||||||
|
experiment,
|
||||||
REAL_MODE=os.environ.get("REAL_MODE", ""),
|
REAL_MODE=os.environ.get("REAL_MODE", ""),
|
||||||
REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""),
|
REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""),
|
||||||
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
|
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
|
||||||
|
|
|
@ -24,6 +24,7 @@ blogger = logging.getLogger("benchmark")
|
||||||
class FunctionExecutor:
|
class FunctionExecutor:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
args,
|
||||||
rpcs: List[MFCDef],
|
rpcs: List[MFCDef],
|
||||||
msid2mwid: Dict[ModelShardID, int],
|
msid2mwid: Dict[ModelShardID, int],
|
||||||
stream: NameResolvingRequestClient,
|
stream: NameResolvingRequestClient,
|
||||||
|
@ -35,6 +36,8 @@ class FunctionExecutor:
|
||||||
shuffle_dataset: bool,
|
shuffle_dataset: bool,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
self.args = args
|
||||||
|
|
||||||
self.func_calls: Dict[str, ModelFunctionCall] = {}
|
self.func_calls: Dict[str, ModelFunctionCall] = {}
|
||||||
self.ctrl = ctrl
|
self.ctrl = ctrl
|
||||||
|
|
||||||
|
@ -42,7 +45,9 @@ class FunctionExecutor:
|
||||||
self.msid2mwid = msid2mwid
|
self.msid2mwid = msid2mwid
|
||||||
|
|
||||||
self.storage_tracker = GlobalStorageTracker(self.n_model_workers)
|
self.storage_tracker = GlobalStorageTracker(self.n_model_workers)
|
||||||
self.redistrib_planner = RedistribPlanner(self.storage_tracker)
|
self.redistrib_planner = RedistribPlanner(
|
||||||
|
self.args.cluster, self.storage_tracker
|
||||||
|
)
|
||||||
|
|
||||||
self.rpcs = rpcs
|
self.rpcs = rpcs
|
||||||
self.src_rpc = list(filter(lambda rpc: rpc.is_src, rpcs))[0]
|
self.src_rpc = list(filter(lambda rpc: rpc.is_src, rpcs))[0]
|
||||||
|
@ -51,6 +56,7 @@ class FunctionExecutor:
|
||||||
# Create model function calls.
|
# Create model function calls.
|
||||||
for rpc in self.rpcs:
|
for rpc in self.rpcs:
|
||||||
func_call = ModelFunctionCall(
|
func_call = ModelFunctionCall(
|
||||||
|
args=self.args,
|
||||||
rpc=rpc,
|
rpc=rpc,
|
||||||
src_rpc=self.src_rpc,
|
src_rpc=self.src_rpc,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
|
|
@ -20,7 +20,6 @@ from realhf.base import (
|
||||||
pkg_version,
|
pkg_version,
|
||||||
seeding,
|
seeding,
|
||||||
)
|
)
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
|
||||||
from realhf.system.worker_base import PollResult, Worker
|
from realhf.system.worker_base import PollResult, Worker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -139,7 +138,7 @@ class GenerationServer(Worker):
|
||||||
map(str, range(gpu_utils.gpu_count()))
|
map(str, range(gpu_utils.gpu_count()))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
servers_per_node = cluster_spec.n_gpus_per_node // self.config.tp_size
|
servers_per_node = self.args.cluster.n_gpus_per_node // self.config.tp_size
|
||||||
idx_on_this_node = self.worker_index % servers_per_node
|
idx_on_this_node = self.worker_index % servers_per_node
|
||||||
self.base_gpu_id = idx_on_this_node * self.config.tp_size
|
self.base_gpu_id = idx_on_this_node * self.config.tp_size
|
||||||
|
|
||||||
|
@ -159,7 +158,7 @@ class GenerationServer(Worker):
|
||||||
# NOTE: Ports returned by `find_multiple_free_ports` are unique,
|
# NOTE: Ports returned by `find_multiple_free_ports` are unique,
|
||||||
# but SGLang servers still encounter conflicts.
|
# but SGLang servers still encounter conflicts.
|
||||||
# Use a clearance period to hack over this issue.
|
# Use a clearance period to hack over this issue.
|
||||||
servers_per_node = cluster_spec.n_gpus_per_node // self.config.tp_size
|
servers_per_node = self.args.cluster.n_gpus_per_node // self.config.tp_size
|
||||||
idx_on_this_node = self.worker_index % servers_per_node
|
idx_on_this_node = self.worker_index % servers_per_node
|
||||||
time.sleep(idx_on_this_node * PORT_CLEARANCE_PERIOD / servers_per_node)
|
time.sleep(idx_on_this_node * PORT_CLEARANCE_PERIOD / servers_per_node)
|
||||||
|
|
||||||
|
@ -169,6 +168,7 @@ class GenerationServer(Worker):
|
||||||
high=60000,
|
high=60000,
|
||||||
experiment_name=self.experiment_name,
|
experiment_name=self.experiment_name,
|
||||||
trial_name=self.trial_name,
|
trial_name=self.trial_name,
|
||||||
|
lockfile_root=os.path.join(constants.get_cache_path(self.args), "ports"),
|
||||||
)
|
)
|
||||||
server_port = ports[0]
|
server_port = ports[0]
|
||||||
nccl_port = ports[1]
|
nccl_port = ports[1]
|
||||||
|
|
|
@ -29,7 +29,7 @@ class AllocateRolloutInput:
|
||||||
qid: str
|
qid: str
|
||||||
|
|
||||||
|
|
||||||
class GserverManager(AsyncWorker):
|
class GserverManager(Worker):
|
||||||
"""This worker has the following functionalities:
|
"""This worker has the following functionalities:
|
||||||
1. As a router, it schedules generation requests and returns the
|
1. As a router, it schedules generation requests and returns the
|
||||||
best server urls to clients for submitting generation requests.
|
best server urls to clients for submitting generation requests.
|
||||||
|
@ -71,7 +71,7 @@ class GserverManager(AsyncWorker):
|
||||||
self.server_urls = []
|
self.server_urls = []
|
||||||
|
|
||||||
# recover info
|
# recover info
|
||||||
self.__recover_run, self.__recover_info = recover.load_recover_info()
|
self.__recover_run, self.__recover_info = recover.load_recover_info(self.args)
|
||||||
if self.__recover_run:
|
if self.__recover_run:
|
||||||
# update weights will be automatically triggered upon the first schedule_request
|
# update weights will be automatically triggered upon the first schedule_request
|
||||||
# self._last_param_realloc_step will also be updated
|
# self._last_param_realloc_step will also be updated
|
||||||
|
@ -110,11 +110,7 @@ class GserverManager(AsyncWorker):
|
||||||
epoch = self.__recover_info.last_step_info.epoch + 1
|
epoch = self.__recover_info.last_step_info.epoch + 1
|
||||||
epochstep = self.__recover_info.last_step_info.epoch_step + 1
|
epochstep = self.__recover_info.last_step_info.epoch_step + 1
|
||||||
globalstep = self.__recover_info.last_step_info.global_step + 1
|
globalstep = self.__recover_info.last_step_info.global_step + 1
|
||||||
save_root = os.path.join(
|
save_root = constants.get_save_path(self.args)
|
||||||
constants.MODEL_SAVE_ROOT,
|
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
)
|
|
||||||
role_path = os.path.join(save_root, role)
|
role_path = os.path.join(save_root, role)
|
||||||
if not os.path.exists(role_path):
|
if not os.path.exists(role_path):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -150,9 +146,7 @@ class GserverManager(AsyncWorker):
|
||||||
self._loaded_recover_weights = True
|
self._loaded_recover_weights = True
|
||||||
else:
|
else:
|
||||||
realloc_dir = os.path.join(
|
realloc_dir = os.path.join(
|
||||||
constants.PARAM_REALLOC_PATH,
|
constants.get_param_realloc_path(self.args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
self.model_name.role,
|
self.model_name.role,
|
||||||
str(realloc_version),
|
str(realloc_version),
|
||||||
)
|
)
|
||||||
|
@ -213,7 +207,7 @@ class GserverManager(AsyncWorker):
|
||||||
url = min(self.server_urls, key=lambda k: self._server_token_usage[k])
|
url = min(self.server_urls, key=lambda k: self._server_token_usage[k])
|
||||||
return self.server_urls.index(url)
|
return self.server_urls.index(url)
|
||||||
|
|
||||||
async def _poll_async(self):
|
def _poll(self):
|
||||||
if not self.thread:
|
if not self.thread:
|
||||||
# Find addresses of generation servers
|
# Find addresses of generation servers
|
||||||
self.server_urls = self._discover_servers(self.config.n_servers)
|
self.server_urls = self._discover_servers(self.config.n_servers)
|
||||||
|
@ -292,9 +286,7 @@ class GserverManager(AsyncWorker):
|
||||||
|
|
||||||
# clear old weights
|
# clear old weights
|
||||||
realloc_root = os.path.join(
|
realloc_root = os.path.join(
|
||||||
constants.PARAM_REALLOC_PATH,
|
constants.get_param_realloc_path(self.args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
self.model_name.role,
|
self.model_name.role,
|
||||||
)
|
)
|
||||||
if os.path.exists(realloc_root):
|
if os.path.exists(realloc_root):
|
||||||
|
@ -483,6 +475,7 @@ class GserverManager(AsyncWorker):
|
||||||
port = network.find_free_port(
|
port = network.find_free_port(
|
||||||
experiment_name=self.experiment_name,
|
experiment_name=self.experiment_name,
|
||||||
trial_name=self.trial_name,
|
trial_name=self.trial_name,
|
||||||
|
lockfile_root=os.path.join(constants.get_cache_path(self.args), "ports"),
|
||||||
)
|
)
|
||||||
self.manager_addr = f"{network.gethostip()}:{port}"
|
self.manager_addr = f"{network.gethostip()}:{port}"
|
||||||
|
|
||||||
|
|
|
@ -97,15 +97,8 @@ class MasterWorker(worker_base.AsyncWorker):
|
||||||
freq_sec=config.exp_ctrl.eval_freq_secs,
|
freq_sec=config.exp_ctrl.eval_freq_secs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.MODEL_SAVE_ROOT = os.path.join(
|
|
||||||
constants.MODEL_SAVE_ROOT,
|
|
||||||
config.worker_info.experiment_name,
|
|
||||||
config.worker_info.trial_name,
|
|
||||||
)
|
|
||||||
os.makedirs(self.MODEL_SAVE_ROOT, exist_ok=True)
|
|
||||||
|
|
||||||
self.__initialized = False
|
self.__initialized = False
|
||||||
self.__recover_run, self.__recover_info = recover.load_recover_info()
|
self.__recover_run, self.__recover_info = recover.load_recover_info(self.args)
|
||||||
if self.__recover_info is not None:
|
if self.__recover_info is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded recover info: recover_start={self.__recover_info.recover_start}, "
|
f"Loaded recover info: recover_start={self.__recover_info.recover_start}, "
|
||||||
|
@ -305,9 +298,7 @@ class MasterWorker(worker_base.AsyncWorker):
|
||||||
notes=self.wandb_config.notes,
|
notes=self.wandb_config.notes,
|
||||||
tags=self.wandb_config.tags,
|
tags=self.wandb_config.tags,
|
||||||
config=self.wandb_config.config,
|
config=self.wandb_config.config,
|
||||||
dir=os.path.join(
|
dir=constants.get_log_path(self.args),
|
||||||
constants.LOG_ROOT, constants.experiment_name(), constants.trial_name()
|
|
||||||
),
|
|
||||||
force=True,
|
force=True,
|
||||||
id=f"{constants.experiment_name()}_{constants.trial_name()}_train",
|
id=f"{constants.experiment_name()}_{constants.trial_name()}_train",
|
||||||
resume="allow",
|
resume="allow",
|
||||||
|
@ -355,6 +346,7 @@ class MasterWorker(worker_base.AsyncWorker):
|
||||||
# Create coroutines for model RPCs.
|
# Create coroutines for model RPCs.
|
||||||
logger.debug(f"Creating asyncio coroutines...")
|
logger.debug(f"Creating asyncio coroutines...")
|
||||||
self.func_executor = FunctionExecutor(
|
self.func_executor = FunctionExecutor(
|
||||||
|
args=self.args,
|
||||||
rpcs=self.__model_rpcs,
|
rpcs=self.__model_rpcs,
|
||||||
msid2mwid=self.config.msid2mwid,
|
msid2mwid=self.config.msid2mwid,
|
||||||
stream=self.__stream,
|
stream=self.__stream,
|
||||||
|
@ -599,7 +591,7 @@ class MasterWorker(worker_base.AsyncWorker):
|
||||||
hash_vals_to_ignore=self.__rpc_ctrl.used_hash_vals_this_epoch,
|
hash_vals_to_ignore=self.__rpc_ctrl.used_hash_vals_this_epoch,
|
||||||
)
|
)
|
||||||
|
|
||||||
recover.dump_recover_info(recover_info)
|
recover.dump_recover_info(self.args, recover_info)
|
||||||
logger.info("Dumped recover info to file.")
|
logger.info("Dumped recover info to file.")
|
||||||
logger.info(f"Will recover from: {recover_info.recover_start}")
|
logger.info(f"Will recover from: {recover_info.recover_start}")
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
@ -56,6 +56,7 @@ class RPCCorountineControl:
|
||||||
class ModelFunctionCall:
|
class ModelFunctionCall:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
args,
|
||||||
rpc: dfg.MFCDef,
|
rpc: dfg.MFCDef,
|
||||||
src_rpc: dfg.MFCDef,
|
src_rpc: dfg.MFCDef,
|
||||||
stream: request_reply_stream.NameResolvingRequestClient,
|
stream: request_reply_stream.NameResolvingRequestClient,
|
||||||
|
@ -68,6 +69,8 @@ class ModelFunctionCall:
|
||||||
summary_writer: SummaryWriter | None,
|
summary_writer: SummaryWriter | None,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
self.args = args
|
||||||
|
|
||||||
self.rpc = rpc
|
self.rpc = rpc
|
||||||
self.src_rpc = src_rpc
|
self.src_rpc = src_rpc
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
@ -82,12 +85,6 @@ class ModelFunctionCall:
|
||||||
for msid, mwid in msid2mwid.items():
|
for msid, mwid in msid2mwid.items():
|
||||||
self.mwid2msids[mwid].append(msid)
|
self.mwid2msids[mwid].append(msid)
|
||||||
|
|
||||||
self.model_save_root = os.path.join(
|
|
||||||
constants.MODEL_SAVE_ROOT,
|
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.rpc_ctrl = ctrl
|
self.rpc_ctrl = ctrl
|
||||||
self.buffers = buffers
|
self.buffers = buffers
|
||||||
self.redistrib_planner = redistrib_planner
|
self.redistrib_planner = redistrib_planner
|
||||||
|
@ -221,7 +218,7 @@ class ModelFunctionCall:
|
||||||
for p in payloads.values():
|
for p in payloads.values():
|
||||||
p.post_hooks.append("save")
|
p.post_hooks.append("save")
|
||||||
save_dir = os.path.join(
|
save_dir = os.path.join(
|
||||||
self.model_save_root,
|
constants.get_log_path(self.args),
|
||||||
rpc.model_name.role,
|
rpc.model_name.role,
|
||||||
f"epoch{ctrl.step_info.epoch + 1}"
|
f"epoch{ctrl.step_info.epoch + 1}"
|
||||||
f"epochstep{ctrl.step_info.epoch_step + 1}"
|
f"epochstep{ctrl.step_info.epoch_step + 1}"
|
||||||
|
|
|
@ -43,12 +43,6 @@ from realhf.base import (
|
||||||
timeutil,
|
timeutil,
|
||||||
topology,
|
topology,
|
||||||
)
|
)
|
||||||
from realhf.base.monitor import (
|
|
||||||
CUDATimeMarkType,
|
|
||||||
cuda_tmark,
|
|
||||||
cuda_tmarked,
|
|
||||||
dump_tmark_db,
|
|
||||||
)
|
|
||||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||||
from realhf.impl.model.utils import cuda_graph
|
from realhf.impl.model.utils import cuda_graph
|
||||||
from realhf.system import request_reply_stream, worker_base
|
from realhf.system import request_reply_stream, worker_base
|
||||||
|
@ -136,7 +130,7 @@ class ModelWorker(worker_base.Worker):
|
||||||
r = self.config.worker_info
|
r = self.config.worker_info
|
||||||
|
|
||||||
# recover info
|
# recover info
|
||||||
self.__recover_run, self.__recover_info = recover.load_recover_info()
|
self.__recover_run, self.__recover_info = recover.load_recover_info(self.args)
|
||||||
|
|
||||||
# Whether to enable profiling is controlled by the following environment variables.
|
# Whether to enable profiling is controlled by the following environment variables.
|
||||||
self.__enable_profiler = os.getenv("REAL_DUMP_TRACE", "0") == "1"
|
self.__enable_profiler = os.getenv("REAL_DUMP_TRACE", "0") == "1"
|
||||||
|
@ -174,11 +168,7 @@ class ModelWorker(worker_base.Worker):
|
||||||
epoch = self.__recover_info.last_step_info.epoch + 1
|
epoch = self.__recover_info.last_step_info.epoch + 1
|
||||||
epochstep = self.__recover_info.last_step_info.epoch_step + 1
|
epochstep = self.__recover_info.last_step_info.epoch_step + 1
|
||||||
globalstep = self.__recover_info.last_step_info.global_step + 1
|
globalstep = self.__recover_info.last_step_info.global_step + 1
|
||||||
save_root = os.path.join(
|
save_root = constants.get_save_path(self.args)
|
||||||
constants.MODEL_SAVE_ROOT,
|
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
)
|
|
||||||
if epoch > 0:
|
if epoch > 0:
|
||||||
role_path = os.path.join(save_root, role)
|
role_path = os.path.join(save_root, role)
|
||||||
if os.path.exists(role_path):
|
if os.path.exists(role_path):
|
||||||
|
@ -251,6 +241,7 @@ class ModelWorker(worker_base.Worker):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.__pg_info = global_comm.setup_global_comm(
|
self.__pg_info = global_comm.setup_global_comm(
|
||||||
|
args=self.args,
|
||||||
expr_name=self.__experiment_name,
|
expr_name=self.__experiment_name,
|
||||||
trial_name=self.__trial_name,
|
trial_name=self.__trial_name,
|
||||||
worker_index=self.__worker_index,
|
worker_index=self.__worker_index,
|
||||||
|
@ -312,13 +303,6 @@ class ModelWorker(worker_base.Worker):
|
||||||
self.__dataset_dp_rank,
|
self.__dataset_dp_rank,
|
||||||
self.__dataset_dp_size,
|
self.__dataset_dp_size,
|
||||||
self.config.tokenizer_name_or_path,
|
self.config.tokenizer_name_or_path,
|
||||||
self.config.worker_info.experiment_name,
|
|
||||||
self.config.worker_info.trial_name,
|
|
||||||
cache_root=(
|
|
||||||
None
|
|
||||||
if not self.config.use_dataset_cache
|
|
||||||
else self.config.dataset_cahce_root
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
for d in self.config.datasets
|
for d in self.config.datasets
|
||||||
]
|
]
|
||||||
|
@ -388,9 +372,7 @@ class ModelWorker(worker_base.Worker):
|
||||||
and hasattr(d, "filter")
|
and hasattr(d, "filter")
|
||||||
):
|
):
|
||||||
dataset_indices_path = os.path.join(
|
dataset_indices_path = os.path.join(
|
||||||
constants.MODEL_SAVE_ROOT,
|
constants.get_save_path(self.args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"dataset_indices",
|
"dataset_indices",
|
||||||
f"{self._dp_rank}_{i}.npy",
|
f"{self._dp_rank}_{i}.npy",
|
||||||
)
|
)
|
||||||
|
@ -438,13 +420,6 @@ class ModelWorker(worker_base.Worker):
|
||||||
s.id.dp_rank,
|
s.id.dp_rank,
|
||||||
s.id.topo.get_dim("data"),
|
s.id.topo.get_dim("data"),
|
||||||
self.__models[s.id.model_name].tokenizer,
|
self.__models[s.id.model_name].tokenizer,
|
||||||
self.config.worker_info.experiment_name,
|
|
||||||
self.config.worker_info.trial_name,
|
|
||||||
cache_root=(
|
|
||||||
None
|
|
||||||
if not self.config.use_dataset_cache
|
|
||||||
else self.config.dataset_cahce_root
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
eval_dataloader = torch.utils.data.DataLoader(
|
eval_dataloader = torch.utils.data.DataLoader(
|
||||||
eval_dataset,
|
eval_dataset,
|
||||||
|
@ -497,16 +472,15 @@ class ModelWorker(worker_base.Worker):
|
||||||
self.__param_realloc(hook_data)
|
self.__param_realloc(hook_data)
|
||||||
elif hook == "offload":
|
elif hook == "offload":
|
||||||
# NOTE: Profiling (or cuda synchronization) will cause an overhead ~0.5s.
|
# NOTE: Profiling (or cuda synchronization) will cause an overhead ~0.5s.
|
||||||
with cuda_tmarked("offload", CUDATimeMarkType.mem_layout):
|
m = self.__unwrapped_models[hook_data["model_name"]]
|
||||||
m = self.__unwrapped_models[hook_data["model_name"]]
|
if not isinstance(m, ReaLModel):
|
||||||
if not isinstance(m, ReaLModel):
|
logger.warning(
|
||||||
logger.warning(
|
f"Model {hook_data['model_name']} (type={type(m)}) is not a ReaLModel, "
|
||||||
f"Model {hook_data['model_name']} (type={type(m)}) is not a ReaLModel, "
|
f"so it can't use offload."
|
||||||
f"so it can't use offload."
|
)
|
||||||
)
|
return
|
||||||
return
|
if not m._offloaded:
|
||||||
if not m._offloaded:
|
m.async_offload()
|
||||||
m.async_offload()
|
|
||||||
elif hook == "save":
|
elif hook == "save":
|
||||||
self.__save_model(hook_data)
|
self.__save_model(hook_data)
|
||||||
elif hook == "evaluate":
|
elif hook == "evaluate":
|
||||||
|
@ -602,15 +576,11 @@ class ModelWorker(worker_base.Worker):
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
# Upon the first fetch request, filter dataset and create dataloader.
|
# Upon the first fetch request, filter dataset and create dataloader.
|
||||||
eval_scores_path = os.path.join(
|
eval_scores_path = os.path.join(
|
||||||
constants.MODEL_SAVE_ROOT,
|
constants.get_save_path(self.args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"dataset_eval_scores.json",
|
"dataset_eval_scores.json",
|
||||||
)
|
)
|
||||||
dataset_indices_path = os.path.join(
|
dataset_indices_path = os.path.join(
|
||||||
constants.MODEL_SAVE_ROOT,
|
constants.get_save_path(self.args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"dataset_indices",
|
"dataset_indices",
|
||||||
f"{dp_rank}_{dataset_id}.npy",
|
f"{dp_rank}_{dataset_id}.npy",
|
||||||
)
|
)
|
||||||
|
@ -668,7 +638,7 @@ class ModelWorker(worker_base.Worker):
|
||||||
continue
|
continue
|
||||||
if self.data_manager.has_data(x.ids[0]):
|
if self.data_manager.has_data(x.ids[0]):
|
||||||
continue
|
continue
|
||||||
data_loaded.append(x)
|
data_loaded.append(x.cpu())
|
||||||
self.data_manager.store(x)
|
self.data_manager.store(x)
|
||||||
assert len(set([x.ids[0] for x in data_loaded])) == len(data_loaded)
|
assert len(set([x.ids[0] for x in data_loaded])) == len(data_loaded)
|
||||||
|
|
||||||
|
@ -700,25 +670,23 @@ class ModelWorker(worker_base.Worker):
|
||||||
"dataset_size": self.dataset_size,
|
"dataset_size": self.dataset_size,
|
||||||
}
|
}
|
||||||
elif request.handle_name == "clear_data_cache":
|
elif request.handle_name == "clear_data_cache":
|
||||||
with cuda_tmarked("clear_data_cache", CUDATimeMarkType.misc):
|
ids = request.data
|
||||||
ids = request.data
|
self.data_manager.remove(ids)
|
||||||
self.data_manager.remove(ids)
|
gc.collect()
|
||||||
gc.collect()
|
if (
|
||||||
if (
|
self.config.cuda_cache_cleanliness
|
||||||
self.config.cuda_cache_cleanliness
|
and self.__clear_cache_frequency.check()
|
||||||
and self.__clear_cache_frequency.check()
|
):
|
||||||
):
|
st = time.monotonic()
|
||||||
st = time.monotonic()
|
self._clear_memory(force=True)
|
||||||
self._clear_memory(force=True)
|
et = time.monotonic()
|
||||||
et = time.monotonic()
|
blogger.debug(
|
||||||
blogger.debug(
|
f"Model worker {self.__worker_index} cleared cache in {et-st:.4f}s. "
|
||||||
f"Model worker {self.__worker_index} cleared cache in {et-st:.4f}s. "
|
)
|
||||||
)
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Get clear_data_cache, dump cuda tmark. "
|
"Get clear_data_cache. "
|
||||||
f"Remaining data in local storage: {self.data_manager.storage_size()}. "
|
f"Remaining data in local storage: {self.data_manager.storage_size()}. "
|
||||||
)
|
)
|
||||||
dump_tmark_db(self.__worker_index)
|
|
||||||
res = request_reply_stream.NoResponse()
|
res = request_reply_stream.NoResponse()
|
||||||
self.__reply_queue.put_nowait((request, res))
|
self.__reply_queue.put_nowait((request, res))
|
||||||
self.__request_sample_size[request.request_id] = 1
|
self.__request_sample_size[request.request_id] = 1
|
||||||
|
@ -820,9 +788,7 @@ class ModelWorker(worker_base.Worker):
|
||||||
tik = time.perf_counter()
|
tik = time.perf_counter()
|
||||||
global_step = self.__models[model_name].version.global_step
|
global_step = self.__models[model_name].version.global_step
|
||||||
realloc_dir = os.path.join(
|
realloc_dir = os.path.join(
|
||||||
constants.PARAM_REALLOC_PATH,
|
constants.get_param_realloc_path(self.args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
model_name.role,
|
model_name.role,
|
||||||
str(global_step),
|
str(global_step),
|
||||||
)
|
)
|
||||||
|
@ -852,9 +818,7 @@ class ModelWorker(worker_base.Worker):
|
||||||
|
|
||||||
def _get_setup_logdir(self, name):
|
def _get_setup_logdir(self, name):
|
||||||
subdir = os.path.join(
|
subdir = os.path.join(
|
||||||
constants.LOG_ROOT,
|
constants.get_log_path(self.args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
name,
|
name,
|
||||||
f"setup{self._setup_counter}",
|
f"setup{self._setup_counter}",
|
||||||
)
|
)
|
||||||
|
@ -990,9 +954,7 @@ class ModelWorker(worker_base.Worker):
|
||||||
raise NotImplementedError(f"Unknown MFC type: {request.handle_name}.")
|
raise NotImplementedError(f"Unknown MFC type: {request.handle_name}.")
|
||||||
|
|
||||||
eval_scores_path = os.path.join(
|
eval_scores_path = os.path.join(
|
||||||
constants.MODEL_SAVE_ROOT,
|
constants.get_save_path(self.args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"dataset_eval_scores.json",
|
"dataset_eval_scores.json",
|
||||||
)
|
)
|
||||||
eval_scores = {}
|
eval_scores = {}
|
||||||
|
@ -1061,7 +1023,6 @@ class ModelWorker(worker_base.Worker):
|
||||||
dist.barrier(group=constants.cpu_parallelism_group())
|
dist.barrier(group=constants.cpu_parallelism_group())
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@cuda_tmark("data_transfer", CUDATimeMarkType.comm)
|
|
||||||
def __data_transfer_among_workers(self, hook_data: Dict[str, Any]):
|
def __data_transfer_among_workers(self, hook_data: Dict[str, Any]):
|
||||||
meta_sample = hook_data["meta_sample"]
|
meta_sample = hook_data["meta_sample"]
|
||||||
|
|
||||||
|
@ -1130,9 +1091,7 @@ class ModelWorker(worker_base.Worker):
|
||||||
global_step = int(global_step.item())
|
global_step = int(global_step.item())
|
||||||
|
|
||||||
realloc_dir = os.path.join(
|
realloc_dir = os.path.join(
|
||||||
constants.PARAM_REALLOC_PATH,
|
constants.get_param_realloc_path(self.args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
from_model_name.role,
|
from_model_name.role,
|
||||||
str(global_step),
|
str(global_step),
|
||||||
)
|
)
|
||||||
|
@ -1284,7 +1243,6 @@ class ModelWorker(worker_base.Worker):
|
||||||
f"Time consumption: {float(t):.4f}s."
|
f"Time consumption: {float(t):.4f}s."
|
||||||
)
|
)
|
||||||
|
|
||||||
@cuda_tmark("post_response", CUDATimeMarkType.misc)
|
|
||||||
def maybe_post_responses(self):
|
def maybe_post_responses(self):
|
||||||
ready_to_post = []
|
ready_to_post = []
|
||||||
while True:
|
while True:
|
||||||
|
@ -1335,7 +1293,6 @@ class ModelWorker(worker_base.Worker):
|
||||||
time.sleep(_MODEL_WORKER_POLL_REQUESTS_INTERVAL_SECS)
|
time.sleep(_MODEL_WORKER_POLL_REQUESTS_INTERVAL_SECS)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@cuda_tmark("receive_request", CUDATimeMarkType.misc)
|
|
||||||
def maybe_receive_requests(self):
|
def maybe_receive_requests(self):
|
||||||
tik = time.perf_counter()
|
tik = time.perf_counter()
|
||||||
while time.perf_counter() - tik < _MODEL_WORKER_POLL_REQUESTS_SECS:
|
while time.perf_counter() - tik < _MODEL_WORKER_POLL_REQUESTS_SECS:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from queue import Empty as QueueEmpty
|
from queue import Empty as QueueEmpty
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
@ -6,7 +7,7 @@ import orjson
|
||||||
import zmq
|
import zmq
|
||||||
from zmq.utils.strtypes import asbytes
|
from zmq.utils.strtypes import asbytes
|
||||||
|
|
||||||
from realhf.base import logging, name_resolve, names, network
|
from realhf.base import constants, logging, name_resolve, names, network
|
||||||
|
|
||||||
logger = logging.getLogger("ZMQ Push-Pull Stream")
|
logger = logging.getLogger("ZMQ Push-Pull Stream")
|
||||||
|
|
||||||
|
@ -160,12 +161,16 @@ class NameResolvingZmqPusher(ZMQJsonPusher):
|
||||||
|
|
||||||
|
|
||||||
class NameResolvingZmqPuller(ZMQJsonPuller):
|
class NameResolvingZmqPuller(ZMQJsonPuller):
|
||||||
def __init__(self, experiment_name, trial_name, puller_index, **kwargs):
|
def __init__(self, args, puller_index, **kwargs):
|
||||||
|
experiment_name = args.experiment_name
|
||||||
|
trial_name = args.trial_name
|
||||||
name = names.push_pull_stream(
|
name = names.push_pull_stream(
|
||||||
experiment_name, trial_name, stream_name=f"puller{puller_index}"
|
experiment_name, trial_name, stream_name=f"puller{puller_index}"
|
||||||
)
|
)
|
||||||
host, port = network.gethostip(), network.find_free_port(
|
host, port = network.gethostip(), network.find_free_port(
|
||||||
experiment_name=experiment_name, trial_name=trial_name
|
experiment_name=experiment_name,
|
||||||
|
trial_name=trial_name,
|
||||||
|
lockfile_root=os.path.join(constants.get_cache_path(args), "ports"),
|
||||||
)
|
)
|
||||||
addr = f"{host}:{port}"
|
addr = f"{host}:{port}"
|
||||||
name_resolve.add(name, addr)
|
name_resolve.add(name, addr)
|
||||||
|
|
|
@ -6,7 +6,7 @@ import itertools
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import *
|
from typing import *
|
||||||
|
|
||||||
from realhf.base.cluster import spec as cluster_spec
|
from realhf.api.cli_args import ClusterSpecConfig
|
||||||
|
|
||||||
|
|
||||||
class GlobalStorageTracker:
|
class GlobalStorageTracker:
|
||||||
|
@ -70,7 +70,10 @@ class RedistribStep:
|
||||||
|
|
||||||
|
|
||||||
class RedistribPlanner:
|
class RedistribPlanner:
|
||||||
def __init__(self, storage_tracker: GlobalStorageTracker):
|
def __init__(
|
||||||
|
self, cluster_config: ClusterSpecConfig, storage_tracker: GlobalStorageTracker
|
||||||
|
):
|
||||||
|
self.cluster_config = cluster_config
|
||||||
self.storage_tracker = storage_tracker
|
self.storage_tracker = storage_tracker
|
||||||
|
|
||||||
def derive_plan(
|
def derive_plan(
|
||||||
|
@ -269,8 +272,8 @@ class RedistribPlanner:
|
||||||
return self._group_bcast_transfers()
|
return self._group_bcast_transfers()
|
||||||
|
|
||||||
def _on_same_node(self, i, j) -> bool:
|
def _on_same_node(self, i, j) -> bool:
|
||||||
return (i // cluster_spec.n_gpus_per_node) == (
|
return (i // self.cluster_config.n_gpus_per_node) == (
|
||||||
j // cluster_spec.n_gpus_per_node
|
j // self.cluster_config.n_gpus_per_node
|
||||||
)
|
)
|
||||||
|
|
||||||
def _select_best_bcast_source(self, source_gpus, target_gpus):
|
def _select_best_bcast_source(self, source_gpus, target_gpus):
|
||||||
|
|
|
@ -87,7 +87,7 @@ class RolloutWorker(AsyncWorker):
|
||||||
self.rollout_stat = RolloutStat()
|
self.rollout_stat = RolloutStat()
|
||||||
|
|
||||||
# recover info
|
# recover info
|
||||||
self.__recover_run, self.__recover_info = recover.load_recover_info()
|
self.__recover_run, self.__recover_info = recover.load_recover_info(self.args)
|
||||||
|
|
||||||
return config.worker_info
|
return config.worker_info
|
||||||
|
|
||||||
|
@ -101,13 +101,6 @@ class RolloutWorker(AsyncWorker):
|
||||||
self.worker_index,
|
self.worker_index,
|
||||||
self.worker_count,
|
self.worker_count,
|
||||||
self.config.tokenizer_path,
|
self.config.tokenizer_path,
|
||||||
self.config.worker_info.experiment_name,
|
|
||||||
self.config.worker_info.trial_name,
|
|
||||||
cache_root=(
|
|
||||||
None
|
|
||||||
if not self.config.use_dataset_cache
|
|
||||||
else self.config.dataset_cahce_root
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
for d in self.config.datasets
|
for d in self.config.datasets
|
||||||
]
|
]
|
||||||
|
@ -129,9 +122,7 @@ class RolloutWorker(AsyncWorker):
|
||||||
# Recover indices for dynamic dataset
|
# Recover indices for dynamic dataset
|
||||||
if hasattr(self.dataset, "filter"):
|
if hasattr(self.dataset, "filter"):
|
||||||
dataset_indices_path = os.path.join(
|
dataset_indices_path = os.path.join(
|
||||||
constants.MODEL_SAVE_ROOT,
|
constants.get_log_path(self.args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
f"dataset_indices_{self.worker_index}.npy",
|
f"dataset_indices_{self.worker_index}.npy",
|
||||||
)
|
)
|
||||||
if os.path.exists(dataset_indices_path):
|
if os.path.exists(dataset_indices_path):
|
||||||
|
@ -156,15 +147,11 @@ class RolloutWorker(AsyncWorker):
|
||||||
self.is_new_epoch = True
|
self.is_new_epoch = True
|
||||||
# Upon the first fetch request, filter dataset and create dataloader.
|
# Upon the first fetch request, filter dataset and create dataloader.
|
||||||
eval_scores_path = os.path.join(
|
eval_scores_path = os.path.join(
|
||||||
constants.MODEL_SAVE_ROOT,
|
constants.get_log_path(self.args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
"dataset_eval_scores.json",
|
"dataset_eval_scores.json",
|
||||||
)
|
)
|
||||||
dataset_indices_path = os.path.join(
|
dataset_indices_path = os.path.join(
|
||||||
constants.MODEL_SAVE_ROOT,
|
constants.get_log_path(self.args),
|
||||||
constants.experiment_name(),
|
|
||||||
constants.trial_name(),
|
|
||||||
f"dataset_indices_{self.worker_index}.npy",
|
f"dataset_indices_{self.worker_index}.npy",
|
||||||
)
|
)
|
||||||
if hasattr(self.dataset, "filter") and os.path.exists(eval_scores_path):
|
if hasattr(self.dataset, "filter") and os.path.exists(eval_scores_path):
|
||||||
|
|
|
@ -2,6 +2,7 @@ import queue
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from torch.utils.data import ConcatDataset, Dataset
|
from torch.utils.data import ConcatDataset, Dataset
|
||||||
|
@ -23,6 +24,7 @@ class PullerStreamDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
util: DatasetUtility,
|
util: DatasetUtility,
|
||||||
|
args,
|
||||||
dataset_cfgs: List[DatasetAbstraction],
|
dataset_cfgs: List[DatasetAbstraction],
|
||||||
pull_timeout_ms=100,
|
pull_timeout_ms=100,
|
||||||
):
|
):
|
||||||
|
@ -35,8 +37,6 @@ class PullerStreamDataset(Dataset):
|
||||||
dp_rank=util.dp_rank,
|
dp_rank=util.dp_rank,
|
||||||
world_size=util.world_size,
|
world_size=util.world_size,
|
||||||
tokenizer_or_tokenizer_name=util.tokenizer,
|
tokenizer_or_tokenizer_name=util.tokenizer,
|
||||||
experiment_name=constants.experiment_name(),
|
|
||||||
trial_name=constants.trial_name(),
|
|
||||||
)
|
)
|
||||||
for dataset_cfg in dataset_cfgs
|
for dataset_cfg in dataset_cfgs
|
||||||
]
|
]
|
||||||
|
@ -51,6 +51,8 @@ class PullerStreamDataset(Dataset):
|
||||||
self.data_queue = queue.Queue(maxsize=self.dataset_size * util.world_size)
|
self.data_queue = queue.Queue(maxsize=self.dataset_size * util.world_size)
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
|
|
||||||
|
self.args = args
|
||||||
|
|
||||||
# Pass ZMQ context (thread-safe) and let worker create the socket
|
# Pass ZMQ context (thread-safe) and let worker create the socket
|
||||||
self.util = util
|
self.util = util
|
||||||
self.worker_thread = threading.Thread(target=self._pull_data_worker)
|
self.worker_thread = threading.Thread(target=self._pull_data_worker)
|
||||||
|
@ -60,33 +62,33 @@ class PullerStreamDataset(Dataset):
|
||||||
"""Worker thread that creates its own ZMQ puller and streams data."""
|
"""Worker thread that creates its own ZMQ puller and streams data."""
|
||||||
# Initialize the puller inside the worker thread
|
# Initialize the puller inside the worker thread
|
||||||
stream = NameResolvingZmqPuller(
|
stream = NameResolvingZmqPuller(
|
||||||
constants.experiment_name(),
|
self.args,
|
||||||
constants.trial_name(),
|
|
||||||
puller_index=self.util.dp_rank,
|
puller_index=self.util.dp_rank,
|
||||||
)
|
)
|
||||||
try:
|
processed_data = None
|
||||||
while not self._stop_event.is_set():
|
while not self._stop_event.is_set():
|
||||||
|
if processed_data is not None:
|
||||||
try:
|
try:
|
||||||
data = stream.pull(timeout_ms=self.pull_timeout_ms)
|
self.data_queue.put_nowait(processed_data)
|
||||||
processed_data = [
|
processed_data = None
|
||||||
SequenceSample.from_json_compatible(x) for x in data
|
except queue.Full:
|
||||||
]
|
|
||||||
logger.debug(
|
|
||||||
f"Get data {[x.ids[0] for x in processed_data]} from puller stream."
|
|
||||||
)
|
|
||||||
self.data_queue.put(processed_data)
|
|
||||||
except queue.Empty:
|
|
||||||
logger.debug(f"No data from puller stream.")
|
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
continue
|
continue
|
||||||
finally:
|
try:
|
||||||
# Ensure socket is closed in the same thread
|
data = stream.pull(timeout_ms=self.pull_timeout_ms)
|
||||||
del stream
|
processed_data = [SequenceSample.from_json_compatible(x) for x in data]
|
||||||
# Exit if this thread has an error
|
logger.debug(
|
||||||
sys.exit(1)
|
f"Get data {[x.ids[0] for x in processed_data]} from puller stream."
|
||||||
|
)
|
||||||
|
except queue.Empty:
|
||||||
|
logger.debug(f"No data from puller stream.")
|
||||||
|
time.sleep(0.1)
|
||||||
|
continue
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Optional[Any]:
|
def __getitem__(self, idx: int) -> Optional[Any]:
|
||||||
samples = []
|
samples = []
|
||||||
|
if not self.worker_thread.is_alive():
|
||||||
|
raise RuntimeError("Stream dataset puller thread is not alive.")
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
samples += self.data_queue.get_nowait()
|
samples += self.data_queue.get_nowait()
|
||||||
|
@ -99,8 +101,6 @@ class PullerStreamDataset(Dataset):
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self._stop_event.set()
|
self._stop_event.set()
|
||||||
if self.worker_thread.is_alive():
|
|
||||||
self.worker_thread.join(timeout=1.0)
|
|
||||||
|
|
||||||
|
|
||||||
register_dataset("puller_stream", PullerStreamDataset)
|
register_dataset("puller_stream", PullerStreamDataset)
|
||||||
|
|
|
@ -568,6 +568,7 @@ class Worker:
|
||||||
self.__worker_index = worker_info.worker_index
|
self.__worker_index = worker_info.worker_index
|
||||||
|
|
||||||
experiment = system_api.make_experiment(name=worker_info.experiment_name)
|
experiment = system_api.make_experiment(name=worker_info.experiment_name)
|
||||||
|
self.args = experiment
|
||||||
|
|
||||||
expr_config = experiment.initial_setup()
|
expr_config = experiment.initial_setup()
|
||||||
if isinstance(expr_config, list):
|
if isinstance(expr_config, list):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -21,7 +22,7 @@ def mock_env():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def agent_config():
|
def agent_config(tmp_path):
|
||||||
return {
|
return {
|
||||||
"gconfig": MagicMock(n=2),
|
"gconfig": MagicMock(n=2),
|
||||||
"tokenizer_path": "/storage/openpsi/models/Qwen__Qwen2.5-0.5B-Instruct/",
|
"tokenizer_path": "/storage/openpsi/models/Qwen__Qwen2.5-0.5B-Instruct/",
|
||||||
|
@ -29,6 +30,7 @@ def agent_config():
|
||||||
"success_rate_ub": 1.0,
|
"success_rate_ub": 1.0,
|
||||||
"reward_scaling": 2.0,
|
"reward_scaling": 2.0,
|
||||||
"reward_bias": 0.1,
|
"reward_bias": 0.1,
|
||||||
|
"answer_save_path": tmp_path,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -140,48 +142,34 @@ async def test_collect_trajectory_empty_act_queue(agent, mock_env, mock_prompt):
|
||||||
|
|
||||||
def test_log_rewards_to_file(agent, tmp_path):
|
def test_log_rewards_to_file(agent, tmp_path):
|
||||||
# Setup test directories
|
# Setup test directories
|
||||||
with (
|
agent.log_rewards_to_file(
|
||||||
patch("realhf.base.constants.LOG_ROOT", tmp_path),
|
qid="123",
|
||||||
patch("realhf.base.constants.experiment_name", return_value="test_exp"),
|
prompt="test_prompt",
|
||||||
patch("realhf.base.constants.trial_name", return_value="test_trial"),
|
prompt_len=3,
|
||||||
):
|
answers=["answer1", "answer2"],
|
||||||
agent.log_rewards_to_file(
|
seqlens=[5, 6],
|
||||||
qid="123",
|
rewards=[0.5, 0.7],
|
||||||
prompt="test_prompt",
|
success=[True, False],
|
||||||
prompt_len=3,
|
version_starts=[1, 2],
|
||||||
answers=["answer1", "answer2"],
|
version_ends=[2, 3],
|
||||||
seqlens=[5, 6],
|
)
|
||||||
rewards=[0.5, 0.7],
|
|
||||||
success=[True, False],
|
|
||||||
version_starts=[1, 2],
|
|
||||||
version_ends=[2, 3],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check generated file
|
# Check generated file
|
||||||
gen_file_path = (
|
gen_file_path = Path(agent.answer_save_path) / "1" / "123.txt"
|
||||||
tmp_path / "test_exp" / "test_trial" / "generated" / "1" / "123.txt"
|
assert gen_file_path.exists()
|
||||||
)
|
with open(gen_file_path) as f:
|
||||||
assert gen_file_path.exists()
|
content = f.read()
|
||||||
with open(gen_file_path) as f:
|
assert "idx: 1 / 2" in content
|
||||||
content = f.read()
|
assert "seqlen: 5" in content
|
||||||
assert "idx: 1 / 2" in content
|
assert "test_prompt" in content
|
||||||
assert "seqlen: 5" in content
|
|
||||||
assert "test_prompt" in content
|
|
||||||
|
|
||||||
# Check monitor file
|
# Check monitor file
|
||||||
monitor_file_path = (
|
monitor_file_path = Path(agent.answer_save_path) / "1" / "123.jsonl"
|
||||||
tmp_path
|
assert monitor_file_path.exists()
|
||||||
/ "test_exp"
|
with open(monitor_file_path) as f:
|
||||||
/ "test_trial"
|
data = json.loads(f.readline())
|
||||||
/ "training_monitor"
|
assert data["version_start"] == 1
|
||||||
/ "1"
|
assert data["prompt_len"] == 3
|
||||||
/ "123.jsonl"
|
|
||||||
)
|
|
||||||
assert monitor_file_path.exists()
|
|
||||||
with open(monitor_file_path) as f:
|
|
||||||
data = json.loads(f.readline())
|
|
||||||
assert data["version_start"] == 1
|
|
||||||
assert data["prompt_len"] == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_reward_calculation(agent):
|
def test_reward_calculation(agent):
|
||||||
|
|
|
@ -14,6 +14,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from realhf.api.cli_args import ClusterSpecConfig
|
||||||
from realhf.api.core.config import ModelName, ModelShardID
|
from realhf.api.core.config import ModelName, ModelShardID
|
||||||
from realhf.api.core.data_api import SequenceSample
|
from realhf.api.core.data_api import SequenceSample
|
||||||
from realhf.base import constants, testing, topology
|
from realhf.base import constants, testing, topology
|
||||||
|
@ -166,7 +167,7 @@ def _test_data_transfer(
|
||||||
data_manager.setup_process_groups()
|
data_manager.setup_process_groups()
|
||||||
|
|
||||||
storage_tracker = GlobalStorageTracker(dist.get_world_size())
|
storage_tracker = GlobalStorageTracker(dist.get_world_size())
|
||||||
planner = RedistribPlanner(storage_tracker)
|
planner = RedistribPlanner(ClusterSpecConfig(), storage_tracker)
|
||||||
|
|
||||||
key = "input_ids"
|
key = "input_ids"
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,12 @@
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import realhf.base.constants as constants
|
from realhf.base import constants, name_resolve, testing
|
||||||
import realhf.base.testing as testing
|
|
||||||
|
|
||||||
|
|
||||||
# This is a test for grouped_gemm experts implementation of MoE.
|
# This is a test for grouped_gemm experts implementation of MoE.
|
||||||
|
@ -99,10 +99,7 @@ def run_grouped_mlp(num_tokens, tp_size, token_dispatch_strategy, seed=1):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skip("grouped_gemm is not used for now.")
|
||||||
not torch.cuda.is_available(),
|
|
||||||
reason="This test requires GPU to run",
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("num_tokens", [200])
|
@pytest.mark.parametrize("num_tokens", [200])
|
||||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||||
@pytest.mark.parametrize("token_dispatch_strategy", ["random"])
|
@pytest.mark.parametrize("token_dispatch_strategy", ["random"])
|
||||||
|
@ -123,10 +120,7 @@ def test_grouped_mlp(
|
||||||
test.launch()
|
test.launch()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skip("grouped_gemm is not used for now.")
|
||||||
not torch.cuda.is_available(),
|
|
||||||
reason="This test requires GPU to run",
|
|
||||||
)
|
|
||||||
@pytest.mark.gpu
|
@pytest.mark.gpu
|
||||||
def test_grouped_gemm():
|
def test_grouped_gemm():
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
|
|
|
@ -36,7 +36,7 @@ def maybe_synchronize_cuda():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="This test requires a GPU.")
|
@pytest.mark.skip("interval_ops are not used now.")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"n_intervals", list(reversed([1, 100, 500, 1000, 2000, 4000, 10000, 100000]))
|
"n_intervals", list(reversed([1, 100, 500, 1000, 2000, 4000, 10000, 100000]))
|
||||||
)
|
)
|
||||||
|
@ -86,7 +86,7 @@ def test_get(n_intervals: int, dtype: torch.dtype):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="This test requires a GPU.")
|
@pytest.mark.skip("interval_ops are not used now.")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"n_intervals", list(reversed([1, 10, 100, 500, 1000, 1000, 10000, 100000]))
|
"n_intervals", list(reversed([1, 10, 100, 500, 1000, 1000, 10000, 100000]))
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,8 +21,6 @@ def _validate_dataset(cfg: config_api.DatasetAbstraction, tokenizer):
|
||||||
dp_rank=0,
|
dp_rank=0,
|
||||||
world_size=1,
|
world_size=1,
|
||||||
tokenizer_or_tokenizer_name=tokenizer,
|
tokenizer_or_tokenizer_name=tokenizer,
|
||||||
experiment_name=str(uuid.uuid4()),
|
|
||||||
trial_name=str(uuid.uuid4()),
|
|
||||||
)
|
)
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
|
|
@ -21,13 +21,13 @@ BACKENDS = [
|
||||||
("nfs", {}),
|
("nfs", {}),
|
||||||
("ray", {}),
|
("ray", {}),
|
||||||
]
|
]
|
||||||
if os.environ.get("REAL_ETCD_ADDR"):
|
if os.environ.get("TESTING_ETCD_ADDR"):
|
||||||
BACKENDS.append(
|
BACKENDS.append(
|
||||||
(
|
(
|
||||||
"etcd3",
|
"etcd3",
|
||||||
{
|
{
|
||||||
"host": os.getenv("REAL_ETCD_ADDR").split(":")[0],
|
"host": os.getenv("TESTING_ETCD_ADDR").split(":")[0],
|
||||||
"port": int(os.getenv("REAL_ETCD_ADDR").split(":")[1]),
|
"port": int(os.getenv("TESTING_ETCD_ADDR").split(":")[1]),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -43,12 +43,9 @@ def name_resolve(request):
|
||||||
temp_dir = tempfile.mkdtemp()
|
temp_dir = tempfile.mkdtemp()
|
||||||
from realhf.base.name_resolve import NfsNameRecordRepository
|
from realhf.base.name_resolve import NfsNameRecordRepository
|
||||||
|
|
||||||
original_root = NfsNameRecordRepository.RECORD_ROOT
|
repo = NfsNameRecordRepository(temp_dir)
|
||||||
NfsNameRecordRepository.RECORD_ROOT = temp_dir
|
|
||||||
repo = NfsNameRecordRepository()
|
|
||||||
yield repo
|
yield repo
|
||||||
repo.reset()
|
repo.reset()
|
||||||
NfsNameRecordRepository.RECORD_ROOT = original_root
|
|
||||||
shutil.rmtree(temp_dir)
|
shutil.rmtree(temp_dir)
|
||||||
elif backend_type == "memory":
|
elif backend_type == "memory":
|
||||||
from realhf.base.name_resolve import MemoryNameRecordRepository
|
from realhf.base.name_resolve import MemoryNameRecordRepository
|
||||||
|
@ -208,17 +205,6 @@ def test_reset(name_resolve):
|
||||||
name_resolve.delete("test_key_no_delete")
|
name_resolve.delete("test_key_no_delete")
|
||||||
|
|
||||||
|
|
||||||
def test_context_manager(name_resolve):
|
|
||||||
"""Test context manager functionality."""
|
|
||||||
with name_resolve.__class__() as repo:
|
|
||||||
repo.add("test_key", "test_value", delete_on_exit=True)
|
|
||||||
assert repo.get("test_key") == "test_value"
|
|
||||||
|
|
||||||
# Key should be deleted after context exits
|
|
||||||
with pytest.raises(NameEntryNotFoundError):
|
|
||||||
name_resolve.get("test_key")
|
|
||||||
|
|
||||||
|
|
||||||
def test_concurrent_access(name_resolve):
|
def test_concurrent_access(name_resolve):
|
||||||
"""Test concurrent access to the same key."""
|
"""Test concurrent access to the same key."""
|
||||||
name_resolve.add("test_key", "initial_value")
|
name_resolve.add("test_key", "initial_value")
|
||||||
|
@ -648,7 +634,9 @@ def test_corner_case_get_same_as_prefix(name_resolve):
|
||||||
assert set(keys) == {"prefix", "prefix/child"}
|
assert set(keys) == {"prefix", "prefix/child"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(os.getenv("REAL_ETCD_ADDR") is None, reason="ETCD3 not configured")
|
@pytest.mark.skipif(
|
||||||
|
os.getenv("TESTING_ETCD_ADDR") is None, reason="ETCD3 not configured"
|
||||||
|
)
|
||||||
def test_etcd3_specific_features(name_resolve):
|
def test_etcd3_specific_features(name_resolve):
|
||||||
if not isinstance(name_resolve, Etcd3NameRecordRepository):
|
if not isinstance(name_resolve, Etcd3NameRecordRepository):
|
||||||
pytest.skip("ETCD3 specific test")
|
pytest.skip("ETCD3 specific test")
|
||||||
|
@ -663,7 +651,9 @@ def test_etcd3_specific_features(name_resolve):
|
||||||
name_resolve.get("test_key")
|
name_resolve.get("test_key")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(os.getenv("REAL_ETCD_ADDR") is not None, reason="NFS specific test")
|
@pytest.mark.skipif(
|
||||||
|
os.getenv("TESTING_ETCD_ADDR") is not None, reason="NFS specific test"
|
||||||
|
)
|
||||||
def test_nfs_specific_features(name_resolve):
|
def test_nfs_specific_features(name_resolve):
|
||||||
"""Test features specific to NFS backend."""
|
"""Test features specific to NFS backend."""
|
||||||
from realhf.base.name_resolve import NfsNameRecordRepository
|
from realhf.base.name_resolve import NfsNameRecordRepository
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import *
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from realhf.api.cli_args import (
|
from realhf.api.cli_args import (
|
||||||
|
ClusterSpecConfig,
|
||||||
ExperimentSaveEvalControl,
|
ExperimentSaveEvalControl,
|
||||||
MFCConfig,
|
MFCConfig,
|
||||||
ModelTrainEvalConfig,
|
ModelTrainEvalConfig,
|
||||||
|
@ -15,7 +16,7 @@ from realhf.api.cli_args import (
|
||||||
PromptAnswerDatasetConfig,
|
PromptAnswerDatasetConfig,
|
||||||
PromptOnlyDatasetConfig,
|
PromptOnlyDatasetConfig,
|
||||||
)
|
)
|
||||||
from realhf.base import cluster, logging, name_resolve, testing
|
from realhf.base import logging, name_resolve, testing
|
||||||
from realhf.experiments.common.null_exp import NullPPOConfig, NullSFTConfig
|
from realhf.experiments.common.null_exp import NullPPOConfig, NullSFTConfig
|
||||||
from tests.experiments.utils import run_test_exp
|
from tests.experiments.utils import run_test_exp
|
||||||
from tests.fixtures import *
|
from tests.fixtures import *
|
||||||
|
@ -64,11 +65,8 @@ def test_buffer_recover(
|
||||||
):
|
):
|
||||||
_, dataset_size = math_code_dataset_with_size
|
_, dataset_size = math_code_dataset_with_size
|
||||||
# Setup experiment env. Should be done before any other operations.
|
# Setup experiment env. Should be done before any other operations.
|
||||||
log_root = tmp_path_factory.mktemp("buffer-recover")
|
|
||||||
cluster.spec.fileroot = str(log_root)
|
|
||||||
expr_name = str(uuid.uuid4())
|
expr_name = str(uuid.uuid4())
|
||||||
trial_name = str(uuid.uuid4())
|
trial_name = str(uuid.uuid4())
|
||||||
testing.clear_name_resolve(expr_name, trial_name)
|
|
||||||
constants.set_experiment_trial_names(expr_name, trial_name)
|
constants.set_experiment_trial_names(expr_name, trial_name)
|
||||||
|
|
||||||
exp_cfg = NullPPOConfig(
|
exp_cfg = NullPPOConfig(
|
||||||
|
@ -114,6 +112,10 @@ def test_buffer_recover(
|
||||||
save_freq_steps=2,
|
save_freq_steps=2,
|
||||||
benchmark_steps=0,
|
benchmark_steps=0,
|
||||||
),
|
),
|
||||||
|
cluster=ClusterSpecConfig(
|
||||||
|
fileroot=str(tmp_path_factory.mktemp("buffer-recover")),
|
||||||
|
n_gpus_per_node=16,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
os.environ["REAL_SAVE_RECOVER_STATES"] = "1"
|
os.environ["REAL_SAVE_RECOVER_STATES"] = "1"
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import *
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from realhf.api.cli_args import (
|
from realhf.api.cli_args import (
|
||||||
|
ClusterSpecConfig,
|
||||||
ExperimentSaveEvalControl,
|
ExperimentSaveEvalControl,
|
||||||
GenerationHyperparameters,
|
GenerationHyperparameters,
|
||||||
MFCConfig,
|
MFCConfig,
|
||||||
|
@ -17,7 +18,7 @@ from realhf.api.cli_args import (
|
||||||
PPOHyperparameters,
|
PPOHyperparameters,
|
||||||
PromptOnlyDatasetConfig,
|
PromptOnlyDatasetConfig,
|
||||||
)
|
)
|
||||||
from realhf.base import cluster, testing
|
from realhf.base import testing
|
||||||
from realhf.experiments.common.ppo_math_exp import PPOMATHConfig
|
from realhf.experiments.common.ppo_math_exp import PPOMATHConfig
|
||||||
from tests.experiments.utils import run_test_exp
|
from tests.experiments.utils import run_test_exp
|
||||||
from tests.fixtures import *
|
from tests.fixtures import *
|
||||||
|
@ -73,8 +74,6 @@ def test_ppo_symm(
|
||||||
mp,
|
mp,
|
||||||
):
|
):
|
||||||
# Setup experiment env. Should be done before any other operations.
|
# Setup experiment env. Should be done before any other operations.
|
||||||
log_root = tmp_path_factory.mktemp("ppo")
|
|
||||||
cluster.spec.fileroot = str(log_root)
|
|
||||||
constants.set_experiment_trial_names(
|
constants.set_experiment_trial_names(
|
||||||
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
||||||
)
|
)
|
||||||
|
@ -117,6 +116,7 @@ def test_ppo_symm(
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
group_size=2,
|
group_size=2,
|
||||||
|
cluster=ClusterSpecConfig(fileroot=str(tmp_path_factory.mktemp("ppo"))),
|
||||||
)
|
)
|
||||||
|
|
||||||
run_test_exp(exp_cfg)
|
run_test_exp(exp_cfg)
|
||||||
|
@ -152,8 +152,6 @@ def test_ppo_decoupled(
|
||||||
gmp,
|
gmp,
|
||||||
):
|
):
|
||||||
# Setup experiment env. Should be done before any other operations.
|
# Setup experiment env. Should be done before any other operations.
|
||||||
log_root = tmp_path_factory.mktemp("ppo")
|
|
||||||
cluster.spec.fileroot = str(log_root)
|
|
||||||
constants.set_experiment_trial_names(
|
constants.set_experiment_trial_names(
|
||||||
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
||||||
)
|
)
|
||||||
|
@ -245,6 +243,7 @@ def test_ppo_decoupled(
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
group_size=2,
|
group_size=2,
|
||||||
|
cluster=ClusterSpecConfig(fileroot=str(tmp_path_factory.mktemp("ppo"))),
|
||||||
)
|
)
|
||||||
|
|
||||||
run_test_exp(exp_cfg)
|
run_test_exp(exp_cfg)
|
||||||
|
@ -275,8 +274,6 @@ def test_ppo_global_reshard(
|
||||||
rew_inf,
|
rew_inf,
|
||||||
):
|
):
|
||||||
# Setup experiment env. Should be done before any other operations.
|
# Setup experiment env. Should be done before any other operations.
|
||||||
log_root = tmp_path_factory.mktemp("ppo-global-reshard")
|
|
||||||
cluster.spec.fileroot = str(log_root)
|
|
||||||
constants.set_experiment_trial_names(
|
constants.set_experiment_trial_names(
|
||||||
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
||||||
)
|
)
|
||||||
|
@ -368,6 +365,7 @@ def test_ppo_global_reshard(
|
||||||
pipeline_parallel_size=critic_train[2],
|
pipeline_parallel_size=critic_train[2],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
cluster=ClusterSpecConfig(fileroot=str(tmp_path_factory.mktemp("ppo"))),
|
||||||
)
|
)
|
||||||
run_test_exp(exp_cfg)
|
run_test_exp(exp_cfg)
|
||||||
|
|
||||||
|
@ -388,8 +386,6 @@ def test_ppo_param_realloc_sub_device_mesh(
|
||||||
critic_inf,
|
critic_inf,
|
||||||
):
|
):
|
||||||
# Setup experiment env. Should be done before any other operations.
|
# Setup experiment env. Should be done before any other operations.
|
||||||
log_root = tmp_path_factory.mktemp("ppo-submesh")
|
|
||||||
cluster.spec.fileroot = str(log_root)
|
|
||||||
constants.set_experiment_trial_names(
|
constants.set_experiment_trial_names(
|
||||||
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
||||||
)
|
)
|
||||||
|
@ -484,6 +480,7 @@ def test_ppo_param_realloc_sub_device_mesh(
|
||||||
pipeline_parallel_size=2,
|
pipeline_parallel_size=2,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
cluster=ClusterSpecConfig(fileroot=str(tmp_path_factory.mktemp("ppo"))),
|
||||||
)
|
)
|
||||||
|
|
||||||
run_test_exp(exp_cfg)
|
run_test_exp(exp_cfg)
|
||||||
|
@ -503,13 +500,9 @@ def test_ppo_save(
|
||||||
bs,
|
bs,
|
||||||
):
|
):
|
||||||
# Setup experiment env. Should be done before any other operations.
|
# Setup experiment env. Should be done before any other operations.
|
||||||
log_root = tmp_path_factory.mktemp("ppo")
|
|
||||||
cluster.spec.fileroot = str(log_root)
|
|
||||||
constants.set_experiment_trial_names(
|
constants.set_experiment_trial_names(
|
||||||
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
||||||
)
|
)
|
||||||
shutil.rmtree(constants.MODEL_SAVE_ROOT, ignore_errors=True)
|
|
||||||
os.makedirs(constants.MODEL_SAVE_ROOT, exist_ok=True)
|
|
||||||
|
|
||||||
total_train_epochs = 3
|
total_train_epochs = 3
|
||||||
|
|
||||||
|
@ -604,7 +597,11 @@ def test_ppo_save(
|
||||||
pipeline_parallel_size=1,
|
pipeline_parallel_size=1,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
cluster=ClusterSpecConfig(fileroot=str(tmp_path_factory.mktemp("ppo"))),
|
||||||
)
|
)
|
||||||
|
shutil.rmtree(constants.get_save_path(exp_cfg), ignore_errors=True)
|
||||||
|
os.makedirs(constants.get_save_path(exp_cfg), exist_ok=True)
|
||||||
|
|
||||||
exp_cfg.actor.vllm.hybrid_train = True
|
exp_cfg.actor.vllm.hybrid_train = True
|
||||||
exp_cfg.actor.vllm.enforce_eager = True
|
exp_cfg.actor.vllm.enforce_eager = True
|
||||||
|
|
||||||
|
@ -636,9 +633,7 @@ def test_ppo_save(
|
||||||
int(os.path.basename(f).split("globalstep")[-1])
|
int(os.path.basename(f).split("globalstep")[-1])
|
||||||
for f in os.listdir(
|
for f in os.listdir(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
constants.MODEL_SAVE_ROOT,
|
constants.get_save_path(exp_cfg),
|
||||||
testing._DEFAULT_EXPR_NAME,
|
|
||||||
testing._DEFAULT_EXPR_NAME,
|
|
||||||
model_name,
|
model_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,11 +6,12 @@ from typing import *
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from realhf.api.cli_args import (
|
from realhf.api.cli_args import (
|
||||||
|
ClusterSpecConfig,
|
||||||
ExperimentSaveEvalControl,
|
ExperimentSaveEvalControl,
|
||||||
ModelTrainEvalConfig,
|
ModelTrainEvalConfig,
|
||||||
PromptAnswerDatasetConfig,
|
PromptAnswerDatasetConfig,
|
||||||
)
|
)
|
||||||
from realhf.base import cluster, testing
|
from realhf.base import testing
|
||||||
from realhf.experiments.common.sft_exp import SFTConfig
|
from realhf.experiments.common.sft_exp import SFTConfig
|
||||||
from tests.experiments.utils import run_test_exp
|
from tests.experiments.utils import run_test_exp
|
||||||
from tests.fixtures import *
|
from tests.fixtures import *
|
||||||
|
@ -32,9 +33,6 @@ def model_class(request):
|
||||||
(1, 2, 4),
|
(1, 2, 4),
|
||||||
(2, 4, 1),
|
(2, 4, 1),
|
||||||
(2, 1, 4),
|
(2, 1, 4),
|
||||||
(4, 2, 2),
|
|
||||||
(2, 4, 2),
|
|
||||||
(2, 2, 4),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
|
def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
|
||||||
|
@ -61,8 +59,6 @@ def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp
|
||||||
def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
|
def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
|
||||||
|
|
||||||
# Setup experiment env. Should be done before any other operations.
|
# Setup experiment env. Should be done before any other operations.
|
||||||
log_root = tmp_path_factory.mktemp("sft")
|
|
||||||
cluster.spec.fileroot = str(log_root)
|
|
||||||
constants.set_experiment_trial_names(
|
constants.set_experiment_trial_names(
|
||||||
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
||||||
)
|
)
|
||||||
|
@ -89,6 +85,7 @@ def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
|
||||||
valid_bs_n_seqs=minbs,
|
valid_bs_n_seqs=minbs,
|
||||||
fill_to_max_length=False,
|
fill_to_max_length=False,
|
||||||
),
|
),
|
||||||
|
cluster=ClusterSpecConfig(fileroot=str(tmp_path_factory.mktemp("sft"))),
|
||||||
)
|
)
|
||||||
|
|
||||||
run_test_exp(exp_cfg)
|
run_test_exp(exp_cfg)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# Copyright 2025 Ant Group Inc.
|
# Copyright 2025 Ant Group Inc.
|
||||||
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from typing import *
|
from typing import *
|
||||||
|
@ -6,7 +7,7 @@ from typing import *
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from realhf.api.core.system_api import Experiment, register_experiment
|
from realhf.api.core.system_api import Experiment, register_experiment
|
||||||
from realhf.base import cluster, constants, logging, testing
|
from realhf.base import constants, logging, testing
|
||||||
from realhf.system.worker_base import WorkerServerStatus
|
from realhf.system.worker_base import WorkerServerStatus
|
||||||
from tests.fixtures import *
|
from tests.fixtures import *
|
||||||
|
|
||||||
|
@ -74,7 +75,6 @@ def run_test_exp(
|
||||||
logger.info("Configuring master worker...")
|
logger.info("Configuring master worker...")
|
||||||
mas.configure(setup_id=0, worker_info=exp_setup.master_worker[0].worker_info)
|
mas.configure(setup_id=0, worker_info=exp_setup.master_worker[0].worker_info)
|
||||||
logger.info("Configuring master worker... Done.")
|
logger.info("Configuring master worker... Done.")
|
||||||
initd = False
|
|
||||||
|
|
||||||
# Run model workers in subprocesses
|
# Run model workers in subprocesses
|
||||||
barrier = mp.Barrier(len(exp_setup.model_worker))
|
barrier = mp.Barrier(len(exp_setup.model_worker))
|
||||||
|
@ -98,13 +98,17 @@ def run_test_exp(
|
||||||
testcase.start()
|
testcase.start()
|
||||||
|
|
||||||
# Run the master worker.
|
# Run the master worker.
|
||||||
for _ in range(int(1e4)):
|
async def run_master_worker():
|
||||||
if mas.status == WorkerServerStatus.PAUSED:
|
initd = False
|
||||||
break
|
for _ in range(int(1e4)):
|
||||||
if not initd:
|
if mas.status == WorkerServerStatus.PAUSED:
|
||||||
logger.info("Running master worker lazy initialization...")
|
break
|
||||||
mas._poll()
|
if not initd:
|
||||||
if not initd:
|
logger.info("Running master worker lazy initialization...")
|
||||||
logger.info("Running master worker lazy initialization... Done.")
|
await mas._poll_async()
|
||||||
initd = True
|
if not initd:
|
||||||
|
logger.info("Running master worker lazy initialization... Done.")
|
||||||
|
initd = True
|
||||||
|
|
||||||
|
asyncio.run(run_master_worker())
|
||||||
testcase.wait(timeout=0.1)
|
testcase.wait(timeout=0.1)
|
||||||
|
|
|
@ -16,7 +16,7 @@ from realhf.api.core.data_api import (
|
||||||
load_hf_tokenizer,
|
load_hf_tokenizer,
|
||||||
)
|
)
|
||||||
from realhf.api.core.model_api import FinetuneSpec, Model
|
from realhf.api.core.model_api import FinetuneSpec, Model
|
||||||
from realhf.base import constants, network, testing
|
from realhf.base import constants, name_resolve, network, testing
|
||||||
from tests.fixtures import *
|
from tests.fixtures import *
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,8 +61,12 @@ def math_code_dataset(request, save_path):
|
||||||
"tokenizer_path", ["/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"]
|
"tokenizer_path", ["/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"]
|
||||||
)
|
)
|
||||||
def test_multi_task_reward_interface(save_path, tokenizer_path, math_code_dataset):
|
def test_multi_task_reward_interface(save_path, tokenizer_path, math_code_dataset):
|
||||||
|
from realhf.api.cli_args import NameResolveConfig
|
||||||
from realhf.impl.dataset.math_code_dataset import MATHCodePromptDataset
|
from realhf.impl.dataset.math_code_dataset import MATHCodePromptDataset
|
||||||
|
|
||||||
|
name_resolve.reconfigure(
|
||||||
|
NameResolveConfig("nfs", f"/tmp/areal/{str(uuid.uuid4())}/")
|
||||||
|
)
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
rank=0, world_size=1, init_method=f"tcp://localhost:{network.find_free_port()}"
|
rank=0, world_size=1, init_method=f"tcp://localhost:{network.find_free_port()}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,7 +13,9 @@ from typing import Optional
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from realhf.api.cli_args import BaseExperimentConfig, NameResolveConfig
|
||||||
from realhf.api.core.config import ModelName
|
from realhf.api.core.config import ModelName
|
||||||
|
from realhf.api.core.system_api import ExpStatus
|
||||||
from realhf.api.core.system_api import GserverManager as GserverManagerConfig
|
from realhf.api.core.system_api import GserverManager as GserverManagerConfig
|
||||||
from realhf.api.core.system_api import WorkerInformation
|
from realhf.api.core.system_api import WorkerInformation
|
||||||
from realhf.base import constants, name_resolve, names, network, testing
|
from realhf.base import constants, name_resolve, names, network, testing
|
||||||
|
@ -42,6 +44,9 @@ def mock_servers():
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.responses import ORJSONResponse, PlainTextResponse
|
from fastapi.responses import ORJSONResponse, PlainTextResponse
|
||||||
|
|
||||||
|
name_resolve.reconfigure(
|
||||||
|
NameResolveConfig("nfs", "/tmp/areal/test-gserver-manager")
|
||||||
|
)
|
||||||
ports = network.find_multiple_free_ports(N_SERVERS)
|
ports = network.find_multiple_free_ports(N_SERVERS)
|
||||||
|
|
||||||
# Create mock server responses
|
# Create mock server responses
|
||||||
|
@ -124,7 +129,6 @@ def mock_servers():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def gserver_manager(request, mock_servers):
|
def gserver_manager(request, mock_servers):
|
||||||
train_batch_size, offpolicyness = request.param
|
train_batch_size, offpolicyness = request.param
|
||||||
testing.clear_name_resolve()
|
|
||||||
constants.set_experiment_trial_names(
|
constants.set_experiment_trial_names(
|
||||||
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
||||||
)
|
)
|
||||||
|
@ -135,6 +139,9 @@ def gserver_manager(request, mock_servers):
|
||||||
name_resolve.add_subentry(name, server_urls[0])
|
name_resolve.add_subentry(name, server_urls[0])
|
||||||
name_resolve.add_subentry(name, server_urls[1])
|
name_resolve.add_subentry(name, server_urls[1])
|
||||||
|
|
||||||
|
name = names.experiment_status(constants.experiment_name(), constants.trial_name())
|
||||||
|
name_resolve.add(name, ExpStatus.RUNNING)
|
||||||
|
|
||||||
# Mock requests.get for metrics endpoint
|
# Mock requests.get for metrics endpoint
|
||||||
m = GserverManager()
|
m = GserverManager()
|
||||||
config = GserverManagerConfig(
|
config = GserverManagerConfig(
|
||||||
|
@ -153,6 +160,7 @@ def gserver_manager(request, mock_servers):
|
||||||
worker_index=0,
|
worker_index=0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
m.args = BaseExperimentConfig()
|
||||||
m._configure(config)
|
m._configure(config)
|
||||||
# launch the server
|
# launch the server
|
||||||
m._poll()
|
m._poll()
|
||||||
|
|
|
@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
from realhf.api.cli_args import NameResolveConfig
|
||||||
from realhf.api.core.model_api import (
|
from realhf.api.core.model_api import (
|
||||||
APIGenerateInput,
|
APIGenerateInput,
|
||||||
APIGenerateOutput,
|
APIGenerateOutput,
|
||||||
|
@ -67,6 +68,10 @@ def partial_rollout_manager():
|
||||||
request_queue = asyncio.Queue()
|
request_queue = asyncio.Queue()
|
||||||
reply_queue = asyncio.Queue()
|
reply_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
name_resolve.reconfigure(
|
||||||
|
NameResolveConfig("nfs", "/tmp/areal/test-partial-rollout")
|
||||||
|
)
|
||||||
|
|
||||||
testing.clear_name_resolve()
|
testing.clear_name_resolve()
|
||||||
constants.set_experiment_trial_names(
|
constants.set_experiment_trial_names(
|
||||||
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
||||||
|
|
|
@ -7,9 +7,10 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from realhf.api.cli_args import BaseExperimentConfig, NameResolveConfig
|
||||||
from realhf.api.core import config as config_api
|
from realhf.api.core import config as config_api
|
||||||
from realhf.api.core import data_api
|
from realhf.api.core import data_api
|
||||||
from realhf.base import constants, testing
|
from realhf.base import constants, name_resolve, testing
|
||||||
from tests.fixtures import *
|
from tests.fixtures import *
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,6 +65,7 @@ def test_load_stream_dataset(prompt_dataset_cfg, tokenizer, mock_puller):
|
||||||
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
testing._DEFAULT_EXPR_NAME, testing._DEFAULT_TRIAL_NAME
|
||||||
)
|
)
|
||||||
|
|
||||||
|
name_resolve.reconfigure(NameResolveConfig("nfs", "/tmp/areal/test-stream-dataset"))
|
||||||
testing.clear_name_resolve()
|
testing.clear_name_resolve()
|
||||||
|
|
||||||
util = data_api.DatasetUtility(
|
util = data_api.DatasetUtility(
|
||||||
|
@ -74,7 +76,12 @@ def test_load_stream_dataset(prompt_dataset_cfg, tokenizer, mock_puller):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test initialization
|
# Test initialization
|
||||||
dataset = PullerStreamDataset(util, prompt_dataset_cfg, pull_timeout_ms=100)
|
dataset = PullerStreamDataset(
|
||||||
|
util,
|
||||||
|
args=BaseExperimentConfig(),
|
||||||
|
dataset_cfgs=prompt_dataset_cfg,
|
||||||
|
pull_timeout_ms=100,
|
||||||
|
)
|
||||||
assert len(dataset) > 0 # Should have non-zero size from prompt dataset
|
assert len(dataset) > 0 # Should have non-zero size from prompt dataset
|
||||||
assert dataset.data_queue.empty()
|
assert dataset.data_queue.empty()
|
||||||
|
|
||||||
|
@ -91,12 +98,14 @@ def test_load_stream_dataset(prompt_dataset_cfg, tokenizer, mock_puller):
|
||||||
assert len(items1) + len(items2) > 0
|
assert len(items1) + len(items2) > 0
|
||||||
|
|
||||||
# Test cleanup
|
# Test cleanup
|
||||||
|
dataset._stop_event.set()
|
||||||
del dataset
|
del dataset
|
||||||
|
|
||||||
|
|
||||||
def test_puller_stream_dataset_timeout(prompt_dataset_cfg, tokenizer):
|
def test_puller_stream_dataset_timeout(prompt_dataset_cfg, tokenizer):
|
||||||
from realhf.system.stream_dataset import PullerStreamDataset
|
from realhf.system.stream_dataset import PullerStreamDataset
|
||||||
|
|
||||||
|
name_resolve.reconfigure(NameResolveConfig("nfs", "/tmp/areal/test-stream-dataset"))
|
||||||
testing.clear_name_resolve()
|
testing.clear_name_resolve()
|
||||||
|
|
||||||
util = data_api.DatasetUtility(
|
util = data_api.DatasetUtility(
|
||||||
|
@ -109,15 +118,19 @@ def test_puller_stream_dataset_timeout(prompt_dataset_cfg, tokenizer):
|
||||||
with patch("realhf.system.stream_dataset.NameResolvingZmqPuller") as mock_puller:
|
with patch("realhf.system.stream_dataset.NameResolvingZmqPuller") as mock_puller:
|
||||||
mock_puller.return_value.pull.side_effect = queue.Empty()
|
mock_puller.return_value.pull.side_effect = queue.Empty()
|
||||||
|
|
||||||
dataset = PullerStreamDataset(util, prompt_dataset_cfg, pull_timeout_ms=10)
|
dataset = PullerStreamDataset(
|
||||||
|
util, BaseExperimentConfig(), prompt_dataset_cfg, pull_timeout_ms=10
|
||||||
|
)
|
||||||
# Should handle timeout gracefully
|
# Should handle timeout gracefully
|
||||||
assert dataset[0] == []
|
assert dataset[0] == []
|
||||||
|
dataset._stop_event.set()
|
||||||
del dataset
|
del dataset
|
||||||
|
|
||||||
|
|
||||||
def test_puller_stream_dataset_stop_event(prompt_dataset_cfg, tokenizer, mock_puller):
|
def test_puller_stream_dataset_stop_event(prompt_dataset_cfg, tokenizer, mock_puller):
|
||||||
from realhf.system.stream_dataset import PullerStreamDataset
|
from realhf.system.stream_dataset import PullerStreamDataset
|
||||||
|
|
||||||
|
name_resolve.reconfigure(NameResolveConfig("nfs", "/tmp/areal/test-stream-dataset"))
|
||||||
testing.clear_name_resolve()
|
testing.clear_name_resolve()
|
||||||
|
|
||||||
util = data_api.DatasetUtility(
|
util = data_api.DatasetUtility(
|
||||||
|
@ -127,7 +140,7 @@ def test_puller_stream_dataset_stop_event(prompt_dataset_cfg, tokenizer, mock_pu
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = PullerStreamDataset(util, prompt_dataset_cfg)
|
dataset = PullerStreamDataset(util, BaseExperimentConfig(), prompt_dataset_cfg)
|
||||||
assert not dataset._stop_event.is_set()
|
assert not dataset._stop_event.is_set()
|
||||||
|
|
||||||
# Trigger stop event and verify thread stops
|
# Trigger stop event and verify thread stops
|
||||||
|
@ -140,6 +153,7 @@ def test_puller_stream_dataset_stop_event(prompt_dataset_cfg, tokenizer, mock_pu
|
||||||
def test_puller_stream_dataset_worker_thread_exception(prompt_dataset_cfg, tokenizer):
|
def test_puller_stream_dataset_worker_thread_exception(prompt_dataset_cfg, tokenizer):
|
||||||
from realhf.system.stream_dataset import PullerStreamDataset
|
from realhf.system.stream_dataset import PullerStreamDataset
|
||||||
|
|
||||||
|
name_resolve.reconfigure(NameResolveConfig("nfs", "/tmp/areal/test-stream-dataset"))
|
||||||
testing.clear_name_resolve()
|
testing.clear_name_resolve()
|
||||||
|
|
||||||
util = data_api.DatasetUtility(
|
util = data_api.DatasetUtility(
|
||||||
|
@ -152,7 +166,7 @@ def test_puller_stream_dataset_worker_thread_exception(prompt_dataset_cfg, token
|
||||||
with patch("realhf.system.stream_dataset.NameResolvingZmqPuller") as mock_puller:
|
with patch("realhf.system.stream_dataset.NameResolvingZmqPuller") as mock_puller:
|
||||||
mock_puller.return_value.pull.side_effect = Exception("Test error")
|
mock_puller.return_value.pull.side_effect = Exception("Test error")
|
||||||
|
|
||||||
dataset = PullerStreamDataset(util, prompt_dataset_cfg)
|
dataset = PullerStreamDataset(util, BaseExperimentConfig(), prompt_dataset_cfg)
|
||||||
time.sleep(0.1) # Give thread time to crash
|
time.sleep(0.1) # Give thread time to crash
|
||||||
assert not dataset.worker_thread.is_alive()
|
assert not dataset.worker_thread.is_alive()
|
||||||
del dataset
|
del dataset
|
||||||
|
|
|
@ -8,7 +8,6 @@ import yaml
|
||||||
from omegaconf import MISSING, OmegaConf
|
from omegaconf import MISSING, OmegaConf
|
||||||
|
|
||||||
from realhf.api.quickstart.entrypoint import kind_reminder
|
from realhf.api.quickstart.entrypoint import kind_reminder
|
||||||
from realhf.base.constants import init_constants
|
|
||||||
from realhf.experiments.async_exp.async_ppo_math_exp import AsyncPPOMATHConfig
|
from realhf.experiments.async_exp.async_ppo_math_exp import AsyncPPOMATHConfig
|
||||||
from training.utils import run_experiment
|
from training.utils import run_experiment
|
||||||
|
|
||||||
|
@ -37,14 +36,10 @@ def main_ppo_math(args):
|
||||||
if args.mode != "ray":
|
if args.mode != "ray":
|
||||||
raise RuntimeError("This script only supports the `ray` mode.")
|
raise RuntimeError("This script only supports the `ray` mode.")
|
||||||
|
|
||||||
init_constants(args)
|
from realhf.base.constants import get_log_path
|
||||||
|
|
||||||
from realhf.base.constants import LOG_ROOT
|
|
||||||
|
|
||||||
# Save overwritten configuration to yaml
|
# Save overwritten configuration to yaml
|
||||||
config_save_path = os.path.join(
|
config_save_path = os.path.join(get_log_path(args), "config.yaml")
|
||||||
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
|
|
||||||
)
|
|
||||||
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
|
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
|
||||||
with open(config_save_path, "w") as f:
|
with open(config_save_path, "w") as f:
|
||||||
config_dict: Dict = dataclasses.asdict(args)
|
config_dict: Dict = dataclasses.asdict(args)
|
||||||
|
|
|
@ -8,7 +8,6 @@ import yaml
|
||||||
from omegaconf import MISSING, OmegaConf
|
from omegaconf import MISSING, OmegaConf
|
||||||
|
|
||||||
from realhf.api.quickstart.entrypoint import kind_reminder
|
from realhf.api.quickstart.entrypoint import kind_reminder
|
||||||
from realhf.base.constants import init_constants
|
|
||||||
from realhf.experiments.common.sft_exp import SFTConfig
|
from realhf.experiments.common.sft_exp import SFTConfig
|
||||||
from training.utils import run_experiment
|
from training.utils import run_experiment
|
||||||
|
|
||||||
|
@ -37,14 +36,10 @@ def main(args):
|
||||||
if args.mode != "ray":
|
if args.mode != "ray":
|
||||||
raise RuntimeError("This script only supports the `ray` mode.")
|
raise RuntimeError("This script only supports the `ray` mode.")
|
||||||
|
|
||||||
init_constants(args)
|
from realhf.base.constants import get_log_path
|
||||||
|
|
||||||
from realhf.base.constants import LOG_ROOT
|
|
||||||
|
|
||||||
# Save overwritten configuration to yaml
|
# Save overwritten configuration to yaml
|
||||||
config_save_path = os.path.join(
|
config_save_path = os.path.join(get_log_path(args), "config.yaml")
|
||||||
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
|
|
||||||
)
|
|
||||||
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
|
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
|
||||||
with open(config_save_path, "w") as f:
|
with open(config_save_path, "w") as f:
|
||||||
config_dict: Dict = dataclasses.asdict(args)
|
config_dict: Dict = dataclasses.asdict(args)
|
||||||
|
|
|
@ -8,7 +8,6 @@ import yaml
|
||||||
from omegaconf import MISSING, OmegaConf
|
from omegaconf import MISSING, OmegaConf
|
||||||
|
|
||||||
from realhf.api.quickstart.entrypoint import kind_reminder
|
from realhf.api.quickstart.entrypoint import kind_reminder
|
||||||
from realhf.base.constants import init_constants
|
|
||||||
from realhf.experiments.common.ppo_math_exp import PPOMATHConfig
|
from realhf.experiments.common.ppo_math_exp import PPOMATHConfig
|
||||||
from training.utils import run_experiment
|
from training.utils import run_experiment
|
||||||
|
|
||||||
|
@ -37,14 +36,10 @@ def main(args):
|
||||||
if args.mode != "ray":
|
if args.mode != "ray":
|
||||||
raise RuntimeError("This script only supports the `ray` mode.")
|
raise RuntimeError("This script only supports the `ray` mode.")
|
||||||
|
|
||||||
init_constants(args)
|
from realhf.base.constants import get_log_path
|
||||||
|
|
||||||
from realhf.base.constants import LOG_ROOT
|
|
||||||
|
|
||||||
# Save overwritten configuration to yaml
|
# Save overwritten configuration to yaml
|
||||||
config_save_path = os.path.join(
|
config_save_path = os.path.join(get_log_path(args), "config.yaml")
|
||||||
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
|
|
||||||
)
|
|
||||||
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
|
os.makedirs(os.path.dirname(config_save_path), exist_ok=True)
|
||||||
with open(config_save_path, "w") as f:
|
with open(config_save_path, "w") as f:
|
||||||
config_dict: Dict = dataclasses.asdict(args)
|
config_dict: Dict = dataclasses.asdict(args)
|
||||||
|
|
|
@ -12,6 +12,7 @@ from typing import Any, List
|
||||||
import psutil
|
import psutil
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
|
from realhf.api.cli_args import NameResolveConfig
|
||||||
from realhf.api.core.system_api import Experiment, ExperimentScheduling, TasksGroup
|
from realhf.api.core.system_api import Experiment, ExperimentScheduling, TasksGroup
|
||||||
from realhf.base import constants, logging, name_resolve, names
|
from realhf.base import constants, logging, name_resolve, names
|
||||||
from realhf.system import WORKER_TYPES, load_worker
|
from realhf.system import WORKER_TYPES, load_worker
|
||||||
|
@ -64,6 +65,7 @@ class RayWorker:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
args,
|
||||||
worker_type: str,
|
worker_type: str,
|
||||||
worker_cls,
|
worker_cls,
|
||||||
kv_store_name,
|
kv_store_name,
|
||||||
|
@ -74,15 +76,16 @@ class RayWorker:
|
||||||
|
|
||||||
os.environ["REAL_MODE"] = "RAY"
|
os.environ["REAL_MODE"] = "RAY"
|
||||||
|
|
||||||
name_resolve.reconfigure("ray", actor_name=kv_store_name)
|
name_recolve_config = NameResolveConfig("ray", ray_actor_name=kv_store_name)
|
||||||
|
name_resolve.reconfigure(name_recolve_config)
|
||||||
self.worker: Worker | AsyncWorker = worker_cls()
|
self.worker: Worker | AsyncWorker = worker_cls()
|
||||||
self.worker_type = worker_type
|
self.worker_type = worker_type
|
||||||
|
self.args = args
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "".join([c.capitalize() for c in self.worker_type.split("_")])
|
return "".join([c.capitalize() for c in self.worker_type.split("_")])
|
||||||
|
|
||||||
def configure(self, cfg: Any, expr_config: Any):
|
def configure(self, cfg: Any, expr_config: Any):
|
||||||
constants.init_constants(expr_config)
|
|
||||||
|
|
||||||
worker_info = cfg.worker_info
|
worker_info = cfg.worker_info
|
||||||
idx = worker_info.worker_index
|
idx = worker_info.worker_index
|
||||||
|
@ -92,6 +95,7 @@ class RayWorker:
|
||||||
self.worker.wandb_config = expr_config.wandb
|
self.worker.wandb_config = expr_config.wandb
|
||||||
self.worker.swanlab_config = expr_config.swanlab
|
self.worker.swanlab_config = expr_config.swanlab
|
||||||
self.worker.tensorboard_config = expr_config.tensorboard
|
self.worker.tensorboard_config = expr_config.tensorboard
|
||||||
|
self.worker.args = self.args
|
||||||
self.logger = logging.getLogger(f"{self.worker_type} {idx}", "benchmark")
|
self.logger = logging.getLogger(f"{self.worker_type} {idx}", "benchmark")
|
||||||
self.logger.info(f"Configuring {self.worker_type}...")
|
self.logger.info(f"Configuring {self.worker_type}...")
|
||||||
self.worker._configure(cfg)
|
self.worker._configure(cfg)
|
||||||
|
@ -125,6 +129,7 @@ def _run_experiment(exp_cfg, expr_name, trial_name):
|
||||||
|
|
||||||
# Initialize ray in the Ray cluster
|
# Initialize ray in the Ray cluster
|
||||||
env_vars = constants.get_env_vars(
|
env_vars = constants.get_env_vars(
|
||||||
|
exp_cfg,
|
||||||
WADNB_MODE=exp_cfg.wandb.mode,
|
WADNB_MODE=exp_cfg.wandb.mode,
|
||||||
SWANLAB_MODE=exp_cfg.swanlab.mode,
|
SWANLAB_MODE=exp_cfg.swanlab.mode,
|
||||||
REAL_MODE="ray",
|
REAL_MODE="ray",
|
||||||
|
@ -145,7 +150,8 @@ def _run_experiment(exp_cfg, expr_name, trial_name):
|
||||||
logger.info("Ray initialized! Ready to run workers.")
|
logger.info("Ray initialized! Ready to run workers.")
|
||||||
|
|
||||||
ray_kv_store_name = f"{expr_name}/{trial_name}/ray_kv_store"
|
ray_kv_store_name = f"{expr_name}/{trial_name}/ray_kv_store"
|
||||||
name_resolve.reconfigure("ray", actor_name=ray_kv_store_name)
|
name_recolve_config = NameResolveConfig("ray", ray_actor_name=ray_kv_store_name)
|
||||||
|
name_resolve.reconfigure(name_recolve_config)
|
||||||
|
|
||||||
name_resolve.clear_subtree(
|
name_resolve.clear_subtree(
|
||||||
names.trial_root(experiment_name=expr_name, trial_name=trial_name)
|
names.trial_root(experiment_name=expr_name, trial_name=trial_name)
|
||||||
|
@ -210,6 +216,7 @@ def _run_experiment(exp_cfg, expr_name, trial_name):
|
||||||
num_gpus=sch.scheduling.gpu,
|
num_gpus=sch.scheduling.gpu,
|
||||||
memory=sch.scheduling.mem * 1024**2,
|
memory=sch.scheduling.mem * 1024**2,
|
||||||
).remote(
|
).remote(
|
||||||
|
args=exp_cfg,
|
||||||
worker_type=worker_type,
|
worker_type=worker_type,
|
||||||
worker_cls=load_worker(worker_type),
|
worker_cls=load_worker(worker_type),
|
||||||
kv_store_name=ray_kv_store_name,
|
kv_store_name=ray_kv_store_name,
|
||||||
|
@ -239,7 +246,7 @@ def _run_experiment(exp_cfg, expr_name, trial_name):
|
||||||
run_jobs = []
|
run_jobs = []
|
||||||
for worker_type in all_workers:
|
for worker_type in all_workers:
|
||||||
workers = all_workers[worker_type]
|
workers = all_workers[worker_type]
|
||||||
if worker_type in ["master_worker", "rollout_worker", "gserver_manager"]:
|
if worker_type in ["master_worker", "rollout_worker"]:
|
||||||
# Only the rollout worker is asynchronous
|
# Only the rollout worker is asynchronous
|
||||||
jobs = [w.run_async.remote() for w in workers]
|
jobs = [w.run_async.remote() for w in workers]
|
||||||
else:
|
else:
|
||||||
|
@ -270,7 +277,7 @@ class DualOutput:
|
||||||
|
|
||||||
|
|
||||||
def run_experiment(exp_cfg, expr_name, trial_name):
|
def run_experiment(exp_cfg, expr_name, trial_name):
|
||||||
log_path = os.path.join(constants.LOG_ROOT, expr_name, trial_name, "main.log")
|
log_path = os.path.join(constants.get_log_path(exp_cfg), "main.log")
|
||||||
with open(log_path, "a") as f:
|
with open(log_path, "a") as f:
|
||||||
# Create dual output handler
|
# Create dual output handler
|
||||||
dual_out = DualOutput(f, sys.stdout)
|
dual_out = DualOutput(f, sys.stdout)
|
||||||
|
|
Loading…
Reference in New Issue