mpi broadcast fix

This commit is contained in:
Dun Liang 2020-04-10 16:37:09 +08:00
parent a3eb40b211
commit 0e506965e2
2 changed files with 8 additions and 7 deletions

View File

@ -47,9 +47,9 @@ int _mpi_local_rank() {
void _mpi_broadcast(ArrayArgs&& args, int i) {
int64 size = args.dtype.dsize();
for (auto i : args.shape)
size *= i;
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, 0, MPI_COMM_WORLD));
for (auto j : args.shape)
size *= j;
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, i, MPI_COMM_WORLD));
}
static uint64_t getHostHash(const char* string) {

View File

@ -23,9 +23,10 @@ class TestMpi(unittest.TestCase):
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
def test_mpi_broadcast(self):
a = np.zeros(100) + mpi.world_rank()
mpi.broadcast(a, 0)
assert (a == 0).all()
for i in range(mpi.world_size()):
a = np.zeros(100) + mpi.world_rank()
mpi.broadcast(a, i)
assert (a == i).all()
def test_mpi_dataset(self):
from jittor.dataset.dataset import Dataset
@ -51,7 +52,7 @@ class TestMpiEntry(unittest.TestCase):
def test_entry(self):
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi -v"
print("run cmd:", cmd)
assert os.system(cmd)==0, "run cmd failed: "+cmd