mirror of https://github.com/Jittor/Jittor
polish mpi dataset shuffle list
This commit is contained in:
parent
4217359f86
commit
1f16af33a2
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.92'
|
||||
__version__ = '1.2.3.93'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -90,6 +90,7 @@ class Dataset(object):
|
|||
self.epoch_id = 0
|
||||
self.sampler = None
|
||||
self._disable_workers = False
|
||||
self._shuffle_rng = np.random.default_rng(1)
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
@ -389,7 +390,10 @@ Example::
|
|||
elif self.shuffle == False:
|
||||
index_list = get_order_list(self.total_len)
|
||||
else:
|
||||
index_list = get_random_list(self.total_len)
|
||||
# using _shuffle_rng to generate multiprocess
|
||||
# consist shuffle list
|
||||
# index_list = get_random_list(self.total_len)
|
||||
index_list = self._shuffle_rng.permutation(range(self.total_len))
|
||||
|
||||
# scatter index_list for all mpi process
|
||||
# scatter rule:
|
||||
|
|
|
@ -217,7 +217,7 @@ void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
|
|||
return;
|
||||
}
|
||||
if (signal == SIGCHLD) {
|
||||
if (si->si_code != CLD_EXITED && si->si_status != SIGTERM) {
|
||||
if (si->si_code != CLD_EXITED && si->si_status != SIGTERM && _pid == getpid()) {
|
||||
LOGe << "Caught SIGCHLD"
|
||||
<< "si_errno:" << si->si_errno
|
||||
<< "si_code:" << si->si_code
|
||||
|
|
|
@ -254,6 +254,45 @@ for d in dataset:
|
|||
assert "quick exit" in s
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
def test_dataset_shuffle_mpi(self):
|
||||
src = """
|
||||
import jittor as jt
|
||||
from jittor.dataset import Dataset
|
||||
import numpy as np
|
||||
|
||||
class YourDataset(Dataset):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.set_attrs(total_len=160, shuffle=True)
|
||||
|
||||
def __getitem__(self, k):
|
||||
return k
|
||||
|
||||
dataset = YourDataset()
|
||||
dataset.set_attrs(num_workers=2)
|
||||
|
||||
for d in dataset:
|
||||
for a in d:
|
||||
print("CHECK: ", a.item())
|
||||
"""
|
||||
fname = os.path.join(jt.flags.cache_path, "test_dataset_shuffle_mpi.py")
|
||||
with open(fname, 'w') as f:
|
||||
f.write(src)
|
||||
import subprocess as sp
|
||||
import sys
|
||||
cmd = sys.executable + " " + fname
|
||||
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
||||
cmd = mpirun_path + " -np 2 " + cmd
|
||||
print(cmd)
|
||||
r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
|
||||
s = r.stdout.decode()
|
||||
# print(s)
|
||||
st = set([ l for l in s.splitlines() if l.startswith("CHECK:") ])
|
||||
assert r.returncode == 0
|
||||
# print(st)
|
||||
assert len(st) == 160, len(st)
|
||||
|
||||
def test_children_died2(self):
|
||||
src = """
|
||||
import jittor as jt
|
||||
|
|
Loading…
Reference in New Issue