PullRequest: 4 Fix the dataloader shuffle and random seed issue.

Merge branch fw/fix-dataloading-not-shuffle of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/4

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fw/fix-dataloading-not-shuffle
* .
* .
* .
This commit is contained in:
博惟 2025-02-28 14:56:47 +08:00
parent d04f17031a
commit e7c4a49adc
12 changed files with 221 additions and 271 deletions

View File

@ -16,12 +16,6 @@ class DatasetAbstraction:
args: Dict[str, Any] = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class DataLoaderAbstraction:
type_: str = "default"
args: Dict[str, Any] = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class ModelWrapperAbstraction:
type_: str
@ -187,10 +181,6 @@ class StandaloneModelShardAbstraction:
model: ModelAbstraction
backend: ModelBackendAbstraction
# evaluation
eval_datasets: Optional[List[DatasetAbstraction]] = None
eval_dataloader: Optional[DataLoaderAbstraction] = dataclasses.field(
default_factory=lambda: DataLoaderAbstraction(
"packed_eval", args=dict(batch_size=128)
)
)
eval_dataset: Optional[DatasetAbstraction] = None
eval_bs: int = 128
should_instantiate: bool = True

View File

@ -881,64 +881,3 @@ def make_dataset(
logger.info(f"Dataset creation/loading time: {time.perf_counter() - tik:.3f}s")
return dataset
ALL_DATALOADER_CLASSES = {}
def register_dataloader(name, dataloader_cls):
assert name not in ALL_DATALOADER_CLASSES
ALL_DATALOADER_CLASSES[name] = dataloader_cls
def make_dataloader(
cfg: Union[str, config_api.DataLoaderAbstraction], dataset: torch.utils.data.Dataset, seed_offset: Optional[int] = None
) -> torch.utils.data.DataLoader:
if isinstance(cfg, str):
cfg = config_api.DataLoaderAbstraction(type_=cfg)
dataloader_cls = ALL_DATALOADER_CLASSES[cfg.type_]
if seed_offset is None:
return dataloader_cls(dataset, **cfg.args)
else:
return dataloader_cls(dataset, **cfg.args, seed_offset=seed_offset)
def PackedDataLoader(dataset, *args, seed_offset: int = 0, **kwargs):
if not isinstance(getattr(dataset, "util", None), DatasetUtility):
raise ValueError("Dataset must have a `util` attribute of type DatasetUtility.")
g = torch.Generator()
g.manual_seed(dataset.util.seed + dist.get_rank() + seed_offset)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
return torch.utils.data.DataLoader(
dataset,
*args,
collate_fn=SequenceSample.gather,
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
batch_size=10240,
shuffle=True,
generator=g,
worker_init_fn=seed_worker,
**kwargs,
)
def PackedEvalDataLoader(dataset, *args, **kwargs):
if not isinstance(getattr(dataset, "util", None), DatasetUtility):
raise ValueError("Dataset must have a `util` attribute of type DatasetUtility.")
return torch.utils.data.DataLoader(
dataset,
*args,
collate_fn=SequenceSample.gather,
shuffle=False,
**kwargs,
)
register_dataloader("packed", PackedDataLoader)
register_dataloader("packed_eval", PackedEvalDataLoader)

View File

@ -15,7 +15,6 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import realhf.api.core.dfg as dfg
from realhf.api.core.config import (
DataLoaderAbstraction,
DatasetAbstraction,
ModelAbstraction,
ModelName,
@ -114,17 +113,13 @@ class WorkerInformation:
@dataclasses.dataclass
class ModelWorker:
seed: int
base_seed: int
shards: List[StandaloneModelShardAbstraction]
# dataset, for source model workers
tokenizer_name_or_path: Optional[str] = None
datasets: Optional[List[Union[str, DatasetAbstraction]]] = None
dataloader: Union[str, DataLoaderAbstraction] = "packed"
use_dataset_cache: bool = False
dataset_cahce_root: str = constants.DATASET_CACHE_PATH
# cuda & cudnn config
cudnn_benchmark: bool = False
cudnn_deterministic: bool = False
cuda_cache_cleanliness: bool = True
cuda_cache_clear_freq: int = 10
torch_cache_mysophobia: bool = False
@ -210,12 +205,13 @@ class ExperimentSaveEvalControl:
@dataclasses.dataclass
class MasterWorker:
base_seed: int
exp_ctrl: ExperimentSaveEvalControl
# main components
n_model_workers: int
model_rpcs: List[dfg.MFCDef] = None
model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = None
msid2mwid: Dict[ModelShardID, int] = None
msid2mwid: Dict[ModelShardID | str, int] = None
data_transfer_pairs: List[Tuple[str, str]] = None
sync_param_pairs: List[Tuple[str, str]] = None
worker_info: Optional[WorkerInformation] = None
@ -263,7 +259,11 @@ class ExperimentConfig:
def __post_init__(self):
self.master_worker = [
MasterWorker(exp_ctrl=self.exp_ctrl, n_model_workers=len(self.model_worker))
MasterWorker(
base_seed=self.model_worker[0].base_seed,
exp_ctrl=self.exp_ctrl,
n_model_workers=len(self.model_worker),
)
]
def lazy_init(self):

View File

@ -2,17 +2,29 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import os
import random
import numpy as np
import torch
import transformers
_SEED = None
def set_random_seed(seed):
global _SEED
_SEED = seed
os.environ["PYTHONHASHSEED"] = str(seed)
transformers.set_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_seed() -> int:
global _SEED
return _SEED

View File

@ -18,7 +18,6 @@ from omegaconf import MISSING, OmegaConf
import realhf.base.logging as logging
from realhf.api.core.config import (
DataLoaderAbstraction,
DatasetAbstraction,
ModelAbstraction,
ModelBackendAbstraction,
@ -256,22 +255,17 @@ class CommonExperimentConfig(Experiment):
return NotImplementedError(f"datasets is not implemented in {self.__class__}")
@property
def eval_datasets(self) -> List[DatasetAbstraction]:
"""A list of dataset configurations used for evaluation.
def eval_dataset(self) -> DatasetAbstraction | None:
"""The dataset configuration used for evaluation.
Can be None if runtime evaluation is not needed.
"""
return None
@property
def eval_dataloader(self) -> DataLoaderAbstraction:
"""The dataloader configuration used for evaluation.
Reserved to changed the evaluation batch size. Training does not
require this property because the batch size is handled in MFC
definitions.
"""
return DataLoaderAbstraction("packed_eval", args=dict(batch_size=128))
def eval_bs(self) -> int:
"""The batch size for runtime evaluation."""
return 128
@property
def tokenizer_name_or_path(self) -> str:
@ -553,7 +547,7 @@ class CommonExperimentConfig(Experiment):
for i, j in itertools.product(range(self.n_nodes), range(self.n_gpus_per_node)):
mw = ModelWorker(
seed=self.seed,
base_seed=self.seed,
shards=[],
datasets=self.datasets,
torch_cache_mysophobia=self.torch_cache_mysophobia,
@ -611,7 +605,6 @@ class CommonExperimentConfig(Experiment):
backend=ModelBackendAbstraction(
"vllm",
args=dict(
seed=self.seed,
model_path=model_cfg.path,
**vllm_dict_args,
),
@ -691,7 +684,6 @@ class CommonExperimentConfig(Experiment):
backend = ModelBackendAbstraction(
"vllm",
args=dict(
seed=self.seed,
model_path=model_cfg.path,
**vllm_dict_args,
),
@ -713,8 +705,8 @@ class CommonExperimentConfig(Experiment):
),
model=model,
backend=backend,
eval_datasets=self.eval_datasets,
eval_dataloader=self.eval_dataloader,
eval_dataset=self.eval_dataset,
eval_bs=self.eval_bs,
)
)
shard_counter[model_name] += 1

View File

@ -5,7 +5,6 @@
import dataclasses
from realhf.api.core.config import (
DataLoaderAbstraction,
DatasetAbstraction,
ModelInterfaceAbstraction,
ModelInterfaceType,
@ -82,22 +81,18 @@ class SFTConfig(CommonExperimentConfig):
]
@property
def eval_datasets(self):
return [
DatasetAbstraction(
"prompt_answer",
args=dict(
max_length=self.dataset.max_seqlen,
dataset_path=self.dataset.valid_path,
),
)
]
def eval_dataset(self):
return DatasetAbstraction(
"prompt_answer",
args=dict(
max_length=self.dataset.max_seqlen,
dataset_path=self.dataset.valid_path,
),
)
@property
def eval_dataloader(self):
return DataLoaderAbstraction(
"packed_eval", args=dict(batch_size=self.dataset.valid_bs_n_seqs)
)
def eval_bs(self) -> int:
return self.dataset.valid_bs_n_seqs
@property
def tokenizer_name_or_path(self):

View File

@ -31,7 +31,7 @@ except ModuleNotFoundError:
from realhf.api.core import data_api, model_api
from realhf.api.quickstart.model import vLLMConfig
from realhf.base import constants, logging
from realhf.base import constants, logging, seeding
logger = logging.getLogger("vLLM backend")
@ -165,7 +165,6 @@ class vLLMGenerationEngine(model_api.PipelinableEngine, LLM):
@dataclasses.dataclass
class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend):
seed: int = 0
model_path: str = ""
def _initialize(
@ -187,7 +186,7 @@ class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend):
skip_tokenizer_init=False,
trust_remote_code=True,
max_model_len=self.max_model_len,
seed=self.seed,
seed=seeding.get_seed(),
dtype=torch.float16,
kv_cache_dtype=self.kv_cache_type,
device=constants.current_device(),

View File

@ -152,10 +152,7 @@ class SFTInterface(model_api.ModelInterface):
res = module.eval_batch(
input_=x.to_device(device),
loss_fn=compute_packed_sft_loss,
mb_spec=MicroBatchSpec(
n_mbs=constants.pipe_parallel_world_size(),
balanced_seqs=True,
),
mb_spec=MicroBatchSpec(),
)
if res is not None:

View File

@ -0,0 +1,117 @@
import dataclasses
from typing import *
from realhf.api.core.data_api import SequenceSample
from realhf.api.core.dfg import MFCDef, ModelInterfaceType
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base.monitor import (
caculuate_llama_forward_flops,
calculate_llama_gen_flops,
calculate_llama_train_flops,
)
@dataclasses.dataclass
class FlopsCounter:
train_configs: List[ReaLModelConfig] = dataclasses.field(default_factory=list)
train_bs: List[int] = dataclasses.field(default_factory=list)
train_seqlens: List[List[int]] = dataclasses.field(default_factory=list)
inf_configs: List[ReaLModelConfig] = dataclasses.field(default_factory=list)
inf_bs: List[int] = dataclasses.field(default_factory=list)
inf_seqlens: List[List[int]] = dataclasses.field(default_factory=list)
gen_configs: List[ReaLModelConfig] = dataclasses.field(default_factory=list)
gen_bs: List[int] = dataclasses.field(default_factory=list)
prompt_lens: List[List[int]] = dataclasses.field(default_factory=list)
gen_len: List[int] = dataclasses.field(default_factory=list)
def clear(self):
self.train_bs.clear()
self.train_seqlens.clear()
self.inf_bs.clear()
self.inf_seqlens.clear()
self.gen_bs.clear()
self.prompt_lens.clear()
self.gen_len.clear()
self.train_configs.clear()
self.inf_configs.clear()
self.gen_configs.clear()
def add_rpc(
self, rpc: MFCDef, sample: SequenceSample, model_config: ReaLModelConfig
):
# Record the data amount for each interface to compute FLOPs.
# Since the user may arbitrarily specify input/output keys,
# we can only try to find the most probable key name for computing FLOPs.
# If such keys do not exist, we will use the key with the longest
# sequence length in this model function call.
acc_seqlens = {
k: sum(sum(x) for x in slens) for k, slens in sample.seqlens.items()
}
seqlen_key = max(sample.seqlens, key=acc_seqlens.get)
flops_seqlens = [sum(x) for x in sample.seqlens[seqlen_key]]
if rpc.interface_type == ModelInterfaceType.GENERATE:
self.gen_configs.append(model_config)
self.gen_bs.append(sample.bs)
self.gen_len.append(
rpc.interface_impl.args["generation_config"]["min_new_tokens"]
)
self.prompt_lens.append(flops_seqlens)
elif rpc.interface_type == ModelInterfaceType.TRAIN_STEP:
self.train_configs.append(model_config)
self.train_bs.append(sample.bs)
self.train_seqlens.append(flops_seqlens)
elif rpc.interface_type == ModelInterfaceType.INFERENCE:
self.inf_configs.append(model_config)
self.inf_bs.append(sample.bs)
self.inf_seqlens.append(flops_seqlens)
def get_flops(self) -> int:
flops = 0
for train_bs, train_seqlens, real_config in zip(
self.train_bs,
self.train_seqlens,
self.train_configs,
):
flops += calculate_llama_train_flops(
checkpoint_activations_factor=4,
batch_size=train_bs,
seqlens=train_seqlens,
num_layers=real_config.n_layers,
hidden_size=real_config.hidden_dim,
intermediate_size=real_config.intermediate_dim,
vocab_size=real_config.vocab_size,
)
for inf_bs, inf_seqlens, real_config in zip(
self.inf_bs,
self.inf_seqlens,
self.inf_configs,
):
flops += caculuate_llama_forward_flops(
batch_size=inf_bs,
seqlens=inf_seqlens,
num_layers=real_config.n_layers,
hidden_size=real_config.hidden_dim,
intermediate_size=real_config.intermediate_dim,
vocab_size=real_config.vocab_size,
)
for gen_bs, prompt_lens, gen_len, real_config in zip(
self.gen_bs,
self.prompt_lens,
self.gen_len,
self.gen_configs,
):
flops += calculate_llama_gen_flops(
batch_size=gen_bs,
prompt_lens=prompt_lens,
gen_len=gen_len,
num_layers=real_config.n_layers,
hidden_size=real_config.hidden_dim,
intermediate_size=real_config.intermediate_dim,
vocab_size=real_config.vocab_size,
)
return flops

View File

@ -11,6 +11,7 @@ import getpass
import itertools
import os
import pprint
import random
import re
import time
import uuid
@ -40,6 +41,7 @@ from realhf.base import (
logging,
name_resolve,
names,
seeding,
timeutil,
topology,
)
@ -48,12 +50,8 @@ from realhf.base.asyncio_utils import (
setup_run_until_complete,
teardown_run_util_complete,
)
from realhf.base.monitor import (
caculuate_llama_forward_flops,
calculate_llama_gen_flops,
calculate_llama_train_flops,
)
from realhf.system.buffer import AsyncIOSequenceBuffer
from realhf.system.flops_counter import FlopsCounter
logger = logging.getLogger("master worker", "system")
blogger = logging.getLogger("benchmark")
@ -107,37 +105,6 @@ def _attach_param_realloc_hooks(
return payload
@dataclasses.dataclass
class InterfaceDataAmount:
train_configs: List[ReaLModelConfig] = dataclasses.field(default_factory=list)
train_bs: List[int] = dataclasses.field(default_factory=list)
train_seqlens: List[List[int]] = dataclasses.field(default_factory=list)
inf_configs: List[ReaLModelConfig] = dataclasses.field(default_factory=list)
inf_bs: List[int] = dataclasses.field(default_factory=list)
inf_seqlens: List[List[int]] = dataclasses.field(default_factory=list)
gen_configs: List[ReaLModelConfig] = dataclasses.field(default_factory=list)
gen_bs: List[int] = dataclasses.field(default_factory=list)
prompt_lens: List[List[int]] = dataclasses.field(default_factory=list)
gen_len: List[int] = dataclasses.field(default_factory=list)
def clear(self):
self.train_bs.clear()
self.train_seqlens.clear()
self.inf_bs.clear()
self.inf_seqlens.clear()
self.gen_bs.clear()
self.prompt_lens.clear()
self.gen_len.clear()
self.train_configs.clear()
self.inf_configs.clear()
self.gen_configs.clear()
@dataclasses.dataclass
class RPCCorountineControl:
## Shared resources ##
@ -157,9 +124,7 @@ class RPCCorountineControl:
# for training data management and data cleaning after each step
ids_to_clear: Set[int] = dataclasses.field(default_factory=set)
data_amount: InterfaceDataAmount = dataclasses.field(
default_factory=InterfaceDataAmount
)
flops_counter: FlopsCounter = dataclasses.field(default_factory=FlopsCounter)
should_save: bool = False
should_eval: bool = False
@ -408,31 +373,7 @@ async def model_rpc_request_func(
buf_indices, sample = await buffer.get_batch_for_rpc(rpc)
# Record the data amount for each interface to compute FLOPs.
# Since the user may arbitrarily specify input/output keys,
# we can only try to find the most probable key name for computing FLOPs.
# If such keys do not exist, we will use the key with the longest
# sequence length in this model function call.
acc_seqlens = {
k: sum(sum(x) for x in slens) for k, slens in sample.seqlens.items()
}
seqlen_key = max(sample.seqlens, key=acc_seqlens.get)
flops_seqlens = [sum(x) for x in sample.seqlens[seqlen_key]]
if rpc.interface_type == dfg.ModelInterfaceType.GENERATE:
ctrl.data_amount.gen_configs.append(model_configs[rpc.model_name])
ctrl.data_amount.gen_bs.append(sample.bs)
ctrl.data_amount.gen_len.append(
rpc.interface_impl.args["generation_config"]["min_new_tokens"]
)
ctrl.data_amount.prompt_lens.append(flops_seqlens)
elif rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
ctrl.data_amount.train_configs.append(model_configs[rpc.model_name])
ctrl.data_amount.train_bs.append(sample.bs)
ctrl.data_amount.train_seqlens.append(flops_seqlens)
elif rpc.interface_type == dfg.ModelInterfaceType.INFERENCE:
ctrl.data_amount.inf_configs.append(model_configs[rpc.model_name])
ctrl.data_amount.inf_bs.append(sample.bs)
ctrl.data_amount.inf_seqlens.append(flops_seqlens)
ctrl.flops_counter.add_rpc(rpc, sample, model_configs[rpc.model_name])
this_rpc_consumed_seqs += sample.bs
@ -599,6 +540,8 @@ class MasterWorker(worker_base.Worker):
def _configure(self, config: config_pkg.MasterWorker):
self.config = config
seeding.set_random_seed(self.config.base_seed + self.config.n_model_workers)
self.__model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = (
config.model_topos
)
@ -1248,12 +1191,9 @@ class MasterWorker(worker_base.Worker):
filtered_data.append(x)
all_data = filtered_data
# Reorder loaded (meta-)data and store them into the buffer.
# NOTE: The reordered indices prioritize longer sequences for detecting OOM errors early.
# reorder_indices, _ = datapack.reorder_to_balanced_batches(
# np.array(seqlens), src_rpc.n_seqs
# )
# all_data: List[data_api.SequenceSample] = [all_data[i] for i in reorder_indices]
# We load data in a round-robin manner across different DP ranks,
# so we also need to shuffle the data to fuse different dataset splits.
random.shuffle(all_data)
blogger.info(
f"Master worker loaded {len(all_data)} pieces of data. "
@ -1287,52 +1227,10 @@ class MasterWorker(worker_base.Worker):
flops = None
tflops_per_gpu = float("inf")
else:
flops = 0
for train_bs, train_seqlens, real_config in zip(
self.__rpc_ctrl.data_amount.train_bs,
self.__rpc_ctrl.data_amount.train_seqlens,
self.__rpc_ctrl.data_amount.train_configs,
):
flops += calculate_llama_train_flops(
checkpoint_activations_factor=4,
batch_size=train_bs,
seqlens=train_seqlens,
num_layers=real_config.n_layers,
hidden_size=real_config.hidden_dim,
intermediate_size=real_config.intermediate_dim,
vocab_size=real_config.vocab_size,
)
for inf_bs, inf_seqlens, real_config in zip(
self.__rpc_ctrl.data_amount.inf_bs,
self.__rpc_ctrl.data_amount.inf_seqlens,
self.__rpc_ctrl.data_amount.inf_configs,
):
flops += caculuate_llama_forward_flops(
batch_size=inf_bs,
seqlens=inf_seqlens,
num_layers=real_config.n_layers,
hidden_size=real_config.hidden_dim,
intermediate_size=real_config.intermediate_dim,
vocab_size=real_config.vocab_size,
)
for gen_bs, prompt_lens, gen_len, real_config in zip(
self.__rpc_ctrl.data_amount.gen_bs,
self.__rpc_ctrl.data_amount.prompt_lens,
self.__rpc_ctrl.data_amount.gen_len,
self.__rpc_ctrl.data_amount.gen_configs,
):
flops += calculate_llama_gen_flops(
batch_size=gen_bs,
prompt_lens=prompt_lens,
gen_len=gen_len,
num_layers=real_config.n_layers,
hidden_size=real_config.hidden_dim,
intermediate_size=real_config.intermediate_dim,
vocab_size=real_config.vocab_size,
)
flops = self.__rpc_ctrl.flops_counter.get_flops()
tflops = flops / (e2e_time * (10**12))
tflops_per_gpu = flops / (e2e_time * self.config.n_model_workers * (10**12))
self.__rpc_ctrl.data_amount.clear()
self.__rpc_ctrl.flops_counter.clear()
#########################################
epoch = self.__rpc_ctrl.step_info.epoch + 1

View File

@ -105,7 +105,6 @@ class NoRequestToHandle(Exception):
class ModelWorker(worker_base.Worker):
_setup_counter = -1
_seed_offset = 0
def _configure(self, cfg: system_api.ModelWorker):
self._setup_counter += 1
@ -126,10 +125,7 @@ class ModelWorker(worker_base.Worker):
self.__worker_index = cfg.worker_info.worker_index
torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
torch.backends.cudnn.deterministic = cfg.cudnn_deterministic
seeding.set_random_seed(cfg.seed)
seeding.set_random_seed(cfg.base_seed + self.__worker_index)
# Reveal process group identity of this worker to world.
gpu_utils.reveal_pg_identity(
@ -308,7 +304,8 @@ class ModelWorker(worker_base.Worker):
datasets = [
data_api.make_dataset(
d,
self.config.seed,
# NOTE: we must use the same seed to ensure the same dataset split
self.config.base_seed,
self.__dataset_dp_rank,
self.__dataset_dp_size,
self.config.tokenizer_name_or_path,
@ -327,10 +324,20 @@ class ModelWorker(worker_base.Worker):
else:
self.__dataset = torch.utils.data.ConcatDataset(datasets)
g = torch.Generator()
g.manual_seed(seeding.get_seed())
self.__dataloader = torch.utils.data.DataLoader(
self.__dataset,
collate_fn=data_api.SequenceSample.gather,
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
batch_size=10240,
shuffle=True,
generator=g,
)
self.__raw_samples = []
for tmp_sample in data_api.make_dataloader(
self.config.dataloader, self.__dataset
):
for tmp_sample in self.__dataloader:
self.__raw_samples += tmp_sample.meta().unpack()
self.__models: Dict[ModelName, model_api.Model] = dict()
@ -406,30 +413,27 @@ class ModelWorker(worker_base.Worker):
interface_impl[0]
)
if s.eval_datasets is not None and s.eval_dataloader is not None:
eval_datasets = [
data_api.make_dataset(
d,
self.config.seed,
s.id.dp_rank,
s.id.topo.get_dim("data"),
self.__models[s.id.model_name].tokenizer,
self.config.worker_info.experiment_name,
self.config.worker_info.trial_name,
cache_root=(
None
if not self.config.use_dataset_cache
else self.config.dataset_cahce_root
),
)
for d in s.eval_datasets
]
if len(eval_datasets) > 1:
eval_dataset = torch.utils.data.ConcatDataset(eval_datasets)
else:
eval_dataset = eval_datasets[0]
eval_dataloader = data_api.make_dataloader(
s.eval_dataloader, eval_dataset
if s.eval_dataset is not None:
eval_dataset = data_api.make_dataset(
s.eval_dataset,
# NOTE: we must use the same seed to ensure the same dataset split
self.config.base_seed,
s.id.dp_rank,
s.id.topo.get_dim("data"),
self.__models[s.id.model_name].tokenizer,
self.config.worker_info.experiment_name,
self.config.worker_info.trial_name,
cache_root=(
None
if not self.config.use_dataset_cache
else self.config.dataset_cahce_root
),
)
eval_dataloader = torch.utils.data.DataLoader(
eval_dataset,
batch_size=s.eval_bs,
collate_fn=data_api.SequenceSample.gather,
shuffle=False,
)
else:
eval_dataloader = None
@ -601,10 +605,17 @@ class ModelWorker(worker_base.Worker):
dataset_indices_path,
self.__dataset.active_indices,
)
self.__dataloader = data_api.make_dataloader(
self.config.dataloader, self.__dataset, seed_offset=self._seed_offset,
g = torch.Generator()
g = g.set_state(self.__dataloader.generator.get_state())
self.__dataloader = torch.utils.data.DataLoader(
self.__dataset,
collate_fn=data_api.SequenceSample.gather,
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
batch_size=10240,
shuffle=True,
generator=g,
)
self._seed_offset += 1
self.__data_generator = enumerate(self.__dataloader)
# Fetch.

View File

@ -156,7 +156,7 @@ class NameResolvingRequestClient:
def request(
self,
handlers: List[str] | None = None,
handlers: List[str | int] | None = None,
handle_type: str | None = None,
datas: List[Any] | None = None,
payloads: List[Payload] | None = None,
@ -222,7 +222,7 @@ class NameResolvingRequestClient:
def call(
self,
handlers: List[str] | None = None,
handlers: List[str | int] | None = None,
handle_type: str | None = None,
datas: List[Any] | None = None,
payloads: List[Payload] | None = None,