cub include for cuda11

This commit is contained in:
Dun Liang 2020-11-26 23:08:59 +08:00
parent cbbb3e7de4
commit 6e3daf8c8b
2 changed files with 7 additions and 3 deletions

View File

@ -178,7 +178,7 @@ void CudnnConvBackwardXOp::jit_run() {
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED
};
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
int perf_count;
cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_algos];
cudnnConvolutionBwdDataAlgo_t algo;

View File

@ -100,8 +100,12 @@ def install_cub(root_folder):
def setup_cub():
from pathlib import Path
cub_path = os.path.join(str(Path.home()), ".cache", "jittor", "cub")
cub_home = install_cub(cub_path)
setup_cuda_lib("cub", link=False, extra_flags=f"-I{cub_home}")
cuda_version = int(get_version(nvcc_path)[1:-1].split('.')[0])
extra_flags = ""
if cuda_version < 11:
cub_home = install_cub(cub_path)
extra_flags = f"-I{cub_home}"
setup_cuda_lib("cub", link=False, extra_flags=extra_flags)
def setup_cuda_extern():
if not has_cuda: return