1. fix compilation failure of fp16_compute.h under rocm platform

2. add jt.compiler.is_cuda to identify if the current device is CUDA
This commit is contained in:
lzhengning 2022-12-17 20:50:09 +08:00 committed by Zheng-Ning Liu
parent 2e74df517b
commit d50fe5e754
4 changed files with 9 additions and 6 deletions

View File

@ -1246,6 +1246,7 @@ jit_utils.add_backend(corex_compiler)
for mod in jit_utils.backends: for mod in jit_utils.backends:
if mod.check(): if mod.check():
break break
is_cuda = os.path.basename(nvcc_path) == "nvcc"
# build core # build core
gen_jit_flags() gen_jit_flags()

View File

@ -20,6 +20,8 @@ class DepthwiseConv(Function):
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
def execute(self, x, weight): def execute(self, x, weight):
if not jt.flags.use_cuda or not jt.compiler.is_cuda:
return nn.conv2d(x, weight, None, self.stride, self.padding, self.dilation, x.shape[1])
self.save_vars = x, weight self.save_vars = x, weight
N,C,H,W = x.shape N,C,H,W = x.shape
o,i,Kh,Kw = weight.shape o,i,Kh,Kw = weight.shape

View File

@ -494,7 +494,7 @@ class BCEWithLogitsLoss(Module):
def softmax(x, dim=None, log=False): def softmax(x, dim=None, log=False):
import jittor.other.code_softmax as code_softmax import jittor.other.code_softmax as code_softmax
if code_softmax.can_softmax_v1(x, dim) and not jt.compiler.has_rocm: if code_softmax.can_softmax_v1(x, dim) and jt.compiler.is_cuda:
return code_softmax.softmax_v1(x, log) return code_softmax.softmax_v1(x, log)
if dim is None: dim = () if dim is None: dim = ()
if log: if log:
@ -873,7 +873,7 @@ class Conv(Module):
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
self.groups = groups self.groups = groups
self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels
if self.is_depthwise_conv and jt.flags.use_cuda: if self.is_depthwise_conv and jt.flags.use_cuda and jt.compiler.is_cuda:
self.depthwise_conv = DepthwiseConv(stride, padding, dilation) self.depthwise_conv = DepthwiseConv(stride, padding, dilation)
assert in_channels % groups == 0, 'in_channels must be divisible by groups' assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups' assert out_channels % groups == 0, 'out_channels must be divisible by groups'
@ -2544,7 +2544,7 @@ class RNNBase(Module):
copy_to('weight' + param_name, offset_idx + idx, idx) copy_to('weight' + param_name, offset_idx + idx, idx)
return num_gates return num_gates
if jt.flags.use_cuda and jt.cudnn and not jt.compiler.has_rocm: if jt.flags.use_cuda and jt.cudnn and jt.compiler.is_cuda:
if getattr(self, '_cudnn_weight_size', None) is None: if getattr(self, '_cudnn_weight_size', None) is None:
offset_array = jt.cudnn.cudnn_rnn_weight_offset( offset_array = jt.cudnn.cudnn_rnn_weight_offset(
cudnn_mode, cudnn_mode,
@ -2630,7 +2630,7 @@ class RNNBase(Module):
hx = (jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype), hx = (jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype),
jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype)) jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype))
if jt.flags.use_cuda and jt.cudnn and self.proj_size == 0 and not jt.compiler.has_rocm: if jt.flags.use_cuda and jt.cudnn and self.proj_size == 0 and jt.compiler.is_cuda:
return self._execute_cudnn_rnn(input, hx) return self._execute_cudnn_rnn(input, hx)
else: else:
hidden_n = [] hidden_n = []

View File

@ -83,12 +83,12 @@ vfill(T* __restrict__ a) {
if (nbyte<=0) return; if (nbyte<=0) return;
if (nbyte>=16) { if (nbyte>=16) {
auto* __restrict__ aa = (int4* __restrict__)a; auto* __restrict__ aa = (int4* __restrict__)a;
aa[0] = {0}; aa[0].x = aa[0].y = aa[0].z = aa[0].w = 0;
return vfill<nbyte-16>(aa+1); return vfill<nbyte-16>(aa+1);
} }
if (nbyte>=8) { if (nbyte>=8) {
auto* __restrict__ aa = (int2* __restrict__)a; auto* __restrict__ aa = (int2* __restrict__)a;
aa[0] = {0}; aa[0].x = aa[0].y = 0;
return vfill<nbyte-8>(aa+1); return vfill<nbyte-8>(aa+1);
} }
if (nbyte>=4) { if (nbyte>=4) {