diff --git a/python/jittor/dataset/dataset.py b/python/jittor/dataset/dataset.py index f4ef8387..9569fc8c 100644 --- a/python/jittor/dataset/dataset.py +++ b/python/jittor/dataset/dataset.py @@ -70,11 +70,9 @@ class Dataset(object): def __len__(self): assert self.total_len >= 0 assert self.batch_size > 0 - real_len = (self.total_len-1)//mpi.world_size()+1 if mpi \ - else self.total_len if self.drop_last: - return real_len // self.batch_size - return (real_len-1) // self.batch_size + 1 + return self.total_len // self.batch_size + return (self.total_len-1) // self.batch_size + 1 def set_attrs(self, **kw): '''set attributes of dataset, equivalent to setattr @@ -133,8 +131,8 @@ class Dataset(object): self.gidc.notify() batch = [] if mp_log_v: - print(f"#{worker_id} {os.getpid()} load batch", cid*self.batch_size, min(self.real_len, (cid+1)*self.batch_size)) - for i in range(cid*self.batch_size, min(self.real_len, (cid+1)*self.batch_size)): + print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)) + for i in range(cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)): batch.append(self[self.index_list[i]]) batch = self.collate_batch(batch) if mp_log_v: @@ -192,26 +190,54 @@ class Dataset(object): # scatter index_list for all mpi process # scatter rule: - # [........] - # 000111 - # 222 - # make sure each process has the same len + # batch 1 batch 2 + # [........] [........] ... + # 00011122 00011122 + # if last batch is smaller than world_size + # pad to world_size + # last batch + # [.] -> [012] if mpi: + world_size = mpi.world_size() + world_rank = mpi.world_rank() index_list = np.int32(index_list) mpi.broadcast(index_list, 0) - real_len = (self.total_len - 1) // mpi.world_size() + 1 - offset = mpi.world_rank() * real_len - if offset + real_len > self.total_len: - offset -= offset + real_len - self.total_len - index_list = index_list[offset:offset+real_len] - self.real_len = real_len - assert real_len == len(index_list) + + assert self.batch_size >= world_size, \ + f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})" + real_batch_size = (self.batch_size-1) // world_size + 1 + if real_batch_size * world_size != self.batch_size: + LOG.w("Batch size is not divisible by MPI world size, " + "The distributed version may be different from " + "the single-process version.") + fix_batch = self.total_len // self.batch_size + last_batch = self.total_len - fix_batch * self.batch_size + fix_batch_l = index_list[0:fix_batch*self.batch_size] \ + .reshape(-1,self.batch_size) + fix_batch_l = fix_batch_l[ + :,real_batch_size*world_rank:real_batch_size*(world_rank+1)] + real_batch_size = fix_batch_l.shape[1] + fix_batch_l = fix_batch_l.flatten() + if not self.drop_last and last_batch > 0: + last_batch_l = index_list[-last_batch:] + real_last_batch = (last_batch-1)//world_size+1 + l = real_last_batch * world_rank + r = l + real_last_batch + if r > last_batch: r = last_batch + if l >= r: l = r-1 + index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]]) + else: + index_list = fix_batch_l + + self.real_len = len(index_list) + self.real_batch_size = real_batch_size + assert self.total_len // self.batch_size == \ + self.real_len // self.real_batch_size else: self.real_len = self.total_len + self.real_batch_size = self.batch_size self.batch_len = len(self) - if "batch_len" in os.environ: - self.batch_len = int(os.environ["batch_len"]) if not hasattr(self, "workers") and self.num_workers: self._init_workers() @@ -245,7 +271,7 @@ class Dataset(object): batch_data = [] for idx in index_list: batch_data.append(self[int(idx)]) - if len(batch_data) == self.batch_size: + if len(batch_data) == self.real_batch_size: batch_data = self.collate_batch(batch_data) batch_data = self.to_jittor(batch_data) yield batch_data diff --git a/python/jittor/test/test_mpi.py b/python/jittor/test/test_mpi.py index 10be0f56..de8cf5f4 100644 --- a/python/jittor/test/test_mpi.py +++ b/python/jittor/test/test_mpi.py @@ -33,19 +33,33 @@ class TestMpi(unittest.TestCase): class ToyDataset(Dataset): def __init__(self): super().__init__() - self.set_attrs(total_len=1024) + self.set_attrs(batch_size=21, total_len=211) def __getitem__(self, index): return index, index*index toy = ToyDataset() - offset = ((toy.total_len-1) // mpi.world_size() + 1) * mpi.world_rank() + offset = ((toy.batch_size-1) // mpi.world_size() + 1) * mpi.world_rank() for _ in range(2): for i,(a,b) in enumerate(toy): assert (a.data*a.data == b.data).all() - c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size)) - assert (c==a.data).all() + if mpi.world_rank() == 0: + if i == len(toy)-1: + assert a.shape[0] == 1 + c = np.array([210]) + else: + assert toy.real_batch_size == 11 + c = np.array(range(offset+i*toy.batch_size, offset+i*toy.batch_size + toy.real_batch_size)) + else: + if i == len(toy)-1: + assert a.shape[0] == 1 + c = np.array([210]) + else: + assert toy.real_batch_size == 10 + c = np.array(range(offset+i*toy.batch_size, offset+i*toy.batch_size + toy.real_batch_size)) + + assert (c==a.data).all(), (c, a.data) def run_mpi_test(num_procs, name): if not jt.compile_extern.inside_mpi():