nccl local rank can larger than device count

This commit is contained in:
Dun Liang 2021-05-12 21:23:49 +08:00
parent f895d2f2e4
commit 25def1f399
2 changed files with 18 additions and 10 deletions

View File

@ -28,10 +28,12 @@ nccl_initer() {
int device_count = get_device_count();
if (!device_count) return;
if (!inside_mpi) return;
if (mpi_local_rank >= device_count)
LOGf << "mpi_local_rank(">>mpi_local_rank>>") is larger than device_count("
>>device_count>>")";
nccl_device_id = mpi_local_rank;
if (mpi_local_rank >= device_count) {
LOGw << "mpi_local_rank(">>mpi_local_rank>>") is larger than device_count("
>>device_count>>")";
nccl_device_id = nccl_device_id % device_count;
}
LOGv << "NCCL init in device" << nccl_device_id << "local_rank" << mpi_local_rank;
checkCudaErrors(cudaSetDevice(nccl_device_id));
event_queue.run_sync([]() {

View File

@ -9,12 +9,14 @@ from .compiler import *
from jittor_utils import run_cmd, get_version, get_int_version
from jittor.utils.misc import download_url_to_local
def search_file(dirs, name):
def search_file(dirs, name, prefer_version=()):
for d in dirs:
fname = os.path.join(d, name)
if os.path.isfile(fname):
LOG.i(f"found {fname}")
return fname
for i in range(len(prefer_version),-1,-1):
vname = ".".join((fname,)+prefer_version[:i])
if os.path.isfile(vname):
LOG.i(f"found {vname}")
return vname
LOG.f(f"file {name} not found in {dirs}")
def install_mkl(root_folder):
@ -162,15 +164,19 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
extra_include_path = os.path.abspath(os.path.join(cuda_include, "..", "targets/x86_64-linux/include"))
extra_lib_path = os.path.abspath(os.path.join(cuda_lib, "..", "targets/x86_64-linux/lib"))
cuda_include_name = search_file([cuda_include, extra_include_path, "/usr/include"], lib_name+".h")
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"lib{lib_name}.so")
# cuda11 prefer cudnn 8
nvcc_version = get_int_version(nvcc_path)
prefer_version = ()
if nvcc_version[0] == 11:
prefer_version = ("8",)
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"lib{lib_name}.so", prefer_version)
if lib_name == "cudnn":
# cudnn cannot found libcudnn_cnn_train.so.8, we manual link for it.
nvcc_version = get_int_version(nvcc_path)
if nvcc_version >= (11,0,0):
libs = ["libcudnn_ops_infer.so", "libcudnn_ops_train.so", "libcudnn_cnn_infer.so", "libcudnn_cnn_train.so"]
for l in libs:
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], l)
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], l, prefer_version)
ctypes.CDLL(ex_cudnn_path, dlopen_flags)
# dynamic link cuda library