mirror of https://github.com/inclusionAI/AReaL
361 lines
13 KiB
Python
361 lines
13 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
|
|
import asyncio
|
|
import dataclasses
|
|
import itertools
|
|
from collections import defaultdict
|
|
from typing import *
|
|
|
|
from realhf.api.cli_args import ClusterSpecConfig
|
|
|
|
|
|
class GlobalStorageTracker:
|
|
def __init__(self, world_size: int):
|
|
self.lock = asyncio.Lock()
|
|
self.storages: List[Dict[Hashable, List[str]]]
|
|
self.storages = [{} for _ in range(world_size)]
|
|
self.data_owner: Dict[Tuple[Hashable, str], int]
|
|
self.data_owner = {}
|
|
|
|
async def add_data(self, rank: int, ids: List[Hashable], key: str, is_owner: bool):
|
|
async with self.lock:
|
|
for data_id in ids:
|
|
if data_id not in self.storages[rank]:
|
|
self.storages[rank][data_id] = [key]
|
|
else:
|
|
if key not in self.storages[rank][data_id]:
|
|
self.storages[rank][data_id].append(key)
|
|
if is_owner:
|
|
self.data_owner[(data_id, key)] = rank
|
|
|
|
def add_data_synced(self, rank: int, ids: List[Hashable], key: str, is_owner: bool):
|
|
for data_id in ids:
|
|
if data_id not in self.storages[rank]:
|
|
self.storages[rank][data_id] = [key]
|
|
else:
|
|
if key not in self.storages[rank][data_id]:
|
|
self.storages[rank][data_id].append(key)
|
|
if is_owner:
|
|
self.data_owner[(data_id, key)] = rank
|
|
|
|
async def clear_data(self, ids: List[Hashable]):
|
|
async with self.lock:
|
|
for storage in self.storages:
|
|
for i in ids:
|
|
if i in storage:
|
|
storage.pop(i)
|
|
keys = list(self.data_owner.keys())
|
|
for i, k in keys:
|
|
if i in ids:
|
|
self.data_owner.pop((i, k))
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class RedistribStep:
|
|
comm_type: str
|
|
root: int | None
|
|
srcs: List[int] | None
|
|
dsts: List[int] | None
|
|
ids: List[List[Hashable]]
|
|
keys: List[str]
|
|
|
|
def __repr__(self) -> str:
|
|
if self.comm_type == "gather":
|
|
return f"Gather {self.keys} to {self.root} from {self.srcs}."
|
|
if self.comm_type == "scatter":
|
|
return f"Scatter {self.keys} from {self.root} to {self.dsts}."
|
|
if self.comm_type == "bcast":
|
|
return f"Bcast {self.keys} from {self.root} to {self.dsts}."
|
|
raise NotImplementedError()
|
|
|
|
|
|
class RedistribPlanner:
|
|
def __init__(
|
|
self, cluster_config: ClusterSpecConfig, storage_tracker: GlobalStorageTracker
|
|
):
|
|
self.cluster_config = cluster_config
|
|
self.storage_tracker = storage_tracker
|
|
|
|
def derive_plan(
|
|
self,
|
|
dests: Dict[int, List[Hashable]],
|
|
keys: List[str],
|
|
pattern: str = "gather-scatter",
|
|
) -> List[RedistribStep]:
|
|
if pattern == "gather-scatter":
|
|
return self.derive_plan_gather_scatter(dests, keys)
|
|
elif pattern == "bcast":
|
|
return self.derive_plan_bcast(dests, keys)
|
|
raise NotImplementedError(f"Unknown data redistribution pattern: {pattern}")
|
|
|
|
def derive_plan_gather_scatter(
|
|
self, dests: Dict[int, List[Hashable]], keys: List[str]
|
|
) -> List[RedistribStep]:
|
|
self.dests = dests
|
|
|
|
all_data_ids = set()
|
|
for all_samples in dests.values():
|
|
for data_id in all_samples:
|
|
all_data_ids.add(data_id)
|
|
|
|
transfer_plan = []
|
|
for key in keys:
|
|
owners = sorted(
|
|
list(
|
|
set(
|
|
[
|
|
self.storage_tracker.data_owner[(i, key)]
|
|
for i in all_data_ids
|
|
]
|
|
)
|
|
)
|
|
)
|
|
gather_ids = []
|
|
for owner in owners:
|
|
this_owner_ids = []
|
|
for i in all_data_ids:
|
|
if (
|
|
i in self.storage_tracker.storages[owner]
|
|
and key in self.storage_tracker.storages[owner][i]
|
|
):
|
|
this_owner_ids.append(i)
|
|
gather_ids.append(sorted(this_owner_ids))
|
|
gather_step = RedistribStep(
|
|
comm_type="gather",
|
|
root=owners[0],
|
|
srcs=owners,
|
|
dsts=None,
|
|
ids=gather_ids,
|
|
keys=[key],
|
|
)
|
|
|
|
scatter_dsts = sorted([i for i in dests if dests[i]])
|
|
scatter_ids = [sorted(dests[i]) for i in scatter_dsts]
|
|
scatter_step = RedistribStep(
|
|
comm_type="scatter",
|
|
root=owners[0],
|
|
dsts=scatter_dsts,
|
|
srcs=None,
|
|
ids=scatter_ids,
|
|
keys=[key],
|
|
)
|
|
transfer_plan += [gather_step, scatter_step]
|
|
|
|
# Prune the plan.
|
|
pop_indices = []
|
|
for idx, step in enumerate(transfer_plan):
|
|
# 1. Omit the gather step if data has already been gathered before.
|
|
if step.comm_type == "gather":
|
|
all_gather_ids = list(itertools.chain.from_iterable(step.ids))
|
|
key = step.keys[0]
|
|
if any(
|
|
i not in self.storage_tracker.storages[step.root]
|
|
for i in all_gather_ids
|
|
):
|
|
continue
|
|
if any(
|
|
key not in self.storage_tracker.storages[step.root][i]
|
|
for i in all_gather_ids
|
|
):
|
|
continue
|
|
pop_indices.append(idx)
|
|
# 2. Omit the gather + scatter step if data has already exists in all dst GPUs.
|
|
if step.comm_type == "scatter":
|
|
key = step.keys[0]
|
|
all_exists = True
|
|
for dst, ids in zip(step.dsts, step.ids):
|
|
if any(i not in self.storage_tracker.storages[dst] for i in ids):
|
|
all_exists = False
|
|
break
|
|
if any(
|
|
key not in self.storage_tracker.storages[dst][i] for i in ids
|
|
):
|
|
all_exists = False
|
|
break
|
|
if all_exists:
|
|
pop_indices.append(idx)
|
|
pop_indices.append(idx - 1)
|
|
for pop_idx in reversed(sorted(set(pop_indices))):
|
|
transfer_plan.pop(pop_idx)
|
|
|
|
# Merging the gather/scatter of different keys
|
|
gather_plan = {}
|
|
scatter_plan = {}
|
|
for step in transfer_plan:
|
|
if step.comm_type == "gather":
|
|
plan_key = (
|
|
step.root,
|
|
tuple(sorted(step.srcs)),
|
|
tuple([tuple(sorted(ids)) for ids in step.ids]),
|
|
)
|
|
if plan_key not in gather_plan:
|
|
gather_plan[plan_key] = step
|
|
else:
|
|
assert all(
|
|
k not in gather_plan[plan_key].keys for k in step.keys
|
|
), (
|
|
gather_plan[plan_key],
|
|
step,
|
|
plan_key,
|
|
)
|
|
gather_plan[plan_key].keys += step.keys
|
|
if step.comm_type == "scatter":
|
|
plan_key = (
|
|
step.root,
|
|
tuple(sorted(step.dsts)),
|
|
tuple([tuple(sorted(ids)) for ids in step.ids]),
|
|
)
|
|
if plan_key not in scatter_plan:
|
|
scatter_plan[plan_key] = step
|
|
else:
|
|
assert all(
|
|
k not in scatter_plan[plan_key].keys for k in step.keys
|
|
), (
|
|
scatter_plan[plan_key],
|
|
step,
|
|
plan_key,
|
|
)
|
|
scatter_plan[plan_key].keys += step.keys
|
|
|
|
# Prioritize gather over scatter
|
|
return list(gather_plan.values()) + list(scatter_plan.values())
|
|
|
|
def derive_plan_bcast(
|
|
self, dests: Dict[int, List[Hashable]], keys: List[str] | Tuple[str]
|
|
) -> List[RedistribStep]:
|
|
assert isinstance(keys, (list, tuple)), type(keys)
|
|
keys = list(keys)
|
|
self.dests = dests
|
|
|
|
# Get all requried data IDs.
|
|
all_data_ids = set()
|
|
for all_samples in self.dests.values():
|
|
for data_id in all_samples:
|
|
all_data_ids.add(data_id)
|
|
|
|
# The producers for each required data.
|
|
id2gpu_src = {}
|
|
for data_id in all_data_ids:
|
|
for key in keys:
|
|
id2gpu_src[(data_id, key)] = []
|
|
for gpu_id, ids2keys in enumerate(self.storage_tracker.storages):
|
|
if data_id in ids2keys and key in ids2keys[data_id]:
|
|
id2gpu_src[(data_id, key)].append(gpu_id)
|
|
|
|
# The consumers for each requried data.
|
|
id2gpu_dst = {}
|
|
for data_id in all_data_ids:
|
|
for key in keys:
|
|
id2gpu_dst[(data_id, key)] = []
|
|
for gpu_id, ids in self.dests.items():
|
|
if data_id in ids:
|
|
id2gpu_dst[(data_id, key)].append(gpu_id)
|
|
|
|
self.transfer_plan = {}
|
|
|
|
for data_id, key in itertools.product(all_data_ids, keys):
|
|
source_gpus = id2gpu_src[(data_id, key)]
|
|
target_gpus = id2gpu_dst[(data_id, key)]
|
|
|
|
assert len(source_gpus) > 0, (data_id, key, id2gpu_src, id2gpu_dst)
|
|
|
|
# Omit data transfer if it exists in the target GPU
|
|
target_gpus = [gpu for gpu in target_gpus if gpu not in source_gpus]
|
|
if not target_gpus:
|
|
continue
|
|
|
|
# Find the "nearest" GPU for data transfer.
|
|
best_src = self._select_best_bcast_source(source_gpus, target_gpus)
|
|
|
|
self.transfer_plan[(data_id, key)] = {"src": best_src, "dsts": target_gpus}
|
|
|
|
return self._group_bcast_transfers()
|
|
|
|
def _on_same_node(self, i, j) -> bool:
|
|
return (i // self.cluster_config.n_gpus_per_node) == (
|
|
j // self.cluster_config.n_gpus_per_node
|
|
)
|
|
|
|
def _select_best_bcast_source(self, source_gpus, target_gpus):
|
|
same_node_counts = {}
|
|
for src in source_gpus:
|
|
same_node_count = sum(
|
|
1 for dst in target_gpus if self._on_same_node(src, dst)
|
|
)
|
|
same_node_counts[src] = same_node_count
|
|
|
|
# Find the source that maximizes locality.
|
|
max_same_node = max(same_node_counts.values())
|
|
best_sources = [
|
|
src for src, count in same_node_counts.items() if count == max_same_node
|
|
]
|
|
|
|
# Find the source with the smallest workload.
|
|
src_load = defaultdict(int)
|
|
for plan in self.transfer_plan.values():
|
|
src_gpu = plan["src"]
|
|
src_load[src_gpu] += len(plan["dsts"])
|
|
return min(best_sources, key=lambda src: src_load[src])
|
|
|
|
def _group_bcast_transfers(self) -> List[RedistribStep]:
|
|
# Group data ids that should be transferred from "src" to "dsts"
|
|
src_to_transfers = defaultdict(lambda: defaultdict(list))
|
|
for (data_id, key), plan in self.transfer_plan.items():
|
|
src_to_transfers[(plan["src"], key)][tuple(sorted(plan["dsts"]))].append(
|
|
data_id
|
|
)
|
|
|
|
stages = []
|
|
while any(src_to_transfers.values()):
|
|
stage = []
|
|
used_dsts = set()
|
|
used_srcs = set()
|
|
|
|
for (src, key), transfers in list(src_to_transfers.items()):
|
|
if src in used_srcs:
|
|
continue
|
|
if not transfers:
|
|
continue
|
|
|
|
# Find a transfer that can be concurrent executed.
|
|
pop_key = None
|
|
for i, dsts in enumerate(transfers):
|
|
if not any(dst in used_dsts for dst in dsts):
|
|
pop_key = dsts
|
|
break
|
|
|
|
if pop_key is not None:
|
|
data_ids = transfers.pop(pop_key)
|
|
stage.append(
|
|
RedistribStep(
|
|
comm_type="bcast",
|
|
root=src,
|
|
srcs=[src],
|
|
keys=[key],
|
|
dsts=pop_key,
|
|
ids=data_ids,
|
|
)
|
|
)
|
|
used_dsts.update(dsts)
|
|
used_srcs.add(src)
|
|
|
|
if stage:
|
|
stages += stage
|
|
else:
|
|
for (src, key), transfers in list(src_to_transfers.items()):
|
|
if transfers:
|
|
dsts, data_ids = transfers.pop(0)
|
|
stages.append(
|
|
RedistribStep(
|
|
comm_type="bcast",
|
|
srcs=[src],
|
|
root=src,
|
|
dsts=dsts,
|
|
ids=data_ids,
|
|
keys=[key],
|
|
)
|
|
)
|
|
break
|
|
|
|
return stages
|