mirror of https://github.com/Jittor/Jittor
polish mpi dataset, better batch size
This commit is contained in:
parent
2315928240
commit
c94995e640
|
@ -70,11 +70,9 @@ class Dataset(object):
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
assert self.total_len >= 0
|
assert self.total_len >= 0
|
||||||
assert self.batch_size > 0
|
assert self.batch_size > 0
|
||||||
real_len = (self.total_len-1)//mpi.world_size()+1 if mpi \
|
|
||||||
else self.total_len
|
|
||||||
if self.drop_last:
|
if self.drop_last:
|
||||||
return real_len // self.batch_size
|
return self.total_len // self.batch_size
|
||||||
return (real_len-1) // self.batch_size + 1
|
return (self.total_len-1) // self.batch_size + 1
|
||||||
|
|
||||||
def set_attrs(self, **kw):
|
def set_attrs(self, **kw):
|
||||||
'''set attributes of dataset, equivalent to setattr
|
'''set attributes of dataset, equivalent to setattr
|
||||||
|
@ -133,8 +131,8 @@ class Dataset(object):
|
||||||
self.gidc.notify()
|
self.gidc.notify()
|
||||||
batch = []
|
batch = []
|
||||||
if mp_log_v:
|
if mp_log_v:
|
||||||
print(f"#{worker_id} {os.getpid()} load batch", cid*self.batch_size, min(self.real_len, (cid+1)*self.batch_size))
|
print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size))
|
||||||
for i in range(cid*self.batch_size, min(self.real_len, (cid+1)*self.batch_size)):
|
for i in range(cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)):
|
||||||
batch.append(self[self.index_list[i]])
|
batch.append(self[self.index_list[i]])
|
||||||
batch = self.collate_batch(batch)
|
batch = self.collate_batch(batch)
|
||||||
if mp_log_v:
|
if mp_log_v:
|
||||||
|
@ -192,26 +190,54 @@ class Dataset(object):
|
||||||
|
|
||||||
# scatter index_list for all mpi process
|
# scatter index_list for all mpi process
|
||||||
# scatter rule:
|
# scatter rule:
|
||||||
# [........]
|
# batch 1 batch 2
|
||||||
# 000111
|
# [........] [........] ...
|
||||||
# 222
|
# 00011122 00011122
|
||||||
# make sure each process has the same len
|
# if last batch is smaller than world_size
|
||||||
|
# pad to world_size
|
||||||
|
# last batch
|
||||||
|
# [.] -> [012]
|
||||||
if mpi:
|
if mpi:
|
||||||
|
world_size = mpi.world_size()
|
||||||
|
world_rank = mpi.world_rank()
|
||||||
index_list = np.int32(index_list)
|
index_list = np.int32(index_list)
|
||||||
mpi.broadcast(index_list, 0)
|
mpi.broadcast(index_list, 0)
|
||||||
real_len = (self.total_len - 1) // mpi.world_size() + 1
|
|
||||||
offset = mpi.world_rank() * real_len
|
assert self.batch_size >= world_size, \
|
||||||
if offset + real_len > self.total_len:
|
f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})"
|
||||||
offset -= offset + real_len - self.total_len
|
real_batch_size = (self.batch_size-1) // world_size + 1
|
||||||
index_list = index_list[offset:offset+real_len]
|
if real_batch_size * world_size != self.batch_size:
|
||||||
self.real_len = real_len
|
LOG.w("Batch size is not divisible by MPI world size, "
|
||||||
assert real_len == len(index_list)
|
"The distributed version may be different from "
|
||||||
|
"the single-process version.")
|
||||||
|
fix_batch = self.total_len // self.batch_size
|
||||||
|
last_batch = self.total_len - fix_batch * self.batch_size
|
||||||
|
fix_batch_l = index_list[0:fix_batch*self.batch_size] \
|
||||||
|
.reshape(-1,self.batch_size)
|
||||||
|
fix_batch_l = fix_batch_l[
|
||||||
|
:,real_batch_size*world_rank:real_batch_size*(world_rank+1)]
|
||||||
|
real_batch_size = fix_batch_l.shape[1]
|
||||||
|
fix_batch_l = fix_batch_l.flatten()
|
||||||
|
if not self.drop_last and last_batch > 0:
|
||||||
|
last_batch_l = index_list[-last_batch:]
|
||||||
|
real_last_batch = (last_batch-1)//world_size+1
|
||||||
|
l = real_last_batch * world_rank
|
||||||
|
r = l + real_last_batch
|
||||||
|
if r > last_batch: r = last_batch
|
||||||
|
if l >= r: l = r-1
|
||||||
|
index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]])
|
||||||
|
else:
|
||||||
|
index_list = fix_batch_l
|
||||||
|
|
||||||
|
self.real_len = len(index_list)
|
||||||
|
self.real_batch_size = real_batch_size
|
||||||
|
assert self.total_len // self.batch_size == \
|
||||||
|
self.real_len // self.real_batch_size
|
||||||
else:
|
else:
|
||||||
self.real_len = self.total_len
|
self.real_len = self.total_len
|
||||||
|
self.real_batch_size = self.batch_size
|
||||||
|
|
||||||
self.batch_len = len(self)
|
self.batch_len = len(self)
|
||||||
if "batch_len" in os.environ:
|
|
||||||
self.batch_len = int(os.environ["batch_len"])
|
|
||||||
|
|
||||||
if not hasattr(self, "workers") and self.num_workers:
|
if not hasattr(self, "workers") and self.num_workers:
|
||||||
self._init_workers()
|
self._init_workers()
|
||||||
|
@ -245,7 +271,7 @@ class Dataset(object):
|
||||||
batch_data = []
|
batch_data = []
|
||||||
for idx in index_list:
|
for idx in index_list:
|
||||||
batch_data.append(self[int(idx)])
|
batch_data.append(self[int(idx)])
|
||||||
if len(batch_data) == self.batch_size:
|
if len(batch_data) == self.real_batch_size:
|
||||||
batch_data = self.collate_batch(batch_data)
|
batch_data = self.collate_batch(batch_data)
|
||||||
batch_data = self.to_jittor(batch_data)
|
batch_data = self.to_jittor(batch_data)
|
||||||
yield batch_data
|
yield batch_data
|
||||||
|
|
|
@ -33,19 +33,33 @@ class TestMpi(unittest.TestCase):
|
||||||
class ToyDataset(Dataset):
|
class ToyDataset(Dataset):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.set_attrs(total_len=1024)
|
self.set_attrs(batch_size=21, total_len=211)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
return index, index*index
|
return index, index*index
|
||||||
|
|
||||||
toy = ToyDataset()
|
toy = ToyDataset()
|
||||||
offset = ((toy.total_len-1) // mpi.world_size() + 1) * mpi.world_rank()
|
offset = ((toy.batch_size-1) // mpi.world_size() + 1) * mpi.world_rank()
|
||||||
|
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
for i,(a,b) in enumerate(toy):
|
for i,(a,b) in enumerate(toy):
|
||||||
assert (a.data*a.data == b.data).all()
|
assert (a.data*a.data == b.data).all()
|
||||||
c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size))
|
if mpi.world_rank() == 0:
|
||||||
assert (c==a.data).all()
|
if i == len(toy)-1:
|
||||||
|
assert a.shape[0] == 1
|
||||||
|
c = np.array([210])
|
||||||
|
else:
|
||||||
|
assert toy.real_batch_size == 11
|
||||||
|
c = np.array(range(offset+i*toy.batch_size, offset+i*toy.batch_size + toy.real_batch_size))
|
||||||
|
else:
|
||||||
|
if i == len(toy)-1:
|
||||||
|
assert a.shape[0] == 1
|
||||||
|
c = np.array([210])
|
||||||
|
else:
|
||||||
|
assert toy.real_batch_size == 10
|
||||||
|
c = np.array(range(offset+i*toy.batch_size, offset+i*toy.batch_size + toy.real_batch_size))
|
||||||
|
|
||||||
|
assert (c==a.data).all(), (c, a.data)
|
||||||
|
|
||||||
def run_mpi_test(num_procs, name):
|
def run_mpi_test(num_procs, name):
|
||||||
if not jt.compile_extern.inside_mpi():
|
if not jt.compile_extern.inside_mpi():
|
||||||
|
|
Loading…
Reference in New Issue