mirror of https://github.com/Jittor/Jittor
polish mpi code
This commit is contained in:
parent
8da6ef720b
commit
3ecce8eb9e
|
@ -18,6 +18,10 @@
|
|||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
static auto nccl_all_reduce =
|
||||
get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>();
|
||||
|
||||
NcclAllReduceOp::NcclAllReduceOp(Var* x) : x(x) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
|
@ -29,14 +33,11 @@ void NcclAllReduceOp::infer_shape() {
|
|||
}
|
||||
|
||||
VarPtr NcclAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
static VarPtr(*nccl_all_reduce)(Var*) =
|
||||
get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>();
|
||||
return nccl_all_reduce(dout);
|
||||
}
|
||||
|
||||
void NcclAllReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
@ -49,11 +50,9 @@ void NcclAllReduceOp::jit_run() {
|
|||
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
|
||||
@if(@strcmp(@Tx,int64)==0, ncclInt64)
|
||||
)
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
checkCudaErrors(ncclAllReduce(xp, yp, size, @T_NCCL, ncclSum, comm, 0));
|
||||
checkCudaErrors(ncclAllReduce(xp, yp, y->num, @T_NCCL, ncclSum, comm, 0));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -29,14 +29,13 @@ void NcclBroadcastOp::infer_shape() {
|
|||
}
|
||||
|
||||
VarPtr NcclBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
static VarPtr(*nccl_reduce)(Var*, int) =
|
||||
static auto nccl_reduce =
|
||||
get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>();
|
||||
return nccl_reduce(dout,root);
|
||||
}
|
||||
|
||||
void NcclBroadcastOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
@ -49,11 +48,9 @@ void NcclBroadcastOp::jit_run() {
|
|||
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
|
||||
@if(@strcmp(@Tx,int64)==0, ncclInt64)
|
||||
)
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
checkCudaErrors(ncclBroadcast(xp, yp, size, @T_NCCL, root, comm, 0));
|
||||
checkCudaErrors(ncclBroadcast(xp, yp, y->num, @T_NCCL, root, comm, 0));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -29,14 +29,13 @@ void NcclReduceOp::infer_shape() {
|
|||
}
|
||||
|
||||
VarPtr NcclReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
static VarPtr(*nccl_broadcast)(Var*, int) =
|
||||
static auto nccl_broadcast =
|
||||
get_op_info("nccl_broadcast").get_constructor<VarPtr, Var*, int>();
|
||||
return nccl_broadcast(dout,root);
|
||||
}
|
||||
|
||||
void NcclReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
@ -49,11 +48,11 @@ void NcclReduceOp::jit_run() {
|
|||
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
|
||||
@if(@strcmp(@Tx,int64)==0, ncclInt64)
|
||||
)
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
checkCudaErrors(ncclReduce(xp, yp, size, @T_NCCL, ncclSum, root, comm, 0));
|
||||
checkCudaErrors(ncclReduce(xp, yp, y->num, @T_NCCL, ncclSum, root, comm, 0));
|
||||
if (root != mpi_world_rank)
|
||||
checkCudaErrors(cudaMemsetAsync(yp, 0, y->size));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -25,12 +25,11 @@ nccl_initer() {
|
|||
if (mpi_world_rank == 0)
|
||||
checkCudaErrors(ncclGetUniqueId(&id));
|
||||
MPI_CHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
|
||||
LOGv << "NCCL init in device" << mpi_local_rank;
|
||||
checkCudaErrors(cudaSetDevice(mpi_local_rank));
|
||||
#ifdef HAS_CUDA
|
||||
event_queue.run_sync([]() {
|
||||
checkCudaErrors(cudaSetDevice(mpi_local_rank));
|
||||
});
|
||||
#endif
|
||||
checkCudaErrors(ncclCommInitRank(&comm, mpi_world_size, id, mpi_world_rank));
|
||||
}
|
||||
|
||||
|
|
|
@ -34,8 +34,8 @@ MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) {
|
|||
ASSERT(op == ns_add) << "Not supported MPI op" << op;
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static VarPtr(*nccl_all_reduce)(Var*) = has_op("nccl_all_reduce")
|
||||
? ; get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>()
|
||||
static auto nccl_all_reduce = has_op("nccl_all_reduce")
|
||||
? get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>()
|
||||
: nullptr;
|
||||
if (nccl_all_reduce) {
|
||||
auto var = nccl_all_reduce(x);
|
||||
|
|
|
@ -19,13 +19,10 @@ namespace jittor {
|
|||
MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static VarPtr(*nccl_broadcast)(Var*, int) = nullptr;
|
||||
if (!nccl_broadcast && has_op("nccl_broadcast")) {
|
||||
nccl_broadcast = get_op_info("nccl_broadcast")
|
||||
.get_constructor<VarPtr, Var*, int>();
|
||||
}
|
||||
static auto nccl_broadcast = has_op("nccl_broadcast")
|
||||
? get_op_info("nccl_broadcast").get_constructor<VarPtr, Var*, int>()
|
||||
: nullptr;
|
||||
if (nccl_broadcast) {
|
||||
LOGr << "nccl";
|
||||
auto var = nccl_broadcast(x, root);
|
||||
forward(var);
|
||||
return;
|
||||
|
@ -50,16 +47,19 @@ VarPtr MpiBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
|
||||
void MpiBroadcastOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void MpiBroadcastOp::jit_run() {
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
@define(T_MPI,
|
||||
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, MPI_FLOAT)
|
||||
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, MPI_INT)
|
||||
@if(@strcmp(@Tx,float64)==0 || @strcmp(@Tx,double)==0, MPI_DOUBLE)
|
||||
@if(@strcmp(@Tx,int64)==0, MPI_DOUBLE_INT)
|
||||
)
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
MPI_Bcast(yp, size, MPI_FLOAT, root, MPI_COMM_WORLD);
|
||||
MPI_Bcast(yp, y->num, T_MPI, root, MPI_COMM_WORLD);
|
||||
}
|
||||
#else
|
||||
void MpiBroadcastOp::jit_run() {
|
||||
|
|
|
@ -34,8 +34,8 @@ MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(r
|
|||
ASSERT(op == ns_add) << "Not supported MPI op" << op;
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static VarPtr(*nccl_reduce)(Var*, int) = has_op("nccl_reduce")
|
||||
? get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>();
|
||||
static auto nccl_reduce = has_op("nccl_reduce")
|
||||
? get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>()
|
||||
: nullptr;
|
||||
if (nccl_reduce) {
|
||||
auto var = nccl_reduce(x, root);
|
||||
|
@ -78,9 +78,9 @@ void MpiReduceOp::jit_run() {
|
|||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
index_t num = y->num;
|
||||
MPI_CHECK(MPI_Reduce(xp, yp, num, T_MPI, OP_MPI, root, MPI_COMM_WORLD));
|
||||
if (root != mpi_world_rank)
|
||||
for (index_t i=0; i<num; i++) yp[i] = 0;
|
||||
MPI_CHECK(MPI_Reduce(xp, yp, num, T_MPI, OP_MPI, root, MPI_COMM_WORLD));
|
||||
}
|
||||
#else
|
||||
void MpiReduceOp::jit_run() {
|
||||
|
|
Loading…
Reference in New Issue