mirror of https://github.com/Jittor/Jittor
cub include for cuda11
This commit is contained in:
parent
cbbb3e7de4
commit
6e3daf8c8b
|
@ -178,7 +178,7 @@ void CudnnConvBackwardXOp::jit_run() {
|
|||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED
|
||||
};
|
||||
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
||||
int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
||||
int perf_count;
|
||||
cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_algos];
|
||||
cudnnConvolutionBwdDataAlgo_t algo;
|
||||
|
|
|
@ -100,8 +100,12 @@ def install_cub(root_folder):
|
|||
def setup_cub():
|
||||
from pathlib import Path
|
||||
cub_path = os.path.join(str(Path.home()), ".cache", "jittor", "cub")
|
||||
cub_home = install_cub(cub_path)
|
||||
setup_cuda_lib("cub", link=False, extra_flags=f"-I{cub_home}")
|
||||
cuda_version = int(get_version(nvcc_path)[1:-1].split('.')[0])
|
||||
extra_flags = ""
|
||||
if cuda_version < 11:
|
||||
cub_home = install_cub(cub_path)
|
||||
extra_flags = f"-I{cub_home}"
|
||||
setup_cuda_lib("cub", link=False, extra_flags=extra_flags)
|
||||
|
||||
def setup_cuda_extern():
|
||||
if not has_cuda: return
|
||||
|
|
Loading…
Reference in New Issue