From c1bb4770ffef07d6854384bc352a1561f1c94df9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=9A=E6=83=9F?= Date: Fri, 7 Mar 2025 11:07:54 +0800 Subject: [PATCH] PullRequest: 6 Implement master worker v2 for refactoring and uvloop support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Merge branch fw/uvloop of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/6 Signed-off-by: 晓雷 * . * fw/fix-dataloading-not-shuffle * . * . * . * . * . * add v2 master worker * cpu test pass * ppo run * . * format * fix * merge and format * change default env vars to v1 worker --- realhf/api/core/data_api.py | 19 + realhf/apps/main.py | 1 + realhf/experiments/common/common.py | 30 +- realhf/experiments/common/ppo_math_exp.py | 12 +- realhf/experiments/common/utils.py | 8 +- realhf/impl/model/backend/mock_train.py | 19 +- realhf/system/__init__.py | 3 + realhf/system/buffer.py | 80 +++- realhf/system/controller.py | 1 + realhf/system/master_worker.py | 39 +- realhf/system/model_worker.py | 20 +- realhf/system/v2/function_call.py | 437 +++++++++++++++++++ realhf/system/v2/function_executor.py | 168 ++++++++ realhf/system/v2/master_worker.py | 501 ++++++++++++++++++++++ realhf/system/worker_base.py | 6 +- tests/experiments/test_math_ppo.py | 13 +- tests/experiments/test_sft.py | 8 +- tests/experiments/utils.py | 13 +- 18 files changed, 1298 insertions(+), 80 deletions(-) create mode 100644 realhf/system/v2/function_call.py create mode 100644 realhf/system/v2/function_executor.py create mode 100644 realhf/system/v2/master_worker.py diff --git a/realhf/api/core/data_api.py b/realhf/api/core/data_api.py index 3cb7c2c..9bb31da 100644 --- a/realhf/api/core/data_api.py +++ b/realhf/api/core/data_api.py @@ -828,3 +828,22 @@ def make_dataset( logger.info(f"Dataset creation/loading time: {time.perf_counter() - tik:.3f}s") return dataset + + +def gather_stat(src: List[Dict]) -> Dict: + cnt, stats = {}, {} + for reply in src: + for k, v in reply.items(): + cnt[k] = cnt.get(k, 0) + 1 + stats[k] = stats.get(k, 0) + v + res = {k: v / cnt for k, v, cnt in zip(stats.keys(), stats.values(), cnt.values())} + for k, c in cnt.items(): + if c != len(src): + logger.warning(f"Gathered `{k}` is not present in every returned stats.") + for k, v in res.items(): + if any(abs(v - x.get(k, None)) > 1e-4 for x in src): + logger.warning( + f"Gathered `{k}` is not all-reduced " + f"before returning: ({[x.get(k, None) for x in src]}, {v})." + ) + return res diff --git a/realhf/apps/main.py b/realhf/apps/main.py index 6528d7f..a1ae9ca 100644 --- a/realhf/apps/main.py +++ b/realhf/apps/main.py @@ -161,6 +161,7 @@ def main_start(args, recover_count: int = 0): REAL_RECOVER_RUN="1" if is_recover_run else "0", REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0", REAL_MATH_METADATA_PATH=os.getenv("REAL_MATH_METADATA_PATH", ""), + REAL_USE_V2_WORKER=os.getenv("REAL_USE_V2_WORKER", "0"), ) for k, v in BASE_ENVIRONS.items(): os.environ[k] = v diff --git a/realhf/experiments/common/common.py b/realhf/experiments/common/common.py index d601c45..6ef9705 100644 --- a/realhf/experiments/common/common.py +++ b/realhf/experiments/common/common.py @@ -31,8 +31,8 @@ from realhf.api.core.system_api import ( ModelWorker, Scheduling, TasksGroup, - WandBConfig, TensorBoardConfig, + WandBConfig, ) from realhf.api.quickstart.device_mesh import ( DeviceMesh, @@ -56,6 +56,7 @@ from realhf.experiments.common.check import ( ) from realhf.experiments.common.utils import ( AllocationMode, + asdict, get_topo, make_inf_backend_config, make_train_backend_config, @@ -203,7 +204,9 @@ class CommonExperimentConfig(Experiment): partition: str = "dev" schedule_strategy: str = "empty_first" wandb: WandBConfig = dataclasses.field(default_factory=WandBConfig) - tensorboard: TensorBoardConfig = dataclasses.field(default_factory=TensorBoardConfig) + tensorboard: TensorBoardConfig = dataclasses.field( + default_factory=TensorBoardConfig + ) image_name: Optional[str] = None recover_mode: str = "disabled" recover_retries: int = 1 @@ -561,9 +564,7 @@ class CommonExperimentConfig(Experiment): and self.gen_device_mesh.mapping[i, j] ): gen_rpc_alloc = next( - alloc - for alloc in rpc_allocs - if alloc.rpc.interface_type == ModelInterfaceType.GENERATE + alloc for alloc in rpc_allocs if alloc.rpc.is_generate() ) model_name = gen_rpc_alloc.rpc.model_name topo = get_topo( @@ -620,6 +621,10 @@ class CommonExperimentConfig(Experiment): model_rpc_allocs, ) in model_name_to_rpc_allocs.items(): rpcs = [rpc_alloc.rpc for rpc_alloc in model_rpc_allocs] + if self._allocation_mode.is_decoupled() and all( + rpc.is_generate() for rpc in rpcs + ): + continue rpc_alloc = model_rpc_allocs[0] model_cfg = self.models[model_name.role] model = get_real_model_config( @@ -677,9 +682,7 @@ class CommonExperimentConfig(Experiment): rpc.is_generate() for rpc in rpcs ): assert len(rpcs) == 1 and rpcs[0].is_generate(), rpcs - vllm_dict_args: Dict[str, Any] = OmegaConf.to_container( - model_cfg.vllm, resolve=True - ) + vllm_dict_args: Dict[str, Any] = asdict(model_cfg.vllm) backend = ModelBackendAbstraction( "vllm", args=dict( @@ -689,6 +692,17 @@ class CommonExperimentConfig(Experiment): ) else: backend = make_inf_backend_config(model_cfg, rpc_alloc.parallel) + if any(rpc.is_generate() for rpc in rpcs) and backend.type_ not in [ + "vllm", + "sglang", + ]: + print(rpcs, model_name, backend.type_) + raise ValueError( + "vLLM or SGLang is not enabled for generation. " + "This behavior has been deprecated. " + "Please set model.vllm.hybrid_train=True " + "or model.sglang.hybrid_train=True." + ) check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs) if mapping[i, j]: diff --git a/realhf/experiments/common/ppo_math_exp.py b/realhf/experiments/common/ppo_math_exp.py index 0987b3c..0be5fa3 100644 --- a/realhf/experiments/common/ppo_math_exp.py +++ b/realhf/experiments/common/ppo_math_exp.py @@ -23,7 +23,11 @@ from realhf.api.quickstart.device_mesh import MFCConfig from realhf.api.quickstart.entrypoint import register_quickstart_exp from realhf.api.quickstart.model import ModelTrainEvalConfig from realhf.experiments.common.common import CommonExperimentConfig -from realhf.experiments.common.utils import resolve_replica_ids, resolve_rpc_hooks +from realhf.experiments.common.utils import ( + asdict, + resolve_replica_ids, + resolve_rpc_hooks, +) logger = logging.getLogger("PPO Math exp", "colored") @@ -283,11 +287,7 @@ class PPOMATHConfig(CommonExperimentConfig): # It is used for unifying the profiling API, which requires to # pass external interface configurations in the launch command. # Customized dataclass objects will not work in that case. - "generation_config": ( - OmegaConf.to_container(self.ppo.gen, resolve=True) - if isinstance(self.ppo.gen, (OmegaConf, DictConfig)) - else dataclasses.asdict(self.ppo.gen) - ), + "generation_config": asdict(self.ppo.gen), "early_stop_imp_ratio": self.ppo.early_stop_imp_ratio, "adv_norm": self.ppo.adv_norm, "group_size": self.group_size, diff --git a/realhf/experiments/common/utils.py b/realhf/experiments/common/utils.py index 8de0dca..6d6ffed 100644 --- a/realhf/experiments/common/utils.py +++ b/realhf/experiments/common/utils.py @@ -10,7 +10,7 @@ import re from typing import * import numpy as np -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from realhf.api.core.config import ( ModelBackendAbstraction, @@ -317,3 +317,9 @@ class AllocationMode: return allocs[k] = v["*"] return allocs + + +def asdict(cfg): + if isinstance(cfg, (OmegaConf, DictConfig)): + return OmegaConf.to_container(cfg, resolve=True) + return dataclasses.asdict(cfg) diff --git a/realhf/impl/model/backend/mock_train.py b/realhf/impl/model/backend/mock_train.py index 8e6ff36..541cae5 100644 --- a/realhf/impl/model/backend/mock_train.py +++ b/realhf/impl/model/backend/mock_train.py @@ -32,7 +32,7 @@ class MockPipeTrainInstrSet(PipeTrainInstrSet): Used for testing only. """ - optimizer: torch.optim.Optimizer + optim: torch.optim.Optimizer def _exec_backward_pass( self, @@ -78,14 +78,19 @@ class MockPipeTrainInstrSet(PipeTrainInstrSet): micro_batch_id: int, step_id: int, ): - self.optimizer.step() + self.optim.step() + + +class AdamWithLossScale(torch.optim.Adam): + def get_loss_scale(self) -> torch.Tensor: + return torch.tensor([1.0], device=constants.current_device()) class MockTrainEngine(model_api.PipelinableEngine): - def __init__(self, module: ReaLModel, optimizer: torch.optim.Optimizer): + def __init__(self, module: ReaLModel, optimizer: AdamWithLossScale): self.module = module - self.optimizer = optimizer + self.optim = optimizer self.inf_engine = PipelinableInferenceEngine(module) if constants.pipe_parallel_world_size() > 1: @@ -111,10 +116,10 @@ class MockTrainEngine(model_api.PipelinableEngine): token_normalize_scope: str, version_steps: int, ): - self.optimizer.zero_grad() + self.optim.zero_grad() if constants.pipe_parallel_world_size() > 1: # Fusing the minibatched forward-backward in a pipeline training schedule. - instr_set = MockPipeTrainInstrSet(self.optimizer) + instr_set = MockPipeTrainInstrSet(self, self.optim) # NOTE: When training with pipeline parallel, num micro batches should be # larger than 2 x num_pipeline_stages to avoid idle time. return self.pipe_runner.train_batch( @@ -222,7 +227,7 @@ class MockTrainBackend(model_api.ModelBackend): raise ValueError("MegatronTrainBackend only supports ReaLModel.") if self.optimizer_name == "adam": - optimizer = torch.optim.Adam(module.parameters(), **self.optimizer_config) + optimizer = AdamWithLossScale(module.parameters(), **self.optimizer_config) else: raise NotImplementedError( f"Optimizer {self.optimizer_name} not implemented for testing." diff --git a/realhf/system/__init__.py b/realhf/system/__init__.py index 874c088..5a2b3a7 100644 --- a/realhf/system/__init__.py +++ b/realhf/system/__init__.py @@ -15,6 +15,7 @@ logger = logging.getLogger("system") # NOTE: Workers are configured in the following order. # Take special care when adding a new worker type. WORKER_TYPES = ["model_worker", "master_worker"] +USE_V2_WORKER = os.getenv("REAL_USE_V2_WORKER", "0") == "1" def load_worker(worker_type: str) -> Type: @@ -25,6 +26,8 @@ def load_worker(worker_type: str) -> Type: def worker_type_to_module(worker_type: str): + if worker_type == "master_worker" and USE_V2_WORKER: + return "realhf.system.v2." + worker_type return "realhf.system." + worker_type diff --git a/realhf/system/buffer.py b/realhf/system/buffer.py index 6f375fa..d1db076 100644 --- a/realhf/system/buffer.py +++ b/realhf/system/buffer.py @@ -147,6 +147,7 @@ class AsyncIOSequenceBuffer: rpcs: List[dfg.MFCDef], max_size: int, ): + self.rpcs = rpcs self._lock = asyncio.Condition(asyncio.Lock()) # Buffer indicators, should be locked by self._lock. @@ -269,6 +270,62 @@ class AsyncIOSequenceBuffer: ) return indices + async def put_batch(self, samples: List[SequenceSample]): + n = len(samples) + + if n == 0: + return np.array([], dtype=np.int64) + + async with self._lock: + self._assert_valid_indicator() + + indices = np.where(self._is_empty)[0][:n] + + if len(indices) < n: + raise BufferFull( + "You are probably using a large dataset. " + "The default buffer size 1M is not large enough. " + "Please set a larger buffer size by setting " + "the environment variable, e.g., REAL_MASTER_BUFFER_SIZE=3000000." + ) + self._is_empty[indices] = False + self._is_being_put[indices] = True + + self.__buffer.put_batch(indices, samples) + + # Set a slight difference in birth time to let the order + # be deterministic. + self._birth_time[indices] = time.monotonic_ns() + np.arange( + len(indices), dtype=np.int64 + ) + + async with self._lock: + self.__buffer._update_has_keys(indices) + + has_keys = self.__buffer._get_has_keys(indices) # [bs, #keys] + rpc_key_mask = self._rpc_key_mask # [#keys, #rpcs] + self._ready_for_rpcs[indices] = ( + has_keys[:, :, None] >= rpc_key_mask[None, :, :] + ).all(axis=1) + + self._is_being_put[indices] = False + self._is_idle[indices] = True + + self._buf_size += len(samples) + if self._buf_size >= 0.95 * self.__max_size: + logger.warning( + f"Buffer is 95% full. The current buffer size is {self._buf_size} " + f"while the maximum size is {self.__max_size}. " + f"If your dataset has more than 1M sequences, consider enlarge " + f"the default batch size in the master worker." + ) + + can_do_rpcs = {rpc.name: self._can_do_rpc(rpc) for rpc in self.rpcs} + logger.info(f"After putting batch, can do RPCs? {can_do_rpcs}.") + + self._lock.notify(len(self._rpc_names)) + return indices + async def amend_batch(self, indices: List[int], samples: List[SequenceSample]): async with self._lock: await self._lock.wait_for( @@ -298,6 +355,17 @@ class AsyncIOSequenceBuffer: if self._is_idle[indices].any(): self._lock.notify(len(self._rpc_names)) + def _can_do_rpc(self, rpc: dfg.MFCDef) -> bool: + rpc_idx = self._rpc_names.index(rpc.name) + ready_indices = np.nonzero( + (self._is_idle | self._is_being_read) + & self._ready_for_rpcs[:, rpc_idx] + & ~self._completed_rpc[:, rpc_idx] + )[0] + if len(ready_indices) < rpc.n_seqs: + return False + return True + async def get_batch_for_rpc( self, rpc: dfg.MFCDef ) -> Tuple[List[int], SequenceSample]: @@ -306,20 +374,10 @@ class AsyncIOSequenceBuffer: ) rpc_idx = self._rpc_names.index(rpc.name) - def _can_do_rpc() -> bool: - ready_indices = np.nonzero( - (self._is_idle | self._is_being_read) - & self._ready_for_rpcs[:, rpc_idx] - & ~self._completed_rpc[:, rpc_idx] - )[0] - if len(ready_indices) < rpc.n_seqs: - return False - return True - async with self._lock: # await self._lock.wait_for(_can_do_rpc) - while not _can_do_rpc(): + while not self._can_do_rpc(rpc): await self._lock.wait() logger.info(f"Input keys ({rpc.input_keys}) for MFC {rpc.name} are ready!") diff --git a/realhf/system/controller.py b/realhf/system/controller.py index f10360d..c23f9db 100644 --- a/realhf/system/controller.py +++ b/realhf/system/controller.py @@ -619,6 +619,7 @@ class RayController: REAL_RECOVER_RUN=os.environ.get("REAL_RECOVER_RUN", ""), REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""), REAL_MATH_METADATA_PATH=os.environ.get("REAL_MATH_METADATA_PATH", ""), + REAL_USE_V2_WORKER=os.getenv("REAL_USE_V2_WORKER", "0"), ) runtime_env = { "env_vars": env_vars, diff --git a/realhf/system/master_worker.py b/realhf/system/master_worker.py index 371cccf..e7fff4f 100644 --- a/realhf/system/master_worker.py +++ b/realhf/system/master_worker.py @@ -39,11 +39,6 @@ from realhf.base import ( timeutil, topology, ) -from realhf.base.asyncio_utils import ( - raise_asyncio_exception, - setup_run_until_complete, - teardown_run_util_complete, -) from realhf.system.buffer import AsyncIOSequenceBuffer from realhf.system.flops_counter import FlopsCounter @@ -482,7 +477,7 @@ async def model_rpc_reply_func( if isinstance(responses[-1], data_api.SequenceSample): res = data_api.SequenceSample.gather(responses) else: - res = _gather_stat(responses) + res = data_api.gather_stat(responses) if rpc.log_return_value: logger.info(f"RPC name {rpc.name} returns {res}") @@ -491,7 +486,9 @@ async def model_rpc_reply_func( wandb.log(res, step=ctrl.step_info.global_step) if summary_writer is not None: for key, val in res.items(): - summary_writer.add_scalar(f"{key}", val, ctrl.step_info.global_step) + summary_writer.add_scalar( + f"{key}", val, ctrl.step_info.global_step + ) logger.info( f"Model rpc {rpc.name} finished. Run time {time.perf_counter() - tik:.4f}s." @@ -518,25 +515,6 @@ async def model_rpc_reply_func( await stream.gather_async(other_req_ids) -def _gather_stat(src: List[Dict]) -> Dict: - cnt, stats = {}, {} - for reply in src: - for k, v in reply.items(): - cnt[k] = cnt.get(k, 0) + 1 - stats[k] = stats.get(k, 0) + v - res = {k: v / cnt for k, v, cnt in zip(stats.keys(), stats.values(), cnt.values())} - for k, c in cnt.items(): - if c != len(src): - logger.warning(f"Gathered `{k}` is not present in every returned stats.") - for k, v in res.items(): - if any(abs(v - x.get(k, None)) > 1e-4 for x in src): - logger.warning( - f"Gathered `{k}` is not all-reduced " - f"before returning: ({[x.get(k, None) for x in src]}, {v})." - ) - return res - - class MasterWorker(worker_base.Worker): os.makedirs(constants.MODEL_SAVE_ROOT, exist_ok=True) global_exp_tik = time.perf_counter() @@ -952,6 +930,9 @@ class MasterWorker(worker_base.Worker): ) coroutine_tasks += [request_task, reply_task] + # Import here to avoid the conflict with nvloop. + from realhf.base.asyncio_utils import setup_run_until_complete + # Set up a run context of EventLoop.run_util_complete, baiscally copy-paste from cpython. # With this context, we can call the non-block EventLoop._run_once (similar to worker._poll). self.__asyncio_tasks: List[asyncio.Task] = coroutine_tasks @@ -993,6 +974,9 @@ class MasterWorker(worker_base.Worker): ) def _poll(self): + # Import here to avoid the conflict with nvloop. + from realhf.base.asyncio_utils import raise_asyncio_exception + is_new_epoch = False first_poll = not self.__initialized @@ -1276,6 +1260,9 @@ class MasterWorker(worker_base.Worker): self.__rpc_ctrl.ids_to_clear.clear() def experiment_complete_exit(self): + # Import here to avoid the conflict with nvloop. + from realhf.base.asyncio_utils import teardown_run_util_complete + self.__rpc_ctrl.stop.set() for task in self.__asyncio_tasks: task.cancel() diff --git a/realhf/system/model_worker.py b/realhf/system/model_worker.py index d2f4a78..c0bf939 100644 --- a/realhf/system/model_worker.py +++ b/realhf/system/model_worker.py @@ -340,6 +340,8 @@ class ModelWorker(worker_base.Worker): for tmp_sample in self.__dataloader: self.__raw_samples += tmp_sample.meta().unpack() + self.__data_generator = enumerate(self.__dataloader) + self.__models: Dict[ModelName, model_api.Model] = dict() self.__model_is_handle: Dict[ModelName, bool] = dict() self.__interfaces: Dict[ModelName, model_api.ModelInterface] = dict() @@ -581,7 +583,10 @@ class ModelWorker(worker_base.Worker): elif request.handle_name == "fetch": dp_rank = int(re.search(r"__data(\d+)__", request.handler).group(1)) assert self.__has_dataset - if request.data["first_batch"]: + # Fetch. + try: + self.__dataset_batch_counter, cur_sample = next(self.__data_generator) + except StopIteration: # Upon the first fetch request, filter dataset and create dataloader. eval_scores_path = os.path.join( constants.MODEL_SAVE_ROOT, @@ -595,10 +600,8 @@ class ModelWorker(worker_base.Worker): constants.trial_name(), f"dataset_indices_{dp_rank}.npy", ) - if ( - hasattr(self.__dataset, "filter") - and not request.data["first_poll"] - and os.path.exists(eval_scores_path) + if hasattr(self.__dataset, "filter") and os.path.exists( + eval_scores_path ): # Don't filter dataset on the first poll after recover. with open(eval_scores_path, "r", encoding="utf-8") as f: @@ -621,9 +624,7 @@ class ModelWorker(worker_base.Worker): generator=g, ) self.__data_generator = enumerate(self.__dataloader) - - # Fetch. - self.__dataset_batch_counter, cur_sample = next(self.__data_generator) + self.__dataset_batch_counter, cur_sample = next(self.__data_generator) # Defer data that has not been used in the previous epoch. data_loaded = [] @@ -707,9 +708,6 @@ class ModelWorker(worker_base.Worker): # e.g., data transfer, parameter syncrhonization. pass elif request.handle_name == "initialize": - assert not self.__model_is_handle[ - request.handler.model_name - ], request.handler.model_name self.__models[request.handler.model_name] = self._backend.initialize( self._model, data ) diff --git a/realhf/system/v2/function_call.py b/realhf/system/v2/function_call.py new file mode 100644 index 0000000..9eb406f --- /dev/null +++ b/realhf/system/v2/function_call.py @@ -0,0 +1,437 @@ +# Copyright 2025 Ant Group Inc. + +import asyncio +import dataclasses +import os +import time +import uuid +from collections import defaultdict +from typing import Dict, Hashable, List, Set, Tuple + +import wandb + +import realhf.api.core.config as config_api +import realhf.api.core.data_api as data_api +import realhf.api.core.dfg as dfg +import realhf.api.core.system_api as config_pkg +import realhf.base.recover as recover +import realhf.system.request_reply_stream as request_reply_stream +from realhf.api.core.config import ModelName +from realhf.api.core.model_api import ReaLModelConfig +from realhf.base import constants, logging, topology +from realhf.system.buffer import AsyncIOSequenceBuffer +from realhf.system.flops_counter import FlopsCounter + +logger = logging.getLogger(__name__, "system") +blogger = logging.getLogger("benchmark") + + +@dataclasses.dataclass +class RPCCorountineControl: + # for counting the number of finished training steps + # one training step corresponds to traversal of the whole DFG + train_count: asyncio.Queue + # For flushing requests + topo_level_count: asyncio.Queue + + # for training data management and data cleaning after each step + ids_to_clear: Set[Hashable] = dataclasses.field(default_factory=set) + flops_counter: FlopsCounter = dataclasses.field(default_factory=FlopsCounter) + + data_owner: Dict[Tuple[int, str], Tuple[ModelName, int]] = dataclasses.field( + default_factory=dict + ) + + should_save: bool = False + should_eval: bool = False + should_ckpt: bool = False + step_info: recover.StepInfo = dataclasses.field(default_factory=recover.StepInfo) + + # recover information + used_hash_vals_this_epoch: List[int] = dataclasses.field(default_factory=list) + hash_vals_to_ignore_in_recover: List[int] = dataclasses.field(default_factory=list) + + +class FunctionCall: + def __init__( + self, + rpc: dfg.MFCDef, + src_rpc: dfg.MFCDef, + stream: request_reply_stream.NameResolvingRequestClient, + msid2mwid: Dict[config_pkg.ModelShardID, int], + model_topos: Dict[str, topology.PipeModelDataParallelTopology], + model_configs: Dict[str, None | ReaLModelConfig], + ctrl: RPCCorountineControl, + buffer: AsyncIOSequenceBuffer, + ): + + self.rpc = rpc + self.src_rpc = src_rpc + self.stream = stream + + self.msid2mwid = msid2mwid + self.model_topos = model_topos + self.model_configs = model_configs + + self.model_save_root = os.path.join( + constants.MODEL_SAVE_ROOT, + constants.experiment_name(), + constants.trial_name(), + ) + + self.rpc_ctrl = ctrl + + self.buffer = buffer + + @property + def dp_size(self): + return self.model_topos[self.rpc.model_name].get_dim("data") + + @property + def pp_size(self): + return self.model_topos[self.rpc.model_name].get_dim("pipe") + + def attach_payloads_with_hooks( + self, + payloads: Dict[config_api.ModelShardID, request_reply_stream.Payload], + mwids: List[int], + main_handlers: List[config_pkg.ModelShardID], + hook_type: str, + ) -> Tuple[Dict[config_api.ModelShardID, request_reply_stream.Payload], List[int]]: + assert hook_type in ["pre", "post"], hook_type + + rpc = self.rpc + model_topos = self.model_topos + model_configs = self.model_configs + + main_mwids = set([self.msid2mwid[h] for h in main_handlers]) + for hook in getattr(rpc, f"_{hook_type}_hooks"): + if isinstance(hook, dfg.ParamReallocHook): + assert (hook.source is None) != (hook.target is None), hook + if hook.source is None: + src_topo = model_topos[rpc.model_name] + dst_topo = model_topos[hook.target] + dst_config = model_configs[hook.target] + src_model_name, dst_model_name = rpc.model_name, hook.target + other_model_name = hook.target + other_topo = dst_topo + else: + src_topo = model_topos[hook.source] + dst_topo = model_topos[rpc.model_name] + dst_config = model_configs[rpc.model_name] + src_model_name, dst_model_name = hook.source, rpc.model_name + other_model_name = hook.source + other_topo = src_topo + + ps_data = { + "from_model_name": src_model_name, + "to_model_name": dst_model_name, + "from_topo": src_topo, + "to_topo": dst_topo, + "to_model_config": dst_config, + "eta": hook.eta, + } + for h in main_handlers: + getattr(payloads[h], f"{hook_type}_hooks").append("param_realloc") + getattr(payloads[h], f"{hook_type}_hook_data").append(ps_data) + other_handlers = [ + config_api.ModelShardID.from_parallelism_rank( + other_model_name, other_topo, j + ) + for j in range(other_topo.world_size()) + ] + for h in other_handlers: + if self.msid2mwid[h] not in mwids: + payloads[h] = request_reply_stream.Payload( + handler=h, + handle_name="empty", + ) + setattr(payloads[h], f"{hook_type}_hooks", ["param_realloc"]) + setattr(payloads[h], f"{hook_type}_hook_data", [ps_data]) + mwids.append(self.msid2mwid[h]) + elif self.msid2mwid[h] not in main_mwids: + hh = next( + hh + for hh in payloads + if self.msid2mwid[hh] == self.msid2mwid[h] + ) + getattr(payloads[hh], f"{hook_type}_hooks").append( + "param_realloc" + ) + getattr(payloads[hh], f"{hook_type}_hook_data").append(ps_data) + + elif isinstance(hook, dfg.OffloadHook): + for h in main_handlers: + getattr(payloads[h], f"{hook_type}_hooks").append("offload") + getattr(payloads[h], f"{hook_type}_hook_data").append( + dict(model_name=h.model_name) + ) + else: + raise NotImplementedError(f"Unknown hook type: {hook}") + return payloads, mwids + + def request( + self, + producer_names: Dict[str, str], + producer_name2producer_handlers: Dict[str, List[config_pkg.ModelShardID]], + producer_mappings: Dict[str, Dict[str, List[int]]], + target_mapping: Dict[str, List[int]], + meta_sample: data_api.SequenceSample, + handlers: List[config_pkg.ModelShardID], + ) -> Tuple[List[uuid.UUID], List[uuid.UUID]]: + + rpc = self.rpc + ctrl = self.rpc_ctrl + + dt_data = { + "keys": rpc.input_keys, + "target": rpc.model_name, + "producer_names": producer_names, + "producer_mappings": producer_mappings, + "target_mapping": target_mapping, + "handle_name": rpc.interface_type.value, + "rpc_name": rpc.name, + "meta_sample": meta_sample, + } + + payloads = { + handler: request_reply_stream.Payload( + handler=handler, + handle_name=rpc.interface_type.value, + pre_hooks=["data_transfer"], + pre_hook_data=[dt_data], + data=rpc.name, + ) + for handler in handlers + } + if ctrl.should_eval: + for p in payloads.values(): + p.post_hooks.append("evaluate") + p.post_hook_data.append(dict(model_name=rpc.model_name)) + if ( + ctrl.should_save or ctrl.should_ckpt + ) and rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP: + for p in payloads.values(): + p.post_hooks.append("save") + save_dir = os.path.join( + self.model_save_root, + rpc.model_name.role, + f"epoch{ctrl.step_info.epoch + 1}" + f"epochstep{ctrl.step_info.epoch_step + 1}" + f"globalstep{ctrl.step_info.global_step + 1}", + ) + p.post_hook_data.append( + dict( + model_name=rpc.model_name, + save_dir=save_dir, + recover_only=not ctrl.should_save, + ) + ) + mwids = [self.msid2mwid[h] for h in handlers] + assert len(mwids) == len(set(mwids)) + + for producer_name in producer_names.values(): + for h in producer_name2producer_handlers[producer_name]: + if self.msid2mwid[h] not in mwids: + payloads[h] = request_reply_stream.Payload( + handler=h, + handle_name="empty", + pre_hooks=["data_transfer"], + pre_hook_data=[dt_data], + ) + mwids.append(self.msid2mwid[h]) + + payloads, mwids = self.attach_payloads_with_hooks( + payloads, + mwids, + main_handlers=handlers, + hook_type="pre", + ) + payloads, mwids = self.attach_payloads_with_hooks( + payloads, + mwids, + main_handlers=handlers, + hook_type="post", + ) + main_payloads = [p for h, p in payloads.items() if h in handlers] + other_payloads = [p for h, p in payloads.items() if h not in handlers] + all_req_ids = self.stream.request( + payloads=main_payloads + other_payloads, + ) + return all_req_ids[: len(main_payloads)], all_req_ids[len(main_payloads) :] + + def data_parallel_dispatch( + self, buf_indices: List[int], sample: data_api.SequenceSample + ) -> Tuple[List[int], data_api.SequenceSample, List[Tuple[int, int]]]: + # Dispatch data to different data parallel ranks. + if self.rpc.is_generate(): + # The workload of generation is decided by batch size, instead of the generated length. + samples, forward_indices, _ = sample.split_with_lengths( + mb_spec=data_api.MicroBatchSpec(n_mbs=self.dp_size), + lens=[1 for _ in range(sample.bs)], + ) + else: + samples, forward_indices, _ = sample.split( + data_api.MicroBatchSpec(n_mbs=self.dp_size) + ) + blogger.info( + f"DP split (DP size {self.dp_size}) for RPC {self.rpc.name}: " + f"#seqs: {[s.bs for s in samples]}, " + f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in samples]}" + ) + sample = data_api.SequenceSample.gather(samples) + buf_indices = [buf_indices[i] for i in forward_indices] + + partitions = data_api.SequenceSplitSpec( + sizes=[s.bs for s in samples] + ).partitions + return buf_indices, sample, partitions + + async def run_step(self, buf_indices, sample): + rpc = self.rpc + topo = self.model_topos[rpc.model_name] + ctrl = self.rpc_ctrl + + handlers = [ + config_pkg.ModelShardID.from_parallelism_rank(rpc.model_name, topo, j) + for j in range(topo.world_size()) + ] + + producer_names = {} # data key -> model name + for k in rpc.input_keys: + if k in rpc.data_producers: + producer_names[k] = rpc.data_producers[k] + else: + producer_names[k] = self.src_rpc.model_name + keys_to_send = defaultdict(list) # model name -> List[keys] to send + for k in producer_names: + keys_to_send[producer_names[k]].append(k) + + # convert producer model name to ModelShardID + producer_name2producer_handlers = {} + for producer_name in keys_to_send: + producer_name2producer_handlers[producer_name] = [ + config_pkg.ModelShardID.from_parallelism_rank( + producer_name, self.model_topos[producer_name], j + ) + for j in range(self.model_topos[producer_name].world_size()) + ] + + dp_head_indices = [ + topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, model=0) + for i in range(self.dp_size) + ] + + ctrl.flops_counter.add_rpc(rpc, sample, self.model_configs[rpc.model_name]) + + # logger.info(f"Model rpc {rpc.name} requesting.") + + # Sample may be reordered here. + buf_indices, sample, partitions = self.data_parallel_dispatch( + buf_indices, sample + ) + target_mapping = {i: list(range(v[0], v[1])) for i, v in enumerate(partitions)} + + # Set data owner of produced data by this RPC, such that downstream RPCs can know + # where to fetch these data. + for dp_idx, (st, ed) in enumerate(partitions): + for i in range(st, ed): + for k in rpc.output_keys: + self.rpc_ctrl.data_owner[sample.ids[i], k] = ( + rpc.model_name, + dp_idx, + ) + + # Get the data owner of this RPC's input data. + # We use it to determine the source of data transfer. + producer_mappings = {} + for k in rpc.input_keys: + names, dp_indices = [], [] + for sample_id in sample.ids: + owner_name, dp_idx = self.rpc_ctrl.data_owner[(sample_id, k)] + names.append(owner_name) + dp_indices.append(dp_idx) + assert len(set(names)) == 1 + producer_mapping = defaultdict(list) + for i, dp_idx in enumerate(dp_indices): + producer_mapping[dp_idx].append(i) + producer_mapping = {k: sorted(v) for k, v in producer_mapping.items()} + producer_mappings[names[0], k] = producer_mapping + + # send partitioned data to model workers + req_ids, other_req_ids = self.request( + producer_names=producer_names, + producer_name2producer_handlers=producer_name2producer_handlers, + producer_mappings=producer_mappings, + target_mapping=target_mapping, + meta_sample=sample, + handlers=handlers, + ) + tik = time.perf_counter() + + await ctrl.topo_level_count.put(1) + logger.info(f"Model rpc {rpc.name} requested.") + + # Then, wait for all main requests to finish. + responses = await self.stream.gather_async(request_ids=req_ids) + # logger.info(f"rpc {rpc.name} received responses {req_ids}") + + # Filter out responses other than DP heads. + # Other repsonses are duplicated or None. + responses = [responses[i] for i in dp_head_indices] + + # If the returned data is a SequenceSample, it is the data returned by + # model function calls. The data shoulbe be amended into buffer. + # Otherwise, it's the train statistics and should be reduced and logged. + if isinstance(responses[-1], data_api.SequenceSample): + res = data_api.SequenceSample.gather(responses) + else: + res = data_api.gather_stat(responses) + + if rpc.log_return_value: + logger.info(f"RPC name {rpc.name} returns {res}") + + if isinstance(res, Dict): + wandb.log(res, step=ctrl.step_info.global_step) + + logger.info( + f"Model rpc {rpc.name} finished. " + f"Run time {time.perf_counter() - tik:.4f}s." + ) + + # If this RPC is the final node in the dataflow graph, + # update the train counter. + # Otherwise, amend data in the buffer. + if rpc.is_dst: + ctrl.ids_to_clear = ctrl.ids_to_clear.union(sample.ids) + await ctrl.train_count.put(1) + else: + logger.info(f"Amending RPC {rpc.name} output keys: {res.keys}") + await self.buffer.amend_batch(buf_indices, res.unpack()) + + # Wait for all side-effect requests to finish. + # Side-effect or empty requests are required for data transfer + # and parameter synchronization. + # Wait them after the main request to log the oorrect MFC time. + await self.stream.gather_async(other_req_ids) + + async def run(self): + rpc = self.rpc + topo = self.model_topos[rpc.model_name] + ctrl = self.rpc_ctrl + + logger.info( + f"Running Model RPC, interface_type=#{rpc.interface_type}# " + f"(dp,mp,pp) = *({topo.get_dim('data')},{topo.get_dim('model')},{topo.get_dim('pipe')})*" + ) + + consumed = 0 + while True: + buf_indices, sample = await self.buffer.get_batch_for_rpc(rpc) + + await self.run_step(buf_indices, sample) + consumed += sample.bs + + # Ensure that parent RPCs will not be over-consumed. + if all(consumed >= c.n_seqs for c in rpc.all_successors()): + break diff --git a/realhf/system/v2/function_executor.py b/realhf/system/v2/function_executor.py new file mode 100644 index 0000000..d81d5a0 --- /dev/null +++ b/realhf/system/v2/function_executor.py @@ -0,0 +1,168 @@ +# Copyright 2025 Ant Group Inc. +import asyncio +import random +from typing import * + +import networkx as nx + +from realhf.api.core.config import ModelName, ModelShardID +from realhf.api.core.data_api import DataBatchMeta, SequenceSample +from realhf.api.core.dfg import MFCDef +from realhf.api.core.model_api import ReaLModelConfig +from realhf.base import logging +from realhf.base.topology import PipeModelDataParallelTopology +from realhf.system.buffer import AsyncIOSequenceBuffer +from realhf.system.request_reply_stream import NameResolvingRequestClient +from realhf.system.v2.function_call import FunctionCall, RPCCorountineControl + +logger = logging.getLogger(__name__, "system") +blogger = logging.getLogger("benchmark") + + +class FunctionExecutor: + def __init__( + self, + rpcs: List[MFCDef], + msid2mwid: Dict[ModelShardID, int], + stream: NameResolvingRequestClient, + buffer: AsyncIOSequenceBuffer, + model_topos: Dict[str, PipeModelDataParallelTopology], + model_configs: Dict[str, None | ReaLModelConfig], + ctrl: RPCCorountineControl, + ): + + self.func_calls: Dict[str, FunctionCall] = {} + self.ctrl = ctrl + + self.n_model_workers = len(set(msid2mwid.values())) + + self.rpcs = rpcs + self.src_rpc = list(filter(lambda rpc: rpc.is_src, rpcs))[0] + self.src_dp_size = model_topos[self.src_rpc.model_name].get_dim("data") + + # Create model function calls. + for rpc in self.rpcs: + func_call = FunctionCall( + rpc=rpc, + src_rpc=self.src_rpc, + stream=stream, + msid2mwid=msid2mwid, + model_topos=model_topos, + model_configs=model_configs, + ctrl=ctrl, + buffer=buffer, + ) + self.func_calls[rpc.name] = func_call + + self.stream = stream + self.buffer = buffer + + # Sort all MFCs in the topological order and + # calculate the width of each level. + # These numbers will determine when to flush MFC requests. + self.topo_widths = [] + for generation in nx.topological_generations(rpcs[0]._G): + self.topo_widths.append(len(generation)) + + def get_leaf_tasks(self) -> List[str]: + dst_rpcs = list(filter(lambda rpc: rpc.is_dst, self.rpcs)) + return [rpc.name for rpc in dst_rpcs] + + async def flush_calls(self): + for level, w in enumerate(self.topo_widths): + for _ in range(w): + await self.ctrl.topo_level_count.get() + logger.info(f"DFG level {level}. Flushing {w} function calls.") + self.stream.request( + handlers=list(range(self.n_model_workers)), handle_type="flush" + ) + + async def finish_traverse(self): + for _ in range(len(self.get_leaf_tasks())): + await self.ctrl.train_count.get() + + async def load_data(self): + src_rpc = self.src_rpc + src_rpc_model_name = src_rpc.model_name + buffer = self.buffer + ctrl = self.ctrl + + dp_idx = -1 + received_ids = set() + + while self.buffer.size < max(rpc.n_seqs for rpc in self.rpcs): + + all_data = [] + + dp_idx += 1 + dp_idx %= self.src_dp_size + + resps = await self.stream.call_async( + handlers=[f"__data{dp_idx}__"], + handle_type="fetch", + datas=[None], + verbose=False, + ) + x: DataBatchMeta | None = resps[0] + + if x is None: + continue + if x.meta_sample is None: + continue + + # Store the owner information of the data. + # RPCs corountines will use this information to + # determine the src and dst of data transfer. + for xx in x.meta_sample.unpack(): + if xx.ids[0] in received_ids: + raise ValueError( + f"Duplicate data id {xx.ids[0]}. Is the final batch? {is_final_batch}." + ) + received_ids.add(xx.ids[0]) + for k in xx.keys: + self.ctrl.data_owner[(xx.ids[0], k)] = ( + src_rpc_model_name, + dp_idx, + ) + all_data += x.meta_sample.unpack() + + filtered_data = [] + for xx in x.meta_sample.unpack(): + if xx.ids[0] in ctrl.hash_vals_to_ignore_in_recover: + ctrl.hash_vals_to_ignore_in_recover.remove(xx.ids[0]) + ctrl.ids_to_clear.add(xx.ids[0]) + else: + filtered_data.append(xx) + all_data = filtered_data + + # We load data in a round-robin manner across different DP ranks, + # so we also need to shuffle the data to fuse different dataset splits. + random.shuffle(all_data) + + # Store into buffer! + buffer_indices = await buffer.put_batch(all_data) + assert len(buffer_indices) == len(all_data) + + blogger.info( + f"Master worker loaded {len(all_data)} pieces of data. " + f"Remaining number of data to ignore: {len(self.ctrl.hash_vals_to_ignore_in_recover)}. " + f"Current buffer size: {buffer.size}/{buffer.max_size}. " + ) + + def execute_step(self): + logger.info("Waiting for the finish of the execution graph.") + loop = asyncio.get_event_loop() + + tasks = [loop.create_task(fc.run()) for fc in self.func_calls.values()] + [ + loop.create_task(self.flush_calls()), + loop.create_task(self.load_data()), + ] + + completion_future = loop.create_task(self.finish_traverse()) + + loop.run_until_complete(completion_future) + + for task in tasks: + loop.run_until_complete(task) + + logger.info("Execution finished!") diff --git a/realhf/system/v2/master_worker.py b/realhf/system/v2/master_worker.py new file mode 100644 index 0000000..81bbae2 --- /dev/null +++ b/realhf/system/v2/master_worker.py @@ -0,0 +1,501 @@ +# Copyright 2025 Ant Group Inc. +# Copyright 2024 Wei Fu & Zhiyu Mei +# Licensed under the Apache License, Version 2.0 (the "License"). + +import asyncio +import copy +import gc +import os +import time +from typing import Dict + +import colorama +import networkx as nx +import numpy as np +import uvloop +import wandb + +import realhf.api.core.dfg as dfg +import realhf.api.core.model_api as model_api +import realhf.api.core.system_api as config_pkg +import realhf.base.recover as recover +import realhf.system.request_reply_stream as request_reply_stream +import realhf.system.worker_base as worker_base +from realhf.api.core.config import ModelName +from realhf.api.core.model_api import ReaLModelConfig +from realhf.base import constants, logging, name_resolve, names, timeutil, topology +from realhf.system.buffer import AsyncIOSequenceBuffer +from realhf.system.v2.function_call import RPCCorountineControl +from realhf.system.v2.function_executor import FunctionExecutor + +logger = logging.getLogger("master worker", "system") +blogger = logging.getLogger("benchmark") + +uvloop.install() + + +class MasterWorker(worker_base.Worker): + global_exp_tik = time.perf_counter() + + def _configure(self, config: config_pkg.MasterWorker): + self.config = config + + self.__model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = ( + config.model_topos + ) + + # Build execution graph and initialize concurrency utilities. + self.__model_rpcs = config.model_rpcs + + # Sort all MFCs in the topological order and + # calculate the width of each level. + # These numbers will determine when to flush MFC requests. + self.__topo_widths = [] + for generation in nx.topological_generations(self.__model_rpcs[0]._G): + self.__topo_widths.append(len(generation)) + logger.info("Topological widths: " + str(self.__topo_widths)) + + self.__rpc_srcs = list(filter(lambda rpc: rpc.is_src, self.__model_rpcs)) + self.__rpc_dsts = list(filter(lambda rpc: rpc.is_dst, self.__model_rpcs)) + + # Save and eval control. + self.__total_train_epochs = config.exp_ctrl.total_train_epochs + self.__save_ctl = timeutil.EpochStepTimeFreqCtl( + freq_epoch=config.exp_ctrl.save_freq_epochs, + freq_step=config.exp_ctrl.save_freq_steps, + freq_sec=config.exp_ctrl.save_freq_secs, + ) + if ( + config.exp_ctrl.ckpt_freq_epochs is None + and config.exp_ctrl.ckpt_freq_steps is None + and config.exp_ctrl.ckpt_freq_secs is None + ): + self.__ckpt_ctl = self.__save_ctl + else: + self.__ckpt_ctl = timeutil.EpochStepTimeFreqCtl( + freq_epoch=config.exp_ctrl.ckpt_freq_epochs, + freq_step=config.exp_ctrl.ckpt_freq_steps, + freq_sec=config.exp_ctrl.ckpt_freq_secs, + ) + self.__eval_ctl = timeutil.EpochStepTimeFreqCtl( + freq_epoch=config.exp_ctrl.eval_freq_epochs, + freq_step=config.exp_ctrl.eval_freq_steps, + freq_sec=config.exp_ctrl.eval_freq_secs, + ) + + self.MODEL_SAVE_ROOT = os.path.join( + constants.MODEL_SAVE_ROOT, + config.worker_info.experiment_name, + config.worker_info.trial_name, + ) + os.makedirs(self.MODEL_SAVE_ROOT, exist_ok=True) + + self.__initialized = False + self.__recover_run, self.__recover_info = recover.load_recover_info() + if self.__recover_info is not None: + logger.info( + f"Loaded recover info: recover_start={self.__recover_info.recover_start}, " + f"last_step_info={self.__recover_info.last_step_info}." + ) + logger.info( + f"Number of used data in recover info: {len(self.__recover_info.hash_vals_to_ignore)}. " + f"The previous experiment probably ran for {len(self.__recover_info.hash_vals_to_ignore) // self.__rpc_srcs[0].n_seqs} steps in the epoch." + ) + + # Create corountine control objects for running the dataflow graph. + self.__rpc_ctrl = RPCCorountineControl( + train_count=asyncio.Queue(maxsize=len(self.__rpc_dsts)), + topo_level_count=asyncio.Queue(maxsize=sum(self.__topo_widths)), + # NOTE: We should accumulate the used data hashes in the same epoch + # to prevent loading data used before. + used_hash_vals_this_epoch=( + copy.deepcopy(self.__recover_info.hash_vals_to_ignore) + if self.__recover_run + else list() + ), + hash_vals_to_ignore_in_recover=( + copy.deepcopy(self.__recover_info.hash_vals_to_ignore) + if self.__recover_run + else list() + ), + ) + + if self.__recover_run: + self.__rpc_ctrl.step_info = copy.deepcopy(self.__recover_info.recover_start) + + self.__eval_ctl.load_state_dict(self.__recover_info.eval_ctl_info) + self.__save_ctl.load_state_dict(self.__recover_info.save_ctl_info) + self.__ckpt_ctl.load_state_dict(self.__recover_info.ckpt_ctl_info) + + logger.info( + f"Recovering from previous run. " + f"Epoch: {self.__rpc_ctrl.step_info.epoch + 1}, " + f"Epoch Step: {self.__rpc_ctrl.step_info.epoch_step + 1} " + f"Global Step: {self.__rpc_ctrl.step_info.global_step + 1}." + ) + + # for benchmark + self.e2e_time_history = [] + self.__benchmark_steps = config.exp_ctrl.benchmark_steps + + return config.worker_info + + def initialize_models(self): + # Initialize model backends. + model_names = list(self.__model_topos.keys()) + self.logger.info(f"Initialize model backends with order: {model_names}.") + train_rpcs = list( + filter( + lambda rpc: rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP, + self.__model_rpcs, + ) + ) + assert all(rpc.n_seqs == train_rpcs[0].n_seqs for rpc in train_rpcs) + if len(train_rpcs) > 0: + ft_spec = model_api.FinetuneSpec( + total_train_epochs=self.config.exp_ctrl.total_train_epochs, + dataset_size=self._dataset_size, + train_batch_size=train_rpcs[0].n_seqs, + ) + else: + ft_spec = model_api.FinetuneSpec( + total_train_epochs=self.config.exp_ctrl.total_train_epochs, + dataset_size=self._dataset_size, + train_batch_size=self.__src_rpc.n_seqs, + ) + _initialized_roles = [] + for model_name in model_names: + topo = self.config.model_topos[model_name] + # Build FinetuneSpec, which is required to initialize backends. + _handlers = [ + config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, j) + for j in range(topo.world_size()) + ] + + init_payloads = [ + request_reply_stream.Payload( + handler=_h, + handle_name="initialize", + data=ft_spec, + ) + for _h in _handlers + ] + + # Send initialization requests then immediately flush them. + self.__stream.request( + payloads=init_payloads, + ) + self.__stream.request( + handlers=_handlers, + handle_type="flush", + no_syn=True, + ) + + _initialized_roles.append(model_name.role) + + self._ft_spec = ft_spec + logger.info("Initializations of models and backends complete.") + + def get_dataset_model_info(self): + src_rpc = self.__rpc_srcs[0] + src_rpc_topo = self.config.model_topos[src_rpc.model_name] + src_rpc_dp_size = src_rpc_topo.get_dim("data") + + # Request training specification from data workers. + all_data = sum( + self.__stream.call( + handlers=[f"__data{i}__" for i in range(src_rpc_dp_size)], + datas=[None for i in range(src_rpc_dp_size)], + handle_type="spec", + ), + [], + ) + + # NOTE: For dynamic datasets, we still count epoch according to the initial number of data, + # such that the learning rate decay is not affected. + seqlens = [max(sum(v[0]) for v in x.seqlens.values()) for x in all_data] + self._dataset_size = len(all_data) + self._steps_per_epoch = self._dataset_size // src_rpc.n_seqs + self._avg_tokens_per_batch = sum(seqlens) / self._steps_per_epoch + self._dataset_ids = [copy.deepcopy(x.ids[0]) for x in all_data] + + # Request model configs from model workers. + # Return None if the model is not a ReaLModel. + self.__model_configs: Dict[ModelName, None | ReaLModelConfig] = {} + for model_name, topo in self.config.model_topos.items(): + h = config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, 0) + self.__model_configs[model_name] = self.__stream.call( + handlers=[h], + datas=[None], + handle_type="model_config", + )[0] + + def __lazy_init(self): + # Set up streams. + handler_routing = copy.deepcopy(self.config.msid2mwid) + src_rpc = self.__rpc_srcs[0] + src_rpc_topo = self.config.model_topos[src_rpc.model_name] + src_rpc_dp_size = src_rpc_topo.get_dim("data") + src_rpc_pp_size = src_rpc_topo.get_dim("pipe") + for i in range(src_rpc_dp_size): + rank = src_rpc_topo.get_rank(data=i, pipe=src_rpc_pp_size - 1, model=0) + handler_routing[f"__data{i}__"] = self.config.msid2mwid[ + config_pkg.ModelShardID.from_parallelism_rank( + model_name=src_rpc.model_name, + topo=src_rpc_topo, + parallelism_rank=rank, + ) + ] + handler_routing.update({i: i for i in range(self.config.n_model_workers)}) + self.__stream = request_reply_stream.make_master_stream( + self.config.worker_info, + n_subscribers=self.config.n_model_workers, + handler_routing=handler_routing, + ) + self.__stream: request_reply_stream.NameResolvingRequestClient + + self.__src_rpc = src_rpc = [ + rpc for rpc in self.config.model_rpcs if rpc.is_src + ][0] + + self.get_dataset_model_info() + + self.initialize_models() + + self.__seqbuffer = AsyncIOSequenceBuffer( + self.__model_rpcs, + max_size=int(os.getenv("REAL_MASTER_BUFFER_SIZE", str(int(1e7)))), + ) + + # Create coroutines for model RPCs. + logger.info(f"Creating asyncio coroutines...") + self.func_executor = FunctionExecutor( + rpcs=self.__model_rpcs, + msid2mwid=self.config.msid2mwid, + stream=self.__stream, + buffer=self.__seqbuffer, + model_topos=self.__model_topos, + model_configs=self.__model_configs, + ctrl=self.__rpc_ctrl, + ) + logger.info(f"Coroutines created. The master worker is ready to run.") + + # wandb init, connect to remote wandb host + wandb.login() + wandb.init( + mode=self.wandb_config.mode, + entity=self.wandb_config.entity, + project=self.wandb_config.project or constants.experiment_name(), + name=self.wandb_config.name or constants.trial_name(), + job_type=self.wandb_config.job_type, + group=self.wandb_config.group, + notes=self.wandb_config.notes, + tags=self.wandb_config.tags, + config=self.wandb_config.config, + dir=os.path.join( + constants.LOG_ROOT, constants.experiment_name(), constants.trial_name() + ), + force=True, + resume="allow", + settings=wandb.Settings(start_method="fork"), + ) + + self.__initialized = True + self._train_start_time = time.perf_counter() + + self.__last_step_info = recover.StepInfo( + epoch=-1, + epoch_step=-1, + global_step=-1, + ) + + def _poll(self): + is_new_epoch = False + + if not self.__initialized: + self.__lazy_init() + + # Main execution steps. The graph runs under-the-hood in RPC & stream threads. + # Wait for the finish of the traversal of the execution graph. + execution_start = time.perf_counter() + if self.__rpc_ctrl.ids_to_clear: + # Send clear cache requests to model workers. + # Clearing the data used in the last step. + self._clear_gpu_cache() + + is_new_epoch = self._ft_spec.is_new_epoch(self.__rpc_ctrl.step_info) + is_epoch_last_step = self._ft_spec.is_epoch_last_step(self.__rpc_ctrl.step_info) + + # Check whether we should evaluate or save models. + self.__rpc_ctrl.should_eval = self.__eval_ctl.check( + epochs=int(is_epoch_last_step), steps=1 + ) + self.__rpc_ctrl.should_save = self.__save_ctl.check( + epochs=int(is_epoch_last_step), steps=1 + ) + self.__rpc_ctrl.should_ckpt = self.__ckpt_ctl.check( + epochs=int(is_epoch_last_step), steps=1 + ) + + # Traverse over the dataflow graph for once. + self.func_executor.execute_step() + + # Post-process. + if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt: + self.__last_step_info = copy.deepcopy(self.__rpc_ctrl.step_info) + + self.__rpc_ctrl.used_hash_vals_this_epoch += list(self.__rpc_ctrl.ids_to_clear) + + if is_epoch_last_step: + self.__rpc_ctrl.used_hash_vals_this_epoch = ( + self.__rpc_ctrl.used_hash_vals_this_epoch[self._dataset_size :] + ) + + if is_new_epoch: + self.__rpc_ctrl.step_info.epoch += 1 + self.__rpc_ctrl.step_info.epoch_step = 0 + + # Logging. + time_since_configure = time.perf_counter() - self._train_start_time + e2e_time = time.perf_counter() - execution_start + self.e2e_time_history.append(e2e_time) + + self._log_training_stats(e2e_time, time_since_configure) + + # Updata counters. + self.__rpc_ctrl.step_info.epoch_step += 1 + self.__rpc_ctrl.step_info.global_step += 1 + + if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt: + self.__recover_save() + + # Pause the worker if experiment or system-wise benchmark completes. + if ( + self.__benchmark_steps is not None + and self.__rpc_ctrl.step_info.global_step >= self.__benchmark_steps + ) or ( + self.__rpc_ctrl.step_info.global_step * self.__src_rpc.n_seqs + >= self.__total_train_epochs * self._dataset_size + ): + # We don't know whether it is the last step of the current epoch, + # so we exit at the first step of the next epoch. + if self.__benchmark_steps is not None: + logger.info( + f"Finished benchmark {self.__benchmark_steps}. " + f"Time consumption of this setup: {time_since_configure:.3f}" + ) + logger.info(f"avg #e2e# time *{np.mean(self.e2e_time_history):.3f}*") + return self.experiment_complete_exit() + + return worker_base.PollResult(sample_count=1, batch_count=1) + + def _log_training_stats(self, e2e_time: float, time_since_configure: float): + # calculate flops + ######################################### + if not all( + isinstance(v, ReaLModelConfig) for v in self.__model_configs.values() + ): + logger.warning( + f"Not all models are ReaLModels. Unable to calculate FLOP/s." + ) + flops = None + tflops_per_gpu = float("inf") + else: + flops = self.__rpc_ctrl.flops_counter.get_flops() + tflops = flops / (e2e_time * (10**12)) + tflops_per_gpu = flops / (e2e_time * self.config.n_model_workers * (10**12)) + self.__rpc_ctrl.flops_counter.clear() + ######################################### + + epoch = self.__rpc_ctrl.step_info.epoch + 1 + epoch_step = self.__rpc_ctrl.step_info.epoch_step + 1 + global_step = self.__rpc_ctrl.step_info.global_step + 1 + s = f"Epoch {epoch}/{self.config.exp_ctrl.total_train_epochs} " + s += f"step {epoch_step}/{self._steps_per_epoch} " + s += f"(global step {global_step}) finishes. " + s += f"Average #tokens per batch is {self._avg_tokens_per_batch:.0f}. " + s += f"#End to end# execution time: *{e2e_time:.3f}*s. " + s += f"Total time consumption: {time_since_configure:.3f}s. " + if len(self.e2e_time_history) > 2: + remaining_steps = self._steps_per_epoch - epoch_step + remaining_epochs = self.__total_train_epochs - epoch + avg_t = np.mean(self.e2e_time_history[2:]) + remain_t = avg_t * remaining_steps + remain_t += avg_t * self._steps_per_epoch * remaining_epochs + s += f"Estimated remaining time: {remain_t:.3f}s. " + if flops is not None: + s += f"TFLOP/s per GPU: {tflops_per_gpu:.2f}, total TFLOP/s: {tflops:.2f}." + logger.info(s) + logger.info( + f"Time taken so far across all configurations: {time.perf_counter() - self.global_exp_tik:.2f}s" + ) + + def _clear_gpu_cache(self): + self.__stream.request( + handlers=list(range(self.config.n_model_workers)), + handle_type="clear_data_cache", + datas=[ + self.__rpc_ctrl.ids_to_clear + for _ in list(range(self.config.n_model_workers)) + ], + no_syn=True, + ) + self.__rpc_ctrl.ids_to_clear.clear() + + def experiment_complete_exit(self): + logger.info( + colorama.Style.RESET_ALL + + colorama.Fore.YELLOW + + colorama.Style.BRIGHT + + "\033[1m" + + "Experiment Completes! Yeah!!!!!!!!" + + colorama.Style.RESET_ALL + ) + + # Send requests to pause model workers. + # Model workers will not respond to this message. + self.__stream.request( + handlers=list(range(self.config.n_model_workers)), + handle_type="reset", + datas=[None for _ in list(range(self.config.n_model_workers))], + ) + self.__stream.close() + constants.reset_run() + # Reset names used for distributed training. + # The next round of training will set up a new distributed environment. + name_resolve.clear_subtree( + names.distributed_root(constants.experiment_name(), constants.trial_name()) + ) + name_resolve.clear_subtree( + names.request_reply_stream_root( + constants.experiment_name(), constants.trial_name() + ) + ) + + wandb.finish() + gc.collect() + self.__initialized = False + self.pause() + return worker_base.PollResult(0, 0) + + def __recover_save(self): + # save step info for recover + if os.getenv("REAL_SAVE_RECOVER_STATES", "0") != "1": + return + # save step info for recover + this_step_info = copy.deepcopy(self.__rpc_ctrl.step_info) + recover_info = recover.RecoverInfo( + recover_start=this_step_info, + last_step_info=self.__last_step_info, + save_ctl_info=self.__save_ctl.state_dict(), + ckpt_ctl_info=self.__ckpt_ctl.state_dict(), + eval_ctl_info=self.__eval_ctl.state_dict(), + hash_vals_to_ignore=self.__rpc_ctrl.used_hash_vals_this_epoch, + ) + + recover.dump_recover_info(recover_info) + logger.info("Dumped recover info to file.") + logger.info(f"Will recover from: {recover_info.recover_start}") + logger.info( + f"Number of data used in this epoch: {len(recover_info.hash_vals_to_ignore)}" + ) diff --git a/realhf/system/worker_base.py b/realhf/system/worker_base.py index bfe256b..c63e876 100644 --- a/realhf/system/worker_base.py +++ b/realhf/system/worker_base.py @@ -13,11 +13,7 @@ import time from typing import Any, Dict, List, Optional, Tuple import realhf.api.core.system_api as system_api -from realhf.base import ( - logging, - name_resolve, - names, -) +from realhf.base import logging, name_resolve, names from realhf.base.gpu_utils import set_cuda_device logger = logging.getLogger("worker") diff --git a/tests/experiments/test_math_ppo.py b/tests/experiments/test_math_ppo.py index f9e6a72..185526c 100644 --- a/tests/experiments/test_math_ppo.py +++ b/tests/experiments/test_math_ppo.py @@ -47,6 +47,8 @@ def math_dataset(request, save_path): return dataset +# NOTE: we can't test v1 and v2 at the same time. +@pytest.mark.parametrize("use_v2_worker", [True]) @pytest.mark.parametrize( "dp,pp,mp", [ @@ -66,6 +68,7 @@ def test_ppo_symm( dp, pp, mp, + use_v2_worker, ): # Setup experiment env. Should be done before any other operations. log_root = tmp_path_factory.mktemp("ppo") @@ -117,8 +120,10 @@ def test_ppo_symm( ), ), ) + exp_cfg.actor.vllm.hybrid_train = True + exp_cfg.actor.vllm.enforce_eager = True - run_test_exp(exp_cfg) + run_test_exp(exp_cfg, use_v2_worker=use_v2_worker) # The global resharding strategy, where all MFCs @@ -238,6 +243,8 @@ def test_ppo_global_reshard( ), ), ) + exp_cfg.actor.vllm.hybrid_train = True + exp_cfg.actor.vllm.enforce_eager = True run_test_exp(exp_cfg) @@ -351,6 +358,8 @@ def test_ppo_param_realloc_sub_device_mesh( ), ), ) + exp_cfg.actor.vllm.hybrid_train = True + exp_cfg.actor.vllm.enforce_eager = True run_test_exp(exp_cfg) @@ -470,6 +479,8 @@ def test_ppo_save( ), ), ) + exp_cfg.actor.vllm.hybrid_train = True + exp_cfg.actor.vllm.enforce_eager = True run_test_exp(exp_cfg) diff --git a/tests/experiments/test_sft.py b/tests/experiments/test_sft.py index 8407673..04c1e4d 100644 --- a/tests/experiments/test_sft.py +++ b/tests/experiments/test_sft.py @@ -38,6 +38,8 @@ def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp) +# NOTE: we can't test v1 and v2 at the same time. +@pytest.mark.parametrize("use_v2_worker", [True]) @pytest.mark.parametrize( "dp,pp,tp", [ @@ -47,7 +49,9 @@ def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp (1, 1, 2), ], ) -def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp): +def test_sft( + tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp, use_v2_worker +): # Setup experiment env. Should be done before any other operations. log_root = tmp_path_factory.mktemp("sft") @@ -79,4 +83,4 @@ def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp): ), ) - run_test_exp(exp_cfg) + run_test_exp(exp_cfg, use_v2_worker=use_v2_worker) diff --git a/tests/experiments/utils.py b/tests/experiments/utils.py index b1fdd1b..c773a3e 100644 --- a/tests/experiments/utils.py +++ b/tests/experiments/utils.py @@ -44,13 +44,22 @@ def run_model_worker(cfg, mw, barrier): initd = True -def run_test_exp(exp_cfg: Experiment, expr_name=None, trial_name=None): +def run_test_exp( + exp_cfg: Experiment, + expr_name=None, + trial_name=None, + use_v2_worker: bool = False, +): constants.set_force_cpu(True) # Register all datasets and models import realhf.impl.dataset # isort: skip import realhf.impl.model # isort: skip from realhf.api.core import system_api - from realhf.system.master_worker import MasterWorker + + if not use_v2_worker: + from realhf.system.master_worker import MasterWorker + else: + from realhf.system.v2.master_worker import MasterWorker system_api.ALL_EXPERIMENT_CLASSES = {} register_experiment(testing._DEFAULT_EXPR_NAME, lambda: exp_cfg)