PullRequest: 6 Implement master worker v2 for refactoring and uvloop support

Merge branch fw/uvloop of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/6

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* fw/fix-dataloading-not-shuffle
* .
* .
* .
* .
* .
* add v2 master worker
* cpu test pass
* ppo run
* .
* format
* fix
* merge and format
* change default env vars to v1 worker
This commit is contained in:
博惟 2025-03-07 11:07:54 +08:00
parent ad5baa74ee
commit c1bb4770ff
18 changed files with 1298 additions and 80 deletions

View File

@ -828,3 +828,22 @@ def make_dataset(
logger.info(f"Dataset creation/loading time: {time.perf_counter() - tik:.3f}s")
return dataset
def gather_stat(src: List[Dict]) -> Dict:
cnt, stats = {}, {}
for reply in src:
for k, v in reply.items():
cnt[k] = cnt.get(k, 0) + 1
stats[k] = stats.get(k, 0) + v
res = {k: v / cnt for k, v, cnt in zip(stats.keys(), stats.values(), cnt.values())}
for k, c in cnt.items():
if c != len(src):
logger.warning(f"Gathered `{k}` is not present in every returned stats.")
for k, v in res.items():
if any(abs(v - x.get(k, None)) > 1e-4 for x in src):
logger.warning(
f"Gathered `{k}` is not all-reduced "
f"before returning: ({[x.get(k, None) for x in src]}, {v})."
)
return res

View File

@ -161,6 +161,7 @@ def main_start(args, recover_count: int = 0):
REAL_RECOVER_RUN="1" if is_recover_run else "0",
REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0",
REAL_MATH_METADATA_PATH=os.getenv("REAL_MATH_METADATA_PATH", ""),
REAL_USE_V2_WORKER=os.getenv("REAL_USE_V2_WORKER", "0"),
)
for k, v in BASE_ENVIRONS.items():
os.environ[k] = v

View File

@ -31,8 +31,8 @@ from realhf.api.core.system_api import (
ModelWorker,
Scheduling,
TasksGroup,
WandBConfig,
TensorBoardConfig,
WandBConfig,
)
from realhf.api.quickstart.device_mesh import (
DeviceMesh,
@ -56,6 +56,7 @@ from realhf.experiments.common.check import (
)
from realhf.experiments.common.utils import (
AllocationMode,
asdict,
get_topo,
make_inf_backend_config,
make_train_backend_config,
@ -203,7 +204,9 @@ class CommonExperimentConfig(Experiment):
partition: str = "dev"
schedule_strategy: str = "empty_first"
wandb: WandBConfig = dataclasses.field(default_factory=WandBConfig)
tensorboard: TensorBoardConfig = dataclasses.field(default_factory=TensorBoardConfig)
tensorboard: TensorBoardConfig = dataclasses.field(
default_factory=TensorBoardConfig
)
image_name: Optional[str] = None
recover_mode: str = "disabled"
recover_retries: int = 1
@ -561,9 +564,7 @@ class CommonExperimentConfig(Experiment):
and self.gen_device_mesh.mapping[i, j]
):
gen_rpc_alloc = next(
alloc
for alloc in rpc_allocs
if alloc.rpc.interface_type == ModelInterfaceType.GENERATE
alloc for alloc in rpc_allocs if alloc.rpc.is_generate()
)
model_name = gen_rpc_alloc.rpc.model_name
topo = get_topo(
@ -620,6 +621,10 @@ class CommonExperimentConfig(Experiment):
model_rpc_allocs,
) in model_name_to_rpc_allocs.items():
rpcs = [rpc_alloc.rpc for rpc_alloc in model_rpc_allocs]
if self._allocation_mode.is_decoupled() and all(
rpc.is_generate() for rpc in rpcs
):
continue
rpc_alloc = model_rpc_allocs[0]
model_cfg = self.models[model_name.role]
model = get_real_model_config(
@ -677,9 +682,7 @@ class CommonExperimentConfig(Experiment):
rpc.is_generate() for rpc in rpcs
):
assert len(rpcs) == 1 and rpcs[0].is_generate(), rpcs
vllm_dict_args: Dict[str, Any] = OmegaConf.to_container(
model_cfg.vllm, resolve=True
)
vllm_dict_args: Dict[str, Any] = asdict(model_cfg.vllm)
backend = ModelBackendAbstraction(
"vllm",
args=dict(
@ -689,6 +692,17 @@ class CommonExperimentConfig(Experiment):
)
else:
backend = make_inf_backend_config(model_cfg, rpc_alloc.parallel)
if any(rpc.is_generate() for rpc in rpcs) and backend.type_ not in [
"vllm",
"sglang",
]:
print(rpcs, model_name, backend.type_)
raise ValueError(
"vLLM or SGLang is not enabled for generation. "
"This behavior has been deprecated. "
"Please set model.vllm.hybrid_train=True "
"or model.sglang.hybrid_train=True."
)
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
if mapping[i, j]:

View File

@ -23,7 +23,11 @@ from realhf.api.quickstart.device_mesh import MFCConfig
from realhf.api.quickstart.entrypoint import register_quickstart_exp
from realhf.api.quickstart.model import ModelTrainEvalConfig
from realhf.experiments.common.common import CommonExperimentConfig
from realhf.experiments.common.utils import resolve_replica_ids, resolve_rpc_hooks
from realhf.experiments.common.utils import (
asdict,
resolve_replica_ids,
resolve_rpc_hooks,
)
logger = logging.getLogger("PPO Math exp", "colored")
@ -283,11 +287,7 @@ class PPOMATHConfig(CommonExperimentConfig):
# It is used for unifying the profiling API, which requires to
# pass external interface configurations in the launch command.
# Customized dataclass objects will not work in that case.
"generation_config": (
OmegaConf.to_container(self.ppo.gen, resolve=True)
if isinstance(self.ppo.gen, (OmegaConf, DictConfig))
else dataclasses.asdict(self.ppo.gen)
),
"generation_config": asdict(self.ppo.gen),
"early_stop_imp_ratio": self.ppo.early_stop_imp_ratio,
"adv_norm": self.ppo.adv_norm,
"group_size": self.group_size,

View File

@ -10,7 +10,7 @@ import re
from typing import *
import numpy as np
from omegaconf import OmegaConf
from omegaconf import DictConfig, OmegaConf
from realhf.api.core.config import (
ModelBackendAbstraction,
@ -317,3 +317,9 @@ class AllocationMode:
return
allocs[k] = v["*"]
return allocs
def asdict(cfg):
if isinstance(cfg, (OmegaConf, DictConfig)):
return OmegaConf.to_container(cfg, resolve=True)
return dataclasses.asdict(cfg)

View File

@ -32,7 +32,7 @@ class MockPipeTrainInstrSet(PipeTrainInstrSet):
Used for testing only.
"""
optimizer: torch.optim.Optimizer
optim: torch.optim.Optimizer
def _exec_backward_pass(
self,
@ -78,14 +78,19 @@ class MockPipeTrainInstrSet(PipeTrainInstrSet):
micro_batch_id: int,
step_id: int,
):
self.optimizer.step()
self.optim.step()
class AdamWithLossScale(torch.optim.Adam):
def get_loss_scale(self) -> torch.Tensor:
return torch.tensor([1.0], device=constants.current_device())
class MockTrainEngine(model_api.PipelinableEngine):
def __init__(self, module: ReaLModel, optimizer: torch.optim.Optimizer):
def __init__(self, module: ReaLModel, optimizer: AdamWithLossScale):
self.module = module
self.optimizer = optimizer
self.optim = optimizer
self.inf_engine = PipelinableInferenceEngine(module)
if constants.pipe_parallel_world_size() > 1:
@ -111,10 +116,10 @@ class MockTrainEngine(model_api.PipelinableEngine):
token_normalize_scope: str,
version_steps: int,
):
self.optimizer.zero_grad()
self.optim.zero_grad()
if constants.pipe_parallel_world_size() > 1:
# Fusing the minibatched forward-backward in a pipeline training schedule.
instr_set = MockPipeTrainInstrSet(self.optimizer)
instr_set = MockPipeTrainInstrSet(self, self.optim)
# NOTE: When training with pipeline parallel, num micro batches should be
# larger than 2 x num_pipeline_stages to avoid idle time.
return self.pipe_runner.train_batch(
@ -222,7 +227,7 @@ class MockTrainBackend(model_api.ModelBackend):
raise ValueError("MegatronTrainBackend only supports ReaLModel.")
if self.optimizer_name == "adam":
optimizer = torch.optim.Adam(module.parameters(), **self.optimizer_config)
optimizer = AdamWithLossScale(module.parameters(), **self.optimizer_config)
else:
raise NotImplementedError(
f"Optimizer {self.optimizer_name} not implemented for testing."

View File

@ -15,6 +15,7 @@ logger = logging.getLogger("system")
# NOTE: Workers are configured in the following order.
# Take special care when adding a new worker type.
WORKER_TYPES = ["model_worker", "master_worker"]
USE_V2_WORKER = os.getenv("REAL_USE_V2_WORKER", "0") == "1"
def load_worker(worker_type: str) -> Type:
@ -25,6 +26,8 @@ def load_worker(worker_type: str) -> Type:
def worker_type_to_module(worker_type: str):
if worker_type == "master_worker" and USE_V2_WORKER:
return "realhf.system.v2." + worker_type
return "realhf.system." + worker_type

View File

@ -147,6 +147,7 @@ class AsyncIOSequenceBuffer:
rpcs: List[dfg.MFCDef],
max_size: int,
):
self.rpcs = rpcs
self._lock = asyncio.Condition(asyncio.Lock())
# Buffer indicators, should be locked by self._lock.
@ -269,6 +270,62 @@ class AsyncIOSequenceBuffer:
)
return indices
async def put_batch(self, samples: List[SequenceSample]):
n = len(samples)
if n == 0:
return np.array([], dtype=np.int64)
async with self._lock:
self._assert_valid_indicator()
indices = np.where(self._is_empty)[0][:n]
if len(indices) < n:
raise BufferFull(
"You are probably using a large dataset. "
"The default buffer size 1M is not large enough. "
"Please set a larger buffer size by setting "
"the environment variable, e.g., REAL_MASTER_BUFFER_SIZE=3000000."
)
self._is_empty[indices] = False
self._is_being_put[indices] = True
self.__buffer.put_batch(indices, samples)
# Set a slight difference in birth time to let the order
# be deterministic.
self._birth_time[indices] = time.monotonic_ns() + np.arange(
len(indices), dtype=np.int64
)
async with self._lock:
self.__buffer._update_has_keys(indices)
has_keys = self.__buffer._get_has_keys(indices) # [bs, #keys]
rpc_key_mask = self._rpc_key_mask # [#keys, #rpcs]
self._ready_for_rpcs[indices] = (
has_keys[:, :, None] >= rpc_key_mask[None, :, :]
).all(axis=1)
self._is_being_put[indices] = False
self._is_idle[indices] = True
self._buf_size += len(samples)
if self._buf_size >= 0.95 * self.__max_size:
logger.warning(
f"Buffer is 95% full. The current buffer size is {self._buf_size} "
f"while the maximum size is {self.__max_size}. "
f"If your dataset has more than 1M sequences, consider enlarge "
f"the default batch size in the master worker."
)
can_do_rpcs = {rpc.name: self._can_do_rpc(rpc) for rpc in self.rpcs}
logger.info(f"After putting batch, can do RPCs? {can_do_rpcs}.")
self._lock.notify(len(self._rpc_names))
return indices
async def amend_batch(self, indices: List[int], samples: List[SequenceSample]):
async with self._lock:
await self._lock.wait_for(
@ -298,6 +355,17 @@ class AsyncIOSequenceBuffer:
if self._is_idle[indices].any():
self._lock.notify(len(self._rpc_names))
def _can_do_rpc(self, rpc: dfg.MFCDef) -> bool:
rpc_idx = self._rpc_names.index(rpc.name)
ready_indices = np.nonzero(
(self._is_idle | self._is_being_read)
& self._ready_for_rpcs[:, rpc_idx]
& ~self._completed_rpc[:, rpc_idx]
)[0]
if len(ready_indices) < rpc.n_seqs:
return False
return True
async def get_batch_for_rpc(
self, rpc: dfg.MFCDef
) -> Tuple[List[int], SequenceSample]:
@ -306,20 +374,10 @@ class AsyncIOSequenceBuffer:
)
rpc_idx = self._rpc_names.index(rpc.name)
def _can_do_rpc() -> bool:
ready_indices = np.nonzero(
(self._is_idle | self._is_being_read)
& self._ready_for_rpcs[:, rpc_idx]
& ~self._completed_rpc[:, rpc_idx]
)[0]
if len(ready_indices) < rpc.n_seqs:
return False
return True
async with self._lock:
# await self._lock.wait_for(_can_do_rpc)
while not _can_do_rpc():
while not self._can_do_rpc(rpc):
await self._lock.wait()
logger.info(f"Input keys ({rpc.input_keys}) for MFC {rpc.name} are ready!")

View File

@ -619,6 +619,7 @@ class RayController:
REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""),
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
REAL_MATH_METADATA_PATH=os.environ.get("REAL_MATH_METADATA_PATH", ""),
REAL_USE_V2_WORKER=os.getenv("REAL_USE_V2_WORKER", "0"),
)
runtime_env = {
"env_vars": env_vars,

View File

@ -39,11 +39,6 @@ from realhf.base import (
timeutil,
topology,
)
from realhf.base.asyncio_utils import (
raise_asyncio_exception,
setup_run_until_complete,
teardown_run_util_complete,
)
from realhf.system.buffer import AsyncIOSequenceBuffer
from realhf.system.flops_counter import FlopsCounter
@ -482,7 +477,7 @@ async def model_rpc_reply_func(
if isinstance(responses[-1], data_api.SequenceSample):
res = data_api.SequenceSample.gather(responses)
else:
res = _gather_stat(responses)
res = data_api.gather_stat(responses)
if rpc.log_return_value:
logger.info(f"RPC name {rpc.name} returns {res}")
@ -491,7 +486,9 @@ async def model_rpc_reply_func(
wandb.log(res, step=ctrl.step_info.global_step)
if summary_writer is not None:
for key, val in res.items():
summary_writer.add_scalar(f"{key}", val, ctrl.step_info.global_step)
summary_writer.add_scalar(
f"{key}", val, ctrl.step_info.global_step
)
logger.info(
f"Model rpc {rpc.name} finished. Run time {time.perf_counter() - tik:.4f}s."
@ -518,25 +515,6 @@ async def model_rpc_reply_func(
await stream.gather_async(other_req_ids)
def _gather_stat(src: List[Dict]) -> Dict:
cnt, stats = {}, {}
for reply in src:
for k, v in reply.items():
cnt[k] = cnt.get(k, 0) + 1
stats[k] = stats.get(k, 0) + v
res = {k: v / cnt for k, v, cnt in zip(stats.keys(), stats.values(), cnt.values())}
for k, c in cnt.items():
if c != len(src):
logger.warning(f"Gathered `{k}` is not present in every returned stats.")
for k, v in res.items():
if any(abs(v - x.get(k, None)) > 1e-4 for x in src):
logger.warning(
f"Gathered `{k}` is not all-reduced "
f"before returning: ({[x.get(k, None) for x in src]}, {v})."
)
return res
class MasterWorker(worker_base.Worker):
os.makedirs(constants.MODEL_SAVE_ROOT, exist_ok=True)
global_exp_tik = time.perf_counter()
@ -952,6 +930,9 @@ class MasterWorker(worker_base.Worker):
)
coroutine_tasks += [request_task, reply_task]
# Import here to avoid the conflict with nvloop.
from realhf.base.asyncio_utils import setup_run_until_complete
# Set up a run context of EventLoop.run_util_complete, baiscally copy-paste from cpython.
# With this context, we can call the non-block EventLoop._run_once (similar to worker._poll).
self.__asyncio_tasks: List[asyncio.Task] = coroutine_tasks
@ -993,6 +974,9 @@ class MasterWorker(worker_base.Worker):
)
def _poll(self):
# Import here to avoid the conflict with nvloop.
from realhf.base.asyncio_utils import raise_asyncio_exception
is_new_epoch = False
first_poll = not self.__initialized
@ -1276,6 +1260,9 @@ class MasterWorker(worker_base.Worker):
self.__rpc_ctrl.ids_to_clear.clear()
def experiment_complete_exit(self):
# Import here to avoid the conflict with nvloop.
from realhf.base.asyncio_utils import teardown_run_util_complete
self.__rpc_ctrl.stop.set()
for task in self.__asyncio_tasks:
task.cancel()

View File

@ -340,6 +340,8 @@ class ModelWorker(worker_base.Worker):
for tmp_sample in self.__dataloader:
self.__raw_samples += tmp_sample.meta().unpack()
self.__data_generator = enumerate(self.__dataloader)
self.__models: Dict[ModelName, model_api.Model] = dict()
self.__model_is_handle: Dict[ModelName, bool] = dict()
self.__interfaces: Dict[ModelName, model_api.ModelInterface] = dict()
@ -581,7 +583,10 @@ class ModelWorker(worker_base.Worker):
elif request.handle_name == "fetch":
dp_rank = int(re.search(r"__data(\d+)__", request.handler).group(1))
assert self.__has_dataset
if request.data["first_batch"]:
# Fetch.
try:
self.__dataset_batch_counter, cur_sample = next(self.__data_generator)
except StopIteration:
# Upon the first fetch request, filter dataset and create dataloader.
eval_scores_path = os.path.join(
constants.MODEL_SAVE_ROOT,
@ -595,10 +600,8 @@ class ModelWorker(worker_base.Worker):
constants.trial_name(),
f"dataset_indices_{dp_rank}.npy",
)
if (
hasattr(self.__dataset, "filter")
and not request.data["first_poll"]
and os.path.exists(eval_scores_path)
if hasattr(self.__dataset, "filter") and os.path.exists(
eval_scores_path
):
# Don't filter dataset on the first poll after recover.
with open(eval_scores_path, "r", encoding="utf-8") as f:
@ -621,9 +624,7 @@ class ModelWorker(worker_base.Worker):
generator=g,
)
self.__data_generator = enumerate(self.__dataloader)
# Fetch.
self.__dataset_batch_counter, cur_sample = next(self.__data_generator)
self.__dataset_batch_counter, cur_sample = next(self.__data_generator)
# Defer data that has not been used in the previous epoch.
data_loaded = []
@ -707,9 +708,6 @@ class ModelWorker(worker_base.Worker):
# e.g., data transfer, parameter syncrhonization.
pass
elif request.handle_name == "initialize":
assert not self.__model_is_handle[
request.handler.model_name
], request.handler.model_name
self.__models[request.handler.model_name] = self._backend.initialize(
self._model, data
)

View File

@ -0,0 +1,437 @@
# Copyright 2025 Ant Group Inc.
import asyncio
import dataclasses
import os
import time
import uuid
from collections import defaultdict
from typing import Dict, Hashable, List, Set, Tuple
import wandb
import realhf.api.core.config as config_api
import realhf.api.core.data_api as data_api
import realhf.api.core.dfg as dfg
import realhf.api.core.system_api as config_pkg
import realhf.base.recover as recover
import realhf.system.request_reply_stream as request_reply_stream
from realhf.api.core.config import ModelName
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import constants, logging, topology
from realhf.system.buffer import AsyncIOSequenceBuffer
from realhf.system.flops_counter import FlopsCounter
logger = logging.getLogger(__name__, "system")
blogger = logging.getLogger("benchmark")
@dataclasses.dataclass
class RPCCorountineControl:
# for counting the number of finished training steps
# one training step corresponds to traversal of the whole DFG
train_count: asyncio.Queue
# For flushing requests
topo_level_count: asyncio.Queue
# for training data management and data cleaning after each step
ids_to_clear: Set[Hashable] = dataclasses.field(default_factory=set)
flops_counter: FlopsCounter = dataclasses.field(default_factory=FlopsCounter)
data_owner: Dict[Tuple[int, str], Tuple[ModelName, int]] = dataclasses.field(
default_factory=dict
)
should_save: bool = False
should_eval: bool = False
should_ckpt: bool = False
step_info: recover.StepInfo = dataclasses.field(default_factory=recover.StepInfo)
# recover information
used_hash_vals_this_epoch: List[int] = dataclasses.field(default_factory=list)
hash_vals_to_ignore_in_recover: List[int] = dataclasses.field(default_factory=list)
class FunctionCall:
def __init__(
self,
rpc: dfg.MFCDef,
src_rpc: dfg.MFCDef,
stream: request_reply_stream.NameResolvingRequestClient,
msid2mwid: Dict[config_pkg.ModelShardID, int],
model_topos: Dict[str, topology.PipeModelDataParallelTopology],
model_configs: Dict[str, None | ReaLModelConfig],
ctrl: RPCCorountineControl,
buffer: AsyncIOSequenceBuffer,
):
self.rpc = rpc
self.src_rpc = src_rpc
self.stream = stream
self.msid2mwid = msid2mwid
self.model_topos = model_topos
self.model_configs = model_configs
self.model_save_root = os.path.join(
constants.MODEL_SAVE_ROOT,
constants.experiment_name(),
constants.trial_name(),
)
self.rpc_ctrl = ctrl
self.buffer = buffer
@property
def dp_size(self):
return self.model_topos[self.rpc.model_name].get_dim("data")
@property
def pp_size(self):
return self.model_topos[self.rpc.model_name].get_dim("pipe")
def attach_payloads_with_hooks(
self,
payloads: Dict[config_api.ModelShardID, request_reply_stream.Payload],
mwids: List[int],
main_handlers: List[config_pkg.ModelShardID],
hook_type: str,
) -> Tuple[Dict[config_api.ModelShardID, request_reply_stream.Payload], List[int]]:
assert hook_type in ["pre", "post"], hook_type
rpc = self.rpc
model_topos = self.model_topos
model_configs = self.model_configs
main_mwids = set([self.msid2mwid[h] for h in main_handlers])
for hook in getattr(rpc, f"_{hook_type}_hooks"):
if isinstance(hook, dfg.ParamReallocHook):
assert (hook.source is None) != (hook.target is None), hook
if hook.source is None:
src_topo = model_topos[rpc.model_name]
dst_topo = model_topos[hook.target]
dst_config = model_configs[hook.target]
src_model_name, dst_model_name = rpc.model_name, hook.target
other_model_name = hook.target
other_topo = dst_topo
else:
src_topo = model_topos[hook.source]
dst_topo = model_topos[rpc.model_name]
dst_config = model_configs[rpc.model_name]
src_model_name, dst_model_name = hook.source, rpc.model_name
other_model_name = hook.source
other_topo = src_topo
ps_data = {
"from_model_name": src_model_name,
"to_model_name": dst_model_name,
"from_topo": src_topo,
"to_topo": dst_topo,
"to_model_config": dst_config,
"eta": hook.eta,
}
for h in main_handlers:
getattr(payloads[h], f"{hook_type}_hooks").append("param_realloc")
getattr(payloads[h], f"{hook_type}_hook_data").append(ps_data)
other_handlers = [
config_api.ModelShardID.from_parallelism_rank(
other_model_name, other_topo, j
)
for j in range(other_topo.world_size())
]
for h in other_handlers:
if self.msid2mwid[h] not in mwids:
payloads[h] = request_reply_stream.Payload(
handler=h,
handle_name="empty",
)
setattr(payloads[h], f"{hook_type}_hooks", ["param_realloc"])
setattr(payloads[h], f"{hook_type}_hook_data", [ps_data])
mwids.append(self.msid2mwid[h])
elif self.msid2mwid[h] not in main_mwids:
hh = next(
hh
for hh in payloads
if self.msid2mwid[hh] == self.msid2mwid[h]
)
getattr(payloads[hh], f"{hook_type}_hooks").append(
"param_realloc"
)
getattr(payloads[hh], f"{hook_type}_hook_data").append(ps_data)
elif isinstance(hook, dfg.OffloadHook):
for h in main_handlers:
getattr(payloads[h], f"{hook_type}_hooks").append("offload")
getattr(payloads[h], f"{hook_type}_hook_data").append(
dict(model_name=h.model_name)
)
else:
raise NotImplementedError(f"Unknown hook type: {hook}")
return payloads, mwids
def request(
self,
producer_names: Dict[str, str],
producer_name2producer_handlers: Dict[str, List[config_pkg.ModelShardID]],
producer_mappings: Dict[str, Dict[str, List[int]]],
target_mapping: Dict[str, List[int]],
meta_sample: data_api.SequenceSample,
handlers: List[config_pkg.ModelShardID],
) -> Tuple[List[uuid.UUID], List[uuid.UUID]]:
rpc = self.rpc
ctrl = self.rpc_ctrl
dt_data = {
"keys": rpc.input_keys,
"target": rpc.model_name,
"producer_names": producer_names,
"producer_mappings": producer_mappings,
"target_mapping": target_mapping,
"handle_name": rpc.interface_type.value,
"rpc_name": rpc.name,
"meta_sample": meta_sample,
}
payloads = {
handler: request_reply_stream.Payload(
handler=handler,
handle_name=rpc.interface_type.value,
pre_hooks=["data_transfer"],
pre_hook_data=[dt_data],
data=rpc.name,
)
for handler in handlers
}
if ctrl.should_eval:
for p in payloads.values():
p.post_hooks.append("evaluate")
p.post_hook_data.append(dict(model_name=rpc.model_name))
if (
ctrl.should_save or ctrl.should_ckpt
) and rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
for p in payloads.values():
p.post_hooks.append("save")
save_dir = os.path.join(
self.model_save_root,
rpc.model_name.role,
f"epoch{ctrl.step_info.epoch + 1}"
f"epochstep{ctrl.step_info.epoch_step + 1}"
f"globalstep{ctrl.step_info.global_step + 1}",
)
p.post_hook_data.append(
dict(
model_name=rpc.model_name,
save_dir=save_dir,
recover_only=not ctrl.should_save,
)
)
mwids = [self.msid2mwid[h] for h in handlers]
assert len(mwids) == len(set(mwids))
for producer_name in producer_names.values():
for h in producer_name2producer_handlers[producer_name]:
if self.msid2mwid[h] not in mwids:
payloads[h] = request_reply_stream.Payload(
handler=h,
handle_name="empty",
pre_hooks=["data_transfer"],
pre_hook_data=[dt_data],
)
mwids.append(self.msid2mwid[h])
payloads, mwids = self.attach_payloads_with_hooks(
payloads,
mwids,
main_handlers=handlers,
hook_type="pre",
)
payloads, mwids = self.attach_payloads_with_hooks(
payloads,
mwids,
main_handlers=handlers,
hook_type="post",
)
main_payloads = [p for h, p in payloads.items() if h in handlers]
other_payloads = [p for h, p in payloads.items() if h not in handlers]
all_req_ids = self.stream.request(
payloads=main_payloads + other_payloads,
)
return all_req_ids[: len(main_payloads)], all_req_ids[len(main_payloads) :]
def data_parallel_dispatch(
self, buf_indices: List[int], sample: data_api.SequenceSample
) -> Tuple[List[int], data_api.SequenceSample, List[Tuple[int, int]]]:
# Dispatch data to different data parallel ranks.
if self.rpc.is_generate():
# The workload of generation is decided by batch size, instead of the generated length.
samples, forward_indices, _ = sample.split_with_lengths(
mb_spec=data_api.MicroBatchSpec(n_mbs=self.dp_size),
lens=[1 for _ in range(sample.bs)],
)
else:
samples, forward_indices, _ = sample.split(
data_api.MicroBatchSpec(n_mbs=self.dp_size)
)
blogger.info(
f"DP split (DP size {self.dp_size}) for RPC {self.rpc.name}: "
f"#seqs: {[s.bs for s in samples]}, "
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in samples]}"
)
sample = data_api.SequenceSample.gather(samples)
buf_indices = [buf_indices[i] for i in forward_indices]
partitions = data_api.SequenceSplitSpec(
sizes=[s.bs for s in samples]
).partitions
return buf_indices, sample, partitions
async def run_step(self, buf_indices, sample):
rpc = self.rpc
topo = self.model_topos[rpc.model_name]
ctrl = self.rpc_ctrl
handlers = [
config_pkg.ModelShardID.from_parallelism_rank(rpc.model_name, topo, j)
for j in range(topo.world_size())
]
producer_names = {} # data key -> model name
for k in rpc.input_keys:
if k in rpc.data_producers:
producer_names[k] = rpc.data_producers[k]
else:
producer_names[k] = self.src_rpc.model_name
keys_to_send = defaultdict(list) # model name -> List[keys] to send
for k in producer_names:
keys_to_send[producer_names[k]].append(k)
# convert producer model name to ModelShardID
producer_name2producer_handlers = {}
for producer_name in keys_to_send:
producer_name2producer_handlers[producer_name] = [
config_pkg.ModelShardID.from_parallelism_rank(
producer_name, self.model_topos[producer_name], j
)
for j in range(self.model_topos[producer_name].world_size())
]
dp_head_indices = [
topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, model=0)
for i in range(self.dp_size)
]
ctrl.flops_counter.add_rpc(rpc, sample, self.model_configs[rpc.model_name])
# logger.info(f"Model rpc {rpc.name} requesting.")
# Sample may be reordered here.
buf_indices, sample, partitions = self.data_parallel_dispatch(
buf_indices, sample
)
target_mapping = {i: list(range(v[0], v[1])) for i, v in enumerate(partitions)}
# Set data owner of produced data by this RPC, such that downstream RPCs can know
# where to fetch these data.
for dp_idx, (st, ed) in enumerate(partitions):
for i in range(st, ed):
for k in rpc.output_keys:
self.rpc_ctrl.data_owner[sample.ids[i], k] = (
rpc.model_name,
dp_idx,
)
# Get the data owner of this RPC's input data.
# We use it to determine the source of data transfer.
producer_mappings = {}
for k in rpc.input_keys:
names, dp_indices = [], []
for sample_id in sample.ids:
owner_name, dp_idx = self.rpc_ctrl.data_owner[(sample_id, k)]
names.append(owner_name)
dp_indices.append(dp_idx)
assert len(set(names)) == 1
producer_mapping = defaultdict(list)
for i, dp_idx in enumerate(dp_indices):
producer_mapping[dp_idx].append(i)
producer_mapping = {k: sorted(v) for k, v in producer_mapping.items()}
producer_mappings[names[0], k] = producer_mapping
# send partitioned data to model workers
req_ids, other_req_ids = self.request(
producer_names=producer_names,
producer_name2producer_handlers=producer_name2producer_handlers,
producer_mappings=producer_mappings,
target_mapping=target_mapping,
meta_sample=sample,
handlers=handlers,
)
tik = time.perf_counter()
await ctrl.topo_level_count.put(1)
logger.info(f"Model rpc {rpc.name} requested.")
# Then, wait for all main requests to finish.
responses = await self.stream.gather_async(request_ids=req_ids)
# logger.info(f"rpc {rpc.name} received responses {req_ids}")
# Filter out responses other than DP heads.
# Other repsonses are duplicated or None.
responses = [responses[i] for i in dp_head_indices]
# If the returned data is a SequenceSample, it is the data returned by
# model function calls. The data shoulbe be amended into buffer.
# Otherwise, it's the train statistics and should be reduced and logged.
if isinstance(responses[-1], data_api.SequenceSample):
res = data_api.SequenceSample.gather(responses)
else:
res = data_api.gather_stat(responses)
if rpc.log_return_value:
logger.info(f"RPC name {rpc.name} returns {res}")
if isinstance(res, Dict):
wandb.log(res, step=ctrl.step_info.global_step)
logger.info(
f"Model rpc {rpc.name} finished. "
f"Run time {time.perf_counter() - tik:.4f}s."
)
# If this RPC is the final node in the dataflow graph,
# update the train counter.
# Otherwise, amend data in the buffer.
if rpc.is_dst:
ctrl.ids_to_clear = ctrl.ids_to_clear.union(sample.ids)
await ctrl.train_count.put(1)
else:
logger.info(f"Amending RPC {rpc.name} output keys: {res.keys}")
await self.buffer.amend_batch(buf_indices, res.unpack())
# Wait for all side-effect requests to finish.
# Side-effect or empty requests are required for data transfer
# and parameter synchronization.
# Wait them after the main request to log the oorrect MFC time.
await self.stream.gather_async(other_req_ids)
async def run(self):
rpc = self.rpc
topo = self.model_topos[rpc.model_name]
ctrl = self.rpc_ctrl
logger.info(
f"Running Model RPC, interface_type=#{rpc.interface_type}# "
f"(dp,mp,pp) = *({topo.get_dim('data')},{topo.get_dim('model')},{topo.get_dim('pipe')})*"
)
consumed = 0
while True:
buf_indices, sample = await self.buffer.get_batch_for_rpc(rpc)
await self.run_step(buf_indices, sample)
consumed += sample.bs
# Ensure that parent RPCs will not be over-consumed.
if all(consumed >= c.n_seqs for c in rpc.all_successors()):
break

View File

@ -0,0 +1,168 @@
# Copyright 2025 Ant Group Inc.
import asyncio
import random
from typing import *
import networkx as nx
from realhf.api.core.config import ModelName, ModelShardID
from realhf.api.core.data_api import DataBatchMeta, SequenceSample
from realhf.api.core.dfg import MFCDef
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import logging
from realhf.base.topology import PipeModelDataParallelTopology
from realhf.system.buffer import AsyncIOSequenceBuffer
from realhf.system.request_reply_stream import NameResolvingRequestClient
from realhf.system.v2.function_call import FunctionCall, RPCCorountineControl
logger = logging.getLogger(__name__, "system")
blogger = logging.getLogger("benchmark")
class FunctionExecutor:
def __init__(
self,
rpcs: List[MFCDef],
msid2mwid: Dict[ModelShardID, int],
stream: NameResolvingRequestClient,
buffer: AsyncIOSequenceBuffer,
model_topos: Dict[str, PipeModelDataParallelTopology],
model_configs: Dict[str, None | ReaLModelConfig],
ctrl: RPCCorountineControl,
):
self.func_calls: Dict[str, FunctionCall] = {}
self.ctrl = ctrl
self.n_model_workers = len(set(msid2mwid.values()))
self.rpcs = rpcs
self.src_rpc = list(filter(lambda rpc: rpc.is_src, rpcs))[0]
self.src_dp_size = model_topos[self.src_rpc.model_name].get_dim("data")
# Create model function calls.
for rpc in self.rpcs:
func_call = FunctionCall(
rpc=rpc,
src_rpc=self.src_rpc,
stream=stream,
msid2mwid=msid2mwid,
model_topos=model_topos,
model_configs=model_configs,
ctrl=ctrl,
buffer=buffer,
)
self.func_calls[rpc.name] = func_call
self.stream = stream
self.buffer = buffer
# Sort all MFCs in the topological order and
# calculate the width of each level.
# These numbers will determine when to flush MFC requests.
self.topo_widths = []
for generation in nx.topological_generations(rpcs[0]._G):
self.topo_widths.append(len(generation))
def get_leaf_tasks(self) -> List[str]:
dst_rpcs = list(filter(lambda rpc: rpc.is_dst, self.rpcs))
return [rpc.name for rpc in dst_rpcs]
async def flush_calls(self):
for level, w in enumerate(self.topo_widths):
for _ in range(w):
await self.ctrl.topo_level_count.get()
logger.info(f"DFG level {level}. Flushing {w} function calls.")
self.stream.request(
handlers=list(range(self.n_model_workers)), handle_type="flush"
)
async def finish_traverse(self):
for _ in range(len(self.get_leaf_tasks())):
await self.ctrl.train_count.get()
async def load_data(self):
src_rpc = self.src_rpc
src_rpc_model_name = src_rpc.model_name
buffer = self.buffer
ctrl = self.ctrl
dp_idx = -1
received_ids = set()
while self.buffer.size < max(rpc.n_seqs for rpc in self.rpcs):
all_data = []
dp_idx += 1
dp_idx %= self.src_dp_size
resps = await self.stream.call_async(
handlers=[f"__data{dp_idx}__"],
handle_type="fetch",
datas=[None],
verbose=False,
)
x: DataBatchMeta | None = resps[0]
if x is None:
continue
if x.meta_sample is None:
continue
# Store the owner information of the data.
# RPCs corountines will use this information to
# determine the src and dst of data transfer.
for xx in x.meta_sample.unpack():
if xx.ids[0] in received_ids:
raise ValueError(
f"Duplicate data id {xx.ids[0]}. Is the final batch? {is_final_batch}."
)
received_ids.add(xx.ids[0])
for k in xx.keys:
self.ctrl.data_owner[(xx.ids[0], k)] = (
src_rpc_model_name,
dp_idx,
)
all_data += x.meta_sample.unpack()
filtered_data = []
for xx in x.meta_sample.unpack():
if xx.ids[0] in ctrl.hash_vals_to_ignore_in_recover:
ctrl.hash_vals_to_ignore_in_recover.remove(xx.ids[0])
ctrl.ids_to_clear.add(xx.ids[0])
else:
filtered_data.append(xx)
all_data = filtered_data
# We load data in a round-robin manner across different DP ranks,
# so we also need to shuffle the data to fuse different dataset splits.
random.shuffle(all_data)
# Store into buffer!
buffer_indices = await buffer.put_batch(all_data)
assert len(buffer_indices) == len(all_data)
blogger.info(
f"Master worker loaded {len(all_data)} pieces of data. "
f"Remaining number of data to ignore: {len(self.ctrl.hash_vals_to_ignore_in_recover)}. "
f"Current buffer size: {buffer.size}/{buffer.max_size}. "
)
def execute_step(self):
logger.info("Waiting for the finish of the execution graph.")
loop = asyncio.get_event_loop()
tasks = [loop.create_task(fc.run()) for fc in self.func_calls.values()] + [
loop.create_task(self.flush_calls()),
loop.create_task(self.load_data()),
]
completion_future = loop.create_task(self.finish_traverse())
loop.run_until_complete(completion_future)
for task in tasks:
loop.run_until_complete(task)
logger.info("Execution finished!")

View File

@ -0,0 +1,501 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import copy
import gc
import os
import time
from typing import Dict
import colorama
import networkx as nx
import numpy as np
import uvloop
import wandb
import realhf.api.core.dfg as dfg
import realhf.api.core.model_api as model_api
import realhf.api.core.system_api as config_pkg
import realhf.base.recover as recover
import realhf.system.request_reply_stream as request_reply_stream
import realhf.system.worker_base as worker_base
from realhf.api.core.config import ModelName
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import constants, logging, name_resolve, names, timeutil, topology
from realhf.system.buffer import AsyncIOSequenceBuffer
from realhf.system.v2.function_call import RPCCorountineControl
from realhf.system.v2.function_executor import FunctionExecutor
logger = logging.getLogger("master worker", "system")
blogger = logging.getLogger("benchmark")
uvloop.install()
class MasterWorker(worker_base.Worker):
global_exp_tik = time.perf_counter()
def _configure(self, config: config_pkg.MasterWorker):
self.config = config
self.__model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = (
config.model_topos
)
# Build execution graph and initialize concurrency utilities.
self.__model_rpcs = config.model_rpcs
# Sort all MFCs in the topological order and
# calculate the width of each level.
# These numbers will determine when to flush MFC requests.
self.__topo_widths = []
for generation in nx.topological_generations(self.__model_rpcs[0]._G):
self.__topo_widths.append(len(generation))
logger.info("Topological widths: " + str(self.__topo_widths))
self.__rpc_srcs = list(filter(lambda rpc: rpc.is_src, self.__model_rpcs))
self.__rpc_dsts = list(filter(lambda rpc: rpc.is_dst, self.__model_rpcs))
# Save and eval control.
self.__total_train_epochs = config.exp_ctrl.total_train_epochs
self.__save_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.save_freq_epochs,
freq_step=config.exp_ctrl.save_freq_steps,
freq_sec=config.exp_ctrl.save_freq_secs,
)
if (
config.exp_ctrl.ckpt_freq_epochs is None
and config.exp_ctrl.ckpt_freq_steps is None
and config.exp_ctrl.ckpt_freq_secs is None
):
self.__ckpt_ctl = self.__save_ctl
else:
self.__ckpt_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.ckpt_freq_epochs,
freq_step=config.exp_ctrl.ckpt_freq_steps,
freq_sec=config.exp_ctrl.ckpt_freq_secs,
)
self.__eval_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.eval_freq_epochs,
freq_step=config.exp_ctrl.eval_freq_steps,
freq_sec=config.exp_ctrl.eval_freq_secs,
)
self.MODEL_SAVE_ROOT = os.path.join(
constants.MODEL_SAVE_ROOT,
config.worker_info.experiment_name,
config.worker_info.trial_name,
)
os.makedirs(self.MODEL_SAVE_ROOT, exist_ok=True)
self.__initialized = False
self.__recover_run, self.__recover_info = recover.load_recover_info()
if self.__recover_info is not None:
logger.info(
f"Loaded recover info: recover_start={self.__recover_info.recover_start}, "
f"last_step_info={self.__recover_info.last_step_info}."
)
logger.info(
f"Number of used data in recover info: {len(self.__recover_info.hash_vals_to_ignore)}. "
f"The previous experiment probably ran for {len(self.__recover_info.hash_vals_to_ignore) // self.__rpc_srcs[0].n_seqs} steps in the epoch."
)
# Create corountine control objects for running the dataflow graph.
self.__rpc_ctrl = RPCCorountineControl(
train_count=asyncio.Queue(maxsize=len(self.__rpc_dsts)),
topo_level_count=asyncio.Queue(maxsize=sum(self.__topo_widths)),
# NOTE: We should accumulate the used data hashes in the same epoch
# to prevent loading data used before.
used_hash_vals_this_epoch=(
copy.deepcopy(self.__recover_info.hash_vals_to_ignore)
if self.__recover_run
else list()
),
hash_vals_to_ignore_in_recover=(
copy.deepcopy(self.__recover_info.hash_vals_to_ignore)
if self.__recover_run
else list()
),
)
if self.__recover_run:
self.__rpc_ctrl.step_info = copy.deepcopy(self.__recover_info.recover_start)
self.__eval_ctl.load_state_dict(self.__recover_info.eval_ctl_info)
self.__save_ctl.load_state_dict(self.__recover_info.save_ctl_info)
self.__ckpt_ctl.load_state_dict(self.__recover_info.ckpt_ctl_info)
logger.info(
f"Recovering from previous run. "
f"Epoch: {self.__rpc_ctrl.step_info.epoch + 1}, "
f"Epoch Step: {self.__rpc_ctrl.step_info.epoch_step + 1} "
f"Global Step: {self.__rpc_ctrl.step_info.global_step + 1}."
)
# for benchmark
self.e2e_time_history = []
self.__benchmark_steps = config.exp_ctrl.benchmark_steps
return config.worker_info
def initialize_models(self):
# Initialize model backends.
model_names = list(self.__model_topos.keys())
self.logger.info(f"Initialize model backends with order: {model_names}.")
train_rpcs = list(
filter(
lambda rpc: rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
self.__model_rpcs,
)
)
assert all(rpc.n_seqs == train_rpcs[0].n_seqs for rpc in train_rpcs)
if len(train_rpcs) > 0:
ft_spec = model_api.FinetuneSpec(
total_train_epochs=self.config.exp_ctrl.total_train_epochs,
dataset_size=self._dataset_size,
train_batch_size=train_rpcs[0].n_seqs,
)
else:
ft_spec = model_api.FinetuneSpec(
total_train_epochs=self.config.exp_ctrl.total_train_epochs,
dataset_size=self._dataset_size,
train_batch_size=self.__src_rpc.n_seqs,
)
_initialized_roles = []
for model_name in model_names:
topo = self.config.model_topos[model_name]
# Build FinetuneSpec, which is required to initialize backends.
_handlers = [
config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, j)
for j in range(topo.world_size())
]
init_payloads = [
request_reply_stream.Payload(
handler=_h,
handle_name="initialize",
data=ft_spec,
)
for _h in _handlers
]
# Send initialization requests then immediately flush them.
self.__stream.request(
payloads=init_payloads,
)
self.__stream.request(
handlers=_handlers,
handle_type="flush",
no_syn=True,
)
_initialized_roles.append(model_name.role)
self._ft_spec = ft_spec
logger.info("Initializations of models and backends complete.")
def get_dataset_model_info(self):
src_rpc = self.__rpc_srcs[0]
src_rpc_topo = self.config.model_topos[src_rpc.model_name]
src_rpc_dp_size = src_rpc_topo.get_dim("data")
# Request training specification from data workers.
all_data = sum(
self.__stream.call(
handlers=[f"__data{i}__" for i in range(src_rpc_dp_size)],
datas=[None for i in range(src_rpc_dp_size)],
handle_type="spec",
),
[],
)
# NOTE: For dynamic datasets, we still count epoch according to the initial number of data,
# such that the learning rate decay is not affected.
seqlens = [max(sum(v[0]) for v in x.seqlens.values()) for x in all_data]
self._dataset_size = len(all_data)
self._steps_per_epoch = self._dataset_size // src_rpc.n_seqs
self._avg_tokens_per_batch = sum(seqlens) / self._steps_per_epoch
self._dataset_ids = [copy.deepcopy(x.ids[0]) for x in all_data]
# Request model configs from model workers.
# Return None if the model is not a ReaLModel.
self.__model_configs: Dict[ModelName, None | ReaLModelConfig] = {}
for model_name, topo in self.config.model_topos.items():
h = config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, 0)
self.__model_configs[model_name] = self.__stream.call(
handlers=[h],
datas=[None],
handle_type="model_config",
)[0]
def __lazy_init(self):
# Set up streams.
handler_routing = copy.deepcopy(self.config.msid2mwid)
src_rpc = self.__rpc_srcs[0]
src_rpc_topo = self.config.model_topos[src_rpc.model_name]
src_rpc_dp_size = src_rpc_topo.get_dim("data")
src_rpc_pp_size = src_rpc_topo.get_dim("pipe")
for i in range(src_rpc_dp_size):
rank = src_rpc_topo.get_rank(data=i, pipe=src_rpc_pp_size - 1, model=0)
handler_routing[f"__data{i}__"] = self.config.msid2mwid[
config_pkg.ModelShardID.from_parallelism_rank(
model_name=src_rpc.model_name,
topo=src_rpc_topo,
parallelism_rank=rank,
)
]
handler_routing.update({i: i for i in range(self.config.n_model_workers)})
self.__stream = request_reply_stream.make_master_stream(
self.config.worker_info,
n_subscribers=self.config.n_model_workers,
handler_routing=handler_routing,
)
self.__stream: request_reply_stream.NameResolvingRequestClient
self.__src_rpc = src_rpc = [
rpc for rpc in self.config.model_rpcs if rpc.is_src
][0]
self.get_dataset_model_info()
self.initialize_models()
self.__seqbuffer = AsyncIOSequenceBuffer(
self.__model_rpcs,
max_size=int(os.getenv("REAL_MASTER_BUFFER_SIZE", str(int(1e7)))),
)
# Create coroutines for model RPCs.
logger.info(f"Creating asyncio coroutines...")
self.func_executor = FunctionExecutor(
rpcs=self.__model_rpcs,
msid2mwid=self.config.msid2mwid,
stream=self.__stream,
buffer=self.__seqbuffer,
model_topos=self.__model_topos,
model_configs=self.__model_configs,
ctrl=self.__rpc_ctrl,
)
logger.info(f"Coroutines created. The master worker is ready to run.")
# wandb init, connect to remote wandb host
wandb.login()
wandb.init(
mode=self.wandb_config.mode,
entity=self.wandb_config.entity,
project=self.wandb_config.project or constants.experiment_name(),
name=self.wandb_config.name or constants.trial_name(),
job_type=self.wandb_config.job_type,
group=self.wandb_config.group,
notes=self.wandb_config.notes,
tags=self.wandb_config.tags,
config=self.wandb_config.config,
dir=os.path.join(
constants.LOG_ROOT, constants.experiment_name(), constants.trial_name()
),
force=True,
resume="allow",
settings=wandb.Settings(start_method="fork"),
)
self.__initialized = True
self._train_start_time = time.perf_counter()
self.__last_step_info = recover.StepInfo(
epoch=-1,
epoch_step=-1,
global_step=-1,
)
def _poll(self):
is_new_epoch = False
if not self.__initialized:
self.__lazy_init()
# Main execution steps. The graph runs under-the-hood in RPC & stream threads.
# Wait for the finish of the traversal of the execution graph.
execution_start = time.perf_counter()
if self.__rpc_ctrl.ids_to_clear:
# Send clear cache requests to model workers.
# Clearing the data used in the last step.
self._clear_gpu_cache()
is_new_epoch = self._ft_spec.is_new_epoch(self.__rpc_ctrl.step_info)
is_epoch_last_step = self._ft_spec.is_epoch_last_step(self.__rpc_ctrl.step_info)
# Check whether we should evaluate or save models.
self.__rpc_ctrl.should_eval = self.__eval_ctl.check(
epochs=int(is_epoch_last_step), steps=1
)
self.__rpc_ctrl.should_save = self.__save_ctl.check(
epochs=int(is_epoch_last_step), steps=1
)
self.__rpc_ctrl.should_ckpt = self.__ckpt_ctl.check(
epochs=int(is_epoch_last_step), steps=1
)
# Traverse over the dataflow graph for once.
self.func_executor.execute_step()
# Post-process.
if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt:
self.__last_step_info = copy.deepcopy(self.__rpc_ctrl.step_info)
self.__rpc_ctrl.used_hash_vals_this_epoch += list(self.__rpc_ctrl.ids_to_clear)
if is_epoch_last_step:
self.__rpc_ctrl.used_hash_vals_this_epoch = (
self.__rpc_ctrl.used_hash_vals_this_epoch[self._dataset_size :]
)
if is_new_epoch:
self.__rpc_ctrl.step_info.epoch += 1
self.__rpc_ctrl.step_info.epoch_step = 0
# Logging.
time_since_configure = time.perf_counter() - self._train_start_time
e2e_time = time.perf_counter() - execution_start
self.e2e_time_history.append(e2e_time)
self._log_training_stats(e2e_time, time_since_configure)
# Updata counters.
self.__rpc_ctrl.step_info.epoch_step += 1
self.__rpc_ctrl.step_info.global_step += 1
if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt:
self.__recover_save()
# Pause the worker if experiment or system-wise benchmark completes.
if (
self.__benchmark_steps is not None
and self.__rpc_ctrl.step_info.global_step >= self.__benchmark_steps
) or (
self.__rpc_ctrl.step_info.global_step * self.__src_rpc.n_seqs
>= self.__total_train_epochs * self._dataset_size
):
# We don't know whether it is the last step of the current epoch,
# so we exit at the first step of the next epoch.
if self.__benchmark_steps is not None:
logger.info(
f"Finished benchmark {self.__benchmark_steps}. "
f"Time consumption of this setup: {time_since_configure:.3f}"
)
logger.info(f"avg #e2e# time *{np.mean(self.e2e_time_history):.3f}*")
return self.experiment_complete_exit()
return worker_base.PollResult(sample_count=1, batch_count=1)
def _log_training_stats(self, e2e_time: float, time_since_configure: float):
# calculate flops
#########################################
if not all(
isinstance(v, ReaLModelConfig) for v in self.__model_configs.values()
):
logger.warning(
f"Not all models are ReaLModels. Unable to calculate FLOP/s."
)
flops = None
tflops_per_gpu = float("inf")
else:
flops = self.__rpc_ctrl.flops_counter.get_flops()
tflops = flops / (e2e_time * (10**12))
tflops_per_gpu = flops / (e2e_time * self.config.n_model_workers * (10**12))
self.__rpc_ctrl.flops_counter.clear()
#########################################
epoch = self.__rpc_ctrl.step_info.epoch + 1
epoch_step = self.__rpc_ctrl.step_info.epoch_step + 1
global_step = self.__rpc_ctrl.step_info.global_step + 1
s = f"Epoch {epoch}/{self.config.exp_ctrl.total_train_epochs} "
s += f"step {epoch_step}/{self._steps_per_epoch} "
s += f"(global step {global_step}) finishes. "
s += f"Average #tokens per batch is {self._avg_tokens_per_batch:.0f}. "
s += f"#End to end# execution time: *{e2e_time:.3f}*s. "
s += f"Total time consumption: {time_since_configure:.3f}s. "
if len(self.e2e_time_history) > 2:
remaining_steps = self._steps_per_epoch - epoch_step
remaining_epochs = self.__total_train_epochs - epoch
avg_t = np.mean(self.e2e_time_history[2:])
remain_t = avg_t * remaining_steps
remain_t += avg_t * self._steps_per_epoch * remaining_epochs
s += f"Estimated remaining time: {remain_t:.3f}s. "
if flops is not None:
s += f"TFLOP/s per GPU: {tflops_per_gpu:.2f}, total TFLOP/s: {tflops:.2f}."
logger.info(s)
logger.info(
f"Time taken so far across all configurations: {time.perf_counter() - self.global_exp_tik:.2f}s"
)
def _clear_gpu_cache(self):
self.__stream.request(
handlers=list(range(self.config.n_model_workers)),
handle_type="clear_data_cache",
datas=[
self.__rpc_ctrl.ids_to_clear
for _ in list(range(self.config.n_model_workers))
],
no_syn=True,
)
self.__rpc_ctrl.ids_to_clear.clear()
def experiment_complete_exit(self):
logger.info(
colorama.Style.RESET_ALL
+ colorama.Fore.YELLOW
+ colorama.Style.BRIGHT
+ "\033[1m"
+ "Experiment Completes! Yeah!!!!!!!!"
+ colorama.Style.RESET_ALL
)
# Send requests to pause model workers.
# Model workers will not respond to this message.
self.__stream.request(
handlers=list(range(self.config.n_model_workers)),
handle_type="reset",
datas=[None for _ in list(range(self.config.n_model_workers))],
)
self.__stream.close()
constants.reset_run()
# Reset names used for distributed training.
# The next round of training will set up a new distributed environment.
name_resolve.clear_subtree(
names.distributed_root(constants.experiment_name(), constants.trial_name())
)
name_resolve.clear_subtree(
names.request_reply_stream_root(
constants.experiment_name(), constants.trial_name()
)
)
wandb.finish()
gc.collect()
self.__initialized = False
self.pause()
return worker_base.PollResult(0, 0)
def __recover_save(self):
# save step info for recover
if os.getenv("REAL_SAVE_RECOVER_STATES", "0") != "1":
return
# save step info for recover
this_step_info = copy.deepcopy(self.__rpc_ctrl.step_info)
recover_info = recover.RecoverInfo(
recover_start=this_step_info,
last_step_info=self.__last_step_info,
save_ctl_info=self.__save_ctl.state_dict(),
ckpt_ctl_info=self.__ckpt_ctl.state_dict(),
eval_ctl_info=self.__eval_ctl.state_dict(),
hash_vals_to_ignore=self.__rpc_ctrl.used_hash_vals_this_epoch,
)
recover.dump_recover_info(recover_info)
logger.info("Dumped recover info to file.")
logger.info(f"Will recover from: {recover_info.recover_start}")
logger.info(
f"Number of data used in this epoch: {len(recover_info.hash_vals_to_ignore)}"
)

View File

@ -13,11 +13,7 @@ import time
from typing import Any, Dict, List, Optional, Tuple
import realhf.api.core.system_api as system_api
from realhf.base import (
logging,
name_resolve,
names,
)
from realhf.base import logging, name_resolve, names
from realhf.base.gpu_utils import set_cuda_device
logger = logging.getLogger("worker")

View File

@ -47,6 +47,8 @@ def math_dataset(request, save_path):
return dataset
# NOTE: we can't test v1 and v2 at the same time.
@pytest.mark.parametrize("use_v2_worker", [True])
@pytest.mark.parametrize(
"dp,pp,mp",
[
@ -66,6 +68,7 @@ def test_ppo_symm(
dp,
pp,
mp,
use_v2_worker,
):
# Setup experiment env. Should be done before any other operations.
log_root = tmp_path_factory.mktemp("ppo")
@ -117,8 +120,10 @@ def test_ppo_symm(
),
),
)
exp_cfg.actor.vllm.hybrid_train = True
exp_cfg.actor.vllm.enforce_eager = True
run_test_exp(exp_cfg)
run_test_exp(exp_cfg, use_v2_worker=use_v2_worker)
# The global resharding strategy, where all MFCs
@ -238,6 +243,8 @@ def test_ppo_global_reshard(
),
),
)
exp_cfg.actor.vllm.hybrid_train = True
exp_cfg.actor.vllm.enforce_eager = True
run_test_exp(exp_cfg)
@ -351,6 +358,8 @@ def test_ppo_param_realloc_sub_device_mesh(
),
),
)
exp_cfg.actor.vllm.hybrid_train = True
exp_cfg.actor.vllm.enforce_eager = True
run_test_exp(exp_cfg)
@ -470,6 +479,8 @@ def test_ppo_save(
),
),
)
exp_cfg.actor.vllm.hybrid_train = True
exp_cfg.actor.vllm.enforce_eager = True
run_test_exp(exp_cfg)

View File

@ -38,6 +38,8 @@ def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp
test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp)
# NOTE: we can't test v1 and v2 at the same time.
@pytest.mark.parametrize("use_v2_worker", [True])
@pytest.mark.parametrize(
"dp,pp,tp",
[
@ -47,7 +49,9 @@ def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp
(1, 1, 2),
],
)
def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
def test_sft(
tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp, use_v2_worker
):
# Setup experiment env. Should be done before any other operations.
log_root = tmp_path_factory.mktemp("sft")
@ -79,4 +83,4 @@ def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
),
)
run_test_exp(exp_cfg)
run_test_exp(exp_cfg, use_v2_worker=use_v2_worker)

View File

@ -44,13 +44,22 @@ def run_model_worker(cfg, mw, barrier):
initd = True
def run_test_exp(exp_cfg: Experiment, expr_name=None, trial_name=None):
def run_test_exp(
exp_cfg: Experiment,
expr_name=None,
trial_name=None,
use_v2_worker: bool = False,
):
constants.set_force_cpu(True)
# Register all datasets and models
import realhf.impl.dataset # isort: skip
import realhf.impl.model # isort: skip
from realhf.api.core import system_api
from realhf.system.master_worker import MasterWorker
if not use_v2_worker:
from realhf.system.master_worker import MasterWorker
else:
from realhf.system.v2.master_worker import MasterWorker
system_api.ALL_EXPERIMENT_CLASSES = {}
register_experiment(testing._DEFAULT_EXPR_NAME, lambda: exp_cfg)