mirror of https://github.com/inclusionAI/AReaL
438 lines
17 KiB
Python
438 lines
17 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
|
|
import asyncio
|
|
import dataclasses
|
|
import os
|
|
import time
|
|
import uuid
|
|
from collections import defaultdict
|
|
from typing import Dict, Hashable, List, Set, Tuple
|
|
|
|
import wandb
|
|
|
|
import realhf.api.core.config as config_api
|
|
import realhf.api.core.data_api as data_api
|
|
import realhf.api.core.dfg as dfg
|
|
import realhf.api.core.system_api as config_pkg
|
|
import realhf.base.recover as recover
|
|
import realhf.system.request_reply_stream as request_reply_stream
|
|
from realhf.api.core.config import ModelName
|
|
from realhf.api.core.model_api import ReaLModelConfig
|
|
from realhf.base import constants, logging, topology
|
|
from realhf.system.buffer import AsyncIOSequenceBuffer
|
|
from realhf.system.flops_counter import FlopsCounter
|
|
|
|
logger = logging.getLogger(__name__, "system")
|
|
blogger = logging.getLogger("benchmark")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class RPCCorountineControl:
|
|
# for counting the number of finished training steps
|
|
# one training step corresponds to traversal of the whole DFG
|
|
train_count: asyncio.Queue
|
|
# For flushing requests
|
|
topo_level_count: asyncio.Queue
|
|
|
|
# for training data management and data cleaning after each step
|
|
ids_to_clear: Set[Hashable] = dataclasses.field(default_factory=set)
|
|
flops_counter: FlopsCounter = dataclasses.field(default_factory=FlopsCounter)
|
|
|
|
data_owner: Dict[Tuple[int, str], Tuple[ModelName, int]] = dataclasses.field(
|
|
default_factory=dict
|
|
)
|
|
|
|
should_save: bool = False
|
|
should_eval: bool = False
|
|
should_ckpt: bool = False
|
|
step_info: recover.StepInfo = dataclasses.field(default_factory=recover.StepInfo)
|
|
|
|
# recover information
|
|
used_hash_vals_this_epoch: List[int] = dataclasses.field(default_factory=list)
|
|
hash_vals_to_ignore_in_recover: List[int] = dataclasses.field(default_factory=list)
|
|
|
|
|
|
class FunctionCall:
|
|
def __init__(
|
|
self,
|
|
rpc: dfg.MFCDef,
|
|
src_rpc: dfg.MFCDef,
|
|
stream: request_reply_stream.NameResolvingRequestClient,
|
|
msid2mwid: Dict[config_pkg.ModelShardID, int],
|
|
model_topos: Dict[str, topology.PipeModelDataParallelTopology],
|
|
model_configs: Dict[str, None | ReaLModelConfig],
|
|
ctrl: RPCCorountineControl,
|
|
buffer: AsyncIOSequenceBuffer,
|
|
):
|
|
|
|
self.rpc = rpc
|
|
self.src_rpc = src_rpc
|
|
self.stream = stream
|
|
|
|
self.msid2mwid = msid2mwid
|
|
self.model_topos = model_topos
|
|
self.model_configs = model_configs
|
|
|
|
self.model_save_root = os.path.join(
|
|
constants.MODEL_SAVE_ROOT,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
)
|
|
|
|
self.rpc_ctrl = ctrl
|
|
|
|
self.buffer = buffer
|
|
|
|
@property
|
|
def dp_size(self):
|
|
return self.model_topos[self.rpc.model_name].get_dim("data")
|
|
|
|
@property
|
|
def pp_size(self):
|
|
return self.model_topos[self.rpc.model_name].get_dim("pipe")
|
|
|
|
def attach_payloads_with_hooks(
|
|
self,
|
|
payloads: Dict[config_api.ModelShardID, request_reply_stream.Payload],
|
|
mwids: List[int],
|
|
main_handlers: List[config_pkg.ModelShardID],
|
|
hook_type: str,
|
|
) -> Tuple[Dict[config_api.ModelShardID, request_reply_stream.Payload], List[int]]:
|
|
assert hook_type in ["pre", "post"], hook_type
|
|
|
|
rpc = self.rpc
|
|
model_topos = self.model_topos
|
|
model_configs = self.model_configs
|
|
|
|
main_mwids = set([self.msid2mwid[h] for h in main_handlers])
|
|
for hook in getattr(rpc, f"_{hook_type}_hooks"):
|
|
if isinstance(hook, dfg.ParamReallocHook):
|
|
assert (hook.source is None) != (hook.target is None), hook
|
|
if hook.source is None:
|
|
src_topo = model_topos[rpc.model_name]
|
|
dst_topo = model_topos[hook.target]
|
|
dst_config = model_configs[hook.target]
|
|
src_model_name, dst_model_name = rpc.model_name, hook.target
|
|
other_model_name = hook.target
|
|
other_topo = dst_topo
|
|
else:
|
|
src_topo = model_topos[hook.source]
|
|
dst_topo = model_topos[rpc.model_name]
|
|
dst_config = model_configs[rpc.model_name]
|
|
src_model_name, dst_model_name = hook.source, rpc.model_name
|
|
other_model_name = hook.source
|
|
other_topo = src_topo
|
|
|
|
ps_data = {
|
|
"from_model_name": src_model_name,
|
|
"to_model_name": dst_model_name,
|
|
"from_topo": src_topo,
|
|
"to_topo": dst_topo,
|
|
"to_model_config": dst_config,
|
|
"eta": hook.eta,
|
|
}
|
|
for h in main_handlers:
|
|
getattr(payloads[h], f"{hook_type}_hooks").append("param_realloc")
|
|
getattr(payloads[h], f"{hook_type}_hook_data").append(ps_data)
|
|
other_handlers = [
|
|
config_api.ModelShardID.from_parallelism_rank(
|
|
other_model_name, other_topo, j
|
|
)
|
|
for j in range(other_topo.world_size())
|
|
]
|
|
for h in other_handlers:
|
|
if self.msid2mwid[h] not in mwids:
|
|
payloads[h] = request_reply_stream.Payload(
|
|
handler=h,
|
|
handle_name="empty",
|
|
)
|
|
setattr(payloads[h], f"{hook_type}_hooks", ["param_realloc"])
|
|
setattr(payloads[h], f"{hook_type}_hook_data", [ps_data])
|
|
mwids.append(self.msid2mwid[h])
|
|
elif self.msid2mwid[h] not in main_mwids:
|
|
hh = next(
|
|
hh
|
|
for hh in payloads
|
|
if self.msid2mwid[hh] == self.msid2mwid[h]
|
|
)
|
|
getattr(payloads[hh], f"{hook_type}_hooks").append(
|
|
"param_realloc"
|
|
)
|
|
getattr(payloads[hh], f"{hook_type}_hook_data").append(ps_data)
|
|
|
|
elif isinstance(hook, dfg.OffloadHook):
|
|
for h in main_handlers:
|
|
getattr(payloads[h], f"{hook_type}_hooks").append("offload")
|
|
getattr(payloads[h], f"{hook_type}_hook_data").append(
|
|
dict(model_name=h.model_name)
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Unknown hook type: {hook}")
|
|
return payloads, mwids
|
|
|
|
def request(
|
|
self,
|
|
producer_names: Dict[str, str],
|
|
producer_name2producer_handlers: Dict[str, List[config_pkg.ModelShardID]],
|
|
producer_mappings: Dict[str, Dict[str, List[int]]],
|
|
target_mapping: Dict[str, List[int]],
|
|
meta_sample: data_api.SequenceSample,
|
|
handlers: List[config_pkg.ModelShardID],
|
|
) -> Tuple[List[uuid.UUID], List[uuid.UUID]]:
|
|
|
|
rpc = self.rpc
|
|
ctrl = self.rpc_ctrl
|
|
|
|
dt_data = {
|
|
"keys": rpc.input_keys,
|
|
"target": rpc.model_name,
|
|
"producer_names": producer_names,
|
|
"producer_mappings": producer_mappings,
|
|
"target_mapping": target_mapping,
|
|
"handle_name": rpc.interface_type.value,
|
|
"rpc_name": rpc.name,
|
|
"meta_sample": meta_sample,
|
|
}
|
|
|
|
payloads = {
|
|
handler: request_reply_stream.Payload(
|
|
handler=handler,
|
|
handle_name=rpc.interface_type.value,
|
|
pre_hooks=["data_transfer"],
|
|
pre_hook_data=[dt_data],
|
|
data=rpc.name,
|
|
)
|
|
for handler in handlers
|
|
}
|
|
if ctrl.should_eval:
|
|
for p in payloads.values():
|
|
p.post_hooks.append("evaluate")
|
|
p.post_hook_data.append(dict(model_name=rpc.model_name))
|
|
if (
|
|
ctrl.should_save or ctrl.should_ckpt
|
|
) and rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
|
|
for p in payloads.values():
|
|
p.post_hooks.append("save")
|
|
save_dir = os.path.join(
|
|
self.model_save_root,
|
|
rpc.model_name.role,
|
|
f"epoch{ctrl.step_info.epoch + 1}"
|
|
f"epochstep{ctrl.step_info.epoch_step + 1}"
|
|
f"globalstep{ctrl.step_info.global_step + 1}",
|
|
)
|
|
p.post_hook_data.append(
|
|
dict(
|
|
model_name=rpc.model_name,
|
|
save_dir=save_dir,
|
|
recover_only=not ctrl.should_save,
|
|
)
|
|
)
|
|
mwids = [self.msid2mwid[h] for h in handlers]
|
|
assert len(mwids) == len(set(mwids))
|
|
|
|
for producer_name in producer_names.values():
|
|
for h in producer_name2producer_handlers[producer_name]:
|
|
if self.msid2mwid[h] not in mwids:
|
|
payloads[h] = request_reply_stream.Payload(
|
|
handler=h,
|
|
handle_name="empty",
|
|
pre_hooks=["data_transfer"],
|
|
pre_hook_data=[dt_data],
|
|
)
|
|
mwids.append(self.msid2mwid[h])
|
|
|
|
payloads, mwids = self.attach_payloads_with_hooks(
|
|
payloads,
|
|
mwids,
|
|
main_handlers=handlers,
|
|
hook_type="pre",
|
|
)
|
|
payloads, mwids = self.attach_payloads_with_hooks(
|
|
payloads,
|
|
mwids,
|
|
main_handlers=handlers,
|
|
hook_type="post",
|
|
)
|
|
main_payloads = [p for h, p in payloads.items() if h in handlers]
|
|
other_payloads = [p for h, p in payloads.items() if h not in handlers]
|
|
all_req_ids = self.stream.request(
|
|
payloads=main_payloads + other_payloads,
|
|
)
|
|
return all_req_ids[: len(main_payloads)], all_req_ids[len(main_payloads) :]
|
|
|
|
def data_parallel_dispatch(
|
|
self, buf_indices: List[int], sample: data_api.SequenceSample
|
|
) -> Tuple[List[int], data_api.SequenceSample, List[Tuple[int, int]]]:
|
|
# Dispatch data to different data parallel ranks.
|
|
if self.rpc.is_generate():
|
|
# The workload of generation is decided by batch size, instead of the generated length.
|
|
samples, forward_indices, _ = sample.split_with_lengths(
|
|
mb_spec=data_api.MicroBatchSpec(n_mbs=self.dp_size),
|
|
lens=[1 for _ in range(sample.bs)],
|
|
)
|
|
else:
|
|
samples, forward_indices, _ = sample.split(
|
|
data_api.MicroBatchSpec(n_mbs=self.dp_size)
|
|
)
|
|
blogger.info(
|
|
f"DP split (DP size {self.dp_size}) for RPC {self.rpc.name}: "
|
|
f"#seqs: {[s.bs for s in samples]}, "
|
|
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in samples]}"
|
|
)
|
|
sample = data_api.SequenceSample.gather(samples)
|
|
buf_indices = [buf_indices[i] for i in forward_indices]
|
|
|
|
partitions = data_api.SequenceSplitSpec(
|
|
sizes=[s.bs for s in samples]
|
|
).partitions
|
|
return buf_indices, sample, partitions
|
|
|
|
async def run_step(self, buf_indices, sample):
|
|
rpc = self.rpc
|
|
topo = self.model_topos[rpc.model_name]
|
|
ctrl = self.rpc_ctrl
|
|
|
|
handlers = [
|
|
config_pkg.ModelShardID.from_parallelism_rank(rpc.model_name, topo, j)
|
|
for j in range(topo.world_size())
|
|
]
|
|
|
|
producer_names = {} # data key -> model name
|
|
for k in rpc.input_keys:
|
|
if k in rpc.data_producers:
|
|
producer_names[k] = rpc.data_producers[k]
|
|
else:
|
|
producer_names[k] = self.src_rpc.model_name
|
|
keys_to_send = defaultdict(list) # model name -> List[keys] to send
|
|
for k in producer_names:
|
|
keys_to_send[producer_names[k]].append(k)
|
|
|
|
# convert producer model name to ModelShardID
|
|
producer_name2producer_handlers = {}
|
|
for producer_name in keys_to_send:
|
|
producer_name2producer_handlers[producer_name] = [
|
|
config_pkg.ModelShardID.from_parallelism_rank(
|
|
producer_name, self.model_topos[producer_name], j
|
|
)
|
|
for j in range(self.model_topos[producer_name].world_size())
|
|
]
|
|
|
|
dp_head_indices = [
|
|
topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, model=0)
|
|
for i in range(self.dp_size)
|
|
]
|
|
|
|
ctrl.flops_counter.add_rpc(rpc, sample, self.model_configs[rpc.model_name])
|
|
|
|
# logger.info(f"Model rpc {rpc.name} requesting.")
|
|
|
|
# Sample may be reordered here.
|
|
buf_indices, sample, partitions = self.data_parallel_dispatch(
|
|
buf_indices, sample
|
|
)
|
|
target_mapping = {i: list(range(v[0], v[1])) for i, v in enumerate(partitions)}
|
|
|
|
# Set data owner of produced data by this RPC, such that downstream RPCs can know
|
|
# where to fetch these data.
|
|
for dp_idx, (st, ed) in enumerate(partitions):
|
|
for i in range(st, ed):
|
|
for k in rpc.output_keys:
|
|
self.rpc_ctrl.data_owner[sample.ids[i], k] = (
|
|
rpc.model_name,
|
|
dp_idx,
|
|
)
|
|
|
|
# Get the data owner of this RPC's input data.
|
|
# We use it to determine the source of data transfer.
|
|
producer_mappings = {}
|
|
for k in rpc.input_keys:
|
|
names, dp_indices = [], []
|
|
for sample_id in sample.ids:
|
|
owner_name, dp_idx = self.rpc_ctrl.data_owner[(sample_id, k)]
|
|
names.append(owner_name)
|
|
dp_indices.append(dp_idx)
|
|
assert len(set(names)) == 1
|
|
producer_mapping = defaultdict(list)
|
|
for i, dp_idx in enumerate(dp_indices):
|
|
producer_mapping[dp_idx].append(i)
|
|
producer_mapping = {k: sorted(v) for k, v in producer_mapping.items()}
|
|
producer_mappings[names[0], k] = producer_mapping
|
|
|
|
# send partitioned data to model workers
|
|
req_ids, other_req_ids = self.request(
|
|
producer_names=producer_names,
|
|
producer_name2producer_handlers=producer_name2producer_handlers,
|
|
producer_mappings=producer_mappings,
|
|
target_mapping=target_mapping,
|
|
meta_sample=sample,
|
|
handlers=handlers,
|
|
)
|
|
tik = time.perf_counter()
|
|
|
|
await ctrl.topo_level_count.put(1)
|
|
logger.info(f"Model rpc {rpc.name} requested.")
|
|
|
|
# Then, wait for all main requests to finish.
|
|
responses = await self.stream.gather_async(request_ids=req_ids)
|
|
# logger.info(f"rpc {rpc.name} received responses {req_ids}")
|
|
|
|
# Filter out responses other than DP heads.
|
|
# Other repsonses are duplicated or None.
|
|
responses = [responses[i] for i in dp_head_indices]
|
|
|
|
# If the returned data is a SequenceSample, it is the data returned by
|
|
# model function calls. The data shoulbe be amended into buffer.
|
|
# Otherwise, it's the train statistics and should be reduced and logged.
|
|
if isinstance(responses[-1], data_api.SequenceSample):
|
|
res = data_api.SequenceSample.gather(responses)
|
|
else:
|
|
res = data_api.gather_stat(responses)
|
|
|
|
if rpc.log_return_value:
|
|
logger.info(f"RPC name {rpc.name} returns {res}")
|
|
|
|
if isinstance(res, Dict):
|
|
wandb.log(res, step=ctrl.step_info.global_step)
|
|
|
|
logger.info(
|
|
f"Model rpc {rpc.name} finished. "
|
|
f"Run time {time.perf_counter() - tik:.4f}s."
|
|
)
|
|
|
|
# If this RPC is the final node in the dataflow graph,
|
|
# update the train counter.
|
|
# Otherwise, amend data in the buffer.
|
|
if rpc.is_dst:
|
|
ctrl.ids_to_clear = ctrl.ids_to_clear.union(sample.ids)
|
|
await ctrl.train_count.put(1)
|
|
else:
|
|
logger.info(f"Amending RPC {rpc.name} output keys: {res.keys}")
|
|
await self.buffer.amend_batch(buf_indices, res.unpack())
|
|
|
|
# Wait for all side-effect requests to finish.
|
|
# Side-effect or empty requests are required for data transfer
|
|
# and parameter synchronization.
|
|
# Wait them after the main request to log the oorrect MFC time.
|
|
await self.stream.gather_async(other_req_ids)
|
|
|
|
async def run(self):
|
|
rpc = self.rpc
|
|
topo = self.model_topos[rpc.model_name]
|
|
ctrl = self.rpc_ctrl
|
|
|
|
logger.info(
|
|
f"Running Model RPC, interface_type=#{rpc.interface_type}# "
|
|
f"(dp,mp,pp) = *({topo.get_dim('data')},{topo.get_dim('model')},{topo.get_dim('pipe')})*"
|
|
)
|
|
|
|
consumed = 0
|
|
while True:
|
|
buf_indices, sample = await self.buffer.get_batch_for_rpc(rpc)
|
|
|
|
await self.run_step(buf_indices, sample)
|
|
consumed += sample.bs
|
|
|
|
# Ensure that parent RPCs will not be over-consumed.
|
|
if all(consumed >= c.n_seqs for c in rpc.all_successors()):
|
|
break
|