add reduce mean

This commit is contained in:
Dun Liang 2020-04-21 23:37:43 +08:00
parent c94995e640
commit eb15529b4d
7 changed files with 81 additions and 34 deletions

View File

@ -16,16 +16,28 @@
namespace jittor {
#ifndef JIT
MpiAllReduceOp::MpiAllReduceOp(Var* x) : x(x) {
static auto make_array = get_op_info("array")
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
static auto make_binary = get_op_info("binary")
.get_constructor<VarPtr, Var*, Var*, NanoString>();
static auto make_mpi_all_reduce = get_op_info("mpi_all_reduce")
.get_constructor<VarPtr, Var*, NanoString>();
MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) {
if (op == ns_mean) {
auto var = make_mpi_all_reduce(x, ns_add);
var = make_binary(var, make_array(&mpi_world_size, 1, ns_int32), ns_divide);
forward(var);
return;
}
ASSERT(op == ns_add) << "Not supported MPI op" << op;
#ifdef HAS_CUDA
if (use_cuda) {
static VarPtr(*nccl_all_reduce)(Var*) = nullptr;
if (!nccl_all_reduce && has_op("nccl_all_reduce")) {
nccl_all_reduce = get_op_info("nccl_all_reduce")
.get_constructor<VarPtr, Var*>();
}
static VarPtr(*nccl_all_reduce)(Var*) = has_op("nccl_all_reduce")
? ; get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>()
: nullptr;
if (nccl_all_reduce) {
LOGr << "nccl";
auto var = nccl_all_reduce(x);
forward(var);
return;
@ -41,24 +53,32 @@ void MpiAllReduceOp::infer_shape() {
}
VarPtr MpiAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
static VarPtr(*mpi_all_reduce)(Var*) =
get_op_info("mpi_all_reduce").get_constructor<VarPtr, Var*>();
return mpi_all_reduce(dout);
static auto mpi_all_reduce =
get_op_info("mpi_all_reduce").get_constructor<VarPtr, Var*,NanoString>();
return mpi_all_reduce(dout, ns_add);
}
void MpiAllReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("XDIM", JK::hex1(x->shape.size()));
add_jit_define("OP", op.to_cstring());
}
#else // JIT
#ifdef JIT_cpu
void MpiAllReduceOp::jit_run() {
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
int size = 1 @for(i, 0, XDIM, * xshape@{i});
@define(T_MPI,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, MPI_FLOAT)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, MPI_INT)
@if(@strcmp(@Tx,float64)==0 || @strcmp(@Tx,double)==0, MPI_DOUBLE)
@if(@strcmp(@Tx,int64)==0, MPI_DOUBLE_INT)
)
@define(OP_MPI,
@if(@strcmp(@OP,add)==0, MPI_SUM)
)
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
MPI_Allreduce(xp, yp, size, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD);
index_t num = y->num;
MPI_Allreduce(xp, yp, num, T_MPI, OP_MPI, MPI_COMM_WORLD);
}
#else
void MpiAllReduceOp::jit_run() {

View File

@ -13,8 +13,9 @@ namespace jittor {
struct MpiAllReduceOp : Op {
Var* x, * y;
NanoString op;
MpiAllReduceOp(Var* x);
MpiAllReduceOp(Var* x, NanoString op=ns_add);
void infer_shape() override;
const char* name() const override { return "mpi_all_reduce"; }

View File

@ -41,9 +41,9 @@ void MpiBroadcastOp::infer_shape() {
}
VarPtr MpiBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
static VarPtr(*mpi_reduce)(Var*, int) =
get_op_info("mpi_reduce").get_constructor<VarPtr, Var*, int>();
return mpi_reduce(dout,root);
static auto mpi_reduce =
get_op_info("mpi_reduce").get_constructor<VarPtr, Var*, NanoString, int>();
return mpi_reduce(dout, ns_add, root);
}
void MpiBroadcastOp::jit_prepare() {

View File

@ -16,16 +16,28 @@
namespace jittor {
#ifndef JIT
MpiReduceOp::MpiReduceOp(Var* x, int root) : x(x), root(root) {
static auto make_array = get_op_info("array")
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
static auto make_binary = get_op_info("binary")
.get_constructor<VarPtr, Var*, Var*, NanoString>();
static auto make_mpi_reduce = get_op_info("mpi_reduce")
.get_constructor<VarPtr, Var*, NanoString, int>();
MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(root) {
if (op == ns_mean) {
auto var = make_mpi_reduce(x, ns_add, root);
var = make_binary(var, make_array(&mpi_world_size, 1, ns_int32), ns_divide);
forward(var);
return;
}
ASSERT(op == ns_add) << "Not supported MPI op" << op;
#ifdef HAS_CUDA
if (use_cuda) {
static VarPtr(*nccl_reduce)(Var*, int) = nullptr;
if (!nccl_reduce && has_op("nccl_reduce")) {
nccl_reduce = get_op_info("nccl_reduce")
.get_constructor<VarPtr, Var*, int>();
}
static VarPtr(*nccl_reduce)(Var*, int) = has_op("nccl_reduce")
? get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>();
: nullptr;
if (nccl_reduce) {
LOGr << "nccl";
auto var = nccl_reduce(x, root);
forward(var);
return;
@ -48,20 +60,26 @@ VarPtr MpiReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
void MpiReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("XDIM", JK::hex1(x->shape.size()));
add_jit_define("OP", op.to_cstring());
}
#else // JIT
#ifdef JIT_cpu
void MpiReduceOp::jit_run() {
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
int size = 1 @for(i, 0, XDIM, * xshape@{i});
@define(T_MPI,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, MPI_FLOAT)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, MPI_INT)
@if(@strcmp(@Tx,float64)==0 || @strcmp(@Tx,double)==0, MPI_DOUBLE)
@if(@strcmp(@Tx,int64)==0, MPI_DOUBLE_INT)
)
@define(OP_MPI,
@if(@strcmp(@OP,add)==0, MPI_SUM)
)
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
index_t num = y->num;
for (index_t i=0; i<num; i++)
yp[i] = 0;
MPI_Reduce(xp, yp, size, MPI_FLOAT, MPI_SUM, root, MPI_COMM_WORLD);
for (index_t i=0; i<num; i++) yp[i] = 0;
MPI_CHECK(MPI_Reduce(xp, yp, num, T_MPI, OP_MPI, root, MPI_COMM_WORLD));
}
#else
void MpiReduceOp::jit_run() {

View File

@ -13,9 +13,10 @@ namespace jittor {
struct MpiReduceOp : Op {
Var* x, * y;
NanoString op;
int root;
MpiReduceOp(Var* x, int root);
MpiReduceOp(Var* x, NanoString op=ns_add, int root=0);
void infer_shape() override;
const char* name() const override { return "mpi_reduce"; }

View File

@ -28,6 +28,13 @@ class TestMpiOps(unittest.TestCase):
g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5])*3)
def test_all_reduce_mean(self):
x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x, "mean")
assert np.allclose(y.data, x.data)
g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5]))
def test_broadcast(self):
data = jt.random([5, 5])
if mpi.world_rank() == 0:
@ -44,7 +51,7 @@ class TestMpiOps(unittest.TestCase):
def test_reduce(self):
x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
y = jt.compile_extern.mpi_ops.mpi_reduce(x, root=0)
y.sync()
if mpi.world_rank() == 0:
assert np.allclose(y.data, (x*3).data)

View File

@ -65,7 +65,7 @@ class TestNcclOps(unittest.TestCase):
log_v=1, log_vprefix="op.cc=100,exe=1000"
) as raw_log:
x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
y = jt.compile_extern.mpi_ops.mpi_reduce(x, root=0)
y_ = y.data
x_ = (x*n).data
if mpi.world_rank() == 0: