mirror of https://github.com/Jittor/Jittor
mpi broadcast fix
This commit is contained in:
parent
a3eb40b211
commit
0e506965e2
|
@ -47,9 +47,9 @@ int _mpi_local_rank() {
|
|||
|
||||
void _mpi_broadcast(ArrayArgs&& args, int i) {
|
||||
int64 size = args.dtype.dsize();
|
||||
for (auto i : args.shape)
|
||||
size *= i;
|
||||
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, 0, MPI_COMM_WORLD));
|
||||
for (auto j : args.shape)
|
||||
size *= j;
|
||||
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, i, MPI_COMM_WORLD));
|
||||
}
|
||||
|
||||
static uint64_t getHostHash(const char* string) {
|
||||
|
|
|
@ -23,9 +23,10 @@ class TestMpi(unittest.TestCase):
|
|||
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
|
||||
|
||||
def test_mpi_broadcast(self):
|
||||
a = np.zeros(100) + mpi.world_rank()
|
||||
mpi.broadcast(a, 0)
|
||||
assert (a == 0).all()
|
||||
for i in range(mpi.world_size()):
|
||||
a = np.zeros(100) + mpi.world_rank()
|
||||
mpi.broadcast(a, i)
|
||||
assert (a == i).all()
|
||||
|
||||
def test_mpi_dataset(self):
|
||||
from jittor.dataset.dataset import Dataset
|
||||
|
@ -51,7 +52,7 @@ class TestMpiEntry(unittest.TestCase):
|
|||
def test_entry(self):
|
||||
if not jt.compile_extern.inside_mpi():
|
||||
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
||||
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
|
||||
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi -v"
|
||||
print("run cmd:", cmd)
|
||||
assert os.system(cmd)==0, "run cmd failed: "+cmd
|
||||
|
||||
|
|
Loading…
Reference in New Issue