polish mpi dataset shuffle list

This commit is contained in:
Dun Liang 2021-08-24 17:22:39 +08:00
parent 4217359f86
commit 1f16af33a2
4 changed files with 46 additions and 3 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.92'
__version__ = '1.2.3.93'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -90,6 +90,7 @@ class Dataset(object):
self.epoch_id = 0
self.sampler = None
self._disable_workers = False
self._shuffle_rng = np.random.default_rng(1)
def __getitem__(self, index):
raise NotImplementedError
@ -389,7 +390,10 @@ Example::
elif self.shuffle == False:
index_list = get_order_list(self.total_len)
else:
index_list = get_random_list(self.total_len)
# using _shuffle_rng to generate multiprocess
# consist shuffle list
# index_list = get_random_list(self.total_len)
index_list = self._shuffle_rng.permutation(range(self.total_len))
# scatter index_list for all mpi process
# scatter rule:

View File

@ -217,7 +217,7 @@ void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
return;
}
if (signal == SIGCHLD) {
if (si->si_code != CLD_EXITED && si->si_status != SIGTERM) {
if (si->si_code != CLD_EXITED && si->si_status != SIGTERM && _pid == getpid()) {
LOGe << "Caught SIGCHLD"
<< "si_errno:" << si->si_errno
<< "si_code:" << si->si_code

View File

@ -254,6 +254,45 @@ for d in dataset:
assert "quick exit" in s
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
def test_dataset_shuffle_mpi(self):
src = """
import jittor as jt
from jittor.dataset import Dataset
import numpy as np
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=160, shuffle=True)
def __getitem__(self, k):
return k
dataset = YourDataset()
dataset.set_attrs(num_workers=2)
for d in dataset:
for a in d:
print("CHECK: ", a.item())
"""
fname = os.path.join(jt.flags.cache_path, "test_dataset_shuffle_mpi.py")
with open(fname, 'w') as f:
f.write(src)
import subprocess as sp
import sys
cmd = sys.executable + " " + fname
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
cmd = mpirun_path + " -np 2 " + cmd
print(cmd)
r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
s = r.stdout.decode()
# print(s)
st = set([ l for l in s.splitlines() if l.startswith("CHECK:") ])
assert r.returncode == 0
# print(st)
assert len(st) == 160, len(st)
def test_children_died2(self):
src = """
import jittor as jt