add fp16 support

This commit is contained in:
Dun Liang 2022-03-15 17:45:39 +08:00
parent 7cf6165a10
commit 39ecdd84fd
49 changed files with 2124 additions and 290 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.3.1.38' __version__ = '1.3.1.38.1'
from jittor_utils import lock from jittor_utils import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int
@ -304,15 +304,52 @@ Var.cast = Var.cast
def array(data, dtype=None): def array(data, dtype=None):
if isinstance(data, core.Var): if isinstance(data, core.Var):
if dtype is None: if dtype is None:
return data.clone() ret = data.clone()
return cast(data, dtype) else:
if dtype is not None: ret = cast(data, dtype)
elif dtype is not None:
if isinstance(dtype, NanoString): if isinstance(dtype, NanoString):
dtype = str(dtype) dtype = str(dtype)
elif callable(dtype): elif callable(dtype):
dtype = dtype.__name__ dtype = dtype.__name__
return ops.array(np.array(data, dtype)) ret = ops.array(np.array(data, dtype))
return ops.array(data) else:
ret = ops.array(data)
# TODO: move those code to core
amp_reg = jt.flags.amp_reg
if amp_reg and ret.numel() != 1 and ret.dtype.is_float():
if amp_reg & 16:
if amp_reg & 1:
if ret.dtype != "float32":
return ret.float32()
elif amp_reg & 2:
if ret.dtype != "float16":
return ret.float16()
return ret
def random(shape, dtype="float32", type="uniform"):
# TODO: move those code to core
if dtype == "float16":
# TODO: make curand support fp16
ret = ops.random(shape, "float32", type).float16()
else:
ret = ops.random(shape, dtype, type)
amp_reg = jt.flags.amp_reg
if amp_reg:
if amp_reg & 16:
if amp_reg & 1:
if ret.dtype != "float32":
return ret.float32()
elif amp_reg & 2:
if ret.dtype != "float16":
return ret.float16()
return ret
def float_auto(x):
if jt.flags.amp_reg & 2:
return x.float16()
return x.float32()
Var.float_auto = float_auto
def array64(data, dtype=None): def array64(data, dtype=None):
with jt.flag_scope(auto_convert_64_to_32=0): with jt.flag_scope(auto_convert_64_to_32=0):
@ -1419,18 +1456,15 @@ Var.size = size
def to_int(v): def to_int(v):
dtype = str(v.dtype) assert v.dtype.is_int()
assert dtype.startswith("int")
return v.item() return v.item()
def to_float(v): def to_float(v):
dtype = str(v.dtype) assert v.dtype.is_float()
assert dtype.startswith("float")
return v.item() return v.item()
def to_bool(v): def to_bool(v):
dtype = str(v.dtype) assert v.dtype.is_int() or v.dtype.is_bool()
assert dtype.startswith("int") or dtype=="bool"
return ori_bool(v.item()) return ori_bool(v.item())
Var.__int__ = to_int Var.__int__ = to_int

View File

@ -210,6 +210,12 @@ def setup_cuda_extern():
LOG.w(f"CUDA found but cub is not loaded:\n{line}") LOG.w(f"CUDA found but cub is not loaded:\n{line}")
libs = ["cublas", "cudnn", "curand"] libs = ["cublas", "cudnn", "curand"]
# in cuda 11.4, module memory comsumptions:
# default context: 259 MB
# cublas: 340 MB
# cudnn: 340 MB
if int(os.environ.get("conv_opt", "0")):
libs = ["cublas", "curand"]
for lib_name in libs: for lib_name in libs:
try: try:
setup_cuda_lib(lib_name, extra_flags=link_cuda_extern) setup_cuda_lib(lib_name, extra_flags=link_cuda_extern)
@ -309,22 +315,27 @@ def install_cutt(root_folder):
if md5 != true_md5: if md5 != true_md5:
os.remove(fullname) os.remove(fullname)
shutil.rmtree(dirname) shutil.rmtree(dirname)
if not os.path.isfile(os.path.join(cache_path, "libcutt"+so)): CUTT_PATH = os.environ.get("CUTT_PATH", "")
LOG.i("Downloading cutt...") if not os.path.isfile(os.path.join(cache_path, "libcutt"+so)) or CUTT_PATH:
download_url_to_local(url, filename, root_folder, true_md5) if CUTT_PATH:
dirname = CUTT_PATH
else:
LOG.i("Downloading cutt...")
download_url_to_local(url, filename, root_folder, true_md5)
import zipfile import zipfile
zf = zipfile.ZipFile(fullname) zf = zipfile.ZipFile(fullname)
try: try:
zf.extractall(path=root_folder) zf.extractall(path=root_folder)
except RuntimeError as e: except RuntimeError as e:
print(e) print(e)
raise raise
zf.close() zf.close()
LOG.i("installing cutt...") LOG.i("installing cutt...")
arch_flag = "" # -Xptxas -dlcm=ca actually not work
arch_flag = " -Xptxas -dlcm=ca "
if len(flags.cuda_archs): if len(flags.cuda_archs):
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} " arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs)) arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))

View File

@ -23,8 +23,8 @@ EXTERN_LIB cublasHandle_t cublas_handle;
static inline cudaDataType get_dtype(NanoString dtype) { static inline cudaDataType get_dtype(NanoString dtype) {
if (dtype == ns_float32) return CUDA_R_32F; if (dtype == ns_float32) return CUDA_R_32F;
// if (dtype == ns_float64) return CUDA_R_64F; if (dtype == ns_float64) return CUDA_R_64F;
// if (dtype == ns_float16) return CUDA_R_16F; if (dtype == ns_float16) return CUDA_R_16F;
LOGf << "not support type" << dtype; LOGf << "not support type" << dtype;
return CUDA_R_32F; return CUDA_R_32F;
} }

View File

@ -124,6 +124,10 @@ void CublasBatchedMatmulOp::jit_run() {
if (use_tensorcore) { if (use_tensorcore) {
computeType = CUBLAS_COMPUTE_32F_FAST_16F; computeType = CUBLAS_COMPUTE_32F_FAST_16F;
} }
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUBLAS_COMPUTE_16F;
}
checkCudaErrors(cublasGemmStridedBatchedEx(handle_, checkCudaErrors(cublasGemmStridedBatchedEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha, k, n, m, &alpha,

View File

@ -81,6 +81,10 @@ void CublasMatmulOp::jit_run() {
if (use_tensorcore) { if (use_tensorcore) {
computeType = CUBLAS_COMPUTE_32F_FAST_16F; computeType = CUBLAS_COMPUTE_32F_FAST_16F;
} }
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUBLAS_COMPUTE_16F;
}
checkCudaErrors(cublasGemmEx(handle_, checkCudaErrors(cublasGemmEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha, k, n, m, &alpha,

View File

@ -174,6 +174,11 @@ void CudnnConvOp::jit_run() {
if(use_tensorcore){ if(use_tensorcore){
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) ); checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
} }
if (x->dtype() == ns_float16
|| y->dtype() == ns_float16 || w->dtype() == ns_float16) {
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
}
int dimY[] = { int dimY[] = {
(int)y->shape[findc("@YFORMAT", 'a')], // n (int)y->shape[findc("@YFORMAT", 'a')], // n

View File

@ -488,18 +488,11 @@ def arctan2(y,x):
angle = jt.zeros(x.shape,dtype=x.dtype) angle = jt.zeros(x.shape,dtype=x.dtype)
x = (x!=0.0).ternary(x, x+1e-30) x = (x!=0.0).ternary(x, x+1e-30)
angle = (y/x).arctan() angle = (y/x).arctan()
mask = y<0 | ((y==0) & (x<0))
mask = (y<0) & (x<0) angle = angle + mask*np.pi
if angle[mask].numel()>0:
angle[mask] -= np.pi
mask = (y>=0) &(x<0)
if angle[mask].numel()>0:
angle[mask] +=np.pi
return angle return angle
def nonzero(x): def nonzero(x):
r''' r'''
Return the index of the elements of input tensor which are not equal to zero. Return the index of the elements of input tensor which are not equal to zero.

View File

@ -143,7 +143,7 @@ class ResNet(nn.Module):
x = self.layer2(x) x = self.layer2(x)
x = self.layer3(x) x = self.layer3(x)
x = self.layer4(x) x = self.layer4(x)
x = self.avgpool(x) x = self.avgpool(x).float_auto()
x = jt.reshape(x, (x.shape[0], -1)) x = jt.reshape(x, (x.shape[0], -1))
x = self.fc(x) x = self.fc(x)
return x return x

View File

@ -37,9 +37,10 @@ def matmul_transpose(a, b):
assert len(a.shape) == 2 and len(b.shape) == 2 assert len(a.shape) == 2 and len(b.shape) == 2
shape = list(a.shape)[:-1] + list(b.shape) shape = list(a.shape)[:-1] + list(b.shape)
a = a.broadcast(shape, [len(shape)-2]) with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4):
b = b.broadcast(shape) a = a.broadcast(shape, [len(shape)-2])
return (a*b).sum(len(shape)-1) b = b.broadcast(shape)
return (a*b).sum(len(shape)-1)
def bmm_transpose(a, b): def bmm_transpose(a, b):
@ -108,47 +109,48 @@ Example::
c = jt.matmul(a, b) c = jt.matmul(a, b)
assert c.shape == [8, 10, 3, 5] assert c.shape == [8, 10, 3, 5]
''' '''
len_a = len(a.shape) with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4):
len_b = len(b.shape) len_a = len(a.shape)
if len_b == 1: len_b = len(b.shape)
# a: [n, m], b:[m], c:[n] if len_b == 1:
return (a*b).sum(-1) # a: [n, m], b:[m], c:[n]
if len_a == 1: return (a*b).sum(-1)
# a: [n], b:[n,k], c:[k] if len_a == 1:
return (a.broadcast(b, [-1]) * b).sum(0) # a: [n], b:[n,k], c:[k]
if len_a>=3 and len_a==len_b: return (a.broadcast(b, [-1]) * b).sum(0)
# bmm if len_a>=3 and len_a==len_b:
# a: [..., n, m], b: [..., m, k], c:[..., n, k] # bmm
if jt.flags.use_cuda and jt.compile_extern.cublas_ops: # a: [..., n, m], b: [..., m, k], c:[..., n, k]
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0) if jt.flags.use_cuda and jt.compile_extern.cublas_ops:
shape = [] return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
len_c = max(len_a, len_b) shape = []
(n, m), (m_, k) = a.shape[-2:], b.shape[-2:] len_c = max(len_a, len_b)
assert m == m_, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}" (n, m), (m_, k) = a.shape[-2:], b.shape[-2:]
# a: [..., n, m] assert m == m_, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}"
# b: [..., m, k] # a: [..., n, m]
# cc:[..., n, m, k] # b: [..., m, k]
# --> # cc:[..., n, m, k]
# 012 # -->
if len_b == 2 and len_a>2: # 012
# TODO:ugly implementation for tuner if len_b == 2 and len_a>2:
aa = a.reshape((-1, m)) # TODO:ugly implementation for tuner
cc = matmul(aa, b) aa = a.reshape((-1, m))
# print(a.shape, b.shape, cc.shape) cc = matmul(aa, b)
return cc.reshape(a.shape[:-1] + [k]) # print(a.shape, b.shape, cc.shape)
for i in range(len_c-2): return cc.reshape(a.shape[:-1] + [k])
ai = len_a-(len_c-i) for i in range(len_c-2):
bi = len_b-(len_c-i) ai = len_a-(len_c-i)
an = a.shape[ai] if ai>=0 else 1 bi = len_b-(len_c-i)
bn = b.shape[bi] if bi>=0 else 1 an = a.shape[ai] if ai>=0 else 1
if an!=1 and bn!=1: bn = b.shape[bi] if bi>=0 else 1
assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}" if an!=1 and bn!=1:
cn = max(an, bn) assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}"
shape.append(cn) cn = max(an, bn)
shape.extend([n, m, k]) shape.append(cn)
a = a.broadcast(shape, [-1]) shape.extend([n, m, k])
b = b.broadcast(shape, [-3]) a = a.broadcast(shape, [-1])
return (a*b).sum(-2) b = b.broadcast(shape, [-3])
return (a*b).sum(-2)
jt.Var.matmul = jt.Var.__matmul__ = matmul jt.Var.matmul = jt.Var.__matmul__ = matmul
jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b)) jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))
@ -488,22 +490,22 @@ class BCEWithLogitsLoss(Module):
def execute(self, output, target): def execute(self, output, target):
return binary_cross_entropy_with_logits(output,target,self.weight,self.pos_weight,self.size_average) return binary_cross_entropy_with_logits(output,target,self.weight,self.pos_weight,self.size_average)
def softmax(x, dim = None): def softmax(x, dim=None, log=False):
import jittor.other.code_softmax as code_softmax import jittor.other.code_softmax as code_softmax
if code_softmax.can_softmax_v1(x, dim): if code_softmax.can_softmax_v1(x, dim):
return code_softmax.softmax_v1(x) return code_softmax.softmax_v1(x, log)
if dim is None: if dim is None:
x = (x - x.max()).exp() x = (x - x.max()).exp()
ret = x / x.sum() ret = x / x.sum()
else: else:
x = (x-x.max(dim, keepdims=True)).exp() x = (x-x.max(dim, keepdims=True)).exp()
ret = x / x.sum(dim, keepdims=True) ret = x / x.sum(dim, keepdims=True)
if log: return ret.log()
return ret return ret
jt.Var.softmax = softmax jt.Var.softmax = softmax
def log_softmax(x,dim=None): def log_softmax(x,dim=None):
x = softmax(x,dim=dim) return softmax(x,dim=dim, log=True)
return jt.log(x)
jt.Var.log_softmax = log_softmax jt.Var.log_softmax = log_softmax
def log_sigmoid(x): def log_sigmoid(x):
@ -832,15 +834,16 @@ class Conv(Module):
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1 oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1 ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
assert oh>0 and ow>0 assert oh>0 and ow>0
xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [ with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4):
'i0', # Nid xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [
'i2', # Cid 'i0', # Nid
f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid 'i2', # Cid
f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid
]) f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid
ww = self.weight.broadcast(xx.shape, [0,3,4]) ])
yy = xx*ww ww = self.weight.broadcast(xx.shape, [0,3,4])
y = yy.sum([2,5,6]) # Kc, Kh, Kw yy = xx*ww
y = yy.sum([2,5,6]) # Kc, Kh, Kw
if self.bias is not None: if self.bias is not None:
b = self.bias.broadcast(y.shape, [0,2,3]) b = self.bias.broadcast(y.shape, [0,2,3])
y = y + b y = y + b
@ -1008,6 +1011,18 @@ class Conv3d(Module):
def execute(self, x): def execute(self, x):
return conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class Conv1d_sp(Linear):
def __init__(self, inchannels, outchannels, kernel_size=1, bias=True):
super().__init__(inchannels, outchannels, bias=bias)
assert kernel_size == 1
def execute(self, x):
x = x.transpose(0, 2, 1)
x = super().execute(x)
x = x.transpose(0, 2, 1)
return x
def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
''' Applies a 2D convolution over an input signal composed of several input planes. ''' Applies a 2D convolution over an input signal composed of several input planes.
@ -1048,15 +1063,16 @@ def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
Kh, Kw = weight.shape[-2:] Kh, Kw = weight.shape[-2:]
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1 oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1 ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
xx = x.reindex([N,out_channels,C,oh,ow,Kh,Kw], [ with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4):
'i0', # Nid xx = x.reindex([N,out_channels,C,oh,ow,Kh,Kw], [
'i2', # Cid 'i0', # Nid
f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}', # Hid+Khid 'i2', # Cid
f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}', # Wid+KWid f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}', # Hid+Khid
]) f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}', # Wid+KWid
ww = weight.broadcast(xx.shape, [0,3,4]) ])
yy = xx*ww ww = weight.broadcast(xx.shape, [0,3,4])
y = yy.sum([2,5,6]) # Kc, Kh, Kw yy = xx*ww
y = yy.sum([2,5,6]) # Kc, Kh, Kw
if bias is not None: if bias is not None:
b = bias.broadcast(y.shape, [0,2,3]) b = bias.broadcast(y.shape, [0,2,3])
y = y + b y = y + b

View File

@ -10,32 +10,48 @@ def can_softmax_v1(a, dim):
return False return False
return True return True
def softmax_v1(a): def softmax_v1(a, log=False):
assert can_softmax_v1(a, -1) assert can_softmax_v1(a, -1)
length = a.shape[-1] length = a.shape[-1]
# tnum = 1024 # tnum = 1024
tnum = 500 if length % 500 == 0 else 512 tnum = 500 if length % 500 == 0 else 512
tnum = 125 if length % 125 == 0 else 128
# tnum = 125
# tnum = 1000 if length % 1000 == 0 else 1024
# tnum = 250 # tnum = 250
per_thread = (length-1) // tnum + 1 per_thread = (length-1) // tnum + 1
ILP = 1
for ilp in [8,4,2]:
if length % tnum == 0 and per_thread % ilp == 0:
ILP = ilp
per_thread //= ILP
break
for_loop = f""" for_loop = f"""
#pragma unroll #pragma unroll
for (int i=0; i<{per_thread}; i++) for (int i=0; i<{per_thread}; i++)
""" """
if length % tnum == 0: if length % tnum != 0:
for_loop += f"if (i*{tnum}+threadIdx.x < len)\n" for_loop += f"if ((i*{tnum}+threadIdx.x)*{ILP} < len)\n"
return jt.code(a.shape, a.dtype, [a], cuda_header=f''' return jt.code(a.shape, a.dtype, [a], cuda_header=f'''
#include <{jt.compile_extern.cub_home}cub/cub.cuh> #include <{jt.compile_extern.cub_home}cub/cub.cuh>
#include <type/fp16_compute.h>
''', cuda_src=f''' ''', cuda_src=f'''
__global__ void kernel(in0_type* x, out0_type* y, int len) {{ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
typedef cub::BlockReduce<float, {tnum}> BlockReduce; typedef cub::BlockReduce<float, {tnum}> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
int id = blockIdx.x * len; int id = blockIdx.x * len;
in0_type v[{per_thread}]; in0_type v[{per_thread}][{ILP}];
{for_loop} v[i] = x[id+i*{tnum}+threadIdx.x]; {for_loop}
float v1 = v[0]; vload<sizeof(in0_type)*{ILP}>(v[i], &x[id+(i*{tnum}+threadIdx.x)*{ILP}]);
{for_loop} v1 = max(v1, v[i]); // v[i] = x[id+i*{tnum}+threadIdx.x];
float v1 = -1e30;
{for_loop}
#pragma unroll
for (int j=0; j<{ILP}; j++) {{
v1 = max(v1, float(v[i][j]));
}}
__shared__ float vmax; __shared__ float vmax;
auto tmp = BlockReduce(temp_storage).Reduce(v1, cub::Max()); auto tmp = BlockReduce(temp_storage).Reduce(v1, cub::Max());
if (threadIdx.x == 0) if (threadIdx.x == 0)
@ -43,10 +59,12 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
__syncthreads(); __syncthreads();
v1 = 0; v1 = 0;
{for_loop} {{ {for_loop}
v[i] = expf(v[i] - vmax); #pragma unroll
v1 += v[i]; for (int j=0; j<{ILP}; j++) {{
}} v[i][j] = expf(float(v[i][j]) - vmax);
v1 += float(v[i][j]);
}}
tmp = BlockReduce(temp_storage).Sum(v1); tmp = BlockReduce(temp_storage).Sum(v1);
__shared__ float vsum; __shared__ float vsum;
@ -54,7 +72,15 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
vsum = tmp; vsum = tmp;
__syncthreads(); __syncthreads();
{for_loop} y[id+i*{tnum}+threadIdx.x] = v[i] / vsum; {for_loop}
#pragma unroll
for (int j=0; j<{ILP}; j++)
v[i][j] = {
"@expand_op(log,@in0_type,float(v[i][j])/vsum)" if log
else "float(v[i][j])/vsum"
};
{for_loop}
vload<sizeof(in0_type)*{ILP}>(&y[id+(i*{tnum}+threadIdx.x)*{ILP}], v[i]);
}} }}
int len = in0->shape[in0->shape.size()-1]; int len = in0->shape[in0->shape.size()-1];
int bnum = in0->numel() / len; int bnum = in0->numel() / len;
@ -64,15 +90,17 @@ CHECK(0 == cudaGetLastError());
''', cuda_grad_src=[f""" ''', cuda_grad_src=[f"""
__global__ void kernel(pout0_type* x, dout_type* y, out0_type* z, int len) {{ __global__ void kernel(pout0_type* x, dout_type* y, out0_type* z, int len) {{
int id = blockIdx.x * len; int id = blockIdx.x * len;
in0_type vx[{per_thread}]; in0_type vx[{per_thread}][{ILP}];
in0_type vy[{per_thread}]; in0_type vy[{per_thread}][{ILP}];
{for_loop} {{ {for_loop} {{
vx[i] = x[id+i*{tnum}+threadIdx.x]; vload<sizeof(in0_type)*{ILP}>(vx[i], &x[id+(i*{tnum}+threadIdx.x)*{ILP}]);
vy[i] = y[id+i*{tnum}+threadIdx.x]; vload<sizeof(in0_type)*{ILP}>(vy[i], &y[id+(i*{tnum}+threadIdx.x)*{ILP}]);
}} }}
float v1 = 0; float v1 = 0;
{for_loop} v1 += vx[i]*vy[i]; {for_loop}
#pragma unroll
for (int j=0; j<{ILP}; j++)
v1 += {"vy[i][j];" if log else "vx[i][j]*vy[i][j];"}
typedef cub::BlockReduce<float, {tnum}> BlockReduce; typedef cub::BlockReduce<float, {tnum}> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
@ -83,7 +111,16 @@ __global__ void kernel(pout0_type* x, dout_type* y, out0_type* z, int len) {{
__syncthreads(); __syncthreads();
{for_loop} {for_loop}
z[id+i*{tnum}+threadIdx.x] = vx[i] * (vy[i] - reduce_var); #pragma unroll
for (int j=0; j<{ILP}; j++)
vx[i][j] = {
"vy[i][j] - expf(vx[i][j]) * reduce_var;" if log
else "vx[i][j] * (vy[i][j] - reduce_var);"
}
{for_loop}
vload<sizeof(in0_type)*{ILP}>(&z[id+(i*{tnum}+threadIdx.x)*{ILP}],
vx[i]);
}} }}
int len = in0->shape[in0->shape.size()-1]; int len = in0->shape[in0->shape.size()-1];
int bnum = in0->numel() / len; int bnum = in0->numel() / len;

View File

@ -120,8 +120,8 @@ class Pool(Module):
for (int i2 = p2; i2 < out_shape2; i2 += s2) for (int i2 = p2; i2 < out_shape2; i2 += s2)
{{ {forward_body} }} {{ {forward_body} }}
}} }}
int tx = min(1024, out_shape3); int tx = std::min(1024, out_shape3);
int ty = min(1024 / tx, out_shape2); int ty = std::min(1024 / tx, out_shape2);
int bx = (out_shape2 - 1) / ty + 1; int bx = (out_shape2 - 1) / ty + 1;
int by = out_shape1; int by = out_shape1;
int bz = out_shape0; int bz = out_shape0;
@ -143,8 +143,8 @@ class Pool(Module):
{{ {backward_body} }} {{ {backward_body} }}
}} }}
cudaMemsetAsync(out_p, 0, out->size); cudaMemsetAsync(out_p, 0, out->size);
int tx = min(1024, pout_shape3); int tx = std::min(1024, pout_shape3);
int ty = min(1024 / tx, pout_shape2); int ty = std::min(1024 / tx, pout_shape2);
int bx = (pout_shape2 - 1) / ty + 1; int bx = (pout_shape2 - 1) / ty + 1;
int by = pout_shape1; int by = pout_shape1;
int bz = pout_shape0; int bz = pout_shape0;
@ -310,9 +310,9 @@ class Pool3d(Module):
for (int i2 = p2; i2 < out_shape2; i2 += s2) for (int i2 = p2; i2 < out_shape2; i2 += s2)
{{ {forward_body} }} {{ {forward_body} }}
}} }}
int tx = min(1024, out_shape4); int tx = std::min(1024, out_shape4);
int ty = min(1024 / tx, out_shape3); int ty = std::min(1024 / tx, out_shape3);
int tz = min(1024 / tx / ty, out_shape2); int tz = std::min(1024 / tx / ty, out_shape2);
int bx = (out_shape2 - 1) / tz + 1; int bx = (out_shape2 - 1) / tz + 1;
int by = out_shape1; int by = out_shape1;
int bz = out_shape0; int bz = out_shape0;
@ -337,9 +337,9 @@ class Pool3d(Module):
{{ {backward_body} }} {{ {backward_body} }}
}} }}
cudaMemsetAsync(out_p, 0, out->size); cudaMemsetAsync(out_p, 0, out->size);
int tx = min(1024, pout_shape4); int tx = std::min(1024, pout_shape4);
int ty = min(1024 / tx, pout_shape3); int ty = std::min(1024 / tx, pout_shape3);
int tz = min(1024 / tx / ty, pout_shape2); int tz = std::min(1024 / tx / ty, pout_shape2);
int bx = (pout_shape2 - 1) / tz + 1; int bx = (pout_shape2 - 1) / tz + 1;
int by = pout_shape1; int by = pout_shape1;
int bz = pout_shape0; int bz = pout_shape0;

View File

@ -39,11 +39,24 @@ template<class T> struct StackIniter {
#define STACK_ALLOC2(T, a, n) T a[n] #define STACK_ALLOC2(T, a, n) T a[n]
#endif #endif
struct AmpGradGuard {
int amp_reg_bk;
AmpGradGuard(Op* op) {
amp_reg_bk = amp_reg;
amp_reg |= (op->flags.flags >> NodeFlags::_prefer_32);
}
~AmpGradGuard() {
amp_reg = amp_reg_bk;
}
};
VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) { VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) {
if (dout == nullptr) return nullptr; if (dout == nullptr) return nullptr;
if (x_index<0) return nullptr; if (x_index<0) return nullptr;
LOGvvvv << "Make grad op:" >> op->name() << "inputs:" >> op->inputs() LOGvvvv << "Make grad op:" >> op->name() << "inputs:" >> op->inputs()
<< "out:" >> out << "dout:" >> dout << "x:" >> x << "xid:" >> x_index; << "out:" >> out << "dout:" >> dout << "x:" >> x << "xid:" >> x_index;
AmpGradGuard agg(op);
auto dx = op->grad(out, dout, x, x_index); auto dx = op->grad(out, dout, x, x_index);
if (x->loop_options) if (x->loop_options)
dx->loop_options = x->loop_options; dx->loop_options = x->loop_options;
@ -182,7 +195,10 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
douts[i] = nullptr; douts[i] = nullptr;
} }
trace_grad_op = op; trace_grad_op = op;
op->grads(douts, dins); {
AmpGradGuard agg(op);
op->grads(douts, dins);
}
// dump "for (Var* in : op->inputs())" // dump "for (Var* in : op->inputs())"
for (int i=0; i<n_i; i++,j++) { for (int i=0; i<n_i; i++,j++) {
auto id = id_buffer[j].second; auto id = id_buffer[j].second;

View File

@ -167,7 +167,7 @@ inline JK& operator<<(JK& jk, int64 c) {
} }
#ifdef __linux__ #ifdef __linux__
inline JK& operator<<(JK& jk, long long int c) { inline JK& operator<<(JK& jk, int64_t c) {
return jk << (int64)c; return jk << (int64)c;
} }
#endif #endif

View File

@ -13,7 +13,8 @@ namespace jittor {
struct Deleter { struct Deleter {
std::function<void()> del; std::function<void()> del;
inline Deleter(std::function<void()>&& func) : del(move(func)) {} inline Deleter(std::function<void()>&& func) : del(move(func)) {}
inline ~Deleter() { del(); } inline Deleter() {}
inline ~Deleter() { if (del) del(); }
}; };
} // jittor } // jittor

View File

@ -9,6 +9,17 @@
namespace jittor { namespace jittor {
DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too");
DEFINE_FLAG_WITH_SETTER(int, auto_mixed_precision_level, 0, "Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16");
void setter_auto_mixed_precision_level(int value) {
if (value <= 3) amp_reg = 0; else
if (value == 4) amp_reg = amp_prefer16; else
if (value == 5) amp_reg = amp_prefer16 | amp_array_prefer; else
if (value == 6) amp_reg = amp_prefer16 | amp_array_prefer | amp_keep_reduce | amp_keep_white;
}
#define FOR_ALL_TYPES(m) \ #define FOR_ALL_TYPES(m) \
m(bool) \ m(bool) \
m(int8) \ m(int8) \
@ -89,15 +100,18 @@ static unordered_set<string> unary_ops = {
"erfinv" "erfinv"
}; };
static unordered_set<string> unary_float_ops = { static unordered_set<string> float_ops = {
"log", "log",
"exp", "exp",
"sqrt", "sqrt",
"mean",
"divide",
}; };
static unordered_set<string> unary_int_ops = { static unordered_set<string> int_ops = {
"round_int", "round_int",
"floor_int", "floor_int",
"ceil_int", "ceil_int",
"floor_divide",
}; };
static unordered_set<string> binary_ops = { static unordered_set<string> binary_ops = {
@ -127,6 +141,13 @@ static unordered_set<string> binary_ops = {
"mean", "mean",
}; };
static unordered_set<string> white_ops = {
// "log",
"exp",
"pow",
};
#define DEFINE_NS(T) NanoString ns_##T; #define DEFINE_NS(T) NanoString ns_##T;
FOR_ALL_NS(DEFINE_NS); FOR_ALL_NS(DEFINE_NS);
@ -135,6 +156,9 @@ char __ns_to_string[ns_max_size*ns_max_len];
int __ns_len[ns_max_size]; int __ns_len[ns_max_size];
static void init_ns() { static void init_ns() {
dsize_map["float16"] = 1;
is_float_map["float16"] = 1;
is_unsigned["float16"] = 0;
NanoString::ns_t i=0; NanoString::ns_t i=0;
auto func = [&](const char* name, NanoString& ns) { auto func = [&](const char* name, NanoString& ns) {
ns.set(NanoString::_index, i++, NanoString::_index_nbits); ns.set(NanoString::_index, i++, NanoString::_index_nbits);
@ -149,13 +173,16 @@ static void init_ns() {
if (unary_ops.count(name)) { if (unary_ops.count(name)) {
ns.set(NanoString::_type, NanoString::_unary, NanoString::_type_nbits); ns.set(NanoString::_type, NanoString::_unary, NanoString::_type_nbits);
ns.set(NanoString::_bool, is_bool.count(name)); ns.set(NanoString::_bool, is_bool.count(name));
ns.set(NanoString::_int, unary_int_ops.count(name)); ns.set(NanoString::_int, int_ops.count(name));
ns.set(NanoString::_float, unary_float_ops.count(name)); ns.set(NanoString::_float, float_ops.count(name));
} else } else
if (binary_ops.count(name)) { if (binary_ops.count(name)) {
ns.set(NanoString::_type, NanoString::_binary, NanoString::_type_nbits); ns.set(NanoString::_type, NanoString::_binary, NanoString::_type_nbits);
ns.set(NanoString::_bool, is_bool.count(name)); ns.set(NanoString::_bool, is_bool.count(name));
ns.set(NanoString::_int, int_ops.count(name));
ns.set(NanoString::_float, float_ops.count(name));
} }
ns.set(NanoString::_white_list, white_ops.count(name));
__string_to_ns[name] = ns; __string_to_ns[name] = ns;
auto name2 = ns.to_cstring(); auto name2 = ns.to_cstring();
int len=0; int len=0;

View File

@ -24,6 +24,7 @@ constexpr int ns_max_len = 16;
m(uint16) \ m(uint16) \
m(uint32) \ m(uint32) \
m(uint64) \ m(uint64) \
m(float16) \
m(float32) \ m(float32) \
m(float64) \ m(float64) \
\ \
@ -100,7 +101,7 @@ struct NanoString {
typedef uint16 ns_t; typedef uint16 ns_t;
enum Flags { enum Flags {
// bit0~7: index // bit0~7: index
_index=0, _index_nbits=8, _index=0, _index_nbits=7,
_n=_index_nbits, _n=_index_nbits,
// bit0-1: type // bit0-1: type
@ -116,6 +117,8 @@ struct NanoString {
_float=_n+5, _float=_n+5,
// bit6-7: dsize(1,2,4,8 byte) // bit6-7: dsize(1,2,4,8 byte)
_dsize=_n+6, _dsize_nbits=2, _dsize=_n+6, _dsize_nbits=2,
// bit8: white list
_white_list=_n+8,
}; };
ns_t data=0; ns_t data=0;
@ -130,11 +133,16 @@ struct NanoString {
inline ns_t index() const { return get(_index, _index_nbits); } inline ns_t index() const { return get(_index, _index_nbits); }
inline int len() const { return __ns_len[index()]; } inline int len() const { return __ns_len[index()]; }
inline ns_t type() const { return get(_type, _type_nbits); } inline ns_t type() const { return get(_type, _type_nbits); }
inline ns_t is_bool() const { return get(_bool); } // @pyjt(is_bool)
inline ns_t is_int() const { return get(_int); } inline bool is_bool() const { return get(_bool); }
inline ns_t is_unsigned() const { return get(_unsigned); } // @pyjt(is_int)
inline ns_t is_float() const { return get(_float); } inline bool is_int() const { return get(_int); }
inline bool is_unsigned() const { return get(_unsigned); }
// @pyjt(is_float)
inline bool is_float() const { return get(_float); }
inline ns_t is_white() const { return get(_white_list); }
inline ns_t dsize() const { return 1<<get(_dsize, _dsize_nbits); } inline ns_t dsize() const { return 1<<get(_dsize, _dsize_nbits); }
inline ns_t dsize_() const { return get(_dsize, _dsize_nbits); }
inline ns_t is_dtype() const { return get(_type, _type_nbits)==_dtype; } inline ns_t is_dtype() const { return get(_type, _type_nbits)==_dtype; }
inline ns_t is_binary() const { return get(_type, _type_nbits)==_binary; } inline ns_t is_binary() const { return get(_type, _type_nbits)==_binary; }
inline ns_t is_unary() const { return get(_type, _type_nbits)==_unary; } inline ns_t is_unary() const { return get(_type, _type_nbits)==_unary; }
@ -156,28 +164,6 @@ struct NanoString {
{ return __ns_to_string+index()*ns_max_len; } { return __ns_to_string+index()*ns_max_len; }
}; };
// force_type = 1 for int, 2 for float
inline
NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0, NanoString op=ns_void) {
bool is_float = v1.is_float() || v2.is_float();
int dsize = std::max(v1.dsize(), v2.dsize());
if (force_type == 1)
is_float = false;
else if (force_type == 2)
is_float = true;
if (is_float) {
if (dsize==4) return ns_float32;
return ns_float64;
} else {
if (dsize==8) return ns_int64;
if (dsize==4) return ns_int32;
if (dsize==2) return ns_int16;
if (op.data == ns_add.data || op.data == ns_subtract.data)
return ns_int8;
return v1;
}
}
// @pyjt(NanoString.__eq__) // @pyjt(NanoString.__eq__)
inline bool eq(const NanoString& a, const NanoString& b) { inline bool eq(const NanoString& a, const NanoString& b) {
return a.data == b.data; return a.data == b.data;
@ -199,4 +185,72 @@ inline std::ostream& operator<<(std::ostream& os, const NanoString& v) {
return os << v.to_cstring(); return os << v.to_cstring();
} }
EXTERN_LIB int amp_reg;
constexpr int amp_prefer32 = 1;
constexpr int amp_prefer16 = 2;
constexpr int amp_keep_reduce = 4;
constexpr int amp_keep_white = 8;
constexpr int amp_array_prefer = 16;
inline NanoString float_dtype(int dsize_) {
if (amp_reg & amp_prefer32) return ns_float32;
if (amp_reg & amp_prefer16) return ns_float16;
return (dsize_ == 3) ? ns_float64 :
(dsize_ == 2 ) ? ns_float32 : ns_float16;
}
inline NanoString int_dtype(int dsize_) {
return (dsize_ == 3) ? ns_int64 :
(dsize_ == 2) ? ns_int32 :
(dsize_ == 1) ? ns_int16 : ns_int8;
}
inline NanoString dtype_infer(NanoString x, NanoString y) {
int dsize_ = std::max(x.dsize_(), y.dsize_());
bool is_float = x.is_float() || y.is_float();
if (is_float)
return float_dtype(dsize_);
else {
return int_dtype(dsize_);
}
}
inline NanoString binary_dtype_infer(NanoString op, NanoString x, NanoString y) {
if (op.is_bool()) return ns_bool;
int dsize_ = std::max(x.dsize_(), y.dsize_());
bool is_float = !op.is_int() &&
(x.is_float() || y.is_float() || op.is_float());
if (is_float) {
if (op.is_white() && !(amp_reg & amp_keep_white))
return (dsize_ == 3) ? ns_float64 : ns_float32;
return float_dtype(dsize_);
} else {
return int_dtype(dsize_);
}
}
inline NanoString unary_dtype_infer(NanoString op, NanoString x) {
if (op.is_bool()) return ns_bool;
int dsize_ = x.dsize_();
if (op.is_float()) {
if (op.is_white() && !(amp_reg & amp_keep_white))
return (dsize_ == 3) ? ns_float64 : ns_float32;
return float_dtype(dsize_);
}
if (op.is_int()) return int_dtype(dsize_);
return x;
}
inline NanoString reduce_dtype_infer(NanoString op, NanoString x) {
bool is_float = x.is_float() || op.is_float();
int dsize_ = x.dsize_();
if (is_float) {
if (amp_reg & amp_keep_reduce)
return float_dtype(dsize_);
return (dsize_ == 3) ? ns_float64 : ns_float32;
} else {
return x;
}
}
} }

View File

@ -51,8 +51,14 @@ struct NodeFlags {
_grads=_n+6, _grads=_n+6,
// bit7: has graph optimize // bit7: has graph optimize
_has_gopt=_n+7, _has_gopt=_n+7,
// bit7: has vary input // bit8: has vary input
_has_vary_input=_n+8, _has_vary_input=_n+8,
// bit9: prefer 32 bit
_prefer_32=_n+9,
// bit10: force 16 bit
_prefer_16=_n+10,
// bit11: reduce keep type unchange
_reduce_keep=_n+11,
}; };
inline void set(Flags f, int a=1, int nbits=1) { inline void set(Flags f, int a=1, int nbits=1) {
@ -90,7 +96,7 @@ struct Node {
operator Var*() { return (Var*)node; } operator Var*() { return (Var*)node; }
operator var_output_t() { return {(Op*)node, index}; } operator var_output_t() { return {(Op*)node, index}; }
}; };
static int64_t tflag_count; static int64 tflag_count;
NodeFlags flags; NodeFlags flags;
NanoString ns; NanoString ns;
inline bool is_var() const { return flags.get(NodeFlags::_var); } inline bool is_var() const { return flags.get(NodeFlags::_var); }

View File

@ -25,11 +25,12 @@ DEFINE_FLAG(int, try_use_32bit_index, 0,
string_view_map<jit_op_entry_t> jit_ops; string_view_map<jit_op_entry_t> jit_ops;
string_view_map<string> jit_key_mapper; string_view_map<string> jit_key_mapper;
int64_t Op::number_of_lived_ops = 0; int64 Op::number_of_lived_ops = 0;
Op::Op() { Op::Op() {
flags.set(NodeFlags::_var, 0); flags.set(NodeFlags::_var, 0);
flags.set(NodeFlags::_cpu, 1); flags.set(NodeFlags::_cpu, 1);
flags.flags |= ((amp_reg & 7) << NodeFlags::_prefer_32);
number_of_lived_ops++; number_of_lived_ops++;
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.record_node(this); if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.record_node(this);
} }

View File

@ -15,7 +15,7 @@ namespace jittor {
enum OpType {other=0, element=1, broadcast=2, reduce=3}; enum OpType {other=0, element=1, broadcast=2, reduce=3};
struct Op : Node { struct Op : Node {
vector<VarPtr> outputs_holder; vector<VarPtr> outputs_holder;
static int64_t number_of_lived_ops; static int64 number_of_lived_ops;
inline Caster<Var*, Node::input_t> inputs() { CHECK_EXIST; return &_inputs; } inline Caster<Var*, Node::input_t> inputs() { CHECK_EXIST; return &_inputs; }
inline Caster<Var*, Node::output_t> outputs() { CHECK_EXIST; return &_outputs; } inline Caster<Var*, Node::output_t> outputs() { CHECK_EXIST; return &_outputs; }

View File

@ -112,7 +112,7 @@ int OpCompiler::total_member_count() {
return member_count; return member_count;
} }
int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>& vars) { int64 OpCompiler::eval(const string& expr, const unordered_map<string,string>& vars) {
if (expr.find("@") != string::npos) { if (expr.find("@") != string::npos) {
string new_expr; string new_expr;
for (size_t i=0; i<expr.size(); i++) { for (size_t i=0; i<expr.size(); i++) {

View File

@ -418,21 +418,13 @@ unordered_set<string> binary_ops = {
"bitwise_xor", "bitwise_xor",
}; };
NanoString binary_dtype_infer(NanoString op, Var* x, Var* y) {
if (op == ns_mean) return dtype_infer(x->ns, y->ns, 2); // force float
int force_type=0;
if (op == ns_divide) force_type=2; // force float
if (op == ns_floor_divide) force_type=1; // force int
return op.is_bool() ? ns_bool : dtype_infer(x->ns, y->ns, force_type, op);
}
BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) { BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) {
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
set_type(OpType::element); set_type(OpType::element);
ns = op; ns = op;
ASSERT(ns.is_binary()); ASSERT(ns.is_binary());
z = create_output(nullptr, binary_dtype_infer(op, x, y)); z = create_output(nullptr, binary_dtype_infer(op, x->ns, y->ns));
} }
VarPtr dirty_clone_broadcast(Var* v) { VarPtr dirty_clone_broadcast(Var* v) {

View File

@ -32,9 +32,11 @@ void op_registe(const OpInfo& op_info);
bool has_op(const string& name); bool has_op(const string& name);
OpInfo get_op_info(const string& name); OpInfo get_op_info(const string& name);
struct OpCompiler;
struct OpByType { struct OpByType {
unordered_set<string> types; unordered_set<string> types;
virtual string expand_op(const vector<string>& args) = 0; virtual string expand_op(const vector<string>& args) = 0;
virtual void post_pass(OpCompiler*) = 0;
}; };
extern vector<OpByType*> op_types; extern vector<OpByType*> op_types;

View File

@ -271,7 +271,7 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
if (x->dtype() == ns_bool) if (x->dtype() == ns_bool)
y = create_output(nullptr, ns_int32); y = create_output(nullptr, ns_int32);
else else
y = create_output(nullptr, binary_dtype_infer(ns, x, x)); y = create_output(nullptr, reduce_dtype_infer(ns, x->ns));
} }
ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask) ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask)
@ -283,7 +283,7 @@ ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask)
ASSERT(ns.is_binary()); ASSERT(ns.is_binary());
reduce_mask = dims_mask; reduce_mask = dims_mask;
this->keepdims_mask = keepdims_mask; this->keepdims_mask = keepdims_mask;
y = create_output(nullptr, binary_dtype_infer(ns, x, x)); y = create_output(nullptr, reduce_dtype_infer(ns, x->ns));
} }
ReduceOp::ReduceOp(Var* x, NanoString op, int dim, bool keepdims) ReduceOp::ReduceOp(Var* x, NanoString op, int dim, bool keepdims)
@ -359,8 +359,8 @@ void ReduceOp::jit_run() {
@for(i, DIM-2, -1, -1, auto ystride@i = ystride@{i+1} * yshape@{i+1};) @for(i, DIM-2, -1, -1, auto ystride@i = ystride@{i+1} * yshape@{i+1};)
index_t xstride@{DIM-1} = 1; index_t xstride@{DIM-1} = 1;
@for(i, DIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};) @for(i, DIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};)
Ty count = Ty(x->num) / Ty(y->num); Ty count = x->num*1.0 / y->num;
Ty rcount = Ty(y->num) / Ty(x->num); Ty rcount = y->num*1.0 / x->num;
@for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) { @for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) {
auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d)); auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d));
yp[yid] = @expand_op(init_@OP, @Ty); yp[yid] = @expand_op(init_@OP, @Ty);

View File

@ -132,7 +132,7 @@ void ReindexOp::jit_run() {
@for(d, 0, XDIM, index_t xid@d = @expand_macro(INDEX@d);) @for(d, 0, XDIM, index_t xid@d = @expand_macro(INDEX@d);)
auto xid = @for(d, 0, XDIM, + xid@d * xstride@d); auto xid = @for(d, 0, XDIM, + xid@d * xstride@d);
bool check_overflow = 0 @for(d, 0, XDIM, || xid@d<0 || xid@d>=xshape@d) @for(d, 0, OSIZE, || (@expand_macro(OFD@d))); bool check_overflow = 0 @for(d, 0, XDIM, || xid@d<0 || xid@d>=xshape@d) @for(d, 0, OSIZE, || (@expand_macro(OFD@d)));
yp[yid] = check_overflow ? (@OVERFLOW) : xp[xid]; yp[yid] = check_overflow ? Tx(@OVERFLOW) : xp[xid];
} }
} }
#endif // JIT #endif // JIT

View File

@ -28,6 +28,12 @@ TransposeOp::TransposeOp(Var* x, NanoVector axes_) : x(x), axes(axes_) {
for (int i=0; i<(int)xdim; i++) for (int i=0; i<(int)xdim; i++)
axes.push_back(xdim-1-i); axes.push_back(xdim-1-i);
} }
if (axes.size() < xdim || (axes.size() == xdim && axes[xdim-1]==xdim-1)) {
static VarPtr(*fuse_transpose)(Var*, NanoVector) = get_op_info("fuse_transpose").get_constructor<VarPtr, Var*, NanoVector>();
auto var = fuse_transpose(x, axes);
forward(var);
return;
}
#ifdef HAS_CUDA #ifdef HAS_CUDA
if (use_cuda) { if (use_cuda) {
static VarPtr(*cutt_transpose)(Var*, NanoVector) = nullptr; static VarPtr(*cutt_transpose)(Var*, NanoVector) = nullptr;

View File

@ -32,6 +32,7 @@ static unordered_set<string> unary_ops = {
"uint16", "uint16",
"uint32", "uint32",
"uint64", "uint64",
"float16",
"float32", "float32",
"float64", "float64",
// please keep float64 the last type // please keep float64 the last type
@ -533,22 +534,15 @@ UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
ns = op; ns = op;
ASSERT(ns.is_unary() | ns.is_dtype()); ASSERT(ns.is_unary() | ns.is_dtype());
NanoString dtype; NanoString dtype;
if (ns == x->dtype()) {
forward(x);
return;
}
if (ns.is_dtype()) { if (ns.is_dtype()) {
if (ns == x->dtype()) {
forward(x);
return;
}
dtype = ns; dtype = ns;
ns = ns_cast; ns = ns_cast;
} else if (ns.is_bool()) } else
dtype = ns_bool; dtype = unary_dtype_infer(ns, x->ns);
else if (ns.is_float())
dtype = dtype_infer(x->ns, x->ns, 2);
else if (ns.is_int())
dtype = dtype_infer(x->ns, x->ns, 1);
else {
dtype = x->ns;
}
y = create_output(nullptr, dtype); y = create_output(nullptr, dtype);
} }

View File

@ -25,6 +25,7 @@
namespace jittor { namespace jittor {
using namespace expr; using namespace expr;
extern int use_cuda;
struct OpInspector { struct OpInspector {
// binary mask for // binary mask for
@ -229,9 +230,14 @@ void ConvTuner::forwardTune(FusedOp* fop) {
if (!(bop->y->input() && bop->x->input() && fop->has(bop->x->input()) && fop->has(bop->y->input()))) continue; if (!(bop->y->input() && bop->x->input() && fop->has(bop->x->input()) && fop->has(bop->y->input()))) continue;
if (!(bop->x->input()->type()==OpType::broadcast && bop->y->input()->type()==OpType::broadcast)) return; if (!(bop->x->input()->type()==OpType::broadcast && bop->y->input()->type()==OpType::broadcast)) return;
// only support float32 currently // only support float32,float16 currently
if (bop->z->dtype() != ns_float32) if (use_cuda) {
continue; if (bop->z->dtype() != ns_float32 && bop->z->dtype() != ns_float16)
continue;
} else {
if (bop->z->dtype() != ns_float32)
continue;
}
Op* ops[3] = {op, bop->x->input(), bop->y->input()}; Op* ops[3] = {op, bop->x->input(), bop->y->input()};
int ok = 0; int ok = 0;
LOGvvvv << "conv like op" << fop << fop->get_jit_key(get_jk()); LOGvvvv << "conv like op" << fop << fop->get_jit_key(get_jk());

View File

@ -262,7 +262,7 @@ void Profiler::record_and_run(
Deleter _d; Deleter _d;
if (is_fused) { if (is_fused) {
auto fop = ((FusedOp*)op); auto fop = ((FusedOp*)op);
if (fop->context && fop->context->entry) { if (fop->context && fop->context->vrm.relay_groups.size()) {
// relay op // relay op
loop = rerun; loop = rerun;
profiler.relay_extra_cost = 0; profiler.relay_extra_cost = 0;

View File

@ -21,7 +21,9 @@ NanoString npy2ns[] = {
ns_int64, ns_uint64, ns_int64, ns_uint64,
ns_float32, ns_float64, ns_float64, ns_float32, ns_float64, ns_float64,
ns_void, ns_void, ns_void, ns_void, ns_void, ns_void,
ns_void ns_void, // 17
ns_void, ns_void, ns_void, ns_void, ns_void, // 22
ns_float16, // 23
}; };
NPY_TYPES ns2npy[] = { NPY_TYPES ns2npy[] = {
@ -34,7 +36,7 @@ NPY_TYPES ns2npy[] = {
NPY_BYTE, NPY_SHORT, NPY_INT, NPY_LONGLONG, NPY_BYTE, NPY_SHORT, NPY_INT, NPY_LONGLONG,
NPY_UBYTE, NPY_USHORT, NPY_UINT, NPY_ULONGLONG, NPY_UBYTE, NPY_USHORT, NPY_UINT, NPY_ULONGLONG,
#endif #endif
NPY_FLOAT, NPY_DOUBLE NPY_HALF, NPY_FLOAT, NPY_DOUBLE
}; };
void** PyArray_API; void** PyArray_API;

View File

@ -48,6 +48,8 @@ enum NPY_TYPES {
NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE, NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE, NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE,
NPY_OBJECT=17, NPY_OBJECT=17,
NPY_HALF=23,
NPY_END=24,
}; };
EXTERN_LIB NanoString npy2ns[]; EXTERN_LIB NanoString npy2ns[];
@ -60,11 +62,11 @@ EXTERN_LIB NPY_TYPES ns2npy[];
inline bool is_c_style(PyArray_Proxy* obj) { return obj->flags & 1; } inline bool is_c_style(PyArray_Proxy* obj) { return obj->flags & 1; }
inline NanoString get_type_str(PyArray_Proxy* obj) { inline NanoString get_type_str(PyArray_Proxy* obj) {
NanoString type = ns_void; NanoString type = ns_void;
if (obj->descr->type_num < NPY_OBJECT) if (obj->descr->type_num < NPY_END)
type = npy2ns[obj->descr->type_num]; type = npy2ns[obj->descr->type_num];
CHECK(type != ns_void) << "Numpy type not support, type_num:" CHECK(type != ns_void) << "Numpy type not support, type_num:"
<< obj->descr->type_num << obj->descr->type_num
<< "type_char:" << obj->descr->type; << "type_char:" << obj->descr->type << NPY_END << npy2ns[obj->descr->type_num];
return type; return type;
} }

View File

@ -141,7 +141,7 @@ ArrayOp::ArrayOp(PyObject* obj) {
} else { } else {
// this is non-continue numpy array // this is non-continue numpy array
#if defined(__linux__) || defined(_WIN32) #if defined(__linux__) || defined(_WIN32)
STACK_ALLOC(int64, dims, args.shape.size()); STACK_ALLOC(int64_t, dims, args.shape.size());
#elif defined(__APPLE__) #elif defined(__APPLE__)
long dims[args.shape.size()]; long dims[args.shape.size()];
#endif #endif

View File

@ -274,7 +274,7 @@ DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) {
DEF_IS(ArrayArgs, PyObject*) to_py_object(const T& a) { DEF_IS(ArrayArgs, PyObject*) to_py_object(const T& a) {
#if defined(__linux__) || defined(_WIN32) #if defined(__linux__) || defined(_WIN32)
STACK_ALLOC(int64, dims, a.shape.size()); STACK_ALLOC(int64_t, dims, a.shape.size());
#elif defined(__APPLE__) #elif defined(__APPLE__)
long dims[a.shape.size()]; long dims[a.shape.size()];
#endif #endif
@ -390,7 +390,7 @@ DEF_IS(VarHolder*, T) from_py_object(PyObject* obj, unique_ptr<VarHolder>& holde
struct DataView; struct DataView;
DEF_IS(DataView, PyObject*) to_py_object(T a) { DEF_IS(DataView, PyObject*) to_py_object(T a) {
#if defined(__linux__) || defined(_WIN32) #if defined(__linux__) || defined(_WIN32)
STACK_ALLOC(int64, dims, a.shape.size()); STACK_ALLOC(int64_t, dims, a.shape.size());
#elif defined(__APPLE__) #elif defined(__APPLE__)
long dims[a.shape.size()]; long dims[a.shape.size()];
#endif #endif

View File

@ -110,7 +110,7 @@ static void push_py_object(RingBuffer* rb, PyObject* obj, uint64& __restrict__ o
rb->push(size, offset); rb->push(size, offset);
args.ptr = rb->get_ptr(size, offset); args.ptr = rb->get_ptr(size, offset);
#if defined(__linux__) || defined(_WIN32) #if defined(__linux__) || defined(_WIN32)
STACK_ALLOC(int64, dims, args.shape.size()); STACK_ALLOC(int64_t, dims, args.shape.size());
#elif defined(__APPLE__) #elif defined(__APPLE__)
long dims[args.shape.size()]; long dims[args.shape.size()];
#endif #endif

View File

@ -12,6 +12,44 @@ namespace jittor {
extern int use_cuda; extern int use_cuda;
unordered_map<string,string> common_op_type_cuda_map = {
{"logical_not", "(!($2))"},
{"bitwise_not", "(~($2))"},
{"negative", "(-($2))"},
{"abs", "::abs($2)"},
{"log", "::logf(($1)($2))"},
{"exp", "::expf(($1)($2))"},
{"sqrt", "::sqrtf(($1)($2))"},
{"round", "(($1) ::roundf(($2)))"},
{"floor", "(($1) ::floorf(($2)))"},
{"ceil", "(($1) ::ceilf(($2)))"},
{"round_int", "(($1) ::roundf(($2)))"},
{"floor_int", "(($1) ::floorf(($2)))"},
{"ceil_int", "(($1) ::ceilf(($2)))"},
{"sin", "(($1) ::sinf(($2)))"},
{"asin", "(($1) ::asinf(($2)))"},
{"sinh", "(($1) ::sinhf(($2)))"},
{"asinh", "(($1) ::asinhf(($2)))"},
{"cos", "(($1) ::cosf(($2)))"},
{"acos", "(($1) ::acosf(($2)))"},
{"cosh", "(($1) ::coshf(($2)))"},
{"acosh", "(($1) ::acoshf(($2)))"},
{"tan", "(($1) ::tanf(($2)))"},
{"atan", "(($1) ::atanf(($2)))"},
{"tanh", "(($1) ::tanhf(($2)))"},
{"atanh", "(($1) ::atanhf(($2)))"},
{"sigmoid", "(($1) (1.0f/(1.0f+::expf((::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300))))))))"},
{"erf", "(($1) ::erff(($2)))"},
{"erfinv", "(($1) ::erfinvf(($1)($2)))"},
{"cast", "(($1)($2))"},
{"pow", "::pow(($2),($4))"},
{"maximum", "::max($1($2), $1($4))"},
{"minimum", "::min($1($2), $1($4))"},
{"mod", "@if(@strcmp($1,float32)==0,(($2)-::floorf(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-::floor(($2)/($4))*($4)),(($2)%($4))))"},
{"init_maximum", "::numeric_min<$1>()"},
{"init_minimum", "::numeric_max<$1>()"},
};
struct CommonOpType : OpByType { struct CommonOpType : OpByType {
CommonOpType() { CommonOpType() {
types = { types = {
@ -34,43 +72,7 @@ struct CommonOpType : OpByType {
if (!types.count(args[i])) if (!types.count(args[i]))
return ""; return "";
} }
static unordered_map<string,string> cuda_map = { auto& cuda_map = common_op_type_cuda_map;
{"logical_not", "(!($2))"},
{"bitwise_not", "(~($2))"},
{"negative", "(-($2))"},
{"abs", "::abs($2)"},
{"log", "::logf(($1)($2))"},
{"exp", "::expf(($1)($2))"},
{"sqrt", "::sqrtf(($1)($2))"},
{"round", "(($1) ::roundf(($2)))"},
{"floor", "(($1) ::floorf(($2)))"},
{"ceil", "(($1) ::ceilf(($2)))"},
{"round_int", "(($1) ::roundf(($2)))"},
{"floor_int", "(($1) ::floorf(($2)))"},
{"ceil_int", "(($1) ::ceilf(($2)))"},
{"sin", "(($1) ::sinf(($2)))"},
{"asin", "(($1) ::asinf(($2)))"},
{"sinh", "(($1) ::sinhf(($2)))"},
{"asinh", "(($1) ::asinhf(($2)))"},
{"cos", "(($1) ::cosf(($2)))"},
{"acos", "(($1) ::acosf(($2)))"},
{"cosh", "(($1) ::coshf(($2)))"},
{"acosh", "(($1) ::acoshf(($2)))"},
{"tan", "(($1) ::tanf(($2)))"},
{"atan", "(($1) ::atanf(($2)))"},
{"tanh", "(($1) ::tanhf(($2)))"},
{"atanh", "(($1) ::atanhf(($2)))"},
{"sigmoid", "(($1) (1.0f/(1.0f+::expf((::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300))))))))"},
{"erf", "(($1) ::erff(($2)))"},
{"erfinv", "(($1) ::erfinvf(($1)($2)))"},
{"cast", "(($1)($2))"},
{"pow", "::pow(($2),($4))"},
{"maximum", "::max($1($2), $1($4))"},
{"minimum", "::min($1($2), $1($4))"},
{"mod", "@if(@strcmp($1,float32)==0,(($2)-::floorf(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-::floor(($2)/($4))*($4)),(($2)%($4))))"},
{"init_maximum", "::numeric_min<$1>()"},
{"init_minimum", "::numeric_max<$1>()"},
};
static unordered_map<string,string> cpu_map = { static unordered_map<string,string> cpu_map = {
{"logical_not", "(!($2))"}, {"logical_not", "(!($2))"},
@ -151,6 +153,10 @@ struct CommonOpType : OpByType {
ret = cpu_map[args.at(0)]; ret = cpu_map[args.at(0)];
return format(ret, args); return format(ret, args);
} }
void post_pass(OpCompiler*) {
return;
}
}; };

View File

@ -0,0 +1,164 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
#ifdef JIT_cuda
#include <driver_types.h>
#include <cuda_fp16.h>
namespace jittor {
typedef __half float16;
#if CUDA_ARCH >= 800
inline __device__ float16 max(float16 a, float16 b) { return __hmax(a, b); }
inline __device__ float16 min(float16 a, float16 b) { return __hmin(a, b); }
#else
inline __device__ float16 max(float16 a, float16 b) { return a<b?b:a; }
inline __device__ float16 min(float16 a, float16 b) { return a<b?a:b; }
#endif
inline __device__ float16 pow(float16 a, float16 b) { return ::pow(float32(a), float32(b)); }
template<int nbyte, class T>
__device__ inline void vload(T* __restrict__ a, T* __restrict__ b) {
if constexpr (nbyte<=0) return;
if constexpr (nbyte>=16) {
auto __restrict__ aa = (float4* __restrict__)a;
auto __restrict__ bb = (float4* __restrict__)b;
aa[0] = bb[0];
return vload<nbyte-16>(aa+1, bb+1);
}
if constexpr (nbyte>=8) {
auto __restrict__ aa = (float2* __restrict__)a;
auto __restrict__ bb = (float2* __restrict__)b;
aa[0] = bb[0];
return vload<nbyte-8>(aa+1, bb+1);
}
if constexpr (nbyte>=4) {
auto __restrict__ aa = (float* __restrict__)a;
auto __restrict__ bb = (float* __restrict__)b;
aa[0] = bb[0];
return vload<nbyte-4>(aa+1, bb+1);
}
if constexpr (nbyte>=2) {
auto __restrict__ aa = (__half* __restrict__)a;
auto __restrict__ bb = (__half* __restrict__)b;
aa[0] = bb[0];
return vload<nbyte-2>(aa+1, bb+1);
}
if constexpr (nbyte>=1) {
auto __restrict__ aa = (int8_t* __restrict__)a;
auto __restrict__ bb = (int8_t* __restrict__)b;
aa[0] = bb[0];
return vload<nbyte-1>(aa+1, bb+1);
}
}
}
using jittor::max;
using jittor::min;
using jittor::pow;
#else
namespace jittor {
struct float16 {
uint16 x;
inline float16(float32 f) {
unsigned x = *((int*)(void*)(&f));
unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
unsigned sign, exponent, mantissa;
// Get rid of +NaN/-NaN case first.
if (u > 0x7f800000) {
this->x = 0x7fffU;
return;
}
sign = ((x >> 16) & 0x8000);
// Get rid of +Inf/-Inf, +0/-0.
if (u > 0x477fefff) {
this->x = sign | 0x7c00U;
return;
}
if (u < 0x33000001) {
this->x = sign | 0x0000U;
return;
}
exponent = ((u >> 23) & 0xff);
mantissa = (u & 0x7fffff);
if (exponent > 0x70) {
shift = 13;
exponent -= 0x70;
} else {
shift = 0x7e - exponent;
exponent = 0;
mantissa |= 0x800000;
}
lsb = (1 << shift);
lsb_s1 = (lsb >> 1);
lsb_m1 = (lsb - 1);
// Round to nearest even.
remainder = (mantissa & lsb_m1);
mantissa >>= shift;
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
++mantissa;
if (!(mantissa & 0x3ff)) {
++exponent;
mantissa = 0;
}
}
this->x = (sign | (exponent << 10) | mantissa);
}
inline operator float() {
unsigned sign = ((x >> 15) & 1);
unsigned exponent = ((x >> 10) & 0x1f);
unsigned mantissa = ((x & 0x3ff) << 13);
if (exponent == 0x1f) { /* NaN or Inf */
mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
exponent = 0xff;
} else if (!exponent) { /* Denorm or Zero */
if (mantissa) {
unsigned int msb;
exponent = 0x71;
do {
msb = (mantissa & 0x400000);
mantissa <<= 1; /* normalize */
--exponent;
} while (!msb);
mantissa &= 0x7fffff; /* 1.mantissa is implicit */
}
} else {
exponent += 0x70;
}
int temp = ((sign << 31) | (exponent << 23) | mantissa);
return reinterpret_cast<float&>(temp);
}
};
}
#endif

View File

@ -0,0 +1,188 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "common.h"
#include "utils/str_utils.h"
#include "ops/op_register.h"
#include "op_compiler.h"
namespace jittor {
extern int use_cuda;
extern unordered_map<string,string> common_op_type_cuda_map;
static bool isvar(char x) { return isalnum(x) || x == '_' || x == ':'; }
struct FP16OpType : OpByType {
FP16OpType() {
types = {
"float16",
};
}
string expand_op(const vector<string>& args) {
bool found_fp16 = 0;
for (int i=1; i<args.size(); i+=2) {
if (types.count(args[i]))
found_fp16 = 1;
}
if (!found_fp16) return "";
static unordered_map<string,string> cuda_map = {
{"logical_not", "(!($2))"},
{"bitwise_not", "(~($2))"},
{"negative", "(-($2))"},
{"abs", "::abs($2)"},
{"log", "::hlog(($1)($2))"},
{"exp", "::hexp(($1)($2))"},
{"sqrt", "::hsqrt(($1)($2))"},
{"round", "(($1) ::roundf(($2)))"},
{"floor", "(($1) ::floorf(($2)))"},
{"ceil", "(($1) ::ceilf(($2)))"},
{"round_int", "(($1) ::roundf(($2)))"},
{"floor_int", "(($1) ::floorf(($2)))"},
{"ceil_int", "(($1) ::ceilf(($2)))"},
{"sin", "(($1) ::sinf(($2)))"},
{"asin", "(($1) ::asinf(($2)))"},
{"sinh", "(($1) ::sinhf(($2)))"},
{"asinh", "(($1) ::asinhf(($2)))"},
{"cos", "(($1) ::cosf(($2)))"},
{"acos", "(($1) ::acosf(($2)))"},
{"cosh", "(($1) ::coshf(($2)))"},
{"acosh", "(($1) ::acoshf(($2)))"},
{"tan", "(($1) ::tanf(($2)))"},
{"atan", "(($1) ::atanf(($2)))"},
{"tanh", "(($1) ::tanhf(($2)))"},
{"atanh", "(($1) ::atanhf(($2)))"},
{"sigmoid", "(($1) (1.0f/(1.0f+::expf((::min($1(-($2)), $1(@if(@strcmp($1,float16)==0,30,300))))))))"},
{"erf", "(($1) ::erff(($2)))"},
{"erfinv", "(($1) ::erfinvf(($1)($2)))"},
{"cast", "(($1)($2))"},
{"pow", "::pow(($2),($4))"},
{"maximum", "::max($1($2), $1($4))"},
{"minimum", "::min($1($2), $1($4))"},
{"mod", "$1(($2)-::hfloor(($2)/($4))*($4))"},
{"init_maximum", "-32768.0f"},
{"init_minimum", "32768.0f"},
};
static unordered_map<string,string> cpu_map = {
{"logical_not", "(!($2))"},
{"bitwise_not", "(~($2))"},
{"negative", "(-($2))"},
{"abs", "std::abs($2)"},
{"log", "std::log(($1)($2))"},
{"exp", "std::exp(($1)($2))"},
{"sqrt", "std::sqrt(($1)($2))"},
{"round", "(($1)std::round(($2)))"},
{"floor", "(($1)std::floor(($2)))"},
{"ceil", "(($1)std::ceil(($2)))"},
{"round_int", "(($1)std::round(($2)))"},
{"floor_int", "(($1)std::floor(($2)))"},
{"ceil_int", "(($1)std::ceil(($2)))"},
{"sin", "(($1) std::sin(($2)))"},
{"asin", "(($1) std::asin(($2)))"},
{"sinh", "(($1) std::sinh(($2)))"},
{"asinh", "(($1) std::asinh(($2)))"},
{"cos", "(($1) std::cos(($2)))"},
{"acos", "(($1) std::acos(($2)))"},
{"cosh", "(($1) std::cosh(($2)))"},
{"acosh", "(($1) std::acosh(($2)))"},
{"tan", "(($1) std::tan(($2)))"},
{"atan", "(($1) std::atan(($2)))"},
{"tanh", "(($1) std::tanh(($2)))"},
{"atanh", "(($1) std::atanh(($2)))"},
{"sigmoid", "(($1) (1.0f/(1.0f+std::exp(std::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300)))))))"},
{"erf", "(($1) std::erf(($2)))"},
{"erfinv", "(jittor::_erfinv($2))"},
{"cast", "(($1)($2))"},
{"pow", "std::pow(($2),($4))"},
{"maximum", "std::max($1($2), $1($4))"},
{"minimum", "std::min($1($2), $1($4))"},
{"mod", "$1(($2)-std::floor(($2)/($4))*($4))"},
{"init_maximum", "-32768.0f"},
{"init_minimum", "32768.0f"},
};
static unordered_map<string,string> both_map {
{"add", "(($2)+($4))"},
{"subtract", "(($2)-($4))"},
{"multiply", "(($2)*($4))"},
{"divide", "($1(($1($2))/($1($4))))"},
{"floor_divide", "($1(($1($2))/($1($4))))"},
{"less", "(($2)<($4))"},
{"less_equal", "(($2)<=($4))"},
{"greater", "(($2)>($4))"},
{"greater_equal", "(($2)>=($4))"},
{"equal", "(($2)==($4))"},
{"not_equal", "(($2)!=($4))"},
{"left_shift", "(($2)<<($4))"},
{"right_shift", "(($2)>>($4))"},
{"logical_and", "(($2)&&($4))"},
{"logical_or", "(($2)||($4))"},
{"logical_xor", "((bool($2))!=(bool($4)))"},
{"bitwise_and", "(($2)&($4))"},
{"bitwise_or", "(($2)|($4))"},
{"bitwise_xor", "(($2)^($4))"},
{"mean", "(($2)+($4)*($1(rcount)))"},
{"init_add", "$1(0)"},
{"init_multiply", "$1(1)"},
{"init_logical_and", "true"},
{"init_logical_or", "false"},
{"init_logical_xor", "false"},
{"init_bitwise_and", "$1(-1)"},
{"init_bitwise_or", "$1(0)"},
{"init_bitwise_xor", "$1(0)"},
{"init_mean", "$1(0)"},
};
string ret;
if (both_map.count(args.at(0)))
ret = both_map[args.at(0)];
else if (use_cuda)
ret = cuda_map[args.at(0)];
else
ret = cpu_map[args.at(0)];
if (use_cuda) {
if (args[1] == "float32" && !both_map.count(args.at(0))) {
ret = common_op_type_cuda_map[args.at(0)];
}
if (args[1] == "float16" || args[1] == "float32") {
for (int i=3; i<args.size(); i+=2) {
if (args[i] != args[1]) {
ret = replace(ret, "$"+S(i-1),
args[1]+"($"+S(i-1)+")");
}
}
} else {
for (int i=3; i<args.size(); i+=2) {
if (args[i] != "float16") {
ret = replace(ret, "$"+S(i-1),
"float16($"+S(i-1)+")");
}
}
}
}
return format(ret, args);
}
void post_pass(OpCompiler* oc) {
string& src = oc->src;
if (src.find("float16") == string::npos)
return;
int i = src.rfind("#include");
if (i<0) i=0;
i = src.find('\n', i) + 1;
src = src.substr(0, i) + "#include \"type/fp16_compute.h\"\n" +
src.substr(i);
return;
}
};
static int _ = registe_op_type(new FP16OpType());
}

View File

@ -18,7 +18,7 @@ namespace jittor {
typedef int8_t int8; typedef int8_t int8;
typedef int16_t int16; typedef int16_t int16;
typedef int int32; typedef int int32;
typedef int64_t int64; typedef long long int64;
typedef uint8_t uint8; typedef uint8_t uint8;
typedef uint16_t uint16; typedef uint16_t uint16;
typedef uint32_t uint32; typedef uint32_t uint32;

View File

@ -14,7 +14,7 @@
namespace jittor { namespace jittor {
int64_t Var::number_of_lived_vars = 0; int64 Var::number_of_lived_vars = 0;
DEFINE_FLAG(fast_shared_ptr<loop_options_t>, compile_options, {}, DEFINE_FLAG(fast_shared_ptr<loop_options_t>, compile_options, {},
"Override the default loop transfrom options"); "Override the default loop transfrom options");
@ -42,7 +42,7 @@ string Var::to_string() {
return s; return s;
} }
int64_t Var::numel() { int64 Var::numel() {
if (!shape.size()) return size=num=-1; if (!shape.size()) return size=num=-1;
bool negtive = 0; bool negtive = 0;
num=1; num=1;

View File

@ -18,13 +18,13 @@ struct Var : Node {
NanoVector shape; NanoVector shape;
cstr name; cstr name;
fast_shared_ptr<loop_options_t> loop_options; fast_shared_ptr<loop_options_t> loop_options;
static int64_t number_of_lived_vars; static int64 number_of_lived_vars;
// this var will be generated after alloc. // this var will be generated after alloc.
void* mem_ptr = nullptr; void* mem_ptr = nullptr;
Allocator* allocator = nullptr; Allocator* allocator = nullptr;
size_t allocation; size_t allocation;
int64_t size, num; int64 size, num;
inline bool is_float() const { CHECK_EXIST; return ns.is_float(); } inline bool is_float() const { CHECK_EXIST; return ns.is_float(); }
inline int dsize() const { CHECK_EXIST; return ns.dsize(); } inline int dsize() const { CHECK_EXIST; return ns.dsize(); }
inline NanoString dtype() const { CHECK_EXIST; return ns; } inline NanoString dtype() const { CHECK_EXIST; return ns; }
@ -40,7 +40,7 @@ struct Var : Node {
Var(NanoVector shape, NanoString dtype); Var(NanoVector shape, NanoString dtype);
string to_string(); string to_string();
int64_t numel(); int64 numel();
void set_shape(NanoVector shape); void set_shape(NanoVector shape);
bool alloc(Allocator* allocator); bool alloc(Allocator* allocator);
inline void share_with(Var* x, size_t offset = 0) { CHECK_EXIST; allocator = (Allocator*)x; allocation = offset; } inline void share_with(Var* x, size_t offset = 0) { CHECK_EXIST; allocator = (Allocator*)x; allocation = offset; }

View File

@ -15,6 +15,7 @@ if __name__ == "__main__":
skip_l = int(os.environ.get("test_skip_l", "0")) skip_l = int(os.environ.get("test_skip_l", "0"))
skip_r = int(os.environ.get("test_skip_r", "1000000")) skip_r = int(os.environ.get("test_skip_r", "1000000"))
skip = os.environ.get("test_skip", "").split(",")
test_only = None test_only = None
if "test_only" in os.environ: if "test_only" in os.environ:
test_only = set(os.environ.get("test_only").split(",")) test_only = set(os.environ.get("test_only").split(","))
@ -34,6 +35,9 @@ if __name__ == "__main__":
continue continue
if test_only and test_name not in test_only: if test_only and test_name not in test_only:
continue continue
for s in skip:
if s in test_name:
continue
print("Add Test", _, test_name) print("Add Test", _, test_name)
suite.addTest(tests) suite.addTest(tests)

View File

@ -0,0 +1,374 @@
from copy import deepcopy
from pathlib import Path
import jittor as jt
import jittor.nn as nn
import numpy as np
import os
split_size = 1000000
conv_opt = int(os.environ.get("conv_opt", "0"))
if conv_opt:
Conv1d_sp = nn.Conv1d_sp
else:
Conv1d_sp = nn.Conv1d
def MLP(channels: list, do_bn=True):
""" Multi-layer perceptron """
n = len(channels)
layers = []
for i in range(1, n):
layers.append(Conv1d_sp(channels[i - 1], channels[i], kernel_size=1, bias=True))
if i < (n - 1):
if do_bn:
layers.append(nn.BatchNorm(channels[i]))
# layers.append(nn.InstanceNorm1d(channels[i]))
# layers.append(nn.LayerNorm(channels[i]))
layers.append(nn.ReLU())
return nn.Sequential(*layers)
def normalize_keypoints(kpts, image_shape):
size = image_shape.flip(1) # shape=(b,2) ;h w -> w, h
center = size / 2
scaling = size.float32().max(1, keepdims=True) * 0.7
return (kpts - center[:, None, :]) / scaling[:, None, :]
class KeypointEncoder(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self, feature_dim, layers, keypoint_position_dim=2):
super().__init__()
# self.keypoint_position_dim = keypoint_position_dim
self.encoder = MLP([keypoint_position_dim + 1] + layers + [feature_dim])
nn.init.constant_(self.encoder[-1].bias, 0.0)
def execute(self, kpts, scores):
inputs = jt.concat([kpts.t(), scores.unsqueeze(1)], dim=1)
return self.encoder(inputs)
cnt = 0
def attention(query, key, value):
global cnt
cnt += 1
b, d, h, n = query.shape
# print("attention", b,d,h,n, cnt)
dim_factor = (1.0 / d)**0.5
query = query.transpose(0, 2, 3, 1).reshape(b * h, -1, d) * dim_factor
key = key.transpose(0, 2, 1, 3).reshape(b * h, d, -1)
value = value.transpose(0, 2, 3, 1).reshape(b * h, -1, d)
# print("attention", query.shape, key.shape, value.shape)
data = []
for i in range(0, query.shape[0], split_size):
end = min(i + split_size, query.shape[0])
tmp1 = nn.bmm(query[i:end], key[i:end])
tmp2 = nn.softmax(tmp1, dim=-1)
tmp3 = nn.bmm(tmp2, value[i:end])
tmp3.sync()
data.append(tmp3)
tmp3 = jt.concat(data)
# for i in range(0, query.shape[0], split_size):
# end = min(i + split_size, query.shape[0])
# tmp1 = nn.bmm(query[:,i:end], key[:,i:end])
# tmp2 = nn.softmax(tmp1, dim=-1)
# tmp3 = nn.bmm(tmp2, value[:,i:end])
# tmp3.sync()
# data.append(tmp3)
# tmp3 = jt.concat(data, dim=1)
# tmp1 = nn.bmm(query, key)
# print(tmp1.shape)
# tmp2 = nn.softmax(tmp1, dim=-1)
# print(tmp2.shape)
# tmp3 = nn.bmm(tmp2, value)
# print(tmp3.shape)
return tmp3.reshape(b, h, -1, d).transpose(0, 3, 1, 2)
return nn.bmm(nn.softmax(nn.bmm(query, key), dim=-1), value).reshape(b, h, -1, d).transpose(0, 3, 1, 2)
class MultiHeadedAttention(nn.Module):
""" Multi-head attention to increase model expressivitiy """
def __init__(self, num_heads: int, d_model: int):
super().__init__()
assert d_model % num_heads == 0
self.dim = d_model // num_heads
self.num_heads = num_heads
self.merge = Conv1d_sp(d_model, d_model, kernel_size=1)
self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
def execute(self, query, key, value):
batch_dim = query.size(0)
query, key, value = [l(x).reshape(batch_dim, self.dim, self.num_heads, -1) for l, x in zip(self.proj, (query, key, value))]
x = attention(query, key, value)
# x = attention_chunk(query, key, value)
return self.merge(x.reshape(batch_dim, self.dim * self.num_heads, -1))
class AttentionalPropagation(nn.Module):
def __init__(self, feature_dim: int, num_heads: int):
super().__init__()
self.attn = MultiHeadedAttention(num_heads, feature_dim)
self.mlp = MLP([feature_dim * 2, feature_dim * 2, feature_dim])
nn.init.constant_(self.mlp[-1].bias, 0.0)
def execute(self, x, source):
message = self.attn(x, source, source)
return self.mlp(jt.concat([x, message], dim=1))
class AttentionalGNN(nn.Module):
def __init__(self, feature_dim: int, layer_names: list):
super().__init__()
self.layers = nn.ModuleList([AttentionalPropagation(feature_dim, 4) for _ in range(len(layer_names))])
self.is_cross = [x == 'cross' for x in layer_names]
def execute(self, desc0, desc1):
for layer, is_cross in zip(self.layers, self.is_cross):
layer.attn.prob = []
if is_cross:
src0, src1 = desc1, desc0
else: # if name == 'self':
src0, src1 = desc0, desc1
# delta0, delta1 = layer(desc0, src0), layer(desc1, src1)
delta0 = layer(desc0, src0)
# print(delta0.numel()*4)
# breakpoint()
jt.sync_all()
delta1 = layer(desc1, src1)
jt.sync_all()
desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
jt.sync_all()
return desc0, desc1
def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
""" Perform Sinkhorn Normalization in Log-space for stability"""
u, v = jt.zeros_like(log_mu), jt.zeros_like(log_nu)
for _ in range(iters):
u = log_mu - (Z + v.unsqueeze(1)).exp().sum(dim=2).log()
v = log_nu - (Z + u.unsqueeze(2)).exp().sum(dim=1).log()
return Z + u.unsqueeze(2) + v.unsqueeze(1)
def log_optimal_transport(scores, alpha, iters: int):
""" Perform Differentiable Optimal Transport in Log-space for stability"""
b, m, n = scores.shape
ms, ns = jt.float(m, requires_grad=False), jt.float(n, requires_grad=False)
bins0 = alpha.broadcast([b, m, 1])
bins1 = alpha.broadcast([b, 1, n])
alpha = alpha.broadcast([b, 1, 1])
couplings = jt.concat([jt.concat([scores, bins0], -1), jt.concat([bins1, alpha], -1)], 1)
norm = -(ms + ns).log()
log_mu = jt.concat([norm.broadcast([m]), ns.log() + norm])
log_nu = jt.concat([norm.broadcast([n]), ms.log() + norm])
log_mu, log_nu = log_mu[None].broadcast([b, m + 1]), log_nu[None].broadcast([b, n + 1])
Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
Z = Z - norm # multiply probabilities by M+N
return Z
def arange_like(x, dim: int):
return jt.ones(x.shape[dim], dtype=x.dtype)[None].cumsum()[0] - 1 # traceable in 1.1
default_config = {
'descriptor_dim': 256, # SuperPoint
'weights': 'indoor',
'keypoint_encoder': [32, 64, 128, 256], # SuperPoint
'GNN_layers': ['self', 'cross'] * 9,
'sinkhorn_iterations': 100,
'match_threshold': 0.2,
}
def get_weighted_loss_batch(scores, all_matches):
matches0, matches1 = all_matches.chunk(chunks=2, dim=2)
batchIdx = jt.arange(all_matches.shape[0]).unsqueeze(1).repeat(1, all_matches.shape[1])
batchIdx, matches0, matches1 = batchIdx.view(-1), matches0.view(-1), matches1.view(-1)
valid_index0, valid_index1 = matches0 >= 0, matches1 >= 0
valid_match = jt.logical_and(valid_index0, valid_index1)
valid_unmatch = jt.logical_xor(valid_index0, valid_index1)
num_match = valid_match.sum().maximum(1e-9)
num_unmatch = valid_unmatch.sum().maximum(1e-9)
score_ = scores[batchIdx, matches0, matches1]
score_match_ = (score_*valid_match).float32().sum() / num_match
score_umatch_ = (score_*valid_unmatch).float32().sum() / num_unmatch
return -(num_unmatch * score_match_ + num_match * score_umatch_) / (num_match + num_unmatch)
# print(score_umatch_, score_match_)
# return -(score_match + score_umatch) / (num_match + num_unmatch)
score_match = scores[(batchIdx[valid_match], matches0[valid_match], matches1[valid_match])].float32().mean() if num_match > 0 else 0
score_umatch = scores[(batchIdx[valid_unmatch], matches0[valid_unmatch], matches1[valid_unmatch])].float32().mean() if num_unmatch > 0 else 0
# print(score_match, score_umatch)
return -(num_unmatch * score_match + num_match * score_umatch) / (num_match + num_unmatch)
def add_dustbin(scores, alpha):
b, m, n = scores.shape
bins0 = jt.broadcast(alpha, (b, m, 1))
bins1 = jt.broadcast(alpha, (b, 1, n))
alpha = jt.broadcast(alpha, (b, 1, 1))
couplings = jt.concat([jt.concat([scores, bins0], -1), jt.concat([bins1, alpha], -1)], 1)
return couplings
class SuperGlue(nn.Module):
def __init__(self, config):
super().__init__()
config = {**default_config, **config}
self.descriptor_dim = config['descriptor_dim']
self.keypoint_encoder = config['keypoint_encoder']
self.GNN_layers = config['GNN_layers']
self.sinkhorn_iterations = config['sinkhorn_iterations']
self.match_threshold = config['match_threshold']
self.keypoint_position_dim = config['keypoint_position_dim']
self.use_dual_softmax = config['use_dual_softmax']
self.scale = jt.float(self.descriptor_dim**-0.5).stop_grad()
# self.scale.requires_grad = False
# self.des_extend = MLP([128, 256])
self.kenc = KeypointEncoder(self.descriptor_dim, self.keypoint_encoder, keypoint_position_dim=self.keypoint_position_dim)
self.gnn = AttentionalGNN(self.descriptor_dim, self.GNN_layers)
self.final_proj = Conv1d_sp(self.descriptor_dim, self.descriptor_dim, kernel_size=1, bias=True)
self.bin_score = jt.float(1.0)
def execute(self, data):
"""Run SuperGlue on a pair of keypoints and descriptors"""
kpts0, kpts1 = data['keypoints0'], data['keypoints1']
desc0, desc1 = data['descriptors0'], data['descriptors1']
all_matches = data['all_matches']
# match_num = data['match_num']
if kpts0.shape[1] == 0 or kpts1.shape[1] == 0 or all_matches.shape[1] == 0: # no keypoints or no matches/unmatches
shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
return {
'matches0': jt.ones(shape0, dtype=jt.int),
'matches1': jt.ones(shape1, dtype=jt.int),
'matching_scores0': jt.zeros(shape0, dtype=jt.float),
'matching_scores1': jt.zeros(shape1, dtype=jt.float),
'skip_train': True
}
# Keypoint normalization.
kpts0 = normalize_keypoints(kpts0, data['shape0'])
kpts1 = normalize_keypoints(kpts1, data['shape1'])
# Keypoint MLP encoder.
# desc0 = self.des_extend(desc0) + self.kenc(kpts0, data['scores0'])
# desc1 = self.des_extend(desc1) + self.kenc(kpts1, data['scores1'])
desc0 = desc0 + self.kenc(kpts0, data['scores0'])
desc1 = desc1 + self.kenc(kpts1, data['scores1'])
# Multi-layer Transformer network.
desc0, desc1 = self.gnn(desc0, desc1)
# Final MLP projection.
desc0, desc1 = self.final_proj(desc0), self.final_proj(desc1)
desc0_t = desc0.t()
losses = []
for i in range(0, desc1.shape[0], split_size):
end = min(desc1.shape[0], i + split_size)
# Compute matching descriptor distance.
scores = nn.bmm(desc0_t[i:end], desc1[i:end]) * self.scale # 457.76 MB
scores.sync()
# Run the optimal transport.
if self.use_dual_softmax:
scores = add_dustbin(scores, self.bin_score) # 458.68 MB
scores.sync()
dual_softmax0, dual_softmax1 = nn.log_softmax(scores, 1), nn.log_softmax(scores, 2)
scores = dual_softmax0 + dual_softmax1 # 458.22 MB
scores.sync()
else:
scores = log_optimal_transport(scores, self.bin_score, iters=self.config['sinkhorn_iterations'])
# loss = torch.stack([get_match_score(scores[b], all_matches[b]) for b in range(all_matches.shape[0])])
loss = get_weighted_loss_batch(scores, all_matches[i:end])
loss.sync()
losses.append(loss)
loss = jt.concat(losses)
'''
# Compute matching descriptor distance.
scores = nn.bmm(desc0.t(), desc1) * self.scale # 457.76 MB
scores.sync()
# Run the optimal transport.
if self.use_dual_softmax:
scores = add_dustbin(scores, self.bin_score) # 458.68 MB
scores.sync()
dual_softmax0, dual_softmax1 = nn.log_softmax(scores, 1), nn.log_softmax(scores, 2)
scores = dual_softmax0 + dual_softmax1 # 458.22 MB
scores.sync()
else:
scores = log_optimal_transport(scores, self.bin_score, iters=self.config['sinkhorn_iterations'])
# loss = torch.stack([get_match_score(scores[b], all_matches[b]) for b in range(all_matches.shape[0])])
loss = get_weighted_loss_batch(scores, all_matches)
# print(scores.shape, all_matches.shape, loss.shape)
'''
# matches0, matches1 = all_matches.chunk(chunks=2, dim=2)
# batchIdx = jt.arange(0, b).unsqueeze(1).repeat(1, num)
# batchIdx, matches0, matches1 = batchIdx.view(-1), matches0.view(-1), matches1.view(-1)
# validmatch = (matches0 >= 0) | (matches1 >= 0)
# batchIdx, matches0, matches1 = batchIdx[validmatch], matches0[validmatch], matches1[validmatch]
# matches0[matches0 == -1] = n
# matches1[matches1 == -1] = m
# loss_mean = -scores[(batchIdx, matches0, matches1)].mean()
# loss_mean = nn.l1_loss(loss_mean, jt.float(0.0))
if not data['return_match']:
return {'loss': loss}
with jt.no_grad():
b, n, m = scores.shape
# Get the matches with score above "match_threshold".
indices0, max0 = scores[:, :-1, :-1].argmax(2)
indices1, max1 = scores[:, :-1, :-1].argmax(1)
mutual0 = jt.arange(0, n)[None] == indices1.gather(1, indices0)
mutual1 = jt.arange(0, m)[None] == indices0.gather(1, indices1)
# zero = scores.new_tensor(0)
# mscores0 = torch.where(mutual0, max0.values.exp(), zero)
mscores0 = max0.exp()
mscores0[mutual0.logical_not()] = 0
# mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
mscores1 = mscores0.gather(1, indices1)
mscores1[mutual1.logical_not()] = 0
valid0 = mutual0 & (mscores0 > self.match_threshold)
valid1 = mutual1 & valid0.gather(1, indices1)
# indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
# indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
indices0[valid0.logical_not()] = -1
indices1[valid1.logical_not()] = -1
return {
'matches0': indices0, # use -1 for invalid match
'matches1': indices1, # use -1 for invalid match
'matching_scores0': mscores0,
'matching_scores1': mscores1,
'loss': loss,
}
# scores big value or small value means confidence? log can't take neg value

View File

@ -0,0 +1,344 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import os
n = 400000000
# n = 4000000
n = 7680000
def get_mem_band():
a = jt.rand((n)).float32()
for i in range(100):
a.copy().sync()
jt.sync_all(True)
import time
t = time.time()
for i in range(1000):
a.copy().sync()
jt.sync_all(True)
dt = time.time() - t
band = a.numel() * 4 * 2000 / dt / 1024**3
print("Mem band: ", band)
return band
def check_simple_add_band():
# copy: 816
# S=1 128,1024, ILP=1 634
# S=0 128,1024, ILP=1 734
# S=0 128,512, ILP=1 716
# S=0 64,1024, ILP=1 706
# S=0 256,1024, ILP=1 706
def test(S=0, B=128, T=1024, ILP=1):
a = jt.rand((n)).float32()
jt.sync_all(True)
jt.flags.log_silent = 1
with jt.profile_scope(100, 1000) as rep:
b = jt.code(a.shape, a.dtype, [a],
cuda_header="#include \"type/fp16_compute.h\"",
cuda_src=f"""
__global__ void kernel(in0_type * __restrict__ a, in0_type* __restrict__ b, int num) {{
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int tnum = blockDim.x * gridDim.x;
#define ILP {ILP}
for (int i=tid*ILP; i<num; i+=tnum*ILP) {{
// b[i] = a[i];
vload<ILP*sizeof(in0_type)>(b+i, a+i);
{"__syncthreads();" if S else ""}
}}
}}
kernel<<<{B},{T}>>>(in0_p, out0_p, in0->num);
""")
b.sync()
bw = float(rep[-1][9]) / 1024**3
s = f"S={S}, B={B}, T={T}, ILP={ILP} BW={bw}"
print(s)
return s, bw
def test2(S=0, B=128, T=1024, ILP=1):
a = jt.rand((n)).float32()
jt.sync_all(True)
# jt.flags.log_silent = 0
with jt.profile_scope(10, 1000) as rep:
b = jt.code(a.shape, a.dtype, [a],
cuda_header="#include \"type/fp16_compute.h\"",
cuda_src=f"""
__global__ void kernel(float2 * __restrict__ a, float2* __restrict__ b, int num) {{
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int tnum = blockDim.x * gridDim.x;
#define ILP 1
for (int i=tid*ILP; i<num; i+=tnum*ILP) {{
b[i] = a[i];
// b[i+1] = a[i+1];
// vload<ILP*sizeof(in0_type)>(b+i, a+i);
{"__syncthreads();" if S else ""}
}}
}}
kernel<<<{B},{T}>>>((float2*)in0_p, (float2*)out0_p, in0->num/2);
""")
b.sync()
bw = float(rep[-1][9]) / 1024**3
s = f"T2: S={S}, B={B}, T={T}, ILP={ILP} BW={bw}"
print(s)
return s, bw
def test3(S=0, B=128, T=1024, ILP=1, C=0):
a = jt.rand((n)).float32()
b = jt.rand(B)
jt.sync_all(True)
jt.flags.log_silent = 1
with jt.profile_scope(100, 1000) as rep:
b = jt.code(a.shape, a.dtype, [a, b],
cuda_header="#include \"type/fp16_compute.h\"",
cuda_src=f"""
__global__ void kernel(in0_type * __restrict__ a, in0_type* __restrict__ b, int num) {{
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int tnum = blockDim.x * gridDim.x;
#define ILP {ILP}
for (int i=tid*ILP; i<num; i+=tnum*ILP) {{
// b[i] = a[i];
vload<ILP*sizeof(in0_type)>(b+i, a+i);
{"__syncthreads();" if S else ""}
}}
{"__syncthreads();" if C else ""}
}}
kernel<<<in1->shape[0],{T}>>>(in0_p, out0_p, in0->num);
""")
b.compile_options = {"FLAGS: -Xptxas -dlcm=ca ": C}
# b.compile_options = {"FLAGS: Xptxas dlcm=ca ": 1}
b.sync()
bw = float(rep[-1][9]) / 1024**3
s = f"T3: S={S}, B={B}, T={T}, ILP={ILP} C={C} BW={bw}"
print(s)
return s, bw
def test4(S=0, B=128, T=1024, ILP=1, C=0, name="b.png"):
a = jt.rand((n)).float32()
b = jt.rand(B*4).uint32()
jt.sync_all(True)
# jt.flags.log_silent = 1
with jt.profile_scope(100, 10000) as rep:
_ = jt.code(a.shape, a.dtype, [a, b],
cuda_header="#include \"type/fp16_compute.h\"",
cuda_src=f"""
__device__ uint get_smid(void) {{
uint ret;
asm("mov.u32 %0, %smid;" : "=r"(ret) );
return ret;
}}
__device__ uint get_time(void) {{
uint ret;
asm volatile("mov.u32 %0, %%globaltimer_lo;" : "=r"(ret));
return ret;
}}
__global__ void kernel(in0_type * __restrict__ a, in0_type* __restrict__ b, int num, in1_type* __restrict__ c) {{
uint t = get_time();
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int tnum = blockDim.x * gridDim.x;
#define ILP {ILP}
for (int i=tid*ILP; i<num; i+=tnum*ILP) {{
// b[i] = a[i];
vload<ILP*sizeof(in0_type)>(b+i, a+i);
{"__syncthreads();" if S else ""}
}}
{"__syncthreads();" if C else ""}
if (threadIdx.x == 0)
((uint4* __restrict__)c)[blockIdx.x] =
uint4{{get_smid(), t, get_time(), 0}};
}}
kernel<<<in1->shape[0]/4,{T}>>>(in0_p, out0_p, in0->num, in1_p);
""")
_.compile_options = {"FLAGS: -Xptxas -dlcm=ca ": C}
# b.compile_options = {"FLAGS: Xptxas dlcm=ca ": 1}
_.sync()
bw = float(rep[-1][9]) / 1024**3
b = b.data.reshape(-1, 4)[:,:3]
mint = b[:,1].min()
b[:,1:] -= mint
smmax = int(b[:,0].max())
smmin = int(b[:,0].min())
maxt = b.max()
# print(b)
s = f"T4: S={S}, B={B}, T={T}, ILP={ILP} C={C} BW={bw:.3f} sm={smmin},{smmax} maxt={maxt}"
print(s)
import pylab as pl
pl.figure(figsize=(16,16))
texts = []
pret = np.zeros(200, dtype="uint32")
for i in range(B):
smid, s, t = b[i]
pl.plot([s,t], [smid, smid], 'ro-')
texts.append((s, smid, i))
texts.append((t, smid, i))
texts = sorted(texts)
for (s, smid, bid) in texts:
cpos = max(pret[smid], s)
pl.text(cpos, smid, str(bid))
pret[smid] = cpos + maxt // 30
# print("???")
# adjust_text(texts, arrowprops=dict(arrowstyle='->', color='blue'))
# print("???")
pl.savefig(name)
pl.close()
return s, bw
# test(S=0, B=128, T=1024, ILP=1)
# test(S=1, B=128, T=1024, ILP=1)
# test(S=0, B=64, T=1024, ILP=1)
# test(S=0, B=256, T=1024, ILP=1)
# test(S=1, B=128, T=512, ILP=1)
# test(S=1, B=128, T=256, ILP=1)
# test(S=0, B=128, T=1024, ILP=2)
# test(S=0, B=128, T=1024, ILP=4)
# test(S=0, B=128, T=512, ILP=2)
# test(S=0, B=128, T=512, ILP=4)
# test(S=1, B=128, T=1024, ILP=2)
# test(S=1, B=128, T=1024, ILP=4)
# test(S=1, B=128, T=1024, ILP=8)
# test(S=1, B=128, T=1024, ILP=16)
# test(S=1, B=128, T=512, ILP=2)
# test(S=1, B=128, T=512, ILP=4)
# test(S=1, B=256, T=1024, ILP=2)
# test(S=1, B=512, T=1024, ILP=2)
# test(S=1, B=256, T=1024, ILP=4)
# test(S=1, B=256, T=1024, ILP=8)
# test(S=1, B=256, T=1024, ILP=16)
# test(S=1, B=256, T=512, ILP=2)
# test(S=1, B=256, T=512, ILP=4)
# test(S=1, B=128, T=256, ILP=2)
# test(S=1, B=128, T=256, ILP=4)
# test(S=0, B=128, T=256, ILP=2)
# test(S=0, B=128, T=256, ILP=4)
# for b in [1, 2, 4, 8, 16, 32, 64, 128,256]:
# test(S=1, B=b, T=512, ILP=2)
import matplotlib as mpl
mpl.use('Agg')
import pylab as pl
import numpy as np
# test4(S=1, B=82, T=1024, ILP=2, C=0, name="b.png")
# test4(S=1, B=83, T=1024, ILP=2, C=0, name="c.png")
# test4(S=1, B=82*3, T=512, ILP=2, C=0, name="d1.png")
# test4(S=1, B=82*3+1, T=512, ILP=2, C=0, name="d2.png")
# test4(S=1, B=82*6+1, T=512, ILP=2, C=0, name="d3.png")
# test4(S=0, B=82*6+1, T=512, ILP=2, C=0, name="d4.png")
for b in range(70, 83):
test4(S=1, B=b, T=1024, ILP=2, C=0, name=f"b-{b}.png")
# data = []
# for b in range(32, 2000, 8):
# _, bw = test3(S=0, B=b, T=32, ILP=2)
# data.append([b, bw])
# data = np.array(data)
# pl.plot(data[:,0], data[:,1])
# for t in [32, 64, 128, 256, 512, 1024]:
# data = []
# for b in range(32, 2000, 8):
# _, bw = test3(S=1, B=b*(1024//t), T=t, ILP=2)
# data.append([b, bw])
# data = np.array(data)
# pl.plot(data[:,0], data[:,1])
# for t in [1024]:
# for c in [0,1]:
# data = []
# # for b in range(32, 1000, 8):
# for b in range(32, 33, 8):
# _, bw = test3(S=c, B=b*(1024//t), T=t, ILP=2, C=0)
# data.append([b, bw])
# data = np.array(data)
# pl.plot(data[:,0], data[:,1])
# for ilp in [2]:
# for s in [1]:
# for t in [1024,512,256,128]:
# data = []
# for b in range(32, 1100, 8):
# _, bw = test3(S=s, B=b*(1024//t), T=t, ILP=ilp)
# data.append([b, bw])
# data = np.array(data)
# pl.plot(data[:,0], data[:,1])
# pl.savefig("a.png")
# pl.close()
# for b in range(80, 90, 1):
# _, bw = test3(S=1, B=b, T=1024, ILP=2)
# # 82
# for b in range(240, 260, 1):
# _, bw = test3(S=1, B=b, T=512, ILP=2)
# # 82*3 = 246
# for b in range(240, 500, 1):
# _, bw = test3(S=1, B=b, T=256, ILP=2)
# # 492 = 82*6
# for b in range(240, 1000, 1):
# _, bw = test3(S=1, B=b, T=128, ILP=2)
# # 984 = 82*12
# for b in [128,256]:
# test(S=1, B=b, T=1024, ILP=2)
# for b in [128,256]:
# test(S=0, B=b, T=512, ILP=2)
# for b in [128,256]:
# test(S=0, B=b, T=1024, ILP=2)
# for b in [128,256]:
# test(S=1, B=b, T=512, ILP=1)
# for b in [128,256]:
# test(S=1, B=b, T=1024, ILP=1)
# for b in [128,256]:
# test(S=0, B=b, T=512, ILP=1)
# for b in [128,256]:
# test(S=0, B=b, T=1024, ILP=1)
# test(S=1, B=128, T=512, ILP=4)
# test(S=1, B=64, T=512, ILP=2)
# test(S=1, B=80, T=512, ILP=2)
# test(S=1, B=100, T=512, ILP=2)
# test(S=1, B=110, T=512, ILP=2)
# test(S=1, B=115, T=512, ILP=2)
# test(S=1, B=120, T=512, ILP=2)
# test(S=1, B=130, T=512, ILP=2)
# test(S=1, B=140, T=512, ILP=2)
# test2(S=1, B=128, T=512, ILP=2)
# test(S=1, B=128, T=256, ILP=4)
# test(S=1, B=128, T=128, ILP=8)
# test(S=1, B=128, T=64, ILP=16)
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
class TestBenchmarkCUDA(unittest.TestCase):
def setUp(self):
jt.flags.use_cuda = 1
def tearDown(self):
jt.flags.use_cuda = 0
def test_main(self):
return
get_mem_band()
check_simple_add_band()
if __name__ == "__main__":
unittest.main()

View File

@ -19,12 +19,12 @@ def all_eq(x, y):
y = convert(y) y = convert(y)
if str(x.dtype).startswith("float"): if str(x.dtype).startswith("float"):
return str(y.dtype).startswith("float") and x.shape == y.shape and (x==y).all() return str(y.dtype).startswith("float") and x.shape == y.shape and (x==y).all()
return x.dtype == y.dtype and x.shape == y.shape and (x==y).all() return x.dtype == y.dtype and x.shape == y.shape and np.testing.assert_allclose(x, y)
def check(op, *args): def check(op, *args):
x = eval(f"np.{op}(*args)") x = eval(f"np.{op}(*args)")
y = eval(f"jt.{op}(*args).data") y = eval(f"jt.{op}(*args).data")
assert all_eq(x, y), f"{x}\n{y}" all_eq(x, y)
class TestBinaryOp(unittest.TestCase): class TestBinaryOp(unittest.TestCase):
def test_binary_op(self): def test_binary_op(self):
@ -47,6 +47,9 @@ class TestBinaryOp(unittest.TestCase):
def test_i(self): def test_i(self):
def check(op, a, b): def check(op, a, b):
if isinstance(a, list):
a = np.array(a)
b = np.array(b)
if jt.flags.use_cuda and op == "@": if jt.flags.use_cuda and op == "@":
return return
if op=="@": if op=="@":
@ -65,13 +68,13 @@ class TestBinaryOp(unittest.TestCase):
a = np.float32(a) a = np.float32(a)
ja = np.float32(ja) ja = np.float32(ja)
assert all_eq(ja, a), (ja,a) all_eq(ja, a)
check("+", 5, 2) check("+", 5, 2)
check("-", 5, 2) check("-", 5, 2)
check("*", 5, 2) check("*", 5, 2)
check("/", 5, 2) check("/", 5, 2)
check("//", 5, 2) check("//", 5, 2)
check("@", [[5]], [[2]]) # check("@", [[5]], [[2]])
check("%", 5, 2) check("%", 5, 2)
check("**", 5, 2) check("**", 5, 2)
check("<<", 5, 2) check("<<", 5, 2)
@ -80,6 +83,15 @@ class TestBinaryOp(unittest.TestCase):
check("^", 5, 2) check("^", 5, 2)
check("|", 5, 2) check("|", 5, 2)
check("+", [5.0,6.0], [2.0,3.0])
check("-", [5.0,6.0], [2.0,3.0])
check("*", [5.0,6.0], [2.0,3.0])
check("/", [5.0,6.0], [2.0,3.0])
check("//", [5.0,6.0], [2.0,3.0])
check("@", [[5,6],[7,8]], [[2,3],[4,5]])
check("%", [5.0,6.0], [2.0,3.0])
check("**", [5.0,6.0], [2.0,3.0])
def test_r(self): def test_r(self):
def check(op, a, b): def check(op, a, b):
a = np.array(a) a = np.array(a)
@ -97,7 +109,7 @@ class TestBinaryOp(unittest.TestCase):
a = eval(f"a {op} b") a = eval(f"a {op} b")
a = np.array(a) a = np.array(a)
assert all_eq(jc, a), f"\n{jc}\n{a}" all_eq(jc, a)
check("+", 5, 2) check("+", 5, 2)
check("-", 5, 2) check("-", 5, 2)
check("*", 5, 2) check("*", 5, 2)
@ -118,6 +130,7 @@ class TestBinaryOp(unittest.TestCase):
a = np.random.rand(10) a = np.random.rand(10)
b = np.random.rand(10) b = np.random.rand(10)
c = np.random.rand(10) c = np.random.rand(10)
tol = 1e-2 if jt.flags.amp_reg & 2 else 1e-4
for op in ops: for op in ops:
func = lambda x: eval(f"((x[0]{op}x[1])*x[2]).sum()") func = lambda x: eval(f"((x[0]{op}x[1])*x[2]).sum()")
x, grads = ngrad(func, [a,b,c], 1e-8) x, grads = ngrad(func, [a,b,c], 1e-8)
@ -127,7 +140,7 @@ class TestBinaryOp(unittest.TestCase):
jx = eval(f"(ja{op}jb)*jc") jx = eval(f"(ja{op}jb)*jc")
jgrads = jt.grad(jx, [ja,jb,jc]) jgrads = jt.grad(jx, [ja,jb,jc])
for jd, nd in zip(jgrads, grads): for jd, nd in zip(jgrads, grads):
assert (np.abs(jd.data-nd)<1e-4).all(), f"\n{jd.data}\n{nd}" np.testing.assert_allclose(jd.data, nd, atol=tol, rtol=tol)
def test_mod_float(self): def test_mod_float(self):
a = jt.random((10,)) a = jt.random((10,))
@ -137,7 +150,8 @@ class TestBinaryOp(unittest.TestCase):
a = jt.random((10,), 'float64') a = jt.random((10,), 'float64')
b = jt.random((10,), 'float64') b = jt.random((10,), 'float64')
c = a % b c = a % b
assert np.allclose(c.data, a.data % b.data) assert np.allclose(c.data, a.data % b.data, a.data, b.data)
if jt.flags.amp_reg & 2: return
a = jt.random((10,)) * 1000 a = jt.random((10,)) * 1000
b = (jt.random((10,)) * 10).int() + 1 b = (jt.random((10,)) * 10).int() + 1
c = a % b c = a % b
@ -169,5 +183,19 @@ class TestBinaryOp(unittest.TestCase):
class TestBinaryOpCuda(TestBinaryOp, test_cuda(2)): class TestBinaryOpCuda(TestBinaryOp, test_cuda(2)):
pass pass
class TestBinaryOpCpuFp16(TestBinaryOp):
def setUp(self):
jt.flags.amp_reg = 2 | 4 | 8 | 16
def tearDown(self):
jt.flags.amp_reg = 0
class TestBinaryOpCudaFp16(TestBinaryOp):
def setUp(self):
jt.flags.amp_reg = 2 | 4 | 8 | 16
jt.flags.use_cuda = 1
def tearDown(self):
jt.flags.amp_reg = 0
jt.flags.use_cuda = 0
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -0,0 +1,342 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import os
def transpose0231(x):
s0, s1, s2, s3 = x.shape
asize = 16
bsize = 16
ILP = 2
return jt.code([s0, s2, s3, s1], x.dtype, [x],
cuda_header="#include <type/fp16_compute.h>\n#include <cassert>",
cuda_src=f"""
__global__ void kernel(in0_type* __restrict__ x, in0_type* __restrict__ y, int s0, int s1, int s2, int s3) {{
__shared__ in0_type t[{asize*ILP}*{bsize*ILP+1}];
int t3 = threadIdx.x % {bsize};
int t1 = threadIdx.x / {bsize};
int b3 = blockIdx.x;
int b2 = blockIdx.y;
int b0 = blockIdx.z;
int x3 = 1;
int x2 = s3;
int x1 = s2*x2;
int x0 = s1*x1;
int y3 = 1;
int y2 = s1;
int y1 = s3*y2;
int y0 = s2*y1;
in0_type tmp[{ILP}];
for (int i=0; i<(s1-1)/{asize*ILP}+1; i++)
{{
int _b3 = b3 * {bsize*ILP} + t3*{ILP};
if (_b3 < s3) {{
#pragma unroll
for (int j=0; j<{ILP}; j++) {{
vload<sizeof(in0_type)*{ILP}>(
tmp,
&x[b0*x0+(t1*{ILP}+j+i*{asize*ILP})*x1+b2*x2+_b3*x3]
);
#pragma unroll
for (int k=0; k<{ILP}; k++)
t[(t1*{ILP}+j)*{bsize*ILP+1}+t3*{ILP}+k] = tmp[k];
}}
}}
__syncthreads();
int t3_ = threadIdx.x % {asize};
int t1_ = threadIdx.x / {asize};
_b3 = b3 * {bsize*ILP} + t1_*{ILP};
if (_b3 < s3) {{
#pragma unroll
for (int j=0; j<{ILP}; j++) {{
#pragma unroll
for (int k=0; k<{ILP}; k++) {{
tmp[k] =
t[(t3*{ILP}+k)*{bsize*ILP+1}+t1_*{ILP}+j];
}}
vload<sizeof(in0_type)*{ILP}>(
&y[b0*y0+b2*y1+(_b3+j)*y2+((t3*{ILP})+i*{asize*ILP})*y3],
tmp
);
}}
}}
__syncthreads();
}}
}}
int s0, s1, s2, s3;
in0->shape.unpack(s0, s1, s2, s3);
kernel<<<{{(s3-1)/{bsize*ILP}+1, s2, s0 }}, {bsize*asize}>>>
(in0_p, out0_p, s0, s1, s2, s3);
""")
def transpose0231_2(x):
s0, s1, s2, s3 = x.shape
asize = 16
bsize = 8
ILP = 2
return jt.code([s0, s2, s3, s1], x.dtype, [x],
cuda_header="#include <type/fp16_compute.h>\n#include <cassert>",
cuda_src=f"""
__global__ __launch_bounds__({asize*bsize}) void kernel(in0_type* __restrict__ x, in0_type* __restrict__ y, int s0, int s1, int s2, int s3) {{
__shared__ in0_type t[{asize*ILP}*{bsize*ILP+1}];
int t3 = threadIdx.x % {bsize};
int t1 = threadIdx.x / {bsize};
int b3 = blockIdx.x;
int b1 = blockIdx.y;
int b2 = 0;
int b0 = blockIdx.z;
int x3 = 1;
int x2 = s3;
int x1 = s2*x2;
int x0 = s1*x1;
int y3 = 1;
int y2 = s1;
int y1 = s3*y2;
int y0 = s2*y1;
in0_type tmp[{ILP}];
{{
int _b3 = b3 * {bsize*ILP} + t3*{ILP};
if (_b3 < s3) {{
#pragma unroll
for (int j=0; j<{ILP}; j++) {{
if (t1*{ILP}+j+b1*{asize*ILP} >= s1)
continue;
vload<sizeof(in0_type)*{ILP}>(
tmp,
&x[b0*x0+(t1*{ILP}+j+b1*{asize*ILP})*x1+b2*x2+_b3*x3]
);
#pragma unroll
for (int k=0; k<{ILP}; k++)
t[(t1*{ILP}+j)*{bsize*ILP+1}+t3*{ILP}+k] = tmp[k];
}}
}}
__syncthreads();
int t3_ = threadIdx.x % {asize};
int t1_ = threadIdx.x / {asize};
_b3 = b3 * {bsize*ILP} + t1_*{ILP};
int yy3 = (t3_*{ILP})+b1*{asize*ILP};
if (_b3 < s3 && yy3 < s1) {{
#pragma unroll
for (int j=0; j<{ILP}; j++) {{
#pragma unroll
for (int k=0; k<{ILP}; k++) {{
tmp[k] =
t[(t3_*{ILP}+k)*{bsize*ILP+1}+t1_*{ILP}+j];
}}
vload<sizeof(in0_type)*{ILP}>(
&y[b0*y0+b2*y1+(_b3+j)*y2+yy3*y3],
tmp
);
// printf("%d %d %d %d %d\\n", b0*y0+b2*y1+(_b3+j)*y2+yy3*y3,
// b0, b2, (_b3+j), yy3);
}}
}}
__syncthreads();
}}
}}
int s0, s1, s2, s3;
in0->shape.unpack(s0, s1, s2, s3);
kernel<<<{{(s3-1)/{bsize*ILP}+1, (s1-1)/{asize*ILP}+1, s0 }}, {bsize*asize}>>>
(in0_p, out0_p, s0, s1, s2, s3);
""")
def check_share():
return
a = jt.rand((30, 32, 4, 2000)).float32()
jt.code(a.shape, a.dtype, [a],
cuda_header="#include <type/fp16_compute.h>\n#include <cassert>",
cuda_src="""
__global__ void kernel(in0_type* __restrict__ a, in0_type* __restrict__ b) {
__shared__ float x[32*33];
for (int i=0; i<3; i++) {
((float2*)&x[i])[0] = ((float2*)&a[i])[0];
((float2*)&b[i])[0] = ((float2*)&x[i+1])[0];
}
}
kernel<<<1024,16*16>>>(in0_p, out0_p);
LOGir << "aaa";
""").sync()
jt.sync_all(True)
# print(a[0]+1)
print("pass test")
class TestFP16(unittest.TestCase):
def test_array(self):
a = np.array([1,2,3], dtype="float16")
b = jt.array(a)
np.testing.assert_allclose(a, b.data)
def test_add(self):
a = np.array([1,2,3], dtype="float16")
b = jt.array(a)
c = b+b
np.testing.assert_allclose(c.data, a+a)
d = c.sum()
np.testing.assert_allclose(d.data, [12])
c = c+1
print(c)
def test_matmul(self):
a = jt.random((100,100)).float16()
b = jt.random((100,100)).float16()
c = jt.matmul(a, b)
c.sync()
def test_matmul_grad(self):
a = jt.random((100,100)).float16()
b = jt.random((100,100)).float16()
c = jt.matmul(a, b)
c.sync()
da, db = jt.grad(c, [a,b])
jt.sync_all()
assert da.dtype == "float16"
assert db.dtype == "float16"
def test_array_random_auto_cast(self):
a = jt.array([1.0,2.0])
assert a.dtype == "float32"
with jt.flag_scope(amp_reg=2+16):
a = jt.array([1.0,2.0])
assert a.dtype == "float16", a.dtype
a = jt.random([10])
assert a.dtype == "float32"
with jt.flag_scope(amp_reg=2+16):
a = jt.random([10])
assert a.dtype == "float16", a.dtype
def test_conv(self):
a = jt.random((3,4,5,5)).float16()
b = jt.random((4,4,3,3)).float16()
c = jt.nn.conv(a, b)
c.sync()
def test_max(self):
a = jt.random((100,)).float16()
b = jt.random((100,)).float16()
c = a.maximum(b)
c.sync()
def test_reduce_dtype_infer(self):
with jt.flag_scope(amp_reg=1):
a = jt.random((3,4,5,5)).float16()
b = a.sum()
b.sync()
assert b.dtype == "float32"
with jt.flag_scope(amp_reg=2):
a = jt.random((3,4,5,5)).float16()
b = a.sum()
b.sync()
assert b.dtype == "float32"
with jt.flag_scope(amp_reg=0):
a = jt.random((3,4,5,5)).float16()
b = a.sum()
b.sync()
assert b.dtype == "float32"
with jt.flag_scope(amp_reg=2+4):
a = jt.random((3,4,5,5)).float16()
b = a.sum()
b.sync()
assert b.dtype == "float16", b.dtype
def test_white_dtype_infer(self):
with jt.flag_scope(amp_reg=1):
a = jt.random((3,4,5,5)).float16()
b = a**a
b.sync()
assert b.dtype == "float32"
with jt.flag_scope(amp_reg=2):
a = jt.random((3,4,5,5)).float16()
b = a**a
b.sync()
assert b.dtype == "float32"
with jt.flag_scope(amp_reg=0):
a = jt.random((3,4,5,5)).float16()
b = a**a
b.sync()
assert b.dtype == "float32"
with jt.flag_scope(amp_reg=2+8):
a = jt.random((3,4,5,5)).float16()
b = a**a
b.sync()
assert b.dtype == "float16", b.dtype
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
class TestFP16CUDA(TestFP16):
def setUp(self):
jt.flags.use_cuda = 1
def tearDown(self):
jt.flags.use_cuda = 0
def test_softmax(self):
a = jt.rand((120, 2000, 2000)).float16()
# a = jt.rand((1, 2000, 2000)).float32()
jt.sync_all()
with jt.profile_scope(10, 100):
a.log_softmax(-1).sync()
def test_transpose(self):
check_share()
# return
a = jt.rand((30, 32, 4, 2000)).float32()
# a = jt.rand((1, 1024, 1, 2000)).float32()
diff = transpose0231(a).data != a.transpose((0,2,3,1)).data
print(np.where(diff))
# return
jt.sync_all()
# with jt.profile_scope(100, 11000):
with jt.profile_scope(100, 11000):
# a.log_softmax(-1).sync()
transpose0231(a).sync()
a.transpose((0,2,3,1)).sync()
# a.transpose((0,2,1,3)).sync()
a.fuse_transpose((0,2,1,3)).sync()
(a+1).sync()
jt.sync_all(True)
diff = transpose0231(a).data != a.transpose((0,2,3,1)).data
print(np.where(diff))
np.testing.assert_allclose(transpose0231(a).data, a.transpose((0,2,3,1)).data)
def test_transpose2(self):
# check_share()
# return
# a = jt.rand((30, 32, 4, 2000)).float32()
# a = jt.rand((1, 10000, 1, 2000)).float32()
a = jt.rand((1, 10000, 1, 2048)).float32()
print("transpose")
transpose0231_2(a).sync()
print("add")
(a+1).sync()
return
# a = jt.arange(32*16).reshape((1, 32, 1, 16))
diff = transpose0231_2(a).data != a.transpose((0,2,3,1)).data
print(np.where(diff))
# return
jt.sync_all()
# with jt.profile_scope(100, 11000):
with jt.profile_scope(100, 1100):
# a.log_softmax(-1).sync()
transpose0231_2(a).sync()
a.transpose((0,2,3,1)).sync()
# a.transpose((0,2,1,3)).sync()
a.fuse_transpose((0,2,1,3)).sync()
(a+1).sync()
jt.sync_all(True)
diff = transpose0231_2(a).data != a.transpose((0,2,3,1)).data
print(np.where(diff))
np.testing.assert_allclose(transpose0231_2(a).data, a.transpose((0,2,3,1)).data)
if __name__ == "__main__":
unittest.main()

View File

@ -75,6 +75,8 @@ class TestPad(unittest.TestCase):
print('pass flip test ...') print('pass flip test ...')
def test_cross(self): def test_cross(self):
def check_equal(a, b, tol):
np.testing.assert_allclose(a.detach().numpy(), b.numpy(), atol=1e-5)
arr1 = np.random.randn(16,3,224,224,3) arr1 = np.random.randn(16,3,224,224,3)
arr2 = np.random.randn(16,3,224,224,3) arr2 = np.random.randn(16,3,224,224,3)
check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=1), jt.array(arr1).cross(jt.array(arr2), dim=1), 1e-1) check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=1), jt.array(arr1).cross(jt.array(arr2), dim=1), 1e-1)
@ -257,34 +259,51 @@ class TestOther(unittest.TestCase):
a = jt.arctan2(jt.array([1,1.0,0]), jt.array([1,0.0,-1])) a = jt.arctan2(jt.array([1,1.0,0]), jt.array([1,0.0,-1]))
np.testing.assert_allclose(a.data, [0.7853982,1.5707964,3.1415927]) np.testing.assert_allclose(a.data, [0.7853982,1.5707964,3.1415927])
y = jt.random((100,))
x = jt.random((100,))
z = jt.arctan2(y, x)
z2 = np.arctan2(y.data, x.data)
np.testing.assert_allclose(z.data, z2)
def test_code_softmax(self): def test_code_softmax(self):
if not jt.has_cuda: return if not jt.has_cuda: return
def softmax(x, dim = None): def softmax(x, dim = None, log=False):
if dim is None: if dim is None:
x = (x - x.max()).exp() x = (x - x.max()).exp()
ret = x / x.sum() ret = x / x.sum()
else: else:
x = (x-x.max(dim, keepdims=True)).exp() x = (x-x.max(dim, keepdims=True)).exp()
ret = x / x.sum(dim, keepdims=True) ret = x / x.sum(dim, keepdims=True)
if log: return ret.log()
return ret return ret
from jittor.other.code_softmax import softmax_v1 from jittor.other.code_softmax import softmax_v1
with jt.flag_scope(use_cuda = 1): with jt.flag_scope(use_cuda = 1):
shape = (120, 2000, 2000) shape = (120, 2000, 2000)
# shape = (3,3) shape = (3,3)
a = jt.rand(shape) for log in [0,1]:
c = jt.rand(shape) for shape in [(3,3),
b = softmax(a, -1) (12, 200, 2000),
bb = softmax_v1(a) (12, 200, 2048),
(12, 200, 2049)]:
print(shape)
a = jt.rand(shape)
c = jt.rand(shape)
b = softmax(a, -1, log=log)
bb = softmax_v1(a, log=log)
err = (bb - b).abs().max() err = (bb - b).abs().max()
assert err.item() < 1e-5 assert err.item() < 1e-5, (err, bb, b)
d1 = jt.grad(b*c, a) d1 = jt.grad(b*c, a)
d2 = jt.grad(bb*c, a) d2 = jt.grad(bb*c, a)
err = (d1 - d2).abs().max() err = (d1 - d2).abs().max()
assert err.item() < 1e-5
if log:
assert err.item() < 1e-2, (err.item())
else:
assert err.item() < 1e-5, (err.item())
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -36,19 +36,7 @@ class MnistNet(Module):
return x return x
@unittest.skipIf(skip_this_test, "skip_this_test") @unittest.skipIf(skip_this_test, "skip_this_test")
class TestResnet(unittest.TestCase): class TestResnetFp32(unittest.TestCase):
@classmethod
def setUpClass(self):
# hyper-parameters
self.batch_size = int(os.environ.get("TEST_BATCH_SIZE", "100"))
self.weight_decay = 0.0001
self.momentum = 0.9
self.learning_rate = 0.1
# mnist dataset
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
.set_attrs(batch_size=self.batch_size, shuffle=True)
self.train_loader.num_workers = 4
# setup random seed # setup random seed
def setup_seed(self, seed): def setup_seed(self, seed):
np.random.seed(seed) np.random.seed(seed)
@ -59,6 +47,19 @@ class TestResnet(unittest.TestCase):
@jt.flag_scope(use_cuda=1, use_stat_allocator=1) @jt.flag_scope(use_cuda=1, use_stat_allocator=1)
def test_resnet(self): def test_resnet(self):
self.setup_seed(1) self.setup_seed(1)
# hyper-parameters
self.batch_size = int(os.environ.get("TEST_BATCH_SIZE", "100"))
self.weight_decay = 0.0001
self.momentum = 0.9
self.learning_rate = 0.1
if jt.flags.amp_reg:
self.learning_rate = 0.01
# mnist dataset
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
.set_attrs(batch_size=self.batch_size, shuffle=True)
self.train_loader.num_workers = 4
loss_list=[] loss_list=[]
acc_list=[] acc_list=[]
mnist_net = MnistNet() mnist_net = MnistNet()
@ -70,6 +71,7 @@ class TestResnet(unittest.TestCase):
for data, target in self.train_loader: for data, target in self.train_loader:
batch_id = self.train_loader.batch_id batch_id = self.train_loader.batch_id
epoch_id = self.train_loader.epoch_id epoch_id = self.train_loader.epoch_id
data = data.float_auto()
# train step # train step
# with jt.log_capture_scope( # with jt.log_capture_scope(
@ -120,6 +122,8 @@ class TestResnet(unittest.TestCase):
# Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000 # Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000 # Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
if jt.flags.amp_reg:
continue
if jt.in_mpi: if jt.in_mpi:
assert jt.core.number_of_lived_vars() < 8100, jt.core.number_of_lived_vars() assert jt.core.number_of_lived_vars() < 8100, jt.core.number_of_lived_vars()
else: else:
@ -131,5 +135,14 @@ class TestResnet(unittest.TestCase):
assert np.mean(loss_list[-50:])<0.5 assert np.mean(loss_list[-50:])<0.5
assert np.mean(acc_list[-50:])>0.8 assert np.mean(acc_list[-50:])>0.8
@unittest.skipIf(skip_this_test, "skip_this_test")
class TestResnetFp16(TestResnetFp32):
def setup(self):
jt.flags.auto_mixed_precision_level = 5
def tearDown(self):
jt.flags.auto_mixed_precision_level = 0
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -0,0 +1,121 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import os
from jittor.test.misc import superglue
from jittor.test.misc.superglue import SuperGlue
import time
@jt.flag_scope(use_cuda=1)
def main():
global superglue
superglue.split_size = int(os.environ.get("split_size", "12"))
# superglue.split_size = 1000000
batch = 30
num = 2000
dim = 128
# jt.display_memory_info()
# os.system("nvidia-smi")
# breakpoint()
with jt.no_grad():
config = {
'superglue': {
'sinkhorn_iterations': 25,
'match_threshold': 0.01,
'keypoint_position_dim': 2,
'descriptor_dim': dim,
'use_dual_softmax': True,
'GNN_layers': ['self', 'cross'] * 9,
}
}
superglue = SuperGlue(config.get('superglue', {}))
superglue.eval()
data = {
'keypoints0': jt.rand((batch, num, 2), dtype=jt.float),
'keypoints1': jt.rand((batch, num, 2), dtype=jt.float),
'shape0': jt.rand((batch, 2), dtype=jt.float),
'shape1': jt.rand((batch, 2), dtype=jt.float),
'descriptors0': jt.rand((batch, dim, num), dtype=jt.float),
'descriptors1': jt.rand((batch, dim, num), dtype=jt.float),
'scores0': jt.rand((batch, num), dtype=jt.float),
'scores1': jt.rand((batch, num), dtype=jt.float),
'all_matches': jt.randint(0, num, (batch, num, 2), dtype=jt.int),
'return_match': False,
# 'match_num': match_num
}
use_fp16 = int(os.environ.get("use_fp16", "0"))
if use_fp16:
jt.flags.amp_reg = 2
for k,v in data.items():
if isinstance(v, jt.Var) and v.dtype == "float32":
v.assign(v.float16())
for v in superglue.parameters():
if v.dtype == "float32":
v.assign(v.float16())
jt.sync_all(True)
import pickle
jt.sync_all(True)
for x in range(5):
print(x)
jt.gc()
x = superglue(data)['loss']
x.sync()
jt.display_memory_info()
# os.system("nvidia-smi")
# breakpoint()
# print(data)
# print(x)
# with open("/tmp/record.pkl", "wb") as f:
# pickle.dump([data, x], f, pickle.HIGHEST_PROTOCOL)
# with jt.flag_scope(trace_py_var=3, profile_memory_enable=1):
# x = superglue(data)['loss']
# x.sync()
# jt.get_max_memory_treemap()
# exit(0)
jt.sync_all(True)
time0 = time.time()
jt.flags.profiler_enable = int(os.environ.get("profiler", "0"))
for x in range(20):
print(x)
# jt.display_memory_info()
x = superglue(data)['loss']
x.sync()
# print(x)
jt.sync_all(True)
time1 = time.time()
print("avg time:", (time1 - time0) / 20)
return (time1 - time0) / 20
class TestSuperglue(unittest.TestCase):
def test(self):
if not jt.has_cuda: return
t1 = main()
os.environ["use_fp16"] = "1"
t2 = main()
os.environ["use_fp16"] = "0"
assert t1*0.55 > t2
if __name__ == "__main__":
unittest.main()

View File

@ -17,7 +17,8 @@ def check(op, *args):
x = convert(x) x = convert(x)
y = convert(y) y = convert(y)
# str match nan and inf # str match nan and inf
assert x.dtype == y.dtype and x.shape == y.shape assert x.dtype == y.dtype and x.shape == y.shape, \
(x.dtype, y.dtype, x.shape, y.shape)
for a,b in zip(x.flatten(), y.flatten()): for a,b in zip(x.flatten(), y.flatten()):
assert str(a)[:5] == str(b)[:5], (a,b) assert str(a)[:5] == str(b)[:5], (a,b)
@ -32,9 +33,10 @@ class TestUnaryOp(unittest.TestCase):
check("logical_not", a) check("logical_not", a)
check("bitwise_not", a) check("bitwise_not", a)
b = np.array([1.1, 2.2, 3.3, 4.4, -1, 0]) b = np.array([1.1, 2.2, 3.3, 4.4, -1, 0])
check("log", a.astype("float32")) type = "float16" if (jt.flags.amp_reg & 2) else "float32"
check("exp", a.astype("float32")) check("log", a.astype(type))
check("sqrt", a.astype("float32")) check("exp", a.astype(type))
check("sqrt", a.astype(type))
def test_grad(self): def test_grad(self):
ops = ["abs", "negative", "log", "exp", "sqrt", ops = ["abs", "negative", "log", "exp", "sqrt",
@ -60,7 +62,8 @@ class TestUnaryOp(unittest.TestCase):
ja = jt.array(b) ja = jt.array(b)
jb = eval(f"jt.{op}(ja)") jb = eval(f"jt.{op}(ja)")
jda = jt.grad(jb, ja) jda = jt.grad(jb, ja)
assert (np.allclose(jda.data, da)), (jda.data,da,op) tol = 1e-2 if jt.flags.amp_reg & 2 else 1e-6
assert (np.allclose(jda.data, da, atol=tol, rtol=tol)), (jda.data,da,op)
def test_sigmoid(self): def test_sigmoid(self):
a = np.arange(-150,150, 10).astype("float32") a = np.arange(-150,150, 10).astype("float32")
@ -92,11 +95,26 @@ class TestUnaryOp(unittest.TestCase):
np.testing.assert_allclose(y.data, y2.data) np.testing.assert_allclose(y.data, y2.data)
d = jt.grad(x2, y2) d = jt.grad(x2, y2)
_, (dn,) = ngrad(lambda y: special.erfinv(y).sum(), [y], 1e-8) _, (dn,) = ngrad(lambda y: special.erfinv(y).sum(), [y], 1e-8)
np.testing.assert_allclose(d.data, dn, atol=1e-6, rtol=1e-6) tol = 1e-3 if jt.flags.amp_reg & 2 else 1e-6
np.testing.assert_allclose(d.data, dn, atol=tol, rtol=tol)
class TestUnaryOpCuda(TestUnaryOp, test_cuda(2)): class TestUnaryOpCuda(TestUnaryOp, test_cuda(2)):
pass pass
class TestUnaryOpCudaFp16(TestUnaryOp, test_cuda(2)):
def setUp(self):
jt.flags.amp_reg = 2 | 4 | 8 | 16
def tearDown(self):
jt.flags.amp_reg = 0
class TestUnaryOpCudaFp16(TestUnaryOp, test_cuda(2)):
def setUp(self):
jt.flags.amp_reg = 2 | 4 | 8 | 16
jt.flags.use_cuda = 1
def tearDown(self):
jt.flags.amp_reg = 0
jt.flags.use_cuda = 0
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Binary file not shown.