mirror of https://github.com/Jittor/Jittor
mpi_reduce init
This commit is contained in:
parent
4967e6ef8b
commit
6c9194bb89
|
@ -58,6 +58,9 @@ void MpiReduceOp::jit_run() {
|
|||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
index_t num = y->num;
|
||||
for (index_t i=0; i<num; i++)
|
||||
yp[i] = 0;
|
||||
MPI_Reduce(xp, yp, size, MPI_FLOAT, MPI_SUM, root, MPI_COMM_WORLD);
|
||||
}
|
||||
#else
|
||||
|
|
|
@ -33,6 +33,8 @@ def test_broadcast():
|
|||
g = jt.grad(y,x)
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(g.data, np.ones([5,5])*3)
|
||||
else:
|
||||
assert np.allclose(g.data, np.zeros([5,5]))
|
||||
|
||||
def test_reduce():
|
||||
print("test reduce")
|
||||
|
@ -42,6 +44,8 @@ def test_reduce():
|
|||
y.sync()
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
else:
|
||||
assert np.allclose(y.data, np.zeros([5,5]))
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5]))
|
||||
|
||||
|
|
Loading…
Reference in New Issue