fix cub home

This commit is contained in:
Dun Liang 2020-12-07 23:06:09 +08:00
parent c80e0168c0
commit 8d79e298e5
2 changed files with 3 additions and 1 deletions

View File

@ -99,6 +99,7 @@ def install_cub(root_folder):
def setup_cub():
global cub_home
cub_home = ""
from pathlib import Path
cub_path = os.path.join(str(Path.home()), ".cache", "jittor", "cub")
cuda_version = int(get_version(nvcc_path)[1:-1].split('.')[0])
@ -106,6 +107,7 @@ def setup_cub():
if cuda_version < 11:
cub_home = install_cub(cub_path)
extra_flags = f"-I{cub_home}"
cub_home += "/"
setup_cuda_lib("cub", link=False, extra_flags=extra_flags)
def setup_cuda_extern():

View File

@ -135,7 +135,7 @@ class DepthwiseConv(Function):
x, weight = self.save_vars
Kh, Kw = self.Khw
return jt.code([x.shape, weight.shape], [x.dtype, weight.dtype], [x, weight, grad],
cuda_header = f"#include <{jt.compile_extern.cub_home}/cub/cub.cuh>"+"""
cuda_header = f"#include <{jt.compile_extern.cub_home}cub/cub.cuh>"+"""
template <typename T>
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
typedef cub::WarpReduce<T> WarpReduce;