mirror of https://github.com/inclusionAI/AReaL
169 lines
5.8 KiB
Python
169 lines
5.8 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
import asyncio
|
|
import random
|
|
from typing import *
|
|
|
|
import networkx as nx
|
|
|
|
from realhf.api.core.config import ModelName, ModelShardID
|
|
from realhf.api.core.data_api import DataBatchMeta, SequenceSample
|
|
from realhf.api.core.dfg import MFCDef
|
|
from realhf.api.core.model_api import ReaLModelConfig
|
|
from realhf.base import logging
|
|
from realhf.base.topology import PipeModelDataParallelTopology
|
|
from realhf.system.buffer import AsyncIOSequenceBuffer
|
|
from realhf.system.request_reply_stream import NameResolvingRequestClient
|
|
from realhf.system.v2.function_call import FunctionCall, RPCCorountineControl
|
|
|
|
logger = logging.getLogger(__name__, "system")
|
|
blogger = logging.getLogger("benchmark")
|
|
|
|
|
|
class FunctionExecutor:
|
|
def __init__(
|
|
self,
|
|
rpcs: List[MFCDef],
|
|
msid2mwid: Dict[ModelShardID, int],
|
|
stream: NameResolvingRequestClient,
|
|
buffer: AsyncIOSequenceBuffer,
|
|
model_topos: Dict[str, PipeModelDataParallelTopology],
|
|
model_configs: Dict[str, None | ReaLModelConfig],
|
|
ctrl: RPCCorountineControl,
|
|
):
|
|
|
|
self.func_calls: Dict[str, FunctionCall] = {}
|
|
self.ctrl = ctrl
|
|
|
|
self.n_model_workers = len(set(msid2mwid.values()))
|
|
|
|
self.rpcs = rpcs
|
|
self.src_rpc = list(filter(lambda rpc: rpc.is_src, rpcs))[0]
|
|
self.src_dp_size = model_topos[self.src_rpc.model_name].get_dim("data")
|
|
|
|
# Create model function calls.
|
|
for rpc in self.rpcs:
|
|
func_call = FunctionCall(
|
|
rpc=rpc,
|
|
src_rpc=self.src_rpc,
|
|
stream=stream,
|
|
msid2mwid=msid2mwid,
|
|
model_topos=model_topos,
|
|
model_configs=model_configs,
|
|
ctrl=ctrl,
|
|
buffer=buffer,
|
|
)
|
|
self.func_calls[rpc.name] = func_call
|
|
|
|
self.stream = stream
|
|
self.buffer = buffer
|
|
|
|
# Sort all MFCs in the topological order and
|
|
# calculate the width of each level.
|
|
# These numbers will determine when to flush MFC requests.
|
|
self.topo_widths = []
|
|
for generation in nx.topological_generations(rpcs[0]._G):
|
|
self.topo_widths.append(len(generation))
|
|
|
|
def get_leaf_tasks(self) -> List[str]:
|
|
dst_rpcs = list(filter(lambda rpc: rpc.is_dst, self.rpcs))
|
|
return [rpc.name for rpc in dst_rpcs]
|
|
|
|
async def flush_calls(self):
|
|
for level, w in enumerate(self.topo_widths):
|
|
for _ in range(w):
|
|
await self.ctrl.topo_level_count.get()
|
|
logger.info(f"DFG level {level}. Flushing {w} function calls.")
|
|
self.stream.request(
|
|
handlers=list(range(self.n_model_workers)), handle_type="flush"
|
|
)
|
|
|
|
async def finish_traverse(self):
|
|
for _ in range(len(self.get_leaf_tasks())):
|
|
await self.ctrl.train_count.get()
|
|
|
|
async def load_data(self):
|
|
src_rpc = self.src_rpc
|
|
src_rpc_model_name = src_rpc.model_name
|
|
buffer = self.buffer
|
|
ctrl = self.ctrl
|
|
|
|
dp_idx = -1
|
|
received_ids = set()
|
|
|
|
while self.buffer.size < max(rpc.n_seqs for rpc in self.rpcs):
|
|
|
|
all_data = []
|
|
|
|
dp_idx += 1
|
|
dp_idx %= self.src_dp_size
|
|
|
|
resps = await self.stream.call_async(
|
|
handlers=[f"__data{dp_idx}__"],
|
|
handle_type="fetch",
|
|
datas=[None],
|
|
verbose=False,
|
|
)
|
|
x: DataBatchMeta | None = resps[0]
|
|
|
|
if x is None:
|
|
continue
|
|
if x.meta_sample is None:
|
|
continue
|
|
|
|
# Store the owner information of the data.
|
|
# RPCs corountines will use this information to
|
|
# determine the src and dst of data transfer.
|
|
for xx in x.meta_sample.unpack():
|
|
if xx.ids[0] in received_ids:
|
|
raise ValueError(
|
|
f"Duplicate data id {xx.ids[0]}. Is the final batch? {is_final_batch}."
|
|
)
|
|
received_ids.add(xx.ids[0])
|
|
for k in xx.keys:
|
|
self.ctrl.data_owner[(xx.ids[0], k)] = (
|
|
src_rpc_model_name,
|
|
dp_idx,
|
|
)
|
|
all_data += x.meta_sample.unpack()
|
|
|
|
filtered_data = []
|
|
for xx in x.meta_sample.unpack():
|
|
if xx.ids[0] in ctrl.hash_vals_to_ignore_in_recover:
|
|
ctrl.hash_vals_to_ignore_in_recover.remove(xx.ids[0])
|
|
ctrl.ids_to_clear.add(xx.ids[0])
|
|
else:
|
|
filtered_data.append(xx)
|
|
all_data = filtered_data
|
|
|
|
# We load data in a round-robin manner across different DP ranks,
|
|
# so we also need to shuffle the data to fuse different dataset splits.
|
|
random.shuffle(all_data)
|
|
|
|
# Store into buffer!
|
|
buffer_indices = await buffer.put_batch(all_data)
|
|
assert len(buffer_indices) == len(all_data)
|
|
|
|
blogger.info(
|
|
f"Master worker loaded {len(all_data)} pieces of data. "
|
|
f"Remaining number of data to ignore: {len(self.ctrl.hash_vals_to_ignore_in_recover)}. "
|
|
f"Current buffer size: {buffer.size}/{buffer.max_size}. "
|
|
)
|
|
|
|
def execute_step(self):
|
|
logger.info("Waiting for the finish of the execution graph.")
|
|
loop = asyncio.get_event_loop()
|
|
|
|
tasks = [loop.create_task(fc.run()) for fc in self.func_calls.values()] + [
|
|
loop.create_task(self.flush_calls()),
|
|
loop.create_task(self.load_data()),
|
|
]
|
|
|
|
completion_future = loop.create_task(self.finish_traverse())
|
|
|
|
loop.run_until_complete(completion_future)
|
|
|
|
for task in tasks:
|
|
loop.run_until_complete(task)
|
|
|
|
logger.info("Execution finished!")
|