mirror of https://github.com/Jittor/Jittor
failback to mpi if nccl not available
This commit is contained in:
parent
21dad69e51
commit
8f2274059c
|
@ -39,6 +39,12 @@ nccl_initer() {
|
|||
event_queue.run_sync([]() {
|
||||
checkCudaErrors(cudaSetDevice(nccl_device_id));
|
||||
});
|
||||
if (mpi_local_size > device_count) {
|
||||
// NCCL not support multiple process on one GPU,
|
||||
// failback use MPI
|
||||
return;
|
||||
}
|
||||
use_device_mpi = true;
|
||||
if (mpi_world_rank == 0)
|
||||
checkCudaErrors(ncclGetUniqueId(&id));
|
||||
MPI_CHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
|
||||
|
|
|
@ -27,9 +27,11 @@ namespace jittor {
|
|||
|
||||
extern int mpi_world_size;
|
||||
extern int mpi_world_rank;
|
||||
extern int mpi_local_size;
|
||||
extern int mpi_local_rank;
|
||||
extern bool inside_mpi;
|
||||
extern bool mpi_enabled;
|
||||
extern bool use_device_mpi;
|
||||
|
||||
/**
|
||||
Return number of MPI nodes.
|
||||
|
|
|
@ -37,7 +37,7 @@ MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) {
|
|||
}
|
||||
ASSERT(op == ns_add) << "Not supported MPI op" << op;
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
if (use_device_mpi) {
|
||||
static auto nccl_all_reduce = has_op("nccl_all_reduce")
|
||||
? get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>()
|
||||
: nullptr;
|
||||
|
|
|
@ -22,7 +22,7 @@ MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
|
|||
return;
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
if (use_device_mpi) {
|
||||
static auto nccl_broadcast = has_op("nccl_broadcast")
|
||||
? get_op_info("nccl_broadcast").get_constructor<VarPtr, Var*, int>()
|
||||
: nullptr;
|
||||
|
|
|
@ -37,7 +37,7 @@ MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(r
|
|||
}
|
||||
ASSERT(op == ns_add) << "Not supported MPI op" << op;
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
if (use_device_mpi) {
|
||||
static auto nccl_reduce = has_op("nccl_reduce")
|
||||
? get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>()
|
||||
: nullptr;
|
||||
|
|
|
@ -30,9 +30,11 @@ namespace jittor {
|
|||
|
||||
int mpi_world_size = 1;
|
||||
int mpi_world_rank = 0;
|
||||
int mpi_local_size = 1;
|
||||
int mpi_local_rank = 0;
|
||||
bool inside_mpi = false;
|
||||
bool mpi_enabled = false;
|
||||
bool use_device_mpi = false;
|
||||
|
||||
int _mpi_world_size() {
|
||||
return mpi_enabled ? mpi_world_size : 1;
|
||||
|
@ -96,7 +98,12 @@ mpi_initer() {
|
|||
if (p == mpi_world_rank) break;
|
||||
if (hostHashs[p] == hostHashs[mpi_world_rank]) mpi_local_rank++;
|
||||
}
|
||||
mpi_local_size = 0;
|
||||
for (int p=0; p<mpi_world_size; p++) {
|
||||
if (hostHashs[p] == hostHashs[mpi_world_rank]) mpi_local_size++;
|
||||
}
|
||||
LOGv << "MPI init finished: local" << mpi_local_rank
|
||||
<< "size" << mpi_local_size
|
||||
<< "global" << mpi_world_rank
|
||||
<< "size" << mpi_world_size;
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.2'
|
||||
__version__ = '1.2.3.3'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -21,6 +21,7 @@ if has_cupy:
|
|||
device_num = 0
|
||||
if jt.mpi:
|
||||
device_num = jt.mpi.local_rank()
|
||||
device_num = device_num % cp.cuda.runtime.getDeviceCount()
|
||||
cupy_device = cp.cuda.Device(device_num)
|
||||
cupy_device.__enter__()
|
||||
|
||||
|
|
Loading…
Reference in New Issue