PullRequest: 9 Refactoring data transfer for v2 workers.

Merge branch fw/datatransfer-v2 of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/9

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* fw/fix-dataloading-not-shuffle
* .
* .
* .
* .
* .
* add v2 master worker
* cpu test pass
* ppo run
* .
* pass sft test
* pass ppo dp test
* format
* fix
* run
* .
* cleanup
* .
* format
* run
* merge and format
* refactor
* sft pass
* .
* format
* format
* format
* .
* .
This commit is contained in:
博惟 2025-03-08 18:26:24 +08:00
parent b3bedd7b9d
commit ca42e43638
25 changed files with 1386 additions and 2126 deletions

View File

@ -39,7 +39,6 @@ RUN cd /vllm && \
python3 use_existing_torch.py && \
pip3 install -r requirements-build.txt && \
MAX_JOBS=64 pip3 install -e . --no-build-isolation
RUN yes | pip3 uninstall uvloop
RUN pip3 install opencv-python-headless==4.5.4.58
RUN apt-get update && apt-get install -y python3.10-venv

View File

@ -63,14 +63,7 @@ def load_problems_with_testcase_batch(path, debug=False, test_case_batch_size=No
return problem_map
global_problems = load_problems_with_testcase_batch(
os.getenv(
"REAL_CODE_METADATA_PATH",
"/storage/datasets/codeparrot-apps-test.jsonl",
),
debug=True,
test_case_batch_size=20,
)
global_problems = None
def code_verify(generateds, query_ids, debug=False, timeout=20, timeout_for_testcase=6):
@ -81,6 +74,15 @@ def code_verify(generateds, query_ids, debug=False, timeout=20, timeout_for_test
payload_list = []
global global_problems
if global_problems is None:
global_problems = load_problems_with_testcase_batch(
os.getenv(
"REAL_CODE_METADATA_PATH",
"/storage/datasets/codeparrot-apps-test.jsonl",
),
debug=True,
test_case_batch_size=20,
)
for idx, query_id in enumerate(query_ids):
if query_id not in global_problems:
payload_list.append(None)

View File

@ -19,17 +19,19 @@ def loadJson(dataDir):
return samples
id2info = loadJson(
os.getenv(
"REAL_MATH_MEATADATA_PATH",
"/storage/datasets/id2info.json",
)
)
id2info = None
def math_verify(generateds: List, query_ids: List, batch_size=20, timeout=60) -> List:
start_time = time.time()
global id2info
if id2info is None:
id2info = loadJson(
os.getenv(
"REAL_MATH_MEATADATA_PATH",
"/storage/datasets/id2info.json",
)
)
assert len(generateds) == len(query_ids), (
len(generateds),
len(query_ids),

View File

@ -396,6 +396,7 @@ class SequenceSample:
group_indices = datapack.ffd_allocate(
lens, mb_spec.max_tokens_per_mb, min_groups=mb_spec.n_mbs
)
group_indices = sorted([sorted(g) for g in group_indices])
forward_indices = datapack.flat2d(group_indices)
sample = SequenceSample.reorder(self, forward_indices)

View File

@ -463,13 +463,12 @@ class ExperimentConfig:
data_transfer_pairs.append((mn1, mn2))
src_rpcs = [rpc for rpc in self.model_rpcs if rpc.is_src]
data_src_rpc = src_rpcs[0]
for r in src_rpcs[1:]:
for r in src_rpcs:
if (
data_src_rpc.model_name,
r.model_name,
) not in data_transfer_pairs:
data_transfer_pairs.append((data_src_rpc.model_name, r.model_name))
data_transfer_pairs += [(mn, mn) for mn in model_names]
return data_transfer_pairs
def _resolve_param_realloc_pairs(

View File

@ -162,7 +162,6 @@ def main_start(args, recover_count: int = 0):
REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0",
REAL_MATH_METADATA_PATH=os.getenv("REAL_MATH_METADATA_PATH", ""),
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
REAL_USE_V2_WORKER=os.getenv("REAL_USE_V2_WORKER", "0"),
)
for k, v in BASE_ENVIRONS.items():
os.environ[k] = v

View File

@ -1,80 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import dataclasses
import sys
import threading
from asyncio.base_events import _run_until_complete_cb
@dataclasses.dataclass
class AsyncRunUntilCompleteContext:
loop: asyncio.BaseEventLoop
future: asyncio.Future
new_task: bool
def setup_run_until_complete(
loop: asyncio.BaseEventLoop,
future: asyncio.Future,
) -> AsyncRunUntilCompleteContext:
loop._check_closed()
loop._check_running()
new_task = not asyncio.futures.isfuture(future)
future = asyncio.tasks.ensure_future(future, loop=loop)
if new_task:
# An exception is raised if the future didn't complete, so there
# is no need to log the "destroy pending task" message
future._log_destroy_pending = False
future.add_done_callback(_run_until_complete_cb)
# set up run forever
loop._set_coroutine_origin_tracking(loop._debug)
loop._old_agen_hooks = sys.get_asyncgen_hooks()
loop._thread_id = threading.get_ident()
sys.set_asyncgen_hooks(
firstiter=loop._asyncgen_firstiter_hook,
finalizer=loop._asyncgen_finalizer_hook,
)
asyncio.events._set_running_loop(loop)
return AsyncRunUntilCompleteContext(loop=loop, future=future, new_task=new_task)
def teardown_run_util_complete(ctx: AsyncRunUntilCompleteContext):
ctx.loop._stopping = False
ctx.loop._thread_id = None
asyncio.events._set_running_loop(None)
ctx.loop._set_coroutine_origin_tracking(False)
# Restore any pre-existing async generator hooks.
if ctx.loop._old_agen_hooks is not None:
sys.set_asyncgen_hooks(*ctx.loop._old_agen_hooks)
ctx.loop._old_agen_hooks = None
ctx.future.remove_done_callback(_run_until_complete_cb)
if not ctx.future.done():
raise RuntimeError("Event loop stopped before Future completed.")
def raise_asyncio_exception(
ctx: AsyncRunUntilCompleteContext, raise_error: bool = True
):
if ctx.new_task and ctx.future.done() and not ctx.future.cancelled():
# The coroutine raised a BaseException. Consume the exception
# to not log a warning, the caller doesn't have access to the
# local task.
ctx.future.exception()
try:
teardown_run_util_complete(ctx)
except RuntimeError as e:
if raise_error:
raise e
if raise_error:
raise

View File

@ -229,6 +229,8 @@ def resolve_rpc_hooks(
class AllocationType(enum.Enum):
DECOUPLED = 1
GLOBAL_HYBRID = 2
MANUAL = 3
SEARCH = 4
@dataclasses.dataclass
@ -244,6 +246,10 @@ class AllocationMode:
@classmethod
def from_str(cls, allocation_mode: str):
if allocation_mode == "manual":
return cls(AllocationType.MANUAL, None)
if allocation_mode == "search":
return cls(AllocationType.SEARCH, None)
alloc_3d = AllocationMode.extract_3d_alloc(allocation_mode)
alloc_hybrid = AllocationMode.extract_key_value_alloc(allocation_mode)
alloc_decoupled = AllocationMode.extract_decoupled_alloc(allocation_mode)

View File

@ -82,7 +82,6 @@ class PromptDataset(torch.utils.data.Dataset):
ids=[self.ids[idx]],
seqlens=[self.prompt_lengths[idx]],
data=dict(packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long)),
metadata=dict(random_id=[uuid.uuid4()]),
)
@ -178,7 +177,6 @@ class MATHPromptDataset(torch.utils.data.Dataset):
[self.base_scores[idx]], dtype=torch.float32
),
),
metadata=dict(random_id=[uuid.uuid4()]),
)
else:
return data_api.SequenceSample.from_default(
@ -187,7 +185,6 @@ class MATHPromptDataset(torch.utils.data.Dataset):
data=dict(
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long)
),
metadata=dict(random_id=[uuid.uuid4()]),
)
def filter(self, eval_scores: Dict[Hashable, float]):

View File

@ -1,344 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
import itertools
from typing import *
import torch
import torch.distributed as dist
from realhf.api.core.config import ModelName, ModelShardID
from realhf.api.core.data_api import SequenceSample
from realhf.base import constants, topology
from realhf.impl.model.comm.global_comm import filter_match_mwids
from realhf.impl.model.comm.param_realloc import pipeline_repartition_strategy
BCAST_CHUNK_SIZE_BYTES = int(4e6)
@dataclasses.dataclass(unsafe_hash=True)
class DataTransferPair:
src: ModelName
src_dp_rank: int
dst: ModelName
dst_dp_rank: int
@dataclasses.dataclass
class DataTransferInfo:
# Groups for data transfer among model workers.
data_transfer_groups: Dict[DataTransferPair, dist.ProcessGroup]
data_transfer_src_ranks: Dict[DataTransferPair, int]
data_transfer_dst_ranks: Dict[DataTransferPair, List[int]]
def setup_data_transfer(
model_topos: Optional[Dict[str, topology.PipeModelDataParallelTopology]] = None,
msid2mwid: Optional[Dict[ModelShardID, int]] = None,
data_transfer_pairs: Optional[List[Tuple[ModelName, ModelName]]] = None,
) -> DataTransferInfo:
# Stores the ranks given a (model_name, dp_rank) pair.
# These workers correspond to a complete set of model parameters sharded by TP+PP.
mw_dp_ranks: Dict[Tuple[ModelName, int], List[int]] = {}
# Stores the dp_head (i.e., mp_rank=0, pp_rank=-1) ranks given a model_name.
mw_dp_head_ranks: Dict[ModelName, List[int]] = {}
if model_topos is not None:
assert msid2mwid is not None
for model_name, topo in model_topos.items():
mw_dp_head_ranks[model_name] = filter_match_mwids(
model_name,
topo,
msid2mwid,
pipe=topo.get_dim("pipe") - 1,
model=0,
)
dp_size = topo.get_dim("data")
for dp_i in range(dp_size):
mw_dp_ranks[model_name, dp_i] = filter_match_mwids(
model_name,
topo,
msid2mwid,
data=dp_i,
)
data_transfer_groups, data_transfer_src_ranks = {}, {}
data_transfer_dst_ranks = {}
if data_transfer_pairs is not None:
for src, dst in data_transfer_pairs:
src_topo = model_topos[src]
dst_topo = model_topos[dst]
# Construct all src-dst pairs, from any src dp rank to any dst dp rank.
# Note that a dp rank corresponds to multiple parameter shards (TP+PP),
# so each pair is a group-to-group communication.
# Since the models in the source group have duplicate data (TP+PP),
# we just use its "head" as the broadcast source,
# and broadcast to all the ranks in the destination group.
for src_dp, dst_dp in itertools.product(
range(src_topo.get_dim("data")), range(dst_topo.get_dim("data"))
):
key = DataTransferPair(
src=src, src_dp_rank=src_dp, dst=dst, dst_dp_rank=dst_dp
)
src_mw_rank = mw_dp_head_ranks[src][src_dp]
dst_mw_ranks = mw_dp_ranks[dst, dst_dp]
data_transfer_dst_ranks[key] = dst_mw_ranks
# The src and dst groups can be disjoint or overlapped.
# If they are disjoint, we need to include the src_mw_rank in the group.
# Otherwise, we only need to include the dst_mw_ranks.
if src_mw_rank not in dst_mw_ranks:
_ranks = [src_mw_rank] + dst_mw_ranks
else:
_ranks = dst_mw_ranks
data_transfer_groups[key] = topology.new_or_get_group(
_ranks, backend="nccl" if constants.use_cuda() else "gloo"
)
data_transfer_src_ranks[key] = src_mw_rank
return DataTransferInfo(
data_transfer_groups=data_transfer_groups,
data_transfer_src_ranks=data_transfer_src_ranks,
data_transfer_dst_ranks=data_transfer_dst_ranks,
)
@dataclasses.dataclass
class DataTransferSenderStep:
rank: int
dst_ranks: List[int]
group: dist.ProcessGroup
key: str
ids: List[int]
@dataclasses.dataclass
class DataTransferReceiverStep:
rank: int
src: int
dst_ranks: List[int]
group: dist.ProcessGroup
key: str
ids: List[int]
def chunked_bcast(buf, src_rank, group, chunk_size_bytes=BCAST_CHUNK_SIZE_BYTES):
shape = buf.shape
chunk_size = chunk_size_bytes // buf.dtype.itemsize
buf = buf.flatten()
for i in range(0, buf.numel(), chunk_size):
s = slice(i, i + chunk_size)
dist.broadcast(buf[s], src=src_rank, group=group)
return buf.view(*shape)
def derive_data_transfer_plan(
keys: List[str],
global_ids: List[int],
consumer_name: ModelName,
consumer_mapping: Dict[int, List[int]],
producer_names: Dict[str, ModelName],
producer_mappings: Dict[Tuple[ModelName, str], Dict[int, List[int]]],
data_transfer_info: DataTransferInfo,
) -> List[DataTransferReceiverStep | DataTransferSenderStep]:
comm_plan = []
for k in keys:
producer_name = producer_names[k]
producer_mapping = producer_mappings[(producer_name, k)]
# partition mapping starts from zero, which is different from buffer indices
repart_strat = pipeline_repartition_strategy(producer_mapping, consumer_mapping)
for (dp_i, dp_j), comm_slots in repart_strat.items():
if len(comm_slots) == 0:
continue
group_key = DataTransferPair(
src=producer_name,
src_dp_rank=dp_i,
dst=consumer_name,
dst_dp_rank=dp_j,
)
bcast_src = data_transfer_info.data_transfer_src_ranks[group_key]
group = data_transfer_info.data_transfer_groups[group_key]
dst_ranks = data_transfer_info.data_transfer_dst_ranks[group_key]
ids = [global_ids[_i] for _i in comm_slots]
for dst_rank in dst_ranks:
comm_plan.append(
DataTransferReceiverStep(
rank=dst_rank,
src=bcast_src,
dst_ranks=dst_ranks,
group=group,
key=k,
ids=ids,
)
)
comm_plan.append(
DataTransferSenderStep(
rank=bcast_src,
dst_ranks=dst_ranks,
group=group,
key=k,
ids=ids,
)
)
return comm_plan
def run_data_transfer(
comm_plan: List[DataTransferReceiverStep | DataTransferSenderStep],
meta_samples: Dict[int, SequenceSample],
storage: Dict[int, SequenceSample],
sent_worker_idx_table: Dict[int, Dict[str, Set[int]]],
received_worker_idx_table: Dict[int, Dict[str, Set[int]]],
) -> Tuple[Set[int], Set[str]]:
device = "cuda" if constants.use_cuda() else "cpu"
for step in comm_plan:
if isinstance(step, DataTransferReceiverStep) and step.rank == dist.get_rank():
ids = step.ids
if step.src == dist.get_rank():
# The receiver is also a sender.
# We can directly use the data without comm.
for _id in step.ids:
if storage[_id].data[step.key] is not None:
storage[_id].data[step.key] = (
storage[_id].data[step.key].to(device)
)
else:
# If we have to receive remote data, we first check whether
# the data has been sent here in previous function calls.
# If so, just fetch it from the cache.
cached = all(
[
set(step.dst_ranks).issubset(
set(received_worker_idx_table[_id][step.key])
)
for _id in ids
]
)
metadata_cached = all(
[
set(step.dst_ranks).issubset(
set(received_worker_idx_table[_id]["__metadata__"])
)
for _id in ids
]
)
if cached:
pass
else:
dtype = meta_samples[ids[0]].dtypes[step.key]
total_len = sum(
sum(meta_samples[_id].seqlens[step.key][0]) for _id in ids
)
trailing_shape = meta_samples[ids[0]].trailing_shapes[step.key]
# Receive data if it is not None.
if trailing_shape is not None:
buf = torch.zeros(
(total_len, *trailing_shape),
dtype=dtype,
device=constants.current_device(),
)
chunked_bcast(buf, step.src, step.group)
else:
buf = None
# Receive metadata if not cached.
if not metadata_cached:
metadatas = [{} for _ in step.ids]
dist.broadcast_object_list(
metadatas, src=step.src, group=step.group
)
# Mark that the data has been received.
for _id in ids:
received_worker_idx_table[_id][step.key].union(step.dst_ranks)
received_worker_idx_table[_id]["__metadata__"].union(
step.dst_ranks
)
# Split the received data and put it into the storage.
offset = 0
for _id, metadata in zip(ids, metadatas):
seqlens = meta_samples[_id].seqlens[step.key]
assert len(seqlens) == 1
seqlen = sum(seqlens[0])
if buf is not None:
vs = buf[offset : offset + seqlen]
else:
vs = None
offset = offset + seqlen
with SequenceSample.disable_validation():
s = SequenceSample(
keys=[step.key],
dtypes={step.key: vs.dtype if vs is not None else None},
trailing_shapes={
step.key: vs.shape[1:] if vs is not None else None
},
ids=[_id],
seqlens={step.key: seqlens},
data={step.key: vs},
metadata=metadata,
)
if _id in storage:
storage[_id].update_(s)
else:
storage[_id] = s
if isinstance(step, DataTransferSenderStep) and step.rank == dist.get_rank():
# Similar to the receiver, we first check whether the data has been sent to all destinations.
cached = all(
[
set(step.dst_ranks).issubset(
set(sent_worker_idx_table[_id][step.key])
)
for _id in step.ids
]
)
metadata_cached = all(
[
set(step.dst_ranks).issubset(
set(sent_worker_idx_table[_id]["__metadata__"])
)
for _id in step.ids
]
)
if cached:
pass
else:
# If not cached, we fetch the data from the storage and send it to all destinations.
for _id in step.ids:
if storage[_id].data[step.key] is not None:
storage[_id].data[step.key] = (
storage[_id].data[step.key].to(device)
)
if all([storage[_id].data[step.key] is not None for _id in step.ids]):
vs = torch.cat(
[storage[_id].data[step.key] for _id in step.ids],
dim=0,
)
chunked_bcast(vs, step.rank, step.group)
if not metadata_cached:
dist.broadcast_object_list(
[storage[_id].metadata for _id in step.ids],
src=step.rank,
group=step.group,
)
for _id in step.ids:
sent_worker_idx_table[_id][step.key].union(step.dst_ranks)
sent_worker_idx_table[_id]["__metadata__"].union(step.dst_ranks)

View File

@ -15,7 +15,6 @@ logger = logging.getLogger("system")
# NOTE: Workers are configured in the following order.
# Take special care when adding a new worker type.
WORKER_TYPES = ["model_worker", "master_worker"]
USE_V2_WORKER = os.getenv("REAL_USE_V2_WORKER", "0") == "1"
def load_worker(worker_type: str) -> Type:
@ -26,8 +25,6 @@ def load_worker(worker_type: str) -> Type:
def worker_type_to_module(worker_type: str):
if worker_type == "master_worker" and USE_V2_WORKER:
return "realhf.system.v2." + worker_type
return "realhf.system." + worker_type

View File

@ -17,32 +17,6 @@ from realhf.api.core.data_api import SequenceSample
logger = logging.getLogger("buffer")
def _extract_intervals(arr):
if len(arr) == 0:
return []
# Initialize the list to hold the intervals
intervals = []
# Start of the first interval
start = arr[0]
for i in range(1, len(arr)):
# Check if the current element is not contiguous with the previous one
if arr[i] != arr[i - 1] + 1:
# End of the current interval
end = arr[i - 1]
# Add the interval as a tuple
intervals.append((start, end + 1))
# Start a new interval
start = arr[i]
# Add the last interval
intervals.append((start, arr[-1] + 1))
return intervals
class BufferFull(Exception):
pass

View File

@ -620,7 +620,6 @@ class RayController:
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
REAL_MATH_METADATA_PATH=os.environ.get("REAL_MATH_METADATA_PATH", ""),
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
REAL_USE_V2_WORKER=os.getenv("REAL_USE_V2_WORKER", "0"),
)
runtime_env = {
"env_vars": env_vars,

View File

@ -0,0 +1,395 @@
# Copyright 2025 Ant Group Inc.
import bisect
import dataclasses
import itertools
from collections import defaultdict
from typing import *
import numpy as np
import torch
import torch.distributed as dist
from realhf import SequenceSample
from realhf.api.core.config import ModelName, ModelShardID
from realhf.base import constants
from realhf.base.topology import PipeModelDataParallelTopology, new_or_get_group
from realhf.impl.model.comm.global_comm import filter_match_mwids
from realhf.system.redistributor import RedistribStep
BCAST_GROUPS = {}
GATHER_GROUPS = {}
SCATTER_GROUPS = {}
class DataManager:
def __init__(
self,
model_topos: Dict[ModelName, PipeModelDataParallelTopology],
msid2mwid: Optional[Dict[ModelShardID, int]] = None,
data_transfer_pairs: Optional[List[Tuple[ModelName, ModelName]]] = None,
):
self.model_topos = model_topos
self.msid2mwid = msid2mwid
self.data_transfer_pairs = data_transfer_pairs
self.storage: Dict[Hashable, SequenceSample] = {}
def setup_process_groups(self):
if self.msid2mwid is None or self.data_transfer_pairs is None:
return
model_topos = self.model_topos
msid2mwid = self.msid2mwid
data_transfer_pairs = self.data_transfer_pairs
# Stores the ranks given a (model_name, dp_rank) pair.
# These workers correspond to a complete set of model parameters sharded by TP+PP.
mw_dp_ranks: Dict[Tuple[ModelName, int], List[int]] = {}
mw_ranks: Dict[ModelName, List[int]] = {}
# Stores the dp_head (i.e., mp_rank=0, pp_rank=-1) ranks given a model_name.
mw_dp_head_ranks: Dict[ModelName, List[int]] = defaultdict(list)
assert msid2mwid is not None
for model_name, topo in model_topos.items():
mw_ranks[model_name] = filter_match_mwids(
model_name,
topo,
msid2mwid,
)
mw_dp_head_ranks[model_name] = filter_match_mwids(
model_name,
topo,
msid2mwid,
pipe=topo.get_dim("pipe") - 1,
model=0,
)
dp_size = topo.get_dim("data")
for dp_i in range(dp_size):
mw_dp_ranks[model_name, dp_i] = filter_match_mwids(
model_name,
topo,
msid2mwid,
data=dp_i,
)
for src, dst in data_transfer_pairs:
src_topo = model_topos[src]
dst_topo = model_topos[dst]
ranks = tuple(sorted(mw_dp_head_ranks[src]))
GATHER_GROUPS[ranks] = new_or_get_group(
list(ranks), backend="nccl" if constants.use_cuda() else "gloo"
)
scatter_ranks = tuple(sorted(set([ranks[0]] + mw_ranks[dst])))
SCATTER_GROUPS[scatter_ranks] = new_or_get_group(
list(scatter_ranks),
backend="nccl" if constants.use_cuda() else "gloo",
)
# Construct all src-dst pairs, from any src dp rank to any dst dp rank.
# Note that a dp rank corresponds to multiple parameter shards (TP+PP),
# so each pair is a group-to-group communication.
# Since the models in the source group have duplicate data (TP+PP),
# we just use its "head" as the broadcast source,
# and broadcast to all the ranks in the destination group.
for src_dp, dst_dp in itertools.product(
range(src_topo.get_dim("data")), range(dst_topo.get_dim("data"))
):
src_mw_rank = mw_dp_head_ranks[src][src_dp]
dst_mw_ranks = mw_dp_ranks[dst, dst_dp]
# The src and dst groups can be disjoint or overlapped.
# If they are disjoint, we need to include the src_mw_rank in the group.
# Otherwise, we only need to include the dst_mw_ranks.
if src_mw_rank not in dst_mw_ranks:
_ranks = [src_mw_rank] + dst_mw_ranks
else:
_ranks = dst_mw_ranks
key = tuple(sorted(_ranks))
BCAST_GROUPS[key] = new_or_get_group(
_ranks, backend="nccl" if constants.use_cuda() else "gloo"
)
def storage_size(self):
return len(self.storage)
def store(self, x: SequenceSample):
assert len(x.ids) == 1
assert x.ids[0] not in self.storage
self.storage[x.ids[0]] = x
def update(self, x: SequenceSample):
self.storage[x.ids[0]].update_(x)
def get(self, data_id: Hashable):
return self.storage[data_id]
def has_data(self, data_id: Hashable):
return data_id in self.storage
def remove(self, ids: List[Hashable]):
for data_id in ids:
if data_id in self.storage:
del self.storage[data_id]
def clear_data(self):
self.storage.clear()
def _bcast_recv(
self,
step: RedistribStep,
data_infos: Dict[Hashable, SequenceSample],
):
assert len(step.keys) == 1
ids = step.ids
key = step.keys[0]
dtype = data_infos[ids[0]].dtypes[key]
total_len = sum(sum(data_infos[_id].seqlens[key][0]) for _id in ids)
trailing_shape = data_infos[ids[0]].trailing_shapes[key]
buf = torch.zeros(
(total_len, *trailing_shape),
dtype=dtype,
device=constants.current_device(),
)
if len(step.dsts) == 1:
dist.recv(buf, src=step.root)
else:
global BCAST_GROUPS
group = BCAST_GROUPS[tuple(sorted([step.root] + list(step.dsts)))]
dist.broadcast(buf, src=step.root, group=group)
# Split the received data and put it into the storage.
offset = 0
for _id in ids:
seqlens = data_infos[_id].seqlens[key]
assert len(seqlens) == 1
seqlen = sum(seqlens[0])
if buf is not None:
vs = buf[offset : offset + seqlen]
else:
vs = None
offset = offset + seqlen
with SequenceSample.disable_validation():
s = SequenceSample(
keys=[key],
dtypes={key: vs.dtype if vs is not None else None},
trailing_shapes={key: vs.shape[1:] if vs is not None else None},
ids=[_id],
seqlens={key: seqlens},
data={key: vs},
metadata={},
)
if _id in self.storage:
self.storage[_id].update_(s)
else:
self.storage[_id] = s
def _bcast_send(self, step: RedistribStep):
ids = step.ids
for _id in ids:
self.storage[_id].to_device(constants.current_device())
assert len(step.keys) == 1
key = step.keys[0]
vs = torch.cat(
[self.storage[_id].data[key] for _id in ids],
dim=0,
)
if len(step.dsts) == 1:
dist.send(vs, dst=step.dsts[0])
else:
global BCAST_GROUPS
group = BCAST_GROUPS[tuple(sorted([step.root] + list(step.dsts)))]
dist.broadcast(vs, src=step.root, group=group)
def _run_bcast(
self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample]
):
if dist.get_rank() in step.dsts:
self._bcast_recv(step, data_infos=data_infos)
if dist.get_rank() == step.root:
self._bcast_send(step)
def _pad_data(self, x: torch.Tensor, maxlen: int):
assert x.dtype == torch.float32
assert len(x.shape) == 1
if maxlen > x.numel():
return torch.nn.functional.pad(x, (0, maxlen - x.numel()), value=0.0)
return x
def _run_gather(
self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample]
):
if dist.get_rank() not in step.srcs:
return
maxlen = 0
for ids in step.ids:
infos = [data_infos[i] for i in ids]
maxlen = max(
maxlen,
sum(
[
sum([sum(info.seqlens[key][0]) for info in infos])
for key in step.keys
]
),
)
if dist.get_rank() == step.root:
gather_list = [
torch.empty(
maxlen, device=constants.current_device(), dtype=torch.float32
)
for _ in range(len(step.srcs))
]
else:
gather_list = None
local_gather_idx = step.srcs.index(dist.get_rank())
ids = step.ids[local_gather_idx]
for i in ids:
self.storage[i].to_device(constants.current_device())
samples = [self.storage[i] for i in ids]
data = torch.cat(
[
sample.data[key].float().flatten()
for sample in samples
for key in step.keys
]
)
data = self._pad_data(data, maxlen)
dist.gather(
data,
gather_list,
dst=step.root,
group=GATHER_GROUPS[tuple(sorted(step.srcs))],
)
if dist.get_rank() != step.root:
return
for ids, buf in zip(step.ids, gather_list):
offset = 0
for i in ids:
for key in step.keys:
seqlen = data_infos[i].seqlens[key][0]
dtype = data_infos[i].dtypes[key]
trailing_shape = data_infos[i].trailing_shapes[key]
size = int(np.prod(trailing_shape) * sum(seqlen))
data = buf[offset : offset + size].to(dtype)
offset += size
# with SequenceSample.disable_validation():
s = SequenceSample(
keys=[key],
dtypes={key: dtype},
trailing_shapes={key: trailing_shape},
ids=[i],
seqlens={key: [seqlen]},
data={key: data},
metadata={},
)
if i in self.storage:
self.storage[i].update_(s)
else:
self.storage[i] = s
def _run_scatter(
self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample]
):
if dist.get_rank() != step.root and dist.get_rank() not in step.dsts:
return
maxlen = 0
for ids in step.ids:
infos = [data_infos[i] for i in ids]
maxlen = max(
maxlen,
sum(
[
sum([sum(info.seqlens[key][0]) for info in infos])
for key in step.keys
]
),
)
buf = torch.empty(
maxlen, device=constants.current_device(), dtype=torch.float32
)
if dist.get_rank() == step.root:
scatter_list = []
for ids in step.ids:
for i in ids:
self.storage[i].to_device(constants.current_device())
samples = [self.storage[i] for i in ids]
data = torch.cat(
[
sample.data[key].float().flatten()
for sample in samples
for key in step.keys
]
)
scatter_list.append(data)
maxlen = max([x.shape[0] for x in scatter_list])
scatter_list = [self._pad_data(x, maxlen) for x in scatter_list]
if step.root not in step.dsts:
idx = bisect.bisect(step.dsts, step.root)
scatter_list.insert(idx, buf)
else:
scatter_list = None
key = tuple(sorted(set([step.root] + step.dsts)))
dist.scatter(buf, scatter_list, src=step.root, group=SCATTER_GROUPS[key])
if dist.get_rank() not in step.dsts:
return
local_dst_idx = step.dsts.index(dist.get_rank())
ids = step.ids[local_dst_idx]
offset = 0
for i in ids:
for key in step.keys:
seqlen = data_infos[i].seqlens[key][0]
dtype = data_infos[i].dtypes[key]
trailing_shape = data_infos[i].trailing_shapes[key]
size = int(np.prod(trailing_shape) * sum(seqlen))
data = buf[offset : offset + size].to(dtype)
offset += size
# with SequenceSample.disable_validation():
s = SequenceSample(
keys=[key],
dtypes={key: dtype},
trailing_shapes={key: trailing_shape},
ids=[i],
seqlens={key: [seqlen]},
data={key: data},
metadata={},
)
if i in self.storage:
self.storage[i].update_(s)
else:
self.storage[i] = s
def redistribute(
self,
data_info: SequenceSample,
plan: List[RedistribStep],
):
data_infos = {x.ids[0]: x for x in data_info.unpack()}
for step in plan:
if step.comm_type == "bcast":
self._run_bcast(step, data_infos)
elif step.comm_type == "gather":
self._run_gather(step, data_infos)
elif step.comm_type == "scatter":
self._run_scatter(step, data_infos)

View File

@ -4,6 +4,7 @@ import random
from typing import *
import networkx as nx
from tensorboardX import SummaryWriter
from realhf.api.core.config import ModelName, ModelShardID
from realhf.api.core.data_api import DataBatchMeta, SequenceSample
@ -12,8 +13,9 @@ from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import logging
from realhf.base.topology import PipeModelDataParallelTopology
from realhf.system.buffer import AsyncIOSequenceBuffer
from realhf.system.model_function_call import ModelFunctionCall, RPCCorountineControl
from realhf.system.redistributor import GlobalStorageTracker, RedistribPlanner
from realhf.system.request_reply_stream import NameResolvingRequestClient
from realhf.system.v2.function_call import FunctionCall, RPCCorountineControl
logger = logging.getLogger(__name__, "system")
blogger = logging.getLogger("benchmark")
@ -29,12 +31,17 @@ class FunctionExecutor:
model_topos: Dict[str, PipeModelDataParallelTopology],
model_configs: Dict[str, None | ReaLModelConfig],
ctrl: RPCCorountineControl,
summary_writer: SummaryWriter | None,
):
self.func_calls: Dict[str, FunctionCall] = {}
self.func_calls: Dict[str, ModelFunctionCall] = {}
self.ctrl = ctrl
self.n_model_workers = len(set(msid2mwid.values()))
self.msid2mwid = msid2mwid
self.storage_tracker = GlobalStorageTracker(self.n_model_workers)
self.redistrib_planner = RedistribPlanner(self.storage_tracker)
self.rpcs = rpcs
self.src_rpc = list(filter(lambda rpc: rpc.is_src, rpcs))[0]
@ -42,7 +49,7 @@ class FunctionExecutor:
# Create model function calls.
for rpc in self.rpcs:
func_call = FunctionCall(
func_call = ModelFunctionCall(
rpc=rpc,
src_rpc=self.src_rpc,
stream=stream,
@ -51,6 +58,8 @@ class FunctionExecutor:
model_configs=model_configs,
ctrl=ctrl,
buffer=buffer,
redistrib_planner=self.redistrib_planner,
summary_writer=summary_writer,
)
self.func_calls[rpc.name] = func_call
@ -92,8 +101,6 @@ class FunctionExecutor:
while self.buffer.size < max(rpc.n_seqs for rpc in self.rpcs):
all_data = []
dp_idx += 1
dp_idx %= self.src_dp_size
@ -110,21 +117,13 @@ class FunctionExecutor:
if x.meta_sample is None:
continue
# Store the owner information of the data.
# RPCs corountines will use this information to
# determine the src and dst of data transfer.
for xx in x.meta_sample.unpack():
if xx.ids[0] in received_ids:
raise ValueError(
f"Duplicate data id {xx.ids[0]}. Is the final batch? {is_final_batch}."
)
raise ValueError(f"Duplicate data id {xx.ids[0]}.")
received_ids.add(xx.ids[0])
for k in xx.keys:
self.ctrl.data_owner[(xx.ids[0], k)] = (
src_rpc_model_name,
dp_idx,
)
all_data += x.meta_sample.unpack()
all_data = x.meta_sample.unpack()
filtered_data = []
for xx in x.meta_sample.unpack():
@ -139,6 +138,16 @@ class FunctionExecutor:
# so we also need to shuffle the data to fuse different dataset splits.
random.shuffle(all_data)
# Update resource tracker for planning data redistribution.
gpu_id = self.stream.route_to(f"__data{dp_idx}__")
for k in all_data[0].keys:
self.storage_tracker.add_data(
gpu_id,
[x.ids[0] for x in all_data],
k,
is_owner=True,
)
# Store into buffer!
buffer_indices = await buffer.put_batch(all_data)
assert len(buffer_indices) == len(all_data)
@ -156,13 +165,23 @@ class FunctionExecutor:
tasks = [loop.create_task(fc.run()) for fc in self.func_calls.values()] + [
loop.create_task(self.flush_calls()),
loop.create_task(self.load_data()),
loop.create_task(self.finish_traverse()),
]
completion_future = loop.create_task(self.finish_traverse())
loop.run_until_complete(completion_future)
for task in tasks:
loop.run_until_complete(task)
loop.run_until_complete(asyncio.gather(*tasks))
logger.info("Execution finished!")
self.clear_gpu_cache()
def clear_gpu_cache(self):
self.stream.request(
handlers=list(range(self.n_model_workers)),
handle_type="clear_data_cache",
datas=[self.ctrl.ids_to_clear for _ in list(range(self.n_model_workers))],
no_syn=True,
)
# Clear resource tracker as well.
self.storage_tracker.clear_data(self.ctrl.ids_to_clear)
self.ctrl.ids_to_clear.clear()

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,8 @@
import asyncio
import dataclasses
import itertools
import json
import os
import time
import uuid
@ -9,6 +11,7 @@ from collections import defaultdict
from typing import Dict, Hashable, List, Set, Tuple
import wandb
from tensorboardX import SummaryWriter
import realhf.api.core.config as config_api
import realhf.api.core.data_api as data_api
@ -16,11 +19,13 @@ import realhf.api.core.dfg as dfg
import realhf.api.core.system_api as config_pkg
import realhf.base.recover as recover
import realhf.system.request_reply_stream as request_reply_stream
from realhf import ModelShardID
from realhf.api.core.config import ModelName
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import constants, logging, topology
from realhf.system.buffer import AsyncIOSequenceBuffer
from realhf.system.flops_counter import FlopsCounter
from realhf.system.redistributor import RedistribPlanner, RedistribStep
logger = logging.getLogger(__name__, "system")
blogger = logging.getLogger("benchmark")
@ -38,10 +43,6 @@ class RPCCorountineControl:
ids_to_clear: Set[Hashable] = dataclasses.field(default_factory=set)
flops_counter: FlopsCounter = dataclasses.field(default_factory=FlopsCounter)
data_owner: Dict[Tuple[int, str], Tuple[ModelName, int]] = dataclasses.field(
default_factory=dict
)
should_save: bool = False
should_eval: bool = False
should_ckpt: bool = False
@ -52,7 +53,7 @@ class RPCCorountineControl:
hash_vals_to_ignore_in_recover: List[int] = dataclasses.field(default_factory=list)
class FunctionCall:
class ModelFunctionCall:
def __init__(
self,
rpc: dfg.MFCDef,
@ -63,16 +64,24 @@ class FunctionCall:
model_configs: Dict[str, None | ReaLModelConfig],
ctrl: RPCCorountineControl,
buffer: AsyncIOSequenceBuffer,
redistrib_planner: RedistribPlanner,
summary_writer: SummaryWriter | None,
):
self.rpc = rpc
self.src_rpc = src_rpc
self.stream = stream
self.n_model_workers = len(set(msid2mwid.values()))
self.msid2mwid = msid2mwid
self.model_topos = model_topos
self.model_configs = model_configs
self.mwid2msids = defaultdict(list)
for msid, mwid in msid2mwid.items():
self.mwid2msids[mwid].append(msid)
self.model_save_root = os.path.join(
constants.MODEL_SAVE_ROOT,
constants.experiment_name(),
@ -80,8 +89,10 @@ class FunctionCall:
)
self.rpc_ctrl = ctrl
self.buffer = buffer
self.redistrib_planner = redistrib_planner
self.summary_writer = summary_writer
@property
def dp_size(self):
@ -172,10 +183,8 @@ class FunctionCall:
def request(
self,
producer_names: Dict[str, str],
producer_name2producer_handlers: Dict[str, List[config_pkg.ModelShardID]],
producer_mappings: Dict[str, Dict[str, List[int]]],
target_mapping: Dict[str, List[int]],
data_transfer_plan: List[RedistribStep],
partitioned_ids: List[Hashable],
meta_sample: data_api.SequenceSample,
handlers: List[config_pkg.ModelShardID],
) -> Tuple[List[uuid.UUID], List[uuid.UUID]]:
@ -184,14 +193,12 @@ class FunctionCall:
ctrl = self.rpc_ctrl
dt_data = {
"keys": rpc.input_keys,
"target": rpc.model_name,
"producer_names": producer_names,
"producer_mappings": producer_mappings,
"target_mapping": target_mapping,
"plan": [json.dumps(dataclasses.asdict(x)) for x in data_transfer_plan],
"partitioned_ids": partitioned_ids,
"handle_name": rpc.interface_type.value,
"rpc_name": rpc.name,
"meta_sample": meta_sample,
"partitioned_ids": partitioned_ids,
}
payloads = {
@ -230,16 +237,27 @@ class FunctionCall:
mwids = [self.msid2mwid[h] for h in handlers]
assert len(mwids) == len(set(mwids))
for producer_name in producer_names.values():
for h in producer_name2producer_handlers[producer_name]:
if self.msid2mwid[h] not in mwids:
payloads[h] = request_reply_stream.Payload(
handler=h,
handle_name="empty",
pre_hooks=["data_transfer"],
pre_hook_data=[dt_data],
)
mwids.append(self.msid2mwid[h])
for step in data_transfer_plan:
if step.root not in mwids:
handler = self.mwid2msids[step.root][0]
payloads[handler] = request_reply_stream.Payload(
handler=handler,
handle_name="empty",
pre_hooks=["data_transfer"],
pre_hook_data=[dt_data],
)
mwids.append(step.root)
if step.comm_type == "gather":
for src in step.srcs:
if src not in mwids:
handler = self.mwid2msids[src][0]
payloads[handler] = request_reply_stream.Payload(
handler=handler,
handle_name="empty",
pre_hooks=["data_transfer"],
pre_hook_data=[dt_data],
)
mwids.append(src)
payloads, mwids = self.attach_payloads_with_hooks(
payloads,
@ -266,9 +284,10 @@ class FunctionCall:
# Dispatch data to different data parallel ranks.
if self.rpc.is_generate():
# The workload of generation is decided by batch size, instead of the generated length.
lens = [1 for _ in range(sample.bs)]
samples, forward_indices, _ = sample.split_with_lengths(
mb_spec=data_api.MicroBatchSpec(n_mbs=self.dp_size),
lens=[1 for _ in range(sample.bs)],
lens=lens,
)
else:
samples, forward_indices, _ = sample.split(
@ -297,26 +316,6 @@ class FunctionCall:
for j in range(topo.world_size())
]
producer_names = {} # data key -> model name
for k in rpc.input_keys:
if k in rpc.data_producers:
producer_names[k] = rpc.data_producers[k]
else:
producer_names[k] = self.src_rpc.model_name
keys_to_send = defaultdict(list) # model name -> List[keys] to send
for k in producer_names:
keys_to_send[producer_names[k]].append(k)
# convert producer model name to ModelShardID
producer_name2producer_handlers = {}
for producer_name in keys_to_send:
producer_name2producer_handlers[producer_name] = [
config_pkg.ModelShardID.from_parallelism_rank(
producer_name, self.model_topos[producer_name], j
)
for j in range(self.model_topos[producer_name].world_size())
]
dp_head_indices = [
topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, model=0)
for i in range(self.dp_size)
@ -330,40 +329,72 @@ class FunctionCall:
buf_indices, sample, partitions = self.data_parallel_dispatch(
buf_indices, sample
)
target_mapping = {i: list(range(v[0], v[1])) for i, v in enumerate(partitions)}
# Set data owner of produced data by this RPC, such that downstream RPCs can know
# where to fetch these data.
for dp_idx, (st, ed) in enumerate(partitions):
for i in range(st, ed):
for k in rpc.output_keys:
self.rpc_ctrl.data_owner[sample.ids[i], k] = (
rpc.model_name,
dp_idx,
# Build data destinations: GPU id -> List[data ids]
partitioned_ids = []
dests = {}
for dp_rank, (st, ed) in enumerate(partitions):
ranks = topo.filter_match(data=dp_rank)
for rank in ranks:
h = config_pkg.ModelShardID.from_parallelism_rank(
model_name=rpc.model_name, topo=topo, parallelism_rank=rank
)
gpu_id = self.msid2mwid[h]
assert gpu_id not in dests
dests[gpu_id] = sample.ids[st:ed]
partitioned_ids.append(sample.ids[st:ed])
for i in range(self.n_model_workers):
if i not in dests:
dests[i] = []
# NOTE: The data loaded from the dataset may be unevenly distributed across DP ranks.
# Only bcast works in this case.
if rpc.is_src:
pattern = "bcast"
else:
pattern = "gather-scatter"
data_transfer_plan = self.redistrib_planner.derive_plan(
dests,
keys=rpc.input_keys,
pattern=pattern,
)
blogger.info(f"Data tranfer plan for `{rpc.name}`: {data_transfer_plan}.")
# Update storage tracker for transferred data.
if rpc.is_src:
# NOTE: since the data we loaded may be unevenly distributed across DP ranks,
# we should change the owner of the data to the src RPC.
for i in range(topo.world_size()):
h = ModelShardID.from_parallelism_rank(
model_name=rpc.model_name, topo=topo, parallelism_rank=i
)
is_dp_head = h.mp_rank == 0 and h.pp_rank == topo.get_dim("pipe") - 1
gpu_id = self.msid2mwid[h]
for key in rpc.input_keys:
self.redistrib_planner.storage_tracker.add_data(
gpu_id, partitioned_ids[h.dp_rank], key=key, is_owner=is_dp_head
)
# Get the data owner of this RPC's input data.
# We use it to determine the source of data transfer.
producer_mappings = {}
for k in rpc.input_keys:
names, dp_indices = [], []
for sample_id in sample.ids:
owner_name, dp_idx = self.rpc_ctrl.data_owner[(sample_id, k)]
names.append(owner_name)
dp_indices.append(dp_idx)
assert len(set(names)) == 1
producer_mapping = defaultdict(list)
for i, dp_idx in enumerate(dp_indices):
producer_mapping[dp_idx].append(i)
producer_mapping = {k: sorted(v) for k, v in producer_mapping.items()}
producer_mappings[names[0], k] = producer_mapping
else:
for step in data_transfer_plan:
if step.comm_type == "scatter":
for gpu_id, ids in zip(step.dsts, step.ids):
for key in step.keys:
self.redistrib_planner.storage_tracker.add_data(
gpu_id, ids, key=key, is_owner=False
)
elif step.comm_type == "gather":
for key in step.keys:
self.redistrib_planner.storage_tracker.add_data(
step.root,
list(itertools.chain.from_iterable(step.ids)),
key=key,
is_owner=False,
)
# send partitioned data to model workers
req_ids, other_req_ids = self.request(
producer_names=producer_names,
producer_name2producer_handlers=producer_name2producer_handlers,
producer_mappings=producer_mappings,
target_mapping=target_mapping,
data_transfer_plan=data_transfer_plan,
partitioned_ids=partitioned_ids,
meta_sample=sample,
handlers=handlers,
)
@ -384,6 +415,22 @@ class FunctionCall:
# model function calls. The data shoulbe be amended into buffer.
# Otherwise, it's the train statistics and should be reduced and logged.
if isinstance(responses[-1], data_api.SequenceSample):
# Update storage tracker for generated data.
for dp_rank, x in enumerate(responses):
pp_size = topo.get_dim("pipe")
ranks = topo.filter_match(data=dp_rank, pipe=pp_size - 1, model=0)
for rank in ranks:
h = config_pkg.ModelShardID.from_parallelism_rank(
model_name=rpc.model_name, topo=topo, parallelism_rank=rank
)
gpu_id = self.msid2mwid[h]
for k in rpc.output_keys:
self.redistrib_planner.storage_tracker.add_data(
gpu_id,
x.ids,
key=k,
is_owner=True,
)
res = data_api.SequenceSample.gather(responses)
else:
res = data_api.gather_stat(responses)
@ -393,6 +440,11 @@ class FunctionCall:
if isinstance(res, Dict):
wandb.log(res, step=ctrl.step_info.global_step)
if self.summary_writer is not None:
for key, val in res.items():
self.summary_writer.add_scalar(
f"{key}", val, ctrl.step_info.global_step
)
logger.info(
f"Model rpc {rpc.name} finished. "
@ -418,7 +470,6 @@ class FunctionCall:
async def run(self):
rpc = self.rpc
topo = self.model_topos[rpc.model_name]
ctrl = self.rpc_ctrl
logger.info(
f"Running Model RPC, interface_type=#{rpc.interface_type}# "

View File

@ -30,7 +30,6 @@ import torch.utils.data
import realhf.api.core.dfg as dfg
import realhf.api.core.system_api as system_api
import realhf.impl.model.comm.data_transfer as data_transfer_comm
import realhf.impl.model.comm.global_comm as global_comm
import realhf.impl.model.comm.param_realloc as param_realloc_comm
from realhf.api.core.config import ModelName
@ -49,11 +48,12 @@ from realhf.base.monitor import (
cuda_tmark,
cuda_tmarked,
dump_tmark_db,
gpu_utilization_monitor,
)
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.utils import cuda_graph
from realhf.system import request_reply_stream, worker_base
from realhf.system.data_manager import DataManager
from realhf.system.redistributor import RedistribStep
# NOTE: Register all implemented datasets and models.
import realhf.api.core.data_api as data_api # isort:skip
@ -257,11 +257,12 @@ class ModelWorker(worker_base.Worker):
msid2mwid=self.config.msid2mwid,
)
self.__data_transfer_info = data_transfer_comm.setup_data_transfer(
self.data_manager = DataManager(
model_topos=self.config.model_topos,
msid2mwid=self.config.msid2mwid,
data_transfer_pairs=self.config.data_transfer_pairs,
)
self.data_manager.setup_process_groups()
self.__param_realloc_info = param_realloc_comm.setup_param_realloc(
model_topos=self.config.model_topos,
@ -466,17 +467,6 @@ class ModelWorker(worker_base.Worker):
self.__reply_queue = queue.Queue(maxsize=8)
self.__request_sample_size = dict()
# Storing data loaded from the dataset and outputs of the
# model function call.
self.__data_storage: Dict[int, data_api.SequenceSample] = {}
self.__data_sent_worker_indices: Dict[int, Dict[str, Set]] = (
collections.defaultdict(lambda: collections.defaultdict(set))
)
self.__data_received_worker_indices: Dict[int, Dict[str, Set]] = (
collections.defaultdict(lambda: collections.defaultdict(set))
)
self.__compute_input_queues = {
model_name: dict(
train_step=queue.Queue(4),
@ -629,10 +619,10 @@ class ModelWorker(worker_base.Worker):
# Defer data that has not been used in the previous epoch.
data_loaded = []
for x in cur_sample.unpack():
if x.ids[0] in self.__data_storage:
if self.data_manager.has_data(x.ids[0]):
continue
data_loaded.append(x)
self.__data_storage[x.ids[0]] = x
self.data_manager.store(x)
assert len(set([x.ids[0] for x in data_loaded])) == len(data_loaded)
if len(data_loaded) > 0:
@ -653,13 +643,7 @@ class ModelWorker(worker_base.Worker):
elif request.handle_name == "clear_data_cache":
with cuda_tmarked("clear_data_cache", CUDATimeMarkType.misc):
ids = request.data
for _id in ids:
if _id in self.__data_storage:
del self.__data_storage[_id]
if _id in self.__data_sent_worker_indices:
del self.__data_sent_worker_indices[_id]
if _id in self.__data_received_worker_indices:
del self.__data_received_worker_indices[_id]
self.data_manager.remove(ids)
gc.collect()
if (
self.config.cuda_cache_cleanliness
@ -673,7 +657,7 @@ class ModelWorker(worker_base.Worker):
)
logger.info(
"Get clear_data_cache, dump cuda tmark. "
f"Remaining data in local storage: {len(self.__data_storage)}. "
f"Remaining data in local storage: {self.data_manager.storage_size()}. "
)
dump_tmark_db(self.__worker_index)
res = request_reply_stream.NoResponse()
@ -940,7 +924,7 @@ class ModelWorker(worker_base.Worker):
for x in res.unpack():
# The input data must exist in the storage, otherwise
# the model function call will not run.
self.__data_storage[x.ids[0]].update_(x)
self.data_manager.update(x)
# Only return meta data back to the master worker.
if isinstance(res, data_api.SequenceSample):
@ -967,32 +951,18 @@ class ModelWorker(worker_base.Worker):
@cuda_tmark("data_transfer", CUDATimeMarkType.comm)
def __data_transfer_among_workers(self, hook_data: Dict[str, Any]):
meta_sample = hook_data["meta_sample"]
comm_plan = data_transfer_comm.derive_data_transfer_plan(
keys=hook_data["keys"],
global_ids=meta_sample.ids,
consumer_name=hook_data["target"],
consumer_mapping=hook_data["target_mapping"],
producer_names=hook_data["producer_names"],
producer_mappings=hook_data["producer_mappings"],
data_transfer_info=self.__data_transfer_info,
)
data_transfer_comm.run_data_transfer(
comm_plan=comm_plan,
meta_samples={x.ids[0]: x for x in meta_sample.unpack()},
storage=self.__data_storage,
sent_worker_idx_table=self.__data_sent_worker_indices,
received_worker_idx_table=self.__data_received_worker_indices,
)
plan = [RedistribStep(**json.loads(x)) for x in hook_data["plan"]]
self.data_manager.redistribute(meta_sample, plan=plan)
if hook_data["target"] in self.__models:
with constants.model_scope(hook_data["target"]):
local_ids = [
meta_sample.ids[i]
for i in hook_data["target_mapping"][self._dp_rank]
]
local_ids = hook_data["partitioned_ids"][self._dp_rank]
r = data_api.SequenceSample.gather(
[self.__data_storage[_id] for _id in local_ids],
[
self.data_manager.get(_id).to_device(constants.current_device())
for _id in local_ids
],
keys=meta_sample.keys,
)
self.__compute_input_queues[hook_data["target"]][
@ -1349,7 +1319,6 @@ class ModelWorker(worker_base.Worker):
self.__models.clear()
self.__backends.clear()
self.__interfaces.clear()
self.__data_storage.clear()
# Reset model worker states.
self.__dist_env_resolved = False

View File

@ -0,0 +1,343 @@
# Copyright 2025 Ant Group Inc.
import dataclasses
import itertools
import os
from collections import defaultdict
from typing import *
from realhf.base.cluster import spec as cluster_spec
class GlobalStorageTracker:
def __init__(self, world_size: int):
self.storages: List[Dict[Hashable, List[str]]]
self.storages = [{} for _ in range(world_size)]
self.data_owner: Dict[Tuple[Hashable, str], int]
self.data_owner = {}
def add_data(self, rank: int, ids: List[Hashable], key: str, is_owner: bool):
for data_id in ids:
if data_id not in self.storages[rank]:
self.storages[rank][data_id] = [key]
else:
if key not in self.storages[rank][data_id]:
self.storages[rank][data_id].append(key)
if is_owner:
self.data_owner[(data_id, key)] = rank
def clear_data(self, ids: List[Hashable]):
for storage in self.storages:
for i in ids:
if i in storage:
storage.pop(i)
keys = list(self.data_owner.keys())
for i, k in keys:
if i in ids:
self.data_owner.pop((i, k))
@dataclasses.dataclass
class RedistribStep:
comm_type: str
root: int | None
srcs: List[int] | None
dsts: List[int] | None
ids: List[List[Hashable]]
keys: List[str]
def __repr__(self) -> str:
if self.comm_type == "gather":
return f"Gather {self.keys} to {self.root} from {self.srcs}."
if self.comm_type == "scatter":
return f"Scatter {self.keys} from {self.root} to {self.dsts}."
if self.comm_type == "bcast":
return f"Bcast {self.keys} from {self.root} to {self.dsts}."
raise NotImplementedError()
class RedistribPlanner:
def __init__(self, storage_tracker: GlobalStorageTracker):
self.storage_tracker = storage_tracker
def derive_plan(
self,
dests: Dict[int, List[Hashable]],
keys: List[str],
pattern: str = "gather-scatter",
) -> List[RedistribStep]:
if pattern == "gather-scatter":
return self.derive_plan_gather_scatter(dests, keys)
elif pattern == "bcast":
return self.derive_plan_bcast(dests, keys)
raise NotImplementedError(f"Unknown data redistribution pattern: {pattern}")
def derive_plan_gather_scatter(
self, dests: Dict[int, List[Hashable]], keys: List[str]
) -> List[RedistribStep]:
self.dests = dests
all_data_ids = set()
for all_samples in dests.values():
for data_id in all_samples:
all_data_ids.add(data_id)
transfer_plan = []
for key in keys:
owners = sorted(
list(
set(
[
self.storage_tracker.data_owner[(i, key)]
for i in all_data_ids
]
)
)
)
gather_ids = []
for owner in owners:
this_owner_ids = []
for i in all_data_ids:
if (
i in self.storage_tracker.storages[owner]
and key in self.storage_tracker.storages[owner][i]
):
this_owner_ids.append(i)
gather_ids.append(sorted(this_owner_ids))
gather_step = RedistribStep(
comm_type="gather",
root=owners[0],
srcs=owners,
dsts=None,
ids=gather_ids,
keys=[key],
)
scatter_dsts = sorted([i for i in dests if dests[i]])
scatter_ids = [sorted(dests[i]) for i in scatter_dsts]
scatter_step = RedistribStep(
comm_type="scatter",
root=owners[0],
dsts=scatter_dsts,
srcs=None,
ids=scatter_ids,
keys=[key],
)
transfer_plan += [gather_step, scatter_step]
# Prune the plan.
pop_indices = []
for idx, step in enumerate(transfer_plan):
# 1. Omit the gather step if data has already been gathered before.
if step.comm_type == "gather":
all_gather_ids = list(itertools.chain.from_iterable(step.ids))
key = step.keys[0]
if any(
i not in self.storage_tracker.storages[step.root]
for i in all_gather_ids
):
continue
if any(
key not in self.storage_tracker.storages[step.root][i]
for i in all_gather_ids
):
continue
pop_indices.append(idx)
# 2. Omit the gather + scatter step if data has already exists in all dst GPUs.
if step.comm_type == "scatter":
key = step.keys[0]
all_exists = True
for dst, ids in zip(step.dsts, step.ids):
if any(i not in self.storage_tracker.storages[dst] for i in ids):
all_exists = False
break
if any(
key not in self.storage_tracker.storages[dst][i] for i in ids
):
all_exists = False
break
if all_exists:
pop_indices.append(idx)
pop_indices.append(idx - 1)
for pop_idx in reversed(sorted(set(pop_indices))):
transfer_plan.pop(pop_idx)
# Merging the gather/scatter of different keys
gather_plan = {}
scatter_plan = {}
for step in transfer_plan:
if step.comm_type == "gather":
plan_key = (
step.root,
tuple(sorted(step.srcs)),
tuple([tuple(sorted(ids)) for ids in step.ids]),
)
if plan_key not in gather_plan:
gather_plan[plan_key] = step
else:
assert all(
k not in gather_plan[plan_key].keys for k in step.keys
), (
gather_plan[plan_key],
step,
plan_key,
)
gather_plan[plan_key].keys += step.keys
if step.comm_type == "scatter":
plan_key = (
step.root,
tuple(sorted(step.dsts)),
tuple([tuple(sorted(ids)) for ids in step.ids]),
)
if plan_key not in scatter_plan:
scatter_plan[plan_key] = step
else:
assert all(
k not in scatter_plan[plan_key].keys for k in step.keys
), (
scatter_plan[plan_key],
step,
plan_key,
)
scatter_plan[plan_key].keys += step.keys
# Prioritize gather over scatter
return list(gather_plan.values()) + list(scatter_plan.values())
def derive_plan_bcast(
self, dests: Dict[int, List[Hashable]], keys: List[str]
) -> List[RedistribStep]:
assert isinstance(keys, list), type(keys)
self.dests = dests
# Get all requried data IDs.
all_data_ids = set()
for all_samples in self.dests.values():
for data_id in all_samples:
all_data_ids.add(data_id)
# The producers for each required data.
id2gpu_src = {}
for data_id in all_data_ids:
for key in keys:
id2gpu_src[(data_id, key)] = []
for gpu_id, ids2keys in enumerate(self.storage_tracker.storages):
if data_id in ids2keys and key in ids2keys[data_id]:
id2gpu_src[(data_id, key)].append(gpu_id)
# The consumers for each requried data.
id2gpu_dst = {}
for data_id in all_data_ids:
for key in keys:
id2gpu_dst[(data_id, key)] = []
for gpu_id, ids in self.dests.items():
if data_id in ids:
id2gpu_dst[(data_id, key)].append(gpu_id)
self.transfer_plan = {}
for data_id, key in itertools.product(all_data_ids, keys):
source_gpus = id2gpu_src[(data_id, key)]
target_gpus = id2gpu_dst[(data_id, key)]
assert len(source_gpus) > 0, (data_id, key, id2gpu_src, id2gpu_dst)
# Omit data transfer if it exists in the target GPU
target_gpus = [gpu for gpu in target_gpus if gpu not in source_gpus]
if not target_gpus:
continue
# Find the "nearest" GPU for data transfer.
best_src = self._select_best_bcast_source(source_gpus, target_gpus)
self.transfer_plan[(data_id, key)] = {"src": best_src, "dsts": target_gpus}
return self._group_bcast_transfers()
def _on_same_node(self, i, j) -> bool:
return (i // cluster_spec.n_gpus_per_node) == (
j // cluster_spec.n_gpus_per_node
)
def _select_best_bcast_source(self, source_gpus, target_gpus):
same_node_counts = {}
for src in source_gpus:
same_node_count = sum(
1 for dst in target_gpus if self._on_same_node(src, dst)
)
same_node_counts[src] = same_node_count
# Find the source that maximizes locality.
max_same_node = max(same_node_counts.values())
best_sources = [
src for src, count in same_node_counts.items() if count == max_same_node
]
# Find the source with the smallest workload.
src_load = defaultdict(int)
for plan in self.transfer_plan.values():
src_gpu = plan["src"]
src_load[src_gpu] += len(plan["dsts"])
return min(best_sources, key=lambda src: src_load[src])
def _group_bcast_transfers(self) -> List[RedistribStep]:
# Group data ids that should be transferred from "src" to "dsts"
src_to_transfers = defaultdict(lambda: defaultdict(list))
for (data_id, key), plan in self.transfer_plan.items():
src_to_transfers[(plan["src"], key)][tuple(sorted(plan["dsts"]))].append(
data_id
)
stages = []
while any(src_to_transfers.values()):
stage = []
used_dsts = set()
used_srcs = set()
for (src, key), transfers in list(src_to_transfers.items()):
if src in used_srcs:
continue
if not transfers:
continue
# Find a transfer that can be concurrent executed.
pop_key = None
for i, dsts in enumerate(transfers):
if not any(dst in used_dsts for dst in dsts):
pop_key = dsts
break
if pop_key is not None:
data_ids = transfers.pop(pop_key)
stage.append(
RedistribStep(
comm_type="bcast",
root=src,
srcs=[src],
keys=[key],
dsts=pop_key,
ids=data_ids,
)
)
used_dsts.update(dsts)
used_srcs.add(src)
if stage:
stages += stage
else:
for (src, key), transfers in list(src_to_transfers.items()):
if transfers:
dsts, data_ids = transfers.pop(0)
stages.append(
RedistribStep(
comm_type="bcast",
srcs=[src],
root=src,
dsts=dsts,
ids=data_ids,
keys=[key],
)
)
break
return stages

View File

@ -138,6 +138,9 @@ class NameResolvingRequestClient:
f"subscribers: {name_resolve.get_subtree(names.request_reply_stream(experiment_name, trial_name, PUBSUB_BARRIER_NAME))}."
)
def route_to(self, handler) -> int:
return self._handler_routing[handler]
def close(self):
self.recv_socket.close()
for send_socket in self.send_sockets:

View File

@ -1,501 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import asyncio
import copy
import gc
import os
import time
from typing import Dict
import colorama
import networkx as nx
import numpy as np
import uvloop
import wandb
import realhf.api.core.dfg as dfg
import realhf.api.core.model_api as model_api
import realhf.api.core.system_api as config_pkg
import realhf.base.recover as recover
import realhf.system.request_reply_stream as request_reply_stream
import realhf.system.worker_base as worker_base
from realhf.api.core.config import ModelName
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import constants, logging, name_resolve, names, timeutil, topology
from realhf.system.buffer import AsyncIOSequenceBuffer
from realhf.system.v2.function_call import RPCCorountineControl
from realhf.system.v2.function_executor import FunctionExecutor
logger = logging.getLogger("master worker", "system")
blogger = logging.getLogger("benchmark")
uvloop.install()
class MasterWorker(worker_base.Worker):
global_exp_tik = time.perf_counter()
def _configure(self, config: config_pkg.MasterWorker):
self.config = config
self.__model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = (
config.model_topos
)
# Build execution graph and initialize concurrency utilities.
self.__model_rpcs = config.model_rpcs
# Sort all MFCs in the topological order and
# calculate the width of each level.
# These numbers will determine when to flush MFC requests.
self.__topo_widths = []
for generation in nx.topological_generations(self.__model_rpcs[0]._G):
self.__topo_widths.append(len(generation))
logger.info("Topological widths: " + str(self.__topo_widths))
self.__rpc_srcs = list(filter(lambda rpc: rpc.is_src, self.__model_rpcs))
self.__rpc_dsts = list(filter(lambda rpc: rpc.is_dst, self.__model_rpcs))
# Save and eval control.
self.__total_train_epochs = config.exp_ctrl.total_train_epochs
self.__save_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.save_freq_epochs,
freq_step=config.exp_ctrl.save_freq_steps,
freq_sec=config.exp_ctrl.save_freq_secs,
)
if (
config.exp_ctrl.ckpt_freq_epochs is None
and config.exp_ctrl.ckpt_freq_steps is None
and config.exp_ctrl.ckpt_freq_secs is None
):
self.__ckpt_ctl = self.__save_ctl
else:
self.__ckpt_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.ckpt_freq_epochs,
freq_step=config.exp_ctrl.ckpt_freq_steps,
freq_sec=config.exp_ctrl.ckpt_freq_secs,
)
self.__eval_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=config.exp_ctrl.eval_freq_epochs,
freq_step=config.exp_ctrl.eval_freq_steps,
freq_sec=config.exp_ctrl.eval_freq_secs,
)
self.MODEL_SAVE_ROOT = os.path.join(
constants.MODEL_SAVE_ROOT,
config.worker_info.experiment_name,
config.worker_info.trial_name,
)
os.makedirs(self.MODEL_SAVE_ROOT, exist_ok=True)
self.__initialized = False
self.__recover_run, self.__recover_info = recover.load_recover_info()
if self.__recover_info is not None:
logger.info(
f"Loaded recover info: recover_start={self.__recover_info.recover_start}, "
f"last_step_info={self.__recover_info.last_step_info}."
)
logger.info(
f"Number of used data in recover info: {len(self.__recover_info.hash_vals_to_ignore)}. "
f"The previous experiment probably ran for {len(self.__recover_info.hash_vals_to_ignore) // self.__rpc_srcs[0].n_seqs} steps in the epoch."
)
# Create corountine control objects for running the dataflow graph.
self.__rpc_ctrl = RPCCorountineControl(
train_count=asyncio.Queue(maxsize=len(self.__rpc_dsts)),
topo_level_count=asyncio.Queue(maxsize=sum(self.__topo_widths)),
# NOTE: We should accumulate the used data hashes in the same epoch
# to prevent loading data used before.
used_hash_vals_this_epoch=(
copy.deepcopy(self.__recover_info.hash_vals_to_ignore)
if self.__recover_run
else list()
),
hash_vals_to_ignore_in_recover=(
copy.deepcopy(self.__recover_info.hash_vals_to_ignore)
if self.__recover_run
else list()
),
)
if self.__recover_run:
self.__rpc_ctrl.step_info = copy.deepcopy(self.__recover_info.recover_start)
self.__eval_ctl.load_state_dict(self.__recover_info.eval_ctl_info)
self.__save_ctl.load_state_dict(self.__recover_info.save_ctl_info)
self.__ckpt_ctl.load_state_dict(self.__recover_info.ckpt_ctl_info)
logger.info(
f"Recovering from previous run. "
f"Epoch: {self.__rpc_ctrl.step_info.epoch + 1}, "
f"Epoch Step: {self.__rpc_ctrl.step_info.epoch_step + 1} "
f"Global Step: {self.__rpc_ctrl.step_info.global_step + 1}."
)
# for benchmark
self.e2e_time_history = []
self.__benchmark_steps = config.exp_ctrl.benchmark_steps
return config.worker_info
def initialize_models(self):
# Initialize model backends.
model_names = list(self.__model_topos.keys())
self.logger.info(f"Initialize model backends with order: {model_names}.")
train_rpcs = list(
filter(
lambda rpc: rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
self.__model_rpcs,
)
)
assert all(rpc.n_seqs == train_rpcs[0].n_seqs for rpc in train_rpcs)
if len(train_rpcs) > 0:
ft_spec = model_api.FinetuneSpec(
total_train_epochs=self.config.exp_ctrl.total_train_epochs,
dataset_size=self._dataset_size,
train_batch_size=train_rpcs[0].n_seqs,
)
else:
ft_spec = model_api.FinetuneSpec(
total_train_epochs=self.config.exp_ctrl.total_train_epochs,
dataset_size=self._dataset_size,
train_batch_size=self.__src_rpc.n_seqs,
)
_initialized_roles = []
for model_name in model_names:
topo = self.config.model_topos[model_name]
# Build FinetuneSpec, which is required to initialize backends.
_handlers = [
config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, j)
for j in range(topo.world_size())
]
init_payloads = [
request_reply_stream.Payload(
handler=_h,
handle_name="initialize",
data=ft_spec,
)
for _h in _handlers
]
# Send initialization requests then immediately flush them.
self.__stream.request(
payloads=init_payloads,
)
self.__stream.request(
handlers=_handlers,
handle_type="flush",
no_syn=True,
)
_initialized_roles.append(model_name.role)
self._ft_spec = ft_spec
logger.info("Initializations of models and backends complete.")
def get_dataset_model_info(self):
src_rpc = self.__rpc_srcs[0]
src_rpc_topo = self.config.model_topos[src_rpc.model_name]
src_rpc_dp_size = src_rpc_topo.get_dim("data")
# Request training specification from data workers.
all_data = sum(
self.__stream.call(
handlers=[f"__data{i}__" for i in range(src_rpc_dp_size)],
datas=[None for i in range(src_rpc_dp_size)],
handle_type="spec",
),
[],
)
# NOTE: For dynamic datasets, we still count epoch according to the initial number of data,
# such that the learning rate decay is not affected.
seqlens = [max(sum(v[0]) for v in x.seqlens.values()) for x in all_data]
self._dataset_size = len(all_data)
self._steps_per_epoch = self._dataset_size // src_rpc.n_seqs
self._avg_tokens_per_batch = sum(seqlens) / self._steps_per_epoch
self._dataset_ids = [copy.deepcopy(x.ids[0]) for x in all_data]
# Request model configs from model workers.
# Return None if the model is not a ReaLModel.
self.__model_configs: Dict[ModelName, None | ReaLModelConfig] = {}
for model_name, topo in self.config.model_topos.items():
h = config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, 0)
self.__model_configs[model_name] = self.__stream.call(
handlers=[h],
datas=[None],
handle_type="model_config",
)[0]
def __lazy_init(self):
# Set up streams.
handler_routing = copy.deepcopy(self.config.msid2mwid)
src_rpc = self.__rpc_srcs[0]
src_rpc_topo = self.config.model_topos[src_rpc.model_name]
src_rpc_dp_size = src_rpc_topo.get_dim("data")
src_rpc_pp_size = src_rpc_topo.get_dim("pipe")
for i in range(src_rpc_dp_size):
rank = src_rpc_topo.get_rank(data=i, pipe=src_rpc_pp_size - 1, model=0)
handler_routing[f"__data{i}__"] = self.config.msid2mwid[
config_pkg.ModelShardID.from_parallelism_rank(
model_name=src_rpc.model_name,
topo=src_rpc_topo,
parallelism_rank=rank,
)
]
handler_routing.update({i: i for i in range(self.config.n_model_workers)})
self.__stream = request_reply_stream.make_master_stream(
self.config.worker_info,
n_subscribers=self.config.n_model_workers,
handler_routing=handler_routing,
)
self.__stream: request_reply_stream.NameResolvingRequestClient
self.__src_rpc = src_rpc = [
rpc for rpc in self.config.model_rpcs if rpc.is_src
][0]
self.get_dataset_model_info()
self.initialize_models()
self.__seqbuffer = AsyncIOSequenceBuffer(
self.__model_rpcs,
max_size=int(os.getenv("REAL_MASTER_BUFFER_SIZE", str(int(1e7)))),
)
# Create coroutines for model RPCs.
logger.info(f"Creating asyncio coroutines...")
self.func_executor = FunctionExecutor(
rpcs=self.__model_rpcs,
msid2mwid=self.config.msid2mwid,
stream=self.__stream,
buffer=self.__seqbuffer,
model_topos=self.__model_topos,
model_configs=self.__model_configs,
ctrl=self.__rpc_ctrl,
)
logger.info(f"Coroutines created. The master worker is ready to run.")
# wandb init, connect to remote wandb host
wandb.login()
wandb.init(
mode=self.wandb_config.mode,
entity=self.wandb_config.entity,
project=self.wandb_config.project or constants.experiment_name(),
name=self.wandb_config.name or constants.trial_name(),
job_type=self.wandb_config.job_type,
group=self.wandb_config.group,
notes=self.wandb_config.notes,
tags=self.wandb_config.tags,
config=self.wandb_config.config,
dir=os.path.join(
constants.LOG_ROOT, constants.experiment_name(), constants.trial_name()
),
force=True,
resume="allow",
settings=wandb.Settings(start_method="fork"),
)
self.__initialized = True
self._train_start_time = time.perf_counter()
self.__last_step_info = recover.StepInfo(
epoch=-1,
epoch_step=-1,
global_step=-1,
)
def _poll(self):
is_new_epoch = False
if not self.__initialized:
self.__lazy_init()
# Main execution steps. The graph runs under-the-hood in RPC & stream threads.
# Wait for the finish of the traversal of the execution graph.
execution_start = time.perf_counter()
if self.__rpc_ctrl.ids_to_clear:
# Send clear cache requests to model workers.
# Clearing the data used in the last step.
self._clear_gpu_cache()
is_new_epoch = self._ft_spec.is_new_epoch(self.__rpc_ctrl.step_info)
is_epoch_last_step = self._ft_spec.is_epoch_last_step(self.__rpc_ctrl.step_info)
# Check whether we should evaluate or save models.
self.__rpc_ctrl.should_eval = self.__eval_ctl.check(
epochs=int(is_epoch_last_step), steps=1
)
self.__rpc_ctrl.should_save = self.__save_ctl.check(
epochs=int(is_epoch_last_step), steps=1
)
self.__rpc_ctrl.should_ckpt = self.__ckpt_ctl.check(
epochs=int(is_epoch_last_step), steps=1
)
# Traverse over the dataflow graph for once.
self.func_executor.execute_step()
# Post-process.
if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt:
self.__last_step_info = copy.deepcopy(self.__rpc_ctrl.step_info)
self.__rpc_ctrl.used_hash_vals_this_epoch += list(self.__rpc_ctrl.ids_to_clear)
if is_epoch_last_step:
self.__rpc_ctrl.used_hash_vals_this_epoch = (
self.__rpc_ctrl.used_hash_vals_this_epoch[self._dataset_size :]
)
if is_new_epoch:
self.__rpc_ctrl.step_info.epoch += 1
self.__rpc_ctrl.step_info.epoch_step = 0
# Logging.
time_since_configure = time.perf_counter() - self._train_start_time
e2e_time = time.perf_counter() - execution_start
self.e2e_time_history.append(e2e_time)
self._log_training_stats(e2e_time, time_since_configure)
# Updata counters.
self.__rpc_ctrl.step_info.epoch_step += 1
self.__rpc_ctrl.step_info.global_step += 1
if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt:
self.__recover_save()
# Pause the worker if experiment or system-wise benchmark completes.
if (
self.__benchmark_steps is not None
and self.__rpc_ctrl.step_info.global_step >= self.__benchmark_steps
) or (
self.__rpc_ctrl.step_info.global_step * self.__src_rpc.n_seqs
>= self.__total_train_epochs * self._dataset_size
):
# We don't know whether it is the last step of the current epoch,
# so we exit at the first step of the next epoch.
if self.__benchmark_steps is not None:
logger.info(
f"Finished benchmark {self.__benchmark_steps}. "
f"Time consumption of this setup: {time_since_configure:.3f}"
)
logger.info(f"avg #e2e# time *{np.mean(self.e2e_time_history):.3f}*")
return self.experiment_complete_exit()
return worker_base.PollResult(sample_count=1, batch_count=1)
def _log_training_stats(self, e2e_time: float, time_since_configure: float):
# calculate flops
#########################################
if not all(
isinstance(v, ReaLModelConfig) for v in self.__model_configs.values()
):
logger.warning(
f"Not all models are ReaLModels. Unable to calculate FLOP/s."
)
flops = None
tflops_per_gpu = float("inf")
else:
flops = self.__rpc_ctrl.flops_counter.get_flops()
tflops = flops / (e2e_time * (10**12))
tflops_per_gpu = flops / (e2e_time * self.config.n_model_workers * (10**12))
self.__rpc_ctrl.flops_counter.clear()
#########################################
epoch = self.__rpc_ctrl.step_info.epoch + 1
epoch_step = self.__rpc_ctrl.step_info.epoch_step + 1
global_step = self.__rpc_ctrl.step_info.global_step + 1
s = f"Epoch {epoch}/{self.config.exp_ctrl.total_train_epochs} "
s += f"step {epoch_step}/{self._steps_per_epoch} "
s += f"(global step {global_step}) finishes. "
s += f"Average #tokens per batch is {self._avg_tokens_per_batch:.0f}. "
s += f"#End to end# execution time: *{e2e_time:.3f}*s. "
s += f"Total time consumption: {time_since_configure:.3f}s. "
if len(self.e2e_time_history) > 2:
remaining_steps = self._steps_per_epoch - epoch_step
remaining_epochs = self.__total_train_epochs - epoch
avg_t = np.mean(self.e2e_time_history[2:])
remain_t = avg_t * remaining_steps
remain_t += avg_t * self._steps_per_epoch * remaining_epochs
s += f"Estimated remaining time: {remain_t:.3f}s. "
if flops is not None:
s += f"TFLOP/s per GPU: {tflops_per_gpu:.2f}, total TFLOP/s: {tflops:.2f}."
logger.info(s)
logger.info(
f"Time taken so far across all configurations: {time.perf_counter() - self.global_exp_tik:.2f}s"
)
def _clear_gpu_cache(self):
self.__stream.request(
handlers=list(range(self.config.n_model_workers)),
handle_type="clear_data_cache",
datas=[
self.__rpc_ctrl.ids_to_clear
for _ in list(range(self.config.n_model_workers))
],
no_syn=True,
)
self.__rpc_ctrl.ids_to_clear.clear()
def experiment_complete_exit(self):
logger.info(
colorama.Style.RESET_ALL
+ colorama.Fore.YELLOW
+ colorama.Style.BRIGHT
+ "\033[1m"
+ "Experiment Completes! Yeah!!!!!!!!"
+ colorama.Style.RESET_ALL
)
# Send requests to pause model workers.
# Model workers will not respond to this message.
self.__stream.request(
handlers=list(range(self.config.n_model_workers)),
handle_type="reset",
datas=[None for _ in list(range(self.config.n_model_workers))],
)
self.__stream.close()
constants.reset_run()
# Reset names used for distributed training.
# The next round of training will set up a new distributed environment.
name_resolve.clear_subtree(
names.distributed_root(constants.experiment_name(), constants.trial_name())
)
name_resolve.clear_subtree(
names.request_reply_stream_root(
constants.experiment_name(), constants.trial_name()
)
)
wandb.finish()
gc.collect()
self.__initialized = False
self.pause()
return worker_base.PollResult(0, 0)
def __recover_save(self):
# save step info for recover
if os.getenv("REAL_SAVE_RECOVER_STATES", "0") != "1":
return
# save step info for recover
this_step_info = copy.deepcopy(self.__rpc_ctrl.step_info)
recover_info = recover.RecoverInfo(
recover_start=this_step_info,
last_step_info=self.__last_step_info,
save_ctl_info=self.__save_ctl.state_dict(),
ckpt_ctl_info=self.__ckpt_ctl.state_dict(),
eval_ctl_info=self.__eval_ctl.state_dict(),
hash_vals_to_ignore=self.__rpc_ctrl.used_hash_vals_this_epoch,
)
recover.dump_recover_info(recover_info)
logger.info("Dumped recover info to file.")
logger.info(f"Will recover from: {recover_info.recover_start}")
logger.info(
f"Number of data used in this epoch: {len(recover_info.hash_vals_to_ignore)}"
)

View File

@ -0,0 +1,288 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
import os
import pathlib
import pickle
import uuid
from typing import *
import numpy as np
import pytest
import torch
import torch.distributed as dist
from realhf.api.core.config import ModelName, ModelShardID
from realhf.api.core.data_api import SequenceSample
from realhf.base import constants, testing, topology
from realhf.base.testing import (
LocalMultiProcessTest,
PipeModelDataParallelTopology,
init_global_constants,
)
from realhf.system.data_manager import DataManager
from realhf.system.redistributor import GlobalStorageTracker, RedistribPlanner
def get_data_manager(
from_model_name,
to_model_name,
from_pp_dp_mp,
to_pp_dp_mp,
):
from_num_pp, from_num_dp, from_num_mp = from_pp_dp_mp
to_num_pp, to_num_dp, to_num_mp = to_pp_dp_mp
from_world_size = from_num_dp * from_num_mp * from_num_pp
to_world_size = to_num_dp * to_num_mp * to_num_pp
from_topo = topology.PipeModelDataParallelTopology(
num_dp=from_num_dp,
num_mp=from_num_mp,
num_pp=from_num_pp,
sequence_parallel=False,
gradient_checkpointing=False,
max_prompt_len=None,
gradient_accumulation_fusion=False,
)
to_topo = topology.PipeModelDataParallelTopology(
num_dp=to_num_dp,
num_mp=to_num_mp,
num_pp=to_num_pp,
sequence_parallel=False,
gradient_checkpointing=False,
max_prompt_len=None,
gradient_accumulation_fusion=False,
)
model_topos = {from_model_name: from_topo, to_model_name: to_topo}
msid2mwid = {}
for i in range(dist.get_world_size()):
# We assume the `from_model` occupies the first serveral GPUs,
# while the `to_model` occupies GPUs from the last one.
# For example, when the world size of `from_model` is 6 and
# the world size of `to_model` is 4, the GPU layout is:
# GPU 0-3: from_model (shard 0-3)
# GPU 4-5: from_model (shard 4-5) + to_model (shard 0-1)
# GPU 6-7: to_model (shard 2-3)
_model_names = []
if i < from_world_size:
_model_names.append(from_model_name)
if i >= dist.get_world_size() - to_world_size:
_model_names.append(to_model_name)
for _model_name in _model_names:
if _model_name == from_model_name:
coord = model_topos[_model_name].get_coord(i)
else:
coord = model_topos[_model_name].get_coord(
i + to_world_size - dist.get_world_size()
)
k = ModelShardID(
_model_name,
dp_rank=coord.data,
mp_rank=coord.model,
pp_rank=coord.pipe,
topo=model_topos[_model_name],
)
msid2mwid[k] = i
init_global_constants(
num_dp=from_num_dp,
num_mp=from_num_mp,
num_pp=from_num_pp,
topo=from_topo,
model_name=from_model_name,
sequence_parallel=False,
msid2mwid=msid2mwid,
)
init_global_constants(
num_dp=to_num_dp,
num_mp=to_num_mp,
num_pp=to_num_pp,
model_name=to_model_name,
sequence_parallel=False,
msid2mwid=msid2mwid,
)
return DataManager(
model_topos=model_topos,
msid2mwid=msid2mwid,
data_transfer_pairs=[(from_model_name, to_model_name)],
)
def recursive_assert_equal(x1, x2):
if type(x1) != type(x2):
raise AssertionError(f"{type(x1)} != {type(x2)}")
if isinstance(x1, dict):
assert set(x1.keys()) == set(x2.keys())
for k in x1.keys():
recursive_assert_equal(x1[k], x2[k])
elif dataclasses.is_dataclass(x1):
for f in dataclasses.fields(x1):
recursive_assert_equal(getattr(x1, f.name), getattr(x2, f.name))
elif isinstance(x1, torch.Tensor):
assert torch.allclose(x1, x2), (x1, x2)
elif isinstance(x1, list):
assert len(x1) == len(x2)
for a, b in zip(x1, x2):
recursive_assert_equal(a, b)
else:
assert x1 == x2
def _test_data_transfer(
tmp_path,
from_pp_dp_mp: Tuple,
to_pp_dp_mp: Tuple,
):
from_model_name = ModelName("data_transfer_test", 0)
from_topo = PipeModelDataParallelTopology(
num_pp=from_pp_dp_mp[0],
num_mp=from_pp_dp_mp[-1],
num_dp=from_pp_dp_mp[1],
sequence_parallel=True,
gradient_checkpointing=True,
gradient_accumulation_fusion=True,
)
to_model_name = ModelName("data_transfer_test", 1)
to_topo = PipeModelDataParallelTopology(
num_pp=to_pp_dp_mp[0],
num_mp=to_pp_dp_mp[-1],
num_dp=to_pp_dp_mp[1],
sequence_parallel=True,
gradient_checkpointing=True,
gradient_accumulation_fusion=True,
)
data_manager = get_data_manager(
from_model_name,
to_model_name,
from_pp_dp_mp,
to_pp_dp_mp,
)
data_manager.setup_process_groups()
storage_tracker = GlobalStorageTracker(dist.get_world_size())
planner = RedistribPlanner(storage_tracker)
key = "input_ids"
world_size = dist.get_world_size()
samples = []
for dp_rank in range(from_pp_dp_mp[1]):
gpu_id = data_manager.msid2mwid[
ModelShardID(
from_model_name,
dp_rank=dp_rank,
mp_rank=0,
pp_rank=from_pp_dp_mp[0] - 1,
topo=from_topo,
)
]
storage_tracker.add_data(
gpu_id,
ids=[i + dp_rank * world_size for i in range(world_size)],
key=key,
is_owner=True,
)
seqlens = torch.randint(10, 1000, size=(world_size,))
dist.all_reduce(seqlens)
input_ids = torch.randint(
0,
10000,
size=(int(sum(seqlens)),),
)
dist.all_reduce(input_ids)
s = SequenceSample.from_default(
ids=[i + dp_rank * world_size for i in range(world_size)],
seqlens=seqlens.numpy().tolist(),
data=dict(input_ids=input_ids),
)
if dist.get_rank() == 0:
for ss in s.unpack():
with open(os.path.join(tmp_path, f"{ss.ids[0]}.pkl"), "wb") as f:
pickle.dump(ss, f)
samples.append(s)
if dist.get_rank() == gpu_id:
for ss in s.unpack():
data_manager.store(ss)
dist.barrier()
all_ids = list(range(world_size * from_topo.get_dim("data")))
np.random.shuffle(all_ids)
_all_ids = [all_ids]
dist.broadcast_object_list(_all_ids, src=0)
all_ids = _all_ids[0]
dests = {}
for rank in range(to_topo.world_size()):
coord = to_topo.get_coord(rank)
dp_size = to_topo.get_dim("data")
gpu_id = data_manager.msid2mwid[
ModelShardID(
to_model_name,
dp_rank=coord.data,
mp_rank=coord.model,
pp_rank=coord.pipe,
topo=to_topo,
)
]
size_per_dp = len(all_ids) // dp_size
dests[gpu_id] = [coord.data * size_per_dp + i for i in range(size_per_dp)]
for gpu_id in range(world_size):
if gpu_id not in dests:
dests[gpu_id] = []
plan = planner.derive_plan(dests, keys=[key])
data_manager.redistribute(SequenceSample.gather(samples), plan)
dist.barrier()
for i, s in data_manager.storage.items():
with open(os.path.join(tmp_path, f"{i}.pkl"), "rb") as f:
ss = pickle.load(f)
recursive_assert_equal(ss, s)
print("success")
parallelism = [(1, 4, 2), (1, 8, 1)]
@pytest.mark.skipif(
os.cpu_count() < 32 or testing.get_free_mem_gb() < 50,
reason="The parameter reallocation test requires at least 32 CPUs and 50GB memory.",
)
@pytest.mark.parametrize("from_pp_dp_mp", [(1, 4, 2)])
@pytest.mark.parametrize("to_pp_dp_mp", [(1, 8, 1)])
@pytest.mark.distributed
def test_data_transfer(
tmp_path,
from_pp_dp_mp: Tuple,
to_pp_dp_mp: Tuple,
):
expr_name = uuid.uuid4()
trial_name = uuid.uuid4()
constants.set_force_cpu(True)
test_impl = LocalMultiProcessTest(
world_size=16,
func=_test_data_transfer,
expr_name=expr_name,
trial_name=trial_name,
timeout_secs=300,
tmp_path=tmp_path,
from_pp_dp_mp=from_pp_dp_mp,
to_pp_dp_mp=to_pp_dp_mp,
)
test_impl.launch()

View File

@ -47,13 +47,11 @@ def math_dataset(request, save_path):
return dataset
# NOTE: we can't test v1 and v2 at the same time.
@pytest.mark.parametrize("use_v2_worker", [True])
@pytest.mark.parametrize(
"dp,pp,mp",
[
(1, 1, 1),
(2, 1, 1),
(2, 1, 2),
(1, 2, 1),
(1, 1, 2),
],
@ -68,7 +66,6 @@ def test_ppo_symm(
dp,
pp,
mp,
use_v2_worker,
):
# Setup experiment env. Should be done before any other operations.
log_root = tmp_path_factory.mktemp("ppo")
@ -120,10 +117,8 @@ def test_ppo_symm(
),
),
)
exp_cfg.actor.vllm.hybrid_train = True
exp_cfg.actor.vllm.enforce_eager = True
run_test_exp(exp_cfg, use_v2_worker=use_v2_worker)
run_test_exp(exp_cfg)
# The global resharding strategy, where all MFCs
@ -138,6 +133,7 @@ def test_ppo_symm(
def test_ppo_global_reshard(
tmp_path_factory,
tokenizer,
math_dataset,
save_path,
cpu_hf_model,
mconfig,
@ -243,19 +239,17 @@ def test_ppo_global_reshard(
),
),
)
exp_cfg.actor.vllm.hybrid_train = True
exp_cfg.actor.vllm.enforce_eager = True
run_test_exp(exp_cfg)
# Actor/critic train and ref_inf/rew_inf are on disjoint
# device meshes and executed concurrently.
@pytest.mark.parametrize("actor_gen", [(1, 2, 1)])
@pytest.mark.parametrize("critic_inf", [(1, 1, 2)])
@pytest.mark.parametrize("actor_gen", [(2, 2, 1)])
@pytest.mark.parametrize("critic_inf", [(2, 1, 2)])
def test_ppo_param_realloc_sub_device_mesh(
tmp_path_factory,
tokenizer,
math_dataset,
save_path,
cpu_hf_model,
mconfig,
@ -276,7 +270,7 @@ def test_ppo_param_realloc_sub_device_mesh(
mode="local",
allocation_mode="manual",
n_nodes=1,
n_gpus_per_node=2,
n_gpus_per_node=8,
actor=ModelTrainEvalConfig(
path=str(save_path),
init_from_scratch=True,
@ -312,54 +306,54 @@ def test_ppo_param_realloc_sub_device_mesh(
),
),
actor_gen=MFCConfig(
device_mesh="NODE01:0,1,2,3",
parallel=ParallelismConfig(
data_parallel_size=actor_gen[0],
model_parallel_size=actor_gen[1],
pipeline_parallel_size=actor_gen[2],
)
),
),
actor_train=MFCConfig(
device_mesh="NODE01:0",
device_mesh="NODE01:4,5,6,7",
parallel=ParallelismConfig(
data_parallel_size=1,
data_parallel_size=4,
model_parallel_size=1,
pipeline_parallel_size=1,
),
),
critic_inf=MFCConfig(
device_mesh="NODE01:4,5,6,7",
parallel=ParallelismConfig(
data_parallel_size=critic_inf[0],
model_parallel_size=critic_inf[1],
pipeline_parallel_size=critic_inf[2],
)
),
),
rew_inf=MFCConfig(
device_mesh="NODE01:1",
device_mesh="NODE01:4,5,6,7",
parallel=ParallelismConfig(
data_parallel_size=1,
data_parallel_size=4,
model_parallel_size=1,
pipeline_parallel_size=1,
),
),
ref_inf=MFCConfig(
device_mesh="NODE01:0",
device_mesh="NODE01:4,5,6,7",
parallel=ParallelismConfig(
data_parallel_size=1,
model_parallel_size=1,
pipeline_parallel_size=1,
model_parallel_size=2,
pipeline_parallel_size=2,
),
),
critic_train=MFCConfig(
device_mesh="NODE01:1",
device_mesh="NODE01:4,5,6,7",
parallel=ParallelismConfig(
data_parallel_size=1,
data_parallel_size=2,
model_parallel_size=1,
pipeline_parallel_size=1,
pipeline_parallel_size=2,
),
),
)
exp_cfg.actor.vllm.hybrid_train = True
exp_cfg.actor.vllm.enforce_eager = True
run_test_exp(exp_cfg)

View File

@ -35,11 +35,17 @@ def model_class(request):
],
)
def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp)
test_sft(
tmp_path_factory,
tokenizer,
save_path,
cpu_hf_model,
dp,
pp,
tp,
)
# NOTE: we can't test v1 and v2 at the same time.
@pytest.mark.parametrize("use_v2_worker", [True])
@pytest.mark.parametrize(
"dp,pp,tp",
[
@ -49,9 +55,7 @@ def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp
(1, 1, 2),
],
)
def test_sft(
tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp, use_v2_worker
):
def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
# Setup experiment env. Should be done before any other operations.
log_root = tmp_path_factory.mktemp("sft")
@ -83,4 +87,4 @@ def test_sft(
),
)
run_test_exp(exp_cfg, use_v2_worker=use_v2_worker)
run_test_exp(exp_cfg)

View File

@ -18,7 +18,11 @@ def model_class(request):
return request.param
def run_model_worker(cfg, mw, barrier):
def run_model_worker(
cfg,
mw,
barrier,
):
constants.set_force_cpu(True)
# Register all datasets and models
import realhf.impl.dataset # isort: skip
@ -48,18 +52,13 @@ def run_test_exp(
exp_cfg: Experiment,
expr_name=None,
trial_name=None,
use_v2_worker: bool = False,
):
constants.set_force_cpu(True)
# Register all datasets and models
import realhf.impl.dataset # isort: skip
import realhf.impl.model # isort: skip
from realhf.api.core import system_api
if not use_v2_worker:
from realhf.system.master_worker import MasterWorker
else:
from realhf.system.v2.master_worker import MasterWorker
from realhf.system.master_worker import MasterWorker
system_api.ALL_EXPERIMENT_CLASSES = {}
register_experiment(testing._DEFAULT_EXPR_NAME, lambda: exp_cfg)
@ -83,7 +82,12 @@ def run_test_exp(
testcase = testing.LocalMultiProcessTest(
world_size=len(exp_setup.model_worker),
func=[
functools.partial(run_model_worker, cfg=exp_cfg, mw=mw, barrier=barrier)
functools.partial(
run_model_worker,
cfg=exp_cfg,
mw=mw,
barrier=barrier,
)
for mw in exp_setup.model_worker
],
expr_name=expr_name or testing._DEFAULT_EXPR_NAME,