mpi_reduce init

This commit is contained in:
guowei yang 2020-04-17 13:23:14 +08:00
parent 4967e6ef8b
commit 6c9194bb89
2 changed files with 7 additions and 0 deletions

View File

@ -58,6 +58,9 @@ void MpiReduceOp::jit_run() {
int size = 1 @for(i, 0, XDIM, * xshape@{i});
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
index_t num = y->num;
for (index_t i=0; i<num; i++)
yp[i] = 0;
MPI_Reduce(xp, yp, size, MPI_FLOAT, MPI_SUM, root, MPI_COMM_WORLD);
}
#else

View File

@ -33,6 +33,8 @@ def test_broadcast():
g = jt.grad(y,x)
if mpi.world_rank() == 0:
assert np.allclose(g.data, np.ones([5,5])*3)
else:
assert np.allclose(g.data, np.zeros([5,5]))
def test_reduce():
print("test reduce")
@ -42,6 +44,8 @@ def test_reduce():
y.sync()
if mpi.world_rank() == 0:
assert np.allclose(y.data, (x*3).data)
else:
assert np.allclose(y.data, np.zeros([5,5]))
g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5]))