failback to mpi if nccl not available

This commit is contained in:
Dun Liang 2021-05-14 13:45:38 +08:00
parent 21dad69e51
commit 8f2274059c
8 changed files with 20 additions and 4 deletions

View File

@ -39,6 +39,12 @@ nccl_initer() {
event_queue.run_sync([]() {
checkCudaErrors(cudaSetDevice(nccl_device_id));
});
if (mpi_local_size > device_count) {
// NCCL not support multiple process on one GPU,
// failback use MPI
return;
}
use_device_mpi = true;
if (mpi_world_rank == 0)
checkCudaErrors(ncclGetUniqueId(&id));
MPI_CHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));

View File

@ -27,9 +27,11 @@ namespace jittor {
extern int mpi_world_size;
extern int mpi_world_rank;
extern int mpi_local_size;
extern int mpi_local_rank;
extern bool inside_mpi;
extern bool mpi_enabled;
extern bool use_device_mpi;
/**
Return number of MPI nodes.

View File

@ -37,7 +37,7 @@ MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) {
}
ASSERT(op == ns_add) << "Not supported MPI op" << op;
#ifdef HAS_CUDA
if (use_cuda) {
if (use_device_mpi) {
static auto nccl_all_reduce = has_op("nccl_all_reduce")
? get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>()
: nullptr;

View File

@ -22,7 +22,7 @@ MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
return;
}
#ifdef HAS_CUDA
if (use_cuda) {
if (use_device_mpi) {
static auto nccl_broadcast = has_op("nccl_broadcast")
? get_op_info("nccl_broadcast").get_constructor<VarPtr, Var*, int>()
: nullptr;

View File

@ -37,7 +37,7 @@ MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(r
}
ASSERT(op == ns_add) << "Not supported MPI op" << op;
#ifdef HAS_CUDA
if (use_cuda) {
if (use_device_mpi) {
static auto nccl_reduce = has_op("nccl_reduce")
? get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>()
: nullptr;

View File

@ -30,9 +30,11 @@ namespace jittor {
int mpi_world_size = 1;
int mpi_world_rank = 0;
int mpi_local_size = 1;
int mpi_local_rank = 0;
bool inside_mpi = false;
bool mpi_enabled = false;
bool use_device_mpi = false;
int _mpi_world_size() {
return mpi_enabled ? mpi_world_size : 1;
@ -96,7 +98,12 @@ mpi_initer() {
if (p == mpi_world_rank) break;
if (hostHashs[p] == hostHashs[mpi_world_rank]) mpi_local_rank++;
}
mpi_local_size = 0;
for (int p=0; p<mpi_world_size; p++) {
if (hostHashs[p] == hostHashs[mpi_world_rank]) mpi_local_size++;
}
LOGv << "MPI init finished: local" << mpi_local_rank
<< "size" << mpi_local_size
<< "global" << mpi_world_rank
<< "size" << mpi_world_size;
}

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.2'
__version__ = '1.2.3.3'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -21,6 +21,7 @@ if has_cupy:
device_num = 0
if jt.mpi:
device_num = jt.mpi.local_rank()
device_num = device_num % cp.cuda.runtime.getDeviceCount()
cupy_device = cp.cuda.Device(device_num)
cupy_device.__enter__()