diff --git a/extern/mpi/inc/mpi_warper.h b/extern/mpi/inc/mpi_warper.h index 6b41e5fa..b4b1ddd1 100644 --- a/extern/mpi/inc/mpi_warper.h +++ b/extern/mpi/inc/mpi_warper.h @@ -37,4 +37,9 @@ int _mpi_world_rank(); // @pyjt(local_rank) int _mpi_local_rank(); +struct ArrayArgs; + +// @pyjt(broadcast) +void _mpi_broadcast(ArrayArgs&& args, int i); + } // jittor diff --git a/extern/mpi/src/mpi_warper.cc b/extern/mpi/src/mpi_warper.cc index 9bde80cf..5a66ccd6 100644 --- a/extern/mpi/src/mpi_warper.cc +++ b/extern/mpi/src/mpi_warper.cc @@ -12,6 +12,7 @@ #include "mpi_warper.h" #include "common.h" +#include "ops/array_op.h" char jt_mpi_err_buffer[MPI_MAX_ERROR_STRING]; @@ -44,7 +45,12 @@ int _mpi_local_rank() { return mpi_local_rank; } - +void _mpi_broadcast(ArrayArgs&& args, int i) { + int64 size = args.dtype.dsize(); + for (auto i : args.shape) + size *= i; + MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, 0, MPI_COMM_WORLD)); +} static uint64_t getHostHash(const char* string) { // Based on DJB2, result = result * 33 + char diff --git a/python/jittor/dataset/dataset.py b/python/jittor/dataset/dataset.py index 5b168cec..6ce238db 100644 --- a/python/jittor/dataset/dataset.py +++ b/python/jittor/dataset/dataset.py @@ -23,6 +23,7 @@ import jittor as jt dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset") mp_log_v = os.environ.get("mp_log_v", 0) +mpi = jt.compile_extern.mpi class Worker: def __init__(self, target, args, buffer_size): @@ -186,6 +187,17 @@ class Dataset(object): index_list = get_order_list(self.total_len) else: index_list = get_random_list(self.total_len) + + # scatter index_list for all mpi process + # scatter rule: + # [000000000111111111222222222] + if mpi: + index_list = np.int32(index_list) + mpi.broadcast(index_list, 0) + new_len = (self.total_len - 1) // mpi.world_size() + 1 + offset = mpi.world_rank() * new_len + index_list = index_list[offset:offset+new_len] + self.total_len = len(index_list) self.batch_len = len(self) if "batch_len" in os.environ: diff --git a/python/jittor/test/test_mpi.py b/python/jittor/test/test_mpi.py index d4c3ad41..2f496639 100644 --- a/python/jittor/test/test_mpi.py +++ b/python/jittor/test/test_mpi.py @@ -10,24 +10,49 @@ import unittest import os, sys import jittor as jt import numpy as np +mpi = jt.compile_extern.mpi -def main(): - print("test mpi_test") - assert jt.compile_extern.mpi_ops.mpi_test("").data == 123 - if jt.compile_extern.nccl_ops: - print("test test_with_mpi") - with jt.flag_scope(use_cuda=1): - assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123 - -@unittest.skipIf(jt.compile_extern.has_mpi is None, "no mpi found") +@unittest.skipIf(mpi is None, "no inside mpirun") class TestMpi(unittest.TestCase): - def test(self): + def test_mpi_test_op(self): + assert jt.compile_extern.mpi_ops.mpi_test("").data == 123 + + @unittest.skipIf(jt.compile_extern.nccl_ops is None, "no inccl") + @jt.flag_scope(use_cuda=1) + def test_nccl_with_mpi(self): + assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123 + + def test_mpi_broadcast(self): + a = np.zeros(100) + mpi.world_rank() + mpi.broadcast(a, 0) + assert (a == 0).all() + + def test_mpi_dataset(self): + from jittor.dataset.dataset import Dataset + class ToyDataset(Dataset): + def __init__(self): + super().__init__() + self.set_attrs(total_len=1024) + + def __getitem__(self, index): + return index, index*index + + toy = ToyDataset() + offset = ((toy.total_len-1) // mpi.world_size() + 1) * mpi.world_rank() + + for i,(a,b) in enumerate(toy): + assert (a.data*a.data == b.data).all() + c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size)) + assert (c==a.data).all() + +@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found") +class TestMpiEntry(unittest.TestCase): + def test_entry(self): if not jt.compile_extern.inside_mpi(): mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun") cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi" + print("run cmd:", cmd) assert os.system(cmd)==0, "run cmd failed: "+cmd - else: - main() if __name__ == "__main__": unittest.main() \ No newline at end of file