AReaL/realhf/search_engine/search.py

233 lines
7.4 KiB
Python

# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import argparse
import functools
import json
import os
import pickle
import pprint
import re
from typing import Any, Dict, List, Literal, Optional
import numpy as np
try:
import realhf._C.mdm_search as mdm_search
except ModuleNotFoundError:
mdm_search = None
import realhf.base.constants as constants
from realhf.api.cli_args import ModelTrainEvalConfig, ParallelismConfig
from realhf.api.core.config import ModelInterfaceType
from realhf.api.core.dfg import MFCDef
from realhf.api.quickstart.device_mesh import DeviceMesh, RPCAllocation
from realhf.api.quickstart.search import RPCExecution
def search_rpc_allocations(
device_mesh: DeviceMesh,
rpcs: List[MFCDef],
num_gen_tokens: int = 256,
n_ppo_minibatches: int = 1,
seq_len: int = 256,
gradient_checkpointing: bool = True,
use_cache: bool = False,
) -> List[RPCAllocation]:
from realhf.search_engine.enumerate import build_graph
from realhf.search_engine.estimate import get_param_realloc_stats
from_file = os.environ.get("REAL_IS_REMOTE", "0") == "1"
dump_dir = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"device_mapping.pkl",
)
log_dir = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"device_mapping",
)
rs_dir = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"raw_search_result",
)
rpc_exe_dir = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"rpc_exe_info",
)
if from_file or (use_cache and os.path.exists(dump_dir)):
with open(dump_dir, "r") as f:
s = json.load(f)
rpc_allocs = [RPCAllocation.from_dict(d) for d in s]
return rpc_allocs
else:
os.makedirs(os.path.dirname(dump_dir), exist_ok=True)
n_nodes = device_mesh.n_nodes
table = {}
for rpc in rpcs:
print(f"Getting param realloc stats for {rpc.model_type} at {rpc.model_path}")
t = get_param_realloc_stats(rpc.model_type, rpc.model_path, n_nodes, True)
table.update(t)
rpc_exe_list = make_rpc_exe_list(
rpcs,
device_mesh,
num_gen_tokens=num_gen_tokens,
n_ppo_minibatches=n_ppo_minibatches,
seq_len=seq_len,
gradient_checkpointing=gradient_checkpointing,
log_dir=rpc_exe_dir,
if_print=False,
)
graph = build_graph(rpcs, 5, 1, if_print=False)
model_size_dict = make_model_size_dict(rpcs, if_print=False)
n_nodes = device_mesh.n_nodes
search_time = 120
rs: List[Dict[str, List]] = mdm_search.multi_mcmc_search(
rpcs,
rpc_exe_list,
graph,
table,
model_size_dict,
0.001, # beta min
0.002, # beta max
0.001, # beta step
search_time, # time limit for each search
1, # repeat
)
if not from_file:
with open(rs_dir, "w") as f:
pprint.pprint(rs, stream=f)
r: Dict[str, Dict[str, Any]] = rs[-1]
pprint.pprint(r)
rpc_name_to_rpcs = {rpc.name: rpc for rpc in rpcs}
rpc_allocs = []
for rpc_name, alloc_info in r.items():
if rpc_name in ["end_time", "mem_cost"]:
continue
# rpc = rpc_dict[rpc_name]
rpc = rpc_name_to_rpcs[rpc_name]
parallel = ParallelismConfig(
pipeline_parallel_size=alloc_info["num_pp"],
data_parallel_size=alloc_info["num_dp"],
model_parallel_size=alloc_info["num_mp"],
use_sequence_parallel=(
alloc_info["num_mp"] > 1
and rpc.interface_type == ModelInterfaceType.TRAIN_STEP
),
)
sub_device_mesh = DeviceMesh(
n_nodes=device_mesh.n_nodes,
n_gpus_per_node=device_mesh.n_gpus_per_node,
mapping=alloc_info["device_mesh_mapping"],
name=alloc_info["device_mesh_name"],
global_mesh_name=device_mesh.global_mesh_name,
)
rpc_alloc = RPCAllocation(
rpc=rpc,
device_mesh=sub_device_mesh,
parallel=parallel,
)
rpc_allocs.append(rpc_alloc)
if not from_file:
with open(dump_dir, "w") as f:
json.dump([rpc_alloc.to_dict() for rpc_alloc in rpc_allocs], f, indent=4)
with open(log_dir, "w") as f:
pprint.pprint(rpc_allocs, stream=f)
return rpc_allocs
def make_rpc_exe_list(
rpcs: List[MFCDef],
device_mesh: DeviceMesh,
num_gen_tokens: int,
n_ppo_minibatches: int,
seq_len: int,
gradient_checkpointing: bool,
if_print: bool = False,
log_dir: Optional[str] = None,
) -> List[RPCExecution]:
from realhf.search_engine.enumerate import enumerate_rpc_executions
rpc_exe_list = []
log_flag = False
for rpc in rpcs:
# real_model_config = load_model_config(rpc)
feasible = enumerate_rpc_executions(
rpc,
device_mesh,
seq_len=seq_len,
num_gen_tokens=num_gen_tokens,
n_ppo_minibatches=n_ppo_minibatches,
gradient_checkpointing=gradient_checkpointing,
)
rpc_exe_list.extend(feasible)
if log_dir is not None:
mode = "w" if not log_flag else "a"
with open(log_dir, mode) as f:
f.write(f"{rpc.name} feasible: {len(feasible)}\n")
feasible.sort(key=lambda x: x.time_cost)
# feasible = feasible[:30]
for i, rpc_exe in enumerate(feasible):
f.write(
f"{i}: time_cost: {rpc_exe.time_cost} ms, {rpc_exe.time_cost} "
f"sub_device_mesh: {rpc_exe.device_mesh}, "
f"parallel_strategy: {rpc_exe.parallel_strategy}, "
f"mem_cost: {rpc_exe.mem/(1024*1024*1024):02f} GB, "
f"static_mem_cost: {rpc_exe.static_mem/(1024*1024*1024):02f} GB\n"
)
f.write("\n")
log_flag = True
if if_print:
print(f"{rpc.name} feasible: {len(feasible)}")
feasible.sort(key=lambda x: x.time_cost)
# feasible = feasible[:10]
for i, rpc_exe in enumerate(feasible):
print(
f"{i}: time_cost: {rpc_exe.time_cost} ms, "
f"sub_device_mesh: {rpc_exe.device_mesh}, "
f"parallel_strategy: {rpc_exe.parallel_strategy}, "
f"mem_cost: {rpc_exe.mem/(1024*1024*1024):02f} GB, "
f"static_mem_cost: {rpc_exe.static_mem/(1024*1024*1024):02f} GB"
)
return rpc_exe_list
def make_model_size_dict(rpcs: List[MFCDef], if_print: bool = False) -> Dict[str, int]:
model_size_dict = {}
for rpc in rpcs:
if rpc.model_name.role in model_size_dict:
continue
# model_configs = load_model_config(rpc)
# model_size_dict[rpc.model_name.role] = estimate_model_size(real_model_config)
model_size_dict[rpc.model_name.role] = rpc.model_type.size
if if_print:
print(
f"model_name: {rpc.model_name.role}, "
f"model_size: {rpc.model_type.size}"
)
return model_size_dict