mpi op backward

This commit is contained in:
guowei yang 2020-04-17 10:53:56 +08:00
parent f4130cab28
commit 4967e6ef8b
7 changed files with 29 additions and 5 deletions

View File

@ -40,6 +40,12 @@ void MpiAllReduceOp::infer_shape() {
y->set_shape(x->shape);
}
VarPtr MpiAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
static VarPtr(*mpi_all_reduce)(Var*) =
get_op_info("mpi_all_reduce").get_constructor<VarPtr, Var*>();
return mpi_all_reduce(dout);
}
void MpiAllReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("XDIM", JK::hex1(x->shape.size()));

View File

@ -18,6 +18,7 @@ struct MpiAllReduceOp : Op {
void infer_shape() override;
const char* name() const override { return "mpi_all_reduce"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;
};

View File

@ -40,6 +40,12 @@ void MpiBroadcastOp::infer_shape() {
y->set_shape(x->shape);
}
VarPtr MpiBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
static VarPtr(*mpi_reduce)(Var*, int) =
get_op_info("mpi_reduce").get_constructor<VarPtr, Var*, int>();
return mpi_reduce(dout,root);
}
void MpiBroadcastOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("XDIM", JK::hex1(x->shape.size()));

View File

@ -19,6 +19,7 @@ struct MpiBroadcastOp : Op {
void infer_shape() override;
const char* name() const override { return "mpi_broadcast"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;
};

View File

@ -40,6 +40,12 @@ void MpiReduceOp::infer_shape() {
y->set_shape(x->shape);
}
VarPtr MpiReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
static VarPtr(*mpi_broadcast)(Var*, int) =
get_op_info("mpi_broadcast").get_constructor<VarPtr, Var*, int>();
return mpi_broadcast(dout,root);
}
void MpiReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("XDIM", JK::hex1(x->shape.size()));

View File

@ -19,6 +19,7 @@ struct MpiReduceOp : Op {
void infer_shape() override;
const char* name() const override { return "mpi_reduce"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;
};

View File

@ -13,9 +13,12 @@ import numpy as np
def test_all_reduce():
print("test all_reduce")
mpi = jt.compile_extern.mpi
x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
assert np.allclose(y.data, (x*3).data)
g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5])*3)
def test_broadcast():
print("test broadcast")
@ -27,6 +30,9 @@ def test_broadcast():
x = jt.zeros([5, 5])
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
assert np.allclose(y.data, data.data)
g = jt.grad(y,x)
if mpi.world_rank() == 0:
assert np.allclose(g.data, np.ones([5,5])*3)
def test_reduce():
print("test reduce")
@ -36,6 +42,8 @@ def test_reduce():
y.sync()
if mpi.world_rank() == 0:
assert np.allclose(y.data, (x*3).data)
g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5]))
def main():
np.random.seed(0)
@ -45,11 +53,6 @@ def main():
test_all_reduce()
test_broadcast()
test_reduce()
with jt.flag_scope(use_cuda=1):
if jt.compile_extern.mpi_ops:
test_all_reduce()
test_broadcast()
test_reduce()
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestMpiOps(unittest.TestCase):