mirror of https://github.com/Jittor/Jittor
fix dataset
This commit is contained in:
parent
69e7d367cd
commit
a3eb40b211
|
@ -70,9 +70,11 @@ class Dataset(object):
|
|||
def __len__(self):
|
||||
assert self.total_len >= 0
|
||||
assert self.batch_size > 0
|
||||
real_len = (self.total_len-1)//mpi.world_size()+1 if mpi \
|
||||
else self.total_len
|
||||
if self.drop_last:
|
||||
return self.total_len // self.batch_size
|
||||
return (self.total_len-1) // self.batch_size + 1
|
||||
return real_len // self.batch_size
|
||||
return (real_len-1) // self.batch_size + 1
|
||||
|
||||
def set_attrs(self, **kw):
|
||||
'''set attributes of dataset, equivalent to setattr
|
||||
|
@ -131,8 +133,8 @@ class Dataset(object):
|
|||
self.gidc.notify()
|
||||
batch = []
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} load batch", cid*self.batch_size, min(self.total_len, (cid+1)*self.batch_size))
|
||||
for i in range(cid*self.batch_size, min(self.total_len, (cid+1)*self.batch_size)):
|
||||
print(f"#{worker_id} {os.getpid()} load batch", cid*self.batch_size, min(self.real_len, (cid+1)*self.batch_size))
|
||||
for i in range(cid*self.batch_size, min(self.real_len, (cid+1)*self.batch_size)):
|
||||
batch.append(self[self.index_list[i]])
|
||||
batch = self.collate_batch(batch)
|
||||
if mp_log_v:
|
||||
|
@ -158,7 +160,7 @@ class Dataset(object):
|
|||
w.buffer.clear()
|
||||
|
||||
def _init_workers(self):
|
||||
self.index_list = mp.Array('i', self.total_len, lock=False)
|
||||
self.index_list = mp.Array('i', self.real_len, lock=False)
|
||||
workers = []
|
||||
# batch id to worker id
|
||||
self.idmap = mp.Array('i', self.batch_len, lock=False)
|
||||
|
@ -175,7 +177,7 @@ class Dataset(object):
|
|||
buffer_size=self.buffer_size)
|
||||
workers.append(w)
|
||||
self.workers = workers
|
||||
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.total_len, buffer=self.index_list)
|
||||
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)
|
||||
|
||||
def __del__(self):
|
||||
if mp_log_v:
|
||||
|
@ -190,14 +192,22 @@ class Dataset(object):
|
|||
|
||||
# scatter index_list for all mpi process
|
||||
# scatter rule:
|
||||
# [000000000111111111222222222]
|
||||
# [........]
|
||||
# 000111
|
||||
# 222
|
||||
# make sure each process has the same len
|
||||
if mpi:
|
||||
index_list = np.int32(index_list)
|
||||
mpi.broadcast(index_list, 0)
|
||||
new_len = (self.total_len - 1) // mpi.world_size() + 1
|
||||
offset = mpi.world_rank() * new_len
|
||||
index_list = index_list[offset:offset+new_len]
|
||||
self.total_len = len(index_list)
|
||||
real_len = (self.total_len - 1) // mpi.world_size() + 1
|
||||
offset = mpi.world_rank() * real_len
|
||||
if offset + real_len > self.total_len:
|
||||
offset -= offset + real_len - self.total_len
|
||||
index_list = index_list[offset:offset+real_len]
|
||||
self.real_len = real_len
|
||||
assert real_len == len(index_list)
|
||||
else:
|
||||
self.real_len = self.total_len
|
||||
|
||||
self.batch_len = len(self)
|
||||
if "batch_len" in os.environ:
|
||||
|
|
|
@ -40,10 +40,11 @@ class TestMpi(unittest.TestCase):
|
|||
toy = ToyDataset()
|
||||
offset = ((toy.total_len-1) // mpi.world_size() + 1) * mpi.world_rank()
|
||||
|
||||
for i,(a,b) in enumerate(toy):
|
||||
assert (a.data*a.data == b.data).all()
|
||||
c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size))
|
||||
assert (c==a.data).all()
|
||||
for _ in range(2):
|
||||
for i,(a,b) in enumerate(toy):
|
||||
assert (a.data*a.data == b.data).all()
|
||||
c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size))
|
||||
assert (c==a.data).all()
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestMpiEntry(unittest.TestCase):
|
||||
|
|
Loading…
Reference in New Issue