mirror of https://github.com/Jittor/Jittor
fix cub home
This commit is contained in:
parent
c80e0168c0
commit
8d79e298e5
|
@ -99,6 +99,7 @@ def install_cub(root_folder):
|
|||
|
||||
def setup_cub():
|
||||
global cub_home
|
||||
cub_home = ""
|
||||
from pathlib import Path
|
||||
cub_path = os.path.join(str(Path.home()), ".cache", "jittor", "cub")
|
||||
cuda_version = int(get_version(nvcc_path)[1:-1].split('.')[0])
|
||||
|
@ -106,6 +107,7 @@ def setup_cub():
|
|||
if cuda_version < 11:
|
||||
cub_home = install_cub(cub_path)
|
||||
extra_flags = f"-I{cub_home}"
|
||||
cub_home += "/"
|
||||
setup_cuda_lib("cub", link=False, extra_flags=extra_flags)
|
||||
|
||||
def setup_cuda_extern():
|
||||
|
|
|
@ -135,7 +135,7 @@ class DepthwiseConv(Function):
|
|||
x, weight = self.save_vars
|
||||
Kh, Kw = self.Khw
|
||||
return jt.code([x.shape, weight.shape], [x.dtype, weight.dtype], [x, weight, grad],
|
||||
cuda_header = f"#include <{jt.compile_extern.cub_home}/cub/cub.cuh>"+"""
|
||||
cuda_header = f"#include <{jt.compile_extern.cub_home}cub/cub.cuh>"+"""
|
||||
template <typename T>
|
||||
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
|
||||
typedef cub::WarpReduce<T> WarpReduce;
|
||||
|
|
Loading…
Reference in New Issue