fix: dataloader shuffle

different seeds for every epoch & every rank
This commit is contained in:
wanghuaijie.whj 2025-02-28 10:16:44 +08:00
parent 241185227d
commit 22f357b31c
2 changed files with 10 additions and 5 deletions

View File

@ -892,19 +892,22 @@ def register_dataloader(name, dataloader_cls):
def make_dataloader(
cfg: Union[str, config_api.DataLoaderAbstraction], dataset: torch.utils.data.Dataset
cfg: Union[str, config_api.DataLoaderAbstraction], dataset: torch.utils.data.Dataset, seed_offset: Optional[int] = None
) -> torch.utils.data.DataLoader:
if isinstance(cfg, str):
cfg = config_api.DataLoaderAbstraction(type_=cfg)
dataloader_cls = ALL_DATALOADER_CLASSES[cfg.type_]
return dataloader_cls(dataset, **cfg.args)
if seed_offset is None:
return dataloader_cls(dataset, **cfg.args)
else:
return dataloader_cls(dataset, **cfg.args, seed_offset=seed_offset)
def PackedDataLoader(dataset, *args, **kwargs):
def PackedDataLoader(dataset, *args, seed_offset: int = 0, **kwargs):
if not isinstance(getattr(dataset, "util", None), DatasetUtility):
raise ValueError("Dataset must have a `util` attribute of type DatasetUtility.")
g = torch.Generator()
g.manual_seed(dataset.util.seed)
g.manual_seed(dataset.util.seed + dist.get_rank() + seed_offset)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32

View File

@ -105,6 +105,7 @@ class NoRequestToHandle(Exception):
class ModelWorker(worker_base.Worker):
_setup_counter = -1
_seed_offset = 0
def _configure(self, cfg: system_api.ModelWorker):
self._setup_counter += 1
@ -601,8 +602,9 @@ class ModelWorker(worker_base.Worker):
self.__dataset.active_indices,
)
self.__dataloader = data_api.make_dataloader(
self.config.dataloader, self.__dataset
self.config.dataloader, self.__dataset, seed_offset=self._seed_offset,
)
self._seed_offset += 1
self.__data_generator = enumerate(self.__dataloader)
# Fetch.