mirror of https://github.com/inclusionAI/AReaL
PullRequest: 29 fix the dataloading bug during recover
Merge branch fw/fix-recover-dataloading of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/29 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * fix the dataloading bug during recover
This commit is contained in:
parent
fb23009e99
commit
26a48be73e
|
@ -91,8 +91,6 @@ def main_worker(args):
|
|||
|
||||
# NOTE: Importing these will initialize DeepSpeed/CUDA devices.
|
||||
# profiler.import_profiler_registers()
|
||||
import realhf.impl.dataset
|
||||
import realhf.impl.model
|
||||
import realhf.system
|
||||
|
||||
logger.debug(f"Run {args.worker_type} worker with args: %s", args)
|
||||
|
|
|
@ -35,6 +35,8 @@ class RecoverInfo:
|
|||
ckpt_ctl_info: Dict
|
||||
eval_ctl_info: Dict
|
||||
|
||||
data_loading_dp_idx: int
|
||||
|
||||
hash_vals_to_ignore: List[int] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
|
|
|
@ -260,7 +260,7 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
offset += x[0]
|
||||
assert offset == sum(x[0] for x in input_.seqlens["packed_prompts"])
|
||||
|
||||
if model.backend_name != "vllm":
|
||||
if model.backend_name not in ["vllm", "sglang"]:
|
||||
# Replicate prompts
|
||||
grouped_input = SequenceSample.from_default(
|
||||
ids=list(range(input_.bs * self.generation_size)),
|
||||
|
@ -287,7 +287,7 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
gconfig=self.gconfig,
|
||||
mb_spec=mb_spec,
|
||||
)
|
||||
if res is None:
|
||||
if res is None or res[0] is None:
|
||||
return None
|
||||
|
||||
gen_tokens, logprobs, _ = res
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch.distributed as dist
|
|||
from realhf import SequenceSample
|
||||
from realhf.api.core.config import ModelName, ModelShardID
|
||||
from realhf.base import constants
|
||||
from realhf.base.topology import PipeModelDataParallelTopology, new_or_get_group
|
||||
from realhf.base.topology import ProcessTopology, new_or_get_group
|
||||
from realhf.impl.model.comm.global_comm import filter_match_mwids
|
||||
from realhf.system.redistributor import RedistribStep
|
||||
|
||||
|
@ -26,7 +26,7 @@ class DataManager:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_topos: Dict[ModelName, PipeModelDataParallelTopology],
|
||||
model_topos: Dict[ModelName, ProcessTopology],
|
||||
msid2mwid: Optional[Dict[ModelShardID, int]] = None,
|
||||
data_transfer_pairs: Optional[List[Tuple[ModelName, ModelName]]] = None,
|
||||
):
|
||||
|
|
|
@ -11,7 +11,7 @@ from realhf.api.core.data_api import DataBatchMeta, SequenceSample
|
|||
from realhf.api.core.dfg import MFCDef
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base import logging
|
||||
from realhf.base.topology import PipeModelDataParallelTopology
|
||||
from realhf.base.topology import ProcessTopology
|
||||
from realhf.system.buffer import AsyncIOSequenceBuffer
|
||||
from realhf.system.model_function_call import ModelFunctionCall, RPCCorountineControl
|
||||
from realhf.system.redistributor import GlobalStorageTracker, RedistribPlanner
|
||||
|
@ -28,7 +28,7 @@ class FunctionExecutor:
|
|||
msid2mwid: Dict[ModelShardID, int],
|
||||
stream: NameResolvingRequestClient,
|
||||
buffer: AsyncIOSequenceBuffer,
|
||||
model_topos: Dict[str, PipeModelDataParallelTopology],
|
||||
model_topos: Dict[str, ProcessTopology],
|
||||
model_configs: Dict[str, None | ReaLModelConfig],
|
||||
ctrl: RPCCorountineControl,
|
||||
summary_writer: SummaryWriter | None,
|
||||
|
@ -91,6 +91,23 @@ class FunctionExecutor:
|
|||
async def finish_traverse(self):
|
||||
for _ in range(len(self.get_leaf_tasks())):
|
||||
await self.ctrl.train_count.get()
|
||||
await self.clear_gpu_cache()
|
||||
|
||||
async def clear_gpu_cache(self):
|
||||
async with self.ctrl.lock:
|
||||
self.ctrl.used_hash_vals_this_epoch += list(self.ctrl.ids_to_clear)
|
||||
self.stream.request(
|
||||
handlers=list(range(self.n_model_workers)),
|
||||
handle_type="clear_data_cache",
|
||||
datas=[
|
||||
self.ctrl.ids_to_clear for _ in list(range(self.n_model_workers))
|
||||
],
|
||||
no_syn=True,
|
||||
)
|
||||
# Clear resource tracker as well.
|
||||
await self.storage_tracker.clear_data(self.ctrl.ids_to_clear)
|
||||
|
||||
self.ctrl.ids_to_clear.clear()
|
||||
|
||||
async def load_data(self):
|
||||
src_rpc = self.src_rpc
|
||||
|
@ -119,40 +136,50 @@ class FunctionExecutor:
|
|||
if x.meta_sample is None:
|
||||
continue
|
||||
|
||||
# RPCs corountines will use this information to
|
||||
# determine the src and dst of data transfer.
|
||||
for xx in x.meta_sample.unpack():
|
||||
if xx.ids[0] in received_ids:
|
||||
raise ValueError(f"Duplicate data id {xx.ids[0]}.")
|
||||
received_ids.add(xx.ids[0])
|
||||
all_data = x.meta_sample.unpack()
|
||||
|
||||
filtered_data = []
|
||||
ids_to_ignore = []
|
||||
for xx in x.meta_sample.unpack():
|
||||
if xx.ids[0] in ctrl.hash_vals_to_ignore_in_recover:
|
||||
ctrl.hash_vals_to_ignore_in_recover.remove(xx.ids[0])
|
||||
ctrl.ids_to_clear.add(xx.ids[0])
|
||||
else:
|
||||
filtered_data.append(xx)
|
||||
async with ctrl.lock:
|
||||
if xx.ids[0] in ctrl.hash_vals_to_ignore_in_recover:
|
||||
ctrl.hash_vals_to_ignore_in_recover.remove(xx.ids[0])
|
||||
ids_to_ignore.append(xx.ids[0])
|
||||
else:
|
||||
if xx.ids[0] in received_ids:
|
||||
raise ValueError(f"Duplicate data id {xx.ids[0]}.")
|
||||
received_ids.add(xx.ids[0])
|
||||
filtered_data.append(xx)
|
||||
|
||||
if ids_to_ignore:
|
||||
# Clear ignored data.
|
||||
self.stream.request(
|
||||
handlers=list(range(self.n_model_workers)),
|
||||
handle_type="clear_data_cache",
|
||||
datas=[ids_to_ignore for _ in list(range(self.n_model_workers))],
|
||||
no_syn=True,
|
||||
)
|
||||
|
||||
all_data = filtered_data
|
||||
|
||||
# We load data in a round-robin manner across different DP ranks,
|
||||
# so we also need to shuffle the data to fuse different dataset splits.
|
||||
random.shuffle(all_data)
|
||||
|
||||
# Update resource tracker for planning data redistribution.
|
||||
gpu_id = self.stream.route_to(f"__data{dp_idx}__")
|
||||
for k in all_data[0].keys:
|
||||
self.storage_tracker.add_data(
|
||||
gpu_id,
|
||||
[x.ids[0] for x in all_data],
|
||||
k,
|
||||
is_owner=True,
|
||||
)
|
||||
if len(all_data) > 0:
|
||||
# Update resource tracker for planning data redistribution.
|
||||
gpu_id = self.stream.route_to(f"__data{dp_idx}__")
|
||||
for k in all_data[0].keys:
|
||||
await self.storage_tracker.add_data(
|
||||
gpu_id,
|
||||
[x.ids[0] for x in all_data],
|
||||
k,
|
||||
is_owner=True,
|
||||
)
|
||||
|
||||
# Store into buffer!
|
||||
buffer_indices = await buffer.put_batch(all_data)
|
||||
assert len(buffer_indices) == len(all_data)
|
||||
# Store into buffer!
|
||||
buffer_indices = await buffer.put_batch(all_data)
|
||||
assert len(buffer_indices) == len(all_data)
|
||||
|
||||
blogger.info(
|
||||
f"Master worker loaded {len(all_data)} pieces of data from DP rank {dp_idx}. "
|
||||
|
@ -173,19 +200,3 @@ class FunctionExecutor:
|
|||
]
|
||||
|
||||
loop.run_until_complete(asyncio.gather(*tasks))
|
||||
|
||||
logger.info("Execution finished!")
|
||||
|
||||
self.clear_gpu_cache()
|
||||
|
||||
def clear_gpu_cache(self):
|
||||
self.stream.request(
|
||||
handlers=list(range(self.n_model_workers)),
|
||||
handle_type="clear_data_cache",
|
||||
datas=[self.ctrl.ids_to_clear for _ in list(range(self.n_model_workers))],
|
||||
no_syn=True,
|
||||
)
|
||||
# Clear resource tracker as well.
|
||||
self.storage_tracker.clear_data(self.ctrl.ids_to_clear)
|
||||
|
||||
self.ctrl.ids_to_clear.clear()
|
||||
|
|
|
@ -30,7 +30,15 @@ import realhf.system.request_reply_stream as request_reply_stream
|
|||
import realhf.system.worker_base as worker_base
|
||||
from realhf.api.core.config import ModelName
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base import constants, logging, name_resolve, names, timeutil, topology
|
||||
from realhf.base import (
|
||||
constants,
|
||||
logging,
|
||||
name_resolve,
|
||||
names,
|
||||
seeding,
|
||||
timeutil,
|
||||
topology,
|
||||
)
|
||||
from realhf.system.buffer import AsyncIOSequenceBuffer
|
||||
from realhf.system.function_executor import FunctionExecutor
|
||||
from realhf.system.model_function_call import RPCCorountineControl
|
||||
|
@ -45,7 +53,9 @@ class MasterWorker(worker_base.Worker):
|
|||
def _configure(self, config: config_pkg.MasterWorker):
|
||||
self.config = config
|
||||
|
||||
self.__model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = (
|
||||
seeding.set_random_seed(self.config.base_seed + self.config.n_model_workers)
|
||||
|
||||
self.__model_topos: Dict[ModelName, topology.ProcessTopology] = (
|
||||
config.model_topos
|
||||
)
|
||||
|
||||
|
@ -111,6 +121,7 @@ class MasterWorker(worker_base.Worker):
|
|||
self.__rpc_ctrl = RPCCorountineControl(
|
||||
train_count=asyncio.Queue(maxsize=len(self.__rpc_dsts)),
|
||||
topo_level_count=asyncio.Queue(maxsize=sum(self.__topo_widths)),
|
||||
lock=asyncio.Lock(),
|
||||
# NOTE: We should accumulate the used data hashes in the same epoch
|
||||
# to prevent loading data used before.
|
||||
used_hash_vals_this_epoch=(
|
||||
|
@ -308,6 +319,10 @@ class MasterWorker(worker_base.Worker):
|
|||
ctrl=self.__rpc_ctrl,
|
||||
summary_writer=self.__summary_writer,
|
||||
)
|
||||
if self.__recover_run:
|
||||
self.func_executor.data_loading_dp_idx = (
|
||||
self.__recover_info.data_loading_dp_idx
|
||||
)
|
||||
logger.info(f"Coroutines created. The master worker is ready to run.")
|
||||
|
||||
self.__initialized = True
|
||||
|
@ -343,6 +358,18 @@ class MasterWorker(worker_base.Worker):
|
|||
epochs=int(is_epoch_last_step), steps=1
|
||||
)
|
||||
|
||||
# Log eval/save info.
|
||||
epoch = self.__rpc_ctrl.step_info.epoch + 1
|
||||
epoch_step = self.__rpc_ctrl.step_info.epoch_step + 1
|
||||
global_step = self.__rpc_ctrl.step_info.global_step + 1
|
||||
s = f"The next step is epoch {epoch}/{self.config.exp_ctrl.total_train_epochs} "
|
||||
s += f"step {epoch_step}/{self._steps_per_epoch} "
|
||||
s += f"(global step {global_step}). "
|
||||
s += f"Should save a checkpoint for recover? {self.__rpc_ctrl.should_ckpt}. "
|
||||
s += f"Should save a persistent checkpoint for evaluation? {self.__rpc_ctrl.should_save}. "
|
||||
s += f"Should run evaluation? {self.__rpc_ctrl.should_eval}. "
|
||||
self.logger.info(s)
|
||||
|
||||
# Traverse over the dataflow graph for once.
|
||||
self.func_executor.execute_step()
|
||||
|
||||
|
@ -350,8 +377,6 @@ class MasterWorker(worker_base.Worker):
|
|||
if self.__rpc_ctrl.should_save or self.__rpc_ctrl.should_ckpt:
|
||||
self.__last_step_info = copy.deepcopy(self.__rpc_ctrl.step_info)
|
||||
|
||||
self.__rpc_ctrl.used_hash_vals_this_epoch += list(self.__rpc_ctrl.ids_to_clear)
|
||||
|
||||
if is_epoch_last_step:
|
||||
self.__rpc_ctrl.used_hash_vals_this_epoch = (
|
||||
self.__rpc_ctrl.used_hash_vals_this_epoch[self._dataset_size :]
|
||||
|
@ -431,8 +456,8 @@ class MasterWorker(worker_base.Worker):
|
|||
s += f"Estimated remaining time: {remain_t:.3f}s. "
|
||||
if flops is not None:
|
||||
s += f"TFLOP/s per GPU: {tflops_per_gpu:.2f}, total TFLOP/s: {tflops:.2f}."
|
||||
logger.info(s)
|
||||
logger.info(
|
||||
self.logger.info(s)
|
||||
self.logger.info(
|
||||
f"Time taken so far across all configurations: {time.perf_counter() - self.global_exp_tik:.2f}s"
|
||||
)
|
||||
|
||||
|
@ -486,6 +511,7 @@ class MasterWorker(worker_base.Worker):
|
|||
save_ctl_info=self.__save_ctl.state_dict(),
|
||||
ckpt_ctl_info=self.__ckpt_ctl.state_dict(),
|
||||
eval_ctl_info=self.__eval_ctl.state_dict(),
|
||||
data_loading_dp_idx=self.func_executor.data_loading_dp_idx,
|
||||
hash_vals_to_ignore=self.__rpc_ctrl.used_hash_vals_this_epoch,
|
||||
)
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ class RPCCorountineControl:
|
|||
# For flushing requests
|
||||
topo_level_count: asyncio.Queue
|
||||
|
||||
lock: asyncio.Lock
|
||||
# for training data management and data cleaning after each step
|
||||
ids_to_clear: Set[Hashable] = dataclasses.field(default_factory=set)
|
||||
flops_counter: FlopsCounter = dataclasses.field(default_factory=FlopsCounter)
|
||||
|
@ -60,7 +61,7 @@ class ModelFunctionCall:
|
|||
src_rpc: dfg.MFCDef,
|
||||
stream: request_reply_stream.NameResolvingRequestClient,
|
||||
msid2mwid: Dict[config_pkg.ModelShardID, int],
|
||||
model_topos: Dict[str, topology.PipeModelDataParallelTopology],
|
||||
model_topos: Dict[str, topology.ProcessTopology],
|
||||
model_configs: Dict[str, None | ReaLModelConfig],
|
||||
ctrl: RPCCorountineControl,
|
||||
buffer: AsyncIOSequenceBuffer,
|
||||
|
@ -321,7 +322,8 @@ class ModelFunctionCall:
|
|||
for i in range(self.dp_size)
|
||||
]
|
||||
|
||||
ctrl.flops_counter.add_rpc(rpc, sample, self.model_configs[rpc.model_name])
|
||||
async with ctrl.lock:
|
||||
ctrl.flops_counter.add_rpc(rpc, sample, self.model_configs[rpc.model_name])
|
||||
|
||||
# logger.info(f"Model rpc {rpc.name} requesting.")
|
||||
|
||||
|
@ -371,7 +373,7 @@ class ModelFunctionCall:
|
|||
is_dp_head = h.mp_rank == 0 and h.pp_rank == topo.get_dim("pipe") - 1
|
||||
gpu_id = self.msid2mwid[h]
|
||||
for key in rpc.input_keys:
|
||||
self.redistrib_planner.storage_tracker.add_data(
|
||||
await self.redistrib_planner.storage_tracker.add_data(
|
||||
gpu_id, partitioned_ids[h.dp_rank], key=key, is_owner=is_dp_head
|
||||
)
|
||||
else:
|
||||
|
@ -379,18 +381,19 @@ class ModelFunctionCall:
|
|||
if step.comm_type == "scatter":
|
||||
for gpu_id, ids in zip(step.dsts, step.ids):
|
||||
for key in step.keys:
|
||||
self.redistrib_planner.storage_tracker.add_data(
|
||||
await self.redistrib_planner.storage_tracker.add_data(
|
||||
gpu_id, ids, key=key, is_owner=False
|
||||
)
|
||||
elif step.comm_type == "gather":
|
||||
for key in step.keys:
|
||||
self.redistrib_planner.storage_tracker.add_data(
|
||||
await self.redistrib_planner.storage_tracker.add_data(
|
||||
step.root,
|
||||
list(itertools.chain.from_iterable(step.ids)),
|
||||
key=key,
|
||||
is_owner=False,
|
||||
)
|
||||
|
||||
await asyncio.sleep(0)
|
||||
# send partitioned data to model workers
|
||||
req_ids, other_req_ids = self.request(
|
||||
data_transfer_plan=data_transfer_plan,
|
||||
|
@ -425,7 +428,7 @@ class ModelFunctionCall:
|
|||
)
|
||||
gpu_id = self.msid2mwid[h]
|
||||
for k in rpc.output_keys:
|
||||
self.redistrib_planner.storage_tracker.add_data(
|
||||
await self.redistrib_planner.storage_tracker.add_data(
|
||||
gpu_id,
|
||||
x.ids,
|
||||
key=k,
|
||||
|
@ -455,7 +458,8 @@ class ModelFunctionCall:
|
|||
# update the train counter.
|
||||
# Otherwise, amend data in the buffer.
|
||||
if rpc.is_dst:
|
||||
ctrl.ids_to_clear = ctrl.ids_to_clear.union(sample.ids)
|
||||
async with ctrl.lock:
|
||||
ctrl.ids_to_clear = ctrl.ids_to_clear.union(sample.ids)
|
||||
await ctrl.train_count.put(1)
|
||||
else:
|
||||
logger.info(f"Amending RPC {rpc.name} output keys: {res.keys}")
|
||||
|
|
|
@ -28,10 +28,9 @@ import torch.distributed
|
|||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
|
||||
import realhf.api.core.dfg as dfg
|
||||
import realhf.api.core.system_api as system_api
|
||||
import realhf.impl.model.comm.global_comm as global_comm
|
||||
import realhf.impl.model.comm.param_realloc as param_realloc_comm
|
||||
from realhf.api.core import data_api, dfg, model_api, system_api
|
||||
from realhf.api.core.config import ModelName
|
||||
from realhf.base import (
|
||||
constants,
|
||||
|
@ -56,8 +55,8 @@ from realhf.system.data_manager import DataManager
|
|||
from realhf.system.redistributor import RedistribStep
|
||||
|
||||
# NOTE: Register all implemented datasets and models.
|
||||
import realhf.api.core.data_api as data_api # isort:skip
|
||||
import realhf.api.core.model_api as model_api # isort:skip
|
||||
import realhf.impl.dataset # isort:skip
|
||||
import realhf.impl.model # isort:skip
|
||||
|
||||
logger = logging.getLogger("Model Worker", "colored")
|
||||
blogger = logging.getLogger("benchmark")
|
||||
|
@ -111,12 +110,6 @@ class ModelWorker(worker_base.Worker):
|
|||
|
||||
self.config = cfg
|
||||
self.model_names = [s.id.model_name for s in cfg.shards]
|
||||
self.shard_indices = [
|
||||
cfg.model_topos[s.id.model_name].get_rank(
|
||||
data=s.id.dp_rank, pipe=s.id.pp_rank, model=s.id.mp_rank
|
||||
)
|
||||
for s in cfg.shards
|
||||
]
|
||||
|
||||
self.__experiment_name = self.config.worker_info.experiment_name
|
||||
self.__trial_name = self.config.worker_info.trial_name
|
||||
|
@ -516,7 +509,6 @@ class ModelWorker(worker_base.Worker):
|
|||
f"RPC hook {hook} CPU time {time.perf_counter() - tik:.4f}s."
|
||||
)
|
||||
if constants.use_cuda():
|
||||
# FIXME: temporary synchronize for debugging
|
||||
torch.cuda.synchronize()
|
||||
return ret
|
||||
|
||||
|
@ -933,16 +925,17 @@ class ModelWorker(worker_base.Worker):
|
|||
if constants.use_cuda():
|
||||
# Monitoring info. There's an all-gather and an all-reduce
|
||||
# over the parallelism group in this function.
|
||||
# FIXME: temporary synchronize for debugging
|
||||
torch.cuda.synchronize()
|
||||
if self._model.backend_name != "vllm":
|
||||
# Since vLLM allocates GPU memory in advance, it is very
|
||||
if (
|
||||
self._model.backend_name != "vllm"
|
||||
and self._model.backend_name != "sglang"
|
||||
):
|
||||
# Since vLLM/SGLang allocates GPU memory in advance, it is very
|
||||
# easy to exceed the 0.95 threshold that triggers a kill.
|
||||
# We omit GPU stats logging for vLLM.
|
||||
# We omit GPU stats logging for vLLM/SGLang.
|
||||
self.__log_gpu_stats(request)
|
||||
|
||||
self._clear_memory()
|
||||
# FIXME: temporary synchronize for debugging
|
||||
if constants.use_cuda():
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=constants.parallelism_group())
|
||||
|
@ -973,8 +966,8 @@ class ModelWorker(worker_base.Worker):
|
|||
from_model_name: ModelName = hook_data["from_model_name"]
|
||||
to_model_name: ModelName = hook_data["to_model_name"]
|
||||
|
||||
from_topo: topology.PipeModelDataParallelTopology = hook_data["from_topo"]
|
||||
to_topo: topology.PipeModelDataParallelTopology = hook_data["to_topo"]
|
||||
from_topo: topology.ProcessTopology = hook_data["from_topo"]
|
||||
to_topo: topology.ProcessTopology = hook_data["to_topo"]
|
||||
|
||||
# NOTE: For the convenience of future developement, we
|
||||
# run parameter reallocation with disk save-load by default.
|
||||
|
@ -1121,22 +1114,21 @@ class ModelWorker(worker_base.Worker):
|
|||
def __load_model(self, hook_data: Dict):
|
||||
tik = time.perf_counter()
|
||||
with constants.model_scope(hook_data["model_name"]):
|
||||
from realhf.impl.model.backend.vllm import (
|
||||
vLLMGenerationBackend,
|
||||
vLLMGenerationEngine,
|
||||
)
|
||||
|
||||
if isinstance(self._model.module, torch.nn.Identity) and isinstance(
|
||||
self._backend, vLLMGenerationBackend
|
||||
self._backend,
|
||||
(
|
||||
model_api.ALL_BACKEND_CLASSES["sglang"],
|
||||
model_api.ALL_BACKEND_CLASSES["vllm"],
|
||||
),
|
||||
):
|
||||
# The uninitialized vLLM model. Since we create the model
|
||||
# inside the vLLM backend, the initial param realloc before
|
||||
# The uninitialized vLLM/SGLang model. Since we create the model
|
||||
# inside the vLLM/SGLang backend, the initial param realloc before
|
||||
# backend initialization can be ignored.
|
||||
return
|
||||
if self._model.backend_name == "vllm":
|
||||
if self._model.backend_name in ["vllm", "sglang"]:
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info("Updating vLLM model from disk.")
|
||||
module: vLLMGenerationEngine = self._model.module
|
||||
logger.info(f"Updating {self._model.backend_name} model from disk.")
|
||||
module = self._model.module
|
||||
module.update_weights_from_disk(hook_data["load_dir"])
|
||||
else:
|
||||
module: ReaLModel = self.__unwrapped_models[hook_data["model_name"]]
|
||||
|
@ -1147,7 +1139,7 @@ class ModelWorker(worker_base.Worker):
|
|||
t = torch.tensor(
|
||||
float(time.perf_counter() - tik),
|
||||
dtype=torch.float64,
|
||||
device=module.device,
|
||||
device=constants.current_device(),
|
||||
)
|
||||
dist.all_reduce(
|
||||
t, op=dist.ReduceOp.MAX, group=constants.parallelism_group()
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import itertools
|
||||
import os
|
||||
|
@ -11,12 +12,24 @@ from realhf.base.cluster import spec as cluster_spec
|
|||
|
||||
class GlobalStorageTracker:
|
||||
def __init__(self, world_size: int):
|
||||
self.lock = asyncio.Lock()
|
||||
self.storages: List[Dict[Hashable, List[str]]]
|
||||
self.storages = [{} for _ in range(world_size)]
|
||||
self.data_owner: Dict[Tuple[Hashable, str], int]
|
||||
self.data_owner = {}
|
||||
|
||||
def add_data(self, rank: int, ids: List[Hashable], key: str, is_owner: bool):
|
||||
async def add_data(self, rank: int, ids: List[Hashable], key: str, is_owner: bool):
|
||||
async with self.lock:
|
||||
for data_id in ids:
|
||||
if data_id not in self.storages[rank]:
|
||||
self.storages[rank][data_id] = [key]
|
||||
else:
|
||||
if key not in self.storages[rank][data_id]:
|
||||
self.storages[rank][data_id].append(key)
|
||||
if is_owner:
|
||||
self.data_owner[(data_id, key)] = rank
|
||||
|
||||
def add_data_synced(self, rank: int, ids: List[Hashable], key: str, is_owner: bool):
|
||||
for data_id in ids:
|
||||
if data_id not in self.storages[rank]:
|
||||
self.storages[rank][data_id] = [key]
|
||||
|
@ -26,15 +39,16 @@ class GlobalStorageTracker:
|
|||
if is_owner:
|
||||
self.data_owner[(data_id, key)] = rank
|
||||
|
||||
def clear_data(self, ids: List[Hashable]):
|
||||
for storage in self.storages:
|
||||
for i in ids:
|
||||
if i in storage:
|
||||
storage.pop(i)
|
||||
keys = list(self.data_owner.keys())
|
||||
for i, k in keys:
|
||||
if i in ids:
|
||||
self.data_owner.pop((i, k))
|
||||
async def clear_data(self, ids: List[Hashable]):
|
||||
async with self.lock:
|
||||
for storage in self.storages:
|
||||
for i in ids:
|
||||
if i in storage:
|
||||
storage.pop(i)
|
||||
keys = list(self.data_owner.keys())
|
||||
for i, k in keys:
|
||||
if i in ids:
|
||||
self.data_owner.pop((i, k))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
|
@ -186,7 +186,7 @@ def _test_data_transfer(
|
|||
topo=from_topo,
|
||||
)
|
||||
]
|
||||
storage_tracker.add_data(
|
||||
storage_tracker.add_data_synced(
|
||||
gpu_id,
|
||||
ids=[i + dp_rank * world_size for i in range(world_size)],
|
||||
key=key,
|
||||
|
|
|
@ -18,11 +18,7 @@ def model_class(request):
|
|||
return request.param
|
||||
|
||||
|
||||
def run_model_worker(
|
||||
cfg,
|
||||
mw,
|
||||
barrier,
|
||||
):
|
||||
def run_model_worker(cfg, mw, barrier, expr_name=None):
|
||||
constants.set_force_cpu(True)
|
||||
# Register all datasets and models
|
||||
import realhf.impl.dataset # isort: skip
|
||||
|
@ -31,7 +27,7 @@ def run_model_worker(
|
|||
from realhf.system.model_worker import ModelWorker
|
||||
|
||||
system_api.ALL_EXPERIMENT_CLASSES = {}
|
||||
register_experiment(testing._DEFAULT_EXPR_NAME, lambda: cfg)
|
||||
register_experiment(expr_name or testing._DEFAULT_EXPR_NAME, lambda: cfg)
|
||||
|
||||
worker = ModelWorker()
|
||||
logger.info("Configuring model worker...")
|
||||
|
@ -61,7 +57,7 @@ def run_test_exp(
|
|||
from realhf.system.master_worker import MasterWorker
|
||||
|
||||
system_api.ALL_EXPERIMENT_CLASSES = {}
|
||||
register_experiment(testing._DEFAULT_EXPR_NAME, lambda: exp_cfg)
|
||||
register_experiment(expr_name or testing._DEFAULT_EXPR_NAME, lambda: exp_cfg)
|
||||
|
||||
# Get worker configurations
|
||||
exp_setup = exp_cfg.initial_setup()
|
||||
|
@ -87,6 +83,7 @@ def run_test_exp(
|
|||
cfg=exp_cfg,
|
||||
mw=mw,
|
||||
barrier=barrier,
|
||||
expr_name=expr_name,
|
||||
)
|
||||
for mw in exp_setup.model_worker
|
||||
],
|
||||
|
|
Loading…
Reference in New Issue