mirror of https://github.com/Jittor/Jittor
mpi op backward
This commit is contained in:
parent
f4130cab28
commit
4967e6ef8b
|
@ -40,6 +40,12 @@ void MpiAllReduceOp::infer_shape() {
|
|||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
VarPtr MpiAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
static VarPtr(*mpi_all_reduce)(Var*) =
|
||||
get_op_info("mpi_all_reduce").get_constructor<VarPtr, Var*>();
|
||||
return mpi_all_reduce(dout);
|
||||
}
|
||||
|
||||
void MpiAllReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
|
|
|
@ -18,6 +18,7 @@ struct MpiAllReduceOp : Op {
|
|||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "mpi_all_reduce"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
|
|
|
@ -40,6 +40,12 @@ void MpiBroadcastOp::infer_shape() {
|
|||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
VarPtr MpiBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
static VarPtr(*mpi_reduce)(Var*, int) =
|
||||
get_op_info("mpi_reduce").get_constructor<VarPtr, Var*, int>();
|
||||
return mpi_reduce(dout,root);
|
||||
}
|
||||
|
||||
void MpiBroadcastOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
|
|
|
@ -19,6 +19,7 @@ struct MpiBroadcastOp : Op {
|
|||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "mpi_broadcast"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
|
|
|
@ -40,6 +40,12 @@ void MpiReduceOp::infer_shape() {
|
|||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
VarPtr MpiReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
static VarPtr(*mpi_broadcast)(Var*, int) =
|
||||
get_op_info("mpi_broadcast").get_constructor<VarPtr, Var*, int>();
|
||||
return mpi_broadcast(dout,root);
|
||||
}
|
||||
|
||||
void MpiReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
|
|
|
@ -19,6 +19,7 @@ struct MpiReduceOp : Op {
|
|||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "mpi_reduce"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
|
|
|
@ -13,9 +13,12 @@ import numpy as np
|
|||
|
||||
def test_all_reduce():
|
||||
print("test all_reduce")
|
||||
mpi = jt.compile_extern.mpi
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5])*3)
|
||||
|
||||
def test_broadcast():
|
||||
print("test broadcast")
|
||||
|
@ -27,6 +30,9 @@ def test_broadcast():
|
|||
x = jt.zeros([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
|
||||
assert np.allclose(y.data, data.data)
|
||||
g = jt.grad(y,x)
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(g.data, np.ones([5,5])*3)
|
||||
|
||||
def test_reduce():
|
||||
print("test reduce")
|
||||
|
@ -36,6 +42,8 @@ def test_reduce():
|
|||
y.sync()
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5]))
|
||||
|
||||
def main():
|
||||
np.random.seed(0)
|
||||
|
@ -45,11 +53,6 @@ def main():
|
|||
test_all_reduce()
|
||||
test_broadcast()
|
||||
test_reduce()
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
if jt.compile_extern.mpi_ops:
|
||||
test_all_reduce()
|
||||
test_broadcast()
|
||||
test_reduce()
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestMpiOps(unittest.TestCase):
|
||||
|
|
Loading…
Reference in New Issue