mirror of https://github.com/Jittor/Jittor
nccl local rank can larger than device count
This commit is contained in:
parent
f895d2f2e4
commit
25def1f399
|
@ -28,10 +28,12 @@ nccl_initer() {
|
|||
int device_count = get_device_count();
|
||||
if (!device_count) return;
|
||||
if (!inside_mpi) return;
|
||||
if (mpi_local_rank >= device_count)
|
||||
LOGf << "mpi_local_rank(">>mpi_local_rank>>") is larger than device_count("
|
||||
>>device_count>>")";
|
||||
nccl_device_id = mpi_local_rank;
|
||||
if (mpi_local_rank >= device_count) {
|
||||
LOGw << "mpi_local_rank(">>mpi_local_rank>>") is larger than device_count("
|
||||
>>device_count>>")";
|
||||
nccl_device_id = nccl_device_id % device_count;
|
||||
}
|
||||
LOGv << "NCCL init in device" << nccl_device_id << "local_rank" << mpi_local_rank;
|
||||
checkCudaErrors(cudaSetDevice(nccl_device_id));
|
||||
event_queue.run_sync([]() {
|
||||
|
|
|
@ -9,12 +9,14 @@ from .compiler import *
|
|||
from jittor_utils import run_cmd, get_version, get_int_version
|
||||
from jittor.utils.misc import download_url_to_local
|
||||
|
||||
def search_file(dirs, name):
|
||||
def search_file(dirs, name, prefer_version=()):
|
||||
for d in dirs:
|
||||
fname = os.path.join(d, name)
|
||||
if os.path.isfile(fname):
|
||||
LOG.i(f"found {fname}")
|
||||
return fname
|
||||
for i in range(len(prefer_version),-1,-1):
|
||||
vname = ".".join((fname,)+prefer_version[:i])
|
||||
if os.path.isfile(vname):
|
||||
LOG.i(f"found {vname}")
|
||||
return vname
|
||||
LOG.f(f"file {name} not found in {dirs}")
|
||||
|
||||
def install_mkl(root_folder):
|
||||
|
@ -162,15 +164,19 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
extra_include_path = os.path.abspath(os.path.join(cuda_include, "..", "targets/x86_64-linux/include"))
|
||||
extra_lib_path = os.path.abspath(os.path.join(cuda_lib, "..", "targets/x86_64-linux/lib"))
|
||||
cuda_include_name = search_file([cuda_include, extra_include_path, "/usr/include"], lib_name+".h")
|
||||
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"lib{lib_name}.so")
|
||||
# cuda11 prefer cudnn 8
|
||||
nvcc_version = get_int_version(nvcc_path)
|
||||
prefer_version = ()
|
||||
if nvcc_version[0] == 11:
|
||||
prefer_version = ("8",)
|
||||
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"lib{lib_name}.so", prefer_version)
|
||||
|
||||
if lib_name == "cudnn":
|
||||
# cudnn cannot found libcudnn_cnn_train.so.8, we manual link for it.
|
||||
nvcc_version = get_int_version(nvcc_path)
|
||||
if nvcc_version >= (11,0,0):
|
||||
libs = ["libcudnn_ops_infer.so", "libcudnn_ops_train.so", "libcudnn_cnn_infer.so", "libcudnn_cnn_train.so"]
|
||||
for l in libs:
|
||||
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], l)
|
||||
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], l, prefer_version)
|
||||
ctypes.CDLL(ex_cudnn_path, dlopen_flags)
|
||||
|
||||
# dynamic link cuda library
|
||||
|
|
Loading…
Reference in New Issue