AReaL/realhf/system/master_worker.py

1358 lines
55 KiB
Python

# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import collections
import copy
import dataclasses
import gc
import os
import random
import time
import uuid
from collections import defaultdict
from typing import Dict, List, Set, Tuple
import colorama
import networkx as nx
import numpy as np
import wandb
from tensorboardX import SummaryWriter
import realhf.api.core.config as config_api
import realhf.api.core.data_api as data_api
import realhf.api.core.dfg as dfg
import realhf.api.core.model_api as model_api
import realhf.api.core.system_api as config_pkg
import realhf.base.recover as recover
import realhf.system.request_reply_stream as request_reply_stream
import realhf.system.worker_base as worker_base
from realhf.api.core.config import ModelName
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import (
constants,
logging,
name_resolve,
names,
seeding,
timeutil,
topology,
)
from realhf.system.buffer import AsyncIOSequenceBuffer
from realhf.system.flops_counter import FlopsCounter
logger = logging.getLogger("master worker", "system")
blogger = logging.getLogger("benchmark")
def _attach_param_realloc_hooks(
payload: request_reply_stream.Payload,
msid2mwid: Dict[config_pkg.ModelShardID, int],
from_model_name: ModelName,
to_model_name: ModelName,
from_topo: topology.PipeModelDataParallelTopology,
to_topo: topology.PipeModelDataParallelTopology,
to_model_config: ReaLModelConfig,
pre: bool,
) -> request_reply_stream.Payload:
model_name = from_model_name
target = to_model_name
# Prioritize handlers of `from_model`, then handlers of `to_model`.
# As a result, if both `from_model` and `to_model` reside in a model worker,
# the handler in the received request will be `from_model`. Layers will also built in `from_model`.
# After that, we assign layers of the `from_model` to `to_model`.
handlers = [
config_pkg.ModelShardID.from_parallelism_rank(model_name, from_topo, j)
for j in range(from_topo.world_size())
]
all_handler_mwids = set([msid2mwid[h] for h in handlers])
dst_handlers = [
config_pkg.ModelShardID.from_parallelism_rank(target, to_topo, j)
for j in range(to_topo.world_size())
]
for h in dst_handlers:
if msid2mwid[h] not in all_handler_mwids:
handlers.append(h)
all_handler_mwids.add(msid2mwid[h])
ps_data = {
"from_model_name": model_name,
"to_model_name": target,
"from_topo": from_topo,
"to_topo": to_topo,
"to_model_config": to_model_config,
"eta": 1.0,
}
if pre:
payload.pre_hooks.append("param_realloc")
payload.pre_hook_data.append(ps_data)
else:
payload.post_hooks.append("param_realloc")
payload.post_hook_data.append(ps_data)
return payload
@dataclasses.dataclass
class RPCCorountineControl:
## Shared resources ##
stop: asyncio.Event
# for counting the number of finished training steps
# one training step corresponds to traversal of the whole DFG
train_count: asyncio.Queue
# For flushing requests
topo_level_count: asyncio.Queue
## Per-coroutine resources ##
# Used for counting the number of concurrent calls.
rpc_concurrency: Dict[str, asyncio.Semaphore]
rpc_traversal: Dict[str, int]
# for synchronizing req ids between req and reply coroutines
request_queues: Dict[str, asyncio.Queue]
# for training data management and data cleaning after each step
ids_to_clear: Set[int] = dataclasses.field(default_factory=set)
flops_counter: FlopsCounter = dataclasses.field(default_factory=FlopsCounter)
should_save: bool = False
should_eval: bool = False
should_ckpt: bool = False
step_info: recover.StepInfo = dataclasses.field(default_factory=recover.StepInfo)
# recover information
used_hash_vals_this_epoch: List[int] = dataclasses.field(default_factory=list)
hash_vals_to_ignore_in_recover: List[int] = dataclasses.field(default_factory=list)
def _attach_payloads_with_hooks(
rpc: dfg.MFCDef,
payloads: Dict[config_api.ModelShardID, request_reply_stream.Payload],
mwids: List[int],
msid2mwid: Dict[config_pkg.ModelShardID, int],
model_configs: Dict[str, None | ReaLModelConfig],
model_topos: Dict[str, topology.PipeModelDataParallelTopology],
main_handlers: List[config_pkg.ModelShardID],
hook_type: str,
) -> Tuple[Dict[config_api.ModelShardID, request_reply_stream.Payload], List[int]]:
assert hook_type in ["pre", "post"], hook_type
main_mwids = set([msid2mwid[h] for h in main_handlers])
for hook in getattr(rpc, f"_{hook_type}_hooks"):
if isinstance(hook, dfg.ParamReallocHook):
assert (hook.source is None) != (hook.target is None), hook
if hook.source is None:
src_topo = model_topos[rpc.model_name]
dst_topo = model_topos[hook.target]
dst_config = model_configs[hook.target]
src_model_name, dst_model_name = rpc.model_name, hook.target
other_model_name = hook.target
other_topo = dst_topo
else:
src_topo = model_topos[hook.source]
dst_topo = model_topos[rpc.model_name]
dst_config = model_configs[rpc.model_name]
src_model_name, dst_model_name = hook.source, rpc.model_name
other_model_name = hook.source
other_topo = src_topo
ps_data = {
"from_model_name": src_model_name,
"to_model_name": dst_model_name,
"from_topo": src_topo,
"to_topo": dst_topo,
"to_model_config": dst_config,
"eta": hook.eta,
}
for h in main_handlers:
getattr(payloads[h], f"{hook_type}_hooks").append("param_realloc")
getattr(payloads[h], f"{hook_type}_hook_data").append(ps_data)
other_handlers = [
config_api.ModelShardID.from_parallelism_rank(
other_model_name, other_topo, j
)
for j in range(other_topo.world_size())
]
for h in other_handlers:
if msid2mwid[h] not in mwids:
payloads[h] = request_reply_stream.Payload(
handler=h,
handle_name="empty",
)
setattr(payloads[h], f"{hook_type}_hooks", ["param_realloc"])
setattr(payloads[h], f"{hook_type}_hook_data", [ps_data])
mwids.append(msid2mwid[h])
elif msid2mwid[h] not in main_mwids:
hh = next(hh for hh in payloads if msid2mwid[hh] == msid2mwid[h])
getattr(payloads[hh], f"{hook_type}_hooks").append("param_realloc")
getattr(payloads[hh], f"{hook_type}_hook_data").append(ps_data)
elif isinstance(hook, dfg.OffloadHook):
for h in main_handlers:
getattr(payloads[h], f"{hook_type}_hooks").append("offload")
getattr(payloads[h], f"{hook_type}_hook_data").append(
dict(model_name=h.model_name)
)
else:
raise NotImplementedError(f"Unknown hook type: {hook}")
return payloads, mwids
def _request_model_function_call(
rpc: dfg.MFCDef,
stream: request_reply_stream.NameResolvingRequestClient,
msid2mwid: Dict[config_pkg.ModelShardID, int],
model_topos: Dict[str, topology.PipeModelDataParallelTopology],
model_configs: Dict[str, None | ReaLModelConfig],
producer_names: Dict[str, str],
producer_name2producer_handlers: Dict[str, List[config_pkg.ModelShardID]],
producer_mappings: Dict[str, Dict[str, List[int]]],
target_mapping: Dict[str, List[int]],
meta_sample: data_api.SequenceSample,
handlers: List[config_pkg.ModelShardID],
ctrl: RPCCorountineControl,
model_save_root: str,
) -> List[uuid.UUID]:
dt_data = {
"keys": rpc.input_keys,
"target": rpc.model_name,
"producer_names": producer_names,
"producer_mappings": producer_mappings,
"target_mapping": target_mapping,
"handle_name": rpc.interface_type.value,
"rpc_name": rpc.name,
"meta_sample": meta_sample,
}
payloads = {
handler: request_reply_stream.Payload(
handler=handler,
handle_name=rpc.interface_type.value,
pre_hooks=["data_transfer"],
pre_hook_data=[dt_data],
data=rpc.name,
)
for handler in handlers
}
if ctrl.should_eval:
for p in payloads.values():
p.post_hooks.append("evaluate")
p.post_hook_data.append(dict(model_name=rpc.model_name))
if (
ctrl.should_save or ctrl.should_ckpt
) and rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
for p in payloads.values():
p.post_hooks.append("save")
save_dir = os.path.join(
model_save_root,
rpc.model_name.role,
f"epoch{ctrl.step_info.epoch + 1}"
f"epochstep{ctrl.step_info.epoch_step + 1}"
f"globalstep{ctrl.step_info.global_step + 1}",
)
p.post_hook_data.append(
dict(
model_name=rpc.model_name,
save_dir=save_dir,
recover_only=not ctrl.should_save,
)
)
mwids = [msid2mwid[h] for h in handlers]
assert len(mwids) == len(set(mwids))
for producer_name in producer_names.values():
for h in producer_name2producer_handlers[producer_name]:
if msid2mwid[h] not in mwids:
payloads[h] = request_reply_stream.Payload(
handler=h,
handle_name="empty",
pre_hooks=["data_transfer"],
pre_hook_data=[dt_data],
)
mwids.append(msid2mwid[h])
payloads, mwids = _attach_payloads_with_hooks(
rpc,
payloads,
mwids,
msid2mwid=msid2mwid,
model_configs=model_configs,
model_topos=model_topos,
main_handlers=handlers,
hook_type="pre",
)
payloads, mwids = _attach_payloads_with_hooks(
rpc,
payloads,
mwids,
msid2mwid=msid2mwid,
model_configs=model_configs,
model_topos=model_topos,
main_handlers=handlers,
hook_type="post",
)
main_payloads = [p for h, p in payloads.items() if h in handlers]
other_payloads = [p for h, p in payloads.items() if h not in handlers]
all_req_ids = stream.request(
payloads=main_payloads + other_payloads,
)
return all_req_ids[: len(main_payloads)], all_req_ids[len(main_payloads) :]
async def model_rpc_request_func(
rpc: dfg.MFCDef,
msid2mwid: Dict[config_pkg.ModelShardID, int],
src_rpc_model_name: ModelName,
stream: request_reply_stream.NameResolvingRequestClient,
buffer: AsyncIOSequenceBuffer,
data_owner: Dict[Tuple[int, str], Tuple[ModelName, int]],
model_topos: Dict[str, topology.PipeModelDataParallelTopology],
model_configs: Dict[str, None | ReaLModelConfig],
model_save_root: str,
ctrl: RPCCorountineControl,
):
"""The corountine for sending requests to model workers."""
topo = model_topos[rpc.model_name]
logger.info(
f"Requesting Model RPC, interface_type=#{rpc.interface_type}# "
f"(dp,mp,pp) = *({topo.get_dim('data')},{topo.get_dim('model')},{topo.get_dim('pipe')})*"
)
topo = model_topos[rpc.model_name]
handlers = [
config_pkg.ModelShardID.from_parallelism_rank(rpc.model_name, topo, j)
for j in range(topo.world_size())
]
producer_names = {} # data key -> model name
for k in rpc.input_keys:
if k in rpc.data_producers:
producer_names[k] = rpc.data_producers[k]
else:
producer_names[k] = src_rpc_model_name
keys_to_send = defaultdict(list) # model name -> List[keys] to send
for k in producer_names:
keys_to_send[producer_names[k]].append(k)
# convert producer model name to ModelShardID
producer_name2producer_handlers = {}
for producer_name in keys_to_send:
producer_name2producer_handlers[producer_name] = [
config_pkg.ModelShardID.from_parallelism_rank(
producer_name, model_topos[producer_name], j
)
for j in range(model_topos[producer_name].world_size())
]
request_queue = ctrl.request_queues[rpc.name]
rpc_concurrency = ctrl.rpc_concurrency[rpc.name]
this_rpc_consumed_seqs = 0
while not ctrl.stop.is_set():
await rpc_concurrency.acquire()
# Ensure that parent RPCs will not be over-consumed.
while any(
this_rpc_consumed_seqs >= (ctrl.rpc_traversal[c.name] + 1) * c.n_seqs
for c in rpc.all_successors()
):
await asyncio.sleep(0.1)
buf_indices, sample = await buffer.get_batch_for_rpc(rpc)
ctrl.flops_counter.add_rpc(rpc, sample, model_configs[rpc.model_name])
this_rpc_consumed_seqs += sample.bs
# logger.info(f"Model rpc {rpc.name} requesting.")
# Dispatch data to different data parallel ranks.
dp_size = topo.get_dim("data")
if rpc.is_generate():
# The workload of generation is decided by batch size, instead of the generated length.
samples, forward_indices, _ = sample.split_with_lengths(
mb_spec=data_api.MicroBatchSpec(n_mbs=dp_size),
lens=[1 for _ in range(sample.bs)],
)
else:
samples, forward_indices, _ = sample.split(
data_api.MicroBatchSpec(n_mbs=dp_size)
)
blogger.info(
f"DP split (DP size {dp_size}) for RPC {rpc.name}: "
f"#seqs: {[s.bs for s in samples]}, "
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in samples]}"
)
sample = data_api.SequenceSample.gather(samples)
buf_indices = [buf_indices[i] for i in forward_indices]
partitions = data_api.SequenceSplitSpec(
sizes=[s.bs for s in samples]
).partitions
target_mapping = {i: list(range(v[0], v[1])) for i, v in enumerate(partitions)}
# Set data owner of produced data by this RPC, such that downstream RPCs can know
# where to fetch these data.
for dp_idx, (st, ed) in enumerate(partitions):
for i in range(st, ed):
for k in rpc.output_keys:
data_owner[sample.ids[i], k] = (rpc.model_name, dp_idx)
# Get the data owner of this RPC's input data.
# We use it to determine the source of data transfer.
producer_mappings = {}
for k in rpc.input_keys:
names, dp_indices = [], []
for sample_id in sample.ids:
owner_name, dp_idx = data_owner[(sample_id, k)]
names.append(owner_name)
dp_indices.append(dp_idx)
assert len(set(names)) == 1
producer_mapping = defaultdict(list)
for i, dp_idx in enumerate(dp_indices):
producer_mapping[dp_idx].append(i)
producer_mapping = {k: sorted(v) for k, v in producer_mapping.items()}
producer_mappings[names[0], k] = producer_mapping
# send partitioned data to model workers
req_ids, other_req_ids = _request_model_function_call(
rpc=rpc,
stream=stream,
msid2mwid=msid2mwid,
model_topos=model_topos,
model_configs=model_configs,
producer_names=producer_names,
producer_name2producer_handlers=producer_name2producer_handlers,
producer_mappings=producer_mappings,
target_mapping=target_mapping,
meta_sample=sample,
handlers=handlers,
ctrl=ctrl,
model_save_root=model_save_root,
)
await request_queue.put(
(buf_indices, sample.ids, req_ids, other_req_ids, time.perf_counter())
)
await ctrl.topo_level_count.put(1)
logger.info(f"Model rpc {rpc.name} requested.")
async def model_rpc_reply_func(
rpc: dfg.MFCDef,
stream: request_reply_stream.NameResolvingRequestClient,
buffer: AsyncIOSequenceBuffer,
model_topos: Dict[str, topology.PipeModelDataParallelTopology],
ctrl: RPCCorountineControl,
summary_writer: SummaryWriter,
):
topo = model_topos[rpc.model_name]
dp_size = topo.get_dim("data")
dp_head_indices = [
topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, model=0)
for i in range(dp_size)
]
request_queue = ctrl.request_queues[rpc.name]
rpc_concurrency = ctrl.rpc_concurrency[rpc.name]
while not ctrl.stop.is_set():
# Wait for master worker's request.
buf_indices, ids, req_ids, other_req_ids, tik = await request_queue.get()
# Then, wait for all main requests to finish.
responses = await stream.gather_async(request_ids=req_ids)
# logger.info(f"rpc {rpc.name} received responses {req_ids}")
# Filter out responses other than DP heads.
# Other repsonses are duplicated or None.
responses: List[request_reply_stream.Payload] = [
responses[i] for i in dp_head_indices
]
# If the returned data is a SequenceSample, it is the data returned by
# model function calls. The data shoulbe be amended into buffer.
# Otherwise, it's the train statistics and should be reduced and logged.
if isinstance(responses[-1], data_api.SequenceSample):
res = data_api.SequenceSample.gather(responses)
else:
res = data_api.gather_stat(responses)
if rpc.log_return_value:
logger.info(f"RPC name {rpc.name} returns {res}")
if isinstance(res, Dict):
wandb.log(res, step=ctrl.step_info.global_step)
if summary_writer is not None:
for key, val in res.items():
summary_writer.add_scalar(
f"{key}", val, ctrl.step_info.global_step
)
logger.info(
f"Model rpc {rpc.name} finished. Run time {time.perf_counter() - tik:.4f}s."
)
# Release the semaphore to let the request corountine continue running.
rpc_concurrency.release()
ctrl.rpc_traversal[rpc.name] += 1
# If this RPC is the final node in the dataflow graph,
# update the train counter.
# Otherwise, amend data in the buffer.
if rpc.is_dst:
ctrl.ids_to_clear = ctrl.ids_to_clear.union(ids)
await ctrl.train_count.put(1)
else:
logger.info(f"Amending RPC {rpc.name} output keys: {res.keys}")
await buffer.amend_batch(buf_indices, res.unpack())
# Wait for all side-effect requests to finish.
# Side-effect or empty requests are required for data transfer
# and parameter synchronization.
# Wait them after the main request to log the oorrect MFC time.
await stream.gather_async(other_req_ids)
class MasterWorker(worker_base.Worker):
os.makedirs(constants.MODEL_SAVE_ROOT, exist_ok=True)
global_exp_tik = time.perf_counter()
def _configure(self, config: config_pkg.MasterWorker):
self.config = config
seeding.set_random_seed(self.config.base_seed + self.config.n_model_workers)
self.__model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = (
config.model_topos
)
# Build execution graph and initialize concurrency utilities.
self.__model_rpcs = config.model_rpcs
# Sort all MFCs in the topological order and
# calculate the width of each level.
# These numbers will determine when to flush MFC requests.
self.__topo_widths = []
for generation in nx.topological_generations(self.__model_rpcs[0]._G):
self.__topo_widths.append(len(generation))
logger.info("Topological widths: " + str(self.__topo_widths))
self.__mwid2msids = defaultdict(list)
for msid, mwid in self.config.msid2mwid.items():
self.__mwid2msids[mwid].append(msid)
self.__rpc_srcs = list(filter(lambda rpc: rpc.is_src, self.__model_rpcs))
self.__rpc_dsts = list(filter(lambda rpc: rpc.is_dst, self.__model_rpcs))
self.__n_rpc_srcs = len(self.__rpc_srcs)
self.__n_rpc_dsts = len(self.__rpc_dsts)
# Save and eval control.
self.__total_train_epochs = config.exp_ctrl.total_train_epochs
self.__save_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.save_freq_epochs,
freq_step=config.exp_ctrl.save_freq_steps,
freq_sec=config.exp_ctrl.save_freq_secs,
)
if (
config.exp_ctrl.ckpt_freq_epochs is None
and config.exp_ctrl.ckpt_freq_steps is None
and config.exp_ctrl.ckpt_freq_secs is None
):
self.__ckpt_ctl = self.__save_ctl
else:
self.__ckpt_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.ckpt_freq_epochs,
freq_step=config.exp_ctrl.ckpt_freq_steps,
freq_sec=config.exp_ctrl.ckpt_freq_secs,
)
self.__eval_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.eval_freq_epochs,
freq_step=config.exp_ctrl.eval_freq_steps,
freq_sec=config.exp_ctrl.eval_freq_secs,
)
self.MODEL_SAVE_ROOT = os.path.join(
constants.MODEL_SAVE_ROOT,
config.worker_info.experiment_name,
config.worker_info.trial_name,
)
os.makedirs(self.MODEL_SAVE_ROOT, exist_ok=True)
self.__initialized = False
self.__recover_run, self.__recover_info = recover.load_recover_info()
if self.__recover_info is not None:
logger.info(
f"Loaded recover info: recover_start={self.__recover_info.recover_start}, "
f"last_step_info={self.__recover_info.last_step_info}."
)
logger.info(
f"Number of used data in recover info: {len(self.__recover_info.hash_vals_to_ignore)}. "
f"The previous experiment probably ran for {len(self.__recover_info.hash_vals_to_ignore) // self.__rpc_srcs[0].n_seqs} steps in the epoch."
)
# Create corountine control objects for running the dataflow graph.
self.__rpc_ctrl = RPCCorountineControl(
stop=asyncio.Event(),
train_count=asyncio.Queue(maxsize=len(self.__rpc_dsts)),
topo_level_count=asyncio.Queue(maxsize=sum(self.__topo_widths)),
rpc_traversal={rpc.name: 0 for rpc in self.__model_rpcs},
request_queues={rpc.name: asyncio.Queue(1) for rpc in self.__model_rpcs},
rpc_concurrency={
rpc.name: asyncio.Semaphore(1) for rpc in self.__model_rpcs
},
# NOTE: We should accumulate the used data hashes in the same epoch
# to prevent loading data used before.
used_hash_vals_this_epoch=(
copy.deepcopy(self.__recover_info.hash_vals_to_ignore)
if self.__recover_run
else list()
),
hash_vals_to_ignore_in_recover=(
copy.deepcopy(self.__recover_info.hash_vals_to_ignore)
if self.__recover_run
else list()
),
)
if self.__recover_run:
self.__rpc_ctrl.step_info = copy.deepcopy(self.__recover_info.recover_start)
self.__eval_ctl.load_state_dict(self.__recover_info.eval_ctl_info)
self.__save_ctl.load_state_dict(self.__recover_info.save_ctl_info)
self.__ckpt_ctl.load_state_dict(self.__recover_info.ckpt_ctl_info)
logger.info(
f"Recovering from previous run. "
f"Epoch: {self.__rpc_ctrl.step_info.epoch + 1}, "
f"Epoch Step: {self.__rpc_ctrl.step_info.epoch_step + 1} "
f"Global Step: {self.__rpc_ctrl.step_info.global_step + 1}."
)
# for benchmark
self.e2e_time_history = []
self.__benchmark_steps = config.exp_ctrl.benchmark_steps
return config.worker_info
def __lazy_init(self):
# Set up streams.
handler_routing = copy.deepcopy(self.config.msid2mwid)
src_rpc = self.__rpc_srcs[0]
src_rpc_topo = self.config.model_topos[src_rpc.model_name]
src_rpc_dp_size = src_rpc_topo.get_dim("data")
src_rpc_pp_size = src_rpc_topo.get_dim("pipe")
for i in range(src_rpc_dp_size):
rank = src_rpc_topo.get_rank(data=i, pipe=src_rpc_pp_size - 1, model=0)
handler_routing[f"__data{i}__"] = self.config.msid2mwid[
config_pkg.ModelShardID.from_parallelism_rank(
model_name=src_rpc.model_name,
topo=src_rpc_topo,
parallelism_rank=rank,
)
]
self.__stream = request_reply_stream.make_master_stream(
self.config.worker_info,
n_subscribers=self.config.n_model_workers,
handler_routing=handler_routing,
)
self.__stream: request_reply_stream.NameResolvingRequestClient
self.__src_rpc = src_rpc = [
rpc for rpc in self.config.model_rpcs if rpc.is_src
][0]
src_rpc_model_name = src_rpc.model_name
self.__src_rpc_dp_size = src_rpc_dp_size = self.config.model_topos[
src_rpc.model_name
].get_dim("data")
# Request training specification from data workers.
all_data = sum(
self.__stream.call(
handlers=[f"__data{i}__" for i in range(src_rpc_dp_size)],
datas=[None for i in range(src_rpc_dp_size)],
handle_type="spec",
),
[],
)
# NOTE: For dynamic datasets, we still count epoch according to the initial number of data,
# such that the learning rate decay is not affected.
seqlens = [max(sum(v[0]) for v in x.seqlens.values()) for x in all_data]
self._dataset_size = len(all_data)
self._steps_per_epoch = self._dataset_size // src_rpc.n_seqs
self._avg_tokens_per_batch = sum(seqlens) / self._steps_per_epoch
self._dataset_ids = [copy.deepcopy(x.ids[0]) for x in all_data]
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
# Build some data required for subsequent model function calls.
self.__all_model_handlers: List[config_pkg.ModelShardID] = []
self.__all_mw_handlers: List[config_pkg.ModelShardID] = []
_covered_mws = set()
self.__dp0_model_handlers: List[config_pkg.ModelShardID] = []
self.__trainable_model_handlers: List[config_pkg.ModelShardID] = []
for model_name, topo in self.config.model_topos.items():
for j in range(topo.world_size()):
h = config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, j)
_mw_id = self.config.msid2mwid[h]
if _mw_id not in _covered_mws:
_covered_mws.add(_mw_id)
self.__all_mw_handlers.append(h)
num_dp = topo.get_dim("data")
self.__all_model_handlers += [
config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, j)
for j in range(topo.world_size())
]
if any(
rpc.model_name == model_name
and rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP
for rpc in self.__model_rpcs
):
self.__trainable_model_handlers += [
config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, j)
for j in range(topo.world_size())
]
self.__dp0_model_handlers += [
config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, j)
for j in topo.filter_match(data=0)
]
# Request model configs from model workers.
# Return None if the model is not a ReaLModel.
self.__model_configs: Dict[ModelName, None | ReaLModelConfig] = {}
for model_name in self.config.model_topos:
h = config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, 0)
self.__model_configs[model_name] = self.__stream.call(
handlers=[h],
datas=[None],
handle_type="model_config",
)[0]
# Initialize model backends.
# For models with the same role, they share the same model parameters.
# Therefore, we must call reallocate parameters from A to B
# before we send requests to initialize B.
_param_senders = [v[0] for v in self.config.sync_param_pairs]
_param_receivers = [v[1] for v in self.config.sync_param_pairs]
# The parameters are by default held by the trainable model.
# If all replicas are not trainable, the parameters are held in replica 0.
_model_is_trainable = collections.defaultdict(list)
for rpc in self.__model_rpcs:
_model_is_trainable[rpc.model_name].append(
rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP
)
_model_is_trainable = {
model_name: any(values)
for model_name, values in _model_is_trainable.items()
}
_roles = set([rpc.model_name.role for rpc in self.__model_rpcs])
_role_cnt = {
role: len(
set(
[
rpc.model_name
for rpc in self.__model_rpcs
if rpc.model_name.role == role
]
)
)
for role in _roles
}
_reordered_model_names = []
for role in sorted(_roles):
if _role_cnt[role] == 1:
_reordered_model_names.append(ModelName(role, 0))
continue
_indices = list(range(_role_cnt[role]))
_trainable_this_role = [
_model_is_trainable[ModelName(role, i)] for i in range(_role_cnt[role])
]
if any(_trainable_this_role):
assert (
sum(_trainable_this_role) == 1
), "only one train for each model is allowed"
_trainable_idx = _trainable_this_role.index(True)
_reordered_model_names.append(ModelName(role, _trainable_idx))
_indices.remove(_trainable_idx)
for i in _indices:
_reordered_model_names.append(ModelName(role, i))
# Send initialization requests.
self.logger.info(
f"Initialize model backends with order: {_reordered_model_names}."
)
train_rpcs = list(
filter(
lambda rpc: rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
self.__model_rpcs,
)
)
assert all(rpc.n_seqs == train_rpcs[0].n_seqs for rpc in train_rpcs)
if len(train_rpcs) > 0:
ft_spec = model_api.FinetuneSpec(
total_train_epochs=self.config.exp_ctrl.total_train_epochs,
dataset_size=self._dataset_size,
train_batch_size=train_rpcs[0].n_seqs,
)
else:
ft_spec = model_api.FinetuneSpec(
total_train_epochs=self.config.exp_ctrl.total_train_epochs,
dataset_size=self._dataset_size,
train_batch_size=self.__src_rpc.n_seqs,
)
_initialized_roles = []
for model_name in _reordered_model_names:
topo = self.config.model_topos[model_name]
# Build FinetuneSpec, which is required to initialize backends.
_handlers = [
config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, j)
for j in range(topo.world_size())
]
init_payloads = [
request_reply_stream.Payload(
handler=_h,
handle_name="initialize",
data=ft_spec,
)
for _h in _handlers
]
# Reallocate parameters if necessary.
if model_name.role in _initialized_roles and model_name in _param_receivers:
_param_realloc_src = _param_senders[_param_receivers.index(model_name)]
# Update handler and payloads to send empty requests
# with param realloc hooks to source workers.
src_topo = self.config.model_topos[_param_realloc_src]
other_handlers = [
config_pkg.ModelShardID.from_parallelism_rank(
_param_realloc_src, src_topo, j
)
for j in range(src_topo.world_size())
]
main_mw_ids = [self.config.msid2mwid[_h] for _h in _handlers]
_other_hanlders = []
for other_h in other_handlers:
if self.config.msid2mwid[other_h] not in main_mw_ids:
_other_hanlders.append(other_h)
init_payloads += [
request_reply_stream.Payload(
handler=_h,
handle_name="empty",
)
for _h in _other_hanlders
]
_handlers = _handlers + _other_hanlders
for i, p in enumerate(init_payloads):
p = _attach_param_realloc_hooks(
payload=p,
msid2mwid=self.config.msid2mwid,
from_model_name=_param_realloc_src,
to_model_name=model_name,
from_topo=self.config.model_topos[_param_realloc_src],
to_topo=self.config.model_topos[model_name],
to_model_config=self.__model_configs[model_name],
pre=True,
)
init_payloads[i] = _attach_param_realloc_hooks(
payload=p,
msid2mwid=self.config.msid2mwid,
from_model_name=model_name,
to_model_name=_param_realloc_src,
to_topo=self.config.model_topos[_param_realloc_src],
from_topo=self.config.model_topos[model_name],
to_model_config=self.__model_configs[_param_realloc_src],
pre=False,
)
# Send initialization requests then immediately flush them.
self.__stream.request(
payloads=init_payloads,
)
self.__stream.request(
handlers=_handlers,
handle_type="flush",
no_syn=True,
)
_initialized_roles.append(model_name.role)
self._ft_spec = ft_spec
logger.info("Initializations of models and backends complete.")
self.__seqbuffer = AsyncIOSequenceBuffer(
self.__model_rpcs,
max_size=int(os.getenv("REAL_MASTER_BUFFER_SIZE", str(int(1e7)))),
)
self.__data_owner = {}
self.__summary_writer = None
if self.tensorboard_config.path is not None:
self.__summary_writer = SummaryWriter(log_dir=self.tensorboard_config.path)
logger.info(f"Creating asyncio coroutines...")
# Create coroutines for model RPCs.
coroutine_tasks = []
for rpc in self.__model_rpcs:
request_task = event_loop.create_task(
model_rpc_request_func(
rpc=rpc,
msid2mwid=self.config.msid2mwid,
src_rpc_model_name=src_rpc_model_name,
data_owner=self.__data_owner,
stream=self.__stream,
buffer=self.__seqbuffer,
model_topos=self.__model_topos,
model_configs=self.__model_configs,
ctrl=self.__rpc_ctrl,
model_save_root=self.MODEL_SAVE_ROOT,
)
)
reply_task = event_loop.create_task(
model_rpc_reply_func(
rpc=rpc,
stream=self.__stream,
buffer=self.__seqbuffer,
model_topos=self.__model_topos,
ctrl=self.__rpc_ctrl,
summary_writer=self.__summary_writer,
)
)
coroutine_tasks += [request_task, reply_task]
# Import here to avoid the conflict with nvloop.
from realhf.base.asyncio_utils import setup_run_until_complete
# Set up a run context of EventLoop.run_util_complete, baiscally copy-paste from cpython.
# With this context, we can call the non-block EventLoop._run_once (similar to worker._poll).
self.__asyncio_tasks: List[asyncio.Task] = coroutine_tasks
self.__asyncio_ctx = setup_run_until_complete(
event_loop, asyncio.gather(*coroutine_tasks)
)
# wandb init, connect to remote wandb host
wandb.login()
wandb.init(
mode=self.wandb_config.mode,
entity=self.wandb_config.entity,
project=self.wandb_config.project or constants.experiment_name(),
name=self.wandb_config.name or f"{constants.trial_name()}_train",
job_type=self.wandb_config.job_type,
group=self.wandb_config.group
or f"{constants.experiment_name()}_{constants.trial_name()}",
notes=self.wandb_config.notes,
tags=self.wandb_config.tags,
config=self.wandb_config.config,
dir=os.path.join(
constants.LOG_ROOT, constants.experiment_name(), constants.trial_name()
),
force=True,
id=f"{constants.experiment_name()}_{constants.trial_name()}_train",
resume="allow",
settings=wandb.Settings(start_method="fork"),
)
logger.info(f"Coroutines created. The master worker is ready to run.")
self.__initialized = True
self._train_start_time = time.perf_counter()
self.__last_step_info = recover.StepInfo(
epoch=-1,
epoch_step=-1,
global_step=-1,
)
def _poll(self):
# Import here to avoid the conflict with nvloop.
from realhf.base.asyncio_utils import raise_asyncio_exception
is_new_epoch = False
first_poll = not self.__initialized
if not self.__initialized:
self.__lazy_init()
self._maybe_request_load_data(first_poll=first_poll)
# Main execution steps. The graph runs under-the-hood in RPC & stream threads.
# Wait for the finish of the traversal of the execution graph.
execution_start = time.perf_counter()
logger.info("Master worker is waiting for the finish of the execution graph.")
if self.__rpc_ctrl.ids_to_clear:
# Send clear cache requests to model workers.
# Clearing the data used in the last step.
self._clear_gpu_cache()
for _ in range(10):
self._maybe_request_load_data(first_poll=first_poll)
if self.__seqbuffer.size < self.__src_rpc.n_seqs:
raise RuntimeError(
f"Buffer size {self.__seqbuffer.size} smaller than "
f"required batch size {self.__src_rpc.n_seqs} after loading data. "
"This should not happen, but we raise an error to stop the experiment. "
"Is your dataset size larger than the configured batch size?"
)
is_new_epoch = self._ft_spec.is_new_epoch(self.__rpc_ctrl.step_info)
is_epoch_last_step = self._ft_spec.is_epoch_last_step(self.__rpc_ctrl.step_info)
# Check whether we should evaluate or save models.
self.__rpc_ctrl.should_eval = self.__eval_ctl.check(
epochs=int(is_epoch_last_step), steps=1
)
self.__rpc_ctrl.should_save = self.__save_ctl.check(
epochs=int(is_epoch_last_step), steps=1
)
self.__rpc_ctrl.should_ckpt = self.__ckpt_ctl.check(
epochs=int(is_epoch_last_step), steps=1
)
_rpc_dst_cnt = 0
_topo_level_cnt, _topo_level_idx = 0, 0
while _rpc_dst_cnt < self.__n_rpc_dsts:
try:
self.__rpc_ctrl.train_count.get_nowait()
_rpc_dst_cnt += 1
continue
except asyncio.QueueEmpty:
pass
try:
self.__rpc_ctrl.topo_level_count.get_nowait()
_topo_level_cnt += 1
if _topo_level_cnt >= self.__topo_widths[_topo_level_idx]:
logger.info(
f"Flushing the current level of the DFG with {self.__topo_widths[_topo_level_idx]} vertices."
)
self.__stream.request(
handlers=self.__all_mw_handlers,
handle_type="flush",
)
_topo_level_idx += 1
_topo_level_cnt = 0
continue
except asyncio.QueueEmpty:
pass
try:
# Similar to worker._poll. Run multiple times until a train step is finished.
self.__asyncio_ctx.loop._run_once()
# NOTE: The following line will propagate errors in corountines back to the main thread.
# It raises asyncio.exceptions.InvalidStateError if the result is not ready.
# (In our use cases, the result will never be ready because corountines run while-loops.)
# We just ignore this error and continue running.
self.__asyncio_ctx.future.result()
except asyncio.exceptions.InvalidStateError:
# Catch the exception when future.result() is not ready.
pass
except KeyboardInterrupt as e:
raise_asyncio_exception(self.__asyncio_ctx, raise_error=False)
raise e
except:
raise_asyncio_exception(self.__asyncio_ctx)
logger.info("Execution finished!")
if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt:
self.__last_step_info = copy.deepcopy(self.__rpc_ctrl.step_info)
self.__rpc_ctrl.used_hash_vals_this_epoch += list(self.__rpc_ctrl.ids_to_clear)
if is_epoch_last_step:
self.__rpc_ctrl.used_hash_vals_this_epoch = (
self.__rpc_ctrl.used_hash_vals_this_epoch[self._dataset_size :]
)
if is_new_epoch:
self.__rpc_ctrl.step_info.epoch += 1
self.__rpc_ctrl.step_info.epoch_step = 0
# Logging.
time_since_configure = time.perf_counter() - self._train_start_time
e2e_time = time.perf_counter() - execution_start
self.e2e_time_history.append(e2e_time)
self._log_training_stats(e2e_time, time_since_configure)
# Updata counters.
self.__rpc_ctrl.step_info.epoch_step += 1
self.__rpc_ctrl.step_info.global_step += 1
if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt:
self.__recover_save()
# Pause the worker if experiment or system-wise benchmark completes.
if (
self.__benchmark_steps is not None
and self.__rpc_ctrl.step_info.global_step >= self.__benchmark_steps
) or (
self.__rpc_ctrl.step_info.global_step * self.__src_rpc.n_seqs
>= self.__total_train_epochs * self._dataset_size
):
# We don't know whether it is the last step of the current epoch,
# so we exit at the first step of the next epoch.
if self.__benchmark_steps is not None:
logger.info(
f"Finished benchmark {self.__benchmark_steps}. "
f"Time consumption of this setup: {time_since_configure:.3f}"
)
logger.info(f"avg #e2e# time *{np.mean(self.e2e_time_history):.3f}*")
return self.experiment_complete_exit()
return worker_base.PollResult(sample_count=1, batch_count=1)
def _maybe_request_load_data(self, first_poll: bool):
should_load_data = self.__seqbuffer.size < self.__src_rpc.n_seqs
if not should_load_data:
return
blogger.info(
f"Current buffer size {self.__seqbuffer.size}/{self.__seqbuffer.max_size}. "
f"The batch size of the source MFC is {self.__src_rpc.n_seqs}."
)
src_rpc_dp_size = self.__src_rpc_dp_size
src_rpc = self.__src_rpc
src_rpc_model_name = src_rpc.model_name
data_owner = self.__data_owner
buffer = self.__seqbuffer
stream = self.__stream
ctrl = self.__rpc_ctrl
# fetch data from dataloader to fill the sequence buffer
blogger.info(f"Filling data into the buffer in a new epoch.")
fetch_data_start = time.perf_counter()
# NOTE: PyTorch dataloader will shuffle data for us.
all_data: List[data_api.SequenceSample] = []
received_ids = set()
# NOTE: Currently we send dataloading requests until iterating
# over the entire dataset. This may lead to a huge memory waste
# with super-large datasets. Empirically, it's fine.
is_final_batch = [False for _ in range(src_rpc_dp_size)]
is_first_batch = [True for _ in range(src_rpc_dp_size)]
while not all(is_final_batch):
# Send request to model workers to get the specification of data.
# Data itself is not transferred to the master worker.
data_batches: List[data_api.DataBatchMeta | None] = [
None for _ in range(src_rpc_dp_size)
]
for i in range(src_rpc_dp_size):
if is_final_batch[i]:
data_batches[i] = None
continue
data_batches[i] = stream.call(
handlers=[f"__data{i}__"],
handle_type="fetch",
datas=[dict(first_batch=is_first_batch[i], first_poll=first_poll)],
verbose=False,
)[0]
is_final_batch[i] = data_batches[i].is_final_batch
is_first_batch[i] = False
# Unpack batched sequences into individual sequences.
for dp_rank, x in enumerate(data_batches):
if x is None:
continue
if x.meta_sample is None:
continue
for xx in x.meta_sample.unpack():
if xx.ids[0] in received_ids:
raise ValueError(
f"Duplicate data id {xx.ids[0]}. Is the final batch? {is_final_batch}."
)
received_ids.add(xx.ids[0])
# Store the owner information of the data.
# RPCs corountines will use this information to
# determine the src and dst of data transfer.
for k in xx.keys:
data_owner[(xx.ids[0], k)] = (src_rpc_model_name, dp_rank)
all_data += x.meta_sample.unpack()
filtered_data = []
for x in all_data:
if x.ids[0] in ctrl.hash_vals_to_ignore_in_recover:
ctrl.hash_vals_to_ignore_in_recover.remove(x.ids[0])
ctrl.ids_to_clear.add(x.ids[0])
else:
filtered_data.append(x)
all_data = filtered_data
# We load data in a round-robin manner across different DP ranks,
# so we also need to shuffle the data to fuse different dataset splits.
random.shuffle(all_data)
blogger.info(
f"Master worker loaded {len(all_data)} pieces of data. "
f"Training epoch {self.__rpc_ctrl.step_info.epoch + 1} approximately has {self._steps_per_epoch} steps. "
f"Each batch has {self._avg_tokens_per_batch:.2f} tokens in average. "
f"Remaining number of data to ignore: {len(self.__rpc_ctrl.hash_vals_to_ignore_in_recover)}."
)
# Store into buffer!
buffer_indices = buffer.put_batch_synced(all_data)
assert len(buffer_indices) == len(all_data)
blogger.info(
f"Filling data finished. Time consumption: "
f"{time.perf_counter() - fetch_data_start:.3f}s."
)
# We should let model workers clear the data ignored during recover.
if ctrl.ids_to_clear:
self._clear_gpu_cache()
def _log_training_stats(self, e2e_time: float, time_since_configure: float):
# calculate flops
#########################################
if not all(
isinstance(v, ReaLModelConfig) for v in self.__model_configs.values()
):
logger.warning(
f"Not all models are ReaLModels. Unable to calculate FLOP/s."
)
flops = None
tflops_per_gpu = float("inf")
else:
flops = self.__rpc_ctrl.flops_counter.get_flops()
tflops = flops / (e2e_time * (10**12))
tflops_per_gpu = flops / (e2e_time * self.config.n_model_workers * (10**12))
self.__rpc_ctrl.flops_counter.clear()
#########################################
epoch = self.__rpc_ctrl.step_info.epoch + 1
epoch_step = self.__rpc_ctrl.step_info.epoch_step + 1
global_step = self.__rpc_ctrl.step_info.global_step + 1
s = f"Epoch {epoch}/{self.config.exp_ctrl.total_train_epochs} "
s += f"step {epoch_step}/{self._steps_per_epoch} "
s += f"(global step {global_step}) finishes. "
s += f"Average #tokens per batch is {self._avg_tokens_per_batch:.0f}. "
s += f"#End to end# execution time: *{e2e_time:.3f}*s. "
s += f"Total time consumption: {time_since_configure:.3f}s. "
if len(self.e2e_time_history) > 2:
remaining_steps = self._steps_per_epoch - epoch_step
remaining_epochs = self.__total_train_epochs - epoch
avg_t = np.mean(self.e2e_time_history[2:])
remain_t = avg_t * remaining_steps
remain_t += avg_t * self._steps_per_epoch * remaining_epochs
s += f"Estimated remaining time: {remain_t:.3f}s. "
if flops is not None:
s += f"TFLOP/s per GPU: {tflops_per_gpu:.2f}, total TFLOP/s: {tflops:.2f}."
logger.info(s)
logger.info(
f"Time taken so far across all configurations: {time.perf_counter() - self.global_exp_tik:.2f}s"
)
def _clear_gpu_cache(self):
self.__stream.request(
handlers=self.__all_mw_handlers,
handle_type="clear_data_cache",
datas=[self.__rpc_ctrl.ids_to_clear for _ in self.__all_mw_handlers],
no_syn=True,
)
self.__rpc_ctrl.ids_to_clear.clear()
def experiment_complete_exit(self):
# Import here to avoid the conflict with nvloop.
from realhf.base.asyncio_utils import teardown_run_util_complete
self.__rpc_ctrl.stop.set()
for task in self.__asyncio_tasks:
task.cancel()
self.__asyncio_ctx.future.set_result(None)
# NOTE: stopping the loop immediately after cancelling tasks may
# raise warnings sometimes, but it doesn't matter.
self.__asyncio_ctx.loop.stop()
teardown_run_util_complete(self.__asyncio_ctx)
logger.info(
colorama.Style.RESET_ALL
+ colorama.Fore.YELLOW
+ colorama.Style.BRIGHT
+ "\033[1m"
+ "Experiment Completes! Yeah!!!!!!!!"
+ colorama.Style.RESET_ALL
)
# Send requests to pause model workers.
# Model workers will not respond to this message.
self.__stream.request(
handlers=self.__all_mw_handlers,
handle_type="reset",
datas=[None for _ in self.__all_mw_handlers],
)
self.__stream.close()
constants.reset_run()
# Reset names used for distributed training.
# The next round of training will set up a new distributed environment.
name_resolve.clear_subtree(
names.distributed_root(constants.experiment_name(), constants.trial_name())
)
name_resolve.clear_subtree(
names.request_reply_stream_root(
constants.experiment_name(), constants.trial_name()
)
)
wandb.finish()
if self.__summary_writer is not None:
self.__summary_writer.close()
gc.collect()
self.__initialized = False
self.pause()
return worker_base.PollResult(0, 0)
def __recover_save(self):
# save step info for recover
if os.getenv("REAL_SAVE_RECOVER_STATES", "0") != "1":
return
# save step info for recover
this_step_info = copy.deepcopy(self.__rpc_ctrl.step_info)
recover_info = recover.RecoverInfo(
recover_start=this_step_info,
last_step_info=self.__last_step_info,
save_ctl_info=self.__save_ctl.state_dict(),
ckpt_ctl_info=self.__ckpt_ctl.state_dict(),
eval_ctl_info=self.__eval_ctl.state_dict(),
hash_vals_to_ignore=self.__rpc_ctrl.used_hash_vals_this_epoch,
)
recover.dump_recover_info(recover_info)
logger.info("Dumped recover info to file.")
logger.info(f"Will recover from: {recover_info.recover_start}")
logger.info(
f"Number of data used in this epoch: {len(recover_info.hash_vals_to_ignore)}"
)
# def _exit_hook(self, exit_status: worker_base.WorkerServerStatus):
# logger.info(f"Master worker exits with {exit_status}.")
# if os.getenv("REAL_SAVE_RECOVER_STATES", "0") != "1":
# return
# if exit_status == worker_base.WorkerServerStatus.ERROR:
# try:
# sleep_time = 600
# current_sleep_time = 0
# while current_sleep_time < sleep_time:
# logger.info(
# f"ERROR exit, waited {current_sleep_time} s for interruption ..."
# )
# time.sleep(10)
# current_sleep_time += 10
# except KeyboardInterrupt:
# logger.info("Received SIGINT, starting recover save")
# self.__recover_save()
# )
# time.sleep(10)
# current_sleep_time += 10
# except KeyboardInterrupt:
# logger.info("Received SIGINT, starting recover save")
# self.__recover_save()