diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index dd0267df..c8dc1d0e 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.1.38' +__version__ = '1.3.1.38.1' from jittor_utils import lock with lock.lock_scope(): ori_int = int @@ -304,15 +304,52 @@ Var.cast = Var.cast def array(data, dtype=None): if isinstance(data, core.Var): if dtype is None: - return data.clone() - return cast(data, dtype) - if dtype is not None: + ret = data.clone() + else: + ret = cast(data, dtype) + elif dtype is not None: if isinstance(dtype, NanoString): dtype = str(dtype) elif callable(dtype): dtype = dtype.__name__ - return ops.array(np.array(data, dtype)) - return ops.array(data) + ret = ops.array(np.array(data, dtype)) + else: + ret = ops.array(data) + # TODO: move those code to core + amp_reg = jt.flags.amp_reg + if amp_reg and ret.numel() != 1 and ret.dtype.is_float(): + if amp_reg & 16: + if amp_reg & 1: + if ret.dtype != "float32": + return ret.float32() + elif amp_reg & 2: + if ret.dtype != "float16": + return ret.float16() + return ret + +def random(shape, dtype="float32", type="uniform"): + # TODO: move those code to core + if dtype == "float16": + # TODO: make curand support fp16 + ret = ops.random(shape, "float32", type).float16() + else: + ret = ops.random(shape, dtype, type) + amp_reg = jt.flags.amp_reg + if amp_reg: + if amp_reg & 16: + if amp_reg & 1: + if ret.dtype != "float32": + return ret.float32() + elif amp_reg & 2: + if ret.dtype != "float16": + return ret.float16() + return ret + +def float_auto(x): + if jt.flags.amp_reg & 2: + return x.float16() + return x.float32() +Var.float_auto = float_auto def array64(data, dtype=None): with jt.flag_scope(auto_convert_64_to_32=0): @@ -1419,18 +1456,15 @@ Var.size = size def to_int(v): - dtype = str(v.dtype) - assert dtype.startswith("int") + assert v.dtype.is_int() return v.item() def to_float(v): - dtype = str(v.dtype) - assert dtype.startswith("float") + assert v.dtype.is_float() return v.item() def to_bool(v): - dtype = str(v.dtype) - assert dtype.startswith("int") or dtype=="bool" + assert v.dtype.is_int() or v.dtype.is_bool() return ori_bool(v.item()) Var.__int__ = to_int diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index 7fbf6c61..893a5ca0 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -210,6 +210,12 @@ def setup_cuda_extern(): LOG.w(f"CUDA found but cub is not loaded:\n{line}") libs = ["cublas", "cudnn", "curand"] + # in cuda 11.4, module memory comsumptions: + # default context: 259 MB + # cublas: 340 MB + # cudnn: 340 MB + if int(os.environ.get("conv_opt", "0")): + libs = ["cublas", "curand"] for lib_name in libs: try: setup_cuda_lib(lib_name, extra_flags=link_cuda_extern) @@ -309,22 +315,27 @@ def install_cutt(root_folder): if md5 != true_md5: os.remove(fullname) shutil.rmtree(dirname) - if not os.path.isfile(os.path.join(cache_path, "libcutt"+so)): - LOG.i("Downloading cutt...") - download_url_to_local(url, filename, root_folder, true_md5) + CUTT_PATH = os.environ.get("CUTT_PATH", "") + if not os.path.isfile(os.path.join(cache_path, "libcutt"+so)) or CUTT_PATH: + if CUTT_PATH: + dirname = CUTT_PATH + else: + LOG.i("Downloading cutt...") + download_url_to_local(url, filename, root_folder, true_md5) - import zipfile + import zipfile - zf = zipfile.ZipFile(fullname) - try: - zf.extractall(path=root_folder) - except RuntimeError as e: - print(e) - raise - zf.close() + zf = zipfile.ZipFile(fullname) + try: + zf.extractall(path=root_folder) + except RuntimeError as e: + print(e) + raise + zf.close() LOG.i("installing cutt...") - arch_flag = "" + # -Xptxas -dlcm=ca actually not work + arch_flag = " -Xptxas -dlcm=ca " if len(flags.cuda_archs): arch_flag = f" -arch=compute_{min(flags.cuda_archs)} " arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs)) diff --git a/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h b/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h index 67d46a69..5ff0de2c 100644 --- a/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h +++ b/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h @@ -23,8 +23,8 @@ EXTERN_LIB cublasHandle_t cublas_handle; static inline cudaDataType get_dtype(NanoString dtype) { if (dtype == ns_float32) return CUDA_R_32F; - // if (dtype == ns_float64) return CUDA_R_64F; - // if (dtype == ns_float16) return CUDA_R_16F; + if (dtype == ns_float64) return CUDA_R_64F; + if (dtype == ns_float16) return CUDA_R_16F; LOGf << "not support type" << dtype; return CUDA_R_32F; } diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc index 6ab519a0..47874ef3 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc @@ -124,6 +124,10 @@ void CublasBatchedMatmulOp::jit_run() { if (use_tensorcore) { computeType = CUBLAS_COMPUTE_32F_FAST_16F; } + if (a->dtype() == ns_float16 + || b->dtype() == ns_float16 || c->dtype() == ns_float16) { + computeType = CUBLAS_COMPUTE_16F; + } checkCudaErrors(cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, k, n, m, &alpha, diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc index 0ed46bc4..95de20f7 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc @@ -81,6 +81,10 @@ void CublasMatmulOp::jit_run() { if (use_tensorcore) { computeType = CUBLAS_COMPUTE_32F_FAST_16F; } + if (a->dtype() == ns_float16 + || b->dtype() == ns_float16 || c->dtype() == ns_float16) { + computeType = CUBLAS_COMPUTE_16F; + } checkCudaErrors(cublasGemmEx(handle_, CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, k, n, m, &alpha, diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc index acb9bd52..c495fdb6 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc @@ -174,6 +174,11 @@ void CudnnConvOp::jit_run() { if(use_tensorcore){ checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) ); } + + if (x->dtype() == ns_float16 + || y->dtype() == ns_float16 || w->dtype() == ns_float16) { + checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) ); + } int dimY[] = { (int)y->shape[findc("@YFORMAT", 'a')], // n diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 15d68915..7b2c64af 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -488,18 +488,11 @@ def arctan2(y,x): angle = jt.zeros(x.shape,dtype=x.dtype) x = (x!=0.0).ternary(x, x+1e-30) angle = (y/x).arctan() - - mask = (y<0) & (x<0) - if angle[mask].numel()>0: - angle[mask] -= np.pi - - mask = (y>=0) &(x<0) - if angle[mask].numel()>0: - angle[mask] +=np.pi + mask = y<0 | ((y==0) & (x<0)) + angle = angle + mask*np.pi return angle - def nonzero(x): r''' Return the index of the elements of input tensor which are not equal to zero. diff --git a/python/jittor/models/resnet.py b/python/jittor/models/resnet.py index 1ef54d65..3aa60928 100644 --- a/python/jittor/models/resnet.py +++ b/python/jittor/models/resnet.py @@ -143,7 +143,7 @@ class ResNet(nn.Module): x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) - x = self.avgpool(x) + x = self.avgpool(x).float_auto() x = jt.reshape(x, (x.shape[0], -1)) x = self.fc(x) return x diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 44ed27dc..a0079948 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -37,9 +37,10 @@ def matmul_transpose(a, b): assert len(a.shape) == 2 and len(b.shape) == 2 shape = list(a.shape)[:-1] + list(b.shape) - a = a.broadcast(shape, [len(shape)-2]) - b = b.broadcast(shape) - return (a*b).sum(len(shape)-1) + with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4): + a = a.broadcast(shape, [len(shape)-2]) + b = b.broadcast(shape) + return (a*b).sum(len(shape)-1) def bmm_transpose(a, b): @@ -108,47 +109,48 @@ Example:: c = jt.matmul(a, b) assert c.shape == [8, 10, 3, 5] ''' - len_a = len(a.shape) - len_b = len(b.shape) - if len_b == 1: - # a: [n, m], b:[m], c:[n] - return (a*b).sum(-1) - if len_a == 1: - # a: [n], b:[n,k], c:[k] - return (a.broadcast(b, [-1]) * b).sum(0) - if len_a>=3 and len_a==len_b: - # bmm - # a: [..., n, m], b: [..., m, k], c:[..., n, k] - if jt.flags.use_cuda and jt.compile_extern.cublas_ops: - return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0) - shape = [] - len_c = max(len_a, len_b) - (n, m), (m_, k) = a.shape[-2:], b.shape[-2:] - assert m == m_, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}" - # a: [..., n, m] - # b: [..., m, k] - # cc:[..., n, m, k] - # --> - # 012 - if len_b == 2 and len_a>2: - # TODO:ugly implementation for tuner - aa = a.reshape((-1, m)) - cc = matmul(aa, b) - # print(a.shape, b.shape, cc.shape) - return cc.reshape(a.shape[:-1] + [k]) - for i in range(len_c-2): - ai = len_a-(len_c-i) - bi = len_b-(len_c-i) - an = a.shape[ai] if ai>=0 else 1 - bn = b.shape[bi] if bi>=0 else 1 - if an!=1 and bn!=1: - assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}" - cn = max(an, bn) - shape.append(cn) - shape.extend([n, m, k]) - a = a.broadcast(shape, [-1]) - b = b.broadcast(shape, [-3]) - return (a*b).sum(-2) + with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4): + len_a = len(a.shape) + len_b = len(b.shape) + if len_b == 1: + # a: [n, m], b:[m], c:[n] + return (a*b).sum(-1) + if len_a == 1: + # a: [n], b:[n,k], c:[k] + return (a.broadcast(b, [-1]) * b).sum(0) + if len_a>=3 and len_a==len_b: + # bmm + # a: [..., n, m], b: [..., m, k], c:[..., n, k] + if jt.flags.use_cuda and jt.compile_extern.cublas_ops: + return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0) + shape = [] + len_c = max(len_a, len_b) + (n, m), (m_, k) = a.shape[-2:], b.shape[-2:] + assert m == m_, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}" + # a: [..., n, m] + # b: [..., m, k] + # cc:[..., n, m, k] + # --> + # 012 + if len_b == 2 and len_a>2: + # TODO:ugly implementation for tuner + aa = a.reshape((-1, m)) + cc = matmul(aa, b) + # print(a.shape, b.shape, cc.shape) + return cc.reshape(a.shape[:-1] + [k]) + for i in range(len_c-2): + ai = len_a-(len_c-i) + bi = len_b-(len_c-i) + an = a.shape[ai] if ai>=0 else 1 + bn = b.shape[bi] if bi>=0 else 1 + if an!=1 and bn!=1: + assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}" + cn = max(an, bn) + shape.append(cn) + shape.extend([n, m, k]) + a = a.broadcast(shape, [-1]) + b = b.broadcast(shape, [-3]) + return (a*b).sum(-2) jt.Var.matmul = jt.Var.__matmul__ = matmul jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b)) @@ -488,22 +490,22 @@ class BCEWithLogitsLoss(Module): def execute(self, output, target): return binary_cross_entropy_with_logits(output,target,self.weight,self.pos_weight,self.size_average) -def softmax(x, dim = None): +def softmax(x, dim=None, log=False): import jittor.other.code_softmax as code_softmax if code_softmax.can_softmax_v1(x, dim): - return code_softmax.softmax_v1(x) + return code_softmax.softmax_v1(x, log) if dim is None: x = (x - x.max()).exp() ret = x / x.sum() else: x = (x-x.max(dim, keepdims=True)).exp() ret = x / x.sum(dim, keepdims=True) + if log: return ret.log() return ret jt.Var.softmax = softmax def log_softmax(x,dim=None): - x = softmax(x,dim=dim) - return jt.log(x) + return softmax(x,dim=dim, log=True) jt.Var.log_softmax = log_softmax def log_sigmoid(x): @@ -832,15 +834,16 @@ class Conv(Module): oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1 ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1 assert oh>0 and ow>0 - xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [ - 'i0', # Nid - 'i2', # Cid - f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid - f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid - ]) - ww = self.weight.broadcast(xx.shape, [0,3,4]) - yy = xx*ww - y = yy.sum([2,5,6]) # Kc, Kh, Kw + with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4): + xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [ + 'i0', # Nid + 'i2', # Cid + f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid + f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid + ]) + ww = self.weight.broadcast(xx.shape, [0,3,4]) + yy = xx*ww + y = yy.sum([2,5,6]) # Kc, Kh, Kw if self.bias is not None: b = self.bias.broadcast(y.shape, [0,2,3]) y = y + b @@ -1008,6 +1011,18 @@ class Conv3d(Module): def execute(self, x): return conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + +class Conv1d_sp(Linear): + def __init__(self, inchannels, outchannels, kernel_size=1, bias=True): + super().__init__(inchannels, outchannels, bias=bias) + assert kernel_size == 1 + + def execute(self, x): + x = x.transpose(0, 2, 1) + x = super().execute(x) + x = x.transpose(0, 2, 1) + return x + def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): ''' Applies a 2D convolution over an input signal composed of several input planes. @@ -1048,15 +1063,16 @@ def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): Kh, Kw = weight.shape[-2:] oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1 ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1 - xx = x.reindex([N,out_channels,C,oh,ow,Kh,Kw], [ - 'i0', # Nid - 'i2', # Cid - f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}', # Hid+Khid - f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}', # Wid+KWid - ]) - ww = weight.broadcast(xx.shape, [0,3,4]) - yy = xx*ww - y = yy.sum([2,5,6]) # Kc, Kh, Kw + with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4): + xx = x.reindex([N,out_channels,C,oh,ow,Kh,Kw], [ + 'i0', # Nid + 'i2', # Cid + f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}', # Hid+Khid + f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}', # Wid+KWid + ]) + ww = weight.broadcast(xx.shape, [0,3,4]) + yy = xx*ww + y = yy.sum([2,5,6]) # Kc, Kh, Kw if bias is not None: b = bias.broadcast(y.shape, [0,2,3]) y = y + b diff --git a/python/jittor/other/code_softmax.py b/python/jittor/other/code_softmax.py index f167e497..8534f0cb 100644 --- a/python/jittor/other/code_softmax.py +++ b/python/jittor/other/code_softmax.py @@ -10,32 +10,48 @@ def can_softmax_v1(a, dim): return False return True -def softmax_v1(a): +def softmax_v1(a, log=False): assert can_softmax_v1(a, -1) length = a.shape[-1] # tnum = 1024 tnum = 500 if length % 500 == 0 else 512 + tnum = 125 if length % 125 == 0 else 128 + # tnum = 125 + # tnum = 1000 if length % 1000 == 0 else 1024 # tnum = 250 per_thread = (length-1) // tnum + 1 + ILP = 1 + for ilp in [8,4,2]: + if length % tnum == 0 and per_thread % ilp == 0: + ILP = ilp + per_thread //= ILP + break for_loop = f""" #pragma unroll for (int i=0; i<{per_thread}; i++) """ - if length % tnum == 0: - for_loop += f"if (i*{tnum}+threadIdx.x < len)\n" + if length % tnum != 0: + for_loop += f"if ((i*{tnum}+threadIdx.x)*{ILP} < len)\n" return jt.code(a.shape, a.dtype, [a], cuda_header=f''' #include <{jt.compile_extern.cub_home}cub/cub.cuh> +#include ''', cuda_src=f''' __global__ void kernel(in0_type* x, out0_type* y, int len) {{ typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; int id = blockIdx.x * len; - in0_type v[{per_thread}]; - {for_loop} v[i] = x[id+i*{tnum}+threadIdx.x]; - float v1 = v[0]; - {for_loop} v1 = max(v1, v[i]); + in0_type v[{per_thread}][{ILP}]; + {for_loop} + vload(v[i], &x[id+(i*{tnum}+threadIdx.x)*{ILP}]); + // v[i] = x[id+i*{tnum}+threadIdx.x]; + float v1 = -1e30; + {for_loop} + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + v1 = max(v1, float(v[i][j])); + }} __shared__ float vmax; auto tmp = BlockReduce(temp_storage).Reduce(v1, cub::Max()); if (threadIdx.x == 0) @@ -43,10 +59,12 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{ __syncthreads(); v1 = 0; - {for_loop} {{ - v[i] = expf(v[i] - vmax); - v1 += v[i]; - }} + {for_loop} + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + v[i][j] = expf(float(v[i][j]) - vmax); + v1 += float(v[i][j]); + }} tmp = BlockReduce(temp_storage).Sum(v1); __shared__ float vsum; @@ -54,7 +72,15 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{ vsum = tmp; __syncthreads(); - {for_loop} y[id+i*{tnum}+threadIdx.x] = v[i] / vsum; + {for_loop} + #pragma unroll + for (int j=0; j<{ILP}; j++) + v[i][j] = { + "@expand_op(log,@in0_type,float(v[i][j])/vsum)" if log + else "float(v[i][j])/vsum" + }; + {for_loop} + vload(&y[id+(i*{tnum}+threadIdx.x)*{ILP}], v[i]); }} int len = in0->shape[in0->shape.size()-1]; int bnum = in0->numel() / len; @@ -64,15 +90,17 @@ CHECK(0 == cudaGetLastError()); ''', cuda_grad_src=[f""" __global__ void kernel(pout0_type* x, dout_type* y, out0_type* z, int len) {{ int id = blockIdx.x * len; - in0_type vx[{per_thread}]; - in0_type vy[{per_thread}]; + in0_type vx[{per_thread}][{ILP}]; + in0_type vy[{per_thread}][{ILP}]; {for_loop} {{ - vx[i] = x[id+i*{tnum}+threadIdx.x]; - vy[i] = y[id+i*{tnum}+threadIdx.x]; + vload(vx[i], &x[id+(i*{tnum}+threadIdx.x)*{ILP}]); + vload(vy[i], &y[id+(i*{tnum}+threadIdx.x)*{ILP}]); }} float v1 = 0; - {for_loop} v1 += vx[i]*vy[i]; - + {for_loop} + #pragma unroll + for (int j=0; j<{ILP}; j++) + v1 += {"vy[i][j];" if log else "vx[i][j]*vy[i][j];"} typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -83,7 +111,16 @@ __global__ void kernel(pout0_type* x, dout_type* y, out0_type* z, int len) {{ __syncthreads(); {for_loop} - z[id+i*{tnum}+threadIdx.x] = vx[i] * (vy[i] - reduce_var); + #pragma unroll + for (int j=0; j<{ILP}; j++) + vx[i][j] = { + "vy[i][j] - expf(vx[i][j]) * reduce_var;" if log + else "vx[i][j] * (vy[i][j] - reduce_var);" + } + + {for_loop} + vload(&z[id+(i*{tnum}+threadIdx.x)*{ILP}], + vx[i]); }} int len = in0->shape[in0->shape.size()-1]; int bnum = in0->numel() / len; diff --git a/python/jittor/pool.py b/python/jittor/pool.py index 0f1fe12e..fd3c681c 100644 --- a/python/jittor/pool.py +++ b/python/jittor/pool.py @@ -120,8 +120,8 @@ class Pool(Module): for (int i2 = p2; i2 < out_shape2; i2 += s2) {{ {forward_body} }} }} - int tx = min(1024, out_shape3); - int ty = min(1024 / tx, out_shape2); + int tx = std::min(1024, out_shape3); + int ty = std::min(1024 / tx, out_shape2); int bx = (out_shape2 - 1) / ty + 1; int by = out_shape1; int bz = out_shape0; @@ -143,8 +143,8 @@ class Pool(Module): {{ {backward_body} }} }} cudaMemsetAsync(out_p, 0, out->size); - int tx = min(1024, pout_shape3); - int ty = min(1024 / tx, pout_shape2); + int tx = std::min(1024, pout_shape3); + int ty = std::min(1024 / tx, pout_shape2); int bx = (pout_shape2 - 1) / ty + 1; int by = pout_shape1; int bz = pout_shape0; @@ -310,9 +310,9 @@ class Pool3d(Module): for (int i2 = p2; i2 < out_shape2; i2 += s2) {{ {forward_body} }} }} - int tx = min(1024, out_shape4); - int ty = min(1024 / tx, out_shape3); - int tz = min(1024 / tx / ty, out_shape2); + int tx = std::min(1024, out_shape4); + int ty = std::min(1024 / tx, out_shape3); + int tz = std::min(1024 / tx / ty, out_shape2); int bx = (out_shape2 - 1) / tz + 1; int by = out_shape1; int bz = out_shape0; @@ -337,9 +337,9 @@ class Pool3d(Module): {{ {backward_body} }} }} cudaMemsetAsync(out_p, 0, out->size); - int tx = min(1024, pout_shape4); - int ty = min(1024 / tx, pout_shape3); - int tz = min(1024 / tx / ty, pout_shape2); + int tx = std::min(1024, pout_shape4); + int ty = std::min(1024 / tx, pout_shape3); + int tz = std::min(1024 / tx / ty, pout_shape2); int bx = (pout_shape2 - 1) / tz + 1; int by = pout_shape1; int bz = pout_shape0; diff --git a/python/jittor/src/grad.cc b/python/jittor/src/grad.cc index 34879757..5e891714 100644 --- a/python/jittor/src/grad.cc +++ b/python/jittor/src/grad.cc @@ -39,11 +39,24 @@ template struct StackIniter { #define STACK_ALLOC2(T, a, n) T a[n] #endif +struct AmpGradGuard { + int amp_reg_bk; + AmpGradGuard(Op* op) { + amp_reg_bk = amp_reg; + amp_reg |= (op->flags.flags >> NodeFlags::_prefer_32); + } + + ~AmpGradGuard() { + amp_reg = amp_reg_bk; + } +}; + VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) { if (dout == nullptr) return nullptr; if (x_index<0) return nullptr; LOGvvvv << "Make grad op:" >> op->name() << "inputs:" >> op->inputs() << "out:" >> out << "dout:" >> dout << "x:" >> x << "xid:" >> x_index; + AmpGradGuard agg(op); auto dx = op->grad(out, dout, x, x_index); if (x->loop_options) dx->loop_options = x->loop_options; @@ -182,7 +195,10 @@ vector grad(Var* loss, vector targets) { douts[i] = nullptr; } trace_grad_op = op; - op->grads(douts, dins); + { + AmpGradGuard agg(op); + op->grads(douts, dins); + } // dump "for (Var* in : op->inputs())" for (int i=0; i del; inline Deleter(std::function&& func) : del(move(func)) {} - inline ~Deleter() { del(); } + inline Deleter() {} + inline ~Deleter() { if (del) del(); } }; } // jittor diff --git a/python/jittor/src/misc/nano_string.cc b/python/jittor/src/misc/nano_string.cc index b85bf861..e97523f5 100644 --- a/python/jittor/src/misc/nano_string.cc +++ b/python/jittor/src/misc/nano_string.cc @@ -9,6 +9,17 @@ namespace jittor { +DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too"); + +DEFINE_FLAG_WITH_SETTER(int, auto_mixed_precision_level, 0, "Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16"); + +void setter_auto_mixed_precision_level(int value) { + if (value <= 3) amp_reg = 0; else + if (value == 4) amp_reg = amp_prefer16; else + if (value == 5) amp_reg = amp_prefer16 | amp_array_prefer; else + if (value == 6) amp_reg = amp_prefer16 | amp_array_prefer | amp_keep_reduce | amp_keep_white; +} + #define FOR_ALL_TYPES(m) \ m(bool) \ m(int8) \ @@ -89,15 +100,18 @@ static unordered_set unary_ops = { "erfinv" }; -static unordered_set unary_float_ops = { +static unordered_set float_ops = { "log", "exp", "sqrt", + "mean", + "divide", }; -static unordered_set unary_int_ops = { +static unordered_set int_ops = { "round_int", "floor_int", "ceil_int", + "floor_divide", }; static unordered_set binary_ops = { @@ -127,6 +141,13 @@ static unordered_set binary_ops = { "mean", }; + +static unordered_set white_ops = { + // "log", + "exp", + "pow", +}; + #define DEFINE_NS(T) NanoString ns_##T; FOR_ALL_NS(DEFINE_NS); @@ -135,6 +156,9 @@ char __ns_to_string[ns_max_size*ns_max_len]; int __ns_len[ns_max_size]; static void init_ns() { + dsize_map["float16"] = 1; + is_float_map["float16"] = 1; + is_unsigned["float16"] = 0; NanoString::ns_t i=0; auto func = [&](const char* name, NanoString& ns) { ns.set(NanoString::_index, i++, NanoString::_index_nbits); @@ -149,13 +173,16 @@ static void init_ns() { if (unary_ops.count(name)) { ns.set(NanoString::_type, NanoString::_unary, NanoString::_type_nbits); ns.set(NanoString::_bool, is_bool.count(name)); - ns.set(NanoString::_int, unary_int_ops.count(name)); - ns.set(NanoString::_float, unary_float_ops.count(name)); + ns.set(NanoString::_int, int_ops.count(name)); + ns.set(NanoString::_float, float_ops.count(name)); } else if (binary_ops.count(name)) { ns.set(NanoString::_type, NanoString::_binary, NanoString::_type_nbits); ns.set(NanoString::_bool, is_bool.count(name)); + ns.set(NanoString::_int, int_ops.count(name)); + ns.set(NanoString::_float, float_ops.count(name)); } + ns.set(NanoString::_white_list, white_ops.count(name)); __string_to_ns[name] = ns; auto name2 = ns.to_cstring(); int len=0; diff --git a/python/jittor/src/misc/nano_string.h b/python/jittor/src/misc/nano_string.h index 12feb331..02440985 100644 --- a/python/jittor/src/misc/nano_string.h +++ b/python/jittor/src/misc/nano_string.h @@ -24,6 +24,7 @@ constexpr int ns_max_len = 16; m(uint16) \ m(uint32) \ m(uint64) \ + m(float16) \ m(float32) \ m(float64) \ \ @@ -100,7 +101,7 @@ struct NanoString { typedef uint16 ns_t; enum Flags { // bit0~7: index - _index=0, _index_nbits=8, + _index=0, _index_nbits=7, _n=_index_nbits, // bit0-1: type @@ -116,6 +117,8 @@ struct NanoString { _float=_n+5, // bit6-7: dsize(1,2,4,8 byte) _dsize=_n+6, _dsize_nbits=2, + // bit8: white list + _white_list=_n+8, }; ns_t data=0; @@ -130,11 +133,16 @@ struct NanoString { inline ns_t index() const { return get(_index, _index_nbits); } inline int len() const { return __ns_len[index()]; } inline ns_t type() const { return get(_type, _type_nbits); } - inline ns_t is_bool() const { return get(_bool); } - inline ns_t is_int() const { return get(_int); } - inline ns_t is_unsigned() const { return get(_unsigned); } - inline ns_t is_float() const { return get(_float); } + // @pyjt(is_bool) + inline bool is_bool() const { return get(_bool); } + // @pyjt(is_int) + inline bool is_int() const { return get(_int); } + inline bool is_unsigned() const { return get(_unsigned); } + // @pyjt(is_float) + inline bool is_float() const { return get(_float); } + inline ns_t is_white() const { return get(_white_list); } inline ns_t dsize() const { return 1< jit_ops; string_view_map jit_key_mapper; -int64_t Op::number_of_lived_ops = 0; +int64 Op::number_of_lived_ops = 0; Op::Op() { flags.set(NodeFlags::_var, 0); flags.set(NodeFlags::_cpu, 1); + flags.flags |= ((amp_reg & 7) << NodeFlags::_prefer_32); number_of_lived_ops++; if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.record_node(this); } diff --git a/python/jittor/src/op.h b/python/jittor/src/op.h index 25d752e5..957e435a 100644 --- a/python/jittor/src/op.h +++ b/python/jittor/src/op.h @@ -15,7 +15,7 @@ namespace jittor { enum OpType {other=0, element=1, broadcast=2, reduce=3}; struct Op : Node { vector outputs_holder; - static int64_t number_of_lived_ops; + static int64 number_of_lived_ops; inline Caster inputs() { CHECK_EXIST; return &_inputs; } inline Caster outputs() { CHECK_EXIST; return &_outputs; } diff --git a/python/jittor/src/op_compiler.cc b/python/jittor/src/op_compiler.cc index fe32d043..1b4a7e96 100644 --- a/python/jittor/src/op_compiler.cc +++ b/python/jittor/src/op_compiler.cc @@ -112,7 +112,7 @@ int OpCompiler::total_member_count() { return member_count; } -int64_t OpCompiler::eval(const string& expr, const unordered_map& vars) { +int64 OpCompiler::eval(const string& expr, const unordered_map& vars) { if (expr.find("@") != string::npos) { string new_expr; for (size_t i=0; i binary_ops = { "bitwise_xor", }; -NanoString binary_dtype_infer(NanoString op, Var* x, Var* y) { - if (op == ns_mean) return dtype_infer(x->ns, y->ns, 2); // force float - int force_type=0; - if (op == ns_divide) force_type=2; // force float - if (op == ns_floor_divide) force_type=1; // force int - return op.is_bool() ? ns_bool : dtype_infer(x->ns, y->ns, force_type, op); -} - BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) { flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); set_type(OpType::element); ns = op; ASSERT(ns.is_binary()); - z = create_output(nullptr, binary_dtype_infer(op, x, y)); + z = create_output(nullptr, binary_dtype_infer(op, x->ns, y->ns)); } VarPtr dirty_clone_broadcast(Var* v) { diff --git a/python/jittor/src/ops/op_register.h b/python/jittor/src/ops/op_register.h index 8f830b4f..bfb56295 100644 --- a/python/jittor/src/ops/op_register.h +++ b/python/jittor/src/ops/op_register.h @@ -32,9 +32,11 @@ void op_registe(const OpInfo& op_info); bool has_op(const string& name); OpInfo get_op_info(const string& name); +struct OpCompiler; struct OpByType { unordered_set types; virtual string expand_op(const vector& args) = 0; + virtual void post_pass(OpCompiler*) = 0; }; extern vector op_types; diff --git a/python/jittor/src/ops/reduce_op.cc b/python/jittor/src/ops/reduce_op.cc index 4c0b3056..d5f1935c 100644 --- a/python/jittor/src/ops/reduce_op.cc +++ b/python/jittor/src/ops/reduce_op.cc @@ -271,7 +271,7 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims) if (x->dtype() == ns_bool) y = create_output(nullptr, ns_int32); else - y = create_output(nullptr, binary_dtype_infer(ns, x, x)); + y = create_output(nullptr, reduce_dtype_infer(ns, x->ns)); } ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask) @@ -283,7 +283,7 @@ ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask) ASSERT(ns.is_binary()); reduce_mask = dims_mask; this->keepdims_mask = keepdims_mask; - y = create_output(nullptr, binary_dtype_infer(ns, x, x)); + y = create_output(nullptr, reduce_dtype_infer(ns, x->ns)); } ReduceOp::ReduceOp(Var* x, NanoString op, int dim, bool keepdims) @@ -359,8 +359,8 @@ void ReduceOp::jit_run() { @for(i, DIM-2, -1, -1, auto ystride@i = ystride@{i+1} * yshape@{i+1};) index_t xstride@{DIM-1} = 1; @for(i, DIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};) - Ty count = Ty(x->num) / Ty(y->num); - Ty rcount = Ty(y->num) / Ty(x->num); + Ty count = x->num*1.0 / y->num; + Ty rcount = y->num*1.0 / x->num; @for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) { auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d)); yp[yid] = @expand_op(init_@OP, @Ty); diff --git a/python/jittor/src/ops/reindex_op.cc b/python/jittor/src/ops/reindex_op.cc index 03110d7d..be38b97d 100644 --- a/python/jittor/src/ops/reindex_op.cc +++ b/python/jittor/src/ops/reindex_op.cc @@ -132,7 +132,7 @@ void ReindexOp::jit_run() { @for(d, 0, XDIM, index_t xid@d = @expand_macro(INDEX@d);) auto xid = @for(d, 0, XDIM, + xid@d * xstride@d); bool check_overflow = 0 @for(d, 0, XDIM, || xid@d<0 || xid@d>=xshape@d) @for(d, 0, OSIZE, || (@expand_macro(OFD@d))); - yp[yid] = check_overflow ? (@OVERFLOW) : xp[xid]; + yp[yid] = check_overflow ? Tx(@OVERFLOW) : xp[xid]; } } #endif // JIT diff --git a/python/jittor/src/ops/transpose_op.cc b/python/jittor/src/ops/transpose_op.cc index 85888b95..94d9576e 100644 --- a/python/jittor/src/ops/transpose_op.cc +++ b/python/jittor/src/ops/transpose_op.cc @@ -28,6 +28,12 @@ TransposeOp::TransposeOp(Var* x, NanoVector axes_) : x(x), axes(axes_) { for (int i=0; i<(int)xdim; i++) axes.push_back(xdim-1-i); } + if (axes.size() < xdim || (axes.size() == xdim && axes[xdim-1]==xdim-1)) { + static VarPtr(*fuse_transpose)(Var*, NanoVector) = get_op_info("fuse_transpose").get_constructor(); + auto var = fuse_transpose(x, axes); + forward(var); + return; + } #ifdef HAS_CUDA if (use_cuda) { static VarPtr(*cutt_transpose)(Var*, NanoVector) = nullptr; diff --git a/python/jittor/src/ops/unary_op.cc b/python/jittor/src/ops/unary_op.cc index a7b1fad1..9a95d6df 100644 --- a/python/jittor/src/ops/unary_op.cc +++ b/python/jittor/src/ops/unary_op.cc @@ -32,6 +32,7 @@ static unordered_set unary_ops = { "uint16", "uint32", "uint64", + "float16", "float32", "float64", // please keep float64 the last type @@ -533,22 +534,15 @@ UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) { ns = op; ASSERT(ns.is_unary() | ns.is_dtype()); NanoString dtype; + if (ns == x->dtype()) { + forward(x); + return; + } if (ns.is_dtype()) { - if (ns == x->dtype()) { - forward(x); - return; - } dtype = ns; ns = ns_cast; - } else if (ns.is_bool()) - dtype = ns_bool; - else if (ns.is_float()) - dtype = dtype_infer(x->ns, x->ns, 2); - else if (ns.is_int()) - dtype = dtype_infer(x->ns, x->ns, 1); - else { - dtype = x->ns; - } + } else + dtype = unary_dtype_infer(ns, x->ns); y = create_output(nullptr, dtype); } diff --git a/python/jittor/src/opt/tuner/conv_tuner.cc b/python/jittor/src/opt/tuner/conv_tuner.cc index 738746d4..c043c1b5 100644 --- a/python/jittor/src/opt/tuner/conv_tuner.cc +++ b/python/jittor/src/opt/tuner/conv_tuner.cc @@ -25,6 +25,7 @@ namespace jittor { using namespace expr; +extern int use_cuda; struct OpInspector { // binary mask for @@ -229,9 +230,14 @@ void ConvTuner::forwardTune(FusedOp* fop) { if (!(bop->y->input() && bop->x->input() && fop->has(bop->x->input()) && fop->has(bop->y->input()))) continue; if (!(bop->x->input()->type()==OpType::broadcast && bop->y->input()->type()==OpType::broadcast)) return; - // only support float32 currently - if (bop->z->dtype() != ns_float32) - continue; + // only support float32,float16 currently + if (use_cuda) { + if (bop->z->dtype() != ns_float32 && bop->z->dtype() != ns_float16) + continue; + } else { + if (bop->z->dtype() != ns_float32) + continue; + } Op* ops[3] = {op, bop->x->input(), bop->y->input()}; int ok = 0; LOGvvvv << "conv like op" << fop << fop->get_jit_key(get_jk()); diff --git a/python/jittor/src/profiler/profiler.cc b/python/jittor/src/profiler/profiler.cc index fc3379b1..12aa8e37 100644 --- a/python/jittor/src/profiler/profiler.cc +++ b/python/jittor/src/profiler/profiler.cc @@ -262,7 +262,7 @@ void Profiler::record_and_run( Deleter _d; if (is_fused) { auto fop = ((FusedOp*)op); - if (fop->context && fop->context->entry) { + if (fop->context && fop->context->vrm.relay_groups.size()) { // relay op loop = rerun; profiler.relay_extra_cost = 0; diff --git a/python/jittor/src/pyjt/numpy.cc b/python/jittor/src/pyjt/numpy.cc index 6546e56c..cda93085 100644 --- a/python/jittor/src/pyjt/numpy.cc +++ b/python/jittor/src/pyjt/numpy.cc @@ -21,7 +21,9 @@ NanoString npy2ns[] = { ns_int64, ns_uint64, ns_float32, ns_float64, ns_float64, ns_void, ns_void, ns_void, - ns_void + ns_void, // 17 + ns_void, ns_void, ns_void, ns_void, ns_void, // 22 + ns_float16, // 23 }; NPY_TYPES ns2npy[] = { @@ -34,7 +36,7 @@ NPY_TYPES ns2npy[] = { NPY_BYTE, NPY_SHORT, NPY_INT, NPY_LONGLONG, NPY_UBYTE, NPY_USHORT, NPY_UINT, NPY_ULONGLONG, #endif - NPY_FLOAT, NPY_DOUBLE + NPY_HALF, NPY_FLOAT, NPY_DOUBLE }; void** PyArray_API; diff --git a/python/jittor/src/pyjt/numpy.h b/python/jittor/src/pyjt/numpy.h index 1a544edf..6328eb0a 100644 --- a/python/jittor/src/pyjt/numpy.h +++ b/python/jittor/src/pyjt/numpy.h @@ -48,6 +48,8 @@ enum NPY_TYPES { NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE, NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE, NPY_OBJECT=17, + NPY_HALF=23, + NPY_END=24, }; EXTERN_LIB NanoString npy2ns[]; @@ -60,11 +62,11 @@ EXTERN_LIB NPY_TYPES ns2npy[]; inline bool is_c_style(PyArray_Proxy* obj) { return obj->flags & 1; } inline NanoString get_type_str(PyArray_Proxy* obj) { NanoString type = ns_void; - if (obj->descr->type_num < NPY_OBJECT) + if (obj->descr->type_num < NPY_END) type = npy2ns[obj->descr->type_num]; CHECK(type != ns_void) << "Numpy type not support, type_num:" << obj->descr->type_num - << "type_char:" << obj->descr->type; + << "type_char:" << obj->descr->type << NPY_END << npy2ns[obj->descr->type_num]; return type; } diff --git a/python/jittor/src/pyjt/py_array_op.cc b/python/jittor/src/pyjt/py_array_op.cc index 683e48bb..79a711a8 100644 --- a/python/jittor/src/pyjt/py_array_op.cc +++ b/python/jittor/src/pyjt/py_array_op.cc @@ -141,7 +141,7 @@ ArrayOp::ArrayOp(PyObject* obj) { } else { // this is non-continue numpy array #if defined(__linux__) || defined(_WIN32) - STACK_ALLOC(int64, dims, args.shape.size()); + STACK_ALLOC(int64_t, dims, args.shape.size()); #elif defined(__APPLE__) long dims[args.shape.size()]; #endif diff --git a/python/jittor/src/pyjt/py_converter.h b/python/jittor/src/pyjt/py_converter.h index 3cafc103..512742a7 100644 --- a/python/jittor/src/pyjt/py_converter.h +++ b/python/jittor/src/pyjt/py_converter.h @@ -274,7 +274,7 @@ DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) { DEF_IS(ArrayArgs, PyObject*) to_py_object(const T& a) { #if defined(__linux__) || defined(_WIN32) - STACK_ALLOC(int64, dims, a.shape.size()); + STACK_ALLOC(int64_t, dims, a.shape.size()); #elif defined(__APPLE__) long dims[a.shape.size()]; #endif @@ -390,7 +390,7 @@ DEF_IS(VarHolder*, T) from_py_object(PyObject* obj, unique_ptr& holde struct DataView; DEF_IS(DataView, PyObject*) to_py_object(T a) { #if defined(__linux__) || defined(_WIN32) - STACK_ALLOC(int64, dims, a.shape.size()); + STACK_ALLOC(int64_t, dims, a.shape.size()); #elif defined(__APPLE__) long dims[a.shape.size()]; #endif diff --git a/python/jittor/src/pyjt/py_ring_buffer.cc b/python/jittor/src/pyjt/py_ring_buffer.cc index 3f46f4f8..3347553c 100644 --- a/python/jittor/src/pyjt/py_ring_buffer.cc +++ b/python/jittor/src/pyjt/py_ring_buffer.cc @@ -110,7 +110,7 @@ static void push_py_object(RingBuffer* rb, PyObject* obj, uint64& __restrict__ o rb->push(size, offset); args.ptr = rb->get_ptr(size, offset); #if defined(__linux__) || defined(_WIN32) - STACK_ALLOC(int64, dims, args.shape.size()); + STACK_ALLOC(int64_t, dims, args.shape.size()); #elif defined(__APPLE__) long dims[args.shape.size()]; #endif diff --git a/python/jittor/src/type/common_op_type.cc b/python/jittor/src/type/common_op_type.cc index 3c9f42ad..305917d4 100644 --- a/python/jittor/src/type/common_op_type.cc +++ b/python/jittor/src/type/common_op_type.cc @@ -12,6 +12,44 @@ namespace jittor { extern int use_cuda; +unordered_map common_op_type_cuda_map = { + {"logical_not", "(!($2))"}, + {"bitwise_not", "(~($2))"}, + {"negative", "(-($2))"}, + {"abs", "::abs($2)"}, + {"log", "::logf(($1)($2))"}, + {"exp", "::expf(($1)($2))"}, + {"sqrt", "::sqrtf(($1)($2))"}, + {"round", "(($1) ::roundf(($2)))"}, + {"floor", "(($1) ::floorf(($2)))"}, + {"ceil", "(($1) ::ceilf(($2)))"}, + {"round_int", "(($1) ::roundf(($2)))"}, + {"floor_int", "(($1) ::floorf(($2)))"}, + {"ceil_int", "(($1) ::ceilf(($2)))"}, + {"sin", "(($1) ::sinf(($2)))"}, + {"asin", "(($1) ::asinf(($2)))"}, + {"sinh", "(($1) ::sinhf(($2)))"}, + {"asinh", "(($1) ::asinhf(($2)))"}, + {"cos", "(($1) ::cosf(($2)))"}, + {"acos", "(($1) ::acosf(($2)))"}, + {"cosh", "(($1) ::coshf(($2)))"}, + {"acosh", "(($1) ::acoshf(($2)))"}, + {"tan", "(($1) ::tanf(($2)))"}, + {"atan", "(($1) ::atanf(($2)))"}, + {"tanh", "(($1) ::tanhf(($2)))"}, + {"atanh", "(($1) ::atanhf(($2)))"}, + {"sigmoid", "(($1) (1.0f/(1.0f+::expf((::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300))))))))"}, + {"erf", "(($1) ::erff(($2)))"}, + {"erfinv", "(($1) ::erfinvf(($1)($2)))"}, + {"cast", "(($1)($2))"}, + {"pow", "::pow(($2),($4))"}, + {"maximum", "::max($1($2), $1($4))"}, + {"minimum", "::min($1($2), $1($4))"}, + {"mod", "@if(@strcmp($1,float32)==0,(($2)-::floorf(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-::floor(($2)/($4))*($4)),(($2)%($4))))"}, + {"init_maximum", "::numeric_min<$1>()"}, + {"init_minimum", "::numeric_max<$1>()"}, +}; + struct CommonOpType : OpByType { CommonOpType() { types = { @@ -34,43 +72,7 @@ struct CommonOpType : OpByType { if (!types.count(args[i])) return ""; } - static unordered_map cuda_map = { - {"logical_not", "(!($2))"}, - {"bitwise_not", "(~($2))"}, - {"negative", "(-($2))"}, - {"abs", "::abs($2)"}, - {"log", "::logf(($1)($2))"}, - {"exp", "::expf(($1)($2))"}, - {"sqrt", "::sqrtf(($1)($2))"}, - {"round", "(($1) ::roundf(($2)))"}, - {"floor", "(($1) ::floorf(($2)))"}, - {"ceil", "(($1) ::ceilf(($2)))"}, - {"round_int", "(($1) ::roundf(($2)))"}, - {"floor_int", "(($1) ::floorf(($2)))"}, - {"ceil_int", "(($1) ::ceilf(($2)))"}, - {"sin", "(($1) ::sinf(($2)))"}, - {"asin", "(($1) ::asinf(($2)))"}, - {"sinh", "(($1) ::sinhf(($2)))"}, - {"asinh", "(($1) ::asinhf(($2)))"}, - {"cos", "(($1) ::cosf(($2)))"}, - {"acos", "(($1) ::acosf(($2)))"}, - {"cosh", "(($1) ::coshf(($2)))"}, - {"acosh", "(($1) ::acoshf(($2)))"}, - {"tan", "(($1) ::tanf(($2)))"}, - {"atan", "(($1) ::atanf(($2)))"}, - {"tanh", "(($1) ::tanhf(($2)))"}, - {"atanh", "(($1) ::atanhf(($2)))"}, - {"sigmoid", "(($1) (1.0f/(1.0f+::expf((::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300))))))))"}, - {"erf", "(($1) ::erff(($2)))"}, - {"erfinv", "(($1) ::erfinvf(($1)($2)))"}, - {"cast", "(($1)($2))"}, - {"pow", "::pow(($2),($4))"}, - {"maximum", "::max($1($2), $1($4))"}, - {"minimum", "::min($1($2), $1($4))"}, - {"mod", "@if(@strcmp($1,float32)==0,(($2)-::floorf(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-::floor(($2)/($4))*($4)),(($2)%($4))))"}, - {"init_maximum", "::numeric_min<$1>()"}, - {"init_minimum", "::numeric_max<$1>()"}, - }; + auto& cuda_map = common_op_type_cuda_map; static unordered_map cpu_map = { {"logical_not", "(!($2))"}, @@ -151,6 +153,10 @@ struct CommonOpType : OpByType { ret = cpu_map[args.at(0)]; return format(ret, args); } + + void post_pass(OpCompiler*) { + return; + } }; diff --git a/python/jittor/src/type/fp16_compute.h b/python/jittor/src/type/fp16_compute.h new file mode 100644 index 00000000..93833704 --- /dev/null +++ b/python/jittor/src/type/fp16_compute.h @@ -0,0 +1,164 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +#ifdef JIT_cuda + +#include +#include + +namespace jittor { + +typedef __half float16; + +#if CUDA_ARCH >= 800 +inline __device__ float16 max(float16 a, float16 b) { return __hmax(a, b); } +inline __device__ float16 min(float16 a, float16 b) { return __hmin(a, b); } +#else +inline __device__ float16 max(float16 a, float16 b) { return a +__device__ inline void vload(T* __restrict__ a, T* __restrict__ b) { + if constexpr (nbyte<=0) return; + if constexpr (nbyte>=16) { + auto __restrict__ aa = (float4* __restrict__)a; + auto __restrict__ bb = (float4* __restrict__)b; + aa[0] = bb[0]; + return vload(aa+1, bb+1); + } + if constexpr (nbyte>=8) { + auto __restrict__ aa = (float2* __restrict__)a; + auto __restrict__ bb = (float2* __restrict__)b; + aa[0] = bb[0]; + return vload(aa+1, bb+1); + } + if constexpr (nbyte>=4) { + auto __restrict__ aa = (float* __restrict__)a; + auto __restrict__ bb = (float* __restrict__)b; + aa[0] = bb[0]; + return vload(aa+1, bb+1); + } + if constexpr (nbyte>=2) { + auto __restrict__ aa = (__half* __restrict__)a; + auto __restrict__ bb = (__half* __restrict__)b; + aa[0] = bb[0]; + return vload(aa+1, bb+1); + } + if constexpr (nbyte>=1) { + auto __restrict__ aa = (int8_t* __restrict__)a; + auto __restrict__ bb = (int8_t* __restrict__)b; + aa[0] = bb[0]; + return vload(aa+1, bb+1); + } +} + + +} + +using jittor::max; +using jittor::min; +using jittor::pow; + +#else + +namespace jittor { + +struct float16 { + uint16 x; + + inline float16(float32 f) { + unsigned x = *((int*)(void*)(&f)); + unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; + unsigned sign, exponent, mantissa; + + + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000) { + this->x = 0x7fffU; + return; + } + + sign = ((x >> 16) & 0x8000); + + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefff) { + this->x = sign | 0x7c00U; + return; + } + if (u < 0x33000001) { + this->x = sign | 0x0000U; + return; + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x70) { + shift = 13; + exponent -= 0x70; + } else { + shift = 0x7e - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0x3ff)) { + ++exponent; + mantissa = 0; + } + } + + this->x = (sign | (exponent << 10) | mantissa); + } + + inline operator float() { + + unsigned sign = ((x >> 15) & 1); + unsigned exponent = ((x >> 10) & 0x1f); + unsigned mantissa = ((x & 0x3ff) << 13); + + if (exponent == 0x1f) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); + exponent = 0xff; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x71; + do { + msb = (mantissa & 0x400000); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffff; /* 1.mantissa is implicit */ + } + } else { + exponent += 0x70; + } + + int temp = ((sign << 31) | (exponent << 23) | mantissa); + + return reinterpret_cast(temp); + } +}; + +} + +#endif \ No newline at end of file diff --git a/python/jittor/src/type/fp16_op_type.cc b/python/jittor/src/type/fp16_op_type.cc new file mode 100644 index 00000000..1d6c14cb --- /dev/null +++ b/python/jittor/src/type/fp16_op_type.cc @@ -0,0 +1,188 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "common.h" +#include "utils/str_utils.h" +#include "ops/op_register.h" +#include "op_compiler.h" + +namespace jittor { + +extern int use_cuda; + +extern unordered_map common_op_type_cuda_map; + +static bool isvar(char x) { return isalnum(x) || x == '_' || x == ':'; } + +struct FP16OpType : OpByType { + FP16OpType() { + types = { + "float16", + }; + } + + string expand_op(const vector& args) { + bool found_fp16 = 0; + for (int i=1; i cuda_map = { + {"logical_not", "(!($2))"}, + {"bitwise_not", "(~($2))"}, + {"negative", "(-($2))"}, + {"abs", "::abs($2)"}, + {"log", "::hlog(($1)($2))"}, + {"exp", "::hexp(($1)($2))"}, + {"sqrt", "::hsqrt(($1)($2))"}, + {"round", "(($1) ::roundf(($2)))"}, + {"floor", "(($1) ::floorf(($2)))"}, + {"ceil", "(($1) ::ceilf(($2)))"}, + {"round_int", "(($1) ::roundf(($2)))"}, + {"floor_int", "(($1) ::floorf(($2)))"}, + {"ceil_int", "(($1) ::ceilf(($2)))"}, + {"sin", "(($1) ::sinf(($2)))"}, + {"asin", "(($1) ::asinf(($2)))"}, + {"sinh", "(($1) ::sinhf(($2)))"}, + {"asinh", "(($1) ::asinhf(($2)))"}, + {"cos", "(($1) ::cosf(($2)))"}, + {"acos", "(($1) ::acosf(($2)))"}, + {"cosh", "(($1) ::coshf(($2)))"}, + {"acosh", "(($1) ::acoshf(($2)))"}, + {"tan", "(($1) ::tanf(($2)))"}, + {"atan", "(($1) ::atanf(($2)))"}, + {"tanh", "(($1) ::tanhf(($2)))"}, + {"atanh", "(($1) ::atanhf(($2)))"}, + {"sigmoid", "(($1) (1.0f/(1.0f+::expf((::min($1(-($2)), $1(@if(@strcmp($1,float16)==0,30,300))))))))"}, + {"erf", "(($1) ::erff(($2)))"}, + {"erfinv", "(($1) ::erfinvf(($1)($2)))"}, + {"cast", "(($1)($2))"}, + {"pow", "::pow(($2),($4))"}, + {"maximum", "::max($1($2), $1($4))"}, + {"minimum", "::min($1($2), $1($4))"}, + {"mod", "$1(($2)-::hfloor(($2)/($4))*($4))"}, + {"init_maximum", "-32768.0f"}, + {"init_minimum", "32768.0f"}, + }; + + static unordered_map cpu_map = { + {"logical_not", "(!($2))"}, + {"bitwise_not", "(~($2))"}, + {"negative", "(-($2))"}, + {"abs", "std::abs($2)"}, + {"log", "std::log(($1)($2))"}, + {"exp", "std::exp(($1)($2))"}, + {"sqrt", "std::sqrt(($1)($2))"}, + {"round", "(($1)std::round(($2)))"}, + {"floor", "(($1)std::floor(($2)))"}, + {"ceil", "(($1)std::ceil(($2)))"}, + {"round_int", "(($1)std::round(($2)))"}, + {"floor_int", "(($1)std::floor(($2)))"}, + {"ceil_int", "(($1)std::ceil(($2)))"}, + {"sin", "(($1) std::sin(($2)))"}, + {"asin", "(($1) std::asin(($2)))"}, + {"sinh", "(($1) std::sinh(($2)))"}, + {"asinh", "(($1) std::asinh(($2)))"}, + {"cos", "(($1) std::cos(($2)))"}, + {"acos", "(($1) std::acos(($2)))"}, + {"cosh", "(($1) std::cosh(($2)))"}, + {"acosh", "(($1) std::acosh(($2)))"}, + {"tan", "(($1) std::tan(($2)))"}, + {"atan", "(($1) std::atan(($2)))"}, + {"tanh", "(($1) std::tanh(($2)))"}, + {"atanh", "(($1) std::atanh(($2)))"}, + {"sigmoid", "(($1) (1.0f/(1.0f+std::exp(std::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300)))))))"}, + {"erf", "(($1) std::erf(($2)))"}, + {"erfinv", "(jittor::_erfinv($2))"}, + {"cast", "(($1)($2))"}, + {"pow", "std::pow(($2),($4))"}, + {"maximum", "std::max($1($2), $1($4))"}, + {"minimum", "std::min($1($2), $1($4))"}, + {"mod", "$1(($2)-std::floor(($2)/($4))*($4))"}, + {"init_maximum", "-32768.0f"}, + {"init_minimum", "32768.0f"}, + }; + + static unordered_map both_map { + {"add", "(($2)+($4))"}, + {"subtract", "(($2)-($4))"}, + {"multiply", "(($2)*($4))"}, + {"divide", "($1(($1($2))/($1($4))))"}, + {"floor_divide", "($1(($1($2))/($1($4))))"}, + {"less", "(($2)<($4))"}, + {"less_equal", "(($2)<=($4))"}, + {"greater", "(($2)>($4))"}, + {"greater_equal", "(($2)>=($4))"}, + {"equal", "(($2)==($4))"}, + {"not_equal", "(($2)!=($4))"}, + {"left_shift", "(($2)<<($4))"}, + {"right_shift", "(($2)>>($4))"}, + {"logical_and", "(($2)&&($4))"}, + {"logical_or", "(($2)||($4))"}, + {"logical_xor", "((bool($2))!=(bool($4)))"}, + {"bitwise_and", "(($2)&($4))"}, + {"bitwise_or", "(($2)|($4))"}, + {"bitwise_xor", "(($2)^($4))"}, + {"mean", "(($2)+($4)*($1(rcount)))"}, + {"init_add", "$1(0)"}, + {"init_multiply", "$1(1)"}, + {"init_logical_and", "true"}, + {"init_logical_or", "false"}, + {"init_logical_xor", "false"}, + {"init_bitwise_and", "$1(-1)"}, + {"init_bitwise_or", "$1(0)"}, + {"init_bitwise_xor", "$1(0)"}, + {"init_mean", "$1(0)"}, + }; + + string ret; + if (both_map.count(args.at(0))) + ret = both_map[args.at(0)]; + else if (use_cuda) + ret = cuda_map[args.at(0)]; + else + ret = cpu_map[args.at(0)]; + if (use_cuda) { + if (args[1] == "float32" && !both_map.count(args.at(0))) { + ret = common_op_type_cuda_map[args.at(0)]; + } + if (args[1] == "float16" || args[1] == "float32") { + for (int i=3; isrc; + if (src.find("float16") == string::npos) + return; + int i = src.rfind("#include"); + if (i<0) i=0; + i = src.find('\n', i) + 1; + src = src.substr(0, i) + "#include \"type/fp16_compute.h\"\n" + + src.substr(i); + return; + } +}; + + +static int _ = registe_op_type(new FP16OpType()); + +} \ No newline at end of file diff --git a/python/jittor/src/types.h b/python/jittor/src/types.h index e3942146..43c18b27 100644 --- a/python/jittor/src/types.h +++ b/python/jittor/src/types.h @@ -18,7 +18,7 @@ namespace jittor { typedef int8_t int8; typedef int16_t int16; typedef int int32; -typedef int64_t int64; +typedef long long int64; typedef uint8_t uint8; typedef uint16_t uint16; typedef uint32_t uint32; diff --git a/python/jittor/src/var.cc b/python/jittor/src/var.cc index 97a8e2a0..b4339575 100644 --- a/python/jittor/src/var.cc +++ b/python/jittor/src/var.cc @@ -14,7 +14,7 @@ namespace jittor { -int64_t Var::number_of_lived_vars = 0; +int64 Var::number_of_lived_vars = 0; DEFINE_FLAG(fast_shared_ptr, compile_options, {}, "Override the default loop transfrom options"); @@ -42,7 +42,7 @@ string Var::to_string() { return s; } -int64_t Var::numel() { +int64 Var::numel() { if (!shape.size()) return size=num=-1; bool negtive = 0; num=1; diff --git a/python/jittor/src/var.h b/python/jittor/src/var.h index 78e09aa8..941ef215 100644 --- a/python/jittor/src/var.h +++ b/python/jittor/src/var.h @@ -18,13 +18,13 @@ struct Var : Node { NanoVector shape; cstr name; fast_shared_ptr loop_options; - static int64_t number_of_lived_vars; + static int64 number_of_lived_vars; // this var will be generated after alloc. void* mem_ptr = nullptr; Allocator* allocator = nullptr; size_t allocation; - int64_t size, num; + int64 size, num; inline bool is_float() const { CHECK_EXIST; return ns.is_float(); } inline int dsize() const { CHECK_EXIST; return ns.dsize(); } inline NanoString dtype() const { CHECK_EXIST; return ns; } @@ -40,7 +40,7 @@ struct Var : Node { Var(NanoVector shape, NanoString dtype); string to_string(); - int64_t numel(); + int64 numel(); void set_shape(NanoVector shape); bool alloc(Allocator* allocator); inline void share_with(Var* x, size_t offset = 0) { CHECK_EXIST; allocator = (Allocator*)x; allocation = offset; } diff --git a/python/jittor/test/__main__.py b/python/jittor/test/__main__.py index 7b26a16c..3904fa48 100644 --- a/python/jittor/test/__main__.py +++ b/python/jittor/test/__main__.py @@ -15,6 +15,7 @@ if __name__ == "__main__": skip_l = int(os.environ.get("test_skip_l", "0")) skip_r = int(os.environ.get("test_skip_r", "1000000")) + skip = os.environ.get("test_skip", "").split(",") test_only = None if "test_only" in os.environ: test_only = set(os.environ.get("test_only").split(",")) @@ -34,6 +35,9 @@ if __name__ == "__main__": continue if test_only and test_name not in test_only: continue + for s in skip: + if s in test_name: + continue print("Add Test", _, test_name) suite.addTest(tests) diff --git a/python/jittor/test/misc/superglue.py b/python/jittor/test/misc/superglue.py new file mode 100644 index 00000000..44af14fd --- /dev/null +++ b/python/jittor/test/misc/superglue.py @@ -0,0 +1,374 @@ +from copy import deepcopy +from pathlib import Path +import jittor as jt +import jittor.nn as nn +import numpy as np +import os + +split_size = 1000000 + +conv_opt = int(os.environ.get("conv_opt", "0")) + +if conv_opt: + Conv1d_sp = nn.Conv1d_sp +else: + Conv1d_sp = nn.Conv1d + + +def MLP(channels: list, do_bn=True): + """ Multi-layer perceptron """ + n = len(channels) + layers = [] + for i in range(1, n): + layers.append(Conv1d_sp(channels[i - 1], channels[i], kernel_size=1, bias=True)) + if i < (n - 1): + if do_bn: + layers.append(nn.BatchNorm(channels[i])) + # layers.append(nn.InstanceNorm1d(channels[i])) + # layers.append(nn.LayerNorm(channels[i])) + layers.append(nn.ReLU()) + return nn.Sequential(*layers) + + +def normalize_keypoints(kpts, image_shape): + size = image_shape.flip(1) # shape=(b,2) ;h w -> w, h + center = size / 2 + scaling = size.float32().max(1, keepdims=True) * 0.7 + return (kpts - center[:, None, :]) / scaling[:, None, :] + + +class KeypointEncoder(nn.Module): + """ Joint encoding of visual appearance and location using MLPs""" + def __init__(self, feature_dim, layers, keypoint_position_dim=2): + super().__init__() + # self.keypoint_position_dim = keypoint_position_dim + self.encoder = MLP([keypoint_position_dim + 1] + layers + [feature_dim]) + nn.init.constant_(self.encoder[-1].bias, 0.0) + + def execute(self, kpts, scores): + inputs = jt.concat([kpts.t(), scores.unsqueeze(1)], dim=1) + return self.encoder(inputs) + +cnt = 0 + +def attention(query, key, value): + global cnt + cnt += 1 + b, d, h, n = query.shape + # print("attention", b,d,h,n, cnt) + dim_factor = (1.0 / d)**0.5 + query = query.transpose(0, 2, 3, 1).reshape(b * h, -1, d) * dim_factor + key = key.transpose(0, 2, 1, 3).reshape(b * h, d, -1) + value = value.transpose(0, 2, 3, 1).reshape(b * h, -1, d) + # print("attention", query.shape, key.shape, value.shape) + + data = [] + for i in range(0, query.shape[0], split_size): + end = min(i + split_size, query.shape[0]) + tmp1 = nn.bmm(query[i:end], key[i:end]) + tmp2 = nn.softmax(tmp1, dim=-1) + tmp3 = nn.bmm(tmp2, value[i:end]) + tmp3.sync() + data.append(tmp3) + tmp3 = jt.concat(data) + + # for i in range(0, query.shape[0], split_size): + # end = min(i + split_size, query.shape[0]) + # tmp1 = nn.bmm(query[:,i:end], key[:,i:end]) + # tmp2 = nn.softmax(tmp1, dim=-1) + # tmp3 = nn.bmm(tmp2, value[:,i:end]) + # tmp3.sync() + # data.append(tmp3) + # tmp3 = jt.concat(data, dim=1) + + # tmp1 = nn.bmm(query, key) + # print(tmp1.shape) + # tmp2 = nn.softmax(tmp1, dim=-1) + # print(tmp2.shape) + # tmp3 = nn.bmm(tmp2, value) + # print(tmp3.shape) + return tmp3.reshape(b, h, -1, d).transpose(0, 3, 1, 2) + return nn.bmm(nn.softmax(nn.bmm(query, key), dim=-1), value).reshape(b, h, -1, d).transpose(0, 3, 1, 2) + + +class MultiHeadedAttention(nn.Module): + """ Multi-head attention to increase model expressivitiy """ + def __init__(self, num_heads: int, d_model: int): + super().__init__() + assert d_model % num_heads == 0 + self.dim = d_model // num_heads + self.num_heads = num_heads + self.merge = Conv1d_sp(d_model, d_model, kernel_size=1) + self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) + + def execute(self, query, key, value): + batch_dim = query.size(0) + query, key, value = [l(x).reshape(batch_dim, self.dim, self.num_heads, -1) for l, x in zip(self.proj, (query, key, value))] + x = attention(query, key, value) + # x = attention_chunk(query, key, value) + return self.merge(x.reshape(batch_dim, self.dim * self.num_heads, -1)) + + +class AttentionalPropagation(nn.Module): + def __init__(self, feature_dim: int, num_heads: int): + super().__init__() + self.attn = MultiHeadedAttention(num_heads, feature_dim) + self.mlp = MLP([feature_dim * 2, feature_dim * 2, feature_dim]) + nn.init.constant_(self.mlp[-1].bias, 0.0) + + def execute(self, x, source): + message = self.attn(x, source, source) + return self.mlp(jt.concat([x, message], dim=1)) + + +class AttentionalGNN(nn.Module): + def __init__(self, feature_dim: int, layer_names: list): + super().__init__() + self.layers = nn.ModuleList([AttentionalPropagation(feature_dim, 4) for _ in range(len(layer_names))]) + self.is_cross = [x == 'cross' for x in layer_names] + + def execute(self, desc0, desc1): + for layer, is_cross in zip(self.layers, self.is_cross): + layer.attn.prob = [] + if is_cross: + src0, src1 = desc1, desc0 + else: # if name == 'self': + src0, src1 = desc0, desc1 + # delta0, delta1 = layer(desc0, src0), layer(desc1, src1) + + delta0 = layer(desc0, src0) + # print(delta0.numel()*4) + # breakpoint() + jt.sync_all() + delta1 = layer(desc1, src1) + jt.sync_all() + desc0, desc1 = (desc0 + delta0), (desc1 + delta1) + jt.sync_all() + return desc0, desc1 + + +def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int): + """ Perform Sinkhorn Normalization in Log-space for stability""" + u, v = jt.zeros_like(log_mu), jt.zeros_like(log_nu) + for _ in range(iters): + u = log_mu - (Z + v.unsqueeze(1)).exp().sum(dim=2).log() + v = log_nu - (Z + u.unsqueeze(2)).exp().sum(dim=1).log() + return Z + u.unsqueeze(2) + v.unsqueeze(1) + + +def log_optimal_transport(scores, alpha, iters: int): + """ Perform Differentiable Optimal Transport in Log-space for stability""" + b, m, n = scores.shape + ms, ns = jt.float(m, requires_grad=False), jt.float(n, requires_grad=False) + + bins0 = alpha.broadcast([b, m, 1]) + bins1 = alpha.broadcast([b, 1, n]) + alpha = alpha.broadcast([b, 1, 1]) + + couplings = jt.concat([jt.concat([scores, bins0], -1), jt.concat([bins1, alpha], -1)], 1) + + norm = -(ms + ns).log() + log_mu = jt.concat([norm.broadcast([m]), ns.log() + norm]) + log_nu = jt.concat([norm.broadcast([n]), ms.log() + norm]) + log_mu, log_nu = log_mu[None].broadcast([b, m + 1]), log_nu[None].broadcast([b, n + 1]) + + Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters) + Z = Z - norm # multiply probabilities by M+N + return Z + + +def arange_like(x, dim: int): + return jt.ones(x.shape[dim], dtype=x.dtype)[None].cumsum()[0] - 1 # traceable in 1.1 + + +default_config = { + 'descriptor_dim': 256, # SuperPoint + 'weights': 'indoor', + 'keypoint_encoder': [32, 64, 128, 256], # SuperPoint + 'GNN_layers': ['self', 'cross'] * 9, + 'sinkhorn_iterations': 100, + 'match_threshold': 0.2, +} + + +def get_weighted_loss_batch(scores, all_matches): + matches0, matches1 = all_matches.chunk(chunks=2, dim=2) + batchIdx = jt.arange(all_matches.shape[0]).unsqueeze(1).repeat(1, all_matches.shape[1]) + batchIdx, matches0, matches1 = batchIdx.view(-1), matches0.view(-1), matches1.view(-1) + valid_index0, valid_index1 = matches0 >= 0, matches1 >= 0 + valid_match = jt.logical_and(valid_index0, valid_index1) + valid_unmatch = jt.logical_xor(valid_index0, valid_index1) + num_match = valid_match.sum().maximum(1e-9) + num_unmatch = valid_unmatch.sum().maximum(1e-9) + + + + score_ = scores[batchIdx, matches0, matches1] + score_match_ = (score_*valid_match).float32().sum() / num_match + score_umatch_ = (score_*valid_unmatch).float32().sum() / num_unmatch + return -(num_unmatch * score_match_ + num_match * score_umatch_) / (num_match + num_unmatch) + # print(score_umatch_, score_match_) + # return -(score_match + score_umatch) / (num_match + num_unmatch) + + score_match = scores[(batchIdx[valid_match], matches0[valid_match], matches1[valid_match])].float32().mean() if num_match > 0 else 0 + score_umatch = scores[(batchIdx[valid_unmatch], matches0[valid_unmatch], matches1[valid_unmatch])].float32().mean() if num_unmatch > 0 else 0 + # print(score_match, score_umatch) + return -(num_unmatch * score_match + num_match * score_umatch) / (num_match + num_unmatch) + + +def add_dustbin(scores, alpha): + b, m, n = scores.shape + bins0 = jt.broadcast(alpha, (b, m, 1)) + bins1 = jt.broadcast(alpha, (b, 1, n)) + alpha = jt.broadcast(alpha, (b, 1, 1)) + couplings = jt.concat([jt.concat([scores, bins0], -1), jt.concat([bins1, alpha], -1)], 1) + return couplings + + +class SuperGlue(nn.Module): + def __init__(self, config): + super().__init__() + config = {**default_config, **config} + self.descriptor_dim = config['descriptor_dim'] + self.keypoint_encoder = config['keypoint_encoder'] + self.GNN_layers = config['GNN_layers'] + self.sinkhorn_iterations = config['sinkhorn_iterations'] + self.match_threshold = config['match_threshold'] + self.keypoint_position_dim = config['keypoint_position_dim'] + self.use_dual_softmax = config['use_dual_softmax'] + self.scale = jt.float(self.descriptor_dim**-0.5).stop_grad() + # self.scale.requires_grad = False + + # self.des_extend = MLP([128, 256]) + + self.kenc = KeypointEncoder(self.descriptor_dim, self.keypoint_encoder, keypoint_position_dim=self.keypoint_position_dim) + + self.gnn = AttentionalGNN(self.descriptor_dim, self.GNN_layers) + + self.final_proj = Conv1d_sp(self.descriptor_dim, self.descriptor_dim, kernel_size=1, bias=True) + + self.bin_score = jt.float(1.0) + + def execute(self, data): + """Run SuperGlue on a pair of keypoints and descriptors""" + + kpts0, kpts1 = data['keypoints0'], data['keypoints1'] + desc0, desc1 = data['descriptors0'], data['descriptors1'] + all_matches = data['all_matches'] + # match_num = data['match_num'] + + if kpts0.shape[1] == 0 or kpts1.shape[1] == 0 or all_matches.shape[1] == 0: # no keypoints or no matches/unmatches + shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1] + return { + 'matches0': jt.ones(shape0, dtype=jt.int), + 'matches1': jt.ones(shape1, dtype=jt.int), + 'matching_scores0': jt.zeros(shape0, dtype=jt.float), + 'matching_scores1': jt.zeros(shape1, dtype=jt.float), + 'skip_train': True + } + + # Keypoint normalization. + kpts0 = normalize_keypoints(kpts0, data['shape0']) + kpts1 = normalize_keypoints(kpts1, data['shape1']) + + # Keypoint MLP encoder. + # desc0 = self.des_extend(desc0) + self.kenc(kpts0, data['scores0']) + # desc1 = self.des_extend(desc1) + self.kenc(kpts1, data['scores1']) + desc0 = desc0 + self.kenc(kpts0, data['scores0']) + desc1 = desc1 + self.kenc(kpts1, data['scores1']) + + # Multi-layer Transformer network. + desc0, desc1 = self.gnn(desc0, desc1) + + # Final MLP projection. + desc0, desc1 = self.final_proj(desc0), self.final_proj(desc1) + desc0_t = desc0.t() + losses = [] + + for i in range(0, desc1.shape[0], split_size): + end = min(desc1.shape[0], i + split_size) + + # Compute matching descriptor distance. + scores = nn.bmm(desc0_t[i:end], desc1[i:end]) * self.scale # 457.76 MB + scores.sync() + + # Run the optimal transport. + if self.use_dual_softmax: + scores = add_dustbin(scores, self.bin_score) # 458.68 MB + scores.sync() + dual_softmax0, dual_softmax1 = nn.log_softmax(scores, 1), nn.log_softmax(scores, 2) + scores = dual_softmax0 + dual_softmax1 # 458.22 MB + scores.sync() + else: + scores = log_optimal_transport(scores, self.bin_score, iters=self.config['sinkhorn_iterations']) + + # loss = torch.stack([get_match_score(scores[b], all_matches[b]) for b in range(all_matches.shape[0])]) + + loss = get_weighted_loss_batch(scores, all_matches[i:end]) + loss.sync() + losses.append(loss) + loss = jt.concat(losses) + ''' + # Compute matching descriptor distance. + scores = nn.bmm(desc0.t(), desc1) * self.scale # 457.76 MB + scores.sync() + + # Run the optimal transport. + if self.use_dual_softmax: + scores = add_dustbin(scores, self.bin_score) # 458.68 MB + scores.sync() + dual_softmax0, dual_softmax1 = nn.log_softmax(scores, 1), nn.log_softmax(scores, 2) + scores = dual_softmax0 + dual_softmax1 # 458.22 MB + scores.sync() + else: + scores = log_optimal_transport(scores, self.bin_score, iters=self.config['sinkhorn_iterations']) + + # loss = torch.stack([get_match_score(scores[b], all_matches[b]) for b in range(all_matches.shape[0])]) + + loss = get_weighted_loss_batch(scores, all_matches) + # print(scores.shape, all_matches.shape, loss.shape) + ''' + + # matches0, matches1 = all_matches.chunk(chunks=2, dim=2) + # batchIdx = jt.arange(0, b).unsqueeze(1).repeat(1, num) + # batchIdx, matches0, matches1 = batchIdx.view(-1), matches0.view(-1), matches1.view(-1) + # validmatch = (matches0 >= 0) | (matches1 >= 0) + # batchIdx, matches0, matches1 = batchIdx[validmatch], matches0[validmatch], matches1[validmatch] + # matches0[matches0 == -1] = n + # matches1[matches1 == -1] = m + # loss_mean = -scores[(batchIdx, matches0, matches1)].mean() + # loss_mean = nn.l1_loss(loss_mean, jt.float(0.0)) + + if not data['return_match']: + return {'loss': loss} + + with jt.no_grad(): + b, n, m = scores.shape + # Get the matches with score above "match_threshold". + indices0, max0 = scores[:, :-1, :-1].argmax(2) + indices1, max1 = scores[:, :-1, :-1].argmax(1) + mutual0 = jt.arange(0, n)[None] == indices1.gather(1, indices0) + mutual1 = jt.arange(0, m)[None] == indices0.gather(1, indices1) + # zero = scores.new_tensor(0) + # mscores0 = torch.where(mutual0, max0.values.exp(), zero) + mscores0 = max0.exp() + mscores0[mutual0.logical_not()] = 0 + # mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) + mscores1 = mscores0.gather(1, indices1) + mscores1[mutual1.logical_not()] = 0 + valid0 = mutual0 & (mscores0 > self.match_threshold) + valid1 = mutual1 & valid0.gather(1, indices1) + # indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) + # indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) + indices0[valid0.logical_not()] = -1 + indices1[valid1.logical_not()] = -1 + + return { + 'matches0': indices0, # use -1 for invalid match + 'matches1': indices1, # use -1 for invalid match + 'matching_scores0': mscores0, + 'matching_scores1': mscores1, + 'loss': loss, + } + + # scores big value or small value means confidence? log can't take neg value \ No newline at end of file diff --git a/python/jittor/test/test_benchmark.py b/python/jittor/test/test_benchmark.py new file mode 100644 index 00000000..2e7db8b7 --- /dev/null +++ b/python/jittor/test/test_benchmark.py @@ -0,0 +1,344 @@ +# *************************************************************** +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import os + +n = 400000000 +# n = 4000000 +n = 7680000 + +def get_mem_band(): + a = jt.rand((n)).float32() + for i in range(100): + a.copy().sync() + jt.sync_all(True) + import time + t = time.time() + for i in range(1000): + a.copy().sync() + jt.sync_all(True) + dt = time.time() - t + band = a.numel() * 4 * 2000 / dt / 1024**3 + print("Mem band: ", band) + return band + +def check_simple_add_band(): + # copy: 816 + # S=1 128,1024, ILP=1 634 + # S=0 128,1024, ILP=1 734 + # S=0 128,512, ILP=1 716 + # S=0 64,1024, ILP=1 706 + # S=0 256,1024, ILP=1 706 + def test(S=0, B=128, T=1024, ILP=1): + a = jt.rand((n)).float32() + jt.sync_all(True) + jt.flags.log_silent = 1 + with jt.profile_scope(100, 1000) as rep: + b = jt.code(a.shape, a.dtype, [a], + cuda_header="#include \"type/fp16_compute.h\"", + cuda_src=f""" + __global__ void kernel(in0_type * __restrict__ a, in0_type* __restrict__ b, int num) {{ + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int tnum = blockDim.x * gridDim.x; + #define ILP {ILP} + for (int i=tid*ILP; i(b+i, a+i); + {"__syncthreads();" if S else ""} + }} + }} + kernel<<<{B},{T}>>>(in0_p, out0_p, in0->num); + """) + b.sync() + bw = float(rep[-1][9]) / 1024**3 + s = f"S={S}, B={B}, T={T}, ILP={ILP} BW={bw}" + print(s) + return s, bw + + def test2(S=0, B=128, T=1024, ILP=1): + a = jt.rand((n)).float32() + jt.sync_all(True) + # jt.flags.log_silent = 0 + with jt.profile_scope(10, 1000) as rep: + b = jt.code(a.shape, a.dtype, [a], + cuda_header="#include \"type/fp16_compute.h\"", + cuda_src=f""" + __global__ void kernel(float2 * __restrict__ a, float2* __restrict__ b, int num) {{ + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int tnum = blockDim.x * gridDim.x; + #define ILP 1 + for (int i=tid*ILP; i(b+i, a+i); + {"__syncthreads();" if S else ""} + }} + }} + kernel<<<{B},{T}>>>((float2*)in0_p, (float2*)out0_p, in0->num/2); + """) + b.sync() + bw = float(rep[-1][9]) / 1024**3 + s = f"T2: S={S}, B={B}, T={T}, ILP={ILP} BW={bw}" + print(s) + return s, bw + + + def test3(S=0, B=128, T=1024, ILP=1, C=0): + a = jt.rand((n)).float32() + b = jt.rand(B) + jt.sync_all(True) + jt.flags.log_silent = 1 + with jt.profile_scope(100, 1000) as rep: + b = jt.code(a.shape, a.dtype, [a, b], + cuda_header="#include \"type/fp16_compute.h\"", + cuda_src=f""" + __global__ void kernel(in0_type * __restrict__ a, in0_type* __restrict__ b, int num) {{ + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int tnum = blockDim.x * gridDim.x; + #define ILP {ILP} + for (int i=tid*ILP; i(b+i, a+i); + {"__syncthreads();" if S else ""} + }} + {"__syncthreads();" if C else ""} + }} + kernel<<shape[0],{T}>>>(in0_p, out0_p, in0->num); + """) + b.compile_options = {"FLAGS: -Xptxas -dlcm=ca ": C} + # b.compile_options = {"FLAGS: –Xptxas –dlcm=ca ": 1} + b.sync() + + bw = float(rep[-1][9]) / 1024**3 + s = f"T3: S={S}, B={B}, T={T}, ILP={ILP} C={C} BW={bw}" + print(s) + return s, bw + + + def test4(S=0, B=128, T=1024, ILP=1, C=0, name="b.png"): + a = jt.rand((n)).float32() + b = jt.rand(B*4).uint32() + jt.sync_all(True) + # jt.flags.log_silent = 1 + with jt.profile_scope(100, 10000) as rep: + _ = jt.code(a.shape, a.dtype, [a, b], + cuda_header="#include \"type/fp16_compute.h\"", + cuda_src=f""" + __device__ uint get_smid(void) {{ + uint ret; + asm("mov.u32 %0, %smid;" : "=r"(ret) ); + return ret; + }} + __device__ uint get_time(void) {{ + uint ret; + asm volatile("mov.u32 %0, %%globaltimer_lo;" : "=r"(ret)); + return ret; + }} + + __global__ void kernel(in0_type * __restrict__ a, in0_type* __restrict__ b, int num, in1_type* __restrict__ c) {{ + uint t = get_time(); + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int tnum = blockDim.x * gridDim.x; + #define ILP {ILP} + for (int i=tid*ILP; i(b+i, a+i); + {"__syncthreads();" if S else ""} + }} + {"__syncthreads();" if C else ""} + if (threadIdx.x == 0) + ((uint4* __restrict__)c)[blockIdx.x] = + uint4{{get_smid(), t, get_time(), 0}}; + }} + kernel<<shape[0]/4,{T}>>>(in0_p, out0_p, in0->num, in1_p); + """) + _.compile_options = {"FLAGS: -Xptxas -dlcm=ca ": C} + # b.compile_options = {"FLAGS: –Xptxas –dlcm=ca ": 1} + _.sync() + + bw = float(rep[-1][9]) / 1024**3 + b = b.data.reshape(-1, 4)[:,:3] + mint = b[:,1].min() + b[:,1:] -= mint + smmax = int(b[:,0].max()) + smmin = int(b[:,0].min()) + maxt = b.max() + + # print(b) + + s = f"T4: S={S}, B={B}, T={T}, ILP={ILP} C={C} BW={bw:.3f} sm={smmin},{smmax} maxt={maxt}" + print(s) + import pylab as pl + pl.figure(figsize=(16,16)) + texts = [] + pret = np.zeros(200, dtype="uint32") + for i in range(B): + smid, s, t = b[i] + pl.plot([s,t], [smid, smid], 'ro-') + texts.append((s, smid, i)) + texts.append((t, smid, i)) + + texts = sorted(texts) + for (s, smid, bid) in texts: + cpos = max(pret[smid], s) + pl.text(cpos, smid, str(bid)) + pret[smid] = cpos + maxt // 30 + + + # print("???") + # adjust_text(texts, arrowprops=dict(arrowstyle='->', color='blue')) + # print("???") + pl.savefig(name) + pl.close() + return s, bw + # test(S=0, B=128, T=1024, ILP=1) + # test(S=1, B=128, T=1024, ILP=1) + # test(S=0, B=64, T=1024, ILP=1) + # test(S=0, B=256, T=1024, ILP=1) + # test(S=1, B=128, T=512, ILP=1) + # test(S=1, B=128, T=256, ILP=1) + + # test(S=0, B=128, T=1024, ILP=2) + # test(S=0, B=128, T=1024, ILP=4) + # test(S=0, B=128, T=512, ILP=2) + # test(S=0, B=128, T=512, ILP=4) + + # test(S=1, B=128, T=1024, ILP=2) + # test(S=1, B=128, T=1024, ILP=4) + # test(S=1, B=128, T=1024, ILP=8) + # test(S=1, B=128, T=1024, ILP=16) + # test(S=1, B=128, T=512, ILP=2) + # test(S=1, B=128, T=512, ILP=4) + + # test(S=1, B=256, T=1024, ILP=2) + # test(S=1, B=512, T=1024, ILP=2) + # test(S=1, B=256, T=1024, ILP=4) + # test(S=1, B=256, T=1024, ILP=8) + # test(S=1, B=256, T=1024, ILP=16) + # test(S=1, B=256, T=512, ILP=2) + # test(S=1, B=256, T=512, ILP=4) + + # test(S=1, B=128, T=256, ILP=2) + # test(S=1, B=128, T=256, ILP=4) + # test(S=0, B=128, T=256, ILP=2) + # test(S=0, B=128, T=256, ILP=4) + + # for b in [1, 2, 4, 8, 16, 32, 64, 128,256]: + # test(S=1, B=b, T=512, ILP=2) + + import matplotlib as mpl + mpl.use('Agg') + import pylab as pl + import numpy as np + + # test4(S=1, B=82, T=1024, ILP=2, C=0, name="b.png") + # test4(S=1, B=83, T=1024, ILP=2, C=0, name="c.png") + # test4(S=1, B=82*3, T=512, ILP=2, C=0, name="d1.png") + # test4(S=1, B=82*3+1, T=512, ILP=2, C=0, name="d2.png") + # test4(S=1, B=82*6+1, T=512, ILP=2, C=0, name="d3.png") + # test4(S=0, B=82*6+1, T=512, ILP=2, C=0, name="d4.png") + + for b in range(70, 83): + test4(S=1, B=b, T=1024, ILP=2, C=0, name=f"b-{b}.png") + + # data = [] + # for b in range(32, 2000, 8): + # _, bw = test3(S=0, B=b, T=32, ILP=2) + # data.append([b, bw]) + # data = np.array(data) + # pl.plot(data[:,0], data[:,1]) + + # for t in [32, 64, 128, 256, 512, 1024]: + # data = [] + # for b in range(32, 2000, 8): + # _, bw = test3(S=1, B=b*(1024//t), T=t, ILP=2) + # data.append([b, bw]) + # data = np.array(data) + # pl.plot(data[:,0], data[:,1]) + + # for t in [1024]: + # for c in [0,1]: + # data = [] + # # for b in range(32, 1000, 8): + # for b in range(32, 33, 8): + # _, bw = test3(S=c, B=b*(1024//t), T=t, ILP=2, C=0) + # data.append([b, bw]) + # data = np.array(data) + # pl.plot(data[:,0], data[:,1]) + + # for ilp in [2]: + # for s in [1]: + # for t in [1024,512,256,128]: + # data = [] + # for b in range(32, 1100, 8): + # _, bw = test3(S=s, B=b*(1024//t), T=t, ILP=ilp) + # data.append([b, bw]) + # data = np.array(data) + # pl.plot(data[:,0], data[:,1]) + + # pl.savefig("a.png") + # pl.close() + # for b in range(80, 90, 1): + # _, bw = test3(S=1, B=b, T=1024, ILP=2) + # # 82 + # for b in range(240, 260, 1): + # _, bw = test3(S=1, B=b, T=512, ILP=2) + # # 82*3 = 246 + # for b in range(240, 500, 1): + # _, bw = test3(S=1, B=b, T=256, ILP=2) + # # 492 = 82*6 + # for b in range(240, 1000, 1): + # _, bw = test3(S=1, B=b, T=128, ILP=2) + # # 984 = 82*12 + + + # for b in [128,256]: + # test(S=1, B=b, T=1024, ILP=2) + # for b in [128,256]: + # test(S=0, B=b, T=512, ILP=2) + # for b in [128,256]: + # test(S=0, B=b, T=1024, ILP=2) + # for b in [128,256]: + # test(S=1, B=b, T=512, ILP=1) + # for b in [128,256]: + # test(S=1, B=b, T=1024, ILP=1) + # for b in [128,256]: + # test(S=0, B=b, T=512, ILP=1) + # for b in [128,256]: + # test(S=0, B=b, T=1024, ILP=1) + # test(S=1, B=128, T=512, ILP=4) + # test(S=1, B=64, T=512, ILP=2) + # test(S=1, B=80, T=512, ILP=2) + # test(S=1, B=100, T=512, ILP=2) + # test(S=1, B=110, T=512, ILP=2) + # test(S=1, B=115, T=512, ILP=2) + # test(S=1, B=120, T=512, ILP=2) + # test(S=1, B=130, T=512, ILP=2) + # test(S=1, B=140, T=512, ILP=2) + # test2(S=1, B=128, T=512, ILP=2) + # test(S=1, B=128, T=256, ILP=4) + # test(S=1, B=128, T=128, ILP=8) + # test(S=1, B=128, T=64, ILP=16) + + + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestBenchmarkCUDA(unittest.TestCase): + def setUp(self): + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.use_cuda = 0 + + def test_main(self): + return + get_mem_band() + check_simple_add_band() + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_binary_op.py b/python/jittor/test/test_binary_op.py index 6eec8c0f..1947fef8 100644 --- a/python/jittor/test/test_binary_op.py +++ b/python/jittor/test/test_binary_op.py @@ -19,12 +19,12 @@ def all_eq(x, y): y = convert(y) if str(x.dtype).startswith("float"): return str(y.dtype).startswith("float") and x.shape == y.shape and (x==y).all() - return x.dtype == y.dtype and x.shape == y.shape and (x==y).all() + return x.dtype == y.dtype and x.shape == y.shape and np.testing.assert_allclose(x, y) def check(op, *args): x = eval(f"np.{op}(*args)") y = eval(f"jt.{op}(*args).data") - assert all_eq(x, y), f"{x}\n{y}" + all_eq(x, y) class TestBinaryOp(unittest.TestCase): def test_binary_op(self): @@ -47,6 +47,9 @@ class TestBinaryOp(unittest.TestCase): def test_i(self): def check(op, a, b): + if isinstance(a, list): + a = np.array(a) + b = np.array(b) if jt.flags.use_cuda and op == "@": return if op=="@": @@ -65,13 +68,13 @@ class TestBinaryOp(unittest.TestCase): a = np.float32(a) ja = np.float32(ja) - assert all_eq(ja, a), (ja,a) + all_eq(ja, a) check("+", 5, 2) check("-", 5, 2) check("*", 5, 2) check("/", 5, 2) check("//", 5, 2) - check("@", [[5]], [[2]]) + # check("@", [[5]], [[2]]) check("%", 5, 2) check("**", 5, 2) check("<<", 5, 2) @@ -80,6 +83,15 @@ class TestBinaryOp(unittest.TestCase): check("^", 5, 2) check("|", 5, 2) + check("+", [5.0,6.0], [2.0,3.0]) + check("-", [5.0,6.0], [2.0,3.0]) + check("*", [5.0,6.0], [2.0,3.0]) + check("/", [5.0,6.0], [2.0,3.0]) + check("//", [5.0,6.0], [2.0,3.0]) + check("@", [[5,6],[7,8]], [[2,3],[4,5]]) + check("%", [5.0,6.0], [2.0,3.0]) + check("**", [5.0,6.0], [2.0,3.0]) + def test_r(self): def check(op, a, b): a = np.array(a) @@ -97,7 +109,7 @@ class TestBinaryOp(unittest.TestCase): a = eval(f"a {op} b") a = np.array(a) - assert all_eq(jc, a), f"\n{jc}\n{a}" + all_eq(jc, a) check("+", 5, 2) check("-", 5, 2) check("*", 5, 2) @@ -118,6 +130,7 @@ class TestBinaryOp(unittest.TestCase): a = np.random.rand(10) b = np.random.rand(10) c = np.random.rand(10) + tol = 1e-2 if jt.flags.amp_reg & 2 else 1e-4 for op in ops: func = lambda x: eval(f"((x[0]{op}x[1])*x[2]).sum()") x, grads = ngrad(func, [a,b,c], 1e-8) @@ -127,7 +140,7 @@ class TestBinaryOp(unittest.TestCase): jx = eval(f"(ja{op}jb)*jc") jgrads = jt.grad(jx, [ja,jb,jc]) for jd, nd in zip(jgrads, grads): - assert (np.abs(jd.data-nd)<1e-4).all(), f"\n{jd.data}\n{nd}" + np.testing.assert_allclose(jd.data, nd, atol=tol, rtol=tol) def test_mod_float(self): a = jt.random((10,)) @@ -137,7 +150,8 @@ class TestBinaryOp(unittest.TestCase): a = jt.random((10,), 'float64') b = jt.random((10,), 'float64') c = a % b - assert np.allclose(c.data, a.data % b.data) + assert np.allclose(c.data, a.data % b.data, a.data, b.data) + if jt.flags.amp_reg & 2: return a = jt.random((10,)) * 1000 b = (jt.random((10,)) * 10).int() + 1 c = a % b @@ -169,5 +183,19 @@ class TestBinaryOp(unittest.TestCase): class TestBinaryOpCuda(TestBinaryOp, test_cuda(2)): pass +class TestBinaryOpCpuFp16(TestBinaryOp): + def setUp(self): + jt.flags.amp_reg = 2 | 4 | 8 | 16 + def tearDown(self): + jt.flags.amp_reg = 0 + +class TestBinaryOpCudaFp16(TestBinaryOp): + def setUp(self): + jt.flags.amp_reg = 2 | 4 | 8 | 16 + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.amp_reg = 0 + jt.flags.use_cuda = 0 + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_fp16.py b/python/jittor/test/test_fp16.py new file mode 100644 index 00000000..d948e402 --- /dev/null +++ b/python/jittor/test/test_fp16.py @@ -0,0 +1,342 @@ +# *************************************************************** +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import os + +def transpose0231(x): + s0, s1, s2, s3 = x.shape + asize = 16 + bsize = 16 + ILP = 2 + return jt.code([s0, s2, s3, s1], x.dtype, [x], + cuda_header="#include \n#include ", + cuda_src=f""" + __global__ void kernel(in0_type* __restrict__ x, in0_type* __restrict__ y, int s0, int s1, int s2, int s3) {{ + __shared__ in0_type t[{asize*ILP}*{bsize*ILP+1}]; + int t3 = threadIdx.x % {bsize}; + int t1 = threadIdx.x / {bsize}; + int b3 = blockIdx.x; + int b2 = blockIdx.y; + int b0 = blockIdx.z; + int x3 = 1; + int x2 = s3; + int x1 = s2*x2; + int x0 = s1*x1; + int y3 = 1; + int y2 = s1; + int y1 = s3*y2; + int y0 = s2*y1; + in0_type tmp[{ILP}]; + for (int i=0; i<(s1-1)/{asize*ILP}+1; i++) + {{ + int _b3 = b3 * {bsize*ILP} + t3*{ILP}; + if (_b3 < s3) {{ + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + vload( + tmp, + &x[b0*x0+(t1*{ILP}+j+i*{asize*ILP})*x1+b2*x2+_b3*x3] + ); + #pragma unroll + for (int k=0; k<{ILP}; k++) + t[(t1*{ILP}+j)*{bsize*ILP+1}+t3*{ILP}+k] = tmp[k]; + + }} + }} + __syncthreads(); + int t3_ = threadIdx.x % {asize}; + int t1_ = threadIdx.x / {asize}; + _b3 = b3 * {bsize*ILP} + t1_*{ILP}; + if (_b3 < s3) {{ + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + #pragma unroll + for (int k=0; k<{ILP}; k++) {{ + tmp[k] = + t[(t3*{ILP}+k)*{bsize*ILP+1}+t1_*{ILP}+j]; + }} + vload( + &y[b0*y0+b2*y1+(_b3+j)*y2+((t3*{ILP})+i*{asize*ILP})*y3], + tmp + ); + }} + }} + __syncthreads(); + }} + }} + int s0, s1, s2, s3; + in0->shape.unpack(s0, s1, s2, s3); + kernel<<<{{(s3-1)/{bsize*ILP}+1, s2, s0 }}, {bsize*asize}>>> + (in0_p, out0_p, s0, s1, s2, s3); + """) + +def transpose0231_2(x): + s0, s1, s2, s3 = x.shape + asize = 16 + bsize = 8 + ILP = 2 + return jt.code([s0, s2, s3, s1], x.dtype, [x], + cuda_header="#include \n#include ", + cuda_src=f""" + __global__ __launch_bounds__({asize*bsize}) void kernel(in0_type* __restrict__ x, in0_type* __restrict__ y, int s0, int s1, int s2, int s3) {{ + __shared__ in0_type t[{asize*ILP}*{bsize*ILP+1}]; + int t3 = threadIdx.x % {bsize}; + int t1 = threadIdx.x / {bsize}; + int b3 = blockIdx.x; + int b1 = blockIdx.y; + int b2 = 0; + int b0 = blockIdx.z; + int x3 = 1; + int x2 = s3; + int x1 = s2*x2; + int x0 = s1*x1; + int y3 = 1; + int y2 = s1; + int y1 = s3*y2; + int y0 = s2*y1; + in0_type tmp[{ILP}]; + {{ + int _b3 = b3 * {bsize*ILP} + t3*{ILP}; + if (_b3 < s3) {{ + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + if (t1*{ILP}+j+b1*{asize*ILP} >= s1) + continue; + vload( + tmp, + &x[b0*x0+(t1*{ILP}+j+b1*{asize*ILP})*x1+b2*x2+_b3*x3] + ); + #pragma unroll + for (int k=0; k<{ILP}; k++) + t[(t1*{ILP}+j)*{bsize*ILP+1}+t3*{ILP}+k] = tmp[k]; + + }} + }} + __syncthreads(); + int t3_ = threadIdx.x % {asize}; + int t1_ = threadIdx.x / {asize}; + _b3 = b3 * {bsize*ILP} + t1_*{ILP}; + int yy3 = (t3_*{ILP})+b1*{asize*ILP}; + if (_b3 < s3 && yy3 < s1) {{ + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + #pragma unroll + for (int k=0; k<{ILP}; k++) {{ + tmp[k] = + t[(t3_*{ILP}+k)*{bsize*ILP+1}+t1_*{ILP}+j]; + }} + vload( + &y[b0*y0+b2*y1+(_b3+j)*y2+yy3*y3], + tmp + ); + // printf("%d %d %d %d %d\\n", b0*y0+b2*y1+(_b3+j)*y2+yy3*y3, + // b0, b2, (_b3+j), yy3); + }} + }} + __syncthreads(); + }} + }} + int s0, s1, s2, s3; + in0->shape.unpack(s0, s1, s2, s3); + kernel<<<{{(s3-1)/{bsize*ILP}+1, (s1-1)/{asize*ILP}+1, s0 }}, {bsize*asize}>>> + (in0_p, out0_p, s0, s1, s2, s3); + """) + +def check_share(): + return + a = jt.rand((30, 32, 4, 2000)).float32() + jt.code(a.shape, a.dtype, [a], + cuda_header="#include \n#include ", + cuda_src=""" + __global__ void kernel(in0_type* __restrict__ a, in0_type* __restrict__ b) { + __shared__ float x[32*33]; + for (int i=0; i<3; i++) { + ((float2*)&x[i])[0] = ((float2*)&a[i])[0]; + ((float2*)&b[i])[0] = ((float2*)&x[i+1])[0]; + } + } + kernel<<<1024,16*16>>>(in0_p, out0_p); + LOGir << "aaa"; + """).sync() + jt.sync_all(True) + # print(a[0]+1) + print("pass test") + +class TestFP16(unittest.TestCase): + def test_array(self): + a = np.array([1,2,3], dtype="float16") + b = jt.array(a) + np.testing.assert_allclose(a, b.data) + + def test_add(self): + a = np.array([1,2,3], dtype="float16") + b = jt.array(a) + c = b+b + np.testing.assert_allclose(c.data, a+a) + d = c.sum() + np.testing.assert_allclose(d.data, [12]) + c = c+1 + print(c) + + def test_matmul(self): + a = jt.random((100,100)).float16() + b = jt.random((100,100)).float16() + c = jt.matmul(a, b) + c.sync() + + def test_matmul_grad(self): + a = jt.random((100,100)).float16() + b = jt.random((100,100)).float16() + c = jt.matmul(a, b) + c.sync() + da, db = jt.grad(c, [a,b]) + jt.sync_all() + assert da.dtype == "float16" + assert db.dtype == "float16" + + def test_array_random_auto_cast(self): + a = jt.array([1.0,2.0]) + assert a.dtype == "float32" + with jt.flag_scope(amp_reg=2+16): + a = jt.array([1.0,2.0]) + assert a.dtype == "float16", a.dtype + + a = jt.random([10]) + assert a.dtype == "float32" + with jt.flag_scope(amp_reg=2+16): + a = jt.random([10]) + assert a.dtype == "float16", a.dtype + + def test_conv(self): + a = jt.random((3,4,5,5)).float16() + b = jt.random((4,4,3,3)).float16() + c = jt.nn.conv(a, b) + c.sync() + + def test_max(self): + a = jt.random((100,)).float16() + b = jt.random((100,)).float16() + c = a.maximum(b) + c.sync() + + def test_reduce_dtype_infer(self): + with jt.flag_scope(amp_reg=1): + a = jt.random((3,4,5,5)).float16() + b = a.sum() + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=2): + a = jt.random((3,4,5,5)).float16() + b = a.sum() + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=0): + a = jt.random((3,4,5,5)).float16() + b = a.sum() + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=2+4): + a = jt.random((3,4,5,5)).float16() + b = a.sum() + b.sync() + assert b.dtype == "float16", b.dtype + + def test_white_dtype_infer(self): + with jt.flag_scope(amp_reg=1): + a = jt.random((3,4,5,5)).float16() + b = a**a + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=2): + a = jt.random((3,4,5,5)).float16() + b = a**a + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=0): + a = jt.random((3,4,5,5)).float16() + b = a**a + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=2+8): + a = jt.random((3,4,5,5)).float16() + b = a**a + b.sync() + assert b.dtype == "float16", b.dtype + + + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestFP16CUDA(TestFP16): + def setUp(self): + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.use_cuda = 0 + + def test_softmax(self): + a = jt.rand((120, 2000, 2000)).float16() + # a = jt.rand((1, 2000, 2000)).float32() + jt.sync_all() + with jt.profile_scope(10, 100): + a.log_softmax(-1).sync() + + def test_transpose(self): + check_share() + # return + a = jt.rand((30, 32, 4, 2000)).float32() + # a = jt.rand((1, 1024, 1, 2000)).float32() + diff = transpose0231(a).data != a.transpose((0,2,3,1)).data + print(np.where(diff)) + # return + jt.sync_all() + # with jt.profile_scope(100, 11000): + with jt.profile_scope(100, 11000): + # a.log_softmax(-1).sync() + transpose0231(a).sync() + + a.transpose((0,2,3,1)).sync() + # a.transpose((0,2,1,3)).sync() + a.fuse_transpose((0,2,1,3)).sync() + (a+1).sync() + jt.sync_all(True) + diff = transpose0231(a).data != a.transpose((0,2,3,1)).data + print(np.where(diff)) + np.testing.assert_allclose(transpose0231(a).data, a.transpose((0,2,3,1)).data) + + def test_transpose2(self): + # check_share() + # return + # a = jt.rand((30, 32, 4, 2000)).float32() + # a = jt.rand((1, 10000, 1, 2000)).float32() + a = jt.rand((1, 10000, 1, 2048)).float32() + print("transpose") + transpose0231_2(a).sync() + print("add") + (a+1).sync() + return + # a = jt.arange(32*16).reshape((1, 32, 1, 16)) + diff = transpose0231_2(a).data != a.transpose((0,2,3,1)).data + print(np.where(diff)) + # return + jt.sync_all() + # with jt.profile_scope(100, 11000): + with jt.profile_scope(100, 1100): + # a.log_softmax(-1).sync() + transpose0231_2(a).sync() + + a.transpose((0,2,3,1)).sync() + # a.transpose((0,2,1,3)).sync() + a.fuse_transpose((0,2,1,3)).sync() + (a+1).sync() + jt.sync_all(True) + diff = transpose0231_2(a).data != a.transpose((0,2,3,1)).data + print(np.where(diff)) + np.testing.assert_allclose(transpose0231_2(a).data, a.transpose((0,2,3,1)).data) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_misc_op.py b/python/jittor/test/test_misc_op.py index 144002f8..47f04762 100644 --- a/python/jittor/test/test_misc_op.py +++ b/python/jittor/test/test_misc_op.py @@ -75,6 +75,8 @@ class TestPad(unittest.TestCase): print('pass flip test ...') def test_cross(self): + def check_equal(a, b, tol): + np.testing.assert_allclose(a.detach().numpy(), b.numpy(), atol=1e-5) arr1 = np.random.randn(16,3,224,224,3) arr2 = np.random.randn(16,3,224,224,3) check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=1), jt.array(arr1).cross(jt.array(arr2), dim=1), 1e-1) @@ -257,34 +259,51 @@ class TestOther(unittest.TestCase): a = jt.arctan2(jt.array([1,1.0,0]), jt.array([1,0.0,-1])) np.testing.assert_allclose(a.data, [0.7853982,1.5707964,3.1415927]) + y = jt.random((100,)) + x = jt.random((100,)) + z = jt.arctan2(y, x) + z2 = np.arctan2(y.data, x.data) + np.testing.assert_allclose(z.data, z2) + def test_code_softmax(self): if not jt.has_cuda: return - def softmax(x, dim = None): + def softmax(x, dim = None, log=False): if dim is None: x = (x - x.max()).exp() ret = x / x.sum() else: x = (x-x.max(dim, keepdims=True)).exp() ret = x / x.sum(dim, keepdims=True) + if log: return ret.log() return ret from jittor.other.code_softmax import softmax_v1 with jt.flag_scope(use_cuda = 1): shape = (120, 2000, 2000) - # shape = (3,3) - a = jt.rand(shape) - c = jt.rand(shape) - b = softmax(a, -1) - bb = softmax_v1(a) + shape = (3,3) + for log in [0,1]: + for shape in [(3,3), + (12, 200, 2000), + (12, 200, 2048), + (12, 200, 2049)]: + print(shape) + a = jt.rand(shape) + c = jt.rand(shape) + b = softmax(a, -1, log=log) + bb = softmax_v1(a, log=log) - err = (bb - b).abs().max() - assert err.item() < 1e-5 + err = (bb - b).abs().max() + assert err.item() < 1e-5, (err, bb, b) - d1 = jt.grad(b*c, a) - d2 = jt.grad(bb*c, a) - err = (d1 - d2).abs().max() - assert err.item() < 1e-5 + d1 = jt.grad(b*c, a) + d2 = jt.grad(bb*c, a) + err = (d1 - d2).abs().max() + + if log: + assert err.item() < 1e-2, (err.item()) + else: + assert err.item() < 1e-5, (err.item()) if __name__ == "__main__": diff --git a/python/jittor/test/test_resnet.py b/python/jittor/test/test_resnet.py index 8defb633..8de9fc3f 100644 --- a/python/jittor/test/test_resnet.py +++ b/python/jittor/test/test_resnet.py @@ -36,19 +36,7 @@ class MnistNet(Module): return x @unittest.skipIf(skip_this_test, "skip_this_test") -class TestResnet(unittest.TestCase): - @classmethod - def setUpClass(self): - # hyper-parameters - self.batch_size = int(os.environ.get("TEST_BATCH_SIZE", "100")) - self.weight_decay = 0.0001 - self.momentum = 0.9 - self.learning_rate = 0.1 - # mnist dataset - self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \ - .set_attrs(batch_size=self.batch_size, shuffle=True) - self.train_loader.num_workers = 4 - +class TestResnetFp32(unittest.TestCase): # setup random seed def setup_seed(self, seed): np.random.seed(seed) @@ -59,6 +47,19 @@ class TestResnet(unittest.TestCase): @jt.flag_scope(use_cuda=1, use_stat_allocator=1) def test_resnet(self): self.setup_seed(1) + + # hyper-parameters + self.batch_size = int(os.environ.get("TEST_BATCH_SIZE", "100")) + self.weight_decay = 0.0001 + self.momentum = 0.9 + self.learning_rate = 0.1 + if jt.flags.amp_reg: + self.learning_rate = 0.01 + # mnist dataset + self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \ + .set_attrs(batch_size=self.batch_size, shuffle=True) + self.train_loader.num_workers = 4 + loss_list=[] acc_list=[] mnist_net = MnistNet() @@ -70,6 +71,7 @@ class TestResnet(unittest.TestCase): for data, target in self.train_loader: batch_id = self.train_loader.batch_id epoch_id = self.train_loader.epoch_id + data = data.float_auto() # train step # with jt.log_capture_scope( @@ -120,6 +122,8 @@ class TestResnet(unittest.TestCase): # Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000 # Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000 + if jt.flags.amp_reg: + continue if jt.in_mpi: assert jt.core.number_of_lived_vars() < 8100, jt.core.number_of_lived_vars() else: @@ -131,5 +135,14 @@ class TestResnet(unittest.TestCase): assert np.mean(loss_list[-50:])<0.5 assert np.mean(acc_list[-50:])>0.8 + +@unittest.skipIf(skip_this_test, "skip_this_test") +class TestResnetFp16(TestResnetFp32): + def setup(self): + jt.flags.auto_mixed_precision_level = 5 + + def tearDown(self): + jt.flags.auto_mixed_precision_level = 0 + if __name__ == "__main__": unittest.main() diff --git a/python/jittor/test/test_superglue.py b/python/jittor/test/test_superglue.py new file mode 100644 index 00000000..0a3c7a18 --- /dev/null +++ b/python/jittor/test/test_superglue.py @@ -0,0 +1,121 @@ +# *************************************************************** +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import os + +from jittor.test.misc import superglue +from jittor.test.misc.superglue import SuperGlue +import time + +@jt.flag_scope(use_cuda=1) +def main(): + global superglue + superglue.split_size = int(os.environ.get("split_size", "12")) + # superglue.split_size = 1000000 + + batch = 30 + num = 2000 + dim = 128 + + # jt.display_memory_info() + # os.system("nvidia-smi") + # breakpoint() + + with jt.no_grad(): + + config = { + 'superglue': { + 'sinkhorn_iterations': 25, + 'match_threshold': 0.01, + 'keypoint_position_dim': 2, + 'descriptor_dim': dim, + 'use_dual_softmax': True, + 'GNN_layers': ['self', 'cross'] * 9, + } + } + + superglue = SuperGlue(config.get('superglue', {})) + + superglue.eval() + + data = { + 'keypoints0': jt.rand((batch, num, 2), dtype=jt.float), + 'keypoints1': jt.rand((batch, num, 2), dtype=jt.float), + 'shape0': jt.rand((batch, 2), dtype=jt.float), + 'shape1': jt.rand((batch, 2), dtype=jt.float), + 'descriptors0': jt.rand((batch, dim, num), dtype=jt.float), + 'descriptors1': jt.rand((batch, dim, num), dtype=jt.float), + 'scores0': jt.rand((batch, num), dtype=jt.float), + 'scores1': jt.rand((batch, num), dtype=jt.float), + 'all_matches': jt.randint(0, num, (batch, num, 2), dtype=jt.int), + 'return_match': False, + # 'match_num': match_num + } + + use_fp16 = int(os.environ.get("use_fp16", "0")) + if use_fp16: + jt.flags.amp_reg = 2 + for k,v in data.items(): + if isinstance(v, jt.Var) and v.dtype == "float32": + v.assign(v.float16()) + for v in superglue.parameters(): + if v.dtype == "float32": + v.assign(v.float16()) + jt.sync_all(True) + + import pickle + jt.sync_all(True) + for x in range(5): + print(x) + jt.gc() + x = superglue(data)['loss'] + x.sync() + jt.display_memory_info() + # os.system("nvidia-smi") + # breakpoint() + # print(data) + # print(x) + + # with open("/tmp/record.pkl", "wb") as f: + # pickle.dump([data, x], f, pickle.HIGHEST_PROTOCOL) + + # with jt.flag_scope(trace_py_var=3, profile_memory_enable=1): + # x = superglue(data)['loss'] + # x.sync() + # jt.get_max_memory_treemap() + # exit(0) + + jt.sync_all(True) + time0 = time.time() + jt.flags.profiler_enable = int(os.environ.get("profiler", "0")) + + for x in range(20): + print(x) + # jt.display_memory_info() + x = superglue(data)['loss'] + x.sync() + # print(x) + + jt.sync_all(True) + time1 = time.time() + print("avg time:", (time1 - time0) / 20) + return (time1 - time0) / 20 + + +class TestSuperglue(unittest.TestCase): + def test(self): + if not jt.has_cuda: return + t1 = main() + os.environ["use_fp16"] = "1" + t2 = main() + os.environ["use_fp16"] = "0" + assert t1*0.55 > t2 + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_unary_op.py b/python/jittor/test/test_unary_op.py index 24483dae..d7b8bc88 100644 --- a/python/jittor/test/test_unary_op.py +++ b/python/jittor/test/test_unary_op.py @@ -17,7 +17,8 @@ def check(op, *args): x = convert(x) y = convert(y) # str match nan and inf - assert x.dtype == y.dtype and x.shape == y.shape + assert x.dtype == y.dtype and x.shape == y.shape, \ + (x.dtype, y.dtype, x.shape, y.shape) for a,b in zip(x.flatten(), y.flatten()): assert str(a)[:5] == str(b)[:5], (a,b) @@ -32,9 +33,10 @@ class TestUnaryOp(unittest.TestCase): check("logical_not", a) check("bitwise_not", a) b = np.array([1.1, 2.2, 3.3, 4.4, -1, 0]) - check("log", a.astype("float32")) - check("exp", a.astype("float32")) - check("sqrt", a.astype("float32")) + type = "float16" if (jt.flags.amp_reg & 2) else "float32" + check("log", a.astype(type)) + check("exp", a.astype(type)) + check("sqrt", a.astype(type)) def test_grad(self): ops = ["abs", "negative", "log", "exp", "sqrt", @@ -60,7 +62,8 @@ class TestUnaryOp(unittest.TestCase): ja = jt.array(b) jb = eval(f"jt.{op}(ja)") jda = jt.grad(jb, ja) - assert (np.allclose(jda.data, da)), (jda.data,da,op) + tol = 1e-2 if jt.flags.amp_reg & 2 else 1e-6 + assert (np.allclose(jda.data, da, atol=tol, rtol=tol)), (jda.data,da,op) def test_sigmoid(self): a = np.arange(-150,150, 10).astype("float32") @@ -92,11 +95,26 @@ class TestUnaryOp(unittest.TestCase): np.testing.assert_allclose(y.data, y2.data) d = jt.grad(x2, y2) _, (dn,) = ngrad(lambda y: special.erfinv(y).sum(), [y], 1e-8) - np.testing.assert_allclose(d.data, dn, atol=1e-6, rtol=1e-6) + tol = 1e-3 if jt.flags.amp_reg & 2 else 1e-6 + np.testing.assert_allclose(d.data, dn, atol=tol, rtol=tol) class TestUnaryOpCuda(TestUnaryOp, test_cuda(2)): pass +class TestUnaryOpCudaFp16(TestUnaryOp, test_cuda(2)): + def setUp(self): + jt.flags.amp_reg = 2 | 4 | 8 | 16 + def tearDown(self): + jt.flags.amp_reg = 0 + +class TestUnaryOpCudaFp16(TestUnaryOp, test_cuda(2)): + def setUp(self): + jt.flags.amp_reg = 2 | 4 | 8 | 16 + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.amp_reg = 0 + jt.flags.use_cuda = 0 + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/python/jittor/utils/data.gz b/python/jittor/utils/data.gz index 728482b2..f59ad084 100644 Binary files a/python/jittor/utils/data.gz and b/python/jittor/utils/data.gz differ