mirror of https://github.com/inclusionAI/AReaL
PullRequest: 9 Refactoring data transfer for v2 workers.
Merge branch fw/datatransfer-v2 of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/9 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * fw/fix-dataloading-not-shuffle * . * . * . * . * . * add v2 master worker * cpu test pass * ppo run * . * pass sft test * pass ppo dp test * format * fix * run * . * cleanup * . * format * run * merge and format * refactor * sft pass * . * format * format * format * . * .
This commit is contained in:
parent
b3bedd7b9d
commit
ca42e43638
|
@ -39,7 +39,6 @@ RUN cd /vllm && \
|
|||
python3 use_existing_torch.py && \
|
||||
pip3 install -r requirements-build.txt && \
|
||||
MAX_JOBS=64 pip3 install -e . --no-build-isolation
|
||||
RUN yes | pip3 uninstall uvloop
|
||||
RUN pip3 install opencv-python-headless==4.5.4.58
|
||||
|
||||
RUN apt-get update && apt-get install -y python3.10-venv
|
||||
|
|
|
@ -63,14 +63,7 @@ def load_problems_with_testcase_batch(path, debug=False, test_case_batch_size=No
|
|||
return problem_map
|
||||
|
||||
|
||||
global_problems = load_problems_with_testcase_batch(
|
||||
os.getenv(
|
||||
"REAL_CODE_METADATA_PATH",
|
||||
"/storage/datasets/codeparrot-apps-test.jsonl",
|
||||
),
|
||||
debug=True,
|
||||
test_case_batch_size=20,
|
||||
)
|
||||
global_problems = None
|
||||
|
||||
|
||||
def code_verify(generateds, query_ids, debug=False, timeout=20, timeout_for_testcase=6):
|
||||
|
@ -81,6 +74,15 @@ def code_verify(generateds, query_ids, debug=False, timeout=20, timeout_for_test
|
|||
payload_list = []
|
||||
|
||||
global global_problems
|
||||
if global_problems is None:
|
||||
global_problems = load_problems_with_testcase_batch(
|
||||
os.getenv(
|
||||
"REAL_CODE_METADATA_PATH",
|
||||
"/storage/datasets/codeparrot-apps-test.jsonl",
|
||||
),
|
||||
debug=True,
|
||||
test_case_batch_size=20,
|
||||
)
|
||||
for idx, query_id in enumerate(query_ids):
|
||||
if query_id not in global_problems:
|
||||
payload_list.append(None)
|
||||
|
|
|
@ -19,17 +19,19 @@ def loadJson(dataDir):
|
|||
return samples
|
||||
|
||||
|
||||
id2info = loadJson(
|
||||
os.getenv(
|
||||
"REAL_MATH_MEATADATA_PATH",
|
||||
"/storage/datasets/id2info.json",
|
||||
)
|
||||
)
|
||||
id2info = None
|
||||
|
||||
|
||||
def math_verify(generateds: List, query_ids: List, batch_size=20, timeout=60) -> List:
|
||||
start_time = time.time()
|
||||
global id2info
|
||||
if id2info is None:
|
||||
id2info = loadJson(
|
||||
os.getenv(
|
||||
"REAL_MATH_MEATADATA_PATH",
|
||||
"/storage/datasets/id2info.json",
|
||||
)
|
||||
)
|
||||
assert len(generateds) == len(query_ids), (
|
||||
len(generateds),
|
||||
len(query_ids),
|
||||
|
|
|
@ -396,6 +396,7 @@ class SequenceSample:
|
|||
group_indices = datapack.ffd_allocate(
|
||||
lens, mb_spec.max_tokens_per_mb, min_groups=mb_spec.n_mbs
|
||||
)
|
||||
group_indices = sorted([sorted(g) for g in group_indices])
|
||||
|
||||
forward_indices = datapack.flat2d(group_indices)
|
||||
sample = SequenceSample.reorder(self, forward_indices)
|
||||
|
|
|
@ -463,13 +463,12 @@ class ExperimentConfig:
|
|||
data_transfer_pairs.append((mn1, mn2))
|
||||
src_rpcs = [rpc for rpc in self.model_rpcs if rpc.is_src]
|
||||
data_src_rpc = src_rpcs[0]
|
||||
for r in src_rpcs[1:]:
|
||||
for r in src_rpcs:
|
||||
if (
|
||||
data_src_rpc.model_name,
|
||||
r.model_name,
|
||||
) not in data_transfer_pairs:
|
||||
data_transfer_pairs.append((data_src_rpc.model_name, r.model_name))
|
||||
data_transfer_pairs += [(mn, mn) for mn in model_names]
|
||||
return data_transfer_pairs
|
||||
|
||||
def _resolve_param_realloc_pairs(
|
||||
|
|
|
@ -162,7 +162,6 @@ def main_start(args, recover_count: int = 0):
|
|||
REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0",
|
||||
REAL_MATH_METADATA_PATH=os.getenv("REAL_MATH_METADATA_PATH", ""),
|
||||
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
|
||||
REAL_USE_V2_WORKER=os.getenv("REAL_USE_V2_WORKER", "0"),
|
||||
)
|
||||
for k, v in BASE_ENVIRONS.items():
|
||||
os.environ[k] = v
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import sys
|
||||
import threading
|
||||
from asyncio.base_events import _run_until_complete_cb
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AsyncRunUntilCompleteContext:
|
||||
loop: asyncio.BaseEventLoop
|
||||
future: asyncio.Future
|
||||
new_task: bool
|
||||
|
||||
|
||||
def setup_run_until_complete(
|
||||
loop: asyncio.BaseEventLoop,
|
||||
future: asyncio.Future,
|
||||
) -> AsyncRunUntilCompleteContext:
|
||||
loop._check_closed()
|
||||
loop._check_running()
|
||||
|
||||
new_task = not asyncio.futures.isfuture(future)
|
||||
future = asyncio.tasks.ensure_future(future, loop=loop)
|
||||
if new_task:
|
||||
# An exception is raised if the future didn't complete, so there
|
||||
# is no need to log the "destroy pending task" message
|
||||
future._log_destroy_pending = False
|
||||
|
||||
future.add_done_callback(_run_until_complete_cb)
|
||||
|
||||
# set up run forever
|
||||
loop._set_coroutine_origin_tracking(loop._debug)
|
||||
|
||||
loop._old_agen_hooks = sys.get_asyncgen_hooks()
|
||||
loop._thread_id = threading.get_ident()
|
||||
sys.set_asyncgen_hooks(
|
||||
firstiter=loop._asyncgen_firstiter_hook,
|
||||
finalizer=loop._asyncgen_finalizer_hook,
|
||||
)
|
||||
asyncio.events._set_running_loop(loop)
|
||||
return AsyncRunUntilCompleteContext(loop=loop, future=future, new_task=new_task)
|
||||
|
||||
|
||||
def teardown_run_util_complete(ctx: AsyncRunUntilCompleteContext):
|
||||
ctx.loop._stopping = False
|
||||
ctx.loop._thread_id = None
|
||||
asyncio.events._set_running_loop(None)
|
||||
ctx.loop._set_coroutine_origin_tracking(False)
|
||||
# Restore any pre-existing async generator hooks.
|
||||
if ctx.loop._old_agen_hooks is not None:
|
||||
sys.set_asyncgen_hooks(*ctx.loop._old_agen_hooks)
|
||||
ctx.loop._old_agen_hooks = None
|
||||
|
||||
ctx.future.remove_done_callback(_run_until_complete_cb)
|
||||
|
||||
if not ctx.future.done():
|
||||
raise RuntimeError("Event loop stopped before Future completed.")
|
||||
|
||||
|
||||
def raise_asyncio_exception(
|
||||
ctx: AsyncRunUntilCompleteContext, raise_error: bool = True
|
||||
):
|
||||
if ctx.new_task and ctx.future.done() and not ctx.future.cancelled():
|
||||
# The coroutine raised a BaseException. Consume the exception
|
||||
# to not log a warning, the caller doesn't have access to the
|
||||
# local task.
|
||||
ctx.future.exception()
|
||||
|
||||
try:
|
||||
teardown_run_util_complete(ctx)
|
||||
except RuntimeError as e:
|
||||
if raise_error:
|
||||
raise e
|
||||
|
||||
if raise_error:
|
||||
raise
|
|
@ -229,6 +229,8 @@ def resolve_rpc_hooks(
|
|||
class AllocationType(enum.Enum):
|
||||
DECOUPLED = 1
|
||||
GLOBAL_HYBRID = 2
|
||||
MANUAL = 3
|
||||
SEARCH = 4
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -244,6 +246,10 @@ class AllocationMode:
|
|||
|
||||
@classmethod
|
||||
def from_str(cls, allocation_mode: str):
|
||||
if allocation_mode == "manual":
|
||||
return cls(AllocationType.MANUAL, None)
|
||||
if allocation_mode == "search":
|
||||
return cls(AllocationType.SEARCH, None)
|
||||
alloc_3d = AllocationMode.extract_3d_alloc(allocation_mode)
|
||||
alloc_hybrid = AllocationMode.extract_key_value_alloc(allocation_mode)
|
||||
alloc_decoupled = AllocationMode.extract_decoupled_alloc(allocation_mode)
|
||||
|
|
|
@ -82,7 +82,6 @@ class PromptDataset(torch.utils.data.Dataset):
|
|||
ids=[self.ids[idx]],
|
||||
seqlens=[self.prompt_lengths[idx]],
|
||||
data=dict(packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long)),
|
||||
metadata=dict(random_id=[uuid.uuid4()]),
|
||||
)
|
||||
|
||||
|
||||
|
@ -178,7 +177,6 @@ class MATHPromptDataset(torch.utils.data.Dataset):
|
|||
[self.base_scores[idx]], dtype=torch.float32
|
||||
),
|
||||
),
|
||||
metadata=dict(random_id=[uuid.uuid4()]),
|
||||
)
|
||||
else:
|
||||
return data_api.SequenceSample.from_default(
|
||||
|
@ -187,7 +185,6 @@ class MATHPromptDataset(torch.utils.data.Dataset):
|
|||
data=dict(
|
||||
packed_prompts=torch.tensor(self.prompts[idx], dtype=torch.long)
|
||||
),
|
||||
metadata=dict(random_id=[uuid.uuid4()]),
|
||||
)
|
||||
|
||||
def filter(self, eval_scores: Dict[Hashable, float]):
|
||||
|
|
|
@ -1,344 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import dataclasses
|
||||
import itertools
|
||||
from typing import *
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from realhf.api.core.config import ModelName, ModelShardID
|
||||
from realhf.api.core.data_api import SequenceSample
|
||||
from realhf.base import constants, topology
|
||||
from realhf.impl.model.comm.global_comm import filter_match_mwids
|
||||
from realhf.impl.model.comm.param_realloc import pipeline_repartition_strategy
|
||||
|
||||
BCAST_CHUNK_SIZE_BYTES = int(4e6)
|
||||
|
||||
|
||||
@dataclasses.dataclass(unsafe_hash=True)
|
||||
class DataTransferPair:
|
||||
src: ModelName
|
||||
src_dp_rank: int
|
||||
dst: ModelName
|
||||
dst_dp_rank: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DataTransferInfo:
|
||||
# Groups for data transfer among model workers.
|
||||
data_transfer_groups: Dict[DataTransferPair, dist.ProcessGroup]
|
||||
data_transfer_src_ranks: Dict[DataTransferPair, int]
|
||||
data_transfer_dst_ranks: Dict[DataTransferPair, List[int]]
|
||||
|
||||
|
||||
def setup_data_transfer(
|
||||
model_topos: Optional[Dict[str, topology.PipeModelDataParallelTopology]] = None,
|
||||
msid2mwid: Optional[Dict[ModelShardID, int]] = None,
|
||||
data_transfer_pairs: Optional[List[Tuple[ModelName, ModelName]]] = None,
|
||||
) -> DataTransferInfo:
|
||||
|
||||
# Stores the ranks given a (model_name, dp_rank) pair.
|
||||
# These workers correspond to a complete set of model parameters sharded by TP+PP.
|
||||
mw_dp_ranks: Dict[Tuple[ModelName, int], List[int]] = {}
|
||||
|
||||
# Stores the dp_head (i.e., mp_rank=0, pp_rank=-1) ranks given a model_name.
|
||||
mw_dp_head_ranks: Dict[ModelName, List[int]] = {}
|
||||
|
||||
if model_topos is not None:
|
||||
assert msid2mwid is not None
|
||||
for model_name, topo in model_topos.items():
|
||||
mw_dp_head_ranks[model_name] = filter_match_mwids(
|
||||
model_name,
|
||||
topo,
|
||||
msid2mwid,
|
||||
pipe=topo.get_dim("pipe") - 1,
|
||||
model=0,
|
||||
)
|
||||
dp_size = topo.get_dim("data")
|
||||
for dp_i in range(dp_size):
|
||||
mw_dp_ranks[model_name, dp_i] = filter_match_mwids(
|
||||
model_name,
|
||||
topo,
|
||||
msid2mwid,
|
||||
data=dp_i,
|
||||
)
|
||||
|
||||
data_transfer_groups, data_transfer_src_ranks = {}, {}
|
||||
data_transfer_dst_ranks = {}
|
||||
if data_transfer_pairs is not None:
|
||||
for src, dst in data_transfer_pairs:
|
||||
src_topo = model_topos[src]
|
||||
dst_topo = model_topos[dst]
|
||||
|
||||
# Construct all src-dst pairs, from any src dp rank to any dst dp rank.
|
||||
# Note that a dp rank corresponds to multiple parameter shards (TP+PP),
|
||||
# so each pair is a group-to-group communication.
|
||||
# Since the models in the source group have duplicate data (TP+PP),
|
||||
# we just use its "head" as the broadcast source,
|
||||
# and broadcast to all the ranks in the destination group.
|
||||
for src_dp, dst_dp in itertools.product(
|
||||
range(src_topo.get_dim("data")), range(dst_topo.get_dim("data"))
|
||||
):
|
||||
key = DataTransferPair(
|
||||
src=src, src_dp_rank=src_dp, dst=dst, dst_dp_rank=dst_dp
|
||||
)
|
||||
src_mw_rank = mw_dp_head_ranks[src][src_dp]
|
||||
dst_mw_ranks = mw_dp_ranks[dst, dst_dp]
|
||||
data_transfer_dst_ranks[key] = dst_mw_ranks
|
||||
# The src and dst groups can be disjoint or overlapped.
|
||||
# If they are disjoint, we need to include the src_mw_rank in the group.
|
||||
# Otherwise, we only need to include the dst_mw_ranks.
|
||||
if src_mw_rank not in dst_mw_ranks:
|
||||
_ranks = [src_mw_rank] + dst_mw_ranks
|
||||
else:
|
||||
_ranks = dst_mw_ranks
|
||||
data_transfer_groups[key] = topology.new_or_get_group(
|
||||
_ranks, backend="nccl" if constants.use_cuda() else "gloo"
|
||||
)
|
||||
data_transfer_src_ranks[key] = src_mw_rank
|
||||
|
||||
return DataTransferInfo(
|
||||
data_transfer_groups=data_transfer_groups,
|
||||
data_transfer_src_ranks=data_transfer_src_ranks,
|
||||
data_transfer_dst_ranks=data_transfer_dst_ranks,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DataTransferSenderStep:
|
||||
rank: int
|
||||
dst_ranks: List[int]
|
||||
group: dist.ProcessGroup
|
||||
key: str
|
||||
ids: List[int]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DataTransferReceiverStep:
|
||||
rank: int
|
||||
src: int
|
||||
dst_ranks: List[int]
|
||||
group: dist.ProcessGroup
|
||||
key: str
|
||||
ids: List[int]
|
||||
|
||||
|
||||
def chunked_bcast(buf, src_rank, group, chunk_size_bytes=BCAST_CHUNK_SIZE_BYTES):
|
||||
shape = buf.shape
|
||||
chunk_size = chunk_size_bytes // buf.dtype.itemsize
|
||||
buf = buf.flatten()
|
||||
for i in range(0, buf.numel(), chunk_size):
|
||||
s = slice(i, i + chunk_size)
|
||||
dist.broadcast(buf[s], src=src_rank, group=group)
|
||||
return buf.view(*shape)
|
||||
|
||||
|
||||
def derive_data_transfer_plan(
|
||||
keys: List[str],
|
||||
global_ids: List[int],
|
||||
consumer_name: ModelName,
|
||||
consumer_mapping: Dict[int, List[int]],
|
||||
producer_names: Dict[str, ModelName],
|
||||
producer_mappings: Dict[Tuple[ModelName, str], Dict[int, List[int]]],
|
||||
data_transfer_info: DataTransferInfo,
|
||||
) -> List[DataTransferReceiverStep | DataTransferSenderStep]:
|
||||
comm_plan = []
|
||||
|
||||
for k in keys:
|
||||
producer_name = producer_names[k]
|
||||
producer_mapping = producer_mappings[(producer_name, k)]
|
||||
|
||||
# partition mapping starts from zero, which is different from buffer indices
|
||||
repart_strat = pipeline_repartition_strategy(producer_mapping, consumer_mapping)
|
||||
|
||||
for (dp_i, dp_j), comm_slots in repart_strat.items():
|
||||
if len(comm_slots) == 0:
|
||||
continue
|
||||
|
||||
group_key = DataTransferPair(
|
||||
src=producer_name,
|
||||
src_dp_rank=dp_i,
|
||||
dst=consumer_name,
|
||||
dst_dp_rank=dp_j,
|
||||
)
|
||||
bcast_src = data_transfer_info.data_transfer_src_ranks[group_key]
|
||||
group = data_transfer_info.data_transfer_groups[group_key]
|
||||
dst_ranks = data_transfer_info.data_transfer_dst_ranks[group_key]
|
||||
|
||||
ids = [global_ids[_i] for _i in comm_slots]
|
||||
|
||||
for dst_rank in dst_ranks:
|
||||
comm_plan.append(
|
||||
DataTransferReceiverStep(
|
||||
rank=dst_rank,
|
||||
src=bcast_src,
|
||||
dst_ranks=dst_ranks,
|
||||
group=group,
|
||||
key=k,
|
||||
ids=ids,
|
||||
)
|
||||
)
|
||||
|
||||
comm_plan.append(
|
||||
DataTransferSenderStep(
|
||||
rank=bcast_src,
|
||||
dst_ranks=dst_ranks,
|
||||
group=group,
|
||||
key=k,
|
||||
ids=ids,
|
||||
)
|
||||
)
|
||||
|
||||
return comm_plan
|
||||
|
||||
|
||||
def run_data_transfer(
|
||||
comm_plan: List[DataTransferReceiverStep | DataTransferSenderStep],
|
||||
meta_samples: Dict[int, SequenceSample],
|
||||
storage: Dict[int, SequenceSample],
|
||||
sent_worker_idx_table: Dict[int, Dict[str, Set[int]]],
|
||||
received_worker_idx_table: Dict[int, Dict[str, Set[int]]],
|
||||
) -> Tuple[Set[int], Set[str]]:
|
||||
device = "cuda" if constants.use_cuda() else "cpu"
|
||||
for step in comm_plan:
|
||||
|
||||
if isinstance(step, DataTransferReceiverStep) and step.rank == dist.get_rank():
|
||||
ids = step.ids
|
||||
if step.src == dist.get_rank():
|
||||
# The receiver is also a sender.
|
||||
# We can directly use the data without comm.
|
||||
for _id in step.ids:
|
||||
if storage[_id].data[step.key] is not None:
|
||||
storage[_id].data[step.key] = (
|
||||
storage[_id].data[step.key].to(device)
|
||||
)
|
||||
else:
|
||||
# If we have to receive remote data, we first check whether
|
||||
# the data has been sent here in previous function calls.
|
||||
# If so, just fetch it from the cache.
|
||||
cached = all(
|
||||
[
|
||||
set(step.dst_ranks).issubset(
|
||||
set(received_worker_idx_table[_id][step.key])
|
||||
)
|
||||
for _id in ids
|
||||
]
|
||||
)
|
||||
|
||||
metadata_cached = all(
|
||||
[
|
||||
set(step.dst_ranks).issubset(
|
||||
set(received_worker_idx_table[_id]["__metadata__"])
|
||||
)
|
||||
for _id in ids
|
||||
]
|
||||
)
|
||||
if cached:
|
||||
pass
|
||||
else:
|
||||
dtype = meta_samples[ids[0]].dtypes[step.key]
|
||||
total_len = sum(
|
||||
sum(meta_samples[_id].seqlens[step.key][0]) for _id in ids
|
||||
)
|
||||
trailing_shape = meta_samples[ids[0]].trailing_shapes[step.key]
|
||||
|
||||
# Receive data if it is not None.
|
||||
if trailing_shape is not None:
|
||||
buf = torch.zeros(
|
||||
(total_len, *trailing_shape),
|
||||
dtype=dtype,
|
||||
device=constants.current_device(),
|
||||
)
|
||||
chunked_bcast(buf, step.src, step.group)
|
||||
else:
|
||||
buf = None
|
||||
|
||||
# Receive metadata if not cached.
|
||||
if not metadata_cached:
|
||||
metadatas = [{} for _ in step.ids]
|
||||
dist.broadcast_object_list(
|
||||
metadatas, src=step.src, group=step.group
|
||||
)
|
||||
|
||||
# Mark that the data has been received.
|
||||
for _id in ids:
|
||||
received_worker_idx_table[_id][step.key].union(step.dst_ranks)
|
||||
received_worker_idx_table[_id]["__metadata__"].union(
|
||||
step.dst_ranks
|
||||
)
|
||||
|
||||
# Split the received data and put it into the storage.
|
||||
offset = 0
|
||||
for _id, metadata in zip(ids, metadatas):
|
||||
seqlens = meta_samples[_id].seqlens[step.key]
|
||||
assert len(seqlens) == 1
|
||||
seqlen = sum(seqlens[0])
|
||||
if buf is not None:
|
||||
vs = buf[offset : offset + seqlen]
|
||||
else:
|
||||
vs = None
|
||||
offset = offset + seqlen
|
||||
with SequenceSample.disable_validation():
|
||||
s = SequenceSample(
|
||||
keys=[step.key],
|
||||
dtypes={step.key: vs.dtype if vs is not None else None},
|
||||
trailing_shapes={
|
||||
step.key: vs.shape[1:] if vs is not None else None
|
||||
},
|
||||
ids=[_id],
|
||||
seqlens={step.key: seqlens},
|
||||
data={step.key: vs},
|
||||
metadata=metadata,
|
||||
)
|
||||
if _id in storage:
|
||||
storage[_id].update_(s)
|
||||
else:
|
||||
storage[_id] = s
|
||||
|
||||
if isinstance(step, DataTransferSenderStep) and step.rank == dist.get_rank():
|
||||
# Similar to the receiver, we first check whether the data has been sent to all destinations.
|
||||
cached = all(
|
||||
[
|
||||
set(step.dst_ranks).issubset(
|
||||
set(sent_worker_idx_table[_id][step.key])
|
||||
)
|
||||
for _id in step.ids
|
||||
]
|
||||
)
|
||||
metadata_cached = all(
|
||||
[
|
||||
set(step.dst_ranks).issubset(
|
||||
set(sent_worker_idx_table[_id]["__metadata__"])
|
||||
)
|
||||
for _id in step.ids
|
||||
]
|
||||
)
|
||||
if cached:
|
||||
pass
|
||||
else:
|
||||
# If not cached, we fetch the data from the storage and send it to all destinations.
|
||||
for _id in step.ids:
|
||||
if storage[_id].data[step.key] is not None:
|
||||
storage[_id].data[step.key] = (
|
||||
storage[_id].data[step.key].to(device)
|
||||
)
|
||||
if all([storage[_id].data[step.key] is not None for _id in step.ids]):
|
||||
vs = torch.cat(
|
||||
[storage[_id].data[step.key] for _id in step.ids],
|
||||
dim=0,
|
||||
)
|
||||
chunked_bcast(vs, step.rank, step.group)
|
||||
|
||||
if not metadata_cached:
|
||||
dist.broadcast_object_list(
|
||||
[storage[_id].metadata for _id in step.ids],
|
||||
src=step.rank,
|
||||
group=step.group,
|
||||
)
|
||||
|
||||
for _id in step.ids:
|
||||
sent_worker_idx_table[_id][step.key].union(step.dst_ranks)
|
||||
sent_worker_idx_table[_id]["__metadata__"].union(step.dst_ranks)
|
|
@ -15,7 +15,6 @@ logger = logging.getLogger("system")
|
|||
# NOTE: Workers are configured in the following order.
|
||||
# Take special care when adding a new worker type.
|
||||
WORKER_TYPES = ["model_worker", "master_worker"]
|
||||
USE_V2_WORKER = os.getenv("REAL_USE_V2_WORKER", "0") == "1"
|
||||
|
||||
|
||||
def load_worker(worker_type: str) -> Type:
|
||||
|
@ -26,8 +25,6 @@ def load_worker(worker_type: str) -> Type:
|
|||
|
||||
|
||||
def worker_type_to_module(worker_type: str):
|
||||
if worker_type == "master_worker" and USE_V2_WORKER:
|
||||
return "realhf.system.v2." + worker_type
|
||||
return "realhf.system." + worker_type
|
||||
|
||||
|
||||
|
|
|
@ -17,32 +17,6 @@ from realhf.api.core.data_api import SequenceSample
|
|||
logger = logging.getLogger("buffer")
|
||||
|
||||
|
||||
def _extract_intervals(arr):
|
||||
if len(arr) == 0:
|
||||
return []
|
||||
|
||||
# Initialize the list to hold the intervals
|
||||
intervals = []
|
||||
|
||||
# Start of the first interval
|
||||
start = arr[0]
|
||||
|
||||
for i in range(1, len(arr)):
|
||||
# Check if the current element is not contiguous with the previous one
|
||||
if arr[i] != arr[i - 1] + 1:
|
||||
# End of the current interval
|
||||
end = arr[i - 1]
|
||||
# Add the interval as a tuple
|
||||
intervals.append((start, end + 1))
|
||||
# Start a new interval
|
||||
start = arr[i]
|
||||
|
||||
# Add the last interval
|
||||
intervals.append((start, arr[-1] + 1))
|
||||
|
||||
return intervals
|
||||
|
||||
|
||||
class BufferFull(Exception):
|
||||
pass
|
||||
|
||||
|
|
|
@ -620,7 +620,6 @@ class RayController:
|
|||
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
|
||||
REAL_MATH_METADATA_PATH=os.environ.get("REAL_MATH_METADATA_PATH", ""),
|
||||
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
|
||||
REAL_USE_V2_WORKER=os.getenv("REAL_USE_V2_WORKER", "0"),
|
||||
)
|
||||
runtime_env = {
|
||||
"env_vars": env_vars,
|
||||
|
|
|
@ -0,0 +1,395 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import bisect
|
||||
import dataclasses
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from realhf import SequenceSample
|
||||
from realhf.api.core.config import ModelName, ModelShardID
|
||||
from realhf.base import constants
|
||||
from realhf.base.topology import PipeModelDataParallelTopology, new_or_get_group
|
||||
from realhf.impl.model.comm.global_comm import filter_match_mwids
|
||||
from realhf.system.redistributor import RedistribStep
|
||||
|
||||
BCAST_GROUPS = {}
|
||||
GATHER_GROUPS = {}
|
||||
SCATTER_GROUPS = {}
|
||||
|
||||
|
||||
class DataManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_topos: Dict[ModelName, PipeModelDataParallelTopology],
|
||||
msid2mwid: Optional[Dict[ModelShardID, int]] = None,
|
||||
data_transfer_pairs: Optional[List[Tuple[ModelName, ModelName]]] = None,
|
||||
):
|
||||
self.model_topos = model_topos
|
||||
self.msid2mwid = msid2mwid
|
||||
self.data_transfer_pairs = data_transfer_pairs
|
||||
|
||||
self.storage: Dict[Hashable, SequenceSample] = {}
|
||||
|
||||
def setup_process_groups(self):
|
||||
if self.msid2mwid is None or self.data_transfer_pairs is None:
|
||||
return
|
||||
|
||||
model_topos = self.model_topos
|
||||
msid2mwid = self.msid2mwid
|
||||
data_transfer_pairs = self.data_transfer_pairs
|
||||
|
||||
# Stores the ranks given a (model_name, dp_rank) pair.
|
||||
# These workers correspond to a complete set of model parameters sharded by TP+PP.
|
||||
mw_dp_ranks: Dict[Tuple[ModelName, int], List[int]] = {}
|
||||
|
||||
mw_ranks: Dict[ModelName, List[int]] = {}
|
||||
|
||||
# Stores the dp_head (i.e., mp_rank=0, pp_rank=-1) ranks given a model_name.
|
||||
mw_dp_head_ranks: Dict[ModelName, List[int]] = defaultdict(list)
|
||||
|
||||
assert msid2mwid is not None
|
||||
for model_name, topo in model_topos.items():
|
||||
mw_ranks[model_name] = filter_match_mwids(
|
||||
model_name,
|
||||
topo,
|
||||
msid2mwid,
|
||||
)
|
||||
mw_dp_head_ranks[model_name] = filter_match_mwids(
|
||||
model_name,
|
||||
topo,
|
||||
msid2mwid,
|
||||
pipe=topo.get_dim("pipe") - 1,
|
||||
model=0,
|
||||
)
|
||||
dp_size = topo.get_dim("data")
|
||||
for dp_i in range(dp_size):
|
||||
mw_dp_ranks[model_name, dp_i] = filter_match_mwids(
|
||||
model_name,
|
||||
topo,
|
||||
msid2mwid,
|
||||
data=dp_i,
|
||||
)
|
||||
|
||||
for src, dst in data_transfer_pairs:
|
||||
src_topo = model_topos[src]
|
||||
dst_topo = model_topos[dst]
|
||||
|
||||
ranks = tuple(sorted(mw_dp_head_ranks[src]))
|
||||
GATHER_GROUPS[ranks] = new_or_get_group(
|
||||
list(ranks), backend="nccl" if constants.use_cuda() else "gloo"
|
||||
)
|
||||
|
||||
scatter_ranks = tuple(sorted(set([ranks[0]] + mw_ranks[dst])))
|
||||
SCATTER_GROUPS[scatter_ranks] = new_or_get_group(
|
||||
list(scatter_ranks),
|
||||
backend="nccl" if constants.use_cuda() else "gloo",
|
||||
)
|
||||
|
||||
# Construct all src-dst pairs, from any src dp rank to any dst dp rank.
|
||||
# Note that a dp rank corresponds to multiple parameter shards (TP+PP),
|
||||
# so each pair is a group-to-group communication.
|
||||
# Since the models in the source group have duplicate data (TP+PP),
|
||||
# we just use its "head" as the broadcast source,
|
||||
# and broadcast to all the ranks in the destination group.
|
||||
for src_dp, dst_dp in itertools.product(
|
||||
range(src_topo.get_dim("data")), range(dst_topo.get_dim("data"))
|
||||
):
|
||||
src_mw_rank = mw_dp_head_ranks[src][src_dp]
|
||||
dst_mw_ranks = mw_dp_ranks[dst, dst_dp]
|
||||
# The src and dst groups can be disjoint or overlapped.
|
||||
# If they are disjoint, we need to include the src_mw_rank in the group.
|
||||
# Otherwise, we only need to include the dst_mw_ranks.
|
||||
if src_mw_rank not in dst_mw_ranks:
|
||||
_ranks = [src_mw_rank] + dst_mw_ranks
|
||||
else:
|
||||
_ranks = dst_mw_ranks
|
||||
key = tuple(sorted(_ranks))
|
||||
BCAST_GROUPS[key] = new_or_get_group(
|
||||
_ranks, backend="nccl" if constants.use_cuda() else "gloo"
|
||||
)
|
||||
|
||||
def storage_size(self):
|
||||
return len(self.storage)
|
||||
|
||||
def store(self, x: SequenceSample):
|
||||
assert len(x.ids) == 1
|
||||
assert x.ids[0] not in self.storage
|
||||
self.storage[x.ids[0]] = x
|
||||
|
||||
def update(self, x: SequenceSample):
|
||||
self.storage[x.ids[0]].update_(x)
|
||||
|
||||
def get(self, data_id: Hashable):
|
||||
return self.storage[data_id]
|
||||
|
||||
def has_data(self, data_id: Hashable):
|
||||
return data_id in self.storage
|
||||
|
||||
def remove(self, ids: List[Hashable]):
|
||||
for data_id in ids:
|
||||
if data_id in self.storage:
|
||||
del self.storage[data_id]
|
||||
|
||||
def clear_data(self):
|
||||
self.storage.clear()
|
||||
|
||||
def _bcast_recv(
|
||||
self,
|
||||
step: RedistribStep,
|
||||
data_infos: Dict[Hashable, SequenceSample],
|
||||
):
|
||||
assert len(step.keys) == 1
|
||||
ids = step.ids
|
||||
key = step.keys[0]
|
||||
dtype = data_infos[ids[0]].dtypes[key]
|
||||
total_len = sum(sum(data_infos[_id].seqlens[key][0]) for _id in ids)
|
||||
trailing_shape = data_infos[ids[0]].trailing_shapes[key]
|
||||
|
||||
buf = torch.zeros(
|
||||
(total_len, *trailing_shape),
|
||||
dtype=dtype,
|
||||
device=constants.current_device(),
|
||||
)
|
||||
|
||||
if len(step.dsts) == 1:
|
||||
dist.recv(buf, src=step.root)
|
||||
else:
|
||||
global BCAST_GROUPS
|
||||
group = BCAST_GROUPS[tuple(sorted([step.root] + list(step.dsts)))]
|
||||
dist.broadcast(buf, src=step.root, group=group)
|
||||
|
||||
# Split the received data and put it into the storage.
|
||||
offset = 0
|
||||
for _id in ids:
|
||||
seqlens = data_infos[_id].seqlens[key]
|
||||
assert len(seqlens) == 1
|
||||
seqlen = sum(seqlens[0])
|
||||
if buf is not None:
|
||||
vs = buf[offset : offset + seqlen]
|
||||
else:
|
||||
vs = None
|
||||
offset = offset + seqlen
|
||||
with SequenceSample.disable_validation():
|
||||
s = SequenceSample(
|
||||
keys=[key],
|
||||
dtypes={key: vs.dtype if vs is not None else None},
|
||||
trailing_shapes={key: vs.shape[1:] if vs is not None else None},
|
||||
ids=[_id],
|
||||
seqlens={key: seqlens},
|
||||
data={key: vs},
|
||||
metadata={},
|
||||
)
|
||||
if _id in self.storage:
|
||||
self.storage[_id].update_(s)
|
||||
else:
|
||||
self.storage[_id] = s
|
||||
|
||||
def _bcast_send(self, step: RedistribStep):
|
||||
ids = step.ids
|
||||
for _id in ids:
|
||||
self.storage[_id].to_device(constants.current_device())
|
||||
assert len(step.keys) == 1
|
||||
key = step.keys[0]
|
||||
vs = torch.cat(
|
||||
[self.storage[_id].data[key] for _id in ids],
|
||||
dim=0,
|
||||
)
|
||||
if len(step.dsts) == 1:
|
||||
dist.send(vs, dst=step.dsts[0])
|
||||
else:
|
||||
global BCAST_GROUPS
|
||||
group = BCAST_GROUPS[tuple(sorted([step.root] + list(step.dsts)))]
|
||||
dist.broadcast(vs, src=step.root, group=group)
|
||||
|
||||
def _run_bcast(
|
||||
self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample]
|
||||
):
|
||||
if dist.get_rank() in step.dsts:
|
||||
self._bcast_recv(step, data_infos=data_infos)
|
||||
|
||||
if dist.get_rank() == step.root:
|
||||
self._bcast_send(step)
|
||||
|
||||
def _pad_data(self, x: torch.Tensor, maxlen: int):
|
||||
assert x.dtype == torch.float32
|
||||
assert len(x.shape) == 1
|
||||
if maxlen > x.numel():
|
||||
return torch.nn.functional.pad(x, (0, maxlen - x.numel()), value=0.0)
|
||||
return x
|
||||
|
||||
def _run_gather(
|
||||
self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample]
|
||||
):
|
||||
if dist.get_rank() not in step.srcs:
|
||||
return
|
||||
|
||||
maxlen = 0
|
||||
for ids in step.ids:
|
||||
infos = [data_infos[i] for i in ids]
|
||||
maxlen = max(
|
||||
maxlen,
|
||||
sum(
|
||||
[
|
||||
sum([sum(info.seqlens[key][0]) for info in infos])
|
||||
for key in step.keys
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
if dist.get_rank() == step.root:
|
||||
gather_list = [
|
||||
torch.empty(
|
||||
maxlen, device=constants.current_device(), dtype=torch.float32
|
||||
)
|
||||
for _ in range(len(step.srcs))
|
||||
]
|
||||
else:
|
||||
gather_list = None
|
||||
|
||||
local_gather_idx = step.srcs.index(dist.get_rank())
|
||||
ids = step.ids[local_gather_idx]
|
||||
for i in ids:
|
||||
self.storage[i].to_device(constants.current_device())
|
||||
samples = [self.storage[i] for i in ids]
|
||||
data = torch.cat(
|
||||
[
|
||||
sample.data[key].float().flatten()
|
||||
for sample in samples
|
||||
for key in step.keys
|
||||
]
|
||||
)
|
||||
data = self._pad_data(data, maxlen)
|
||||
|
||||
dist.gather(
|
||||
data,
|
||||
gather_list,
|
||||
dst=step.root,
|
||||
group=GATHER_GROUPS[tuple(sorted(step.srcs))],
|
||||
)
|
||||
|
||||
if dist.get_rank() != step.root:
|
||||
return
|
||||
|
||||
for ids, buf in zip(step.ids, gather_list):
|
||||
offset = 0
|
||||
for i in ids:
|
||||
for key in step.keys:
|
||||
seqlen = data_infos[i].seqlens[key][0]
|
||||
dtype = data_infos[i].dtypes[key]
|
||||
trailing_shape = data_infos[i].trailing_shapes[key]
|
||||
size = int(np.prod(trailing_shape) * sum(seqlen))
|
||||
data = buf[offset : offset + size].to(dtype)
|
||||
offset += size
|
||||
# with SequenceSample.disable_validation():
|
||||
s = SequenceSample(
|
||||
keys=[key],
|
||||
dtypes={key: dtype},
|
||||
trailing_shapes={key: trailing_shape},
|
||||
ids=[i],
|
||||
seqlens={key: [seqlen]},
|
||||
data={key: data},
|
||||
metadata={},
|
||||
)
|
||||
if i in self.storage:
|
||||
self.storage[i].update_(s)
|
||||
else:
|
||||
self.storage[i] = s
|
||||
|
||||
def _run_scatter(
|
||||
self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample]
|
||||
):
|
||||
if dist.get_rank() != step.root and dist.get_rank() not in step.dsts:
|
||||
return
|
||||
|
||||
maxlen = 0
|
||||
for ids in step.ids:
|
||||
infos = [data_infos[i] for i in ids]
|
||||
maxlen = max(
|
||||
maxlen,
|
||||
sum(
|
||||
[
|
||||
sum([sum(info.seqlens[key][0]) for info in infos])
|
||||
for key in step.keys
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
buf = torch.empty(
|
||||
maxlen, device=constants.current_device(), dtype=torch.float32
|
||||
)
|
||||
|
||||
if dist.get_rank() == step.root:
|
||||
scatter_list = []
|
||||
for ids in step.ids:
|
||||
for i in ids:
|
||||
self.storage[i].to_device(constants.current_device())
|
||||
samples = [self.storage[i] for i in ids]
|
||||
data = torch.cat(
|
||||
[
|
||||
sample.data[key].float().flatten()
|
||||
for sample in samples
|
||||
for key in step.keys
|
||||
]
|
||||
)
|
||||
scatter_list.append(data)
|
||||
maxlen = max([x.shape[0] for x in scatter_list])
|
||||
scatter_list = [self._pad_data(x, maxlen) for x in scatter_list]
|
||||
if step.root not in step.dsts:
|
||||
|
||||
idx = bisect.bisect(step.dsts, step.root)
|
||||
scatter_list.insert(idx, buf)
|
||||
else:
|
||||
scatter_list = None
|
||||
|
||||
key = tuple(sorted(set([step.root] + step.dsts)))
|
||||
dist.scatter(buf, scatter_list, src=step.root, group=SCATTER_GROUPS[key])
|
||||
|
||||
if dist.get_rank() not in step.dsts:
|
||||
return
|
||||
|
||||
local_dst_idx = step.dsts.index(dist.get_rank())
|
||||
ids = step.ids[local_dst_idx]
|
||||
offset = 0
|
||||
for i in ids:
|
||||
for key in step.keys:
|
||||
seqlen = data_infos[i].seqlens[key][0]
|
||||
dtype = data_infos[i].dtypes[key]
|
||||
trailing_shape = data_infos[i].trailing_shapes[key]
|
||||
size = int(np.prod(trailing_shape) * sum(seqlen))
|
||||
data = buf[offset : offset + size].to(dtype)
|
||||
offset += size
|
||||
# with SequenceSample.disable_validation():
|
||||
s = SequenceSample(
|
||||
keys=[key],
|
||||
dtypes={key: dtype},
|
||||
trailing_shapes={key: trailing_shape},
|
||||
ids=[i],
|
||||
seqlens={key: [seqlen]},
|
||||
data={key: data},
|
||||
metadata={},
|
||||
)
|
||||
if i in self.storage:
|
||||
self.storage[i].update_(s)
|
||||
else:
|
||||
self.storage[i] = s
|
||||
|
||||
def redistribute(
|
||||
self,
|
||||
data_info: SequenceSample,
|
||||
plan: List[RedistribStep],
|
||||
):
|
||||
data_infos = {x.ids[0]: x for x in data_info.unpack()}
|
||||
|
||||
for step in plan:
|
||||
if step.comm_type == "bcast":
|
||||
self._run_bcast(step, data_infos)
|
||||
elif step.comm_type == "gather":
|
||||
self._run_gather(step, data_infos)
|
||||
elif step.comm_type == "scatter":
|
||||
self._run_scatter(step, data_infos)
|
|
@ -4,6 +4,7 @@ import random
|
|||
from typing import *
|
||||
|
||||
import networkx as nx
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from realhf.api.core.config import ModelName, ModelShardID
|
||||
from realhf.api.core.data_api import DataBatchMeta, SequenceSample
|
||||
|
@ -12,8 +13,9 @@ from realhf.api.core.model_api import ReaLModelConfig
|
|||
from realhf.base import logging
|
||||
from realhf.base.topology import PipeModelDataParallelTopology
|
||||
from realhf.system.buffer import AsyncIOSequenceBuffer
|
||||
from realhf.system.model_function_call import ModelFunctionCall, RPCCorountineControl
|
||||
from realhf.system.redistributor import GlobalStorageTracker, RedistribPlanner
|
||||
from realhf.system.request_reply_stream import NameResolvingRequestClient
|
||||
from realhf.system.v2.function_call import FunctionCall, RPCCorountineControl
|
||||
|
||||
logger = logging.getLogger(__name__, "system")
|
||||
blogger = logging.getLogger("benchmark")
|
||||
|
@ -29,12 +31,17 @@ class FunctionExecutor:
|
|||
model_topos: Dict[str, PipeModelDataParallelTopology],
|
||||
model_configs: Dict[str, None | ReaLModelConfig],
|
||||
ctrl: RPCCorountineControl,
|
||||
summary_writer: SummaryWriter | None,
|
||||
):
|
||||
|
||||
self.func_calls: Dict[str, FunctionCall] = {}
|
||||
self.func_calls: Dict[str, ModelFunctionCall] = {}
|
||||
self.ctrl = ctrl
|
||||
|
||||
self.n_model_workers = len(set(msid2mwid.values()))
|
||||
self.msid2mwid = msid2mwid
|
||||
|
||||
self.storage_tracker = GlobalStorageTracker(self.n_model_workers)
|
||||
self.redistrib_planner = RedistribPlanner(self.storage_tracker)
|
||||
|
||||
self.rpcs = rpcs
|
||||
self.src_rpc = list(filter(lambda rpc: rpc.is_src, rpcs))[0]
|
||||
|
@ -42,7 +49,7 @@ class FunctionExecutor:
|
|||
|
||||
# Create model function calls.
|
||||
for rpc in self.rpcs:
|
||||
func_call = FunctionCall(
|
||||
func_call = ModelFunctionCall(
|
||||
rpc=rpc,
|
||||
src_rpc=self.src_rpc,
|
||||
stream=stream,
|
||||
|
@ -51,6 +58,8 @@ class FunctionExecutor:
|
|||
model_configs=model_configs,
|
||||
ctrl=ctrl,
|
||||
buffer=buffer,
|
||||
redistrib_planner=self.redistrib_planner,
|
||||
summary_writer=summary_writer,
|
||||
)
|
||||
self.func_calls[rpc.name] = func_call
|
||||
|
||||
|
@ -92,8 +101,6 @@ class FunctionExecutor:
|
|||
|
||||
while self.buffer.size < max(rpc.n_seqs for rpc in self.rpcs):
|
||||
|
||||
all_data = []
|
||||
|
||||
dp_idx += 1
|
||||
dp_idx %= self.src_dp_size
|
||||
|
||||
|
@ -110,21 +117,13 @@ class FunctionExecutor:
|
|||
if x.meta_sample is None:
|
||||
continue
|
||||
|
||||
# Store the owner information of the data.
|
||||
# RPCs corountines will use this information to
|
||||
# determine the src and dst of data transfer.
|
||||
for xx in x.meta_sample.unpack():
|
||||
if xx.ids[0] in received_ids:
|
||||
raise ValueError(
|
||||
f"Duplicate data id {xx.ids[0]}. Is the final batch? {is_final_batch}."
|
||||
)
|
||||
raise ValueError(f"Duplicate data id {xx.ids[0]}.")
|
||||
received_ids.add(xx.ids[0])
|
||||
for k in xx.keys:
|
||||
self.ctrl.data_owner[(xx.ids[0], k)] = (
|
||||
src_rpc_model_name,
|
||||
dp_idx,
|
||||
)
|
||||
all_data += x.meta_sample.unpack()
|
||||
all_data = x.meta_sample.unpack()
|
||||
|
||||
filtered_data = []
|
||||
for xx in x.meta_sample.unpack():
|
||||
|
@ -139,6 +138,16 @@ class FunctionExecutor:
|
|||
# so we also need to shuffle the data to fuse different dataset splits.
|
||||
random.shuffle(all_data)
|
||||
|
||||
# Update resource tracker for planning data redistribution.
|
||||
gpu_id = self.stream.route_to(f"__data{dp_idx}__")
|
||||
for k in all_data[0].keys:
|
||||
self.storage_tracker.add_data(
|
||||
gpu_id,
|
||||
[x.ids[0] for x in all_data],
|
||||
k,
|
||||
is_owner=True,
|
||||
)
|
||||
|
||||
# Store into buffer!
|
||||
buffer_indices = await buffer.put_batch(all_data)
|
||||
assert len(buffer_indices) == len(all_data)
|
||||
|
@ -156,13 +165,23 @@ class FunctionExecutor:
|
|||
tasks = [loop.create_task(fc.run()) for fc in self.func_calls.values()] + [
|
||||
loop.create_task(self.flush_calls()),
|
||||
loop.create_task(self.load_data()),
|
||||
loop.create_task(self.finish_traverse()),
|
||||
]
|
||||
|
||||
completion_future = loop.create_task(self.finish_traverse())
|
||||
|
||||
loop.run_until_complete(completion_future)
|
||||
|
||||
for task in tasks:
|
||||
loop.run_until_complete(task)
|
||||
loop.run_until_complete(asyncio.gather(*tasks))
|
||||
|
||||
logger.info("Execution finished!")
|
||||
|
||||
self.clear_gpu_cache()
|
||||
|
||||
def clear_gpu_cache(self):
|
||||
self.stream.request(
|
||||
handlers=list(range(self.n_model_workers)),
|
||||
handle_type="clear_data_cache",
|
||||
datas=[self.ctrl.ids_to_clear for _ in list(range(self.n_model_workers))],
|
||||
no_syn=True,
|
||||
)
|
||||
# Clear resource tracker as well.
|
||||
self.storage_tracker.clear_data(self.ctrl.ids_to_clear)
|
||||
|
||||
self.ctrl.ids_to_clear.clear()
|
File diff suppressed because it is too large
Load Diff
|
@ -2,6 +2,8 @@
|
|||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
|
@ -9,6 +11,7 @@ from collections import defaultdict
|
|||
from typing import Dict, Hashable, List, Set, Tuple
|
||||
|
||||
import wandb
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
import realhf.api.core.config as config_api
|
||||
import realhf.api.core.data_api as data_api
|
||||
|
@ -16,11 +19,13 @@ import realhf.api.core.dfg as dfg
|
|||
import realhf.api.core.system_api as config_pkg
|
||||
import realhf.base.recover as recover
|
||||
import realhf.system.request_reply_stream as request_reply_stream
|
||||
from realhf import ModelShardID
|
||||
from realhf.api.core.config import ModelName
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base import constants, logging, topology
|
||||
from realhf.system.buffer import AsyncIOSequenceBuffer
|
||||
from realhf.system.flops_counter import FlopsCounter
|
||||
from realhf.system.redistributor import RedistribPlanner, RedistribStep
|
||||
|
||||
logger = logging.getLogger(__name__, "system")
|
||||
blogger = logging.getLogger("benchmark")
|
||||
|
@ -38,10 +43,6 @@ class RPCCorountineControl:
|
|||
ids_to_clear: Set[Hashable] = dataclasses.field(default_factory=set)
|
||||
flops_counter: FlopsCounter = dataclasses.field(default_factory=FlopsCounter)
|
||||
|
||||
data_owner: Dict[Tuple[int, str], Tuple[ModelName, int]] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
should_save: bool = False
|
||||
should_eval: bool = False
|
||||
should_ckpt: bool = False
|
||||
|
@ -52,7 +53,7 @@ class RPCCorountineControl:
|
|||
hash_vals_to_ignore_in_recover: List[int] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
class FunctionCall:
|
||||
class ModelFunctionCall:
|
||||
def __init__(
|
||||
self,
|
||||
rpc: dfg.MFCDef,
|
||||
|
@ -63,16 +64,24 @@ class FunctionCall:
|
|||
model_configs: Dict[str, None | ReaLModelConfig],
|
||||
ctrl: RPCCorountineControl,
|
||||
buffer: AsyncIOSequenceBuffer,
|
||||
redistrib_planner: RedistribPlanner,
|
||||
summary_writer: SummaryWriter | None,
|
||||
):
|
||||
|
||||
self.rpc = rpc
|
||||
self.src_rpc = src_rpc
|
||||
self.stream = stream
|
||||
|
||||
self.n_model_workers = len(set(msid2mwid.values()))
|
||||
|
||||
self.msid2mwid = msid2mwid
|
||||
self.model_topos = model_topos
|
||||
self.model_configs = model_configs
|
||||
|
||||
self.mwid2msids = defaultdict(list)
|
||||
for msid, mwid in msid2mwid.items():
|
||||
self.mwid2msids[mwid].append(msid)
|
||||
|
||||
self.model_save_root = os.path.join(
|
||||
constants.MODEL_SAVE_ROOT,
|
||||
constants.experiment_name(),
|
||||
|
@ -80,8 +89,10 @@ class FunctionCall:
|
|||
)
|
||||
|
||||
self.rpc_ctrl = ctrl
|
||||
|
||||
self.buffer = buffer
|
||||
self.redistrib_planner = redistrib_planner
|
||||
|
||||
self.summary_writer = summary_writer
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
|
@ -172,10 +183,8 @@ class FunctionCall:
|
|||
|
||||
def request(
|
||||
self,
|
||||
producer_names: Dict[str, str],
|
||||
producer_name2producer_handlers: Dict[str, List[config_pkg.ModelShardID]],
|
||||
producer_mappings: Dict[str, Dict[str, List[int]]],
|
||||
target_mapping: Dict[str, List[int]],
|
||||
data_transfer_plan: List[RedistribStep],
|
||||
partitioned_ids: List[Hashable],
|
||||
meta_sample: data_api.SequenceSample,
|
||||
handlers: List[config_pkg.ModelShardID],
|
||||
) -> Tuple[List[uuid.UUID], List[uuid.UUID]]:
|
||||
|
@ -184,14 +193,12 @@ class FunctionCall:
|
|||
ctrl = self.rpc_ctrl
|
||||
|
||||
dt_data = {
|
||||
"keys": rpc.input_keys,
|
||||
"target": rpc.model_name,
|
||||
"producer_names": producer_names,
|
||||
"producer_mappings": producer_mappings,
|
||||
"target_mapping": target_mapping,
|
||||
"plan": [json.dumps(dataclasses.asdict(x)) for x in data_transfer_plan],
|
||||
"partitioned_ids": partitioned_ids,
|
||||
"handle_name": rpc.interface_type.value,
|
||||
"rpc_name": rpc.name,
|
||||
"meta_sample": meta_sample,
|
||||
"partitioned_ids": partitioned_ids,
|
||||
}
|
||||
|
||||
payloads = {
|
||||
|
@ -230,16 +237,27 @@ class FunctionCall:
|
|||
mwids = [self.msid2mwid[h] for h in handlers]
|
||||
assert len(mwids) == len(set(mwids))
|
||||
|
||||
for producer_name in producer_names.values():
|
||||
for h in producer_name2producer_handlers[producer_name]:
|
||||
if self.msid2mwid[h] not in mwids:
|
||||
payloads[h] = request_reply_stream.Payload(
|
||||
handler=h,
|
||||
handle_name="empty",
|
||||
pre_hooks=["data_transfer"],
|
||||
pre_hook_data=[dt_data],
|
||||
)
|
||||
mwids.append(self.msid2mwid[h])
|
||||
for step in data_transfer_plan:
|
||||
if step.root not in mwids:
|
||||
handler = self.mwid2msids[step.root][0]
|
||||
payloads[handler] = request_reply_stream.Payload(
|
||||
handler=handler,
|
||||
handle_name="empty",
|
||||
pre_hooks=["data_transfer"],
|
||||
pre_hook_data=[dt_data],
|
||||
)
|
||||
mwids.append(step.root)
|
||||
if step.comm_type == "gather":
|
||||
for src in step.srcs:
|
||||
if src not in mwids:
|
||||
handler = self.mwid2msids[src][0]
|
||||
payloads[handler] = request_reply_stream.Payload(
|
||||
handler=handler,
|
||||
handle_name="empty",
|
||||
pre_hooks=["data_transfer"],
|
||||
pre_hook_data=[dt_data],
|
||||
)
|
||||
mwids.append(src)
|
||||
|
||||
payloads, mwids = self.attach_payloads_with_hooks(
|
||||
payloads,
|
||||
|
@ -266,9 +284,10 @@ class FunctionCall:
|
|||
# Dispatch data to different data parallel ranks.
|
||||
if self.rpc.is_generate():
|
||||
# The workload of generation is decided by batch size, instead of the generated length.
|
||||
lens = [1 for _ in range(sample.bs)]
|
||||
samples, forward_indices, _ = sample.split_with_lengths(
|
||||
mb_spec=data_api.MicroBatchSpec(n_mbs=self.dp_size),
|
||||
lens=[1 for _ in range(sample.bs)],
|
||||
lens=lens,
|
||||
)
|
||||
else:
|
||||
samples, forward_indices, _ = sample.split(
|
||||
|
@ -297,26 +316,6 @@ class FunctionCall:
|
|||
for j in range(topo.world_size())
|
||||
]
|
||||
|
||||
producer_names = {} # data key -> model name
|
||||
for k in rpc.input_keys:
|
||||
if k in rpc.data_producers:
|
||||
producer_names[k] = rpc.data_producers[k]
|
||||
else:
|
||||
producer_names[k] = self.src_rpc.model_name
|
||||
keys_to_send = defaultdict(list) # model name -> List[keys] to send
|
||||
for k in producer_names:
|
||||
keys_to_send[producer_names[k]].append(k)
|
||||
|
||||
# convert producer model name to ModelShardID
|
||||
producer_name2producer_handlers = {}
|
||||
for producer_name in keys_to_send:
|
||||
producer_name2producer_handlers[producer_name] = [
|
||||
config_pkg.ModelShardID.from_parallelism_rank(
|
||||
producer_name, self.model_topos[producer_name], j
|
||||
)
|
||||
for j in range(self.model_topos[producer_name].world_size())
|
||||
]
|
||||
|
||||
dp_head_indices = [
|
||||
topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, model=0)
|
||||
for i in range(self.dp_size)
|
||||
|
@ -330,40 +329,72 @@ class FunctionCall:
|
|||
buf_indices, sample, partitions = self.data_parallel_dispatch(
|
||||
buf_indices, sample
|
||||
)
|
||||
target_mapping = {i: list(range(v[0], v[1])) for i, v in enumerate(partitions)}
|
||||
|
||||
# Set data owner of produced data by this RPC, such that downstream RPCs can know
|
||||
# where to fetch these data.
|
||||
for dp_idx, (st, ed) in enumerate(partitions):
|
||||
for i in range(st, ed):
|
||||
for k in rpc.output_keys:
|
||||
self.rpc_ctrl.data_owner[sample.ids[i], k] = (
|
||||
rpc.model_name,
|
||||
dp_idx,
|
||||
# Build data destinations: GPU id -> List[data ids]
|
||||
partitioned_ids = []
|
||||
dests = {}
|
||||
for dp_rank, (st, ed) in enumerate(partitions):
|
||||
ranks = topo.filter_match(data=dp_rank)
|
||||
for rank in ranks:
|
||||
h = config_pkg.ModelShardID.from_parallelism_rank(
|
||||
model_name=rpc.model_name, topo=topo, parallelism_rank=rank
|
||||
)
|
||||
gpu_id = self.msid2mwid[h]
|
||||
assert gpu_id not in dests
|
||||
dests[gpu_id] = sample.ids[st:ed]
|
||||
partitioned_ids.append(sample.ids[st:ed])
|
||||
for i in range(self.n_model_workers):
|
||||
if i not in dests:
|
||||
dests[i] = []
|
||||
|
||||
# NOTE: The data loaded from the dataset may be unevenly distributed across DP ranks.
|
||||
# Only bcast works in this case.
|
||||
if rpc.is_src:
|
||||
pattern = "bcast"
|
||||
else:
|
||||
pattern = "gather-scatter"
|
||||
data_transfer_plan = self.redistrib_planner.derive_plan(
|
||||
dests,
|
||||
keys=rpc.input_keys,
|
||||
pattern=pattern,
|
||||
)
|
||||
blogger.info(f"Data tranfer plan for `{rpc.name}`: {data_transfer_plan}.")
|
||||
|
||||
# Update storage tracker for transferred data.
|
||||
if rpc.is_src:
|
||||
# NOTE: since the data we loaded may be unevenly distributed across DP ranks,
|
||||
# we should change the owner of the data to the src RPC.
|
||||
for i in range(topo.world_size()):
|
||||
h = ModelShardID.from_parallelism_rank(
|
||||
model_name=rpc.model_name, topo=topo, parallelism_rank=i
|
||||
)
|
||||
is_dp_head = h.mp_rank == 0 and h.pp_rank == topo.get_dim("pipe") - 1
|
||||
gpu_id = self.msid2mwid[h]
|
||||
for key in rpc.input_keys:
|
||||
self.redistrib_planner.storage_tracker.add_data(
|
||||
gpu_id, partitioned_ids[h.dp_rank], key=key, is_owner=is_dp_head
|
||||
)
|
||||
|
||||
# Get the data owner of this RPC's input data.
|
||||
# We use it to determine the source of data transfer.
|
||||
producer_mappings = {}
|
||||
for k in rpc.input_keys:
|
||||
names, dp_indices = [], []
|
||||
for sample_id in sample.ids:
|
||||
owner_name, dp_idx = self.rpc_ctrl.data_owner[(sample_id, k)]
|
||||
names.append(owner_name)
|
||||
dp_indices.append(dp_idx)
|
||||
assert len(set(names)) == 1
|
||||
producer_mapping = defaultdict(list)
|
||||
for i, dp_idx in enumerate(dp_indices):
|
||||
producer_mapping[dp_idx].append(i)
|
||||
producer_mapping = {k: sorted(v) for k, v in producer_mapping.items()}
|
||||
producer_mappings[names[0], k] = producer_mapping
|
||||
else:
|
||||
for step in data_transfer_plan:
|
||||
if step.comm_type == "scatter":
|
||||
for gpu_id, ids in zip(step.dsts, step.ids):
|
||||
for key in step.keys:
|
||||
self.redistrib_planner.storage_tracker.add_data(
|
||||
gpu_id, ids, key=key, is_owner=False
|
||||
)
|
||||
elif step.comm_type == "gather":
|
||||
for key in step.keys:
|
||||
self.redistrib_planner.storage_tracker.add_data(
|
||||
step.root,
|
||||
list(itertools.chain.from_iterable(step.ids)),
|
||||
key=key,
|
||||
is_owner=False,
|
||||
)
|
||||
|
||||
# send partitioned data to model workers
|
||||
req_ids, other_req_ids = self.request(
|
||||
producer_names=producer_names,
|
||||
producer_name2producer_handlers=producer_name2producer_handlers,
|
||||
producer_mappings=producer_mappings,
|
||||
target_mapping=target_mapping,
|
||||
data_transfer_plan=data_transfer_plan,
|
||||
partitioned_ids=partitioned_ids,
|
||||
meta_sample=sample,
|
||||
handlers=handlers,
|
||||
)
|
||||
|
@ -384,6 +415,22 @@ class FunctionCall:
|
|||
# model function calls. The data shoulbe be amended into buffer.
|
||||
# Otherwise, it's the train statistics and should be reduced and logged.
|
||||
if isinstance(responses[-1], data_api.SequenceSample):
|
||||
# Update storage tracker for generated data.
|
||||
for dp_rank, x in enumerate(responses):
|
||||
pp_size = topo.get_dim("pipe")
|
||||
ranks = topo.filter_match(data=dp_rank, pipe=pp_size - 1, model=0)
|
||||
for rank in ranks:
|
||||
h = config_pkg.ModelShardID.from_parallelism_rank(
|
||||
model_name=rpc.model_name, topo=topo, parallelism_rank=rank
|
||||
)
|
||||
gpu_id = self.msid2mwid[h]
|
||||
for k in rpc.output_keys:
|
||||
self.redistrib_planner.storage_tracker.add_data(
|
||||
gpu_id,
|
||||
x.ids,
|
||||
key=k,
|
||||
is_owner=True,
|
||||
)
|
||||
res = data_api.SequenceSample.gather(responses)
|
||||
else:
|
||||
res = data_api.gather_stat(responses)
|
||||
|
@ -393,6 +440,11 @@ class FunctionCall:
|
|||
|
||||
if isinstance(res, Dict):
|
||||
wandb.log(res, step=ctrl.step_info.global_step)
|
||||
if self.summary_writer is not None:
|
||||
for key, val in res.items():
|
||||
self.summary_writer.add_scalar(
|
||||
f"{key}", val, ctrl.step_info.global_step
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Model rpc {rpc.name} finished. "
|
||||
|
@ -418,7 +470,6 @@ class FunctionCall:
|
|||
async def run(self):
|
||||
rpc = self.rpc
|
||||
topo = self.model_topos[rpc.model_name]
|
||||
ctrl = self.rpc_ctrl
|
||||
|
||||
logger.info(
|
||||
f"Running Model RPC, interface_type=#{rpc.interface_type}# "
|
|
@ -30,7 +30,6 @@ import torch.utils.data
|
|||
|
||||
import realhf.api.core.dfg as dfg
|
||||
import realhf.api.core.system_api as system_api
|
||||
import realhf.impl.model.comm.data_transfer as data_transfer_comm
|
||||
import realhf.impl.model.comm.global_comm as global_comm
|
||||
import realhf.impl.model.comm.param_realloc as param_realloc_comm
|
||||
from realhf.api.core.config import ModelName
|
||||
|
@ -49,11 +48,12 @@ from realhf.base.monitor import (
|
|||
cuda_tmark,
|
||||
cuda_tmarked,
|
||||
dump_tmark_db,
|
||||
gpu_utilization_monitor,
|
||||
)
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
from realhf.impl.model.utils import cuda_graph
|
||||
from realhf.system import request_reply_stream, worker_base
|
||||
from realhf.system.data_manager import DataManager
|
||||
from realhf.system.redistributor import RedistribStep
|
||||
|
||||
# NOTE: Register all implemented datasets and models.
|
||||
import realhf.api.core.data_api as data_api # isort:skip
|
||||
|
@ -257,11 +257,12 @@ class ModelWorker(worker_base.Worker):
|
|||
msid2mwid=self.config.msid2mwid,
|
||||
)
|
||||
|
||||
self.__data_transfer_info = data_transfer_comm.setup_data_transfer(
|
||||
self.data_manager = DataManager(
|
||||
model_topos=self.config.model_topos,
|
||||
msid2mwid=self.config.msid2mwid,
|
||||
data_transfer_pairs=self.config.data_transfer_pairs,
|
||||
)
|
||||
self.data_manager.setup_process_groups()
|
||||
|
||||
self.__param_realloc_info = param_realloc_comm.setup_param_realloc(
|
||||
model_topos=self.config.model_topos,
|
||||
|
@ -466,17 +467,6 @@ class ModelWorker(worker_base.Worker):
|
|||
self.__reply_queue = queue.Queue(maxsize=8)
|
||||
self.__request_sample_size = dict()
|
||||
|
||||
# Storing data loaded from the dataset and outputs of the
|
||||
# model function call.
|
||||
self.__data_storage: Dict[int, data_api.SequenceSample] = {}
|
||||
|
||||
self.__data_sent_worker_indices: Dict[int, Dict[str, Set]] = (
|
||||
collections.defaultdict(lambda: collections.defaultdict(set))
|
||||
)
|
||||
self.__data_received_worker_indices: Dict[int, Dict[str, Set]] = (
|
||||
collections.defaultdict(lambda: collections.defaultdict(set))
|
||||
)
|
||||
|
||||
self.__compute_input_queues = {
|
||||
model_name: dict(
|
||||
train_step=queue.Queue(4),
|
||||
|
@ -629,10 +619,10 @@ class ModelWorker(worker_base.Worker):
|
|||
# Defer data that has not been used in the previous epoch.
|
||||
data_loaded = []
|
||||
for x in cur_sample.unpack():
|
||||
if x.ids[0] in self.__data_storage:
|
||||
if self.data_manager.has_data(x.ids[0]):
|
||||
continue
|
||||
data_loaded.append(x)
|
||||
self.__data_storage[x.ids[0]] = x
|
||||
self.data_manager.store(x)
|
||||
assert len(set([x.ids[0] for x in data_loaded])) == len(data_loaded)
|
||||
|
||||
if len(data_loaded) > 0:
|
||||
|
@ -653,13 +643,7 @@ class ModelWorker(worker_base.Worker):
|
|||
elif request.handle_name == "clear_data_cache":
|
||||
with cuda_tmarked("clear_data_cache", CUDATimeMarkType.misc):
|
||||
ids = request.data
|
||||
for _id in ids:
|
||||
if _id in self.__data_storage:
|
||||
del self.__data_storage[_id]
|
||||
if _id in self.__data_sent_worker_indices:
|
||||
del self.__data_sent_worker_indices[_id]
|
||||
if _id in self.__data_received_worker_indices:
|
||||
del self.__data_received_worker_indices[_id]
|
||||
self.data_manager.remove(ids)
|
||||
gc.collect()
|
||||
if (
|
||||
self.config.cuda_cache_cleanliness
|
||||
|
@ -673,7 +657,7 @@ class ModelWorker(worker_base.Worker):
|
|||
)
|
||||
logger.info(
|
||||
"Get clear_data_cache, dump cuda tmark. "
|
||||
f"Remaining data in local storage: {len(self.__data_storage)}. "
|
||||
f"Remaining data in local storage: {self.data_manager.storage_size()}. "
|
||||
)
|
||||
dump_tmark_db(self.__worker_index)
|
||||
res = request_reply_stream.NoResponse()
|
||||
|
@ -940,7 +924,7 @@ class ModelWorker(worker_base.Worker):
|
|||
for x in res.unpack():
|
||||
# The input data must exist in the storage, otherwise
|
||||
# the model function call will not run.
|
||||
self.__data_storage[x.ids[0]].update_(x)
|
||||
self.data_manager.update(x)
|
||||
|
||||
# Only return meta data back to the master worker.
|
||||
if isinstance(res, data_api.SequenceSample):
|
||||
|
@ -967,32 +951,18 @@ class ModelWorker(worker_base.Worker):
|
|||
@cuda_tmark("data_transfer", CUDATimeMarkType.comm)
|
||||
def __data_transfer_among_workers(self, hook_data: Dict[str, Any]):
|
||||
meta_sample = hook_data["meta_sample"]
|
||||
comm_plan = data_transfer_comm.derive_data_transfer_plan(
|
||||
keys=hook_data["keys"],
|
||||
global_ids=meta_sample.ids,
|
||||
consumer_name=hook_data["target"],
|
||||
consumer_mapping=hook_data["target_mapping"],
|
||||
producer_names=hook_data["producer_names"],
|
||||
producer_mappings=hook_data["producer_mappings"],
|
||||
data_transfer_info=self.__data_transfer_info,
|
||||
)
|
||||
|
||||
data_transfer_comm.run_data_transfer(
|
||||
comm_plan=comm_plan,
|
||||
meta_samples={x.ids[0]: x for x in meta_sample.unpack()},
|
||||
storage=self.__data_storage,
|
||||
sent_worker_idx_table=self.__data_sent_worker_indices,
|
||||
received_worker_idx_table=self.__data_received_worker_indices,
|
||||
)
|
||||
plan = [RedistribStep(**json.loads(x)) for x in hook_data["plan"]]
|
||||
self.data_manager.redistribute(meta_sample, plan=plan)
|
||||
|
||||
if hook_data["target"] in self.__models:
|
||||
with constants.model_scope(hook_data["target"]):
|
||||
local_ids = [
|
||||
meta_sample.ids[i]
|
||||
for i in hook_data["target_mapping"][self._dp_rank]
|
||||
]
|
||||
local_ids = hook_data["partitioned_ids"][self._dp_rank]
|
||||
r = data_api.SequenceSample.gather(
|
||||
[self.__data_storage[_id] for _id in local_ids],
|
||||
[
|
||||
self.data_manager.get(_id).to_device(constants.current_device())
|
||||
for _id in local_ids
|
||||
],
|
||||
keys=meta_sample.keys,
|
||||
)
|
||||
self.__compute_input_queues[hook_data["target"]][
|
||||
|
@ -1349,7 +1319,6 @@ class ModelWorker(worker_base.Worker):
|
|||
self.__models.clear()
|
||||
self.__backends.clear()
|
||||
self.__interfaces.clear()
|
||||
self.__data_storage.clear()
|
||||
|
||||
# Reset model worker states.
|
||||
self.__dist_env_resolved = False
|
||||
|
|
|
@ -0,0 +1,343 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import dataclasses
|
||||
import itertools
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import *
|
||||
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
|
||||
|
||||
class GlobalStorageTracker:
|
||||
def __init__(self, world_size: int):
|
||||
self.storages: List[Dict[Hashable, List[str]]]
|
||||
self.storages = [{} for _ in range(world_size)]
|
||||
self.data_owner: Dict[Tuple[Hashable, str], int]
|
||||
self.data_owner = {}
|
||||
|
||||
def add_data(self, rank: int, ids: List[Hashable], key: str, is_owner: bool):
|
||||
for data_id in ids:
|
||||
if data_id not in self.storages[rank]:
|
||||
self.storages[rank][data_id] = [key]
|
||||
else:
|
||||
if key not in self.storages[rank][data_id]:
|
||||
self.storages[rank][data_id].append(key)
|
||||
if is_owner:
|
||||
self.data_owner[(data_id, key)] = rank
|
||||
|
||||
def clear_data(self, ids: List[Hashable]):
|
||||
for storage in self.storages:
|
||||
for i in ids:
|
||||
if i in storage:
|
||||
storage.pop(i)
|
||||
keys = list(self.data_owner.keys())
|
||||
for i, k in keys:
|
||||
if i in ids:
|
||||
self.data_owner.pop((i, k))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RedistribStep:
|
||||
comm_type: str
|
||||
root: int | None
|
||||
srcs: List[int] | None
|
||||
dsts: List[int] | None
|
||||
ids: List[List[Hashable]]
|
||||
keys: List[str]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.comm_type == "gather":
|
||||
return f"Gather {self.keys} to {self.root} from {self.srcs}."
|
||||
if self.comm_type == "scatter":
|
||||
return f"Scatter {self.keys} from {self.root} to {self.dsts}."
|
||||
if self.comm_type == "bcast":
|
||||
return f"Bcast {self.keys} from {self.root} to {self.dsts}."
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RedistribPlanner:
|
||||
def __init__(self, storage_tracker: GlobalStorageTracker):
|
||||
self.storage_tracker = storage_tracker
|
||||
|
||||
def derive_plan(
|
||||
self,
|
||||
dests: Dict[int, List[Hashable]],
|
||||
keys: List[str],
|
||||
pattern: str = "gather-scatter",
|
||||
) -> List[RedistribStep]:
|
||||
if pattern == "gather-scatter":
|
||||
return self.derive_plan_gather_scatter(dests, keys)
|
||||
elif pattern == "bcast":
|
||||
return self.derive_plan_bcast(dests, keys)
|
||||
raise NotImplementedError(f"Unknown data redistribution pattern: {pattern}")
|
||||
|
||||
def derive_plan_gather_scatter(
|
||||
self, dests: Dict[int, List[Hashable]], keys: List[str]
|
||||
) -> List[RedistribStep]:
|
||||
self.dests = dests
|
||||
|
||||
all_data_ids = set()
|
||||
for all_samples in dests.values():
|
||||
for data_id in all_samples:
|
||||
all_data_ids.add(data_id)
|
||||
|
||||
transfer_plan = []
|
||||
for key in keys:
|
||||
owners = sorted(
|
||||
list(
|
||||
set(
|
||||
[
|
||||
self.storage_tracker.data_owner[(i, key)]
|
||||
for i in all_data_ids
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
gather_ids = []
|
||||
for owner in owners:
|
||||
this_owner_ids = []
|
||||
for i in all_data_ids:
|
||||
if (
|
||||
i in self.storage_tracker.storages[owner]
|
||||
and key in self.storage_tracker.storages[owner][i]
|
||||
):
|
||||
this_owner_ids.append(i)
|
||||
gather_ids.append(sorted(this_owner_ids))
|
||||
gather_step = RedistribStep(
|
||||
comm_type="gather",
|
||||
root=owners[0],
|
||||
srcs=owners,
|
||||
dsts=None,
|
||||
ids=gather_ids,
|
||||
keys=[key],
|
||||
)
|
||||
|
||||
scatter_dsts = sorted([i for i in dests if dests[i]])
|
||||
scatter_ids = [sorted(dests[i]) for i in scatter_dsts]
|
||||
scatter_step = RedistribStep(
|
||||
comm_type="scatter",
|
||||
root=owners[0],
|
||||
dsts=scatter_dsts,
|
||||
srcs=None,
|
||||
ids=scatter_ids,
|
||||
keys=[key],
|
||||
)
|
||||
transfer_plan += [gather_step, scatter_step]
|
||||
|
||||
# Prune the plan.
|
||||
pop_indices = []
|
||||
for idx, step in enumerate(transfer_plan):
|
||||
# 1. Omit the gather step if data has already been gathered before.
|
||||
if step.comm_type == "gather":
|
||||
all_gather_ids = list(itertools.chain.from_iterable(step.ids))
|
||||
key = step.keys[0]
|
||||
if any(
|
||||
i not in self.storage_tracker.storages[step.root]
|
||||
for i in all_gather_ids
|
||||
):
|
||||
continue
|
||||
if any(
|
||||
key not in self.storage_tracker.storages[step.root][i]
|
||||
for i in all_gather_ids
|
||||
):
|
||||
continue
|
||||
pop_indices.append(idx)
|
||||
# 2. Omit the gather + scatter step if data has already exists in all dst GPUs.
|
||||
if step.comm_type == "scatter":
|
||||
key = step.keys[0]
|
||||
all_exists = True
|
||||
for dst, ids in zip(step.dsts, step.ids):
|
||||
if any(i not in self.storage_tracker.storages[dst] for i in ids):
|
||||
all_exists = False
|
||||
break
|
||||
if any(
|
||||
key not in self.storage_tracker.storages[dst][i] for i in ids
|
||||
):
|
||||
all_exists = False
|
||||
break
|
||||
if all_exists:
|
||||
pop_indices.append(idx)
|
||||
pop_indices.append(idx - 1)
|
||||
for pop_idx in reversed(sorted(set(pop_indices))):
|
||||
transfer_plan.pop(pop_idx)
|
||||
|
||||
# Merging the gather/scatter of different keys
|
||||
gather_plan = {}
|
||||
scatter_plan = {}
|
||||
for step in transfer_plan:
|
||||
if step.comm_type == "gather":
|
||||
plan_key = (
|
||||
step.root,
|
||||
tuple(sorted(step.srcs)),
|
||||
tuple([tuple(sorted(ids)) for ids in step.ids]),
|
||||
)
|
||||
if plan_key not in gather_plan:
|
||||
gather_plan[plan_key] = step
|
||||
else:
|
||||
assert all(
|
||||
k not in gather_plan[plan_key].keys for k in step.keys
|
||||
), (
|
||||
gather_plan[plan_key],
|
||||
step,
|
||||
plan_key,
|
||||
)
|
||||
gather_plan[plan_key].keys += step.keys
|
||||
if step.comm_type == "scatter":
|
||||
plan_key = (
|
||||
step.root,
|
||||
tuple(sorted(step.dsts)),
|
||||
tuple([tuple(sorted(ids)) for ids in step.ids]),
|
||||
)
|
||||
if plan_key not in scatter_plan:
|
||||
scatter_plan[plan_key] = step
|
||||
else:
|
||||
assert all(
|
||||
k not in scatter_plan[plan_key].keys for k in step.keys
|
||||
), (
|
||||
scatter_plan[plan_key],
|
||||
step,
|
||||
plan_key,
|
||||
)
|
||||
scatter_plan[plan_key].keys += step.keys
|
||||
|
||||
# Prioritize gather over scatter
|
||||
return list(gather_plan.values()) + list(scatter_plan.values())
|
||||
|
||||
def derive_plan_bcast(
|
||||
self, dests: Dict[int, List[Hashable]], keys: List[str]
|
||||
) -> List[RedistribStep]:
|
||||
assert isinstance(keys, list), type(keys)
|
||||
self.dests = dests
|
||||
|
||||
# Get all requried data IDs.
|
||||
all_data_ids = set()
|
||||
for all_samples in self.dests.values():
|
||||
for data_id in all_samples:
|
||||
all_data_ids.add(data_id)
|
||||
|
||||
# The producers for each required data.
|
||||
id2gpu_src = {}
|
||||
for data_id in all_data_ids:
|
||||
for key in keys:
|
||||
id2gpu_src[(data_id, key)] = []
|
||||
for gpu_id, ids2keys in enumerate(self.storage_tracker.storages):
|
||||
if data_id in ids2keys and key in ids2keys[data_id]:
|
||||
id2gpu_src[(data_id, key)].append(gpu_id)
|
||||
|
||||
# The consumers for each requried data.
|
||||
id2gpu_dst = {}
|
||||
for data_id in all_data_ids:
|
||||
for key in keys:
|
||||
id2gpu_dst[(data_id, key)] = []
|
||||
for gpu_id, ids in self.dests.items():
|
||||
if data_id in ids:
|
||||
id2gpu_dst[(data_id, key)].append(gpu_id)
|
||||
|
||||
self.transfer_plan = {}
|
||||
|
||||
for data_id, key in itertools.product(all_data_ids, keys):
|
||||
source_gpus = id2gpu_src[(data_id, key)]
|
||||
target_gpus = id2gpu_dst[(data_id, key)]
|
||||
|
||||
assert len(source_gpus) > 0, (data_id, key, id2gpu_src, id2gpu_dst)
|
||||
|
||||
# Omit data transfer if it exists in the target GPU
|
||||
target_gpus = [gpu for gpu in target_gpus if gpu not in source_gpus]
|
||||
if not target_gpus:
|
||||
continue
|
||||
|
||||
# Find the "nearest" GPU for data transfer.
|
||||
best_src = self._select_best_bcast_source(source_gpus, target_gpus)
|
||||
|
||||
self.transfer_plan[(data_id, key)] = {"src": best_src, "dsts": target_gpus}
|
||||
|
||||
return self._group_bcast_transfers()
|
||||
|
||||
def _on_same_node(self, i, j) -> bool:
|
||||
return (i // cluster_spec.n_gpus_per_node) == (
|
||||
j // cluster_spec.n_gpus_per_node
|
||||
)
|
||||
|
||||
def _select_best_bcast_source(self, source_gpus, target_gpus):
|
||||
same_node_counts = {}
|
||||
for src in source_gpus:
|
||||
same_node_count = sum(
|
||||
1 for dst in target_gpus if self._on_same_node(src, dst)
|
||||
)
|
||||
same_node_counts[src] = same_node_count
|
||||
|
||||
# Find the source that maximizes locality.
|
||||
max_same_node = max(same_node_counts.values())
|
||||
best_sources = [
|
||||
src for src, count in same_node_counts.items() if count == max_same_node
|
||||
]
|
||||
|
||||
# Find the source with the smallest workload.
|
||||
src_load = defaultdict(int)
|
||||
for plan in self.transfer_plan.values():
|
||||
src_gpu = plan["src"]
|
||||
src_load[src_gpu] += len(plan["dsts"])
|
||||
return min(best_sources, key=lambda src: src_load[src])
|
||||
|
||||
def _group_bcast_transfers(self) -> List[RedistribStep]:
|
||||
# Group data ids that should be transferred from "src" to "dsts"
|
||||
src_to_transfers = defaultdict(lambda: defaultdict(list))
|
||||
for (data_id, key), plan in self.transfer_plan.items():
|
||||
src_to_transfers[(plan["src"], key)][tuple(sorted(plan["dsts"]))].append(
|
||||
data_id
|
||||
)
|
||||
|
||||
stages = []
|
||||
while any(src_to_transfers.values()):
|
||||
stage = []
|
||||
used_dsts = set()
|
||||
used_srcs = set()
|
||||
|
||||
for (src, key), transfers in list(src_to_transfers.items()):
|
||||
if src in used_srcs:
|
||||
continue
|
||||
if not transfers:
|
||||
continue
|
||||
|
||||
# Find a transfer that can be concurrent executed.
|
||||
pop_key = None
|
||||
for i, dsts in enumerate(transfers):
|
||||
if not any(dst in used_dsts for dst in dsts):
|
||||
pop_key = dsts
|
||||
break
|
||||
|
||||
if pop_key is not None:
|
||||
data_ids = transfers.pop(pop_key)
|
||||
stage.append(
|
||||
RedistribStep(
|
||||
comm_type="bcast",
|
||||
root=src,
|
||||
srcs=[src],
|
||||
keys=[key],
|
||||
dsts=pop_key,
|
||||
ids=data_ids,
|
||||
)
|
||||
)
|
||||
used_dsts.update(dsts)
|
||||
used_srcs.add(src)
|
||||
|
||||
if stage:
|
||||
stages += stage
|
||||
else:
|
||||
for (src, key), transfers in list(src_to_transfers.items()):
|
||||
if transfers:
|
||||
dsts, data_ids = transfers.pop(0)
|
||||
stages.append(
|
||||
RedistribStep(
|
||||
comm_type="bcast",
|
||||
srcs=[src],
|
||||
root=src,
|
||||
dsts=dsts,
|
||||
ids=data_ids,
|
||||
keys=[key],
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
return stages
|
|
@ -138,6 +138,9 @@ class NameResolvingRequestClient:
|
|||
f"subscribers: {name_resolve.get_subtree(names.request_reply_stream(experiment_name, trial_name, PUBSUB_BARRIER_NAME))}."
|
||||
)
|
||||
|
||||
def route_to(self, handler) -> int:
|
||||
return self._handler_routing[handler]
|
||||
|
||||
def close(self):
|
||||
self.recv_socket.close()
|
||||
for send_socket in self.send_sockets:
|
||||
|
|
|
@ -1,501 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
import colorama
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import uvloop
|
||||
import wandb
|
||||
|
||||
import realhf.api.core.dfg as dfg
|
||||
import realhf.api.core.model_api as model_api
|
||||
import realhf.api.core.system_api as config_pkg
|
||||
import realhf.base.recover as recover
|
||||
import realhf.system.request_reply_stream as request_reply_stream
|
||||
import realhf.system.worker_base as worker_base
|
||||
from realhf.api.core.config import ModelName
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base import constants, logging, name_resolve, names, timeutil, topology
|
||||
from realhf.system.buffer import AsyncIOSequenceBuffer
|
||||
from realhf.system.v2.function_call import RPCCorountineControl
|
||||
from realhf.system.v2.function_executor import FunctionExecutor
|
||||
|
||||
logger = logging.getLogger("master worker", "system")
|
||||
blogger = logging.getLogger("benchmark")
|
||||
|
||||
uvloop.install()
|
||||
|
||||
|
||||
class MasterWorker(worker_base.Worker):
|
||||
global_exp_tik = time.perf_counter()
|
||||
|
||||
def _configure(self, config: config_pkg.MasterWorker):
|
||||
self.config = config
|
||||
|
||||
self.__model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = (
|
||||
config.model_topos
|
||||
)
|
||||
|
||||
# Build execution graph and initialize concurrency utilities.
|
||||
self.__model_rpcs = config.model_rpcs
|
||||
|
||||
# Sort all MFCs in the topological order and
|
||||
# calculate the width of each level.
|
||||
# These numbers will determine when to flush MFC requests.
|
||||
self.__topo_widths = []
|
||||
for generation in nx.topological_generations(self.__model_rpcs[0]._G):
|
||||
self.__topo_widths.append(len(generation))
|
||||
logger.info("Topological widths: " + str(self.__topo_widths))
|
||||
|
||||
self.__rpc_srcs = list(filter(lambda rpc: rpc.is_src, self.__model_rpcs))
|
||||
self.__rpc_dsts = list(filter(lambda rpc: rpc.is_dst, self.__model_rpcs))
|
||||
|
||||
# Save and eval control.
|
||||
self.__total_train_epochs = config.exp_ctrl.total_train_epochs
|
||||
self.__save_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=config.exp_ctrl.save_freq_epochs,
|
||||
freq_step=config.exp_ctrl.save_freq_steps,
|
||||
freq_sec=config.exp_ctrl.save_freq_secs,
|
||||
)
|
||||
if (
|
||||
config.exp_ctrl.ckpt_freq_epochs is None
|
||||
and config.exp_ctrl.ckpt_freq_steps is None
|
||||
and config.exp_ctrl.ckpt_freq_secs is None
|
||||
):
|
||||
self.__ckpt_ctl = self.__save_ctl
|
||||
else:
|
||||
self.__ckpt_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=config.exp_ctrl.ckpt_freq_epochs,
|
||||
freq_step=config.exp_ctrl.ckpt_freq_steps,
|
||||
freq_sec=config.exp_ctrl.ckpt_freq_secs,
|
||||
)
|
||||
self.__eval_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=config.exp_ctrl.eval_freq_epochs,
|
||||
freq_step=config.exp_ctrl.eval_freq_steps,
|
||||
freq_sec=config.exp_ctrl.eval_freq_secs,
|
||||
)
|
||||
|
||||
self.MODEL_SAVE_ROOT = os.path.join(
|
||||
constants.MODEL_SAVE_ROOT,
|
||||
config.worker_info.experiment_name,
|
||||
config.worker_info.trial_name,
|
||||
)
|
||||
os.makedirs(self.MODEL_SAVE_ROOT, exist_ok=True)
|
||||
|
||||
self.__initialized = False
|
||||
self.__recover_run, self.__recover_info = recover.load_recover_info()
|
||||
if self.__recover_info is not None:
|
||||
logger.info(
|
||||
f"Loaded recover info: recover_start={self.__recover_info.recover_start}, "
|
||||
f"last_step_info={self.__recover_info.last_step_info}."
|
||||
)
|
||||
logger.info(
|
||||
f"Number of used data in recover info: {len(self.__recover_info.hash_vals_to_ignore)}. "
|
||||
f"The previous experiment probably ran for {len(self.__recover_info.hash_vals_to_ignore) // self.__rpc_srcs[0].n_seqs} steps in the epoch."
|
||||
)
|
||||
|
||||
# Create corountine control objects for running the dataflow graph.
|
||||
self.__rpc_ctrl = RPCCorountineControl(
|
||||
train_count=asyncio.Queue(maxsize=len(self.__rpc_dsts)),
|
||||
topo_level_count=asyncio.Queue(maxsize=sum(self.__topo_widths)),
|
||||
# NOTE: We should accumulate the used data hashes in the same epoch
|
||||
# to prevent loading data used before.
|
||||
used_hash_vals_this_epoch=(
|
||||
copy.deepcopy(self.__recover_info.hash_vals_to_ignore)
|
||||
if self.__recover_run
|
||||
else list()
|
||||
),
|
||||
hash_vals_to_ignore_in_recover=(
|
||||
copy.deepcopy(self.__recover_info.hash_vals_to_ignore)
|
||||
if self.__recover_run
|
||||
else list()
|
||||
),
|
||||
)
|
||||
|
||||
if self.__recover_run:
|
||||
self.__rpc_ctrl.step_info = copy.deepcopy(self.__recover_info.recover_start)
|
||||
|
||||
self.__eval_ctl.load_state_dict(self.__recover_info.eval_ctl_info)
|
||||
self.__save_ctl.load_state_dict(self.__recover_info.save_ctl_info)
|
||||
self.__ckpt_ctl.load_state_dict(self.__recover_info.ckpt_ctl_info)
|
||||
|
||||
logger.info(
|
||||
f"Recovering from previous run. "
|
||||
f"Epoch: {self.__rpc_ctrl.step_info.epoch + 1}, "
|
||||
f"Epoch Step: {self.__rpc_ctrl.step_info.epoch_step + 1} "
|
||||
f"Global Step: {self.__rpc_ctrl.step_info.global_step + 1}."
|
||||
)
|
||||
|
||||
# for benchmark
|
||||
self.e2e_time_history = []
|
||||
self.__benchmark_steps = config.exp_ctrl.benchmark_steps
|
||||
|
||||
return config.worker_info
|
||||
|
||||
def initialize_models(self):
|
||||
# Initialize model backends.
|
||||
model_names = list(self.__model_topos.keys())
|
||||
self.logger.info(f"Initialize model backends with order: {model_names}.")
|
||||
train_rpcs = list(
|
||||
filter(
|
||||
lambda rpc: rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP,
|
||||
self.__model_rpcs,
|
||||
)
|
||||
)
|
||||
assert all(rpc.n_seqs == train_rpcs[0].n_seqs for rpc in train_rpcs)
|
||||
if len(train_rpcs) > 0:
|
||||
ft_spec = model_api.FinetuneSpec(
|
||||
total_train_epochs=self.config.exp_ctrl.total_train_epochs,
|
||||
dataset_size=self._dataset_size,
|
||||
train_batch_size=train_rpcs[0].n_seqs,
|
||||
)
|
||||
else:
|
||||
ft_spec = model_api.FinetuneSpec(
|
||||
total_train_epochs=self.config.exp_ctrl.total_train_epochs,
|
||||
dataset_size=self._dataset_size,
|
||||
train_batch_size=self.__src_rpc.n_seqs,
|
||||
)
|
||||
_initialized_roles = []
|
||||
for model_name in model_names:
|
||||
topo = self.config.model_topos[model_name]
|
||||
# Build FinetuneSpec, which is required to initialize backends.
|
||||
_handlers = [
|
||||
config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, j)
|
||||
for j in range(topo.world_size())
|
||||
]
|
||||
|
||||
init_payloads = [
|
||||
request_reply_stream.Payload(
|
||||
handler=_h,
|
||||
handle_name="initialize",
|
||||
data=ft_spec,
|
||||
)
|
||||
for _h in _handlers
|
||||
]
|
||||
|
||||
# Send initialization requests then immediately flush them.
|
||||
self.__stream.request(
|
||||
payloads=init_payloads,
|
||||
)
|
||||
self.__stream.request(
|
||||
handlers=_handlers,
|
||||
handle_type="flush",
|
||||
no_syn=True,
|
||||
)
|
||||
|
||||
_initialized_roles.append(model_name.role)
|
||||
|
||||
self._ft_spec = ft_spec
|
||||
logger.info("Initializations of models and backends complete.")
|
||||
|
||||
def get_dataset_model_info(self):
|
||||
src_rpc = self.__rpc_srcs[0]
|
||||
src_rpc_topo = self.config.model_topos[src_rpc.model_name]
|
||||
src_rpc_dp_size = src_rpc_topo.get_dim("data")
|
||||
|
||||
# Request training specification from data workers.
|
||||
all_data = sum(
|
||||
self.__stream.call(
|
||||
handlers=[f"__data{i}__" for i in range(src_rpc_dp_size)],
|
||||
datas=[None for i in range(src_rpc_dp_size)],
|
||||
handle_type="spec",
|
||||
),
|
||||
[],
|
||||
)
|
||||
|
||||
# NOTE: For dynamic datasets, we still count epoch according to the initial number of data,
|
||||
# such that the learning rate decay is not affected.
|
||||
seqlens = [max(sum(v[0]) for v in x.seqlens.values()) for x in all_data]
|
||||
self._dataset_size = len(all_data)
|
||||
self._steps_per_epoch = self._dataset_size // src_rpc.n_seqs
|
||||
self._avg_tokens_per_batch = sum(seqlens) / self._steps_per_epoch
|
||||
self._dataset_ids = [copy.deepcopy(x.ids[0]) for x in all_data]
|
||||
|
||||
# Request model configs from model workers.
|
||||
# Return None if the model is not a ReaLModel.
|
||||
self.__model_configs: Dict[ModelName, None | ReaLModelConfig] = {}
|
||||
for model_name, topo in self.config.model_topos.items():
|
||||
h = config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, 0)
|
||||
self.__model_configs[model_name] = self.__stream.call(
|
||||
handlers=[h],
|
||||
datas=[None],
|
||||
handle_type="model_config",
|
||||
)[0]
|
||||
|
||||
def __lazy_init(self):
|
||||
# Set up streams.
|
||||
handler_routing = copy.deepcopy(self.config.msid2mwid)
|
||||
src_rpc = self.__rpc_srcs[0]
|
||||
src_rpc_topo = self.config.model_topos[src_rpc.model_name]
|
||||
src_rpc_dp_size = src_rpc_topo.get_dim("data")
|
||||
src_rpc_pp_size = src_rpc_topo.get_dim("pipe")
|
||||
for i in range(src_rpc_dp_size):
|
||||
rank = src_rpc_topo.get_rank(data=i, pipe=src_rpc_pp_size - 1, model=0)
|
||||
handler_routing[f"__data{i}__"] = self.config.msid2mwid[
|
||||
config_pkg.ModelShardID.from_parallelism_rank(
|
||||
model_name=src_rpc.model_name,
|
||||
topo=src_rpc_topo,
|
||||
parallelism_rank=rank,
|
||||
)
|
||||
]
|
||||
handler_routing.update({i: i for i in range(self.config.n_model_workers)})
|
||||
self.__stream = request_reply_stream.make_master_stream(
|
||||
self.config.worker_info,
|
||||
n_subscribers=self.config.n_model_workers,
|
||||
handler_routing=handler_routing,
|
||||
)
|
||||
self.__stream: request_reply_stream.NameResolvingRequestClient
|
||||
|
||||
self.__src_rpc = src_rpc = [
|
||||
rpc for rpc in self.config.model_rpcs if rpc.is_src
|
||||
][0]
|
||||
|
||||
self.get_dataset_model_info()
|
||||
|
||||
self.initialize_models()
|
||||
|
||||
self.__seqbuffer = AsyncIOSequenceBuffer(
|
||||
self.__model_rpcs,
|
||||
max_size=int(os.getenv("REAL_MASTER_BUFFER_SIZE", str(int(1e7)))),
|
||||
)
|
||||
|
||||
# Create coroutines for model RPCs.
|
||||
logger.info(f"Creating asyncio coroutines...")
|
||||
self.func_executor = FunctionExecutor(
|
||||
rpcs=self.__model_rpcs,
|
||||
msid2mwid=self.config.msid2mwid,
|
||||
stream=self.__stream,
|
||||
buffer=self.__seqbuffer,
|
||||
model_topos=self.__model_topos,
|
||||
model_configs=self.__model_configs,
|
||||
ctrl=self.__rpc_ctrl,
|
||||
)
|
||||
logger.info(f"Coroutines created. The master worker is ready to run.")
|
||||
|
||||
# wandb init, connect to remote wandb host
|
||||
wandb.login()
|
||||
wandb.init(
|
||||
mode=self.wandb_config.mode,
|
||||
entity=self.wandb_config.entity,
|
||||
project=self.wandb_config.project or constants.experiment_name(),
|
||||
name=self.wandb_config.name or constants.trial_name(),
|
||||
job_type=self.wandb_config.job_type,
|
||||
group=self.wandb_config.group,
|
||||
notes=self.wandb_config.notes,
|
||||
tags=self.wandb_config.tags,
|
||||
config=self.wandb_config.config,
|
||||
dir=os.path.join(
|
||||
constants.LOG_ROOT, constants.experiment_name(), constants.trial_name()
|
||||
),
|
||||
force=True,
|
||||
resume="allow",
|
||||
settings=wandb.Settings(start_method="fork"),
|
||||
)
|
||||
|
||||
self.__initialized = True
|
||||
self._train_start_time = time.perf_counter()
|
||||
|
||||
self.__last_step_info = recover.StepInfo(
|
||||
epoch=-1,
|
||||
epoch_step=-1,
|
||||
global_step=-1,
|
||||
)
|
||||
|
||||
def _poll(self):
|
||||
is_new_epoch = False
|
||||
|
||||
if not self.__initialized:
|
||||
self.__lazy_init()
|
||||
|
||||
# Main execution steps. The graph runs under-the-hood in RPC & stream threads.
|
||||
# Wait for the finish of the traversal of the execution graph.
|
||||
execution_start = time.perf_counter()
|
||||
if self.__rpc_ctrl.ids_to_clear:
|
||||
# Send clear cache requests to model workers.
|
||||
# Clearing the data used in the last step.
|
||||
self._clear_gpu_cache()
|
||||
|
||||
is_new_epoch = self._ft_spec.is_new_epoch(self.__rpc_ctrl.step_info)
|
||||
is_epoch_last_step = self._ft_spec.is_epoch_last_step(self.__rpc_ctrl.step_info)
|
||||
|
||||
# Check whether we should evaluate or save models.
|
||||
self.__rpc_ctrl.should_eval = self.__eval_ctl.check(
|
||||
epochs=int(is_epoch_last_step), steps=1
|
||||
)
|
||||
self.__rpc_ctrl.should_save = self.__save_ctl.check(
|
||||
epochs=int(is_epoch_last_step), steps=1
|
||||
)
|
||||
self.__rpc_ctrl.should_ckpt = self.__ckpt_ctl.check(
|
||||
epochs=int(is_epoch_last_step), steps=1
|
||||
)
|
||||
|
||||
# Traverse over the dataflow graph for once.
|
||||
self.func_executor.execute_step()
|
||||
|
||||
# Post-process.
|
||||
if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt:
|
||||
self.__last_step_info = copy.deepcopy(self.__rpc_ctrl.step_info)
|
||||
|
||||
self.__rpc_ctrl.used_hash_vals_this_epoch += list(self.__rpc_ctrl.ids_to_clear)
|
||||
|
||||
if is_epoch_last_step:
|
||||
self.__rpc_ctrl.used_hash_vals_this_epoch = (
|
||||
self.__rpc_ctrl.used_hash_vals_this_epoch[self._dataset_size :]
|
||||
)
|
||||
|
||||
if is_new_epoch:
|
||||
self.__rpc_ctrl.step_info.epoch += 1
|
||||
self.__rpc_ctrl.step_info.epoch_step = 0
|
||||
|
||||
# Logging.
|
||||
time_since_configure = time.perf_counter() - self._train_start_time
|
||||
e2e_time = time.perf_counter() - execution_start
|
||||
self.e2e_time_history.append(e2e_time)
|
||||
|
||||
self._log_training_stats(e2e_time, time_since_configure)
|
||||
|
||||
# Updata counters.
|
||||
self.__rpc_ctrl.step_info.epoch_step += 1
|
||||
self.__rpc_ctrl.step_info.global_step += 1
|
||||
|
||||
if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt:
|
||||
self.__recover_save()
|
||||
|
||||
# Pause the worker if experiment or system-wise benchmark completes.
|
||||
if (
|
||||
self.__benchmark_steps is not None
|
||||
and self.__rpc_ctrl.step_info.global_step >= self.__benchmark_steps
|
||||
) or (
|
||||
self.__rpc_ctrl.step_info.global_step * self.__src_rpc.n_seqs
|
||||
>= self.__total_train_epochs * self._dataset_size
|
||||
):
|
||||
# We don't know whether it is the last step of the current epoch,
|
||||
# so we exit at the first step of the next epoch.
|
||||
if self.__benchmark_steps is not None:
|
||||
logger.info(
|
||||
f"Finished benchmark {self.__benchmark_steps}. "
|
||||
f"Time consumption of this setup: {time_since_configure:.3f}"
|
||||
)
|
||||
logger.info(f"avg #e2e# time *{np.mean(self.e2e_time_history):.3f}*")
|
||||
return self.experiment_complete_exit()
|
||||
|
||||
return worker_base.PollResult(sample_count=1, batch_count=1)
|
||||
|
||||
def _log_training_stats(self, e2e_time: float, time_since_configure: float):
|
||||
# calculate flops
|
||||
#########################################
|
||||
if not all(
|
||||
isinstance(v, ReaLModelConfig) for v in self.__model_configs.values()
|
||||
):
|
||||
logger.warning(
|
||||
f"Not all models are ReaLModels. Unable to calculate FLOP/s."
|
||||
)
|
||||
flops = None
|
||||
tflops_per_gpu = float("inf")
|
||||
else:
|
||||
flops = self.__rpc_ctrl.flops_counter.get_flops()
|
||||
tflops = flops / (e2e_time * (10**12))
|
||||
tflops_per_gpu = flops / (e2e_time * self.config.n_model_workers * (10**12))
|
||||
self.__rpc_ctrl.flops_counter.clear()
|
||||
#########################################
|
||||
|
||||
epoch = self.__rpc_ctrl.step_info.epoch + 1
|
||||
epoch_step = self.__rpc_ctrl.step_info.epoch_step + 1
|
||||
global_step = self.__rpc_ctrl.step_info.global_step + 1
|
||||
s = f"Epoch {epoch}/{self.config.exp_ctrl.total_train_epochs} "
|
||||
s += f"step {epoch_step}/{self._steps_per_epoch} "
|
||||
s += f"(global step {global_step}) finishes. "
|
||||
s += f"Average #tokens per batch is {self._avg_tokens_per_batch:.0f}. "
|
||||
s += f"#End to end# execution time: *{e2e_time:.3f}*s. "
|
||||
s += f"Total time consumption: {time_since_configure:.3f}s. "
|
||||
if len(self.e2e_time_history) > 2:
|
||||
remaining_steps = self._steps_per_epoch - epoch_step
|
||||
remaining_epochs = self.__total_train_epochs - epoch
|
||||
avg_t = np.mean(self.e2e_time_history[2:])
|
||||
remain_t = avg_t * remaining_steps
|
||||
remain_t += avg_t * self._steps_per_epoch * remaining_epochs
|
||||
s += f"Estimated remaining time: {remain_t:.3f}s. "
|
||||
if flops is not None:
|
||||
s += f"TFLOP/s per GPU: {tflops_per_gpu:.2f}, total TFLOP/s: {tflops:.2f}."
|
||||
logger.info(s)
|
||||
logger.info(
|
||||
f"Time taken so far across all configurations: {time.perf_counter() - self.global_exp_tik:.2f}s"
|
||||
)
|
||||
|
||||
def _clear_gpu_cache(self):
|
||||
self.__stream.request(
|
||||
handlers=list(range(self.config.n_model_workers)),
|
||||
handle_type="clear_data_cache",
|
||||
datas=[
|
||||
self.__rpc_ctrl.ids_to_clear
|
||||
for _ in list(range(self.config.n_model_workers))
|
||||
],
|
||||
no_syn=True,
|
||||
)
|
||||
self.__rpc_ctrl.ids_to_clear.clear()
|
||||
|
||||
def experiment_complete_exit(self):
|
||||
logger.info(
|
||||
colorama.Style.RESET_ALL
|
||||
+ colorama.Fore.YELLOW
|
||||
+ colorama.Style.BRIGHT
|
||||
+ "\033[1m"
|
||||
+ "Experiment Completes! Yeah!!!!!!!!"
|
||||
+ colorama.Style.RESET_ALL
|
||||
)
|
||||
|
||||
# Send requests to pause model workers.
|
||||
# Model workers will not respond to this message.
|
||||
self.__stream.request(
|
||||
handlers=list(range(self.config.n_model_workers)),
|
||||
handle_type="reset",
|
||||
datas=[None for _ in list(range(self.config.n_model_workers))],
|
||||
)
|
||||
self.__stream.close()
|
||||
constants.reset_run()
|
||||
# Reset names used for distributed training.
|
||||
# The next round of training will set up a new distributed environment.
|
||||
name_resolve.clear_subtree(
|
||||
names.distributed_root(constants.experiment_name(), constants.trial_name())
|
||||
)
|
||||
name_resolve.clear_subtree(
|
||||
names.request_reply_stream_root(
|
||||
constants.experiment_name(), constants.trial_name()
|
||||
)
|
||||
)
|
||||
|
||||
wandb.finish()
|
||||
gc.collect()
|
||||
self.__initialized = False
|
||||
self.pause()
|
||||
return worker_base.PollResult(0, 0)
|
||||
|
||||
def __recover_save(self):
|
||||
# save step info for recover
|
||||
if os.getenv("REAL_SAVE_RECOVER_STATES", "0") != "1":
|
||||
return
|
||||
# save step info for recover
|
||||
this_step_info = copy.deepcopy(self.__rpc_ctrl.step_info)
|
||||
recover_info = recover.RecoverInfo(
|
||||
recover_start=this_step_info,
|
||||
last_step_info=self.__last_step_info,
|
||||
save_ctl_info=self.__save_ctl.state_dict(),
|
||||
ckpt_ctl_info=self.__ckpt_ctl.state_dict(),
|
||||
eval_ctl_info=self.__eval_ctl.state_dict(),
|
||||
hash_vals_to_ignore=self.__rpc_ctrl.used_hash_vals_this_epoch,
|
||||
)
|
||||
|
||||
recover.dump_recover_info(recover_info)
|
||||
logger.info("Dumped recover info to file.")
|
||||
logger.info(f"Will recover from: {recover_info.recover_start}")
|
||||
logger.info(
|
||||
f"Number of data used in this epoch: {len(recover_info.hash_vals_to_ignore)}"
|
||||
)
|
|
@ -0,0 +1,288 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import pathlib
|
||||
import pickle
|
||||
import uuid
|
||||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from realhf.api.core.config import ModelName, ModelShardID
|
||||
from realhf.api.core.data_api import SequenceSample
|
||||
from realhf.base import constants, testing, topology
|
||||
from realhf.base.testing import (
|
||||
LocalMultiProcessTest,
|
||||
PipeModelDataParallelTopology,
|
||||
init_global_constants,
|
||||
)
|
||||
from realhf.system.data_manager import DataManager
|
||||
from realhf.system.redistributor import GlobalStorageTracker, RedistribPlanner
|
||||
|
||||
|
||||
def get_data_manager(
|
||||
from_model_name,
|
||||
to_model_name,
|
||||
from_pp_dp_mp,
|
||||
to_pp_dp_mp,
|
||||
):
|
||||
|
||||
from_num_pp, from_num_dp, from_num_mp = from_pp_dp_mp
|
||||
to_num_pp, to_num_dp, to_num_mp = to_pp_dp_mp
|
||||
|
||||
from_world_size = from_num_dp * from_num_mp * from_num_pp
|
||||
to_world_size = to_num_dp * to_num_mp * to_num_pp
|
||||
|
||||
from_topo = topology.PipeModelDataParallelTopology(
|
||||
num_dp=from_num_dp,
|
||||
num_mp=from_num_mp,
|
||||
num_pp=from_num_pp,
|
||||
sequence_parallel=False,
|
||||
gradient_checkpointing=False,
|
||||
max_prompt_len=None,
|
||||
gradient_accumulation_fusion=False,
|
||||
)
|
||||
to_topo = topology.PipeModelDataParallelTopology(
|
||||
num_dp=to_num_dp,
|
||||
num_mp=to_num_mp,
|
||||
num_pp=to_num_pp,
|
||||
sequence_parallel=False,
|
||||
gradient_checkpointing=False,
|
||||
max_prompt_len=None,
|
||||
gradient_accumulation_fusion=False,
|
||||
)
|
||||
|
||||
model_topos = {from_model_name: from_topo, to_model_name: to_topo}
|
||||
|
||||
msid2mwid = {}
|
||||
for i in range(dist.get_world_size()):
|
||||
# We assume the `from_model` occupies the first serveral GPUs,
|
||||
# while the `to_model` occupies GPUs from the last one.
|
||||
# For example, when the world size of `from_model` is 6 and
|
||||
# the world size of `to_model` is 4, the GPU layout is:
|
||||
# GPU 0-3: from_model (shard 0-3)
|
||||
# GPU 4-5: from_model (shard 4-5) + to_model (shard 0-1)
|
||||
# GPU 6-7: to_model (shard 2-3)
|
||||
_model_names = []
|
||||
if i < from_world_size:
|
||||
_model_names.append(from_model_name)
|
||||
if i >= dist.get_world_size() - to_world_size:
|
||||
_model_names.append(to_model_name)
|
||||
for _model_name in _model_names:
|
||||
if _model_name == from_model_name:
|
||||
coord = model_topos[_model_name].get_coord(i)
|
||||
else:
|
||||
coord = model_topos[_model_name].get_coord(
|
||||
i + to_world_size - dist.get_world_size()
|
||||
)
|
||||
k = ModelShardID(
|
||||
_model_name,
|
||||
dp_rank=coord.data,
|
||||
mp_rank=coord.model,
|
||||
pp_rank=coord.pipe,
|
||||
topo=model_topos[_model_name],
|
||||
)
|
||||
msid2mwid[k] = i
|
||||
|
||||
init_global_constants(
|
||||
num_dp=from_num_dp,
|
||||
num_mp=from_num_mp,
|
||||
num_pp=from_num_pp,
|
||||
topo=from_topo,
|
||||
model_name=from_model_name,
|
||||
sequence_parallel=False,
|
||||
msid2mwid=msid2mwid,
|
||||
)
|
||||
|
||||
init_global_constants(
|
||||
num_dp=to_num_dp,
|
||||
num_mp=to_num_mp,
|
||||
num_pp=to_num_pp,
|
||||
model_name=to_model_name,
|
||||
sequence_parallel=False,
|
||||
msid2mwid=msid2mwid,
|
||||
)
|
||||
|
||||
return DataManager(
|
||||
model_topos=model_topos,
|
||||
msid2mwid=msid2mwid,
|
||||
data_transfer_pairs=[(from_model_name, to_model_name)],
|
||||
)
|
||||
|
||||
|
||||
def recursive_assert_equal(x1, x2):
|
||||
if type(x1) != type(x2):
|
||||
raise AssertionError(f"{type(x1)} != {type(x2)}")
|
||||
if isinstance(x1, dict):
|
||||
assert set(x1.keys()) == set(x2.keys())
|
||||
for k in x1.keys():
|
||||
recursive_assert_equal(x1[k], x2[k])
|
||||
elif dataclasses.is_dataclass(x1):
|
||||
for f in dataclasses.fields(x1):
|
||||
recursive_assert_equal(getattr(x1, f.name), getattr(x2, f.name))
|
||||
elif isinstance(x1, torch.Tensor):
|
||||
assert torch.allclose(x1, x2), (x1, x2)
|
||||
elif isinstance(x1, list):
|
||||
assert len(x1) == len(x2)
|
||||
for a, b in zip(x1, x2):
|
||||
recursive_assert_equal(a, b)
|
||||
else:
|
||||
assert x1 == x2
|
||||
|
||||
|
||||
def _test_data_transfer(
|
||||
tmp_path,
|
||||
from_pp_dp_mp: Tuple,
|
||||
to_pp_dp_mp: Tuple,
|
||||
):
|
||||
|
||||
from_model_name = ModelName("data_transfer_test", 0)
|
||||
from_topo = PipeModelDataParallelTopology(
|
||||
num_pp=from_pp_dp_mp[0],
|
||||
num_mp=from_pp_dp_mp[-1],
|
||||
num_dp=from_pp_dp_mp[1],
|
||||
sequence_parallel=True,
|
||||
gradient_checkpointing=True,
|
||||
gradient_accumulation_fusion=True,
|
||||
)
|
||||
to_model_name = ModelName("data_transfer_test", 1)
|
||||
to_topo = PipeModelDataParallelTopology(
|
||||
num_pp=to_pp_dp_mp[0],
|
||||
num_mp=to_pp_dp_mp[-1],
|
||||
num_dp=to_pp_dp_mp[1],
|
||||
sequence_parallel=True,
|
||||
gradient_checkpointing=True,
|
||||
gradient_accumulation_fusion=True,
|
||||
)
|
||||
|
||||
data_manager = get_data_manager(
|
||||
from_model_name,
|
||||
to_model_name,
|
||||
from_pp_dp_mp,
|
||||
to_pp_dp_mp,
|
||||
)
|
||||
data_manager.setup_process_groups()
|
||||
|
||||
storage_tracker = GlobalStorageTracker(dist.get_world_size())
|
||||
planner = RedistribPlanner(storage_tracker)
|
||||
|
||||
key = "input_ids"
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
samples = []
|
||||
for dp_rank in range(from_pp_dp_mp[1]):
|
||||
gpu_id = data_manager.msid2mwid[
|
||||
ModelShardID(
|
||||
from_model_name,
|
||||
dp_rank=dp_rank,
|
||||
mp_rank=0,
|
||||
pp_rank=from_pp_dp_mp[0] - 1,
|
||||
topo=from_topo,
|
||||
)
|
||||
]
|
||||
storage_tracker.add_data(
|
||||
gpu_id,
|
||||
ids=[i + dp_rank * world_size for i in range(world_size)],
|
||||
key=key,
|
||||
is_owner=True,
|
||||
)
|
||||
|
||||
seqlens = torch.randint(10, 1000, size=(world_size,))
|
||||
dist.all_reduce(seqlens)
|
||||
input_ids = torch.randint(
|
||||
0,
|
||||
10000,
|
||||
size=(int(sum(seqlens)),),
|
||||
)
|
||||
dist.all_reduce(input_ids)
|
||||
|
||||
s = SequenceSample.from_default(
|
||||
ids=[i + dp_rank * world_size for i in range(world_size)],
|
||||
seqlens=seqlens.numpy().tolist(),
|
||||
data=dict(input_ids=input_ids),
|
||||
)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
for ss in s.unpack():
|
||||
with open(os.path.join(tmp_path, f"{ss.ids[0]}.pkl"), "wb") as f:
|
||||
pickle.dump(ss, f)
|
||||
|
||||
samples.append(s)
|
||||
if dist.get_rank() == gpu_id:
|
||||
for ss in s.unpack():
|
||||
data_manager.store(ss)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
all_ids = list(range(world_size * from_topo.get_dim("data")))
|
||||
np.random.shuffle(all_ids)
|
||||
_all_ids = [all_ids]
|
||||
dist.broadcast_object_list(_all_ids, src=0)
|
||||
all_ids = _all_ids[0]
|
||||
|
||||
dests = {}
|
||||
for rank in range(to_topo.world_size()):
|
||||
coord = to_topo.get_coord(rank)
|
||||
dp_size = to_topo.get_dim("data")
|
||||
gpu_id = data_manager.msid2mwid[
|
||||
ModelShardID(
|
||||
to_model_name,
|
||||
dp_rank=coord.data,
|
||||
mp_rank=coord.model,
|
||||
pp_rank=coord.pipe,
|
||||
topo=to_topo,
|
||||
)
|
||||
]
|
||||
size_per_dp = len(all_ids) // dp_size
|
||||
dests[gpu_id] = [coord.data * size_per_dp + i for i in range(size_per_dp)]
|
||||
|
||||
for gpu_id in range(world_size):
|
||||
if gpu_id not in dests:
|
||||
dests[gpu_id] = []
|
||||
|
||||
plan = planner.derive_plan(dests, keys=[key])
|
||||
data_manager.redistribute(SequenceSample.gather(samples), plan)
|
||||
dist.barrier()
|
||||
|
||||
for i, s in data_manager.storage.items():
|
||||
with open(os.path.join(tmp_path, f"{i}.pkl"), "rb") as f:
|
||||
ss = pickle.load(f)
|
||||
recursive_assert_equal(ss, s)
|
||||
print("success")
|
||||
|
||||
|
||||
parallelism = [(1, 4, 2), (1, 8, 1)]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.cpu_count() < 32 or testing.get_free_mem_gb() < 50,
|
||||
reason="The parameter reallocation test requires at least 32 CPUs and 50GB memory.",
|
||||
)
|
||||
@pytest.mark.parametrize("from_pp_dp_mp", [(1, 4, 2)])
|
||||
@pytest.mark.parametrize("to_pp_dp_mp", [(1, 8, 1)])
|
||||
@pytest.mark.distributed
|
||||
def test_data_transfer(
|
||||
tmp_path,
|
||||
from_pp_dp_mp: Tuple,
|
||||
to_pp_dp_mp: Tuple,
|
||||
):
|
||||
expr_name = uuid.uuid4()
|
||||
trial_name = uuid.uuid4()
|
||||
constants.set_force_cpu(True)
|
||||
test_impl = LocalMultiProcessTest(
|
||||
world_size=16,
|
||||
func=_test_data_transfer,
|
||||
expr_name=expr_name,
|
||||
trial_name=trial_name,
|
||||
timeout_secs=300,
|
||||
tmp_path=tmp_path,
|
||||
from_pp_dp_mp=from_pp_dp_mp,
|
||||
to_pp_dp_mp=to_pp_dp_mp,
|
||||
)
|
||||
test_impl.launch()
|
|
@ -47,13 +47,11 @@ def math_dataset(request, save_path):
|
|||
return dataset
|
||||
|
||||
|
||||
# NOTE: we can't test v1 and v2 at the same time.
|
||||
@pytest.mark.parametrize("use_v2_worker", [True])
|
||||
@pytest.mark.parametrize(
|
||||
"dp,pp,mp",
|
||||
[
|
||||
(1, 1, 1),
|
||||
(2, 1, 1),
|
||||
(2, 1, 2),
|
||||
(1, 2, 1),
|
||||
(1, 1, 2),
|
||||
],
|
||||
|
@ -68,7 +66,6 @@ def test_ppo_symm(
|
|||
dp,
|
||||
pp,
|
||||
mp,
|
||||
use_v2_worker,
|
||||
):
|
||||
# Setup experiment env. Should be done before any other operations.
|
||||
log_root = tmp_path_factory.mktemp("ppo")
|
||||
|
@ -120,10 +117,8 @@ def test_ppo_symm(
|
|||
),
|
||||
),
|
||||
)
|
||||
exp_cfg.actor.vllm.hybrid_train = True
|
||||
exp_cfg.actor.vllm.enforce_eager = True
|
||||
|
||||
run_test_exp(exp_cfg, use_v2_worker=use_v2_worker)
|
||||
run_test_exp(exp_cfg)
|
||||
|
||||
|
||||
# The global resharding strategy, where all MFCs
|
||||
|
@ -138,6 +133,7 @@ def test_ppo_symm(
|
|||
def test_ppo_global_reshard(
|
||||
tmp_path_factory,
|
||||
tokenizer,
|
||||
math_dataset,
|
||||
save_path,
|
||||
cpu_hf_model,
|
||||
mconfig,
|
||||
|
@ -243,19 +239,17 @@ def test_ppo_global_reshard(
|
|||
),
|
||||
),
|
||||
)
|
||||
exp_cfg.actor.vllm.hybrid_train = True
|
||||
exp_cfg.actor.vllm.enforce_eager = True
|
||||
|
||||
run_test_exp(exp_cfg)
|
||||
|
||||
|
||||
# Actor/critic train and ref_inf/rew_inf are on disjoint
|
||||
# device meshes and executed concurrently.
|
||||
@pytest.mark.parametrize("actor_gen", [(1, 2, 1)])
|
||||
@pytest.mark.parametrize("critic_inf", [(1, 1, 2)])
|
||||
@pytest.mark.parametrize("actor_gen", [(2, 2, 1)])
|
||||
@pytest.mark.parametrize("critic_inf", [(2, 1, 2)])
|
||||
def test_ppo_param_realloc_sub_device_mesh(
|
||||
tmp_path_factory,
|
||||
tokenizer,
|
||||
math_dataset,
|
||||
save_path,
|
||||
cpu_hf_model,
|
||||
mconfig,
|
||||
|
@ -276,7 +270,7 @@ def test_ppo_param_realloc_sub_device_mesh(
|
|||
mode="local",
|
||||
allocation_mode="manual",
|
||||
n_nodes=1,
|
||||
n_gpus_per_node=2,
|
||||
n_gpus_per_node=8,
|
||||
actor=ModelTrainEvalConfig(
|
||||
path=str(save_path),
|
||||
init_from_scratch=True,
|
||||
|
@ -312,54 +306,54 @@ def test_ppo_param_realloc_sub_device_mesh(
|
|||
),
|
||||
),
|
||||
actor_gen=MFCConfig(
|
||||
device_mesh="NODE01:0,1,2,3",
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=actor_gen[0],
|
||||
model_parallel_size=actor_gen[1],
|
||||
pipeline_parallel_size=actor_gen[2],
|
||||
)
|
||||
),
|
||||
),
|
||||
actor_train=MFCConfig(
|
||||
device_mesh="NODE01:0",
|
||||
device_mesh="NODE01:4,5,6,7",
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=1,
|
||||
data_parallel_size=4,
|
||||
model_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
),
|
||||
),
|
||||
critic_inf=MFCConfig(
|
||||
device_mesh="NODE01:4,5,6,7",
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=critic_inf[0],
|
||||
model_parallel_size=critic_inf[1],
|
||||
pipeline_parallel_size=critic_inf[2],
|
||||
)
|
||||
),
|
||||
),
|
||||
rew_inf=MFCConfig(
|
||||
device_mesh="NODE01:1",
|
||||
device_mesh="NODE01:4,5,6,7",
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=1,
|
||||
data_parallel_size=4,
|
||||
model_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
),
|
||||
),
|
||||
ref_inf=MFCConfig(
|
||||
device_mesh="NODE01:0",
|
||||
device_mesh="NODE01:4,5,6,7",
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=1,
|
||||
model_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
model_parallel_size=2,
|
||||
pipeline_parallel_size=2,
|
||||
),
|
||||
),
|
||||
critic_train=MFCConfig(
|
||||
device_mesh="NODE01:1",
|
||||
device_mesh="NODE01:4,5,6,7",
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=1,
|
||||
data_parallel_size=2,
|
||||
model_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
pipeline_parallel_size=2,
|
||||
),
|
||||
),
|
||||
)
|
||||
exp_cfg.actor.vllm.hybrid_train = True
|
||||
exp_cfg.actor.vllm.enforce_eager = True
|
||||
|
||||
run_test_exp(exp_cfg)
|
||||
|
||||
|
|
|
@ -35,11 +35,17 @@ def model_class(request):
|
|||
],
|
||||
)
|
||||
def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
|
||||
test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp)
|
||||
test_sft(
|
||||
tmp_path_factory,
|
||||
tokenizer,
|
||||
save_path,
|
||||
cpu_hf_model,
|
||||
dp,
|
||||
pp,
|
||||
tp,
|
||||
)
|
||||
|
||||
|
||||
# NOTE: we can't test v1 and v2 at the same time.
|
||||
@pytest.mark.parametrize("use_v2_worker", [True])
|
||||
@pytest.mark.parametrize(
|
||||
"dp,pp,tp",
|
||||
[
|
||||
|
@ -49,9 +55,7 @@ def test_sft_xl(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp
|
|||
(1, 1, 2),
|
||||
],
|
||||
)
|
||||
def test_sft(
|
||||
tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp, use_v2_worker
|
||||
):
|
||||
def test_sft(tmp_path_factory, tokenizer, save_path, cpu_hf_model, dp, pp, tp):
|
||||
|
||||
# Setup experiment env. Should be done before any other operations.
|
||||
log_root = tmp_path_factory.mktemp("sft")
|
||||
|
@ -83,4 +87,4 @@ def test_sft(
|
|||
),
|
||||
)
|
||||
|
||||
run_test_exp(exp_cfg, use_v2_worker=use_v2_worker)
|
||||
run_test_exp(exp_cfg)
|
||||
|
|
|
@ -18,7 +18,11 @@ def model_class(request):
|
|||
return request.param
|
||||
|
||||
|
||||
def run_model_worker(cfg, mw, barrier):
|
||||
def run_model_worker(
|
||||
cfg,
|
||||
mw,
|
||||
barrier,
|
||||
):
|
||||
constants.set_force_cpu(True)
|
||||
# Register all datasets and models
|
||||
import realhf.impl.dataset # isort: skip
|
||||
|
@ -48,18 +52,13 @@ def run_test_exp(
|
|||
exp_cfg: Experiment,
|
||||
expr_name=None,
|
||||
trial_name=None,
|
||||
use_v2_worker: bool = False,
|
||||
):
|
||||
constants.set_force_cpu(True)
|
||||
# Register all datasets and models
|
||||
import realhf.impl.dataset # isort: skip
|
||||
import realhf.impl.model # isort: skip
|
||||
from realhf.api.core import system_api
|
||||
|
||||
if not use_v2_worker:
|
||||
from realhf.system.master_worker import MasterWorker
|
||||
else:
|
||||
from realhf.system.v2.master_worker import MasterWorker
|
||||
from realhf.system.master_worker import MasterWorker
|
||||
|
||||
system_api.ALL_EXPERIMENT_CLASSES = {}
|
||||
register_experiment(testing._DEFAULT_EXPR_NAME, lambda: exp_cfg)
|
||||
|
@ -83,7 +82,12 @@ def run_test_exp(
|
|||
testcase = testing.LocalMultiProcessTest(
|
||||
world_size=len(exp_setup.model_worker),
|
||||
func=[
|
||||
functools.partial(run_model_worker, cfg=exp_cfg, mw=mw, barrier=barrier)
|
||||
functools.partial(
|
||||
run_model_worker,
|
||||
cfg=exp_cfg,
|
||||
mw=mw,
|
||||
barrier=barrier,
|
||||
)
|
||||
for mw in exp_setup.model_worker
|
||||
],
|
||||
expr_name=expr_name or testing._DEFAULT_EXPR_NAME,
|
||||
|
|
Loading…
Reference in New Issue