mirror of https://github.com/inclusionAI/AReaL
fix: dataloader shuffle
different seeds for every epoch & every rank
This commit is contained in:
parent
241185227d
commit
22f357b31c
|
@ -892,19 +892,22 @@ def register_dataloader(name, dataloader_cls):
|
|||
|
||||
|
||||
def make_dataloader(
|
||||
cfg: Union[str, config_api.DataLoaderAbstraction], dataset: torch.utils.data.Dataset
|
||||
cfg: Union[str, config_api.DataLoaderAbstraction], dataset: torch.utils.data.Dataset, seed_offset: Optional[int] = None
|
||||
) -> torch.utils.data.DataLoader:
|
||||
if isinstance(cfg, str):
|
||||
cfg = config_api.DataLoaderAbstraction(type_=cfg)
|
||||
dataloader_cls = ALL_DATALOADER_CLASSES[cfg.type_]
|
||||
return dataloader_cls(dataset, **cfg.args)
|
||||
if seed_offset is None:
|
||||
return dataloader_cls(dataset, **cfg.args)
|
||||
else:
|
||||
return dataloader_cls(dataset, **cfg.args, seed_offset=seed_offset)
|
||||
|
||||
|
||||
def PackedDataLoader(dataset, *args, **kwargs):
|
||||
def PackedDataLoader(dataset, *args, seed_offset: int = 0, **kwargs):
|
||||
if not isinstance(getattr(dataset, "util", None), DatasetUtility):
|
||||
raise ValueError("Dataset must have a `util` attribute of type DatasetUtility.")
|
||||
g = torch.Generator()
|
||||
g.manual_seed(dataset.util.seed)
|
||||
g.manual_seed(dataset.util.seed + dist.get_rank() + seed_offset)
|
||||
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = torch.initial_seed() % 2**32
|
||||
|
|
|
@ -105,6 +105,7 @@ class NoRequestToHandle(Exception):
|
|||
|
||||
class ModelWorker(worker_base.Worker):
|
||||
_setup_counter = -1
|
||||
_seed_offset = 0
|
||||
|
||||
def _configure(self, cfg: system_api.ModelWorker):
|
||||
self._setup_counter += 1
|
||||
|
@ -601,8 +602,9 @@ class ModelWorker(worker_base.Worker):
|
|||
self.__dataset.active_indices,
|
||||
)
|
||||
self.__dataloader = data_api.make_dataloader(
|
||||
self.config.dataloader, self.__dataset
|
||||
self.config.dataloader, self.__dataset, seed_offset=self._seed_offset,
|
||||
)
|
||||
self._seed_offset += 1
|
||||
self.__data_generator = enumerate(self.__dataloader)
|
||||
|
||||
# Fetch.
|
||||
|
|
Loading…
Reference in New Issue