mirror of https://github.com/inclusionAI/AReaL
159 lines
5.6 KiB
Python
159 lines
5.6 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
from typing import List
|
|
|
|
from realhf.api.core.dfg import MFCDef, ModelInterfaceType
|
|
from realhf.api.core.dfg import build_graph as build_dfg
|
|
from realhf.api.quickstart.device_mesh import DeviceMesh, find_parallel_strategies
|
|
from realhf.api.quickstart.search import MFCDef, RPCExecution, RPCInstance
|
|
from realhf.search_engine.estimate import (
|
|
estimate_rpc_memory_cost,
|
|
estimate_rpc_time_cost,
|
|
)
|
|
|
|
MEM_INDEX = 1.0 # heuristic value to scale estimated memory
|
|
|
|
|
|
def enumerate_rpc_executions(
|
|
rpc: MFCDef,
|
|
device_mesh: DeviceMesh,
|
|
seq_len: int,
|
|
num_gen_tokens: int,
|
|
n_ppo_minibatches: int,
|
|
gradient_checkpointing: bool,
|
|
) -> List[RPCExecution]:
|
|
sub_device_meshes = device_mesh.sub_device_meshes()
|
|
import pprint
|
|
|
|
feasible = []
|
|
for sub_device_mesh in sub_device_meshes:
|
|
ps = find_parallel_strategies(sub_device_mesh)
|
|
for parallel in ps:
|
|
num_dp = parallel.data_parallel_size
|
|
num_pp = parallel.pipeline_parallel_size
|
|
num_mp = parallel.model_parallel_size
|
|
bs = rpc.n_seqs
|
|
# seq_len = seq_len
|
|
min_bs = (
|
|
2 * num_dp * num_pp * n_ppo_minibatches
|
|
if rpc.interface_type == ModelInterfaceType.TRAIN_STEP
|
|
else num_dp * num_pp
|
|
)
|
|
if min_bs > bs:
|
|
# batch size too small
|
|
continue
|
|
# heuristic to filter out inherent slow configurations
|
|
if (
|
|
num_mp * num_dp > device_mesh.n_gpus_per_node
|
|
and rpc.interface_type == ModelInterfaceType.TRAIN_STEP
|
|
):
|
|
continue
|
|
if num_mp > 8:
|
|
continue
|
|
if num_pp > max(device_mesh.n_nodes, 8):
|
|
continue
|
|
# memory and time estimation
|
|
mem_cost, static_mem = estimate_rpc_memory_cost(
|
|
rpc,
|
|
parallel,
|
|
bs,
|
|
seq_len,
|
|
gradient_checkpointing=gradient_checkpointing,
|
|
n_ppo_minibatches=n_ppo_minibatches,
|
|
num_gen_tokens=num_gen_tokens,
|
|
offload=rpc.model_name.role in ["ref", "reward"],
|
|
)
|
|
mem_cost = int(mem_cost * MEM_INDEX)
|
|
static_mem = int(static_mem * MEM_INDEX)
|
|
time_cost = estimate_rpc_time_cost(
|
|
rpc,
|
|
parallel,
|
|
bs=bs,
|
|
seq_len=seq_len,
|
|
num_gen_tokens=num_gen_tokens,
|
|
gradient_checkpointing=gradient_checkpointing,
|
|
n_ppo_minibatches=n_ppo_minibatches,
|
|
)
|
|
time_cost = int(time_cost)
|
|
if mem_cost < device_mesh.gpu_memory_capacity:
|
|
feasible.append(
|
|
RPCExecution(
|
|
rpc,
|
|
sub_device_mesh,
|
|
parallel,
|
|
time_cost,
|
|
mem_cost,
|
|
static_mem,
|
|
)
|
|
)
|
|
return feasible
|
|
|
|
|
|
def build_graph(
|
|
rpcs: List[MFCDef],
|
|
num_epoch: int = 5,
|
|
epoch_dependency_interval: int = 1,
|
|
if_print=False,
|
|
) -> List[RPCInstance]:
|
|
"""Build model function call graph of multiple training epochs,
|
|
|
|
args:
|
|
exp: ProfileExperiment, the experiment object
|
|
num_epoch: int, number of training epochs
|
|
epoch_dependency_interval: int, the interval of epoch dependency,
|
|
e.g. if epoch_dependency_interval = 2, then the graph will have
|
|
edges between epoch i and epoch i+2, i+4, ...
|
|
|
|
returns:
|
|
rpc_instances: List[RPCInstance], the list of RPCInstance objects
|
|
"""
|
|
# one epoch dependency graph
|
|
rpcs, edges = build_dfg(rpcs)
|
|
rpc_names_mapping = {rpc.name: rpc for rpc in rpcs}
|
|
rpc_instances = []
|
|
|
|
# multi epoch graph
|
|
for epoch_id in range(num_epoch):
|
|
for rpc in rpcs:
|
|
children = []
|
|
parents = []
|
|
if rpc.is_src and epoch_id >= epoch_dependency_interval:
|
|
for other in rpcs:
|
|
if other.is_dst and other.model_name.role == rpc.model_name.role:
|
|
parents.append(
|
|
RPCInstance(
|
|
rpc,
|
|
epoch_id - epoch_dependency_interval,
|
|
[],
|
|
[],
|
|
)
|
|
)
|
|
if rpc.is_dst and rpc.model_name.role == rpc.model_name.role:
|
|
for other in rpcs:
|
|
if (
|
|
other.is_src
|
|
and epoch_id + epoch_dependency_interval < num_epoch
|
|
):
|
|
children.append(
|
|
RPCInstance(
|
|
rpc,
|
|
epoch_id + epoch_dependency_interval,
|
|
[],
|
|
[],
|
|
)
|
|
)
|
|
for parent in rpc.parents:
|
|
p = rpc_names_mapping[parent]
|
|
parents.append(RPCInstance(p, epoch_id, [], []))
|
|
for child in rpc.children:
|
|
c = rpc_names_mapping[child]
|
|
children.append(RPCInstance(c, epoch_id, [], []))
|
|
rpc_instance = RPCInstance(rpc, epoch_id, parents, children)
|
|
rpc_instances.append(rpc_instance)
|
|
if if_print:
|
|
for ri in rpc_instances:
|
|
print(ri)
|
|
return rpc_instances
|