mirror of https://github.com/Jittor/Jittor
add reduce mean
This commit is contained in:
parent
c94995e640
commit
eb15529b4d
|
@ -16,16 +16,28 @@
|
|||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
MpiAllReduceOp::MpiAllReduceOp(Var* x) : x(x) {
|
||||
|
||||
static auto make_array = get_op_info("array")
|
||||
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
|
||||
static auto make_binary = get_op_info("binary")
|
||||
.get_constructor<VarPtr, Var*, Var*, NanoString>();
|
||||
static auto make_mpi_all_reduce = get_op_info("mpi_all_reduce")
|
||||
.get_constructor<VarPtr, Var*, NanoString>();
|
||||
|
||||
MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) {
|
||||
if (op == ns_mean) {
|
||||
auto var = make_mpi_all_reduce(x, ns_add);
|
||||
var = make_binary(var, make_array(&mpi_world_size, 1, ns_int32), ns_divide);
|
||||
forward(var);
|
||||
return;
|
||||
}
|
||||
ASSERT(op == ns_add) << "Not supported MPI op" << op;
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static VarPtr(*nccl_all_reduce)(Var*) = nullptr;
|
||||
if (!nccl_all_reduce && has_op("nccl_all_reduce")) {
|
||||
nccl_all_reduce = get_op_info("nccl_all_reduce")
|
||||
.get_constructor<VarPtr, Var*>();
|
||||
}
|
||||
static VarPtr(*nccl_all_reduce)(Var*) = has_op("nccl_all_reduce")
|
||||
? ; get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>()
|
||||
: nullptr;
|
||||
if (nccl_all_reduce) {
|
||||
LOGr << "nccl";
|
||||
auto var = nccl_all_reduce(x);
|
||||
forward(var);
|
||||
return;
|
||||
|
@ -41,24 +53,32 @@ void MpiAllReduceOp::infer_shape() {
|
|||
}
|
||||
|
||||
VarPtr MpiAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
static VarPtr(*mpi_all_reduce)(Var*) =
|
||||
get_op_info("mpi_all_reduce").get_constructor<VarPtr, Var*>();
|
||||
return mpi_all_reduce(dout);
|
||||
static auto mpi_all_reduce =
|
||||
get_op_info("mpi_all_reduce").get_constructor<VarPtr, Var*,NanoString>();
|
||||
return mpi_all_reduce(dout, ns_add);
|
||||
}
|
||||
|
||||
void MpiAllReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
add_jit_define("OP", op.to_cstring());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void MpiAllReduceOp::jit_run() {
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
@define(T_MPI,
|
||||
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, MPI_FLOAT)
|
||||
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, MPI_INT)
|
||||
@if(@strcmp(@Tx,float64)==0 || @strcmp(@Tx,double)==0, MPI_DOUBLE)
|
||||
@if(@strcmp(@Tx,int64)==0, MPI_DOUBLE_INT)
|
||||
)
|
||||
@define(OP_MPI,
|
||||
@if(@strcmp(@OP,add)==0, MPI_SUM)
|
||||
)
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
MPI_Allreduce(xp, yp, size, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD);
|
||||
index_t num = y->num;
|
||||
MPI_Allreduce(xp, yp, num, T_MPI, OP_MPI, MPI_COMM_WORLD);
|
||||
}
|
||||
#else
|
||||
void MpiAllReduceOp::jit_run() {
|
||||
|
|
|
@ -13,8 +13,9 @@ namespace jittor {
|
|||
|
||||
struct MpiAllReduceOp : Op {
|
||||
Var* x, * y;
|
||||
NanoString op;
|
||||
|
||||
MpiAllReduceOp(Var* x);
|
||||
MpiAllReduceOp(Var* x, NanoString op=ns_add);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "mpi_all_reduce"; }
|
||||
|
|
|
@ -41,9 +41,9 @@ void MpiBroadcastOp::infer_shape() {
|
|||
}
|
||||
|
||||
VarPtr MpiBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
static VarPtr(*mpi_reduce)(Var*, int) =
|
||||
get_op_info("mpi_reduce").get_constructor<VarPtr, Var*, int>();
|
||||
return mpi_reduce(dout,root);
|
||||
static auto mpi_reduce =
|
||||
get_op_info("mpi_reduce").get_constructor<VarPtr, Var*, NanoString, int>();
|
||||
return mpi_reduce(dout, ns_add, root);
|
||||
}
|
||||
|
||||
void MpiBroadcastOp::jit_prepare() {
|
||||
|
|
|
@ -16,16 +16,28 @@
|
|||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
MpiReduceOp::MpiReduceOp(Var* x, int root) : x(x), root(root) {
|
||||
|
||||
static auto make_array = get_op_info("array")
|
||||
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
|
||||
static auto make_binary = get_op_info("binary")
|
||||
.get_constructor<VarPtr, Var*, Var*, NanoString>();
|
||||
static auto make_mpi_reduce = get_op_info("mpi_reduce")
|
||||
.get_constructor<VarPtr, Var*, NanoString, int>();
|
||||
|
||||
MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(root) {
|
||||
if (op == ns_mean) {
|
||||
auto var = make_mpi_reduce(x, ns_add, root);
|
||||
var = make_binary(var, make_array(&mpi_world_size, 1, ns_int32), ns_divide);
|
||||
forward(var);
|
||||
return;
|
||||
}
|
||||
ASSERT(op == ns_add) << "Not supported MPI op" << op;
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static VarPtr(*nccl_reduce)(Var*, int) = nullptr;
|
||||
if (!nccl_reduce && has_op("nccl_reduce")) {
|
||||
nccl_reduce = get_op_info("nccl_reduce")
|
||||
.get_constructor<VarPtr, Var*, int>();
|
||||
}
|
||||
static VarPtr(*nccl_reduce)(Var*, int) = has_op("nccl_reduce")
|
||||
? get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>();
|
||||
: nullptr;
|
||||
if (nccl_reduce) {
|
||||
LOGr << "nccl";
|
||||
auto var = nccl_reduce(x, root);
|
||||
forward(var);
|
||||
return;
|
||||
|
@ -48,20 +60,26 @@ VarPtr MpiReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
|
||||
void MpiReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
add_jit_define("OP", op.to_cstring());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void MpiReduceOp::jit_run() {
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
@define(T_MPI,
|
||||
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, MPI_FLOAT)
|
||||
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, MPI_INT)
|
||||
@if(@strcmp(@Tx,float64)==0 || @strcmp(@Tx,double)==0, MPI_DOUBLE)
|
||||
@if(@strcmp(@Tx,int64)==0, MPI_DOUBLE_INT)
|
||||
)
|
||||
@define(OP_MPI,
|
||||
@if(@strcmp(@OP,add)==0, MPI_SUM)
|
||||
)
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
index_t num = y->num;
|
||||
for (index_t i=0; i<num; i++)
|
||||
yp[i] = 0;
|
||||
MPI_Reduce(xp, yp, size, MPI_FLOAT, MPI_SUM, root, MPI_COMM_WORLD);
|
||||
for (index_t i=0; i<num; i++) yp[i] = 0;
|
||||
MPI_CHECK(MPI_Reduce(xp, yp, num, T_MPI, OP_MPI, root, MPI_COMM_WORLD));
|
||||
}
|
||||
#else
|
||||
void MpiReduceOp::jit_run() {
|
||||
|
|
|
@ -13,9 +13,10 @@ namespace jittor {
|
|||
|
||||
struct MpiReduceOp : Op {
|
||||
Var* x, * y;
|
||||
NanoString op;
|
||||
int root;
|
||||
|
||||
MpiReduceOp(Var* x, int root);
|
||||
MpiReduceOp(Var* x, NanoString op=ns_add, int root=0);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "mpi_reduce"; }
|
||||
|
|
|
@ -28,6 +28,13 @@ class TestMpiOps(unittest.TestCase):
|
|||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5])*3)
|
||||
|
||||
def test_all_reduce_mean(self):
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x, "mean")
|
||||
assert np.allclose(y.data, x.data)
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5]))
|
||||
|
||||
def test_broadcast(self):
|
||||
data = jt.random([5, 5])
|
||||
if mpi.world_rank() == 0:
|
||||
|
@ -44,7 +51,7 @@ class TestMpiOps(unittest.TestCase):
|
|||
|
||||
def test_reduce(self):
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
|
||||
y = jt.compile_extern.mpi_ops.mpi_reduce(x, root=0)
|
||||
y.sync()
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
|
|
|
@ -65,7 +65,7 @@ class TestNcclOps(unittest.TestCase):
|
|||
log_v=1, log_vprefix="op.cc=100,exe=1000"
|
||||
) as raw_log:
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
|
||||
y = jt.compile_extern.mpi_ops.mpi_reduce(x, root=0)
|
||||
y_ = y.data
|
||||
x_ = (x*n).data
|
||||
if mpi.world_rank() == 0:
|
||||
|
|
Loading…
Reference in New Issue