polish mpi dataset, better batch size

This commit is contained in:
Dun Liang 2020-04-21 15:48:49 +08:00
parent 2315928240
commit c94995e640
2 changed files with 64 additions and 24 deletions

View File

@ -70,11 +70,9 @@ class Dataset(object):
def __len__(self):
assert self.total_len >= 0
assert self.batch_size > 0
real_len = (self.total_len-1)//mpi.world_size()+1 if mpi \
else self.total_len
if self.drop_last:
return real_len // self.batch_size
return (real_len-1) // self.batch_size + 1
return self.total_len // self.batch_size
return (self.total_len-1) // self.batch_size + 1
def set_attrs(self, **kw):
'''set attributes of dataset, equivalent to setattr
@ -133,8 +131,8 @@ class Dataset(object):
self.gidc.notify()
batch = []
if mp_log_v:
print(f"#{worker_id} {os.getpid()} load batch", cid*self.batch_size, min(self.real_len, (cid+1)*self.batch_size))
for i in range(cid*self.batch_size, min(self.real_len, (cid+1)*self.batch_size)):
print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size))
for i in range(cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)):
batch.append(self[self.index_list[i]])
batch = self.collate_batch(batch)
if mp_log_v:
@ -192,26 +190,54 @@ class Dataset(object):
# scatter index_list for all mpi process
# scatter rule:
# [........]
# 000111
# 222
# make sure each process has the same len
# batch 1 batch 2
# [........] [........] ...
# 00011122 00011122
# if last batch is smaller than world_size
# pad to world_size
# last batch
# [.] -> [012]
if mpi:
world_size = mpi.world_size()
world_rank = mpi.world_rank()
index_list = np.int32(index_list)
mpi.broadcast(index_list, 0)
real_len = (self.total_len - 1) // mpi.world_size() + 1
offset = mpi.world_rank() * real_len
if offset + real_len > self.total_len:
offset -= offset + real_len - self.total_len
index_list = index_list[offset:offset+real_len]
self.real_len = real_len
assert real_len == len(index_list)
assert self.batch_size >= world_size, \
f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})"
real_batch_size = (self.batch_size-1) // world_size + 1
if real_batch_size * world_size != self.batch_size:
LOG.w("Batch size is not divisible by MPI world size, "
"The distributed version may be different from "
"the single-process version.")
fix_batch = self.total_len // self.batch_size
last_batch = self.total_len - fix_batch * self.batch_size
fix_batch_l = index_list[0:fix_batch*self.batch_size] \
.reshape(-1,self.batch_size)
fix_batch_l = fix_batch_l[
:,real_batch_size*world_rank:real_batch_size*(world_rank+1)]
real_batch_size = fix_batch_l.shape[1]
fix_batch_l = fix_batch_l.flatten()
if not self.drop_last and last_batch > 0:
last_batch_l = index_list[-last_batch:]
real_last_batch = (last_batch-1)//world_size+1
l = real_last_batch * world_rank
r = l + real_last_batch
if r > last_batch: r = last_batch
if l >= r: l = r-1
index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]])
else:
index_list = fix_batch_l
self.real_len = len(index_list)
self.real_batch_size = real_batch_size
assert self.total_len // self.batch_size == \
self.real_len // self.real_batch_size
else:
self.real_len = self.total_len
self.real_batch_size = self.batch_size
self.batch_len = len(self)
if "batch_len" in os.environ:
self.batch_len = int(os.environ["batch_len"])
if not hasattr(self, "workers") and self.num_workers:
self._init_workers()
@ -245,7 +271,7 @@ class Dataset(object):
batch_data = []
for idx in index_list:
batch_data.append(self[int(idx)])
if len(batch_data) == self.batch_size:
if len(batch_data) == self.real_batch_size:
batch_data = self.collate_batch(batch_data)
batch_data = self.to_jittor(batch_data)
yield batch_data

View File

@ -33,19 +33,33 @@ class TestMpi(unittest.TestCase):
class ToyDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=1024)
self.set_attrs(batch_size=21, total_len=211)
def __getitem__(self, index):
return index, index*index
toy = ToyDataset()
offset = ((toy.total_len-1) // mpi.world_size() + 1) * mpi.world_rank()
offset = ((toy.batch_size-1) // mpi.world_size() + 1) * mpi.world_rank()
for _ in range(2):
for i,(a,b) in enumerate(toy):
assert (a.data*a.data == b.data).all()
c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size))
assert (c==a.data).all()
if mpi.world_rank() == 0:
if i == len(toy)-1:
assert a.shape[0] == 1
c = np.array([210])
else:
assert toy.real_batch_size == 11
c = np.array(range(offset+i*toy.batch_size, offset+i*toy.batch_size + toy.real_batch_size))
else:
if i == len(toy)-1:
assert a.shape[0] == 1
c = np.array([210])
else:
assert toy.real_batch_size == 10
c = np.array(range(offset+i*toy.batch_size, offset+i*toy.batch_size + toy.real_batch_size))
assert (c==a.data).all(), (c, a.data)
def run_mpi_test(num_procs, name):
if not jt.compile_extern.inside_mpi():