AReaL/realhf/impl/model/comm/param_realloc.py

574 lines
23 KiB
Python

# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
import itertools
from collections import defaultdict
from typing import *
import numpy as np
import scipy.optimize
import torch.distributed
import torch.nn as nn
from realhf.api.core import model_api
from realhf.api.core.config import ModelName, ModelShardID
from realhf.base import constants, topology
from realhf.impl.model.comm.global_comm import filter_match_mwids
from realhf.impl.model.nn.flatten_param import (
ContiguousParamSpec,
build_param_spec,
param_intervals_from_keys,
param_size_from_keys,
)
from realhf.impl.model.nn.real_llm_base import keys_from_layer_indices
from realhf.impl.model.nn.real_llm_parallel import (
partition_pipeline_layers,
pipeline_repartition_strategy,
)
_TRAINABLE: Dict[ModelName, bool] = {}
def set_trainable(model_name: ModelName, trainable: bool):
_TRAINABLE[model_name] = trainable
def is_trainable(model_name: ModelName) -> bool:
return _TRAINABLE[model_name]
@dataclasses.dataclass(unsafe_hash=True)
class ParamReallocPair:
src: ModelName
src_dp_rank: int
src_tp_rank: int
src_pp_rank: int
dst: ModelName
dst_tp_rank: int
dst_pp_rank: int
@dataclasses.dataclass(unsafe_hash=True)
class ParamReallocModelPair:
src: ModelName
dst: ModelName
@dataclasses.dataclass
class ParamReallocInfo:
# Groups for parameter synchronization.
param_realloc_model_group: Dict[
ParamReallocModelPair, torch.distributed.ProcessGroup
]
param_realloc_model_cpu_group: Dict[
ParamReallocModelPair, torch.distributed.ProcessGroup
]
param_realloc_groups: Dict[ParamReallocPair, torch.distributed.ProcessGroup]
param_realloc_src_ranks: Dict[ParamReallocPair, int]
param_realloc_dst_ranks: Dict[ParamReallocPair, List[int]]
def _max_match(_src_ranks: List[int], _grouped_dst_ranks: List[List[int]]):
cost_matrix = []
for source in _src_ranks:
costs = []
for destinations in _grouped_dst_ranks:
cost = 0 if source in destinations else 1
costs.append(cost)
cost_matrix.append(costs)
row_ind, col_ind = scipy.optimize.linear_sum_assignment(cost_matrix)
return row_ind, col_ind
def _group_mwids_by_node(ranks: List[int]) -> Dict[int, List[int]]:
node2ranks = defaultdict(list)
for r in ranks:
node2ranks[r // 8].append(r)
return {k: node2ranks[k] for k in sorted(node2ranks.keys())}
def _squeeze_mwids_by_node(ranks: List[int]) -> List[int]:
node2ranks = _group_mwids_by_node(ranks)
return [ranks[0] for ranks in node2ranks.values()]
def _assign_src_to_dsts(
node2srcs: Dict[int, List[int]], node2dsts: Dict[int, List[int]]
) -> Dict[int, List[int]]:
"""Assign nodes with a greedy algorithm.
All ranks in the values of node2srcs have the data required by all dst ranks.
Source ranks can be assigned to zero or multiple destination ranks.
All destination ranks must be assigned to exactly one src rank.
Args:
node2srcs (Dict[int, List[int]]): Node index -> source ranks.
node2dsts (Dict[int, List[int]]): Node index -> destination ranks.
Returns:
Dict[int, List[int]]: src rank -> dst ranks.
"""
# First, assign all destination ranks to source nodes.
# If a destination node is also a source node, assign it to itself.
# Otherwise, find the node with the minimum workload for load balancing.
dst_src_nodes = {}
for dst_node in node2dsts.keys():
if dst_node in node2srcs:
# dst node is also src node
dst_src_nodes[dst_node] = dst_node
src_node_workloads = {k: int(k in dst_src_nodes) for k in node2srcs}
assert sum(src_node_workloads.values()) == len(dst_src_nodes)
for dst_node in node2dsts.keys():
if dst_node not in node2srcs:
# find a source node with the minimum workload
src_node = min(src_node_workloads, key=src_node_workloads.get)
dst_src_nodes[dst_node] = src_node
assert all(dst_node in dst_src_nodes for dst_node in node2dsts)
# Revert the key-value of the dict.
src_dst_nodes = defaultdict(list)
for dst_node, src_node in dst_src_nodes.items():
src_dst_nodes[src_node].append(dst_node)
# Next, find an appropriate source rank on each source node.
# If the source rank is also the destination rank, assign it to the destination rank.
# Otherwise, assign the first one to the destination ranks.
assignment = {}
for src_node, dst_nodes in src_dst_nodes.items():
assigned = False
src_ranks = node2srcs[src_node]
dst_ranks = sum([node2dsts[dst_node] for dst_node in dst_nodes], start=[])
for s in src_ranks:
if s in dst_ranks:
assignment[s] = dst_ranks
assigned = True
break
if not assigned:
assignment[src_ranks[0]] = dst_ranks
assert len(set(assignment.keys())) == len(assignment.keys())
assert len(set(sum(assignment.values(), []))) == len(sum(assignment.values(), []))
return assignment
def _create_param_realloc_groups(
from_topo: topology.ProcessTopology,
to_topo: topology.ProcessTopology,
src: ModelName,
dst: ModelName,
msid2mwid: Dict[ModelShardID, int],
param_realloc_groups: Dict[ParamReallocPair, torch.distributed.ProcessGroup],
param_realloc_src_ranks: Dict[ParamReallocPair, int],
param_realloc_dst_ranks: Dict[ParamReallocPair, List[int]],
):
mwid2msid: Dict[int, Dict[ModelName, ModelShardID]] = defaultdict(dict)
for k, v in msid2mwid.items():
mwid2msid[v][k.model_name] = k
for pp_i, pp_j in itertools.product(
range(from_topo.get_dim("pipe")), range(to_topo.get_dim("pipe"))
):
# create tensor reshard groups
src_tp_size = from_topo.get_dim("tensor")
dst_tp_size = to_topo.get_dim("tensor")
for tp_j in range(dst_tp_size):
_all_dst_ranks = filter_match_mwids(
dst, to_topo, msid2mwid, pipe=pp_j, tensor=tp_j
)
if src_tp_size > dst_tp_size:
factor = src_tp_size // dst_tp_size
tp_is = list(range(factor * tp_j, factor * (tp_j + 1)))
_all_src_ranks = [
filter_match_mwids(
src, from_topo, msid2mwid, tensor=tp_i, pipe=pp_i
)
for tp_i in tp_is
]
else:
factor = dst_tp_size // src_tp_size
_all_src_ranks = [
filter_match_mwids(
src,
from_topo,
msid2mwid,
tensor=tp_j // factor,
pipe=pp_i,
)
]
# All GPUs in _src_ranks have the data required by (pp_j, tp_j)
for _src_ranks in _all_src_ranks:
# NOTE: inter-node communication cost is significantly larger than intra-node communication cost.
# We only select one sender per host/node to prevent multiple senders occupying the same network bandwidth.
# This is not the optimal solution for intra-node communication
# because there may exist a source rank that is also dst rank,
# but we forcely select the first source rank on each node here.
assignment = _assign_src_to_dsts(
_group_mwids_by_node(_src_ranks),
_group_mwids_by_node(_all_dst_ranks),
)
_idle_src_ranks = [r for r in _src_ranks if r not in assignment]
for _src_rank in _idle_src_ranks:
dp_i, tp_i = (
from_topo.get_coord(
mwid2msid[_src_rank][src].parallelism_rank
).data,
from_topo.get_coord(
mwid2msid[_src_rank][src].parallelism_rank
).tensor,
)
key = ParamReallocPair(
src=src,
src_dp_rank=dp_i,
src_tp_rank=tp_i,
src_pp_rank=pp_i,
dst=dst,
dst_tp_rank=tp_j,
dst_pp_rank=pp_j,
)
param_realloc_dst_ranks[key] = []
param_realloc_groups[key] = None
param_realloc_src_ranks[key] = _src_rank
for _src_rank, _dst_ranks in assignment.items():
dp_i, tp_i = (
from_topo.get_coord(
mwid2msid[_src_rank][src].parallelism_rank
).data,
from_topo.get_coord(
mwid2msid[_src_rank][src].parallelism_rank
).tensor,
)
key = ParamReallocPair(
src=src,
src_dp_rank=dp_i,
src_tp_rank=tp_i,
src_pp_rank=pp_i,
dst=dst,
dst_tp_rank=tp_j,
dst_pp_rank=pp_j,
)
param_realloc_dst_ranks[key] = _dst_ranks
if _src_rank not in _dst_ranks:
_dst_ranks = [_src_rank] + _dst_ranks
assert len(set(_dst_ranks)) == len(_dst_ranks)
if len(_dst_ranks) > 1:
if torch.distributed.is_initialized():
param_realloc_groups[key] = topology.new_or_get_group(
_dst_ranks
)
else:
# for estimating parameter realloc cost
param_realloc_groups[key] = 1
else:
param_realloc_groups[key] = None
param_realloc_src_ranks[key] = _src_rank
def setup_param_realloc(
model_topos: Optional[Dict[str, topology.ProcessTopology]] = None,
msid2mwid: Optional[Dict[ModelShardID, int]] = None,
param_realloc_pairs: Optional[List[Tuple[ModelName, ModelName]]] = None,
) -> ParamReallocInfo:
param_realloc_groups = {}
param_realloc_src_ranks = {}
param_realloc_dst_ranks = {}
param_realloc_model_group = {}
param_realloc_model_cpu_group = {}
if param_realloc_pairs is not None:
for src, dst in param_realloc_pairs:
_create_param_realloc_groups(
model_topos[src],
model_topos[dst],
src,
dst,
msid2mwid,
param_realloc_groups,
param_realloc_src_ranks,
param_realloc_dst_ranks,
)
pair_mw_ranks = set()
topo1 = model_topos[src]
topo2 = model_topos[dst]
for i in range(topo1.world_size()):
pair_mw_ranks.add(
msid2mwid[ModelShardID.from_parallelism_rank(src, topo1, i)]
)
for j in range(topo2.world_size()):
pair_mw_ranks.add(
msid2mwid[ModelShardID.from_parallelism_rank(dst, topo2, j)]
)
param_realloc_model_group[ParamReallocModelPair(src, dst)] = (
topology.new_or_get_group(list(sorted(pair_mw_ranks)))
)
param_realloc_model_cpu_group[ParamReallocModelPair(src, dst)] = (
topology.new_or_get_group(list(sorted(pair_mw_ranks)), backend="gloo")
)
return ParamReallocInfo(
param_realloc_groups=param_realloc_groups,
param_realloc_src_ranks=param_realloc_src_ranks,
param_realloc_dst_ranks=param_realloc_dst_ranks,
param_realloc_model_group=param_realloc_model_group,
param_realloc_model_cpu_group=param_realloc_model_cpu_group,
)
@dataclasses.dataclass
class ReparallelizeSenderStep:
rank: int
sender_tp_portion_id: int
receiver_tp_portion_id: int
param_keys: List[str]
param_intervals_cpu: List[Tuple[int, int]]
param_intervals_cuda: torch.Tensor
max_interval_size: int
param_size: int
group: torch.distributed.ProcessGroup
dst_ranks: List[int]
remove: bool = False
@dataclasses.dataclass
class ReparallelizeReceiverStep:
rank: int
sender_tp_portion_id: int
receiver_tp_portion_id: int
sender_param_intervals_cpu: List[Tuple[int, int]]
sender_param_intervals_cuda: torch.Tensor
sender_max_interval_size: int
receiver_param_intervals_cpu: List[Tuple[int, int]]
receiver_param_intervals_cuda: torch.Tensor
receiver_max_interval_size: int
param_size: int
param_keys: List[str]
param_dtype: torch.dtype
src: int
dst_ranks: List[int]
group: torch.distributed.ProcessGroup
def _derive_reparallelize_comm_plan(
from_model_name: ModelName,
to_model_name: ModelName,
from_topo: topology.ProcessTopology,
to_topo: topology.ProcessTopology,
from_model_config: model_api.ReaLModelConfig,
to_model_config: model_api.ReaLModelConfig,
pg_info: ParamReallocInfo,
dtype: Optional[torch.dtype] = torch.float16,
) -> List[ReparallelizeReceiverStep | ReparallelizeSenderStep]:
src_tp_size = from_topo.get_dim("tensor")
dst_tp_size = to_topo.get_dim("tensor")
assert src_tp_size % dst_tp_size == 0 or dst_tp_size % src_tp_size == 0
for k, v in dataclasses.asdict(to_model_config).items():
if k not in ["is_critic"] and v != getattr(from_model_config, k):
raise ValueError(
f"Can't load a checkpoint with different config (key `{k}`, "
f"value in checkpoint is `{v}`, current value is `{getattr(from_model_config, k)}`)."
)
if from_model_config.n_kv_heads > 1 and (
from_model_config.n_kv_heads % src_tp_size == 0
) != (from_model_config.n_kv_heads % dst_tp_size == 0):
raise ValueError("Whether to partition kv heads should remain the same.")
from_layer_mapping = partition_pipeline_layers(
from_model_config,
from_topo.get_dim("pipe"),
)
from_layer_mapping = {
k: list(range(v[0], v[1])) for k, v in from_layer_mapping.items()
}
to_layer_mapping = partition_pipeline_layers(
to_model_config,
to_topo.get_dim("pipe"),
)
to_layer_mapping = {k: list(range(v[0], v[1])) for k, v in to_layer_mapping.items()}
repart_strat = pipeline_repartition_strategy(from_layer_mapping, to_layer_mapping)
from_model_head_param_point_to_embedding = (
from_model_config.tied_embedding
and not from_model_config.is_critic
and from_topo.get_dim("pipe") == 1
)
to_model_head_param_point_to_embedding = (
to_model_config.tied_embedding
and not to_model_config.is_critic
and to_topo.get_dim("pipe") == 1
)
if constants.has_model_name(from_model_name):
with constants.model_scope(from_model_name):
from_layer_indices = from_layer_mapping[constants.pipe_parallel_rank()]
from_model_param_specs, _ = build_param_spec(
from_layer_indices,
from_model_config,
tp_size=from_topo.get_dim("tensor"),
dp_size=from_topo.get_dim("data"),
pp_size=from_topo.get_dim("pipe"),
head_param_point_to_embedding=from_model_head_param_point_to_embedding,
)
if constants.has_model_name(to_model_name):
with constants.model_scope(to_model_name):
to_layer_indices = to_layer_mapping[constants.pipe_parallel_rank()]
to_model_param_specs, _ = build_param_spec(
to_layer_indices,
to_model_config,
tp_size=to_topo.get_dim("tensor"),
pp_size=to_topo.get_dim("pipe"),
dp_size=to_topo.get_dim("data"),
head_param_point_to_embedding=to_model_head_param_point_to_embedding,
)
comm_plan = []
src_dp_size = from_topo.get_dim("data")
src_pp_size = from_topo.get_dim("pipe")
dst_pp_size = to_topo.get_dim("pipe")
# derive a global NCCL communication plan
for (pp_i, pp_j), layer_indices in repart_strat.items():
if len(layer_indices) == 0:
continue
for tp_i in range(src_tp_size):
if dst_tp_size > src_tp_size:
factor = dst_tp_size // src_tp_size
tp_js = [i + factor * tp_i for i in range(factor)]
receiver_tp_portion_id = 0
else:
factor = src_tp_size // dst_tp_size
tp_js = [tp_i // factor]
receiver_tp_portion_id = tp_i % factor
for sender_tp_portion_id, tp_j in enumerate(tp_js):
for dp_i in range(src_dp_size):
key = ParamReallocPair(
src=from_model_name,
src_dp_rank=dp_i,
src_tp_rank=tp_i,
src_pp_rank=pp_i,
dst=to_model_name,
dst_tp_rank=tp_j,
dst_pp_rank=pp_j,
)
src = pg_info.param_realloc_src_ranks[key]
group = pg_info.param_realloc_groups[key]
dst_ranks = pg_info.param_realloc_dst_ranks[key]
param_keys = None
param_intervals_cpu = receiver_param_intervals_cpu = None
param_intervals_cuda = receiver_param_intervals_cuda = None
max_interval_size = max_receiver_interval_size = None
param_keys = keys_from_layer_indices(
from_model_config, layer_indices
)
param_size = param_size_from_keys(
config=from_model_config,
src_tp_size=src_tp_size,
sd_keys=param_keys,
src2dst_tp_size=max(dst_tp_size // src_tp_size, 1),
src2dst_tp_rank=sender_tp_portion_id,
head_param_point_to_embedding=from_model_head_param_point_to_embedding,
)
if torch.distributed.is_initialized():
# torch.distributed is not initialized when estimating param realloc cost
if torch.distributed.get_rank() == src:
param_intervals_cpu = param_intervals_from_keys(
model_name=from_model_name,
config=from_model_config,
tp_size=src_tp_size,
param_spec=from_model_param_specs,
sd_keys=param_keys,
portion_size=max(dst_tp_size // src_tp_size, 1),
portion_rank=sender_tp_portion_id,
head_param_point_to_embedding=from_model_head_param_point_to_embedding,
)
param_intervals_cuda = torch.tensor(
param_intervals_cpu,
dtype=torch.long,
device=constants.current_device(),
)
max_interval_size = max(
y - x for x, y in param_intervals_cpu
)
if torch.distributed.get_rank() in dst_ranks:
receiver_param_intervals_cpu = param_intervals_from_keys(
model_name=to_model_name,
config=to_model_config,
tp_size=dst_tp_size,
param_spec=to_model_param_specs,
sd_keys=param_keys,
portion_size=max(src_tp_size // dst_tp_size, 1),
portion_rank=receiver_tp_portion_id,
head_param_point_to_embedding=to_model_head_param_point_to_embedding,
)
receiver_param_intervals_cuda = torch.tensor(
receiver_param_intervals_cpu,
dtype=torch.long,
device=constants.current_device(),
)
max_receiver_interval_size = max(
y - x for x, y in receiver_param_intervals_cpu
)
for dst_rank in dst_ranks:
comm_plan.append(
ReparallelizeReceiverStep(
rank=dst_rank,
sender_tp_portion_id=sender_tp_portion_id,
receiver_tp_portion_id=receiver_tp_portion_id,
param_keys=param_keys,
sender_param_intervals_cpu=param_intervals_cpu,
sender_param_intervals_cuda=param_intervals_cuda,
sender_max_interval_size=max_interval_size,
receiver_param_intervals_cpu=receiver_param_intervals_cpu,
receiver_param_intervals_cuda=receiver_param_intervals_cuda,
receiver_max_interval_size=max_receiver_interval_size,
param_size=param_size,
param_dtype=dtype,
src=src,
dst_ranks=dst_ranks,
group=group,
)
)
comm_plan.append(
ReparallelizeSenderStep(
rank=src,
sender_tp_portion_id=sender_tp_portion_id,
receiver_tp_portion_id=receiver_tp_portion_id,
param_keys=param_keys,
param_intervals_cpu=param_intervals_cpu,
param_intervals_cuda=param_intervals_cuda,
max_interval_size=max_interval_size,
param_size=param_size,
group=group,
dst_ranks=dst_ranks,
)
)
for i, step in enumerate(comm_plan):
if isinstance(step, ReparallelizeReceiverStep):
continue
step: ReparallelizeSenderStep
required_by_nex_steps = False
for nex_step in comm_plan[i + 1 :]:
if (
isinstance(nex_step, ReparallelizeSenderStep)
and nex_step.rank == step.rank
and nex_step.param_keys == step.param_keys
):
required_by_nex_steps = True
break
step.remove = not required_by_nex_steps
return comm_plan
@dataclasses.dataclass
class ReparallelizeTraget:
comm_plan: List[Union[ReparallelizeSenderStep, ReparallelizeReceiverStep]]
to_param_spec: Dict[str, ContiguousParamSpec]
to_param_size: int
to_layers_handle: nn.ModuleList
to_layer_start_idx: int
to_layer_end_idx: int