diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index 442f2bca..f1e2f2a0 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -1246,6 +1246,7 @@ jit_utils.add_backend(corex_compiler) for mod in jit_utils.backends: if mod.check(): break +is_cuda = os.path.basename(nvcc_path) == "nvcc" # build core gen_jit_flags() diff --git a/python/jittor/depthwise_conv.py b/python/jittor/depthwise_conv.py index ed17111d..1dd19c42 100644 --- a/python/jittor/depthwise_conv.py +++ b/python/jittor/depthwise_conv.py @@ -20,6 +20,8 @@ class DepthwiseConv(Function): self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) def execute(self, x, weight): + if not jt.flags.use_cuda or not jt.compiler.is_cuda: + return nn.conv2d(x, weight, None, self.stride, self.padding, self.dilation, x.shape[1]) self.save_vars = x, weight N,C,H,W = x.shape o,i,Kh,Kw = weight.shape diff --git a/python/jittor/nn.py b/python/jittor/nn.py index d7c255dc..7c72e014 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -494,7 +494,7 @@ class BCEWithLogitsLoss(Module): def softmax(x, dim=None, log=False): import jittor.other.code_softmax as code_softmax - if code_softmax.can_softmax_v1(x, dim) and not jt.compiler.has_rocm: + if code_softmax.can_softmax_v1(x, dim) and jt.compiler.is_cuda: return code_softmax.softmax_v1(x, log) if dim is None: dim = () if log: @@ -873,7 +873,7 @@ class Conv(Module): self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) self.groups = groups self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels - if self.is_depthwise_conv and jt.flags.use_cuda: + if self.is_depthwise_conv and jt.flags.use_cuda and jt.compiler.is_cuda: self.depthwise_conv = DepthwiseConv(stride, padding, dilation) assert in_channels % groups == 0, 'in_channels must be divisible by groups' assert out_channels % groups == 0, 'out_channels must be divisible by groups' @@ -2544,7 +2544,7 @@ class RNNBase(Module): copy_to('weight' + param_name, offset_idx + idx, idx) return num_gates - if jt.flags.use_cuda and jt.cudnn and not jt.compiler.has_rocm: + if jt.flags.use_cuda and jt.cudnn and jt.compiler.is_cuda: if getattr(self, '_cudnn_weight_size', None) is None: offset_array = jt.cudnn.cudnn_rnn_weight_offset( cudnn_mode, @@ -2630,7 +2630,7 @@ class RNNBase(Module): hx = (jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype), jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype)) - if jt.flags.use_cuda and jt.cudnn and self.proj_size == 0 and not jt.compiler.has_rocm: + if jt.flags.use_cuda and jt.cudnn and self.proj_size == 0 and jt.compiler.is_cuda: return self._execute_cudnn_rnn(input, hx) else: hidden_n = [] diff --git a/python/jittor/src/type/fp16_compute.h b/python/jittor/src/type/fp16_compute.h index edfc0d35..8222e0c3 100644 --- a/python/jittor/src/type/fp16_compute.h +++ b/python/jittor/src/type/fp16_compute.h @@ -83,12 +83,12 @@ vfill(T* __restrict__ a) { if (nbyte<=0) return; if (nbyte>=16) { auto* __restrict__ aa = (int4* __restrict__)a; - aa[0] = {0}; + aa[0].x = aa[0].y = aa[0].z = aa[0].w = 0; return vfill(aa+1); } if (nbyte>=8) { auto* __restrict__ aa = (int2* __restrict__)a; - aa[0] = {0}; + aa[0].x = aa[0].y = 0; return vfill(aa+1); } if (nbyte>=4) {