mirror of https://github.com/inclusionAI/AReaL
PullRequest: 6 Implement master worker v2 for refactoring and uvloop support
Merge branch fw/uvloop of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/6 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * fw/fix-dataloading-not-shuffle * . * . * . * . * . * add v2 master worker * cpu test pass * ppo run * . * format * fix * merge and format * change default env vars to v1 worker
This commit is contained in:
parent
ad5baa74ee
commit
c1bb4770ff
|
@ -828,3 +828,22 @@ def make_dataset(
|
|||
logger.info(f"Dataset creation/loading time: {time.perf_counter() - tik:.3f}s")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def gather_stat(src: List[Dict]) -> Dict:
|
||||
cnt, stats = {}, {}
|
||||
for reply in src:
|
||||
for k, v in reply.items():
|
||||
cnt[k] = cnt.get(k, 0) + 1
|
||||
stats[k] = stats.get(k, 0) + v
|
||||
res = {k: v / cnt for k, v, cnt in zip(stats.keys(), stats.values(), cnt.values())}
|
||||
for k, c in cnt.items():
|
||||
if c != len(src):
|
||||
logger.warning(f"Gathered `{k}` is not present in every returned stats.")
|
||||
for k, v in res.items():
|
||||
if any(abs(v - x.get(k, None)) > 1e-4 for x in src):
|
||||
logger.warning(
|
||||
f"Gathered `{k}` is not all-reduced "
|
||||
f"before returning: ({[x.get(k, None) for x in src]}, {v})."
|
||||
)
|
||||
return res
|
||||
|
|
|
@ -161,6 +161,7 @@ def main_start(args, recover_count: int = 0):
|
|||
REAL_RECOVER_RUN="1" if is_recover_run else "0",
|
||||
REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0",
|
||||
REAL_MATH_METADATA_PATH=os.getenv("REAL_MATH_METADATA_PATH", ""),
|
||||
REAL_USE_V2_WORKER=os.getenv("REAL_USE_V2_WORKER", "0"),
|
||||
)
|
||||
for k, v in BASE_ENVIRONS.items():
|
||||
os.environ[k] = v
|
||||
|
|
|
@ -31,8 +31,8 @@ from realhf.api.core.system_api import (
|
|||
ModelWorker,
|
||||
Scheduling,
|
||||
TasksGroup,
|
||||
WandBConfig,
|
||||
TensorBoardConfig,
|
||||
WandBConfig,
|
||||
)
|
||||
from realhf.api.quickstart.device_mesh import (
|
||||
DeviceMesh,
|
||||
|
@ -56,6 +56,7 @@ from realhf.experiments.common.check import (
|
|||
)
|
||||
from realhf.experiments.common.utils import (
|
||||
AllocationMode,
|
||||
asdict,
|
||||
get_topo,
|
||||
make_inf_backend_config,
|
||||
make_train_backend_config,
|
||||
|
@ -203,7 +204,9 @@ class CommonExperimentConfig(Experiment):
|
|||
partition: str = "dev"
|
||||
schedule_strategy: str = "empty_first"
|
||||
wandb: WandBConfig = dataclasses.field(default_factory=WandBConfig)
|
||||
tensorboard: TensorBoardConfig = dataclasses.field(default_factory=TensorBoardConfig)
|
||||
tensorboard: TensorBoardConfig = dataclasses.field(
|
||||
default_factory=TensorBoardConfig
|
||||
)
|
||||
image_name: Optional[str] = None
|
||||
recover_mode: str = "disabled"
|
||||
recover_retries: int = 1
|
||||
|
@ -561,9 +564,7 @@ class CommonExperimentConfig(Experiment):
|
|||
and self.gen_device_mesh.mapping[i, j]
|
||||
):
|
||||
gen_rpc_alloc = next(
|
||||
alloc
|
||||
for alloc in rpc_allocs
|
||||
if alloc.rpc.interface_type == ModelInterfaceType.GENERATE
|
||||
alloc for alloc in rpc_allocs if alloc.rpc.is_generate()
|
||||
)
|
||||
model_name = gen_rpc_alloc.rpc.model_name
|
||||
topo = get_topo(
|
||||
|
@ -620,6 +621,10 @@ class CommonExperimentConfig(Experiment):
|
|||
model_rpc_allocs,
|
||||
) in model_name_to_rpc_allocs.items():
|
||||
rpcs = [rpc_alloc.rpc for rpc_alloc in model_rpc_allocs]
|
||||
if self._allocation_mode.is_decoupled() and all(
|
||||
rpc.is_generate() for rpc in rpcs
|
||||
):
|
||||
continue
|
||||
rpc_alloc = model_rpc_allocs[0]
|
||||
model_cfg = self.models[model_name.role]
|
||||
model = get_real_model_config(
|
||||
|
@ -677,9 +682,7 @@ class CommonExperimentConfig(Experiment):
|
|||
rpc.is_generate() for rpc in rpcs
|
||||
):
|
||||
assert len(rpcs) == 1 and rpcs[0].is_generate(), rpcs
|
||||
vllm_dict_args: Dict[str, Any] = OmegaConf.to_container(
|
||||
model_cfg.vllm, resolve=True
|
||||
)
|
||||
vllm_dict_args: Dict[str, Any] = asdict(model_cfg.vllm)
|
||||
backend = ModelBackendAbstraction(
|
||||
"vllm",
|
||||
args=dict(
|
||||
|
@ -689,6 +692,17 @@ class CommonExperimentConfig(Experiment):
|
|||
)
|
||||
else:
|
||||
backend = make_inf_backend_config(model_cfg, rpc_alloc.parallel)
|
||||
if any(rpc.is_generate() for rpc in rpcs) and backend.type_ not in [
|
||||
"vllm",
|
||||
"sglang",
|
||||
]:
|
||||
print(rpcs, model_name, backend.type_)
|
||||
raise ValueError(
|
||||
"vLLM or SGLang is not enabled for generation. "
|
||||
"This behavior has been deprecated. "
|
||||
"Please set model.vllm.hybrid_train=True "
|
||||
"or model.sglang.hybrid_train=True."
|
||||
)
|
||||
|
||||
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
|
||||
if mapping[i, j]:
|
||||
|
|
|
@ -23,7 +23,11 @@ from realhf.api.quickstart.device_mesh import MFCConfig
|
|||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig
|
||||
from realhf.experiments.common.common import CommonExperimentConfig
|
||||
from realhf.experiments.common.utils import resolve_replica_ids, resolve_rpc_hooks
|
||||
from realhf.experiments.common.utils import (
|
||||
asdict,
|
||||
resolve_replica_ids,
|
||||
resolve_rpc_hooks,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("PPO Math exp", "colored")
|
||||
|
||||
|
@ -283,11 +287,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
# It is used for unifying the profiling API, which requires to
|
||||
# pass external interface configurations in the launch command.
|
||||
# Customized dataclass objects will not work in that case.
|
||||
"generation_config": (
|
||||
OmegaConf.to_container(self.ppo.gen, resolve=True)
|
||||
if isinstance(self.ppo.gen, (OmegaConf, DictConfig))
|
||||
else dataclasses.asdict(self.ppo.gen)
|
||||
),
|
||||
"generation_config": asdict(self.ppo.gen),
|
||||
"early_stop_imp_ratio": self.ppo.early_stop_imp_ratio,
|
||||
"adv_norm": self.ppo.adv_norm,
|
||||
"group_size": self.group_size,
|
||||
|
|
|
@ -10,7 +10,7 @@ import re
|
|||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from realhf.api.core.config import (
|
||||
ModelBackendAbstraction,
|
||||
|
@ -317,3 +317,9 @@ class AllocationMode:
|
|||
return
|
||||
allocs[k] = v["*"]
|
||||
return allocs
|
||||
|
||||
|
||||
def asdict(cfg):
|
||||
if isinstance(cfg, (OmegaConf, DictConfig)):
|
||||
return OmegaConf.to_container(cfg, resolve=True)
|
||||
return dataclasses.asdict(cfg)
|
||||
|
|
|
@ -32,7 +32,7 @@ class MockPipeTrainInstrSet(PipeTrainInstrSet):
|
|||
Used for testing only.
|
||||
"""
|
||||
|
||||
optimizer: torch.optim.Optimizer
|
||||
optim: torch.optim.Optimizer
|
||||
|
||||
def _exec_backward_pass(
|
||||
self,
|
||||
|
@ -78,14 +78,19 @@ class MockPipeTrainInstrSet(PipeTrainInstrSet):
|
|||
micro_batch_id: int,
|
||||
step_id: int,
|
||||
):
|
||||
self.optimizer.step()
|
||||
self.optim.step()
|
||||
|
||||
|
||||
class AdamWithLossScale(torch.optim.Adam):
|
||||
def get_loss_scale(self) -> torch.Tensor:
|
||||
return torch.tensor([1.0], device=constants.current_device())
|
||||
|
||||
|
||||
class MockTrainEngine(model_api.PipelinableEngine):
|
||||
|
||||
def __init__(self, module: ReaLModel, optimizer: torch.optim.Optimizer):
|
||||
def __init__(self, module: ReaLModel, optimizer: AdamWithLossScale):
|
||||
self.module = module
|
||||
self.optimizer = optimizer
|
||||
self.optim = optimizer
|
||||
|
||||
self.inf_engine = PipelinableInferenceEngine(module)
|
||||
if constants.pipe_parallel_world_size() > 1:
|
||||
|
@ -111,10 +116,10 @@ class MockTrainEngine(model_api.PipelinableEngine):
|
|||
token_normalize_scope: str,
|
||||
version_steps: int,
|
||||
):
|
||||
self.optimizer.zero_grad()
|
||||
self.optim.zero_grad()
|
||||
if constants.pipe_parallel_world_size() > 1:
|
||||
# Fusing the minibatched forward-backward in a pipeline training schedule.
|
||||
instr_set = MockPipeTrainInstrSet(self.optimizer)
|
||||
instr_set = MockPipeTrainInstrSet(self, self.optim)
|
||||
# NOTE: When training with pipeline parallel, num micro batches should be
|
||||
# larger than 2 x num_pipeline_stages to avoid idle time.
|
||||
return self.pipe_runner.train_batch(
|
||||
|
@ -222,7 +227,7 @@ class MockTrainBackend(model_api.ModelBackend):
|
|||
raise ValueError("MegatronTrainBackend only supports ReaLModel.")
|
||||
|
||||
if self.optimizer_name == "adam":
|
||||
optimizer = torch.optim.Adam(module.parameters(), **self.optimizer_config)
|
||||
optimizer = AdamWithLossScale(module.parameters(), **self.optimizer_config)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Optimizer {self.optimizer_name} not implemented for testing."
|
||||
|
|
|
@ -15,6 +15,7 @@ logger = logging.getLogger("system")
|
|||
# NOTE: Workers are configured in the following order.
|
||||
# Take special care when adding a new worker type.
|
||||
WORKER_TYPES = ["model_worker", "master_worker"]
|
||||
USE_V2_WORKER = os.getenv("REAL_USE_V2_WORKER", "0") == "1"
|
||||
|
||||
|
||||
def load_worker(worker_type: str) -> Type:
|
||||
|
@ -25,6 +26,8 @@ def load_worker(worker_type: str) -> Type:
|
|||
|
||||
|
||||
def worker_type_to_module(worker_type: str):
|
||||
if worker_type == "master_worker" and USE_V2_WORKER:
|
||||
return "realhf.system.v2." + worker_type
|
||||
return "realhf.system." + worker_type
|
||||
|
||||
|
||||
|
|
|
@ -147,6 +147,7 @@ class AsyncIOSequenceBuffer:
|
|||
rpcs: List[dfg.MFCDef],
|
||||
max_size: int,
|
||||
):
|
||||
self.rpcs = rpcs
|
||||
self._lock = asyncio.Condition(asyncio.Lock())
|
||||
|
||||
# Buffer indicators, should be locked by self._lock.
|
||||
|
@ -269,6 +270,62 @@ class AsyncIOSequenceBuffer:
|
|||
)
|
||||
return indices
|
||||
|
||||
async def put_batch(self, samples: List[SequenceSample]):
|
||||
n = len(samples)
|
||||
|
||||
if n == 0:
|
||||
return np.array([], dtype=np.int64)
|
||||
|
||||
async with self._lock:
|
||||
self._assert_valid_indicator()
|
||||
|
||||
indices = np.where(self._is_empty)[0][:n]
|
||||
|
||||
if len(indices) < n:
|
||||
raise BufferFull(
|
||||
"You are probably using a large dataset. "
|
||||
"The default buffer size 1M is not large enough. "
|
||||
"Please set a larger buffer size by setting "
|
||||
"the environment variable, e.g., REAL_MASTER_BUFFER_SIZE=3000000."
|
||||
)
|
||||
self._is_empty[indices] = False
|
||||
self._is_being_put[indices] = True
|
||||
|
||||
self.__buffer.put_batch(indices, samples)
|
||||
|
||||
# Set a slight difference in birth time to let the order
|
||||
# be deterministic.
|
||||
self._birth_time[indices] = time.monotonic_ns() + np.arange(
|
||||
len(indices), dtype=np.int64
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self.__buffer._update_has_keys(indices)
|
||||
|
||||
has_keys = self.__buffer._get_has_keys(indices) # [bs, #keys]
|
||||
rpc_key_mask = self._rpc_key_mask # [#keys, #rpcs]
|
||||
self._ready_for_rpcs[indices] = (
|
||||
has_keys[:, :, None] >= rpc_key_mask[None, :, :]
|
||||
).all(axis=1)
|
||||
|
||||
self._is_being_put[indices] = False
|
||||
self._is_idle[indices] = True
|
||||
|
||||
self._buf_size += len(samples)
|
||||
if self._buf_size >= 0.95 * self.__max_size:
|
||||
logger.warning(
|
||||
f"Buffer is 95% full. The current buffer size is {self._buf_size} "
|
||||
f"while the maximum size is {self.__max_size}. "
|
||||
f"If your dataset has more than 1M sequences, consider enlarge "
|
||||
f"the default batch size in the master worker."
|
||||
)
|
||||
|
||||
can_do_rpcs = {rpc.name: self._can_do_rpc(rpc) for rpc in self.rpcs}
|
||||
logger.info(f"After putting batch, can do RPCs? {can_do_rpcs}.")
|
||||
|
||||
self._lock.notify(len(self._rpc_names))
|
||||
return indices
|
||||
|
||||
async def amend_batch(self, indices: List[int], samples: List[SequenceSample]):
|
||||
async with self._lock:
|
||||
await self._lock.wait_for(
|
||||
|
@ -298,6 +355,17 @@ class AsyncIOSequenceBuffer:
|
|||
if self._is_idle[indices].any():
|
||||
self._lock.notify(len(self._rpc_names))
|
||||
|
||||
def _can_do_rpc(self, rpc: dfg.MFCDef) -> bool:
|
||||
rpc_idx = self._rpc_names.index(rpc.name)
|
||||
ready_indices = np.nonzero(
|
||||
(self._is_idle | self._is_being_read)
|
||||
& self._ready_for_rpcs[:, rpc_idx]
|
||||
& ~self._completed_rpc[:, rpc_idx]
|
||||
)[0]
|
||||
if len(ready_indices) < rpc.n_seqs:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_batch_for_rpc(
|
||||
self, rpc: dfg.MFCDef
|
||||
) -> Tuple[List[int], SequenceSample]:
|
||||
|
@ -306,20 +374,10 @@ class AsyncIOSequenceBuffer:
|
|||
)
|
||||
rpc_idx = self._rpc_names.index(rpc.name)
|
||||
|
||||
def _can_do_rpc() -> bool:
|
||||
ready_indices = np.nonzero(
|
||||
(self._is_idle | self._is_being_read)
|
||||
& self._ready_for_rpcs[:, rpc_idx]
|
||||
& ~self._completed_rpc[:, rpc_idx]
|
||||
)[0]
|
||||
if len(ready_indices) < rpc.n_seqs:
|
||||
return False
|
||||
return True
|
||||
|
||||
async with self._lock:
|
||||
# await self._lock.wait_for(_can_do_rpc)
|
||||
|
||||
while not _can_do_rpc():
|
||||
while not self._can_do_rpc(rpc):
|
||||
await self._lock.wait()
|
||||
|
||||
logger.info(f"Input keys ({rpc.input_keys}) for MFC {rpc.name} are ready!")
|
||||
|
|
|
@ -619,6 +619,7 @@ class RayController:
|
|||
REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""),
|
||||
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
|
||||
REAL_MATH_METADATA_PATH=os.environ.get("REAL_MATH_METADATA_PATH", ""),
|
||||
REAL_USE_V2_WORKER=os.getenv("REAL_USE_V2_WORKER", "0"),
|
||||
)
|
||||
runtime_env = {
|
||||
"env_vars": env_vars,
|
||||
|
|
|
@ -39,11 +39,6 @@ from realhf.base import (
|
|||
timeutil,
|
||||
topology,
|
||||
)
|
||||
from realhf.base.asyncio_utils import (
|
||||
raise_asyncio_exception,
|
||||
setup_run_until_complete,
|
||||
teardown_run_util_complete,
|
||||
)
|
||||
from realhf.system.buffer import AsyncIOSequenceBuffer
|
||||
from realhf.system.flops_counter import FlopsCounter
|
||||
|
||||
|
@ -482,7 +477,7 @@ async def model_rpc_reply_func(
|
|||
if isinstance(responses[-1], data_api.SequenceSample):
|
||||
res = data_api.SequenceSample.gather(responses)
|
||||
else:
|
||||
res = _gather_stat(responses)
|
||||
res = data_api.gather_stat(responses)
|
||||
|
||||
if rpc.log_return_value:
|
||||
logger.info(f"RPC name {rpc.name} returns {res}")
|
||||
|
@ -491,7 +486,9 @@ async def model_rpc_reply_func(
|
|||
wandb.log(res, step=ctrl.step_info.global_step)
|
||||
if summary_writer is not None:
|
||||
for key, val in res.items():
|
||||
summary_writer.add_scalar(f"{key}", val, ctrl.step_info.global_step)
|
||||
summary_writer.add_scalar(
|
||||
f"{key}", val, ctrl.step_info.global_step
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Model rpc {rpc.name} finished. Run time {time.perf_counter() - tik:.4f}s."
|
||||
|
@ -518,25 +515,6 @@ async def model_rpc_reply_func(
|
|||
await stream.gather_async(other_req_ids)
|
||||
|
||||
|
||||
def _gather_stat(src: List[Dict]) -> Dict:
|
||||
cnt, stats = {}, {}
|
||||
for reply in src:
|
||||
for k, v in reply.items():
|
||||
cnt[k] = cnt.get(k, 0) + 1
|
||||
stats[k] = stats.get(k, 0) + v
|
||||
res = {k: v / cnt for k, v, cnt in zip(stats.keys(), stats.values(), cnt.values())}
|
||||
for k, c in cnt.items():
|
||||
if c != len(src):
|
||||
logger.warning(f"Gathered `{k}` is not present in every returned stats.")
|
||||
for k, v in res.items():
|
||||
if any(abs(v - x.get(k, None)) > 1e-4 for x in src):
|
||||
logger.warning(
|
||||
f"Gathered `{k}` is not all-reduced "
|
||||
f"before returning: ({[x.get(k, None) for x in src]}, {v})."
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
class MasterWorker(worker_base.Worker):
|
||||
os.makedirs(constants.MODEL_SAVE_ROOT, exist_ok=True)
|
||||
global_exp_tik = time.perf_counter()
|
||||
|
@ -952,6 +930,9 @@ class MasterWorker(worker_base.Worker):
|
|||
)
|
||||
coroutine_tasks += [request_task, reply_task]
|
||||
|
||||
# Import here to avoid the conflict with nvloop.
|
||||
from realhf.base.asyncio_utils import setup_run_until_complete
|
||||
|
||||
# Set up a run context of EventLoop.run_util_complete, baiscally copy-paste from cpython.
|
||||
# With this context, we can call the non-block EventLoop._run_once (similar to worker._poll).
|
||||
self.__asyncio_tasks: List[asyncio.Task] = coroutine_tasks
|
||||
|
@ -993,6 +974,9 @@ class MasterWorker(worker_base.Worker):
|
|||
)
|
||||
|
||||
def _poll(self):
|
||||
# Import here to avoid the conflict with nvloop.
|
||||
from realhf.base.asyncio_utils import raise_asyncio_exception
|
||||
|
||||
is_new_epoch = False
|
||||
first_poll = not self.__initialized
|
||||
|
||||
|
@ -1276,6 +1260,9 @@ class MasterWorker(worker_base.Worker):
|
|||
self.__rpc_ctrl.ids_to_clear.clear()
|
||||
|
||||
def experiment_complete_exit(self):
|
||||
# Import here to avoid the conflict with nvloop.
|
||||
from realhf.base.asyncio_utils import teardown_run_util_complete
|
||||
|
||||
self.__rpc_ctrl.stop.set()
|
||||
for task in self.__asyncio_tasks:
|
||||
task.cancel()
|
||||
|
|
|
@ -340,6 +340,8 @@ class ModelWorker(worker_base.Worker):
|
|||
for tmp_sample in self.__dataloader:
|
||||
self.__raw_samples += tmp_sample.meta().unpack()
|
||||
|
||||
self.__data_generator = enumerate(self.__dataloader)
|
||||
|
||||
self.__models: Dict[ModelName, model_api.Model] = dict()
|
||||
self.__model_is_handle: Dict[ModelName, bool] = dict()
|
||||
self.__interfaces: Dict[ModelName, model_api.ModelInterface] = dict()
|
||||
|
@ -581,7 +583,10 @@ class ModelWorker(worker_base.Worker):
|
|||
elif request.handle_name == "fetch":
|
||||
dp_rank = int(re.search(r"__data(\d+)__", request.handler).group(1))
|
||||
assert self.__has_dataset
|
||||
if request.data["first_batch"]:
|
||||
# Fetch.
|
||||
try:
|
||||
self.__dataset_batch_counter, cur_sample = next(self.__data_generator)
|
||||
except StopIteration:
|
||||
# Upon the first fetch request, filter dataset and create dataloader.
|
||||
eval_scores_path = os.path.join(
|
||||
constants.MODEL_SAVE_ROOT,
|
||||
|
@ -595,10 +600,8 @@ class ModelWorker(worker_base.Worker):
|
|||
constants.trial_name(),
|
||||
f"dataset_indices_{dp_rank}.npy",
|
||||
)
|
||||
if (
|
||||
hasattr(self.__dataset, "filter")
|
||||
and not request.data["first_poll"]
|
||||
and os.path.exists(eval_scores_path)
|
||||
if hasattr(self.__dataset, "filter") and os.path.exists(
|
||||
eval_scores_path
|
||||
):
|
||||
# Don't filter dataset on the first poll after recover.
|
||||
with open(eval_scores_path, "r", encoding="utf-8") as f:
|
||||
|
@ -621,9 +624,7 @@ class ModelWorker(worker_base.Worker):
|
|||
generator=g,
|
||||
)
|
||||
self.__data_generator = enumerate(self.__dataloader)
|
||||
|
||||
# Fetch.
|
||||
self.__dataset_batch_counter, cur_sample = next(self.__data_generator)
|
||||
self.__dataset_batch_counter, cur_sample = next(self.__data_generator)
|
||||
|
||||
# Defer data that has not been used in the previous epoch.
|
||||
data_loaded = []
|
||||
|
@ -707,9 +708,6 @@ class ModelWorker(worker_base.Worker):
|
|||
# e.g., data transfer, parameter syncrhonization.
|
||||
pass
|
||||
elif request.handle_name == "initialize":
|
||||
assert not self.__model_is_handle[
|
||||
request.handler.model_name
|
||||
], request.handler.model_name
|
||||
self.__models[request.handler.model_name] = self._backend.initialize(
|
||||
self._model, data
|
||||
)
|
||||
|
|
|
@ -0,0 +1,437 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Hashable, List, Set, Tuple
|
||||
|
||||
import wandb
|
||||
|
||||
import realhf.api.core.config as config_api
|
||||
import realhf.api.core.data_api as data_api
|
||||
import realhf.api.core.dfg as dfg
|
||||
import realhf.api.core.system_api as config_pkg
|
||||
import realhf.base.recover as recover
|
||||
import realhf.system.request_reply_stream as request_reply_stream
|
||||
from realhf.api.core.config import ModelName
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base import constants, logging, topology
|
||||
from realhf.system.buffer import AsyncIOSequenceBuffer
|
||||
from realhf.system.flops_counter import FlopsCounter
|
||||
|
||||
logger = logging.getLogger(__name__, "system")
|
||||
blogger = logging.getLogger("benchmark")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RPCCorountineControl:
|
||||
# for counting the number of finished training steps
|
||||
# one training step corresponds to traversal of the whole DFG
|
||||
train_count: asyncio.Queue
|
||||
# For flushing requests
|
||||
topo_level_count: asyncio.Queue
|
||||
|
||||
# for training data management and data cleaning after each step
|
||||
ids_to_clear: Set[Hashable] = dataclasses.field(default_factory=set)
|
||||
flops_counter: FlopsCounter = dataclasses.field(default_factory=FlopsCounter)
|
||||
|
||||
data_owner: Dict[Tuple[int, str], Tuple[ModelName, int]] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
should_save: bool = False
|
||||
should_eval: bool = False
|
||||
should_ckpt: bool = False
|
||||
step_info: recover.StepInfo = dataclasses.field(default_factory=recover.StepInfo)
|
||||
|
||||
# recover information
|
||||
used_hash_vals_this_epoch: List[int] = dataclasses.field(default_factory=list)
|
||||
hash_vals_to_ignore_in_recover: List[int] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
class FunctionCall:
|
||||
def __init__(
|
||||
self,
|
||||
rpc: dfg.MFCDef,
|
||||
src_rpc: dfg.MFCDef,
|
||||
stream: request_reply_stream.NameResolvingRequestClient,
|
||||
msid2mwid: Dict[config_pkg.ModelShardID, int],
|
||||
model_topos: Dict[str, topology.PipeModelDataParallelTopology],
|
||||
model_configs: Dict[str, None | ReaLModelConfig],
|
||||
ctrl: RPCCorountineControl,
|
||||
buffer: AsyncIOSequenceBuffer,
|
||||
):
|
||||
|
||||
self.rpc = rpc
|
||||
self.src_rpc = src_rpc
|
||||
self.stream = stream
|
||||
|
||||
self.msid2mwid = msid2mwid
|
||||
self.model_topos = model_topos
|
||||
self.model_configs = model_configs
|
||||
|
||||
self.model_save_root = os.path.join(
|
||||
constants.MODEL_SAVE_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
)
|
||||
|
||||
self.rpc_ctrl = ctrl
|
||||
|
||||
self.buffer = buffer
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
return self.model_topos[self.rpc.model_name].get_dim("data")
|
||||
|
||||
@property
|
||||
def pp_size(self):
|
||||
return self.model_topos[self.rpc.model_name].get_dim("pipe")
|
||||
|
||||
def attach_payloads_with_hooks(
|
||||
self,
|
||||
payloads: Dict[config_api.ModelShardID, request_reply_stream.Payload],
|
||||
mwids: List[int],
|
||||
main_handlers: List[config_pkg.ModelShardID],
|
||||
hook_type: str,
|
||||
) -> Tuple[Dict[config_api.ModelShardID, request_reply_stream.Payload], List[int]]:
|
||||
assert hook_type in ["pre", "post"], hook_type
|
||||
|
||||
rpc = self.rpc
|
||||
model_topos = self.model_topos
|
||||
model_configs = self.model_configs
|
||||
|
||||
main_mwids = set([self.msid2mwid[h] for h in main_handlers])
|
||||
for hook in getattr(rpc, f"_{hook_type}_hooks"):
|
||||
if isinstance(hook, dfg.ParamReallocHook):
|
||||
assert (hook.source is None) != (hook.target is None), hook
|
||||
if hook.source is None:
|
||||
src_topo = model_topos[rpc.model_name]
|
||||
dst_topo = model_topos[hook.target]
|
||||
dst_config = model_configs[hook.target]
|
||||
src_model_name, dst_model_name = rpc.model_name, hook.target
|
||||
other_model_name = hook.target
|
||||
other_topo = dst_topo
|
||||
else:
|
||||
src_topo = model_topos[hook.source]
|
||||
dst_topo = model_topos[rpc.model_name]
|
||||
dst_config = model_configs[rpc.model_name]
|
||||
src_model_name, dst_model_name = hook.source, rpc.model_name
|
||||
other_model_name = hook.source
|
||||
other_topo = src_topo
|
||||
|
||||
ps_data = {
|
||||
"from_model_name": src_model_name,
|
||||
"to_model_name": dst_model_name,
|
||||
"from_topo": src_topo,
|
||||
"to_topo": dst_topo,
|
||||
"to_model_config": dst_config,
|
||||
"eta": hook.eta,
|
||||
}
|
||||
for h in main_handlers:
|
||||
getattr(payloads[h], f"{hook_type}_hooks").append("param_realloc")
|
||||
getattr(payloads[h], f"{hook_type}_hook_data").append(ps_data)
|
||||
other_handlers = [
|
||||
config_api.ModelShardID.from_parallelism_rank(
|
||||
other_model_name, other_topo, j
|
||||
)
|
||||
for j in range(other_topo.world_size())
|
||||
]
|
||||
for h in other_handlers:
|
||||
if self.msid2mwid[h] not in mwids:
|
||||
payloads[h] = request_reply_stream.Payload(
|
||||
handler=h,
|
||||
handle_name="empty",
|
||||
)
|
||||
setattr(payloads[h], f"{hook_type}_hooks", ["param_realloc"])
|
||||
setattr(payloads[h], f"{hook_type}_hook_data", [ps_data])
|
||||
mwids.append(self.msid2mwid[h])
|
||||
elif self.msid2mwid[h] not in main_mwids:
|
||||
hh = next(
|
||||
hh
|
||||
for hh in payloads
|
||||
if self.msid2mwid[hh] == self.msid2mwid[h]
|
||||
)
|
||||
getattr(payloads[hh], f"{hook_type}_hooks").append(
|
||||
"param_realloc"
|
||||
)
|
||||
getattr(payloads[hh], f"{hook_type}_hook_data").append(ps_data)
|
||||
|
||||
elif isinstance(hook, dfg.OffloadHook):
|
||||
for h in main_handlers:
|
||||
getattr(payloads[h], f"{hook_type}_hooks").append("offload")
|
||||
getattr(payloads[h], f"{hook_type}_hook_data").append(
|
||||
dict(model_name=h.model_name)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown hook type: {hook}")
|
||||
return payloads, mwids
|
||||
|
||||
def request(
|
||||
self,
|
||||
producer_names: Dict[str, str],
|
||||
producer_name2producer_handlers: Dict[str, List[config_pkg.ModelShardID]],
|
||||
producer_mappings: Dict[str, Dict[str, List[int]]],
|
||||
target_mapping: Dict[str, List[int]],
|
||||
meta_sample: data_api.SequenceSample,
|
||||
handlers: List[config_pkg.ModelShardID],
|
||||
) -> Tuple[List[uuid.UUID], List[uuid.UUID]]:
|
||||
|
||||
rpc = self.rpc
|
||||
ctrl = self.rpc_ctrl
|
||||
|
||||
dt_data = {
|
||||
"keys": rpc.input_keys,
|
||||
"target": rpc.model_name,
|
||||
"producer_names": producer_names,
|
||||
"producer_mappings": producer_mappings,
|
||||
"target_mapping": target_mapping,
|
||||
"handle_name": rpc.interface_type.value,
|
||||
"rpc_name": rpc.name,
|
||||
"meta_sample": meta_sample,
|
||||
}
|
||||
|
||||
payloads = {
|
||||
handler: request_reply_stream.Payload(
|
||||
handler=handler,
|
||||
handle_name=rpc.interface_type.value,
|
||||
pre_hooks=["data_transfer"],
|
||||
pre_hook_data=[dt_data],
|
||||
data=rpc.name,
|
||||
)
|
||||
for handler in handlers
|
||||
}
|
||||
if ctrl.should_eval:
|
||||
for p in payloads.values():
|
||||
p.post_hooks.append("evaluate")
|
||||
p.post_hook_data.append(dict(model_name=rpc.model_name))
|
||||
if (
|
||||
ctrl.should_save or ctrl.should_ckpt
|
||||
) and rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
|
||||
for p in payloads.values():
|
||||
p.post_hooks.append("save")
|
||||
save_dir = os.path.join(
|
||||
self.model_save_root,
|
||||
rpc.model_name.role,
|
||||
f"epoch{ctrl.step_info.epoch + 1}"
|
||||
f"epochstep{ctrl.step_info.epoch_step + 1}"
|
||||
f"globalstep{ctrl.step_info.global_step + 1}",
|
||||
)
|
||||
p.post_hook_data.append(
|
||||
dict(
|
||||
model_name=rpc.model_name,
|
||||
save_dir=save_dir,
|
||||
recover_only=not ctrl.should_save,
|
||||
)
|
||||
)
|
||||
mwids = [self.msid2mwid[h] for h in handlers]
|
||||
assert len(mwids) == len(set(mwids))
|
||||
|
||||
for producer_name in producer_names.values():
|
||||
for h in producer_name2producer_handlers[producer_name]:
|
||||
if self.msid2mwid[h] not in mwids:
|
||||
payloads[h] = request_reply_stream.Payload(
|
||||
handler=h,
|
||||
handle_name="empty",
|
||||
pre_hooks=["data_transfer"],
|
||||
pre_hook_data=[dt_data],
|
||||
)
|
||||
mwids.append(self.msid2mwid[h])
|
||||
|
||||
payloads, mwids = self.attach_payloads_with_hooks(
|
||||
payloads,
|
||||
mwids,
|
||||
main_handlers=handlers,
|
||||
hook_type="pre",
|
||||
)
|
||||
payloads, mwids = self.attach_payloads_with_hooks(
|
||||
payloads,
|
||||
mwids,
|
||||
main_handlers=handlers,
|
||||
hook_type="post",
|
||||
)
|
||||
main_payloads = [p for h, p in payloads.items() if h in handlers]
|
||||
other_payloads = [p for h, p in payloads.items() if h not in handlers]
|
||||
all_req_ids = self.stream.request(
|
||||
payloads=main_payloads + other_payloads,
|
||||
)
|
||||
return all_req_ids[: len(main_payloads)], all_req_ids[len(main_payloads) :]
|
||||
|
||||
def data_parallel_dispatch(
|
||||
self, buf_indices: List[int], sample: data_api.SequenceSample
|
||||
) -> Tuple[List[int], data_api.SequenceSample, List[Tuple[int, int]]]:
|
||||
# Dispatch data to different data parallel ranks.
|
||||
if self.rpc.is_generate():
|
||||
# The workload of generation is decided by batch size, instead of the generated length.
|
||||
samples, forward_indices, _ = sample.split_with_lengths(
|
||||
mb_spec=data_api.MicroBatchSpec(n_mbs=self.dp_size),
|
||||
lens=[1 for _ in range(sample.bs)],
|
||||
)
|
||||
else:
|
||||
samples, forward_indices, _ = sample.split(
|
||||
data_api.MicroBatchSpec(n_mbs=self.dp_size)
|
||||
)
|
||||
blogger.info(
|
||||
f"DP split (DP size {self.dp_size}) for RPC {self.rpc.name}: "
|
||||
f"#seqs: {[s.bs for s in samples]}, "
|
||||
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in samples]}"
|
||||
)
|
||||
sample = data_api.SequenceSample.gather(samples)
|
||||
buf_indices = [buf_indices[i] for i in forward_indices]
|
||||
|
||||
partitions = data_api.SequenceSplitSpec(
|
||||
sizes=[s.bs for s in samples]
|
||||
).partitions
|
||||
return buf_indices, sample, partitions
|
||||
|
||||
async def run_step(self, buf_indices, sample):
|
||||
rpc = self.rpc
|
||||
topo = self.model_topos[rpc.model_name]
|
||||
ctrl = self.rpc_ctrl
|
||||
|
||||
handlers = [
|
||||
config_pkg.ModelShardID.from_parallelism_rank(rpc.model_name, topo, j)
|
||||
for j in range(topo.world_size())
|
||||
]
|
||||
|
||||
producer_names = {} # data key -> model name
|
||||
for k in rpc.input_keys:
|
||||
if k in rpc.data_producers:
|
||||
producer_names[k] = rpc.data_producers[k]
|
||||
else:
|
||||
producer_names[k] = self.src_rpc.model_name
|
||||
keys_to_send = defaultdict(list) # model name -> List[keys] to send
|
||||
for k in producer_names:
|
||||
keys_to_send[producer_names[k]].append(k)
|
||||
|
||||
# convert producer model name to ModelShardID
|
||||
producer_name2producer_handlers = {}
|
||||
for producer_name in keys_to_send:
|
||||
producer_name2producer_handlers[producer_name] = [
|
||||
config_pkg.ModelShardID.from_parallelism_rank(
|
||||
producer_name, self.model_topos[producer_name], j
|
||||
)
|
||||
for j in range(self.model_topos[producer_name].world_size())
|
||||
]
|
||||
|
||||
dp_head_indices = [
|
||||
topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, model=0)
|
||||
for i in range(self.dp_size)
|
||||
]
|
||||
|
||||
ctrl.flops_counter.add_rpc(rpc, sample, self.model_configs[rpc.model_name])
|
||||
|
||||
# logger.info(f"Model rpc {rpc.name} requesting.")
|
||||
|
||||
# Sample may be reordered here.
|
||||
buf_indices, sample, partitions = self.data_parallel_dispatch(
|
||||
buf_indices, sample
|
||||
)
|
||||
target_mapping = {i: list(range(v[0], v[1])) for i, v in enumerate(partitions)}
|
||||
|
||||
# Set data owner of produced data by this RPC, such that downstream RPCs can know
|
||||
# where to fetch these data.
|
||||
for dp_idx, (st, ed) in enumerate(partitions):
|
||||
for i in range(st, ed):
|
||||
for k in rpc.output_keys:
|
||||
self.rpc_ctrl.data_owner[sample.ids[i], k] = (
|
||||
rpc.model_name,
|
||||
dp_idx,
|
||||
)
|
||||
|
||||
# Get the data owner of this RPC's input data.
|
||||
# We use it to determine the source of data transfer.
|
||||
producer_mappings = {}
|
||||
for k in rpc.input_keys:
|
||||
names, dp_indices = [], []
|
||||
for sample_id in sample.ids:
|
||||
owner_name, dp_idx = self.rpc_ctrl.data_owner[(sample_id, k)]
|
||||
names.append(owner_name)
|
||||
dp_indices.append(dp_idx)
|
||||
assert len(set(names)) == 1
|
||||
producer_mapping = defaultdict(list)
|
||||
for i, dp_idx in enumerate(dp_indices):
|
||||
producer_mapping[dp_idx].append(i)
|
||||
producer_mapping = {k: sorted(v) for k, v in producer_mapping.items()}
|
||||
producer_mappings[names[0], k] = producer_mapping
|
||||
|
||||
# send partitioned data to model workers
|
||||
req_ids, other_req_ids = self.request(
|
||||
producer_names=producer_names,
|
||||
producer_name2producer_handlers=producer_name2producer_handlers,
|
||||
producer_mappings=producer_mappings,
|
||||
target_mapping=target_mapping,
|
||||
meta_sample=sample,
|
||||
handlers=handlers,
|
||||
)
|
||||
tik = time.perf_counter()
|
||||
|
||||
await ctrl.topo_level_count.put(1)
|
||||
logger.info(f"Model rpc {rpc.name} requested.")
|
||||
|
||||
# Then, wait for all main requests to finish.
|
||||
responses = await self.stream.gather_async(request_ids=req_ids)
|
||||
# logger.info(f"rpc {rpc.name} received responses {req_ids}")
|
||||
|
||||
# Filter out responses other than DP heads.
|
||||
# Other repsonses are duplicated or None.
|
||||
responses = [responses[i] for i in dp_head_indices]
|
||||
|
||||
# If the returned data is a SequenceSample, it is the data returned by
|
||||
# model function calls. The data shoulbe be amended into buffer.
|
||||
# Otherwise, it's the train statistics and should be reduced and logged.
|
||||
if isinstance(responses[-1], data_api.SequenceSample):
|
||||
res = data_api.SequenceSample.gather(responses)
|
||||
else:
|
||||
res = data_api.gather_stat(responses)
|
||||
|
||||
if rpc.log_return_value:
|
||||
logger.info(f"RPC name {rpc.name} returns {res}")
|
||||
|
||||
if isinstance(res, Dict):
|
||||
wandb.log(res, step=ctrl.step_info.global_step)
|
||||
|
||||
logger.info(
|
||||
f"Model rpc {rpc.name} finished. "
|
||||
f"Run time {time.perf_counter() - tik:.4f}s."
|
||||
)
|
||||
|
||||
# If this RPC is the final node in the dataflow graph,
|
||||
# update the train counter.
|
||||
# Otherwise, amend data in the buffer.
|
||||
if rpc.is_dst:
|
||||
ctrl.ids_to_clear = ctrl.ids_to_clear.union(sample.ids)
|
||||
await ctrl.train_count.put(1)
|
||||
else:
|
||||
logger.info(f"Amending RPC {rpc.name} output keys: {res.keys}")
|
||||
await self.buffer.amend_batch(buf_indices, res.unpack())
|
||||
|
||||
# Wait for all side-effect requests to finish.
|
||||
# Side-effect or empty requests are required for data transfer
|
||||
# and parameter synchronization.
|
||||
# Wait them after the main request to log the oorrect MFC time.
|
||||
await self.stream.gather_async(other_req_ids)
|
||||
|
||||
async def run(self):
|
||||
rpc = self.rpc
|
||||
topo = self.model_topos[rpc.model_name]
|
||||
ctrl = self.rpc_ctrl
|
||||
|
||||
logger.info(
|
||||
f"Running Model RPC, interface_type=#{rpc.interface_type}# "
|
||||
f"(dp,mp,pp) = *({topo.get_dim('data')},{topo.get_dim('model')},{topo.get_dim('pipe')})*"
|
||||
)
|
||||
|
||||
consumed = 0
|
||||
while True:
|
||||
buf_indices, sample = await self.buffer.get_batch_for_rpc(rpc)
|
||||
|
||||
await self.run_step(buf_indices, sample)
|
||||
consumed += sample.bs
|
||||
|
||||
# Ensure that parent RPCs will not be over-consumed.
|
||||
if all(consumed >= c.n_seqs for c in rpc.all_successors()):
|
||||
break
|
|
@ -0,0 +1,168 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
import asyncio
|
||||
import random
|
||||
from typing import *
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from realhf.api.core.config import ModelName, ModelShardID
|
||||
from realhf.api.core.data_api import DataBatchMeta, SequenceSample
|
||||
from realhf.api.core.dfg import MFCDef
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base import logging
|
||||
from realhf.base.topology import PipeModelDataParallelTopology
|
||||
from realhf.system.buffer import AsyncIOSequenceBuffer
|
||||
from realhf.system.request_reply_stream import NameResolvingRequestClient
|
||||
from realhf.system.v2.function_call import FunctionCall, RPCCorountineControl
|
||||
|
||||
logger = logging.getLogger(__name__, "system")
|
||||
blogger = logging.getLogger("benchmark")
|
||||
|
||||
|
||||
class FunctionExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
rpcs: List[MFCDef],
|
||||
msid2mwid: Dict[ModelShardID, int],
|
||||
stream: NameResolvingRequestClient,
|
||||
buffer: AsyncIOSequenceBuffer,
|
||||
model_topos: Dict[str, PipeModelDataParallelTopology],
|
||||
model_configs: Dict[str, None | ReaLModelConfig],
|
||||
ctrl: RPCCorountineControl,
|
||||
):
|
||||
|
||||
self.func_calls: Dict[str, FunctionCall] = {}
|
||||
self.ctrl = ctrl
|
||||
|
||||
self.n_model_workers = len(set(msid2mwid.values()))
|
||||
|
||||
self.rpcs = rpcs
|
||||
self.src_rpc = list(filter(lambda rpc: rpc.is_src, rpcs))[0]
|
||||
self.src_dp_size = model_topos[self.src_rpc.model_name].get_dim("data")
|
||||
|
||||
# Create model function calls.
|
||||
for rpc in self.rpcs:
|
||||
func_call = FunctionCall(
|
||||
rpc=rpc,
|
||||
src_rpc=self.src_rpc,
|
||||
stream=stream,
|
||||
msid2mwid=msid2mwid,
|
||||
model_topos=model_topos,
|
||||
model_configs=model_configs,
|
||||
ctrl=ctrl,
|
||||
buffer=buffer,
|
||||
)
|
||||
self.func_calls[rpc.name] = func_call
|
||||
|
||||
self.stream = stream
|
||||
self.buffer = buffer
|
||||
|
||||
# Sort all MFCs in the topological order and
|
||||
# calculate the width of each level.
|
||||
# These numbers will determine when to flush MFC requests.
|
||||
self.topo_widths = []
|
||||
for generation in nx.topological_generations(rpcs[0]._G):
|
||||
self.topo_widths.append(len(generation))
|
||||
|
||||
def get_leaf_tasks(self) -> List[str]:
|
||||
dst_rpcs = list(filter(lambda rpc: rpc.is_dst, self.rpcs))
|
||||
return [rpc.name for rpc in dst_rpcs]
|
||||
|
||||
async def flush_calls(self):
|
||||
for level, w in enumerate(self.topo_widths):
|
||||
for _ in range(w):
|
||||
await self.ctrl.topo_level_count.get()
|
||||
logger.info(f"DFG level {level}. Flushing {w} function calls.")
|
||||
self.stream.request(
|
||||
handlers=list(range(self.n_model_workers)), handle_type="flush"
|
||||
)
|
||||
|
||||
async def finish_traverse(self):
|
||||
for _ in range(len(self.get_leaf_tasks())):
|
||||
await self.ctrl.train_count.get()
|
||||
|
||||
async def load_data(self):
|
||||
src_rpc = self.src_rpc
|
||||
src_rpc_model_name = src_rpc.model_name
|
||||
buffer = self.buffer
|
||||
ctrl = self.ctrl
|
||||
|
||||
dp_idx = -1
|
||||
received_ids = set()
|
||||
|
||||
while self.buffer.size < max(rpc.n_seqs for rpc in self.rpcs):
|
||||
|
||||
all_data = []
|
||||
|
||||
dp_idx += 1
|
||||
dp_idx %= self.src_dp_size
|
||||
|
||||
resps = await self.stream.call_async(
|
||||
handlers=[f"__data{dp_idx}__"],
|
||||
handle_type="fetch",
|
||||
datas=[None],
|
||||
verbose=False,
|
||||
)
|
||||
x: DataBatchMeta | None = resps[0]
|
||||
|
||||
if x is None:
|
||||
continue
|
||||
if x.meta_sample is None:
|
||||
continue
|
||||
|
||||
# Store the owner information of the data.
|
||||
# RPCs corountines will use this information to
|
||||
# determine the src and dst of data transfer.
|
||||
for xx in x.meta_sample.unpack():
|
||||
if xx.ids[0] in received_ids:
|
||||
raise ValueError(
|
||||
f"Duplicate data id {xx.ids[0]}. Is the final batch? {is_final_batch}."
|
||||
)
|
||||
received_ids.add(xx.ids[0])
|
||||
for k in xx.keys:
|
||||
self.ctrl.data_owner[(xx.ids[0], k)] = (
|
||||
src_rpc_model_name,
|
||||
dp_idx,
|
||||
)
|
||||
all_data += x.meta_sample.unpack()
|
||||
|
||||
filtered_data = []
|
||||
for xx in x.meta_sample.unpack():
|
||||
if xx.ids[0] in ctrl.hash_vals_to_ignore_in_recover:
|
||||
ctrl.hash_vals_to_ignore_in_recover.remove(xx.ids[0])
|
||||
ctrl.ids_to_clear.add(xx.ids[0])
|
||||
else:
|
||||
filtered_data.append(xx)
|
||||
all_data = filtered_data
|
||||
|
||||
# We load data in a round-robin manner across different DP ranks,
|
||||
# so we also need to shuffle the data to fuse different dataset splits.
|
||||
random.shuffle(all_data)
|
||||
|
||||
# Store into buffer!
|
||||
buffer_indices = await buffer.put_batch(all_data)
|
||||
assert len(buffer_indices) == len(all_data)
|
||||
|
||||
blogger.info(
|
||||
f"Master worker loaded {len(all_data)} pieces of data. "
|
||||
f"Remaining number of data to ignore: {len(self.ctrl.hash_vals_to_ignore_in_recover)}. "
|
||||
f"Current buffer size: {buffer.size}/{buffer.max_size}. "
|
||||
)
|
||||
|
||||
def execute_step(self):
|
||||
logger.info("Waiting for the finish of the execution graph.")
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
tasks = [loop.create_task(fc.run()) for fc in self.func_calls.values()] + [
|
||||
loop.create_task(self.flush_calls()),
|
||||
loop.create_task(self.load_data()),
|
||||
]
|
||||
|
||||
completion_future = loop.create_task(self.finish_traverse())
|
||||
|
||||
loop.run_until_complete(completion_future)
|
||||
|
||||
for task in tasks:
|
||||
loop.run_until_complete(task)
|
||||
|
||||
logger.info("Execution finished!")
|
|
@ -0,0 +1,501 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
import colorama
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import uvloop
|
||||
import wandb
|
||||
|
||||
import realhf.api.core.dfg as dfg
|
||||
import realhf.api.core.model_api as model_api
|
||||
import realhf.api.core.system_api as config_pkg
|
||||
import realhf.base.recover as recover
|
||||
import realhf.system.request_reply_stream as request_reply_stream
|
||||
import realhf.system.worker_base as worker_base
|
||||
from realhf.api.core.config import ModelName
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base import constants, logging, name_resolve, names, timeutil, topology
|
||||
from realhf.system.buffer import AsyncIOSequenceBuffer
|
||||
from realhf.system.v2.function_call import RPCCorountineControl
|
||||
from realhf.system.v2.function_executor import FunctionExecutor
|
||||
|
||||
logger = logging.getLogger("master worker", "system")
|
||||
blogger = logging.getLogger("benchmark")
|
||||
|
||||
uvloop.install()
|
||||
|
||||
|
||||
class MasterWorker(worker_base.Worker):
|
||||
global_exp_tik = time.perf_counter()
|
||||
|
||||
def _configure(self, config: config_pkg.MasterWorker):
|
||||
self.config = config
|
||||
|
||||
self.__model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = (
|
||||
config.model_topos
|
||||
)
|
||||
|
||||
# Build execution graph and initialize concurrency utilities.
|
||||
self.__model_rpcs = config.model_rpcs
|
||||
|
||||
# Sort all MFCs in the topological order and
|
||||
# calculate the width of each level.
|
||||
# These numbers will determine when to flush MFC requests.
|
||||
self.__topo_widths = []
|
||||
for generation in nx.topological_generations(self.__model_rpcs[0]._G):
|
||||
self.__topo_widths.append(len(generation))
|
||||
logger.info("Topological widths: " + str(self.__topo_widths))
|
||||
|
||||
self.__rpc_srcs = list(filter(lambda rpc: rpc.is_src, self.__model_rpcs))
|
||||
self.__rpc_dsts = list(filter(lambda rpc: rpc.is_dst, self.__model_rpcs))
|
||||
|
||||
# Save and eval control.
|
||||
self.__total_train_epochs = config.exp_ctrl.total_train_epochs
|
||||
self.__save_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=config.exp_ctrl.save_freq_epochs,
|
||||
freq_step=config.exp_ctrl.save_freq_steps,
|
||||
freq_sec=config.exp_ctrl.save_freq_secs,
|
||||
)
|
||||
if (
|
||||
config.exp_ctrl.ckpt_freq_epochs is None
|
||||
and config.exp_ctrl.ckpt_freq_steps is None
|
||||
and config.exp_ctrl.ckpt_freq_secs is None
|
||||
):
|
||||
self.__ckpt_ctl = self.__save_ctl
|
||||
else:
|
||||
self.__ckpt_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=config.exp_ctrl.ckpt_freq_epochs,
|
||||
freq_step=config.exp_ctrl.ckpt_freq_steps,
|
||||
freq_sec=config.exp_ctrl.ckpt_freq_secs,
|
||||
)
|
||||
self.__eval_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=config.exp_ctrl.eval_freq_epochs,
|
||||
freq_step=config.exp_ctrl.eval_freq_steps,
|
||||
freq_sec=config.exp_ctrl.eval_freq_secs,
|
||||
)
|
||||
|
||||
self.MODEL_SAVE_ROOT = os.path.join(
|
||||
constants.MODEL_SAVE_ROOT,
|
||||
config.worker_info.experiment_name,
|
||||
config.worker_info.trial_name,
|
||||
)
|
||||
os.makedirs(self.MODEL_SAVE_ROOT, exist_ok=True)
|
||||
|
||||
self.__initialized = False
|
||||
self.__recover_run, self.__recover_info = recover.load_recover_info()
|
||||
if self.__recover_info is not None:
|
||||
logger.info(
|
||||
f"Loaded recover info: recover_start={self.__recover_info.recover_start}, "
|
||||
f"last_step_info={self.__recover_info.last_step_info}."
|
||||
)
|
||||
logger.info(
|
||||
f"Number of used data in recover info: {len(self.__recover_info.hash_vals_to_ignore)}. "
|
||||
f"The previous experiment probably ran for {len(self.__recover_info.hash_vals_to_ignore) // self.__rpc_srcs[0].n_seqs} steps in the epoch."
|
||||
)
|
||||
|
||||
# Create corountine control objects for running the dataflow graph.
|
||||
self.__rpc_ctrl = RPCCorountineControl(
|
||||
train_count=asyncio.Queue(maxsize=len(self.__rpc_dsts)),
|
||||
topo_level_count=asyncio.Queue(maxsize=sum(self.__topo_widths)),
|
||||
# NOTE: We should accumulate the used data hashes in the same epoch
|
||||
# to prevent loading data used before.
|
||||
used_hash_vals_this_epoch=(
|
||||
copy.deepcopy(self.__recover_info.hash_vals_to_ignore)
|
||||
if self.__recover_run
|
||||
else list()
|
||||
),
|
||||
hash_vals_to_ignore_in_recover=(
|
||||
copy.deepcopy(self.__recover_info.hash_vals_to_ignore)
|
||||
if self.__recover_run
|
||||
else list()
|
||||
),
|
||||
)
|
||||
|
||||
if self.__recover_run:
|
||||
self.__rpc_ctrl.step_info = copy.deepcopy(self.__recover_info.recover_start)
|
||||
|
||||
self.__eval_ctl.load_state_dict(self.__recover_info.eval_ctl_info)
|
||||
self.__save_ctl.load_state_dict(self.__recover_info.save_ctl_info)
|
||||
self.__ckpt_ctl.load_state_dict(self.__recover_info.ckpt_ctl_info)
|
||||
|
||||
logger.info(
|
||||
f"Recovering from previous run. "
|
||||
f"Epoch: {self.__rpc_ctrl.step_info.epoch + 1}, "
|
||||
f"Epoch Step: {self.__rpc_ctrl.step_info.epoch_step + 1} "
|
||||
f"Global Step: {self.__rpc_ctrl.step_info.global_step + 1}."
|
||||
)
|
||||
|
||||
# for benchmark
|
||||
self.e2e_time_history = []
|
||||
self.__benchmark_steps = config.exp_ctrl.benchmark_steps
|
||||
|
||||
return config.worker_info
|
||||
|
||||
def initialize_models(self):
|
||||
# Initialize model backends.
|
||||
model_names = list(self.__model_topos.keys())
|
||||
self.logger.info(f"Initialize model backends with order: {model_names}.")
|
||||
train_rpcs = list(
|
||||
filter(
|
||||
lambda rpc: rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
|
||||
self.__model_rpcs,
|
||||
)
|
||||
)
|
||||
assert all(rpc.n_seqs == train_rpcs[0].n_seqs for rpc in train_rpcs)
|
||||
if len(train_rpcs) > 0:
|
||||
ft_spec = model_api.FinetuneSpec(
|
||||
total_train_epochs=self.config.exp_ctrl.total_train_epochs,
|
||||
dataset_size=self._dataset_size,
|
||||
train_batch_size=train_rpcs[0].n_seqs,
|
||||
)
|
||||
else:
|
||||
ft_spec = model_api.FinetuneSpec(
|
||||
total_train_epochs=self.config.exp_ctrl.total_train_epochs,
|
||||
dataset_size=self._dataset_size,
|
||||
train_batch_size=self.__src_rpc.n_seqs,
|
||||
)
|
||||
_initialized_roles = []
|
||||
for model_name in model_names:
|
||||
topo = self.config.model_topos[model_name]
|
||||
# Build FinetuneSpec, which is required to initialize backends.
|
||||
_handlers = [
|
||||
config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, j)
|
||||
for j in range(topo.world_size())
|
||||
]
|
||||
|
||||
init_payloads = [
|
||||
request_reply_stream.Payload(
|
||||
handler=_h,
|
||||
handle_name="initialize",
|
||||
data=ft_spec,
|
||||
)
|
||||
for _h in _handlers
|
||||
]
|
||||
|
||||
# Send initialization requests then immediately flush them.
|
||||
self.__stream.request(
|
||||
payloads=init_payloads,
|
||||
)
|
||||
self.__stream.request(
|
||||
handlers=_handlers,
|
||||
handle_type="flush",
|
||||
no_syn=True,
|
||||
)
|
||||
|
||||
_initialized_roles.append(model_name.role)
|
||||
|
||||
self._ft_spec = ft_spec
|
||||
logger.info("Initializations of models and backends complete.")
|
||||
|
||||
def get_dataset_model_info(self):
|
||||
src_rpc = self.__rpc_srcs[0]
|
||||
src_rpc_topo = self.config.model_topos[src_rpc.model_name]
|
||||
src_rpc_dp_size = src_rpc_topo.get_dim("data")
|
||||
|
||||
# Request training specification from data workers.
|
||||
all_data = sum(
|
||||
self.__stream.call(
|
||||
handlers=[f"__data{i}__" for i in range(src_rpc_dp_size)],
|
||||
datas=[None for i in range(src_rpc_dp_size)],
|
||||
handle_type="spec",
|
||||
),
|
||||
[],
|
||||
)
|
||||
|
||||
# NOTE: For dynamic datasets, we still count epoch according to the initial number of data,
|
||||
# such that the learning rate decay is not affected.
|
||||
seqlens = [max(sum(v[0]) for v in x.seqlens.values()) for x in all_data]
|
||||
self._dataset_size = len(all_data)
|
||||
self._steps_per_epoch = self._dataset_size // src_rpc.n_seqs
|
||||
self._avg_tokens_per_batch = sum(seqlens) / self._steps_per_epoch
|
||||
self._dataset_ids = [copy.deepcopy(x.ids[0]) for x in all_data]
|
||||
|
||||
# Request model configs from model workers.
|
||||
# Return None if the model is not a ReaLModel.
|
||||
self.__model_configs: Dict[ModelName, None | ReaLModelConfig] = {}
|
||||
for model_name, topo in self.config.model_topos.items():
|
||||
h = config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, 0)
|
||||
self.__model_configs[model_name] = self.__stream.call(
|
||||
handlers=[h],
|
||||
datas=[None],
|
||||
handle_type="model_config",
|
||||
)[0]
|
||||
|
||||
def __lazy_init(self):
|
||||
# Set up streams.
|
||||
handler_routing = copy.deepcopy(self.config.msid2mwid)
|
||||
src_rpc = self.__rpc_srcs[0]
|
||||
src_rpc_topo = self.config.model_topos[src_rpc.model_name]
|
||||
src_rpc_dp_size = src_rpc_topo.get_dim("data")
|
||||
src_rpc_pp_size = src_rpc_topo.get_dim("pipe")
|
||||
for i in range(src_rpc_dp_size):
|
||||
rank = src_rpc_topo.get_rank(data=i, pipe=src_rpc_pp_size - 1, model=0)
|
||||
handler_routing[f"__data{i}__"] = self.config.msid2mwid[
|
||||
config_pkg.ModelShardID.from_parallelism_rank(
|
||||
model_name=src_rpc.model_name,
|
||||
topo=src_rpc_topo,
|
||||
parallelism_rank=rank,
|
||||
)
|
||||
]
|
||||
handler_routing.update({i: i for i in range(self.config.n_model_workers)})
|
||||
self.__stream = request_reply_stream.make_master_stream(
|
||||
self.config.worker_info,
|
||||
n_subscribers=self.config.n_model_workers,
|
||||
handler_routing=handler_routing,
|
||||
)
|
||||
self.__stream: request_reply_stream.NameResolvingRequestClient
|
||||
|
||||
self.__src_rpc = src_rpc = [
|
||||
rpc for rpc in self.config.model_rpcs if rpc.is_src
|
||||
][0]
|
||||
|
||||
self.get_dataset_model_info()
|
||||
|
||||
self.initialize_models()
|
||||
|
||||
self.__seqbuffer = AsyncIOSequenceBuffer(
|
||||
self.__model_rpcs,
|
||||
max_size=int(os.getenv("REAL_MASTER_BUFFER_SIZE", str(int(1e7)))),
|
||||
)
|
||||
|
||||
# Create coroutines for model RPCs.
|
||||
logger.info(f"Creating asyncio coroutines...")
|
||||
self.func_executor = FunctionExecutor(
|
||||
rpcs=self.__model_rpcs,
|
||||
msid2mwid=self.config.msid2mwid,
|
||||
stream=self.__stream,
|
||||
buffer=self.__seqbuffer,
|
||||
model_topos=self.__model_topos,
|
||||
model_configs=self.__model_configs,
|
||||
ctrl=self.__rpc_ctrl,
|
||||
)
|
||||
logger.info(f"Coroutines created. The master worker is ready to run.")
|
||||
|
||||
# wandb init, connect to remote wandb host
|
||||
wandb.login()
|
||||
wandb.init(
|
||||
mode=self.wandb_config.mode,
|
||||
entity=self.wandb_config.entity,
|
||||
project=self.wandb_config.project or constants.experiment_name(),
|
||||
name=self.wandb_config.name or constants.trial_name(),
|
||||
job_type=self.wandb_config.job_type,
|
||||
group=self.wandb_config.group,
|
||||
notes=self.wandb_config.notes,
|
||||
tags=self.wandb_config.tags,
|
||||
config=self.wandb_config.config,
|
||||
dir=os.path.join(
|
||||
constants.LOG_ROOT, constants.experiment_name(), constants.trial_name()
|
||||
),
|
||||
force=True,
|
||||
resume="allow",
|
||||
settings=wandb.Settings(start_method="fork"),
|
||||
)
|
||||
|
||||
self.__initialized = True
|
||||
self._train_start_time = time.perf_counter()
|
||||
|
||||
self.__last_step_info = recover.StepInfo(
|
||||
epoch=-1,
|
||||
epoch_step=-1,
|
||||
global_step=-1,
|
||||
)
|
||||
|
||||
def _poll(self):
|
||||
is_new_epoch = False
|
||||
|
||||
if not self.__initialized:
|
||||
self.__lazy_init()
|
||||
|
||||
# Main execution steps. The graph runs under-the-hood in RPC & stream threads.
|
||||
# Wait for the finish of the traversal of the execution graph.
|
||||
execution_start = time.perf_counter()
|
||||
if self.__rpc_ctrl.ids_to_clear:
|
||||
# Send clear cache requests to model workers.
|
||||
# Clearing the data used in the last step.
|
||||
self._clear_gpu_cache()
|
||||
|
||||
is_new_epoch = self._ft_spec.is_new_epoch(self.__rpc_ctrl.step_info)
|
||||
is_epoch_last_step = self._ft_spec.is_epoch_last_step(self.__rpc_ctrl.step_info)
|
||||
|
||||
# Check whether we should evaluate or save models.
|
||||
self.__rpc_ctrl.should_eval = self.__eval_ctl.check(
|
||||
epochs=int(is_epoch_last_step), steps=1
|
||||
)
|
||||
self.__rpc_ctrl.should_save = self.__save_ctl.check(
|
||||
epochs=int(is_epoch_last_step), steps=1
|
||||
)
|
||||
self.__rpc_ctrl.should_ckpt = self.__ckpt_ctl.check(
|
||||
epochs=int(is_epoch_last_step), steps=1
|
||||
)
|
||||
|
||||
# Traverse over the dataflow graph for once.
|
||||
self.func_executor.execute_step()
|
||||
|
||||
# Post-process.
|
||||
if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt:
|
||||
self.__last_step_info = copy.deepcopy(self.__rpc_ctrl.step_info)
|
||||
|
||||
self.__rpc_ctrl.used_hash_vals_this_epoch += list(self.__rpc_ctrl.ids_to_clear)
|
||||
|
||||
if is_epoch_last_step:
|
||||
self.__rpc_ctrl.used_hash_vals_this_epoch = (
|
||||
self.__rpc_ctrl.used_hash_vals_this_epoch[self._dataset_size :]
|
||||
)
|
||||
|
||||
if is_new_epoch:
|
||||
self.__rpc_ctrl.step_info.epoch += 1
|
||||
self.__rpc_ctrl.step_info.epoch_step = 0
|
||||
|
||||
# Logging.
|
||||
time_since_configure = time.perf_counter() - self._train_start_time
|
||||
e2e_time = time.perf_counter() - execution_start
|
||||
self.e2e_time_history.append(e2e_time)
|
||||
|
||||
self._log_training_stats(e2e_time, time_since_configure)
|
||||
|
||||
# Updata counters.
|
||||
self.__rpc_ctrl.step_info.epoch_step += 1
|
||||
self.__rpc_ctrl.step_info.global_step += 1
|
||||
|
||||
if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt:
|
||||
self.__recover_save()
|
||||
|
||||
# Pause the worker if experiment or system-wise benchmark completes.
|
||||
if (
|
||||
self.__benchmark_steps is not None
|
||||
and self.__rpc_ctrl.step_info.global_step >= self.__benchmark_steps
|
||||
) or (
|
||||
self.__rpc_ctrl.step_info.global_step * self.__src_rpc.n_seqs
|
||||
>= self.__total_train_epochs * self._dataset_size
|
||||
):
|
||||
# We don't know whether it is the last step of the current epoch,
|
||||
# so we exit at the first step of the next epoch.
|
||||
if self.__benchmark_steps is not None:
|
||||
logger.info(
|
||||
f"Finished benchmark {self.__benchmark_steps}. "
|
||||
f"Time consumption of this setup: {time_since_configure:.3f}"
|
||||
)
|
||||
logger.info(f"avg #e2e# time *{np.mean(self.e2e_time_history):.3f}*")
|
||||
return self.experiment_complete_exit()
|
||||
|
||||
return worker_base.PollResult(sample_count=1, batch_count=1)
|
||||
|
||||
def _log_training_stats(self, e2e_time: float, time_since_configure: float):
|
||||
# calculate flops
|
||||
#########################################
|
||||
if not all(
|
||||
isinstance(v, ReaLModelConfig) for v in self.__model_configs.values()
|
||||
):
|
||||
logger.warning(
|
||||
f"Not all models are ReaLModels. Unable to calculate FLOP/s."
|
||||
)
|
||||
flops = None
|
||||
tflops_per_gpu = float("inf")
|
||||
else:
|
||||
flops = self.__rpc_ctrl.flops_counter.get_flops()
|
||||
tflops = flops / (e2e_time * (10**12))
|
||||
tflops_per_gpu = flops / (e2e_time * self.config.n_model_workers * (10**12))
|
||||
self.__rpc_ctrl.flops_counter.clear()
|
||||
#########################################
|
||||
|
||||
epoch = self.__rpc_ctrl.step_info.epoch + 1
|
||||
epoch_step = self.__rpc_ctrl.step_info.epoch_step + 1
|
||||
global_step = self.__rpc_ctrl.step_info.global_step + 1
|
||||
s = f"Epoch {epoch}/{self.config.exp_ctrl.total_train_epochs} "
|
||||
s += f"step {epoch_step}/{self._steps_per_epoch} "
|
||||
s += f"(global step {global_step}) finishes. "
|
||||
s += f"Average #tokens per batch is {self._avg_tokens_per_batch:.0f}. "
|
||||
s += f"#End to end# execution time: *{e2e_time:.3f}*s. "
|
||||
s += f"Total time consumption: {time_since_configure:.3f}s. "
|
||||
if len(self.e2e_time_history) > 2:
|
||||
remaining_steps = self._steps_per_epoch - epoch_step
|
||||
remaining_epochs = self.__total_train_epochs - epoch
|
||||
avg_t = np.mean(self.e2e_time_history[2:])
|
||||
remain_t = avg_t * remaining_steps
|
||||
remain_t += avg_t * self._steps_per_epoch * remaining_epochs
|
||||
s += f"Estimated remaining time: {remain_t:.3f}s. "
|
||||
if flops is not None:
|
||||
s += f"TFLOP/s per GPU: {tflops_per_gpu:.2f}, total TFLOP/s: {tflops:.2f}."
|
||||
logger.info(s)
|
||||
logger.info(
|
||||
f"Time taken so far across all configurations: {time.perf_counter() - self.global_exp_tik:.2f}s"
|
||||
)
|
||||
|
||||
def _clear_gpu_cache(self):
|
||||
self.__stream.request(
|
||||
handlers=list(range(self.config.n_model_workers)),
|
||||
handle_type="clear_data_cache",
|
||||
datas=[
|
||||
self.__rpc_ctrl.ids_to_clear
|
||||
for _ in list(range(self.config.n_model_workers))
|
||||
],
|
||||
no_syn=True,
|
||||
)
|
||||
self.__rpc_ctrl.ids_to_clear.clear()
|
||||
|
||||
def experiment_complete_exit(self):
|
||||
logger.info(
|
||||
colorama.Style.RESET_ALL
|
||||
+ colorama.Fore.YELLOW
|
||||
+ colorama.Style.BRIGHT
|
||||
+ "\033[1m"
|
||||
+ "Experiment Completes! Yeah!!!!!!!!"
|
||||
+ colorama.Style.RESET_ALL
|
||||
)
|
||||
|
||||
# Send requests to pause model workers.
|
||||
# Model workers will not respond to this message.
|
||||
self.__stream.request(
|
||||
handlers=list(range(self.config.n_model_workers)),
|
||||
handle_type="reset",
|
||||
datas=[None for _ in list(range(self.config.n_model_workers))],
|
||||
)
|
||||
self.__stream.close()
|
||||
constants.reset_run()
|
||||
# Reset names used for distributed training.
|
||||
# The next round of training will set up a new distributed environment.
|
||||
name_resolve.clear_subtree(
|
||||
names.distributed_root(constants.experiment_name(), constants.trial_name())
|
||||
)
|
||||
name_resolve.clear_subtree(
|
||||
names.request_reply_stream_root(
|
||||
constants.experiment_name(), constants.trial_name()
|
||||
)
|
||||
)
|
||||
|
||||
wandb.finish()
|
||||
gc.collect()
|
||||
self.__initialized = False
|
||||
self.pause()
|
||||
return worker_base.PollResult(0, 0)
|
||||
|
||||
def __recover_save(self):
|
||||
# save step info for recover
|
||||
if os.getenv("REAL_SAVE_RECOVER_STATES", "0") != "1":
|
||||
return
|
||||
# save step info for recover
|
||||
this_step_info = copy.deepcopy(self.__rpc_ctrl.step_info)
|
||||
recover_info = recover.RecoverInfo(
|
||||
recover_start=this_step_info,
|
||||
last_step_info=self.__last_step_info,
|
||||
save_ctl_info=self.__save_ctl.state_dict(),
|
||||
ckpt_ctl_info=self.__ckpt_ctl.state_dict(),
|
||||
eval_ctl_info=self.__eval_ctl.state_dict(),
|
||||
hash_vals_to_ignore=self.__rpc_ctrl.used_hash_vals_this_epoch,
|
||||
)
|
||||
|
||||
recover.dump_recover_info(recover_info)
|
||||
logger.info("Dumped recover info to file.")
|
||||
logger.info(f"Will recover from: {recover_info.recover_start}")
|
||||
logger.info(
|
||||
f"Number of data used in this epoch: {len(recover_info.hash_vals_to_ignore)}"
|
||||
)
|
|
@ -13,11 +13,7 @@ import time
|
|||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import realhf.api.core.system_api as system_api
|
||||
from realhf.base import (
|
||||
logging,
|
||||
name_resolve,
|
||||
names,
|
||||
)
|
||||
from realhf.base import logging, name_resolve, names
|
||||
from realhf.base.gpu_utils import set_cuda_device
|
||||
|
||||
logger = logging.getLogger("worker")
|
||||
|
|
|
@ -47,6 +47,8 @@ def math_dataset(request, save_path):
|
|||
return dataset
|
||||
|
||||
|
||||
# NOTE: we can't test v1 and v2 at the same time.
|
||||
@pytest.mark.parametrize("use_v2_worker", [True])
|
||||
@pytest.mark.parametrize(
|
||||
"dp,pp,mp",
|
||||
[
|
||||
|
@ -66,6 +68,7 @@ def test_ppo_symm(
|
|||
dp,
|
||||
pp,
|
||||
mp,
|
||||
use_v2_worker,
|
||||
):
|
||||
# Setup experiment env. Should be done before any other operations.
|
||||
log_root = tmp_path_factory.mktemp("ppo")
|
||||
|
@ -117,8 +120,10 @@ def test_ppo_symm(
|
|||
),
|
||||
),
|
||||
)
|
||||
exp_cfg.actor.vllm.hybrid_train = True
|
||||
exp_cfg.actor.vllm.enforce_eager = True
|
||||
|
||||
run_test_exp(exp_cfg)
|
||||
run_test_exp(exp_cfg, use_v2_worker=use_v2_worker)
|
||||
|
||||
|
||||
# The global resharding strategy, where all MFCs
|
||||
|
@ -238,6 +243,8 @@ def test_ppo_global_reshard(
|
|||
),
|
||||
),
|
||||
)
|
||||
exp_cfg.actor.vllm.hybrid_train = True
|
||||
exp_cfg.actor.vllm.enforce_eager = True
|
||||
|
||||
run_test_exp(exp_cfg)
|
||||
|
||||
|
@ -351,6 +358,8 @@ def test_ppo_param_realloc_sub_device_mesh(
|
|||
),
|
||||
),
|
||||
)
|
||||
exp_cfg.actor.vllm.hybrid_train = True
|
||||
exp_cfg.actor.vllm.enforce_eager = True
|
||||
|
||||
run_test_exp(exp_cfg)
|
||||
|
||||
|
@ -470,6 +479,8 @@ def test_ppo_save(
|
|||
),
|
||||
),
|
||||
)
|
||||
exp_cfg.actor.vllm.hybrid_train = True
|
||||
exp_cfg.actor.vllm.enforce_eager = True
|
||||
|
||||
run_test_exp(exp_cfg)
|
||||
|
||||
|
|
|
@ -38,6 +38,8 @@ def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp
|
|||
test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp)
|
||||
|
||||
|
||||
# NOTE: we can't test v1 and v2 at the same time.
|
||||
@pytest.mark.parametrize("use_v2_worker", [True])
|
||||
@pytest.mark.parametrize(
|
||||
"dp,pp,tp",
|
||||
[
|
||||
|
@ -47,7 +49,9 @@ def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp
|
|||
(1, 1, 2),
|
||||
],
|
||||
)
|
||||
def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
|
||||
def test_sft(
|
||||
tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp, use_v2_worker
|
||||
):
|
||||
|
||||
# Setup experiment env. Should be done before any other operations.
|
||||
log_root = tmp_path_factory.mktemp("sft")
|
||||
|
@ -79,4 +83,4 @@ def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
|
|||
),
|
||||
)
|
||||
|
||||
run_test_exp(exp_cfg)
|
||||
run_test_exp(exp_cfg, use_v2_worker=use_v2_worker)
|
||||
|
|
|
@ -44,13 +44,22 @@ def run_model_worker(cfg, mw, barrier):
|
|||
initd = True
|
||||
|
||||
|
||||
def run_test_exp(exp_cfg: Experiment, expr_name=None, trial_name=None):
|
||||
def run_test_exp(
|
||||
exp_cfg: Experiment,
|
||||
expr_name=None,
|
||||
trial_name=None,
|
||||
use_v2_worker: bool = False,
|
||||
):
|
||||
constants.set_force_cpu(True)
|
||||
# Register all datasets and models
|
||||
import realhf.impl.dataset # isort: skip
|
||||
import realhf.impl.model # isort: skip
|
||||
from realhf.api.core import system_api
|
||||
from realhf.system.master_worker import MasterWorker
|
||||
|
||||
if not use_v2_worker:
|
||||
from realhf.system.master_worker import MasterWorker
|
||||
else:
|
||||
from realhf.system.v2.master_worker import MasterWorker
|
||||
|
||||
system_api.ALL_EXPERIMENT_CLASSES = {}
|
||||
register_experiment(testing._DEFAULT_EXPR_NAME, lambda: exp_cfg)
|
||||
|
|
Loading…
Reference in New Issue