mirror of https://github.com/Jittor/Jittor
setup nccl
This commit is contained in:
parent
195f754cba
commit
8127107d87
|
@ -79,6 +79,11 @@ const char *_cudaGetErrorEnum(cusolverStatus_t error);
|
|||
const char *_cudaGetErrorEnum(curandStatus_t error);
|
||||
#endif
|
||||
|
||||
#ifdef NCCL_H_
|
||||
// cuRAND API errors
|
||||
const char *_cudaGetErrorEnum(ncclResult_t error);
|
||||
#endif
|
||||
|
||||
#ifdef NV_NPPIDEFS_H
|
||||
// NPP API errors
|
||||
const char *_cudaGetErrorEnum(NppStatus error);
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2019 Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "nccl_test_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
|
||||
#include <nccl.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
|
||||
#ifndef JIT
|
||||
const char *_cudaGetErrorEnum(ncclResult_t error) {
|
||||
return ncclGetErrorString(error);
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
NcclTestOp::NcclTestOp(string cmd) : cmd(cmd) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void NcclTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
void NcclTestOp::jit_run() {
|
||||
auto args = split(cmd, " ");
|
||||
if (!cmd.size()) args.clear();
|
||||
vector<char*> v(args.size());
|
||||
for (uint i=0; i<args.size(); i++)
|
||||
v[i] = &args[i][0];
|
||||
output->ptr<T>()[0] = 123;
|
||||
|
||||
|
||||
|
||||
//managing 4 devices
|
||||
int nDev;
|
||||
checkCudaErrors(cudaGetDeviceCount(&nDev));
|
||||
nDev = std::min(nDev, 2);
|
||||
|
||||
ncclComm_t comms[nDev];
|
||||
int size = 32*1024*1024;
|
||||
int devs[4] = { 0, 1, 2, 3 };
|
||||
|
||||
|
||||
//allocating and initializing device buffers
|
||||
float** sendbuff = (float**)malloc(nDev * sizeof(float*));
|
||||
float** recvbuff = (float**)malloc(nDev * sizeof(float*));
|
||||
cudaStream_t* s = (cudaStream_t*)malloc(sizeof(cudaStream_t)*nDev);
|
||||
|
||||
|
||||
for (int i = 0; i < nDev; ++i) {
|
||||
checkCudaErrors(cudaSetDevice(i));
|
||||
checkCudaErrors(cudaMalloc(sendbuff + i, size * sizeof(float)));
|
||||
checkCudaErrors(cudaMalloc(recvbuff + i, size * sizeof(float)));
|
||||
checkCudaErrors(cudaMemset(sendbuff[i], 1, size * sizeof(float)));
|
||||
checkCudaErrors(cudaMemset(recvbuff[i], 0, size * sizeof(float)));
|
||||
checkCudaErrors(cudaStreamCreate(s+i));
|
||||
}
|
||||
|
||||
|
||||
//initializing NCCL
|
||||
checkCudaErrors(ncclCommInitAll(comms, nDev, devs));
|
||||
|
||||
|
||||
//calling NCCL communication API. Group API is required when using
|
||||
//multiple devices per thread
|
||||
checkCudaErrors(ncclGroupStart());
|
||||
for (int i = 0; i < nDev; ++i)
|
||||
checkCudaErrors(ncclAllReduce((const void*)sendbuff[i], (void*)recvbuff[i], size, ncclFloat, ncclSum,
|
||||
comms[i], s[i]));
|
||||
checkCudaErrors(ncclGroupEnd());
|
||||
|
||||
|
||||
//synchronizing on CUDA streams to wait for completion of NCCL operation
|
||||
for (int i = 0; i < nDev; ++i) {
|
||||
checkCudaErrors(cudaSetDevice(i));
|
||||
checkCudaErrors(cudaStreamSynchronize(s[i]));
|
||||
}
|
||||
|
||||
|
||||
//free device buffers
|
||||
for (int i = 0; i < nDev; ++i) {
|
||||
checkCudaErrors(cudaSetDevice(i));
|
||||
checkCudaErrors(cudaFree(sendbuff[i]));
|
||||
checkCudaErrors(cudaFree(recvbuff[i]));
|
||||
}
|
||||
|
||||
|
||||
//finalizing NCCL
|
||||
for(int i = 0; i < nDev; ++i)
|
||||
ncclCommDestroy(comms[i]);
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,23 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct NcclTestOp : Op {
|
||||
Var* output;
|
||||
string cmd;
|
||||
|
||||
NcclTestOp(string cmd);
|
||||
|
||||
const char* name() const override { return "nccl_test"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -3,8 +3,9 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import os, sys
|
||||
import os, sys, shutil
|
||||
from .compiler import *
|
||||
from jittor_utils import run_cmd
|
||||
from jittor.dataset.utils import download_url_to_local
|
||||
|
||||
def search_file(dirs, name):
|
||||
|
@ -171,10 +172,10 @@ def install_cutt(root_folder):
|
|||
true_md5 = "a6f4f7f75310a69b131e21f1ebec768a"
|
||||
|
||||
if os.path.exists(fullname):
|
||||
md5 = os.popen('md5sum ' + fullname).read().split()[0]
|
||||
md5 = run_cmd('md5sum '+fullname).split()[0]
|
||||
if md5 != true_md5:
|
||||
os.system('rm ' + fullname)
|
||||
os.system('rm -rf ' + dirname)
|
||||
os.remove(fullname)
|
||||
shutil.rmtree(dirname)
|
||||
if not os.path.isfile(os.path.join(dirname, "bin", "cutt_test")):
|
||||
LOG.i("Downloading cutt...")
|
||||
download_url_to_local(url, filename, root_folder, true_md5)
|
||||
|
@ -186,9 +187,9 @@ def install_cutt(root_folder):
|
|||
zf.extractall(path=root_folder)
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
raise
|
||||
zf.close()
|
||||
|
||||
from jittor_utils import run_cmd
|
||||
LOG.i("installing cutt...")
|
||||
run_cmd(f"make", cwd=dirname)
|
||||
return dirname
|
||||
|
@ -233,7 +234,73 @@ def setup_cutt():
|
|||
LOG.vv("Get cutt_ops: "+str(dir(cutt_ops)))
|
||||
|
||||
|
||||
def install_nccl(root_folder):
|
||||
url = "https://github.com/NVIDIA/nccl/archive/v2.6.4-1.tar.gz"
|
||||
|
||||
filename = "nccl.tgz"
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
dirname = os.path.join(root_folder, "nccl-2.6.4-1")
|
||||
true_md5 = "38d7a9e98d95a99df0a4f1ad6fb50fa7"
|
||||
|
||||
if os.path.exists(fullname):
|
||||
md5 = run_cmd('md5sum '+fullname).split()[0]
|
||||
if md5 != true_md5:
|
||||
os.remove(fullname)
|
||||
if os.path.isdir(dirname):
|
||||
shutil.rmtree(dirname)
|
||||
if not os.path.isfile(os.path.join(dirname, "build", "lib", "libnccl.so")):
|
||||
LOG.i("Downloading nccl...")
|
||||
download_url_to_local(url, filename, root_folder, true_md5)
|
||||
|
||||
import tarfile
|
||||
with tarfile.open(fullname, "r") as tar:
|
||||
tar.extractall(root_folder)
|
||||
|
||||
LOG.i("installing nccl...")
|
||||
arch_flag = f" -arch={','.join(map(lambda x:'sm_'+str(x),flags.cuda_archs))} "
|
||||
run_cmd(f"make -j8 src.build CUDA_HOME='{cuda_home}' NVCC_GENCODE='{arch_flag}' ", cwd=dirname)
|
||||
return dirname
|
||||
|
||||
def setup_nccl():
|
||||
global nccl_ops, use_nccl
|
||||
if not has_cuda:
|
||||
use_nccl = False
|
||||
return
|
||||
use_nccl = os.environ.get("use_nccl", "1")=="1"
|
||||
nccl_ops = None
|
||||
if not use_nccl: return
|
||||
nccl_include_path = os.environ.get("nccl_include_path")
|
||||
nccl_lib_path = os.environ.get("nccl_lib_path")
|
||||
|
||||
if nccl_lib_path is None or nccl_include_path is None:
|
||||
LOG.v("setup nccl...")
|
||||
# nccl_path decouple with cc_path
|
||||
from pathlib import Path
|
||||
nccl_path = os.path.join(str(Path.home()), ".cache", "jittor", "nccl")
|
||||
|
||||
make_cache_dir(nccl_path)
|
||||
nccl_home = install_nccl(nccl_path)
|
||||
nccl_include_path = os.path.join(nccl_home, "build", "include")
|
||||
nccl_lib_path = os.path.join(nccl_home, "build", "lib")
|
||||
|
||||
nccl_lib_name = os.path.join(nccl_lib_path, "libnccl.so")
|
||||
assert os.path.isdir(nccl_include_path)
|
||||
assert os.path.isdir(nccl_lib_path)
|
||||
assert os.path.isfile(nccl_lib_name), nccl_lib_name
|
||||
LOG.v(f"nccl_include_path: {nccl_include_path}")
|
||||
LOG.v(f"nccl_lib_path: {nccl_lib_path}")
|
||||
LOG.v(f"nccl_lib_name: {nccl_lib_name}")
|
||||
# We do not link manualy, link in custom ops
|
||||
ctypes.CDLL(nccl_lib_name, dlopen_flags)
|
||||
|
||||
nccl_op_dir = os.path.join(jittor_path, "extern", "cuda", "nccl", "ops")
|
||||
nccl_op_files = [os.path.join(nccl_op_dir, name) for name in os.listdir(nccl_op_dir)]
|
||||
nccl_ops = compile_custom_ops(nccl_op_files,
|
||||
extra_flags=f" -I'{nccl_include_path}'")
|
||||
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
|
||||
|
||||
setup_cutt()
|
||||
setup_mkl()
|
||||
setup_nccl()
|
||||
|
||||
setup_cuda_extern()
|
||||
|
|
|
@ -689,8 +689,9 @@ def compile_extern():
|
|||
def check_cuda():
|
||||
if nvcc_path == "":
|
||||
return
|
||||
global cc_flags, has_cuda, core_link_flags, cuda_dir, cuda_lib, cuda_include
|
||||
global cc_flags, has_cuda, core_link_flags, cuda_dir, cuda_lib, cuda_include, cuda_home
|
||||
cuda_dir = os.path.dirname(get_full_path_of_executable(nvcc_path))
|
||||
cuda_home = os.path.abspath(os.path.join(cuda_dir, ".."))
|
||||
assert cuda_dir.endswith("bin") and "cuda" in cuda_dir.lower(), f"Wrong cuda_dir: {cuda_dir}"
|
||||
cuda_include = os.path.abspath(os.path.join(cuda_dir, "..", "include"))
|
||||
cuda_lib = os.path.abspath(os.path.join(cuda_dir, "..", "lib64"))
|
||||
|
|
|
@ -45,6 +45,7 @@ def check(jt_model, torch_model, shape, near_data):
|
|||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestArgPoolOp(unittest.TestCase):
|
||||
@unittest.skipIf(jt.compiler.has_cuda, "No cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda(self):
|
||||
jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1))
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
import unittest
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.nccl_ops is None, "no nccl found")
|
||||
class TestNccl(unittest.TestCase):
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_nccl(self):
|
||||
assert jt.compile_extern.nccl_ops.nccl_test("").data == 123
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue