polish mpi code

This commit is contained in:
Dun Liang 2020-04-22 15:06:23 +08:00
parent 8da6ef720b
commit 3ecce8eb9e
7 changed files with 27 additions and 33 deletions

View File

@ -18,6 +18,10 @@
namespace jittor {
#ifndef JIT
static auto nccl_all_reduce =
get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>();
NcclAllReduceOp::NcclAllReduceOp(Var* x) : x(x) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
@ -29,14 +33,11 @@ void NcclAllReduceOp::infer_shape() {
}
VarPtr NcclAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
static VarPtr(*nccl_all_reduce)(Var*) =
get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>();
return nccl_all_reduce(dout);
}
void NcclAllReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("XDIM", JK::hex1(x->shape.size()));
}
#else // JIT
@ -49,11 +50,9 @@ void NcclAllReduceOp::jit_run() {
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
@if(@strcmp(@Tx,int64)==0, ncclInt64)
)
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
int size = 1 @for(i, 0, XDIM, * xshape@{i});
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
checkCudaErrors(ncclAllReduce(xp, yp, size, @T_NCCL, ncclSum, comm, 0));
checkCudaErrors(ncclAllReduce(xp, yp, y->num, @T_NCCL, ncclSum, comm, 0));
}
#endif

View File

@ -29,14 +29,13 @@ void NcclBroadcastOp::infer_shape() {
}
VarPtr NcclBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
static VarPtr(*nccl_reduce)(Var*, int) =
static auto nccl_reduce =
get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>();
return nccl_reduce(dout,root);
}
void NcclBroadcastOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("XDIM", JK::hex1(x->shape.size()));
}
#else // JIT
@ -49,11 +48,9 @@ void NcclBroadcastOp::jit_run() {
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
@if(@strcmp(@Tx,int64)==0, ncclInt64)
)
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
int size = 1 @for(i, 0, XDIM, * xshape@{i});
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
checkCudaErrors(ncclBroadcast(xp, yp, size, @T_NCCL, root, comm, 0));
checkCudaErrors(ncclBroadcast(xp, yp, y->num, @T_NCCL, root, comm, 0));
}
#endif

View File

@ -29,14 +29,13 @@ void NcclReduceOp::infer_shape() {
}
VarPtr NcclReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
static VarPtr(*nccl_broadcast)(Var*, int) =
static auto nccl_broadcast =
get_op_info("nccl_broadcast").get_constructor<VarPtr, Var*, int>();
return nccl_broadcast(dout,root);
}
void NcclReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("XDIM", JK::hex1(x->shape.size()));
}
#else // JIT
@ -49,11 +48,11 @@ void NcclReduceOp::jit_run() {
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
@if(@strcmp(@Tx,int64)==0, ncclInt64)
)
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
int size = 1 @for(i, 0, XDIM, * xshape@{i});
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
checkCudaErrors(ncclReduce(xp, yp, size, @T_NCCL, ncclSum, root, comm, 0));
checkCudaErrors(ncclReduce(xp, yp, y->num, @T_NCCL, ncclSum, root, comm, 0));
if (root != mpi_world_rank)
checkCudaErrors(cudaMemsetAsync(yp, 0, y->size));
}
#endif

View File

@ -25,12 +25,11 @@ nccl_initer() {
if (mpi_world_rank == 0)
checkCudaErrors(ncclGetUniqueId(&id));
MPI_CHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
LOGv << "NCCL init in device" << mpi_local_rank;
checkCudaErrors(cudaSetDevice(mpi_local_rank));
#ifdef HAS_CUDA
event_queue.run_sync([]() {
checkCudaErrors(cudaSetDevice(mpi_local_rank));
});
#endif
checkCudaErrors(ncclCommInitRank(&comm, mpi_world_size, id, mpi_world_rank));
}

View File

@ -34,8 +34,8 @@ MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) {
ASSERT(op == ns_add) << "Not supported MPI op" << op;
#ifdef HAS_CUDA
if (use_cuda) {
static VarPtr(*nccl_all_reduce)(Var*) = has_op("nccl_all_reduce")
? ; get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>()
static auto nccl_all_reduce = has_op("nccl_all_reduce")
? get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>()
: nullptr;
if (nccl_all_reduce) {
auto var = nccl_all_reduce(x);

View File

@ -19,13 +19,10 @@ namespace jittor {
MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
#ifdef HAS_CUDA
if (use_cuda) {
static VarPtr(*nccl_broadcast)(Var*, int) = nullptr;
if (!nccl_broadcast && has_op("nccl_broadcast")) {
nccl_broadcast = get_op_info("nccl_broadcast")
.get_constructor<VarPtr, Var*, int>();
}
static auto nccl_broadcast = has_op("nccl_broadcast")
? get_op_info("nccl_broadcast").get_constructor<VarPtr, Var*, int>()
: nullptr;
if (nccl_broadcast) {
LOGr << "nccl";
auto var = nccl_broadcast(x, root);
forward(var);
return;
@ -50,16 +47,19 @@ VarPtr MpiBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
void MpiBroadcastOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("XDIM", JK::hex1(x->shape.size()));
}
#else // JIT
#ifdef JIT_cpu
void MpiBroadcastOp::jit_run() {
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
int size = 1 @for(i, 0, XDIM, * xshape@{i});
@define(T_MPI,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, MPI_FLOAT)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, MPI_INT)
@if(@strcmp(@Tx,float64)==0 || @strcmp(@Tx,double)==0, MPI_DOUBLE)
@if(@strcmp(@Tx,int64)==0, MPI_DOUBLE_INT)
)
auto* __restrict__ yp = y->ptr<Tx>();
MPI_Bcast(yp, size, MPI_FLOAT, root, MPI_COMM_WORLD);
MPI_Bcast(yp, y->num, T_MPI, root, MPI_COMM_WORLD);
}
#else
void MpiBroadcastOp::jit_run() {

View File

@ -34,8 +34,8 @@ MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(r
ASSERT(op == ns_add) << "Not supported MPI op" << op;
#ifdef HAS_CUDA
if (use_cuda) {
static VarPtr(*nccl_reduce)(Var*, int) = has_op("nccl_reduce")
? get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>();
static auto nccl_reduce = has_op("nccl_reduce")
? get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>()
: nullptr;
if (nccl_reduce) {
auto var = nccl_reduce(x, root);
@ -78,9 +78,9 @@ void MpiReduceOp::jit_run() {
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
index_t num = y->num;
MPI_CHECK(MPI_Reduce(xp, yp, num, T_MPI, OP_MPI, root, MPI_COMM_WORLD));
if (root != mpi_world_rank)
for (index_t i=0; i<num; i++) yp[i] = 0;
MPI_CHECK(MPI_Reduce(xp, yp, num, T_MPI, OP_MPI, root, MPI_COMM_WORLD));
}
#else
void MpiReduceOp::jit_run() {