PullRequest: 29 fix the dataloading bug during recover

Merge branch fw/fix-recover-dataloading of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/29

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fix the dataloading bug during recover
This commit is contained in:
博惟 2025-03-12 16:00:49 +08:00
parent fb23009e99
commit 26a48be73e
11 changed files with 152 additions and 108 deletions

View File

@ -91,8 +91,6 @@ def main_worker(args):
# NOTE: Importing these will initialize DeepSpeed/CUDA devices.
# profiler.import_profiler_registers()
import realhf.impl.dataset
import realhf.impl.model
import realhf.system
logger.debug(f"Run {args.worker_type} worker with args: %s", args)

View File

@ -35,6 +35,8 @@ class RecoverInfo:
ckpt_ctl_info: Dict
eval_ctl_info: Dict
data_loading_dp_idx: int
hash_vals_to_ignore: List[int] = dataclasses.field(default_factory=list)

View File

@ -260,7 +260,7 @@ class PPOActorInterface(model_api.ModelInterface):
offset += x[0]
assert offset == sum(x[0] for x in input_.seqlens["packed_prompts"])
if model.backend_name != "vllm":
if model.backend_name not in ["vllm", "sglang"]:
# Replicate prompts
grouped_input = SequenceSample.from_default(
ids=list(range(input_.bs * self.generation_size)),
@ -287,7 +287,7 @@ class PPOActorInterface(model_api.ModelInterface):
gconfig=self.gconfig,
mb_spec=mb_spec,
)
if res is None:
if res is None or res[0] is None:
return None
gen_tokens, logprobs, _ = res

View File

@ -13,7 +13,7 @@ import torch.distributed as dist
from realhf import SequenceSample
from realhf.api.core.config import ModelName, ModelShardID
from realhf.base import constants
from realhf.base.topology import PipeModelDataParallelTopology, new_or_get_group
from realhf.base.topology import ProcessTopology, new_or_get_group
from realhf.impl.model.comm.global_comm import filter_match_mwids
from realhf.system.redistributor import RedistribStep
@ -26,7 +26,7 @@ class DataManager:
def __init__(
self,
model_topos: Dict[ModelName, PipeModelDataParallelTopology],
model_topos: Dict[ModelName, ProcessTopology],
msid2mwid: Optional[Dict[ModelShardID, int]] = None,
data_transfer_pairs: Optional[List[Tuple[ModelName, ModelName]]] = None,
):

View File

@ -11,7 +11,7 @@ from realhf.api.core.data_api import DataBatchMeta, SequenceSample
from realhf.api.core.dfg import MFCDef
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import logging
from realhf.base.topology import PipeModelDataParallelTopology
from realhf.base.topology import ProcessTopology
from realhf.system.buffer import AsyncIOSequenceBuffer
from realhf.system.model_function_call import ModelFunctionCall, RPCCorountineControl
from realhf.system.redistributor import GlobalStorageTracker, RedistribPlanner
@ -28,7 +28,7 @@ class FunctionExecutor:
msid2mwid: Dict[ModelShardID, int],
stream: NameResolvingRequestClient,
buffer: AsyncIOSequenceBuffer,
model_topos: Dict[str, PipeModelDataParallelTopology],
model_topos: Dict[str, ProcessTopology],
model_configs: Dict[str, None | ReaLModelConfig],
ctrl: RPCCorountineControl,
summary_writer: SummaryWriter | None,
@ -91,6 +91,23 @@ class FunctionExecutor:
async def finish_traverse(self):
for _ in range(len(self.get_leaf_tasks())):
await self.ctrl.train_count.get()
await self.clear_gpu_cache()
async def clear_gpu_cache(self):
async with self.ctrl.lock:
self.ctrl.used_hash_vals_this_epoch += list(self.ctrl.ids_to_clear)
self.stream.request(
handlers=list(range(self.n_model_workers)),
handle_type="clear_data_cache",
datas=[
self.ctrl.ids_to_clear for _ in list(range(self.n_model_workers))
],
no_syn=True,
)
# Clear resource tracker as well.
await self.storage_tracker.clear_data(self.ctrl.ids_to_clear)
self.ctrl.ids_to_clear.clear()
async def load_data(self):
src_rpc = self.src_rpc
@ -119,40 +136,50 @@ class FunctionExecutor:
if x.meta_sample is None:
continue
# RPCs corountines will use this information to
# determine the src and dst of data transfer.
for xx in x.meta_sample.unpack():
if xx.ids[0] in received_ids:
raise ValueError(f"Duplicate data id {xx.ids[0]}.")
received_ids.add(xx.ids[0])
all_data = x.meta_sample.unpack()
filtered_data = []
ids_to_ignore = []
for xx in x.meta_sample.unpack():
if xx.ids[0] in ctrl.hash_vals_to_ignore_in_recover:
ctrl.hash_vals_to_ignore_in_recover.remove(xx.ids[0])
ctrl.ids_to_clear.add(xx.ids[0])
else:
filtered_data.append(xx)
async with ctrl.lock:
if xx.ids[0] in ctrl.hash_vals_to_ignore_in_recover:
ctrl.hash_vals_to_ignore_in_recover.remove(xx.ids[0])
ids_to_ignore.append(xx.ids[0])
else:
if xx.ids[0] in received_ids:
raise ValueError(f"Duplicate data id {xx.ids[0]}.")
received_ids.add(xx.ids[0])
filtered_data.append(xx)
if ids_to_ignore:
# Clear ignored data.
self.stream.request(
handlers=list(range(self.n_model_workers)),
handle_type="clear_data_cache",
datas=[ids_to_ignore for _ in list(range(self.n_model_workers))],
no_syn=True,
)
all_data = filtered_data
# We load data in a round-robin manner across different DP ranks,
# so we also need to shuffle the data to fuse different dataset splits.
random.shuffle(all_data)
# Update resource tracker for planning data redistribution.
gpu_id = self.stream.route_to(f"__data{dp_idx}__")
for k in all_data[0].keys:
self.storage_tracker.add_data(
gpu_id,
[x.ids[0] for x in all_data],
k,
is_owner=True,
)
if len(all_data) > 0:
# Update resource tracker for planning data redistribution.
gpu_id = self.stream.route_to(f"__data{dp_idx}__")
for k in all_data[0].keys:
await self.storage_tracker.add_data(
gpu_id,
[x.ids[0] for x in all_data],
k,
is_owner=True,
)
# Store into buffer!
buffer_indices = await buffer.put_batch(all_data)
assert len(buffer_indices) == len(all_data)
# Store into buffer!
buffer_indices = await buffer.put_batch(all_data)
assert len(buffer_indices) == len(all_data)
blogger.info(
f"Master worker loaded {len(all_data)} pieces of data from DP rank {dp_idx}. "
@ -173,19 +200,3 @@ class FunctionExecutor:
]
loop.run_until_complete(asyncio.gather(*tasks))
logger.info("Execution finished!")
self.clear_gpu_cache()
def clear_gpu_cache(self):
self.stream.request(
handlers=list(range(self.n_model_workers)),
handle_type="clear_data_cache",
datas=[self.ctrl.ids_to_clear for _ in list(range(self.n_model_workers))],
no_syn=True,
)
# Clear resource tracker as well.
self.storage_tracker.clear_data(self.ctrl.ids_to_clear)
self.ctrl.ids_to_clear.clear()

View File

@ -30,7 +30,15 @@ import realhf.system.request_reply_stream as request_reply_stream
import realhf.system.worker_base as worker_base
from realhf.api.core.config import ModelName
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import constants, logging, name_resolve, names, timeutil, topology
from realhf.base import (
constants,
logging,
name_resolve,
names,
seeding,
timeutil,
topology,
)
from realhf.system.buffer import AsyncIOSequenceBuffer
from realhf.system.function_executor import FunctionExecutor
from realhf.system.model_function_call import RPCCorountineControl
@ -45,7 +53,9 @@ class MasterWorker(worker_base.Worker):
def _configure(self, config: config_pkg.MasterWorker):
self.config = config
self.__model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = (
seeding.set_random_seed(self.config.base_seed + self.config.n_model_workers)
self.__model_topos: Dict[ModelName, topology.ProcessTopology] = (
config.model_topos
)
@ -111,6 +121,7 @@ class MasterWorker(worker_base.Worker):
self.__rpc_ctrl = RPCCorountineControl(
train_count=asyncio.Queue(maxsize=len(self.__rpc_dsts)),
topo_level_count=asyncio.Queue(maxsize=sum(self.__topo_widths)),
lock=asyncio.Lock(),
# NOTE: We should accumulate the used data hashes in the same epoch
# to prevent loading data used before.
used_hash_vals_this_epoch=(
@ -308,6 +319,10 @@ class MasterWorker(worker_base.Worker):
ctrl=self.__rpc_ctrl,
summary_writer=self.__summary_writer,
)
if self.__recover_run:
self.func_executor.data_loading_dp_idx = (
self.__recover_info.data_loading_dp_idx
)
logger.info(f"Coroutines created. The master worker is ready to run.")
self.__initialized = True
@ -343,6 +358,18 @@ class MasterWorker(worker_base.Worker):
epochs=int(is_epoch_last_step), steps=1
)
# Log eval/save info.
epoch = self.__rpc_ctrl.step_info.epoch + 1
epoch_step = self.__rpc_ctrl.step_info.epoch_step + 1
global_step = self.__rpc_ctrl.step_info.global_step + 1
s = f"The next step is epoch {epoch}/{self.config.exp_ctrl.total_train_epochs} "
s += f"step {epoch_step}/{self._steps_per_epoch} "
s += f"(global step {global_step}). "
s += f"Should save a checkpoint for recover? {self.__rpc_ctrl.should_ckpt}. "
s += f"Should save a persistent checkpoint for evaluation? {self.__rpc_ctrl.should_save}. "
s += f"Should run evaluation? {self.__rpc_ctrl.should_eval}. "
self.logger.info(s)
# Traverse over the dataflow graph for once.
self.func_executor.execute_step()
@ -350,8 +377,6 @@ class MasterWorker(worker_base.Worker):
if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt:
self.__last_step_info = copy.deepcopy(self.__rpc_ctrl.step_info)
self.__rpc_ctrl.used_hash_vals_this_epoch += list(self.__rpc_ctrl.ids_to_clear)
if is_epoch_last_step:
self.__rpc_ctrl.used_hash_vals_this_epoch = (
self.__rpc_ctrl.used_hash_vals_this_epoch[self._dataset_size :]
@ -431,8 +456,8 @@ class MasterWorker(worker_base.Worker):
s += f"Estimated remaining time: {remain_t:.3f}s. "
if flops is not None:
s += f"TFLOP/s per GPU: {tflops_per_gpu:.2f}, total TFLOP/s: {tflops:.2f}."
logger.info(s)
logger.info(
self.logger.info(s)
self.logger.info(
f"Time taken so far across all configurations: {time.perf_counter() - self.global_exp_tik:.2f}s"
)
@ -486,6 +511,7 @@ class MasterWorker(worker_base.Worker):
save_ctl_info=self.__save_ctl.state_dict(),
ckpt_ctl_info=self.__ckpt_ctl.state_dict(),
eval_ctl_info=self.__eval_ctl.state_dict(),
data_loading_dp_idx=self.func_executor.data_loading_dp_idx,
hash_vals_to_ignore=self.__rpc_ctrl.used_hash_vals_this_epoch,
)

View File

@ -39,6 +39,7 @@ class RPCCorountineControl:
# For flushing requests
topo_level_count: asyncio.Queue
lock: asyncio.Lock
# for training data management and data cleaning after each step
ids_to_clear: Set[Hashable] = dataclasses.field(default_factory=set)
flops_counter: FlopsCounter = dataclasses.field(default_factory=FlopsCounter)
@ -60,7 +61,7 @@ class ModelFunctionCall:
src_rpc: dfg.MFCDef,
stream: request_reply_stream.NameResolvingRequestClient,
msid2mwid: Dict[config_pkg.ModelShardID, int],
model_topos: Dict[str, topology.PipeModelDataParallelTopology],
model_topos: Dict[str, topology.ProcessTopology],
model_configs: Dict[str, None | ReaLModelConfig],
ctrl: RPCCorountineControl,
buffer: AsyncIOSequenceBuffer,
@ -321,7 +322,8 @@ class ModelFunctionCall:
for i in range(self.dp_size)
]
ctrl.flops_counter.add_rpc(rpc, sample, self.model_configs[rpc.model_name])
async with ctrl.lock:
ctrl.flops_counter.add_rpc(rpc, sample, self.model_configs[rpc.model_name])
# logger.info(f"Model rpc {rpc.name} requesting.")
@ -371,7 +373,7 @@ class ModelFunctionCall:
is_dp_head = h.mp_rank == 0 and h.pp_rank == topo.get_dim("pipe") - 1
gpu_id = self.msid2mwid[h]
for key in rpc.input_keys:
self.redistrib_planner.storage_tracker.add_data(
await self.redistrib_planner.storage_tracker.add_data(
gpu_id, partitioned_ids[h.dp_rank], key=key, is_owner=is_dp_head
)
else:
@ -379,18 +381,19 @@ class ModelFunctionCall:
if step.comm_type == "scatter":
for gpu_id, ids in zip(step.dsts, step.ids):
for key in step.keys:
self.redistrib_planner.storage_tracker.add_data(
await self.redistrib_planner.storage_tracker.add_data(
gpu_id, ids, key=key, is_owner=False
)
elif step.comm_type == "gather":
for key in step.keys:
self.redistrib_planner.storage_tracker.add_data(
await self.redistrib_planner.storage_tracker.add_data(
step.root,
list(itertools.chain.from_iterable(step.ids)),
key=key,
is_owner=False,
)
await asyncio.sleep(0)
# send partitioned data to model workers
req_ids, other_req_ids = self.request(
data_transfer_plan=data_transfer_plan,
@ -425,7 +428,7 @@ class ModelFunctionCall:
)
gpu_id = self.msid2mwid[h]
for k in rpc.output_keys:
self.redistrib_planner.storage_tracker.add_data(
await self.redistrib_planner.storage_tracker.add_data(
gpu_id,
x.ids,
key=k,
@ -455,7 +458,8 @@ class ModelFunctionCall:
# update the train counter.
# Otherwise, amend data in the buffer.
if rpc.is_dst:
ctrl.ids_to_clear = ctrl.ids_to_clear.union(sample.ids)
async with ctrl.lock:
ctrl.ids_to_clear = ctrl.ids_to_clear.union(sample.ids)
await ctrl.train_count.put(1)
else:
logger.info(f"Amending RPC {rpc.name} output keys: {res.keys}")

View File

@ -28,10 +28,9 @@ import torch.distributed
import torch.distributed as dist
import torch.utils.data
import realhf.api.core.dfg as dfg
import realhf.api.core.system_api as system_api
import realhf.impl.model.comm.global_comm as global_comm
import realhf.impl.model.comm.param_realloc as param_realloc_comm
from realhf.api.core import data_api, dfg, model_api, system_api
from realhf.api.core.config import ModelName
from realhf.base import (
constants,
@ -56,8 +55,8 @@ from realhf.system.data_manager import DataManager
from realhf.system.redistributor import RedistribStep
# NOTE: Register all implemented datasets and models.
import realhf.api.core.data_api as data_api # isort:skip
import realhf.api.core.model_api as model_api # isort:skip
import realhf.impl.dataset # isort:skip
import realhf.impl.model # isort:skip
logger = logging.getLogger("Model Worker", "colored")
blogger = logging.getLogger("benchmark")
@ -111,12 +110,6 @@ class ModelWorker(worker_base.Worker):
self.config = cfg
self.model_names = [s.id.model_name for s in cfg.shards]
self.shard_indices = [
cfg.model_topos[s.id.model_name].get_rank(
data=s.id.dp_rank, pipe=s.id.pp_rank, model=s.id.mp_rank
)
for s in cfg.shards
]
self.__experiment_name = self.config.worker_info.experiment_name
self.__trial_name = self.config.worker_info.trial_name
@ -516,7 +509,6 @@ class ModelWorker(worker_base.Worker):
f"RPC hook {hook} CPU time {time.perf_counter() - tik:.4f}s."
)
if constants.use_cuda():
# FIXME: temporary synchronize for debugging
torch.cuda.synchronize()
return ret
@ -933,16 +925,17 @@ class ModelWorker(worker_base.Worker):
if constants.use_cuda():
# Monitoring info. There's an all-gather and an all-reduce
# over the parallelism group in this function.
# FIXME: temporary synchronize for debugging
torch.cuda.synchronize()
if self._model.backend_name != "vllm":
# Since vLLM allocates GPU memory in advance, it is very
if (
self._model.backend_name != "vllm"
and self._model.backend_name != "sglang"
):
# Since vLLM/SGLang allocates GPU memory in advance, it is very
# easy to exceed the 0.95 threshold that triggers a kill.
# We omit GPU stats logging for vLLM.
# We omit GPU stats logging for vLLM/SGLang.
self.__log_gpu_stats(request)
self._clear_memory()
# FIXME: temporary synchronize for debugging
if constants.use_cuda():
torch.cuda.synchronize()
dist.barrier(group=constants.parallelism_group())
@ -973,8 +966,8 @@ class ModelWorker(worker_base.Worker):
from_model_name: ModelName = hook_data["from_model_name"]
to_model_name: ModelName = hook_data["to_model_name"]
from_topo: topology.PipeModelDataParallelTopology = hook_data["from_topo"]
to_topo: topology.PipeModelDataParallelTopology = hook_data["to_topo"]
from_topo: topology.ProcessTopology = hook_data["from_topo"]
to_topo: topology.ProcessTopology = hook_data["to_topo"]
# NOTE: For the convenience of future developement, we
# run parameter reallocation with disk save-load by default.
@ -1121,22 +1114,21 @@ class ModelWorker(worker_base.Worker):
def __load_model(self, hook_data: Dict):
tik = time.perf_counter()
with constants.model_scope(hook_data["model_name"]):
from realhf.impl.model.backend.vllm import (
vLLMGenerationBackend,
vLLMGenerationEngine,
)
if isinstance(self._model.module, torch.nn.Identity) and isinstance(
self._backend, vLLMGenerationBackend
self._backend,
(
model_api.ALL_BACKEND_CLASSES["sglang"],
model_api.ALL_BACKEND_CLASSES["vllm"],
),
):
# The uninitialized vLLM model. Since we create the model
# inside the vLLM backend, the initial param realloc before
# The uninitialized vLLM/SGLang model. Since we create the model
# inside the vLLM/SGLang backend, the initial param realloc before
# backend initialization can be ignored.
return
if self._model.backend_name == "vllm":
if self._model.backend_name in ["vllm", "sglang"]:
if constants.parallelism_rank() == 0:
logger.info("Updating vLLM model from disk.")
module: vLLMGenerationEngine = self._model.module
logger.info(f"Updating {self._model.backend_name} model from disk.")
module = self._model.module
module.update_weights_from_disk(hook_data["load_dir"])
else:
module: ReaLModel = self.__unwrapped_models[hook_data["model_name"]]
@ -1147,7 +1139,7 @@ class ModelWorker(worker_base.Worker):
t = torch.tensor(
float(time.perf_counter() - tik),
dtype=torch.float64,
device=module.device,
device=constants.current_device(),
)
dist.all_reduce(
t, op=dist.ReduceOp.MAX, group=constants.parallelism_group()

View File

@ -1,5 +1,6 @@
# Copyright 2025 Ant Group Inc.
import asyncio
import dataclasses
import itertools
import os
@ -11,12 +12,24 @@ from realhf.base.cluster import spec as cluster_spec
class GlobalStorageTracker:
def __init__(self, world_size: int):
self.lock = asyncio.Lock()
self.storages: List[Dict[Hashable, List[str]]]
self.storages = [{} for _ in range(world_size)]
self.data_owner: Dict[Tuple[Hashable, str], int]
self.data_owner = {}
def add_data(self, rank: int, ids: List[Hashable], key: str, is_owner: bool):
async def add_data(self, rank: int, ids: List[Hashable], key: str, is_owner: bool):
async with self.lock:
for data_id in ids:
if data_id not in self.storages[rank]:
self.storages[rank][data_id] = [key]
else:
if key not in self.storages[rank][data_id]:
self.storages[rank][data_id].append(key)
if is_owner:
self.data_owner[(data_id, key)] = rank
def add_data_synced(self, rank: int, ids: List[Hashable], key: str, is_owner: bool):
for data_id in ids:
if data_id not in self.storages[rank]:
self.storages[rank][data_id] = [key]
@ -26,15 +39,16 @@ class GlobalStorageTracker:
if is_owner:
self.data_owner[(data_id, key)] = rank
def clear_data(self, ids: List[Hashable]):
for storage in self.storages:
for i in ids:
if i in storage:
storage.pop(i)
keys = list(self.data_owner.keys())
for i, k in keys:
if i in ids:
self.data_owner.pop((i, k))
async def clear_data(self, ids: List[Hashable]):
async with self.lock:
for storage in self.storages:
for i in ids:
if i in storage:
storage.pop(i)
keys = list(self.data_owner.keys())
for i, k in keys:
if i in ids:
self.data_owner.pop((i, k))
@dataclasses.dataclass

View File

@ -186,7 +186,7 @@ def _test_data_transfer(
topo=from_topo,
)
]
storage_tracker.add_data(
storage_tracker.add_data_synced(
gpu_id,
ids=[i + dp_rank * world_size for i in range(world_size)],
key=key,

View File

@ -18,11 +18,7 @@ def model_class(request):
return request.param
def run_model_worker(
cfg,
mw,
barrier,
):
def run_model_worker(cfg, mw, barrier, expr_name=None):
constants.set_force_cpu(True)
# Register all datasets and models
import realhf.impl.dataset # isort: skip
@ -31,7 +27,7 @@ def run_model_worker(
from realhf.system.model_worker import ModelWorker
system_api.ALL_EXPERIMENT_CLASSES = {}
register_experiment(testing._DEFAULT_EXPR_NAME, lambda: cfg)
register_experiment(expr_name or testing._DEFAULT_EXPR_NAME, lambda: cfg)
worker = ModelWorker()
logger.info("Configuring model worker...")
@ -61,7 +57,7 @@ def run_test_exp(
from realhf.system.master_worker import MasterWorker
system_api.ALL_EXPERIMENT_CLASSES = {}
register_experiment(testing._DEFAULT_EXPR_NAME, lambda: exp_cfg)
register_experiment(expr_name or testing._DEFAULT_EXPR_NAME, lambda: exp_cfg)
# Get worker configurations
exp_setup = exp_cfg.initial_setup()
@ -87,6 +83,7 @@ def run_test_exp(
cfg=exp_cfg,
mw=mw,
barrier=barrier,
expr_name=expr_name,
)
for mw in exp_setup.model_worker
],