fix dataset

This commit is contained in:
Dun Liang 2020-04-10 13:21:34 +08:00
parent 69e7d367cd
commit a3eb40b211
2 changed files with 26 additions and 15 deletions

View File

@ -70,9 +70,11 @@ class Dataset(object):
def __len__(self):
assert self.total_len >= 0
assert self.batch_size > 0
real_len = (self.total_len-1)//mpi.world_size()+1 if mpi \
else self.total_len
if self.drop_last:
return self.total_len // self.batch_size
return (self.total_len-1) // self.batch_size + 1
return real_len // self.batch_size
return (real_len-1) // self.batch_size + 1
def set_attrs(self, **kw):
'''set attributes of dataset, equivalent to setattr
@ -131,8 +133,8 @@ class Dataset(object):
self.gidc.notify()
batch = []
if mp_log_v:
print(f"#{worker_id} {os.getpid()} load batch", cid*self.batch_size, min(self.total_len, (cid+1)*self.batch_size))
for i in range(cid*self.batch_size, min(self.total_len, (cid+1)*self.batch_size)):
print(f"#{worker_id} {os.getpid()} load batch", cid*self.batch_size, min(self.real_len, (cid+1)*self.batch_size))
for i in range(cid*self.batch_size, min(self.real_len, (cid+1)*self.batch_size)):
batch.append(self[self.index_list[i]])
batch = self.collate_batch(batch)
if mp_log_v:
@ -158,7 +160,7 @@ class Dataset(object):
w.buffer.clear()
def _init_workers(self):
self.index_list = mp.Array('i', self.total_len, lock=False)
self.index_list = mp.Array('i', self.real_len, lock=False)
workers = []
# batch id to worker id
self.idmap = mp.Array('i', self.batch_len, lock=False)
@ -175,7 +177,7 @@ class Dataset(object):
buffer_size=self.buffer_size)
workers.append(w)
self.workers = workers
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.total_len, buffer=self.index_list)
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)
def __del__(self):
if mp_log_v:
@ -190,14 +192,22 @@ class Dataset(object):
# scatter index_list for all mpi process
# scatter rule:
# [000000000111111111222222222]
# [........]
# 000111
# 222
# make sure each process has the same len
if mpi:
index_list = np.int32(index_list)
mpi.broadcast(index_list, 0)
new_len = (self.total_len - 1) // mpi.world_size() + 1
offset = mpi.world_rank() * new_len
index_list = index_list[offset:offset+new_len]
self.total_len = len(index_list)
real_len = (self.total_len - 1) // mpi.world_size() + 1
offset = mpi.world_rank() * real_len
if offset + real_len > self.total_len:
offset -= offset + real_len - self.total_len
index_list = index_list[offset:offset+real_len]
self.real_len = real_len
assert real_len == len(index_list)
else:
self.real_len = self.total_len
self.batch_len = len(self)
if "batch_len" in os.environ:

View File

@ -40,10 +40,11 @@ class TestMpi(unittest.TestCase):
toy = ToyDataset()
offset = ((toy.total_len-1) // mpi.world_size() + 1) * mpi.world_rank()
for i,(a,b) in enumerate(toy):
assert (a.data*a.data == b.data).all()
c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size))
assert (c==a.data).all()
for _ in range(2):
for i,(a,b) in enumerate(toy):
assert (a.data*a.data == b.data).all()
c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size))
assert (c==a.data).all()
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestMpiEntry(unittest.TestCase):