mirror of https://github.com/Jittor/Jittor
add fp16 support
This commit is contained in:
parent
7cf6165a10
commit
39ecdd84fd
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.1.38'
|
||||
__version__ = '1.3.1.38.1'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -304,15 +304,52 @@ Var.cast = Var.cast
|
|||
def array(data, dtype=None):
|
||||
if isinstance(data, core.Var):
|
||||
if dtype is None:
|
||||
return data.clone()
|
||||
return cast(data, dtype)
|
||||
if dtype is not None:
|
||||
ret = data.clone()
|
||||
else:
|
||||
ret = cast(data, dtype)
|
||||
elif dtype is not None:
|
||||
if isinstance(dtype, NanoString):
|
||||
dtype = str(dtype)
|
||||
elif callable(dtype):
|
||||
dtype = dtype.__name__
|
||||
return ops.array(np.array(data, dtype))
|
||||
return ops.array(data)
|
||||
ret = ops.array(np.array(data, dtype))
|
||||
else:
|
||||
ret = ops.array(data)
|
||||
# TODO: move those code to core
|
||||
amp_reg = jt.flags.amp_reg
|
||||
if amp_reg and ret.numel() != 1 and ret.dtype.is_float():
|
||||
if amp_reg & 16:
|
||||
if amp_reg & 1:
|
||||
if ret.dtype != "float32":
|
||||
return ret.float32()
|
||||
elif amp_reg & 2:
|
||||
if ret.dtype != "float16":
|
||||
return ret.float16()
|
||||
return ret
|
||||
|
||||
def random(shape, dtype="float32", type="uniform"):
|
||||
# TODO: move those code to core
|
||||
if dtype == "float16":
|
||||
# TODO: make curand support fp16
|
||||
ret = ops.random(shape, "float32", type).float16()
|
||||
else:
|
||||
ret = ops.random(shape, dtype, type)
|
||||
amp_reg = jt.flags.amp_reg
|
||||
if amp_reg:
|
||||
if amp_reg & 16:
|
||||
if amp_reg & 1:
|
||||
if ret.dtype != "float32":
|
||||
return ret.float32()
|
||||
elif amp_reg & 2:
|
||||
if ret.dtype != "float16":
|
||||
return ret.float16()
|
||||
return ret
|
||||
|
||||
def float_auto(x):
|
||||
if jt.flags.amp_reg & 2:
|
||||
return x.float16()
|
||||
return x.float32()
|
||||
Var.float_auto = float_auto
|
||||
|
||||
def array64(data, dtype=None):
|
||||
with jt.flag_scope(auto_convert_64_to_32=0):
|
||||
|
@ -1419,18 +1456,15 @@ Var.size = size
|
|||
|
||||
|
||||
def to_int(v):
|
||||
dtype = str(v.dtype)
|
||||
assert dtype.startswith("int")
|
||||
assert v.dtype.is_int()
|
||||
return v.item()
|
||||
|
||||
def to_float(v):
|
||||
dtype = str(v.dtype)
|
||||
assert dtype.startswith("float")
|
||||
assert v.dtype.is_float()
|
||||
return v.item()
|
||||
|
||||
def to_bool(v):
|
||||
dtype = str(v.dtype)
|
||||
assert dtype.startswith("int") or dtype=="bool"
|
||||
assert v.dtype.is_int() or v.dtype.is_bool()
|
||||
return ori_bool(v.item())
|
||||
|
||||
Var.__int__ = to_int
|
||||
|
|
|
@ -210,6 +210,12 @@ def setup_cuda_extern():
|
|||
LOG.w(f"CUDA found but cub is not loaded:\n{line}")
|
||||
|
||||
libs = ["cublas", "cudnn", "curand"]
|
||||
# in cuda 11.4, module memory comsumptions:
|
||||
# default context: 259 MB
|
||||
# cublas: 340 MB
|
||||
# cudnn: 340 MB
|
||||
if int(os.environ.get("conv_opt", "0")):
|
||||
libs = ["cublas", "curand"]
|
||||
for lib_name in libs:
|
||||
try:
|
||||
setup_cuda_lib(lib_name, extra_flags=link_cuda_extern)
|
||||
|
@ -309,22 +315,27 @@ def install_cutt(root_folder):
|
|||
if md5 != true_md5:
|
||||
os.remove(fullname)
|
||||
shutil.rmtree(dirname)
|
||||
if not os.path.isfile(os.path.join(cache_path, "libcutt"+so)):
|
||||
LOG.i("Downloading cutt...")
|
||||
download_url_to_local(url, filename, root_folder, true_md5)
|
||||
CUTT_PATH = os.environ.get("CUTT_PATH", "")
|
||||
if not os.path.isfile(os.path.join(cache_path, "libcutt"+so)) or CUTT_PATH:
|
||||
if CUTT_PATH:
|
||||
dirname = CUTT_PATH
|
||||
else:
|
||||
LOG.i("Downloading cutt...")
|
||||
download_url_to_local(url, filename, root_folder, true_md5)
|
||||
|
||||
import zipfile
|
||||
import zipfile
|
||||
|
||||
zf = zipfile.ZipFile(fullname)
|
||||
try:
|
||||
zf.extractall(path=root_folder)
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
raise
|
||||
zf.close()
|
||||
zf = zipfile.ZipFile(fullname)
|
||||
try:
|
||||
zf.extractall(path=root_folder)
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
raise
|
||||
zf.close()
|
||||
|
||||
LOG.i("installing cutt...")
|
||||
arch_flag = ""
|
||||
# -Xptxas -dlcm=ca actually not work
|
||||
arch_flag = " -Xptxas -dlcm=ca "
|
||||
if len(flags.cuda_archs):
|
||||
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
|
||||
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
||||
|
|
|
@ -23,8 +23,8 @@ EXTERN_LIB cublasHandle_t cublas_handle;
|
|||
|
||||
static inline cudaDataType get_dtype(NanoString dtype) {
|
||||
if (dtype == ns_float32) return CUDA_R_32F;
|
||||
// if (dtype == ns_float64) return CUDA_R_64F;
|
||||
// if (dtype == ns_float16) return CUDA_R_16F;
|
||||
if (dtype == ns_float64) return CUDA_R_64F;
|
||||
if (dtype == ns_float16) return CUDA_R_16F;
|
||||
LOGf << "not support type" << dtype;
|
||||
return CUDA_R_32F;
|
||||
}
|
||||
|
|
|
@ -124,6 +124,10 @@ void CublasBatchedMatmulOp::jit_run() {
|
|||
if (use_tensorcore) {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
}
|
||||
if (a->dtype() == ns_float16
|
||||
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
|
||||
computeType = CUBLAS_COMPUTE_16F;
|
||||
}
|
||||
checkCudaErrors(cublasGemmStridedBatchedEx(handle_,
|
||||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
k, n, m, &alpha,
|
||||
|
|
|
@ -81,6 +81,10 @@ void CublasMatmulOp::jit_run() {
|
|||
if (use_tensorcore) {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
}
|
||||
if (a->dtype() == ns_float16
|
||||
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
|
||||
computeType = CUBLAS_COMPUTE_16F;
|
||||
}
|
||||
checkCudaErrors(cublasGemmEx(handle_,
|
||||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
k, n, m, &alpha,
|
||||
|
|
|
@ -175,6 +175,11 @@ void CudnnConvOp::jit_run() {
|
|||
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
|
||||
}
|
||||
|
||||
if (x->dtype() == ns_float16
|
||||
|| y->dtype() == ns_float16 || w->dtype() == ns_float16) {
|
||||
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
|
||||
}
|
||||
|
||||
int dimY[] = {
|
||||
(int)y->shape[findc("@YFORMAT", 'a')], // n
|
||||
(int)y->shape[findc("@YFORMAT", 'b')], // c
|
||||
|
|
|
@ -488,18 +488,11 @@ def arctan2(y,x):
|
|||
angle = jt.zeros(x.shape,dtype=x.dtype)
|
||||
x = (x!=0.0).ternary(x, x+1e-30)
|
||||
angle = (y/x).arctan()
|
||||
|
||||
mask = (y<0) & (x<0)
|
||||
if angle[mask].numel()>0:
|
||||
angle[mask] -= np.pi
|
||||
|
||||
mask = (y>=0) &(x<0)
|
||||
if angle[mask].numel()>0:
|
||||
angle[mask] +=np.pi
|
||||
mask = y<0 | ((y==0) & (x<0))
|
||||
angle = angle + mask*np.pi
|
||||
return angle
|
||||
|
||||
|
||||
|
||||
def nonzero(x):
|
||||
r'''
|
||||
Return the index of the elements of input tensor which are not equal to zero.
|
||||
|
|
|
@ -143,7 +143,7 @@ class ResNet(nn.Module):
|
|||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.avgpool(x)
|
||||
x = self.avgpool(x).float_auto()
|
||||
x = jt.reshape(x, (x.shape[0], -1))
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
|
|
@ -37,9 +37,10 @@ def matmul_transpose(a, b):
|
|||
assert len(a.shape) == 2 and len(b.shape) == 2
|
||||
|
||||
shape = list(a.shape)[:-1] + list(b.shape)
|
||||
a = a.broadcast(shape, [len(shape)-2])
|
||||
b = b.broadcast(shape)
|
||||
return (a*b).sum(len(shape)-1)
|
||||
with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4):
|
||||
a = a.broadcast(shape, [len(shape)-2])
|
||||
b = b.broadcast(shape)
|
||||
return (a*b).sum(len(shape)-1)
|
||||
|
||||
|
||||
def bmm_transpose(a, b):
|
||||
|
@ -108,47 +109,48 @@ Example::
|
|||
c = jt.matmul(a, b)
|
||||
assert c.shape == [8, 10, 3, 5]
|
||||
'''
|
||||
len_a = len(a.shape)
|
||||
len_b = len(b.shape)
|
||||
if len_b == 1:
|
||||
# a: [n, m], b:[m], c:[n]
|
||||
return (a*b).sum(-1)
|
||||
if len_a == 1:
|
||||
# a: [n], b:[n,k], c:[k]
|
||||
return (a.broadcast(b, [-1]) * b).sum(0)
|
||||
if len_a>=3 and len_a==len_b:
|
||||
# bmm
|
||||
# a: [..., n, m], b: [..., m, k], c:[..., n, k]
|
||||
if jt.flags.use_cuda and jt.compile_extern.cublas_ops:
|
||||
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
|
||||
shape = []
|
||||
len_c = max(len_a, len_b)
|
||||
(n, m), (m_, k) = a.shape[-2:], b.shape[-2:]
|
||||
assert m == m_, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}"
|
||||
# a: [..., n, m]
|
||||
# b: [..., m, k]
|
||||
# cc:[..., n, m, k]
|
||||
# -->
|
||||
# 012
|
||||
if len_b == 2 and len_a>2:
|
||||
# TODO:ugly implementation for tuner
|
||||
aa = a.reshape((-1, m))
|
||||
cc = matmul(aa, b)
|
||||
# print(a.shape, b.shape, cc.shape)
|
||||
return cc.reshape(a.shape[:-1] + [k])
|
||||
for i in range(len_c-2):
|
||||
ai = len_a-(len_c-i)
|
||||
bi = len_b-(len_c-i)
|
||||
an = a.shape[ai] if ai>=0 else 1
|
||||
bn = b.shape[bi] if bi>=0 else 1
|
||||
if an!=1 and bn!=1:
|
||||
assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}"
|
||||
cn = max(an, bn)
|
||||
shape.append(cn)
|
||||
shape.extend([n, m, k])
|
||||
a = a.broadcast(shape, [-1])
|
||||
b = b.broadcast(shape, [-3])
|
||||
return (a*b).sum(-2)
|
||||
with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4):
|
||||
len_a = len(a.shape)
|
||||
len_b = len(b.shape)
|
||||
if len_b == 1:
|
||||
# a: [n, m], b:[m], c:[n]
|
||||
return (a*b).sum(-1)
|
||||
if len_a == 1:
|
||||
# a: [n], b:[n,k], c:[k]
|
||||
return (a.broadcast(b, [-1]) * b).sum(0)
|
||||
if len_a>=3 and len_a==len_b:
|
||||
# bmm
|
||||
# a: [..., n, m], b: [..., m, k], c:[..., n, k]
|
||||
if jt.flags.use_cuda and jt.compile_extern.cublas_ops:
|
||||
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
|
||||
shape = []
|
||||
len_c = max(len_a, len_b)
|
||||
(n, m), (m_, k) = a.shape[-2:], b.shape[-2:]
|
||||
assert m == m_, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}"
|
||||
# a: [..., n, m]
|
||||
# b: [..., m, k]
|
||||
# cc:[..., n, m, k]
|
||||
# -->
|
||||
# 012
|
||||
if len_b == 2 and len_a>2:
|
||||
# TODO:ugly implementation for tuner
|
||||
aa = a.reshape((-1, m))
|
||||
cc = matmul(aa, b)
|
||||
# print(a.shape, b.shape, cc.shape)
|
||||
return cc.reshape(a.shape[:-1] + [k])
|
||||
for i in range(len_c-2):
|
||||
ai = len_a-(len_c-i)
|
||||
bi = len_b-(len_c-i)
|
||||
an = a.shape[ai] if ai>=0 else 1
|
||||
bn = b.shape[bi] if bi>=0 else 1
|
||||
if an!=1 and bn!=1:
|
||||
assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}"
|
||||
cn = max(an, bn)
|
||||
shape.append(cn)
|
||||
shape.extend([n, m, k])
|
||||
a = a.broadcast(shape, [-1])
|
||||
b = b.broadcast(shape, [-3])
|
||||
return (a*b).sum(-2)
|
||||
jt.Var.matmul = jt.Var.__matmul__ = matmul
|
||||
jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))
|
||||
|
||||
|
@ -488,22 +490,22 @@ class BCEWithLogitsLoss(Module):
|
|||
def execute(self, output, target):
|
||||
return binary_cross_entropy_with_logits(output,target,self.weight,self.pos_weight,self.size_average)
|
||||
|
||||
def softmax(x, dim = None):
|
||||
def softmax(x, dim=None, log=False):
|
||||
import jittor.other.code_softmax as code_softmax
|
||||
if code_softmax.can_softmax_v1(x, dim):
|
||||
return code_softmax.softmax_v1(x)
|
||||
return code_softmax.softmax_v1(x, log)
|
||||
if dim is None:
|
||||
x = (x - x.max()).exp()
|
||||
ret = x / x.sum()
|
||||
else:
|
||||
x = (x-x.max(dim, keepdims=True)).exp()
|
||||
ret = x / x.sum(dim, keepdims=True)
|
||||
if log: return ret.log()
|
||||
return ret
|
||||
jt.Var.softmax = softmax
|
||||
|
||||
def log_softmax(x,dim=None):
|
||||
x = softmax(x,dim=dim)
|
||||
return jt.log(x)
|
||||
return softmax(x,dim=dim, log=True)
|
||||
jt.Var.log_softmax = log_softmax
|
||||
|
||||
def log_sigmoid(x):
|
||||
|
@ -832,15 +834,16 @@ class Conv(Module):
|
|||
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
||||
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
||||
assert oh>0 and ow>0
|
||||
xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [
|
||||
'i0', # Nid
|
||||
'i2', # Cid
|
||||
f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid
|
||||
f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid
|
||||
])
|
||||
ww = self.weight.broadcast(xx.shape, [0,3,4])
|
||||
yy = xx*ww
|
||||
y = yy.sum([2,5,6]) # Kc, Kh, Kw
|
||||
with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4):
|
||||
xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [
|
||||
'i0', # Nid
|
||||
'i2', # Cid
|
||||
f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid
|
||||
f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid
|
||||
])
|
||||
ww = self.weight.broadcast(xx.shape, [0,3,4])
|
||||
yy = xx*ww
|
||||
y = yy.sum([2,5,6]) # Kc, Kh, Kw
|
||||
if self.bias is not None:
|
||||
b = self.bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
|
@ -1008,6 +1011,18 @@ class Conv3d(Module):
|
|||
def execute(self, x):
|
||||
return conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class Conv1d_sp(Linear):
|
||||
def __init__(self, inchannels, outchannels, kernel_size=1, bias=True):
|
||||
super().__init__(inchannels, outchannels, bias=bias)
|
||||
assert kernel_size == 1
|
||||
|
||||
def execute(self, x):
|
||||
x = x.transpose(0, 2, 1)
|
||||
x = super().execute(x)
|
||||
x = x.transpose(0, 2, 1)
|
||||
return x
|
||||
|
||||
def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
''' Applies a 2D convolution over an input signal composed of several input planes.
|
||||
|
||||
|
@ -1048,15 +1063,16 @@ def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
|||
Kh, Kw = weight.shape[-2:]
|
||||
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
|
||||
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
|
||||
xx = x.reindex([N,out_channels,C,oh,ow,Kh,Kw], [
|
||||
'i0', # Nid
|
||||
'i2', # Cid
|
||||
f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}', # Hid+Khid
|
||||
f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}', # Wid+KWid
|
||||
])
|
||||
ww = weight.broadcast(xx.shape, [0,3,4])
|
||||
yy = xx*ww
|
||||
y = yy.sum([2,5,6]) # Kc, Kh, Kw
|
||||
with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4):
|
||||
xx = x.reindex([N,out_channels,C,oh,ow,Kh,Kw], [
|
||||
'i0', # Nid
|
||||
'i2', # Cid
|
||||
f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}', # Hid+Khid
|
||||
f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}', # Wid+KWid
|
||||
])
|
||||
ww = weight.broadcast(xx.shape, [0,3,4])
|
||||
yy = xx*ww
|
||||
y = yy.sum([2,5,6]) # Kc, Kh, Kw
|
||||
if bias is not None:
|
||||
b = bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
|
|
|
@ -10,32 +10,48 @@ def can_softmax_v1(a, dim):
|
|||
return False
|
||||
return True
|
||||
|
||||
def softmax_v1(a):
|
||||
def softmax_v1(a, log=False):
|
||||
assert can_softmax_v1(a, -1)
|
||||
length = a.shape[-1]
|
||||
# tnum = 1024
|
||||
tnum = 500 if length % 500 == 0 else 512
|
||||
tnum = 125 if length % 125 == 0 else 128
|
||||
# tnum = 125
|
||||
# tnum = 1000 if length % 1000 == 0 else 1024
|
||||
# tnum = 250
|
||||
per_thread = (length-1) // tnum + 1
|
||||
ILP = 1
|
||||
for ilp in [8,4,2]:
|
||||
if length % tnum == 0 and per_thread % ilp == 0:
|
||||
ILP = ilp
|
||||
per_thread //= ILP
|
||||
break
|
||||
for_loop = f"""
|
||||
#pragma unroll
|
||||
for (int i=0; i<{per_thread}; i++)
|
||||
"""
|
||||
if length % tnum == 0:
|
||||
for_loop += f"if (i*{tnum}+threadIdx.x < len)\n"
|
||||
if length % tnum != 0:
|
||||
for_loop += f"if ((i*{tnum}+threadIdx.x)*{ILP} < len)\n"
|
||||
|
||||
return jt.code(a.shape, a.dtype, [a], cuda_header=f'''
|
||||
#include <{jt.compile_extern.cub_home}cub/cub.cuh>
|
||||
#include <type/fp16_compute.h>
|
||||
''', cuda_src=f'''
|
||||
__global__ void kernel(in0_type* x, out0_type* y, int len) {{
|
||||
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
int id = blockIdx.x * len;
|
||||
in0_type v[{per_thread}];
|
||||
{for_loop} v[i] = x[id+i*{tnum}+threadIdx.x];
|
||||
float v1 = v[0];
|
||||
{for_loop} v1 = max(v1, v[i]);
|
||||
in0_type v[{per_thread}][{ILP}];
|
||||
{for_loop}
|
||||
vload<sizeof(in0_type)*{ILP}>(v[i], &x[id+(i*{tnum}+threadIdx.x)*{ILP}]);
|
||||
// v[i] = x[id+i*{tnum}+threadIdx.x];
|
||||
float v1 = -1e30;
|
||||
{for_loop}
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++) {{
|
||||
v1 = max(v1, float(v[i][j]));
|
||||
}}
|
||||
__shared__ float vmax;
|
||||
auto tmp = BlockReduce(temp_storage).Reduce(v1, cub::Max());
|
||||
if (threadIdx.x == 0)
|
||||
|
@ -43,10 +59,12 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
|
|||
__syncthreads();
|
||||
|
||||
v1 = 0;
|
||||
{for_loop} {{
|
||||
v[i] = expf(v[i] - vmax);
|
||||
v1 += v[i];
|
||||
}}
|
||||
{for_loop}
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++) {{
|
||||
v[i][j] = expf(float(v[i][j]) - vmax);
|
||||
v1 += float(v[i][j]);
|
||||
}}
|
||||
|
||||
tmp = BlockReduce(temp_storage).Sum(v1);
|
||||
__shared__ float vsum;
|
||||
|
@ -54,7 +72,15 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
|
|||
vsum = tmp;
|
||||
__syncthreads();
|
||||
|
||||
{for_loop} y[id+i*{tnum}+threadIdx.x] = v[i] / vsum;
|
||||
{for_loop}
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++)
|
||||
v[i][j] = {
|
||||
"@expand_op(log,@in0_type,float(v[i][j])/vsum)" if log
|
||||
else "float(v[i][j])/vsum"
|
||||
};
|
||||
{for_loop}
|
||||
vload<sizeof(in0_type)*{ILP}>(&y[id+(i*{tnum}+threadIdx.x)*{ILP}], v[i]);
|
||||
}}
|
||||
int len = in0->shape[in0->shape.size()-1];
|
||||
int bnum = in0->numel() / len;
|
||||
|
@ -64,15 +90,17 @@ CHECK(0 == cudaGetLastError());
|
|||
''', cuda_grad_src=[f"""
|
||||
__global__ void kernel(pout0_type* x, dout_type* y, out0_type* z, int len) {{
|
||||
int id = blockIdx.x * len;
|
||||
in0_type vx[{per_thread}];
|
||||
in0_type vy[{per_thread}];
|
||||
in0_type vx[{per_thread}][{ILP}];
|
||||
in0_type vy[{per_thread}][{ILP}];
|
||||
{for_loop} {{
|
||||
vx[i] = x[id+i*{tnum}+threadIdx.x];
|
||||
vy[i] = y[id+i*{tnum}+threadIdx.x];
|
||||
vload<sizeof(in0_type)*{ILP}>(vx[i], &x[id+(i*{tnum}+threadIdx.x)*{ILP}]);
|
||||
vload<sizeof(in0_type)*{ILP}>(vy[i], &y[id+(i*{tnum}+threadIdx.x)*{ILP}]);
|
||||
}}
|
||||
float v1 = 0;
|
||||
{for_loop} v1 += vx[i]*vy[i];
|
||||
|
||||
{for_loop}
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++)
|
||||
v1 += {"vy[i][j];" if log else "vx[i][j]*vy[i][j];"}
|
||||
|
||||
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
@ -83,7 +111,16 @@ __global__ void kernel(pout0_type* x, dout_type* y, out0_type* z, int len) {{
|
|||
__syncthreads();
|
||||
|
||||
{for_loop}
|
||||
z[id+i*{tnum}+threadIdx.x] = vx[i] * (vy[i] - reduce_var);
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++)
|
||||
vx[i][j] = {
|
||||
"vy[i][j] - expf(vx[i][j]) * reduce_var;" if log
|
||||
else "vx[i][j] * (vy[i][j] - reduce_var);"
|
||||
}
|
||||
|
||||
{for_loop}
|
||||
vload<sizeof(in0_type)*{ILP}>(&z[id+(i*{tnum}+threadIdx.x)*{ILP}],
|
||||
vx[i]);
|
||||
}}
|
||||
int len = in0->shape[in0->shape.size()-1];
|
||||
int bnum = in0->numel() / len;
|
||||
|
|
|
@ -120,8 +120,8 @@ class Pool(Module):
|
|||
for (int i2 = p2; i2 < out_shape2; i2 += s2)
|
||||
{{ {forward_body} }}
|
||||
}}
|
||||
int tx = min(1024, out_shape3);
|
||||
int ty = min(1024 / tx, out_shape2);
|
||||
int tx = std::min(1024, out_shape3);
|
||||
int ty = std::min(1024 / tx, out_shape2);
|
||||
int bx = (out_shape2 - 1) / ty + 1;
|
||||
int by = out_shape1;
|
||||
int bz = out_shape0;
|
||||
|
@ -143,8 +143,8 @@ class Pool(Module):
|
|||
{{ {backward_body} }}
|
||||
}}
|
||||
cudaMemsetAsync(out_p, 0, out->size);
|
||||
int tx = min(1024, pout_shape3);
|
||||
int ty = min(1024 / tx, pout_shape2);
|
||||
int tx = std::min(1024, pout_shape3);
|
||||
int ty = std::min(1024 / tx, pout_shape2);
|
||||
int bx = (pout_shape2 - 1) / ty + 1;
|
||||
int by = pout_shape1;
|
||||
int bz = pout_shape0;
|
||||
|
@ -310,9 +310,9 @@ class Pool3d(Module):
|
|||
for (int i2 = p2; i2 < out_shape2; i2 += s2)
|
||||
{{ {forward_body} }}
|
||||
}}
|
||||
int tx = min(1024, out_shape4);
|
||||
int ty = min(1024 / tx, out_shape3);
|
||||
int tz = min(1024 / tx / ty, out_shape2);
|
||||
int tx = std::min(1024, out_shape4);
|
||||
int ty = std::min(1024 / tx, out_shape3);
|
||||
int tz = std::min(1024 / tx / ty, out_shape2);
|
||||
int bx = (out_shape2 - 1) / tz + 1;
|
||||
int by = out_shape1;
|
||||
int bz = out_shape0;
|
||||
|
@ -337,9 +337,9 @@ class Pool3d(Module):
|
|||
{{ {backward_body} }}
|
||||
}}
|
||||
cudaMemsetAsync(out_p, 0, out->size);
|
||||
int tx = min(1024, pout_shape4);
|
||||
int ty = min(1024 / tx, pout_shape3);
|
||||
int tz = min(1024 / tx / ty, pout_shape2);
|
||||
int tx = std::min(1024, pout_shape4);
|
||||
int ty = std::min(1024 / tx, pout_shape3);
|
||||
int tz = std::min(1024 / tx / ty, pout_shape2);
|
||||
int bx = (pout_shape2 - 1) / tz + 1;
|
||||
int by = pout_shape1;
|
||||
int bz = pout_shape0;
|
||||
|
|
|
@ -39,11 +39,24 @@ template<class T> struct StackIniter {
|
|||
#define STACK_ALLOC2(T, a, n) T a[n]
|
||||
#endif
|
||||
|
||||
struct AmpGradGuard {
|
||||
int amp_reg_bk;
|
||||
AmpGradGuard(Op* op) {
|
||||
amp_reg_bk = amp_reg;
|
||||
amp_reg |= (op->flags.flags >> NodeFlags::_prefer_32);
|
||||
}
|
||||
|
||||
~AmpGradGuard() {
|
||||
amp_reg = amp_reg_bk;
|
||||
}
|
||||
};
|
||||
|
||||
VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) {
|
||||
if (dout == nullptr) return nullptr;
|
||||
if (x_index<0) return nullptr;
|
||||
LOGvvvv << "Make grad op:" >> op->name() << "inputs:" >> op->inputs()
|
||||
<< "out:" >> out << "dout:" >> dout << "x:" >> x << "xid:" >> x_index;
|
||||
AmpGradGuard agg(op);
|
||||
auto dx = op->grad(out, dout, x, x_index);
|
||||
if (x->loop_options)
|
||||
dx->loop_options = x->loop_options;
|
||||
|
@ -182,7 +195,10 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
douts[i] = nullptr;
|
||||
}
|
||||
trace_grad_op = op;
|
||||
op->grads(douts, dins);
|
||||
{
|
||||
AmpGradGuard agg(op);
|
||||
op->grads(douts, dins);
|
||||
}
|
||||
// dump "for (Var* in : op->inputs())"
|
||||
for (int i=0; i<n_i; i++,j++) {
|
||||
auto id = id_buffer[j].second;
|
||||
|
|
|
@ -167,7 +167,7 @@ inline JK& operator<<(JK& jk, int64 c) {
|
|||
}
|
||||
|
||||
#ifdef __linux__
|
||||
inline JK& operator<<(JK& jk, long long int c) {
|
||||
inline JK& operator<<(JK& jk, int64_t c) {
|
||||
return jk << (int64)c;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -13,7 +13,8 @@ namespace jittor {
|
|||
struct Deleter {
|
||||
std::function<void()> del;
|
||||
inline Deleter(std::function<void()>&& func) : del(move(func)) {}
|
||||
inline ~Deleter() { del(); }
|
||||
inline Deleter() {}
|
||||
inline ~Deleter() { if (del) del(); }
|
||||
};
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -9,6 +9,17 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too");
|
||||
|
||||
DEFINE_FLAG_WITH_SETTER(int, auto_mixed_precision_level, 0, "Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16");
|
||||
|
||||
void setter_auto_mixed_precision_level(int value) {
|
||||
if (value <= 3) amp_reg = 0; else
|
||||
if (value == 4) amp_reg = amp_prefer16; else
|
||||
if (value == 5) amp_reg = amp_prefer16 | amp_array_prefer; else
|
||||
if (value == 6) amp_reg = amp_prefer16 | amp_array_prefer | amp_keep_reduce | amp_keep_white;
|
||||
}
|
||||
|
||||
#define FOR_ALL_TYPES(m) \
|
||||
m(bool) \
|
||||
m(int8) \
|
||||
|
@ -89,15 +100,18 @@ static unordered_set<string> unary_ops = {
|
|||
"erfinv"
|
||||
};
|
||||
|
||||
static unordered_set<string> unary_float_ops = {
|
||||
static unordered_set<string> float_ops = {
|
||||
"log",
|
||||
"exp",
|
||||
"sqrt",
|
||||
"mean",
|
||||
"divide",
|
||||
};
|
||||
static unordered_set<string> unary_int_ops = {
|
||||
static unordered_set<string> int_ops = {
|
||||
"round_int",
|
||||
"floor_int",
|
||||
"ceil_int",
|
||||
"floor_divide",
|
||||
};
|
||||
|
||||
static unordered_set<string> binary_ops = {
|
||||
|
@ -127,6 +141,13 @@ static unordered_set<string> binary_ops = {
|
|||
"mean",
|
||||
};
|
||||
|
||||
|
||||
static unordered_set<string> white_ops = {
|
||||
// "log",
|
||||
"exp",
|
||||
"pow",
|
||||
};
|
||||
|
||||
#define DEFINE_NS(T) NanoString ns_##T;
|
||||
FOR_ALL_NS(DEFINE_NS);
|
||||
|
||||
|
@ -135,6 +156,9 @@ char __ns_to_string[ns_max_size*ns_max_len];
|
|||
int __ns_len[ns_max_size];
|
||||
|
||||
static void init_ns() {
|
||||
dsize_map["float16"] = 1;
|
||||
is_float_map["float16"] = 1;
|
||||
is_unsigned["float16"] = 0;
|
||||
NanoString::ns_t i=0;
|
||||
auto func = [&](const char* name, NanoString& ns) {
|
||||
ns.set(NanoString::_index, i++, NanoString::_index_nbits);
|
||||
|
@ -149,13 +173,16 @@ static void init_ns() {
|
|||
if (unary_ops.count(name)) {
|
||||
ns.set(NanoString::_type, NanoString::_unary, NanoString::_type_nbits);
|
||||
ns.set(NanoString::_bool, is_bool.count(name));
|
||||
ns.set(NanoString::_int, unary_int_ops.count(name));
|
||||
ns.set(NanoString::_float, unary_float_ops.count(name));
|
||||
ns.set(NanoString::_int, int_ops.count(name));
|
||||
ns.set(NanoString::_float, float_ops.count(name));
|
||||
} else
|
||||
if (binary_ops.count(name)) {
|
||||
ns.set(NanoString::_type, NanoString::_binary, NanoString::_type_nbits);
|
||||
ns.set(NanoString::_bool, is_bool.count(name));
|
||||
ns.set(NanoString::_int, int_ops.count(name));
|
||||
ns.set(NanoString::_float, float_ops.count(name));
|
||||
}
|
||||
ns.set(NanoString::_white_list, white_ops.count(name));
|
||||
__string_to_ns[name] = ns;
|
||||
auto name2 = ns.to_cstring();
|
||||
int len=0;
|
||||
|
|
|
@ -24,6 +24,7 @@ constexpr int ns_max_len = 16;
|
|||
m(uint16) \
|
||||
m(uint32) \
|
||||
m(uint64) \
|
||||
m(float16) \
|
||||
m(float32) \
|
||||
m(float64) \
|
||||
\
|
||||
|
@ -100,7 +101,7 @@ struct NanoString {
|
|||
typedef uint16 ns_t;
|
||||
enum Flags {
|
||||
// bit0~7: index
|
||||
_index=0, _index_nbits=8,
|
||||
_index=0, _index_nbits=7,
|
||||
_n=_index_nbits,
|
||||
|
||||
// bit0-1: type
|
||||
|
@ -116,6 +117,8 @@ struct NanoString {
|
|||
_float=_n+5,
|
||||
// bit6-7: dsize(1,2,4,8 byte)
|
||||
_dsize=_n+6, _dsize_nbits=2,
|
||||
// bit8: white list
|
||||
_white_list=_n+8,
|
||||
};
|
||||
ns_t data=0;
|
||||
|
||||
|
@ -130,11 +133,16 @@ struct NanoString {
|
|||
inline ns_t index() const { return get(_index, _index_nbits); }
|
||||
inline int len() const { return __ns_len[index()]; }
|
||||
inline ns_t type() const { return get(_type, _type_nbits); }
|
||||
inline ns_t is_bool() const { return get(_bool); }
|
||||
inline ns_t is_int() const { return get(_int); }
|
||||
inline ns_t is_unsigned() const { return get(_unsigned); }
|
||||
inline ns_t is_float() const { return get(_float); }
|
||||
// @pyjt(is_bool)
|
||||
inline bool is_bool() const { return get(_bool); }
|
||||
// @pyjt(is_int)
|
||||
inline bool is_int() const { return get(_int); }
|
||||
inline bool is_unsigned() const { return get(_unsigned); }
|
||||
// @pyjt(is_float)
|
||||
inline bool is_float() const { return get(_float); }
|
||||
inline ns_t is_white() const { return get(_white_list); }
|
||||
inline ns_t dsize() const { return 1<<get(_dsize, _dsize_nbits); }
|
||||
inline ns_t dsize_() const { return get(_dsize, _dsize_nbits); }
|
||||
inline ns_t is_dtype() const { return get(_type, _type_nbits)==_dtype; }
|
||||
inline ns_t is_binary() const { return get(_type, _type_nbits)==_binary; }
|
||||
inline ns_t is_unary() const { return get(_type, _type_nbits)==_unary; }
|
||||
|
@ -156,28 +164,6 @@ struct NanoString {
|
|||
{ return __ns_to_string+index()*ns_max_len; }
|
||||
};
|
||||
|
||||
// force_type = 1 for int, 2 for float
|
||||
inline
|
||||
NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0, NanoString op=ns_void) {
|
||||
bool is_float = v1.is_float() || v2.is_float();
|
||||
int dsize = std::max(v1.dsize(), v2.dsize());
|
||||
if (force_type == 1)
|
||||
is_float = false;
|
||||
else if (force_type == 2)
|
||||
is_float = true;
|
||||
if (is_float) {
|
||||
if (dsize==4) return ns_float32;
|
||||
return ns_float64;
|
||||
} else {
|
||||
if (dsize==8) return ns_int64;
|
||||
if (dsize==4) return ns_int32;
|
||||
if (dsize==2) return ns_int16;
|
||||
if (op.data == ns_add.data || op.data == ns_subtract.data)
|
||||
return ns_int8;
|
||||
return v1;
|
||||
}
|
||||
}
|
||||
|
||||
// @pyjt(NanoString.__eq__)
|
||||
inline bool eq(const NanoString& a, const NanoString& b) {
|
||||
return a.data == b.data;
|
||||
|
@ -199,4 +185,72 @@ inline std::ostream& operator<<(std::ostream& os, const NanoString& v) {
|
|||
return os << v.to_cstring();
|
||||
}
|
||||
|
||||
EXTERN_LIB int amp_reg;
|
||||
constexpr int amp_prefer32 = 1;
|
||||
constexpr int amp_prefer16 = 2;
|
||||
constexpr int amp_keep_reduce = 4;
|
||||
constexpr int amp_keep_white = 8;
|
||||
constexpr int amp_array_prefer = 16;
|
||||
|
||||
inline NanoString float_dtype(int dsize_) {
|
||||
if (amp_reg & amp_prefer32) return ns_float32;
|
||||
if (amp_reg & amp_prefer16) return ns_float16;
|
||||
return (dsize_ == 3) ? ns_float64 :
|
||||
(dsize_ == 2 ) ? ns_float32 : ns_float16;
|
||||
}
|
||||
|
||||
inline NanoString int_dtype(int dsize_) {
|
||||
return (dsize_ == 3) ? ns_int64 :
|
||||
(dsize_ == 2) ? ns_int32 :
|
||||
(dsize_ == 1) ? ns_int16 : ns_int8;
|
||||
}
|
||||
|
||||
inline NanoString dtype_infer(NanoString x, NanoString y) {
|
||||
int dsize_ = std::max(x.dsize_(), y.dsize_());
|
||||
bool is_float = x.is_float() || y.is_float();
|
||||
if (is_float)
|
||||
return float_dtype(dsize_);
|
||||
else {
|
||||
return int_dtype(dsize_);
|
||||
}
|
||||
}
|
||||
|
||||
inline NanoString binary_dtype_infer(NanoString op, NanoString x, NanoString y) {
|
||||
if (op.is_bool()) return ns_bool;
|
||||
int dsize_ = std::max(x.dsize_(), y.dsize_());
|
||||
bool is_float = !op.is_int() &&
|
||||
(x.is_float() || y.is_float() || op.is_float());
|
||||
if (is_float) {
|
||||
if (op.is_white() && !(amp_reg & amp_keep_white))
|
||||
return (dsize_ == 3) ? ns_float64 : ns_float32;
|
||||
return float_dtype(dsize_);
|
||||
} else {
|
||||
return int_dtype(dsize_);
|
||||
}
|
||||
}
|
||||
|
||||
inline NanoString unary_dtype_infer(NanoString op, NanoString x) {
|
||||
if (op.is_bool()) return ns_bool;
|
||||
int dsize_ = x.dsize_();
|
||||
if (op.is_float()) {
|
||||
if (op.is_white() && !(amp_reg & amp_keep_white))
|
||||
return (dsize_ == 3) ? ns_float64 : ns_float32;
|
||||
return float_dtype(dsize_);
|
||||
}
|
||||
if (op.is_int()) return int_dtype(dsize_);
|
||||
return x;
|
||||
}
|
||||
|
||||
inline NanoString reduce_dtype_infer(NanoString op, NanoString x) {
|
||||
bool is_float = x.is_float() || op.is_float();
|
||||
int dsize_ = x.dsize_();
|
||||
if (is_float) {
|
||||
if (amp_reg & amp_keep_reduce)
|
||||
return float_dtype(dsize_);
|
||||
return (dsize_ == 3) ? ns_float64 : ns_float32;
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -51,8 +51,14 @@ struct NodeFlags {
|
|||
_grads=_n+6,
|
||||
// bit7: has graph optimize
|
||||
_has_gopt=_n+7,
|
||||
// bit7: has vary input
|
||||
// bit8: has vary input
|
||||
_has_vary_input=_n+8,
|
||||
// bit9: prefer 32 bit
|
||||
_prefer_32=_n+9,
|
||||
// bit10: force 16 bit
|
||||
_prefer_16=_n+10,
|
||||
// bit11: reduce keep type unchange
|
||||
_reduce_keep=_n+11,
|
||||
};
|
||||
|
||||
inline void set(Flags f, int a=1, int nbits=1) {
|
||||
|
@ -90,7 +96,7 @@ struct Node {
|
|||
operator Var*() { return (Var*)node; }
|
||||
operator var_output_t() { return {(Op*)node, index}; }
|
||||
};
|
||||
static int64_t tflag_count;
|
||||
static int64 tflag_count;
|
||||
NodeFlags flags;
|
||||
NanoString ns;
|
||||
inline bool is_var() const { return flags.get(NodeFlags::_var); }
|
||||
|
|
|
@ -25,11 +25,12 @@ DEFINE_FLAG(int, try_use_32bit_index, 0,
|
|||
string_view_map<jit_op_entry_t> jit_ops;
|
||||
string_view_map<string> jit_key_mapper;
|
||||
|
||||
int64_t Op::number_of_lived_ops = 0;
|
||||
int64 Op::number_of_lived_ops = 0;
|
||||
|
||||
Op::Op() {
|
||||
flags.set(NodeFlags::_var, 0);
|
||||
flags.set(NodeFlags::_cpu, 1);
|
||||
flags.flags |= ((amp_reg & 7) << NodeFlags::_prefer_32);
|
||||
number_of_lived_ops++;
|
||||
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.record_node(this);
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ namespace jittor {
|
|||
enum OpType {other=0, element=1, broadcast=2, reduce=3};
|
||||
struct Op : Node {
|
||||
vector<VarPtr> outputs_holder;
|
||||
static int64_t number_of_lived_ops;
|
||||
static int64 number_of_lived_ops;
|
||||
|
||||
inline Caster<Var*, Node::input_t> inputs() { CHECK_EXIST; return &_inputs; }
|
||||
inline Caster<Var*, Node::output_t> outputs() { CHECK_EXIST; return &_outputs; }
|
||||
|
|
|
@ -112,7 +112,7 @@ int OpCompiler::total_member_count() {
|
|||
return member_count;
|
||||
}
|
||||
|
||||
int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>& vars) {
|
||||
int64 OpCompiler::eval(const string& expr, const unordered_map<string,string>& vars) {
|
||||
if (expr.find("@") != string::npos) {
|
||||
string new_expr;
|
||||
for (size_t i=0; i<expr.size(); i++) {
|
||||
|
|
|
@ -418,21 +418,13 @@ unordered_set<string> binary_ops = {
|
|||
"bitwise_xor",
|
||||
};
|
||||
|
||||
NanoString binary_dtype_infer(NanoString op, Var* x, Var* y) {
|
||||
if (op == ns_mean) return dtype_infer(x->ns, y->ns, 2); // force float
|
||||
int force_type=0;
|
||||
if (op == ns_divide) force_type=2; // force float
|
||||
if (op == ns_floor_divide) force_type=1; // force int
|
||||
return op.is_bool() ? ns_bool : dtype_infer(x->ns, y->ns, force_type, op);
|
||||
}
|
||||
|
||||
BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::element);
|
||||
ns = op;
|
||||
ASSERT(ns.is_binary());
|
||||
z = create_output(nullptr, binary_dtype_infer(op, x, y));
|
||||
z = create_output(nullptr, binary_dtype_infer(op, x->ns, y->ns));
|
||||
}
|
||||
|
||||
VarPtr dirty_clone_broadcast(Var* v) {
|
||||
|
|
|
@ -32,9 +32,11 @@ void op_registe(const OpInfo& op_info);
|
|||
bool has_op(const string& name);
|
||||
OpInfo get_op_info(const string& name);
|
||||
|
||||
struct OpCompiler;
|
||||
struct OpByType {
|
||||
unordered_set<string> types;
|
||||
virtual string expand_op(const vector<string>& args) = 0;
|
||||
virtual void post_pass(OpCompiler*) = 0;
|
||||
};
|
||||
|
||||
extern vector<OpByType*> op_types;
|
||||
|
|
|
@ -271,7 +271,7 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
|
|||
if (x->dtype() == ns_bool)
|
||||
y = create_output(nullptr, ns_int32);
|
||||
else
|
||||
y = create_output(nullptr, binary_dtype_infer(ns, x, x));
|
||||
y = create_output(nullptr, reduce_dtype_infer(ns, x->ns));
|
||||
}
|
||||
|
||||
ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask)
|
||||
|
@ -283,7 +283,7 @@ ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask)
|
|||
ASSERT(ns.is_binary());
|
||||
reduce_mask = dims_mask;
|
||||
this->keepdims_mask = keepdims_mask;
|
||||
y = create_output(nullptr, binary_dtype_infer(ns, x, x));
|
||||
y = create_output(nullptr, reduce_dtype_infer(ns, x->ns));
|
||||
}
|
||||
|
||||
ReduceOp::ReduceOp(Var* x, NanoString op, int dim, bool keepdims)
|
||||
|
@ -359,8 +359,8 @@ void ReduceOp::jit_run() {
|
|||
@for(i, DIM-2, -1, -1, auto ystride@i = ystride@{i+1} * yshape@{i+1};)
|
||||
index_t xstride@{DIM-1} = 1;
|
||||
@for(i, DIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};)
|
||||
Ty count = Ty(x->num) / Ty(y->num);
|
||||
Ty rcount = Ty(y->num) / Ty(x->num);
|
||||
Ty count = x->num*1.0 / y->num;
|
||||
Ty rcount = y->num*1.0 / x->num;
|
||||
@for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) {
|
||||
auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d));
|
||||
yp[yid] = @expand_op(init_@OP, @Ty);
|
||||
|
|
|
@ -132,7 +132,7 @@ void ReindexOp::jit_run() {
|
|||
@for(d, 0, XDIM, index_t xid@d = @expand_macro(INDEX@d);)
|
||||
auto xid = @for(d, 0, XDIM, + xid@d * xstride@d);
|
||||
bool check_overflow = 0 @for(d, 0, XDIM, || xid@d<0 || xid@d>=xshape@d) @for(d, 0, OSIZE, || (@expand_macro(OFD@d)));
|
||||
yp[yid] = check_overflow ? (@OVERFLOW) : xp[xid];
|
||||
yp[yid] = check_overflow ? Tx(@OVERFLOW) : xp[xid];
|
||||
}
|
||||
}
|
||||
#endif // JIT
|
||||
|
|
|
@ -28,6 +28,12 @@ TransposeOp::TransposeOp(Var* x, NanoVector axes_) : x(x), axes(axes_) {
|
|||
for (int i=0; i<(int)xdim; i++)
|
||||
axes.push_back(xdim-1-i);
|
||||
}
|
||||
if (axes.size() < xdim || (axes.size() == xdim && axes[xdim-1]==xdim-1)) {
|
||||
static VarPtr(*fuse_transpose)(Var*, NanoVector) = get_op_info("fuse_transpose").get_constructor<VarPtr, Var*, NanoVector>();
|
||||
auto var = fuse_transpose(x, axes);
|
||||
forward(var);
|
||||
return;
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static VarPtr(*cutt_transpose)(Var*, NanoVector) = nullptr;
|
||||
|
|
|
@ -32,6 +32,7 @@ static unordered_set<string> unary_ops = {
|
|||
"uint16",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"float16",
|
||||
"float32",
|
||||
"float64",
|
||||
// please keep float64 the last type
|
||||
|
@ -533,22 +534,15 @@ UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
|
|||
ns = op;
|
||||
ASSERT(ns.is_unary() | ns.is_dtype());
|
||||
NanoString dtype;
|
||||
if (ns == x->dtype()) {
|
||||
forward(x);
|
||||
return;
|
||||
}
|
||||
if (ns.is_dtype()) {
|
||||
if (ns == x->dtype()) {
|
||||
forward(x);
|
||||
return;
|
||||
}
|
||||
dtype = ns;
|
||||
ns = ns_cast;
|
||||
} else if (ns.is_bool())
|
||||
dtype = ns_bool;
|
||||
else if (ns.is_float())
|
||||
dtype = dtype_infer(x->ns, x->ns, 2);
|
||||
else if (ns.is_int())
|
||||
dtype = dtype_infer(x->ns, x->ns, 1);
|
||||
else {
|
||||
dtype = x->ns;
|
||||
}
|
||||
} else
|
||||
dtype = unary_dtype_infer(ns, x->ns);
|
||||
y = create_output(nullptr, dtype);
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
namespace jittor {
|
||||
|
||||
using namespace expr;
|
||||
extern int use_cuda;
|
||||
|
||||
struct OpInspector {
|
||||
// binary mask for
|
||||
|
@ -229,9 +230,14 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
if (!(bop->y->input() && bop->x->input() && fop->has(bop->x->input()) && fop->has(bop->y->input()))) continue;
|
||||
if (!(bop->x->input()->type()==OpType::broadcast && bop->y->input()->type()==OpType::broadcast)) return;
|
||||
|
||||
// only support float32 currently
|
||||
if (bop->z->dtype() != ns_float32)
|
||||
continue;
|
||||
// only support float32,float16 currently
|
||||
if (use_cuda) {
|
||||
if (bop->z->dtype() != ns_float32 && bop->z->dtype() != ns_float16)
|
||||
continue;
|
||||
} else {
|
||||
if (bop->z->dtype() != ns_float32)
|
||||
continue;
|
||||
}
|
||||
Op* ops[3] = {op, bop->x->input(), bop->y->input()};
|
||||
int ok = 0;
|
||||
LOGvvvv << "conv like op" << fop << fop->get_jit_key(get_jk());
|
||||
|
|
|
@ -262,7 +262,7 @@ void Profiler::record_and_run(
|
|||
Deleter _d;
|
||||
if (is_fused) {
|
||||
auto fop = ((FusedOp*)op);
|
||||
if (fop->context && fop->context->entry) {
|
||||
if (fop->context && fop->context->vrm.relay_groups.size()) {
|
||||
// relay op
|
||||
loop = rerun;
|
||||
profiler.relay_extra_cost = 0;
|
||||
|
|
|
@ -21,7 +21,9 @@ NanoString npy2ns[] = {
|
|||
ns_int64, ns_uint64,
|
||||
ns_float32, ns_float64, ns_float64,
|
||||
ns_void, ns_void, ns_void,
|
||||
ns_void
|
||||
ns_void, // 17
|
||||
ns_void, ns_void, ns_void, ns_void, ns_void, // 22
|
||||
ns_float16, // 23
|
||||
};
|
||||
|
||||
NPY_TYPES ns2npy[] = {
|
||||
|
@ -34,7 +36,7 @@ NPY_TYPES ns2npy[] = {
|
|||
NPY_BYTE, NPY_SHORT, NPY_INT, NPY_LONGLONG,
|
||||
NPY_UBYTE, NPY_USHORT, NPY_UINT, NPY_ULONGLONG,
|
||||
#endif
|
||||
NPY_FLOAT, NPY_DOUBLE
|
||||
NPY_HALF, NPY_FLOAT, NPY_DOUBLE
|
||||
};
|
||||
|
||||
void** PyArray_API;
|
||||
|
|
|
@ -48,6 +48,8 @@ enum NPY_TYPES {
|
|||
NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
|
||||
NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE,
|
||||
NPY_OBJECT=17,
|
||||
NPY_HALF=23,
|
||||
NPY_END=24,
|
||||
};
|
||||
|
||||
EXTERN_LIB NanoString npy2ns[];
|
||||
|
@ -60,11 +62,11 @@ EXTERN_LIB NPY_TYPES ns2npy[];
|
|||
inline bool is_c_style(PyArray_Proxy* obj) { return obj->flags & 1; }
|
||||
inline NanoString get_type_str(PyArray_Proxy* obj) {
|
||||
NanoString type = ns_void;
|
||||
if (obj->descr->type_num < NPY_OBJECT)
|
||||
if (obj->descr->type_num < NPY_END)
|
||||
type = npy2ns[obj->descr->type_num];
|
||||
CHECK(type != ns_void) << "Numpy type not support, type_num:"
|
||||
<< obj->descr->type_num
|
||||
<< "type_char:" << obj->descr->type;
|
||||
<< "type_char:" << obj->descr->type << NPY_END << npy2ns[obj->descr->type_num];
|
||||
return type;
|
||||
}
|
||||
|
||||
|
|
|
@ -141,7 +141,7 @@ ArrayOp::ArrayOp(PyObject* obj) {
|
|||
} else {
|
||||
// this is non-continue numpy array
|
||||
#if defined(__linux__) || defined(_WIN32)
|
||||
STACK_ALLOC(int64, dims, args.shape.size());
|
||||
STACK_ALLOC(int64_t, dims, args.shape.size());
|
||||
#elif defined(__APPLE__)
|
||||
long dims[args.shape.size()];
|
||||
#endif
|
||||
|
|
|
@ -274,7 +274,7 @@ DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) {
|
|||
|
||||
DEF_IS(ArrayArgs, PyObject*) to_py_object(const T& a) {
|
||||
#if defined(__linux__) || defined(_WIN32)
|
||||
STACK_ALLOC(int64, dims, a.shape.size());
|
||||
STACK_ALLOC(int64_t, dims, a.shape.size());
|
||||
#elif defined(__APPLE__)
|
||||
long dims[a.shape.size()];
|
||||
#endif
|
||||
|
@ -390,7 +390,7 @@ DEF_IS(VarHolder*, T) from_py_object(PyObject* obj, unique_ptr<VarHolder>& holde
|
|||
struct DataView;
|
||||
DEF_IS(DataView, PyObject*) to_py_object(T a) {
|
||||
#if defined(__linux__) || defined(_WIN32)
|
||||
STACK_ALLOC(int64, dims, a.shape.size());
|
||||
STACK_ALLOC(int64_t, dims, a.shape.size());
|
||||
#elif defined(__APPLE__)
|
||||
long dims[a.shape.size()];
|
||||
#endif
|
||||
|
|
|
@ -110,7 +110,7 @@ static void push_py_object(RingBuffer* rb, PyObject* obj, uint64& __restrict__ o
|
|||
rb->push(size, offset);
|
||||
args.ptr = rb->get_ptr(size, offset);
|
||||
#if defined(__linux__) || defined(_WIN32)
|
||||
STACK_ALLOC(int64, dims, args.shape.size());
|
||||
STACK_ALLOC(int64_t, dims, args.shape.size());
|
||||
#elif defined(__APPLE__)
|
||||
long dims[args.shape.size()];
|
||||
#endif
|
||||
|
|
|
@ -12,6 +12,44 @@ namespace jittor {
|
|||
|
||||
extern int use_cuda;
|
||||
|
||||
unordered_map<string,string> common_op_type_cuda_map = {
|
||||
{"logical_not", "(!($2))"},
|
||||
{"bitwise_not", "(~($2))"},
|
||||
{"negative", "(-($2))"},
|
||||
{"abs", "::abs($2)"},
|
||||
{"log", "::logf(($1)($2))"},
|
||||
{"exp", "::expf(($1)($2))"},
|
||||
{"sqrt", "::sqrtf(($1)($2))"},
|
||||
{"round", "(($1) ::roundf(($2)))"},
|
||||
{"floor", "(($1) ::floorf(($2)))"},
|
||||
{"ceil", "(($1) ::ceilf(($2)))"},
|
||||
{"round_int", "(($1) ::roundf(($2)))"},
|
||||
{"floor_int", "(($1) ::floorf(($2)))"},
|
||||
{"ceil_int", "(($1) ::ceilf(($2)))"},
|
||||
{"sin", "(($1) ::sinf(($2)))"},
|
||||
{"asin", "(($1) ::asinf(($2)))"},
|
||||
{"sinh", "(($1) ::sinhf(($2)))"},
|
||||
{"asinh", "(($1) ::asinhf(($2)))"},
|
||||
{"cos", "(($1) ::cosf(($2)))"},
|
||||
{"acos", "(($1) ::acosf(($2)))"},
|
||||
{"cosh", "(($1) ::coshf(($2)))"},
|
||||
{"acosh", "(($1) ::acoshf(($2)))"},
|
||||
{"tan", "(($1) ::tanf(($2)))"},
|
||||
{"atan", "(($1) ::atanf(($2)))"},
|
||||
{"tanh", "(($1) ::tanhf(($2)))"},
|
||||
{"atanh", "(($1) ::atanhf(($2)))"},
|
||||
{"sigmoid", "(($1) (1.0f/(1.0f+::expf((::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300))))))))"},
|
||||
{"erf", "(($1) ::erff(($2)))"},
|
||||
{"erfinv", "(($1) ::erfinvf(($1)($2)))"},
|
||||
{"cast", "(($1)($2))"},
|
||||
{"pow", "::pow(($2),($4))"},
|
||||
{"maximum", "::max($1($2), $1($4))"},
|
||||
{"minimum", "::min($1($2), $1($4))"},
|
||||
{"mod", "@if(@strcmp($1,float32)==0,(($2)-::floorf(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-::floor(($2)/($4))*($4)),(($2)%($4))))"},
|
||||
{"init_maximum", "::numeric_min<$1>()"},
|
||||
{"init_minimum", "::numeric_max<$1>()"},
|
||||
};
|
||||
|
||||
struct CommonOpType : OpByType {
|
||||
CommonOpType() {
|
||||
types = {
|
||||
|
@ -34,43 +72,7 @@ struct CommonOpType : OpByType {
|
|||
if (!types.count(args[i]))
|
||||
return "";
|
||||
}
|
||||
static unordered_map<string,string> cuda_map = {
|
||||
{"logical_not", "(!($2))"},
|
||||
{"bitwise_not", "(~($2))"},
|
||||
{"negative", "(-($2))"},
|
||||
{"abs", "::abs($2)"},
|
||||
{"log", "::logf(($1)($2))"},
|
||||
{"exp", "::expf(($1)($2))"},
|
||||
{"sqrt", "::sqrtf(($1)($2))"},
|
||||
{"round", "(($1) ::roundf(($2)))"},
|
||||
{"floor", "(($1) ::floorf(($2)))"},
|
||||
{"ceil", "(($1) ::ceilf(($2)))"},
|
||||
{"round_int", "(($1) ::roundf(($2)))"},
|
||||
{"floor_int", "(($1) ::floorf(($2)))"},
|
||||
{"ceil_int", "(($1) ::ceilf(($2)))"},
|
||||
{"sin", "(($1) ::sinf(($2)))"},
|
||||
{"asin", "(($1) ::asinf(($2)))"},
|
||||
{"sinh", "(($1) ::sinhf(($2)))"},
|
||||
{"asinh", "(($1) ::asinhf(($2)))"},
|
||||
{"cos", "(($1) ::cosf(($2)))"},
|
||||
{"acos", "(($1) ::acosf(($2)))"},
|
||||
{"cosh", "(($1) ::coshf(($2)))"},
|
||||
{"acosh", "(($1) ::acoshf(($2)))"},
|
||||
{"tan", "(($1) ::tanf(($2)))"},
|
||||
{"atan", "(($1) ::atanf(($2)))"},
|
||||
{"tanh", "(($1) ::tanhf(($2)))"},
|
||||
{"atanh", "(($1) ::atanhf(($2)))"},
|
||||
{"sigmoid", "(($1) (1.0f/(1.0f+::expf((::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300))))))))"},
|
||||
{"erf", "(($1) ::erff(($2)))"},
|
||||
{"erfinv", "(($1) ::erfinvf(($1)($2)))"},
|
||||
{"cast", "(($1)($2))"},
|
||||
{"pow", "::pow(($2),($4))"},
|
||||
{"maximum", "::max($1($2), $1($4))"},
|
||||
{"minimum", "::min($1($2), $1($4))"},
|
||||
{"mod", "@if(@strcmp($1,float32)==0,(($2)-::floorf(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-::floor(($2)/($4))*($4)),(($2)%($4))))"},
|
||||
{"init_maximum", "::numeric_min<$1>()"},
|
||||
{"init_minimum", "::numeric_max<$1>()"},
|
||||
};
|
||||
auto& cuda_map = common_op_type_cuda_map;
|
||||
|
||||
static unordered_map<string,string> cpu_map = {
|
||||
{"logical_not", "(!($2))"},
|
||||
|
@ -151,6 +153,10 @@ struct CommonOpType : OpByType {
|
|||
ret = cpu_map[args.at(0)];
|
||||
return format(ret, args);
|
||||
}
|
||||
|
||||
void post_pass(OpCompiler*) {
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
#ifdef JIT_cuda
|
||||
|
||||
#include <driver_types.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace jittor {
|
||||
|
||||
typedef __half float16;
|
||||
|
||||
#if CUDA_ARCH >= 800
|
||||
inline __device__ float16 max(float16 a, float16 b) { return __hmax(a, b); }
|
||||
inline __device__ float16 min(float16 a, float16 b) { return __hmin(a, b); }
|
||||
#else
|
||||
inline __device__ float16 max(float16 a, float16 b) { return a<b?b:a; }
|
||||
inline __device__ float16 min(float16 a, float16 b) { return a<b?a:b; }
|
||||
#endif
|
||||
|
||||
inline __device__ float16 pow(float16 a, float16 b) { return ::pow(float32(a), float32(b)); }
|
||||
|
||||
|
||||
template<int nbyte, class T>
|
||||
__device__ inline void vload(T* __restrict__ a, T* __restrict__ b) {
|
||||
if constexpr (nbyte<=0) return;
|
||||
if constexpr (nbyte>=16) {
|
||||
auto __restrict__ aa = (float4* __restrict__)a;
|
||||
auto __restrict__ bb = (float4* __restrict__)b;
|
||||
aa[0] = bb[0];
|
||||
return vload<nbyte-16>(aa+1, bb+1);
|
||||
}
|
||||
if constexpr (nbyte>=8) {
|
||||
auto __restrict__ aa = (float2* __restrict__)a;
|
||||
auto __restrict__ bb = (float2* __restrict__)b;
|
||||
aa[0] = bb[0];
|
||||
return vload<nbyte-8>(aa+1, bb+1);
|
||||
}
|
||||
if constexpr (nbyte>=4) {
|
||||
auto __restrict__ aa = (float* __restrict__)a;
|
||||
auto __restrict__ bb = (float* __restrict__)b;
|
||||
aa[0] = bb[0];
|
||||
return vload<nbyte-4>(aa+1, bb+1);
|
||||
}
|
||||
if constexpr (nbyte>=2) {
|
||||
auto __restrict__ aa = (__half* __restrict__)a;
|
||||
auto __restrict__ bb = (__half* __restrict__)b;
|
||||
aa[0] = bb[0];
|
||||
return vload<nbyte-2>(aa+1, bb+1);
|
||||
}
|
||||
if constexpr (nbyte>=1) {
|
||||
auto __restrict__ aa = (int8_t* __restrict__)a;
|
||||
auto __restrict__ bb = (int8_t* __restrict__)b;
|
||||
aa[0] = bb[0];
|
||||
return vload<nbyte-1>(aa+1, bb+1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
using jittor::max;
|
||||
using jittor::min;
|
||||
using jittor::pow;
|
||||
|
||||
#else
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct float16 {
|
||||
uint16 x;
|
||||
|
||||
inline float16(float32 f) {
|
||||
unsigned x = *((int*)(void*)(&f));
|
||||
unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
|
||||
unsigned sign, exponent, mantissa;
|
||||
|
||||
|
||||
// Get rid of +NaN/-NaN case first.
|
||||
if (u > 0x7f800000) {
|
||||
this->x = 0x7fffU;
|
||||
return;
|
||||
}
|
||||
|
||||
sign = ((x >> 16) & 0x8000);
|
||||
|
||||
// Get rid of +Inf/-Inf, +0/-0.
|
||||
if (u > 0x477fefff) {
|
||||
this->x = sign | 0x7c00U;
|
||||
return;
|
||||
}
|
||||
if (u < 0x33000001) {
|
||||
this->x = sign | 0x0000U;
|
||||
return;
|
||||
}
|
||||
|
||||
exponent = ((u >> 23) & 0xff);
|
||||
mantissa = (u & 0x7fffff);
|
||||
|
||||
if (exponent > 0x70) {
|
||||
shift = 13;
|
||||
exponent -= 0x70;
|
||||
} else {
|
||||
shift = 0x7e - exponent;
|
||||
exponent = 0;
|
||||
mantissa |= 0x800000;
|
||||
}
|
||||
lsb = (1 << shift);
|
||||
lsb_s1 = (lsb >> 1);
|
||||
lsb_m1 = (lsb - 1);
|
||||
|
||||
// Round to nearest even.
|
||||
remainder = (mantissa & lsb_m1);
|
||||
mantissa >>= shift;
|
||||
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
|
||||
++mantissa;
|
||||
if (!(mantissa & 0x3ff)) {
|
||||
++exponent;
|
||||
mantissa = 0;
|
||||
}
|
||||
}
|
||||
|
||||
this->x = (sign | (exponent << 10) | mantissa);
|
||||
}
|
||||
|
||||
inline operator float() {
|
||||
|
||||
unsigned sign = ((x >> 15) & 1);
|
||||
unsigned exponent = ((x >> 10) & 0x1f);
|
||||
unsigned mantissa = ((x & 0x3ff) << 13);
|
||||
|
||||
if (exponent == 0x1f) { /* NaN or Inf */
|
||||
mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
|
||||
exponent = 0xff;
|
||||
} else if (!exponent) { /* Denorm or Zero */
|
||||
if (mantissa) {
|
||||
unsigned int msb;
|
||||
exponent = 0x71;
|
||||
do {
|
||||
msb = (mantissa & 0x400000);
|
||||
mantissa <<= 1; /* normalize */
|
||||
--exponent;
|
||||
} while (!msb);
|
||||
mantissa &= 0x7fffff; /* 1.mantissa is implicit */
|
||||
}
|
||||
} else {
|
||||
exponent += 0x70;
|
||||
}
|
||||
|
||||
int temp = ((sign << 31) | (exponent << 23) | mantissa);
|
||||
|
||||
return reinterpret_cast<float&>(temp);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,188 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "common.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "op_compiler.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
extern int use_cuda;
|
||||
|
||||
extern unordered_map<string,string> common_op_type_cuda_map;
|
||||
|
||||
static bool isvar(char x) { return isalnum(x) || x == '_' || x == ':'; }
|
||||
|
||||
struct FP16OpType : OpByType {
|
||||
FP16OpType() {
|
||||
types = {
|
||||
"float16",
|
||||
};
|
||||
}
|
||||
|
||||
string expand_op(const vector<string>& args) {
|
||||
bool found_fp16 = 0;
|
||||
for (int i=1; i<args.size(); i+=2) {
|
||||
if (types.count(args[i]))
|
||||
found_fp16 = 1;
|
||||
}
|
||||
if (!found_fp16) return "";
|
||||
static unordered_map<string,string> cuda_map = {
|
||||
{"logical_not", "(!($2))"},
|
||||
{"bitwise_not", "(~($2))"},
|
||||
{"negative", "(-($2))"},
|
||||
{"abs", "::abs($2)"},
|
||||
{"log", "::hlog(($1)($2))"},
|
||||
{"exp", "::hexp(($1)($2))"},
|
||||
{"sqrt", "::hsqrt(($1)($2))"},
|
||||
{"round", "(($1) ::roundf(($2)))"},
|
||||
{"floor", "(($1) ::floorf(($2)))"},
|
||||
{"ceil", "(($1) ::ceilf(($2)))"},
|
||||
{"round_int", "(($1) ::roundf(($2)))"},
|
||||
{"floor_int", "(($1) ::floorf(($2)))"},
|
||||
{"ceil_int", "(($1) ::ceilf(($2)))"},
|
||||
{"sin", "(($1) ::sinf(($2)))"},
|
||||
{"asin", "(($1) ::asinf(($2)))"},
|
||||
{"sinh", "(($1) ::sinhf(($2)))"},
|
||||
{"asinh", "(($1) ::asinhf(($2)))"},
|
||||
{"cos", "(($1) ::cosf(($2)))"},
|
||||
{"acos", "(($1) ::acosf(($2)))"},
|
||||
{"cosh", "(($1) ::coshf(($2)))"},
|
||||
{"acosh", "(($1) ::acoshf(($2)))"},
|
||||
{"tan", "(($1) ::tanf(($2)))"},
|
||||
{"atan", "(($1) ::atanf(($2)))"},
|
||||
{"tanh", "(($1) ::tanhf(($2)))"},
|
||||
{"atanh", "(($1) ::atanhf(($2)))"},
|
||||
{"sigmoid", "(($1) (1.0f/(1.0f+::expf((::min($1(-($2)), $1(@if(@strcmp($1,float16)==0,30,300))))))))"},
|
||||
{"erf", "(($1) ::erff(($2)))"},
|
||||
{"erfinv", "(($1) ::erfinvf(($1)($2)))"},
|
||||
{"cast", "(($1)($2))"},
|
||||
{"pow", "::pow(($2),($4))"},
|
||||
{"maximum", "::max($1($2), $1($4))"},
|
||||
{"minimum", "::min($1($2), $1($4))"},
|
||||
{"mod", "$1(($2)-::hfloor(($2)/($4))*($4))"},
|
||||
{"init_maximum", "-32768.0f"},
|
||||
{"init_minimum", "32768.0f"},
|
||||
};
|
||||
|
||||
static unordered_map<string,string> cpu_map = {
|
||||
{"logical_not", "(!($2))"},
|
||||
{"bitwise_not", "(~($2))"},
|
||||
{"negative", "(-($2))"},
|
||||
{"abs", "std::abs($2)"},
|
||||
{"log", "std::log(($1)($2))"},
|
||||
{"exp", "std::exp(($1)($2))"},
|
||||
{"sqrt", "std::sqrt(($1)($2))"},
|
||||
{"round", "(($1)std::round(($2)))"},
|
||||
{"floor", "(($1)std::floor(($2)))"},
|
||||
{"ceil", "(($1)std::ceil(($2)))"},
|
||||
{"round_int", "(($1)std::round(($2)))"},
|
||||
{"floor_int", "(($1)std::floor(($2)))"},
|
||||
{"ceil_int", "(($1)std::ceil(($2)))"},
|
||||
{"sin", "(($1) std::sin(($2)))"},
|
||||
{"asin", "(($1) std::asin(($2)))"},
|
||||
{"sinh", "(($1) std::sinh(($2)))"},
|
||||
{"asinh", "(($1) std::asinh(($2)))"},
|
||||
{"cos", "(($1) std::cos(($2)))"},
|
||||
{"acos", "(($1) std::acos(($2)))"},
|
||||
{"cosh", "(($1) std::cosh(($2)))"},
|
||||
{"acosh", "(($1) std::acosh(($2)))"},
|
||||
{"tan", "(($1) std::tan(($2)))"},
|
||||
{"atan", "(($1) std::atan(($2)))"},
|
||||
{"tanh", "(($1) std::tanh(($2)))"},
|
||||
{"atanh", "(($1) std::atanh(($2)))"},
|
||||
{"sigmoid", "(($1) (1.0f/(1.0f+std::exp(std::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300)))))))"},
|
||||
{"erf", "(($1) std::erf(($2)))"},
|
||||
{"erfinv", "(jittor::_erfinv($2))"},
|
||||
{"cast", "(($1)($2))"},
|
||||
{"pow", "std::pow(($2),($4))"},
|
||||
{"maximum", "std::max($1($2), $1($4))"},
|
||||
{"minimum", "std::min($1($2), $1($4))"},
|
||||
{"mod", "$1(($2)-std::floor(($2)/($4))*($4))"},
|
||||
{"init_maximum", "-32768.0f"},
|
||||
{"init_minimum", "32768.0f"},
|
||||
};
|
||||
|
||||
static unordered_map<string,string> both_map {
|
||||
{"add", "(($2)+($4))"},
|
||||
{"subtract", "(($2)-($4))"},
|
||||
{"multiply", "(($2)*($4))"},
|
||||
{"divide", "($1(($1($2))/($1($4))))"},
|
||||
{"floor_divide", "($1(($1($2))/($1($4))))"},
|
||||
{"less", "(($2)<($4))"},
|
||||
{"less_equal", "(($2)<=($4))"},
|
||||
{"greater", "(($2)>($4))"},
|
||||
{"greater_equal", "(($2)>=($4))"},
|
||||
{"equal", "(($2)==($4))"},
|
||||
{"not_equal", "(($2)!=($4))"},
|
||||
{"left_shift", "(($2)<<($4))"},
|
||||
{"right_shift", "(($2)>>($4))"},
|
||||
{"logical_and", "(($2)&&($4))"},
|
||||
{"logical_or", "(($2)||($4))"},
|
||||
{"logical_xor", "((bool($2))!=(bool($4)))"},
|
||||
{"bitwise_and", "(($2)&($4))"},
|
||||
{"bitwise_or", "(($2)|($4))"},
|
||||
{"bitwise_xor", "(($2)^($4))"},
|
||||
{"mean", "(($2)+($4)*($1(rcount)))"},
|
||||
{"init_add", "$1(0)"},
|
||||
{"init_multiply", "$1(1)"},
|
||||
{"init_logical_and", "true"},
|
||||
{"init_logical_or", "false"},
|
||||
{"init_logical_xor", "false"},
|
||||
{"init_bitwise_and", "$1(-1)"},
|
||||
{"init_bitwise_or", "$1(0)"},
|
||||
{"init_bitwise_xor", "$1(0)"},
|
||||
{"init_mean", "$1(0)"},
|
||||
};
|
||||
|
||||
string ret;
|
||||
if (both_map.count(args.at(0)))
|
||||
ret = both_map[args.at(0)];
|
||||
else if (use_cuda)
|
||||
ret = cuda_map[args.at(0)];
|
||||
else
|
||||
ret = cpu_map[args.at(0)];
|
||||
if (use_cuda) {
|
||||
if (args[1] == "float32" && !both_map.count(args.at(0))) {
|
||||
ret = common_op_type_cuda_map[args.at(0)];
|
||||
}
|
||||
if (args[1] == "float16" || args[1] == "float32") {
|
||||
for (int i=3; i<args.size(); i+=2) {
|
||||
if (args[i] != args[1]) {
|
||||
ret = replace(ret, "$"+S(i-1),
|
||||
args[1]+"($"+S(i-1)+")");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i=3; i<args.size(); i+=2) {
|
||||
if (args[i] != "float16") {
|
||||
ret = replace(ret, "$"+S(i-1),
|
||||
"float16($"+S(i-1)+")");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return format(ret, args);
|
||||
}
|
||||
|
||||
void post_pass(OpCompiler* oc) {
|
||||
string& src = oc->src;
|
||||
if (src.find("float16") == string::npos)
|
||||
return;
|
||||
int i = src.rfind("#include");
|
||||
if (i<0) i=0;
|
||||
i = src.find('\n', i) + 1;
|
||||
src = src.substr(0, i) + "#include \"type/fp16_compute.h\"\n" +
|
||||
src.substr(i);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static int _ = registe_op_type(new FP16OpType());
|
||||
|
||||
}
|
|
@ -18,7 +18,7 @@ namespace jittor {
|
|||
typedef int8_t int8;
|
||||
typedef int16_t int16;
|
||||
typedef int int32;
|
||||
typedef int64_t int64;
|
||||
typedef long long int64;
|
||||
typedef uint8_t uint8;
|
||||
typedef uint16_t uint16;
|
||||
typedef uint32_t uint32;
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
int64_t Var::number_of_lived_vars = 0;
|
||||
int64 Var::number_of_lived_vars = 0;
|
||||
|
||||
DEFINE_FLAG(fast_shared_ptr<loop_options_t>, compile_options, {},
|
||||
"Override the default loop transfrom options");
|
||||
|
@ -42,7 +42,7 @@ string Var::to_string() {
|
|||
return s;
|
||||
}
|
||||
|
||||
int64_t Var::numel() {
|
||||
int64 Var::numel() {
|
||||
if (!shape.size()) return size=num=-1;
|
||||
bool negtive = 0;
|
||||
num=1;
|
||||
|
|
|
@ -18,13 +18,13 @@ struct Var : Node {
|
|||
NanoVector shape;
|
||||
cstr name;
|
||||
fast_shared_ptr<loop_options_t> loop_options;
|
||||
static int64_t number_of_lived_vars;
|
||||
static int64 number_of_lived_vars;
|
||||
|
||||
// this var will be generated after alloc.
|
||||
void* mem_ptr = nullptr;
|
||||
Allocator* allocator = nullptr;
|
||||
size_t allocation;
|
||||
int64_t size, num;
|
||||
int64 size, num;
|
||||
inline bool is_float() const { CHECK_EXIST; return ns.is_float(); }
|
||||
inline int dsize() const { CHECK_EXIST; return ns.dsize(); }
|
||||
inline NanoString dtype() const { CHECK_EXIST; return ns; }
|
||||
|
@ -40,7 +40,7 @@ struct Var : Node {
|
|||
Var(NanoVector shape, NanoString dtype);
|
||||
|
||||
string to_string();
|
||||
int64_t numel();
|
||||
int64 numel();
|
||||
void set_shape(NanoVector shape);
|
||||
bool alloc(Allocator* allocator);
|
||||
inline void share_with(Var* x, size_t offset = 0) { CHECK_EXIST; allocator = (Allocator*)x; allocation = offset; }
|
||||
|
|
|
@ -15,6 +15,7 @@ if __name__ == "__main__":
|
|||
|
||||
skip_l = int(os.environ.get("test_skip_l", "0"))
|
||||
skip_r = int(os.environ.get("test_skip_r", "1000000"))
|
||||
skip = os.environ.get("test_skip", "").split(",")
|
||||
test_only = None
|
||||
if "test_only" in os.environ:
|
||||
test_only = set(os.environ.get("test_only").split(","))
|
||||
|
@ -34,6 +35,9 @@ if __name__ == "__main__":
|
|||
continue
|
||||
if test_only and test_name not in test_only:
|
||||
continue
|
||||
for s in skip:
|
||||
if s in test_name:
|
||||
continue
|
||||
|
||||
print("Add Test", _, test_name)
|
||||
suite.addTest(tests)
|
||||
|
|
|
@ -0,0 +1,374 @@
|
|||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import jittor as jt
|
||||
import jittor.nn as nn
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
split_size = 1000000
|
||||
|
||||
conv_opt = int(os.environ.get("conv_opt", "0"))
|
||||
|
||||
if conv_opt:
|
||||
Conv1d_sp = nn.Conv1d_sp
|
||||
else:
|
||||
Conv1d_sp = nn.Conv1d
|
||||
|
||||
|
||||
def MLP(channels: list, do_bn=True):
|
||||
""" Multi-layer perceptron """
|
||||
n = len(channels)
|
||||
layers = []
|
||||
for i in range(1, n):
|
||||
layers.append(Conv1d_sp(channels[i - 1], channels[i], kernel_size=1, bias=True))
|
||||
if i < (n - 1):
|
||||
if do_bn:
|
||||
layers.append(nn.BatchNorm(channels[i]))
|
||||
# layers.append(nn.InstanceNorm1d(channels[i]))
|
||||
# layers.append(nn.LayerNorm(channels[i]))
|
||||
layers.append(nn.ReLU())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def normalize_keypoints(kpts, image_shape):
|
||||
size = image_shape.flip(1) # shape=(b,2) ;h w -> w, h
|
||||
center = size / 2
|
||||
scaling = size.float32().max(1, keepdims=True) * 0.7
|
||||
return (kpts - center[:, None, :]) / scaling[:, None, :]
|
||||
|
||||
|
||||
class KeypointEncoder(nn.Module):
|
||||
""" Joint encoding of visual appearance and location using MLPs"""
|
||||
def __init__(self, feature_dim, layers, keypoint_position_dim=2):
|
||||
super().__init__()
|
||||
# self.keypoint_position_dim = keypoint_position_dim
|
||||
self.encoder = MLP([keypoint_position_dim + 1] + layers + [feature_dim])
|
||||
nn.init.constant_(self.encoder[-1].bias, 0.0)
|
||||
|
||||
def execute(self, kpts, scores):
|
||||
inputs = jt.concat([kpts.t(), scores.unsqueeze(1)], dim=1)
|
||||
return self.encoder(inputs)
|
||||
|
||||
cnt = 0
|
||||
|
||||
def attention(query, key, value):
|
||||
global cnt
|
||||
cnt += 1
|
||||
b, d, h, n = query.shape
|
||||
# print("attention", b,d,h,n, cnt)
|
||||
dim_factor = (1.0 / d)**0.5
|
||||
query = query.transpose(0, 2, 3, 1).reshape(b * h, -1, d) * dim_factor
|
||||
key = key.transpose(0, 2, 1, 3).reshape(b * h, d, -1)
|
||||
value = value.transpose(0, 2, 3, 1).reshape(b * h, -1, d)
|
||||
# print("attention", query.shape, key.shape, value.shape)
|
||||
|
||||
data = []
|
||||
for i in range(0, query.shape[0], split_size):
|
||||
end = min(i + split_size, query.shape[0])
|
||||
tmp1 = nn.bmm(query[i:end], key[i:end])
|
||||
tmp2 = nn.softmax(tmp1, dim=-1)
|
||||
tmp3 = nn.bmm(tmp2, value[i:end])
|
||||
tmp3.sync()
|
||||
data.append(tmp3)
|
||||
tmp3 = jt.concat(data)
|
||||
|
||||
# for i in range(0, query.shape[0], split_size):
|
||||
# end = min(i + split_size, query.shape[0])
|
||||
# tmp1 = nn.bmm(query[:,i:end], key[:,i:end])
|
||||
# tmp2 = nn.softmax(tmp1, dim=-1)
|
||||
# tmp3 = nn.bmm(tmp2, value[:,i:end])
|
||||
# tmp3.sync()
|
||||
# data.append(tmp3)
|
||||
# tmp3 = jt.concat(data, dim=1)
|
||||
|
||||
# tmp1 = nn.bmm(query, key)
|
||||
# print(tmp1.shape)
|
||||
# tmp2 = nn.softmax(tmp1, dim=-1)
|
||||
# print(tmp2.shape)
|
||||
# tmp3 = nn.bmm(tmp2, value)
|
||||
# print(tmp3.shape)
|
||||
return tmp3.reshape(b, h, -1, d).transpose(0, 3, 1, 2)
|
||||
return nn.bmm(nn.softmax(nn.bmm(query, key), dim=-1), value).reshape(b, h, -1, d).transpose(0, 3, 1, 2)
|
||||
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
""" Multi-head attention to increase model expressivitiy """
|
||||
def __init__(self, num_heads: int, d_model: int):
|
||||
super().__init__()
|
||||
assert d_model % num_heads == 0
|
||||
self.dim = d_model // num_heads
|
||||
self.num_heads = num_heads
|
||||
self.merge = Conv1d_sp(d_model, d_model, kernel_size=1)
|
||||
self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
|
||||
|
||||
def execute(self, query, key, value):
|
||||
batch_dim = query.size(0)
|
||||
query, key, value = [l(x).reshape(batch_dim, self.dim, self.num_heads, -1) for l, x in zip(self.proj, (query, key, value))]
|
||||
x = attention(query, key, value)
|
||||
# x = attention_chunk(query, key, value)
|
||||
return self.merge(x.reshape(batch_dim, self.dim * self.num_heads, -1))
|
||||
|
||||
|
||||
class AttentionalPropagation(nn.Module):
|
||||
def __init__(self, feature_dim: int, num_heads: int):
|
||||
super().__init__()
|
||||
self.attn = MultiHeadedAttention(num_heads, feature_dim)
|
||||
self.mlp = MLP([feature_dim * 2, feature_dim * 2, feature_dim])
|
||||
nn.init.constant_(self.mlp[-1].bias, 0.0)
|
||||
|
||||
def execute(self, x, source):
|
||||
message = self.attn(x, source, source)
|
||||
return self.mlp(jt.concat([x, message], dim=1))
|
||||
|
||||
|
||||
class AttentionalGNN(nn.Module):
|
||||
def __init__(self, feature_dim: int, layer_names: list):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([AttentionalPropagation(feature_dim, 4) for _ in range(len(layer_names))])
|
||||
self.is_cross = [x == 'cross' for x in layer_names]
|
||||
|
||||
def execute(self, desc0, desc1):
|
||||
for layer, is_cross in zip(self.layers, self.is_cross):
|
||||
layer.attn.prob = []
|
||||
if is_cross:
|
||||
src0, src1 = desc1, desc0
|
||||
else: # if name == 'self':
|
||||
src0, src1 = desc0, desc1
|
||||
# delta0, delta1 = layer(desc0, src0), layer(desc1, src1)
|
||||
|
||||
delta0 = layer(desc0, src0)
|
||||
# print(delta0.numel()*4)
|
||||
# breakpoint()
|
||||
jt.sync_all()
|
||||
delta1 = layer(desc1, src1)
|
||||
jt.sync_all()
|
||||
desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
|
||||
jt.sync_all()
|
||||
return desc0, desc1
|
||||
|
||||
|
||||
def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
|
||||
""" Perform Sinkhorn Normalization in Log-space for stability"""
|
||||
u, v = jt.zeros_like(log_mu), jt.zeros_like(log_nu)
|
||||
for _ in range(iters):
|
||||
u = log_mu - (Z + v.unsqueeze(1)).exp().sum(dim=2).log()
|
||||
v = log_nu - (Z + u.unsqueeze(2)).exp().sum(dim=1).log()
|
||||
return Z + u.unsqueeze(2) + v.unsqueeze(1)
|
||||
|
||||
|
||||
def log_optimal_transport(scores, alpha, iters: int):
|
||||
""" Perform Differentiable Optimal Transport in Log-space for stability"""
|
||||
b, m, n = scores.shape
|
||||
ms, ns = jt.float(m, requires_grad=False), jt.float(n, requires_grad=False)
|
||||
|
||||
bins0 = alpha.broadcast([b, m, 1])
|
||||
bins1 = alpha.broadcast([b, 1, n])
|
||||
alpha = alpha.broadcast([b, 1, 1])
|
||||
|
||||
couplings = jt.concat([jt.concat([scores, bins0], -1), jt.concat([bins1, alpha], -1)], 1)
|
||||
|
||||
norm = -(ms + ns).log()
|
||||
log_mu = jt.concat([norm.broadcast([m]), ns.log() + norm])
|
||||
log_nu = jt.concat([norm.broadcast([n]), ms.log() + norm])
|
||||
log_mu, log_nu = log_mu[None].broadcast([b, m + 1]), log_nu[None].broadcast([b, n + 1])
|
||||
|
||||
Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
|
||||
Z = Z - norm # multiply probabilities by M+N
|
||||
return Z
|
||||
|
||||
|
||||
def arange_like(x, dim: int):
|
||||
return jt.ones(x.shape[dim], dtype=x.dtype)[None].cumsum()[0] - 1 # traceable in 1.1
|
||||
|
||||
|
||||
default_config = {
|
||||
'descriptor_dim': 256, # SuperPoint
|
||||
'weights': 'indoor',
|
||||
'keypoint_encoder': [32, 64, 128, 256], # SuperPoint
|
||||
'GNN_layers': ['self', 'cross'] * 9,
|
||||
'sinkhorn_iterations': 100,
|
||||
'match_threshold': 0.2,
|
||||
}
|
||||
|
||||
|
||||
def get_weighted_loss_batch(scores, all_matches):
|
||||
matches0, matches1 = all_matches.chunk(chunks=2, dim=2)
|
||||
batchIdx = jt.arange(all_matches.shape[0]).unsqueeze(1).repeat(1, all_matches.shape[1])
|
||||
batchIdx, matches0, matches1 = batchIdx.view(-1), matches0.view(-1), matches1.view(-1)
|
||||
valid_index0, valid_index1 = matches0 >= 0, matches1 >= 0
|
||||
valid_match = jt.logical_and(valid_index0, valid_index1)
|
||||
valid_unmatch = jt.logical_xor(valid_index0, valid_index1)
|
||||
num_match = valid_match.sum().maximum(1e-9)
|
||||
num_unmatch = valid_unmatch.sum().maximum(1e-9)
|
||||
|
||||
|
||||
|
||||
score_ = scores[batchIdx, matches0, matches1]
|
||||
score_match_ = (score_*valid_match).float32().sum() / num_match
|
||||
score_umatch_ = (score_*valid_unmatch).float32().sum() / num_unmatch
|
||||
return -(num_unmatch * score_match_ + num_match * score_umatch_) / (num_match + num_unmatch)
|
||||
# print(score_umatch_, score_match_)
|
||||
# return -(score_match + score_umatch) / (num_match + num_unmatch)
|
||||
|
||||
score_match = scores[(batchIdx[valid_match], matches0[valid_match], matches1[valid_match])].float32().mean() if num_match > 0 else 0
|
||||
score_umatch = scores[(batchIdx[valid_unmatch], matches0[valid_unmatch], matches1[valid_unmatch])].float32().mean() if num_unmatch > 0 else 0
|
||||
# print(score_match, score_umatch)
|
||||
return -(num_unmatch * score_match + num_match * score_umatch) / (num_match + num_unmatch)
|
||||
|
||||
|
||||
def add_dustbin(scores, alpha):
|
||||
b, m, n = scores.shape
|
||||
bins0 = jt.broadcast(alpha, (b, m, 1))
|
||||
bins1 = jt.broadcast(alpha, (b, 1, n))
|
||||
alpha = jt.broadcast(alpha, (b, 1, 1))
|
||||
couplings = jt.concat([jt.concat([scores, bins0], -1), jt.concat([bins1, alpha], -1)], 1)
|
||||
return couplings
|
||||
|
||||
|
||||
class SuperGlue(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
config = {**default_config, **config}
|
||||
self.descriptor_dim = config['descriptor_dim']
|
||||
self.keypoint_encoder = config['keypoint_encoder']
|
||||
self.GNN_layers = config['GNN_layers']
|
||||
self.sinkhorn_iterations = config['sinkhorn_iterations']
|
||||
self.match_threshold = config['match_threshold']
|
||||
self.keypoint_position_dim = config['keypoint_position_dim']
|
||||
self.use_dual_softmax = config['use_dual_softmax']
|
||||
self.scale = jt.float(self.descriptor_dim**-0.5).stop_grad()
|
||||
# self.scale.requires_grad = False
|
||||
|
||||
# self.des_extend = MLP([128, 256])
|
||||
|
||||
self.kenc = KeypointEncoder(self.descriptor_dim, self.keypoint_encoder, keypoint_position_dim=self.keypoint_position_dim)
|
||||
|
||||
self.gnn = AttentionalGNN(self.descriptor_dim, self.GNN_layers)
|
||||
|
||||
self.final_proj = Conv1d_sp(self.descriptor_dim, self.descriptor_dim, kernel_size=1, bias=True)
|
||||
|
||||
self.bin_score = jt.float(1.0)
|
||||
|
||||
def execute(self, data):
|
||||
"""Run SuperGlue on a pair of keypoints and descriptors"""
|
||||
|
||||
kpts0, kpts1 = data['keypoints0'], data['keypoints1']
|
||||
desc0, desc1 = data['descriptors0'], data['descriptors1']
|
||||
all_matches = data['all_matches']
|
||||
# match_num = data['match_num']
|
||||
|
||||
if kpts0.shape[1] == 0 or kpts1.shape[1] == 0 or all_matches.shape[1] == 0: # no keypoints or no matches/unmatches
|
||||
shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
|
||||
return {
|
||||
'matches0': jt.ones(shape0, dtype=jt.int),
|
||||
'matches1': jt.ones(shape1, dtype=jt.int),
|
||||
'matching_scores0': jt.zeros(shape0, dtype=jt.float),
|
||||
'matching_scores1': jt.zeros(shape1, dtype=jt.float),
|
||||
'skip_train': True
|
||||
}
|
||||
|
||||
# Keypoint normalization.
|
||||
kpts0 = normalize_keypoints(kpts0, data['shape0'])
|
||||
kpts1 = normalize_keypoints(kpts1, data['shape1'])
|
||||
|
||||
# Keypoint MLP encoder.
|
||||
# desc0 = self.des_extend(desc0) + self.kenc(kpts0, data['scores0'])
|
||||
# desc1 = self.des_extend(desc1) + self.kenc(kpts1, data['scores1'])
|
||||
desc0 = desc0 + self.kenc(kpts0, data['scores0'])
|
||||
desc1 = desc1 + self.kenc(kpts1, data['scores1'])
|
||||
|
||||
# Multi-layer Transformer network.
|
||||
desc0, desc1 = self.gnn(desc0, desc1)
|
||||
|
||||
# Final MLP projection.
|
||||
desc0, desc1 = self.final_proj(desc0), self.final_proj(desc1)
|
||||
desc0_t = desc0.t()
|
||||
losses = []
|
||||
|
||||
for i in range(0, desc1.shape[0], split_size):
|
||||
end = min(desc1.shape[0], i + split_size)
|
||||
|
||||
# Compute matching descriptor distance.
|
||||
scores = nn.bmm(desc0_t[i:end], desc1[i:end]) * self.scale # 457.76 MB
|
||||
scores.sync()
|
||||
|
||||
# Run the optimal transport.
|
||||
if self.use_dual_softmax:
|
||||
scores = add_dustbin(scores, self.bin_score) # 458.68 MB
|
||||
scores.sync()
|
||||
dual_softmax0, dual_softmax1 = nn.log_softmax(scores, 1), nn.log_softmax(scores, 2)
|
||||
scores = dual_softmax0 + dual_softmax1 # 458.22 MB
|
||||
scores.sync()
|
||||
else:
|
||||
scores = log_optimal_transport(scores, self.bin_score, iters=self.config['sinkhorn_iterations'])
|
||||
|
||||
# loss = torch.stack([get_match_score(scores[b], all_matches[b]) for b in range(all_matches.shape[0])])
|
||||
|
||||
loss = get_weighted_loss_batch(scores, all_matches[i:end])
|
||||
loss.sync()
|
||||
losses.append(loss)
|
||||
loss = jt.concat(losses)
|
||||
'''
|
||||
# Compute matching descriptor distance.
|
||||
scores = nn.bmm(desc0.t(), desc1) * self.scale # 457.76 MB
|
||||
scores.sync()
|
||||
|
||||
# Run the optimal transport.
|
||||
if self.use_dual_softmax:
|
||||
scores = add_dustbin(scores, self.bin_score) # 458.68 MB
|
||||
scores.sync()
|
||||
dual_softmax0, dual_softmax1 = nn.log_softmax(scores, 1), nn.log_softmax(scores, 2)
|
||||
scores = dual_softmax0 + dual_softmax1 # 458.22 MB
|
||||
scores.sync()
|
||||
else:
|
||||
scores = log_optimal_transport(scores, self.bin_score, iters=self.config['sinkhorn_iterations'])
|
||||
|
||||
# loss = torch.stack([get_match_score(scores[b], all_matches[b]) for b in range(all_matches.shape[0])])
|
||||
|
||||
loss = get_weighted_loss_batch(scores, all_matches)
|
||||
# print(scores.shape, all_matches.shape, loss.shape)
|
||||
'''
|
||||
|
||||
# matches0, matches1 = all_matches.chunk(chunks=2, dim=2)
|
||||
# batchIdx = jt.arange(0, b).unsqueeze(1).repeat(1, num)
|
||||
# batchIdx, matches0, matches1 = batchIdx.view(-1), matches0.view(-1), matches1.view(-1)
|
||||
# validmatch = (matches0 >= 0) | (matches1 >= 0)
|
||||
# batchIdx, matches0, matches1 = batchIdx[validmatch], matches0[validmatch], matches1[validmatch]
|
||||
# matches0[matches0 == -1] = n
|
||||
# matches1[matches1 == -1] = m
|
||||
# loss_mean = -scores[(batchIdx, matches0, matches1)].mean()
|
||||
# loss_mean = nn.l1_loss(loss_mean, jt.float(0.0))
|
||||
|
||||
if not data['return_match']:
|
||||
return {'loss': loss}
|
||||
|
||||
with jt.no_grad():
|
||||
b, n, m = scores.shape
|
||||
# Get the matches with score above "match_threshold".
|
||||
indices0, max0 = scores[:, :-1, :-1].argmax(2)
|
||||
indices1, max1 = scores[:, :-1, :-1].argmax(1)
|
||||
mutual0 = jt.arange(0, n)[None] == indices1.gather(1, indices0)
|
||||
mutual1 = jt.arange(0, m)[None] == indices0.gather(1, indices1)
|
||||
# zero = scores.new_tensor(0)
|
||||
# mscores0 = torch.where(mutual0, max0.values.exp(), zero)
|
||||
mscores0 = max0.exp()
|
||||
mscores0[mutual0.logical_not()] = 0
|
||||
# mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
|
||||
mscores1 = mscores0.gather(1, indices1)
|
||||
mscores1[mutual1.logical_not()] = 0
|
||||
valid0 = mutual0 & (mscores0 > self.match_threshold)
|
||||
valid1 = mutual1 & valid0.gather(1, indices1)
|
||||
# indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
|
||||
# indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
|
||||
indices0[valid0.logical_not()] = -1
|
||||
indices1[valid1.logical_not()] = -1
|
||||
|
||||
return {
|
||||
'matches0': indices0, # use -1 for invalid match
|
||||
'matches1': indices1, # use -1 for invalid match
|
||||
'matching_scores0': mscores0,
|
||||
'matching_scores1': mscores1,
|
||||
'loss': loss,
|
||||
}
|
||||
|
||||
# scores big value or small value means confidence? log can't take neg value
|
|
@ -0,0 +1,344 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
n = 400000000
|
||||
# n = 4000000
|
||||
n = 7680000
|
||||
|
||||
def get_mem_band():
|
||||
a = jt.rand((n)).float32()
|
||||
for i in range(100):
|
||||
a.copy().sync()
|
||||
jt.sync_all(True)
|
||||
import time
|
||||
t = time.time()
|
||||
for i in range(1000):
|
||||
a.copy().sync()
|
||||
jt.sync_all(True)
|
||||
dt = time.time() - t
|
||||
band = a.numel() * 4 * 2000 / dt / 1024**3
|
||||
print("Mem band: ", band)
|
||||
return band
|
||||
|
||||
def check_simple_add_band():
|
||||
# copy: 816
|
||||
# S=1 128,1024, ILP=1 634
|
||||
# S=0 128,1024, ILP=1 734
|
||||
# S=0 128,512, ILP=1 716
|
||||
# S=0 64,1024, ILP=1 706
|
||||
# S=0 256,1024, ILP=1 706
|
||||
def test(S=0, B=128, T=1024, ILP=1):
|
||||
a = jt.rand((n)).float32()
|
||||
jt.sync_all(True)
|
||||
jt.flags.log_silent = 1
|
||||
with jt.profile_scope(100, 1000) as rep:
|
||||
b = jt.code(a.shape, a.dtype, [a],
|
||||
cuda_header="#include \"type/fp16_compute.h\"",
|
||||
cuda_src=f"""
|
||||
__global__ void kernel(in0_type * __restrict__ a, in0_type* __restrict__ b, int num) {{
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int tnum = blockDim.x * gridDim.x;
|
||||
#define ILP {ILP}
|
||||
for (int i=tid*ILP; i<num; i+=tnum*ILP) {{
|
||||
// b[i] = a[i];
|
||||
vload<ILP*sizeof(in0_type)>(b+i, a+i);
|
||||
{"__syncthreads();" if S else ""}
|
||||
}}
|
||||
}}
|
||||
kernel<<<{B},{T}>>>(in0_p, out0_p, in0->num);
|
||||
""")
|
||||
b.sync()
|
||||
bw = float(rep[-1][9]) / 1024**3
|
||||
s = f"S={S}, B={B}, T={T}, ILP={ILP} BW={bw}"
|
||||
print(s)
|
||||
return s, bw
|
||||
|
||||
def test2(S=0, B=128, T=1024, ILP=1):
|
||||
a = jt.rand((n)).float32()
|
||||
jt.sync_all(True)
|
||||
# jt.flags.log_silent = 0
|
||||
with jt.profile_scope(10, 1000) as rep:
|
||||
b = jt.code(a.shape, a.dtype, [a],
|
||||
cuda_header="#include \"type/fp16_compute.h\"",
|
||||
cuda_src=f"""
|
||||
__global__ void kernel(float2 * __restrict__ a, float2* __restrict__ b, int num) {{
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int tnum = blockDim.x * gridDim.x;
|
||||
#define ILP 1
|
||||
for (int i=tid*ILP; i<num; i+=tnum*ILP) {{
|
||||
b[i] = a[i];
|
||||
// b[i+1] = a[i+1];
|
||||
// vload<ILP*sizeof(in0_type)>(b+i, a+i);
|
||||
{"__syncthreads();" if S else ""}
|
||||
}}
|
||||
}}
|
||||
kernel<<<{B},{T}>>>((float2*)in0_p, (float2*)out0_p, in0->num/2);
|
||||
""")
|
||||
b.sync()
|
||||
bw = float(rep[-1][9]) / 1024**3
|
||||
s = f"T2: S={S}, B={B}, T={T}, ILP={ILP} BW={bw}"
|
||||
print(s)
|
||||
return s, bw
|
||||
|
||||
|
||||
def test3(S=0, B=128, T=1024, ILP=1, C=0):
|
||||
a = jt.rand((n)).float32()
|
||||
b = jt.rand(B)
|
||||
jt.sync_all(True)
|
||||
jt.flags.log_silent = 1
|
||||
with jt.profile_scope(100, 1000) as rep:
|
||||
b = jt.code(a.shape, a.dtype, [a, b],
|
||||
cuda_header="#include \"type/fp16_compute.h\"",
|
||||
cuda_src=f"""
|
||||
__global__ void kernel(in0_type * __restrict__ a, in0_type* __restrict__ b, int num) {{
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int tnum = blockDim.x * gridDim.x;
|
||||
#define ILP {ILP}
|
||||
for (int i=tid*ILP; i<num; i+=tnum*ILP) {{
|
||||
// b[i] = a[i];
|
||||
vload<ILP*sizeof(in0_type)>(b+i, a+i);
|
||||
{"__syncthreads();" if S else ""}
|
||||
}}
|
||||
{"__syncthreads();" if C else ""}
|
||||
}}
|
||||
kernel<<<in1->shape[0],{T}>>>(in0_p, out0_p, in0->num);
|
||||
""")
|
||||
b.compile_options = {"FLAGS: -Xptxas -dlcm=ca ": C}
|
||||
# b.compile_options = {"FLAGS: –Xptxas –dlcm=ca ": 1}
|
||||
b.sync()
|
||||
|
||||
bw = float(rep[-1][9]) / 1024**3
|
||||
s = f"T3: S={S}, B={B}, T={T}, ILP={ILP} C={C} BW={bw}"
|
||||
print(s)
|
||||
return s, bw
|
||||
|
||||
|
||||
def test4(S=0, B=128, T=1024, ILP=1, C=0, name="b.png"):
|
||||
a = jt.rand((n)).float32()
|
||||
b = jt.rand(B*4).uint32()
|
||||
jt.sync_all(True)
|
||||
# jt.flags.log_silent = 1
|
||||
with jt.profile_scope(100, 10000) as rep:
|
||||
_ = jt.code(a.shape, a.dtype, [a, b],
|
||||
cuda_header="#include \"type/fp16_compute.h\"",
|
||||
cuda_src=f"""
|
||||
__device__ uint get_smid(void) {{
|
||||
uint ret;
|
||||
asm("mov.u32 %0, %smid;" : "=r"(ret) );
|
||||
return ret;
|
||||
}}
|
||||
__device__ uint get_time(void) {{
|
||||
uint ret;
|
||||
asm volatile("mov.u32 %0, %%globaltimer_lo;" : "=r"(ret));
|
||||
return ret;
|
||||
}}
|
||||
|
||||
__global__ void kernel(in0_type * __restrict__ a, in0_type* __restrict__ b, int num, in1_type* __restrict__ c) {{
|
||||
uint t = get_time();
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int tnum = blockDim.x * gridDim.x;
|
||||
#define ILP {ILP}
|
||||
for (int i=tid*ILP; i<num; i+=tnum*ILP) {{
|
||||
// b[i] = a[i];
|
||||
vload<ILP*sizeof(in0_type)>(b+i, a+i);
|
||||
{"__syncthreads();" if S else ""}
|
||||
}}
|
||||
{"__syncthreads();" if C else ""}
|
||||
if (threadIdx.x == 0)
|
||||
((uint4* __restrict__)c)[blockIdx.x] =
|
||||
uint4{{get_smid(), t, get_time(), 0}};
|
||||
}}
|
||||
kernel<<<in1->shape[0]/4,{T}>>>(in0_p, out0_p, in0->num, in1_p);
|
||||
""")
|
||||
_.compile_options = {"FLAGS: -Xptxas -dlcm=ca ": C}
|
||||
# b.compile_options = {"FLAGS: –Xptxas –dlcm=ca ": 1}
|
||||
_.sync()
|
||||
|
||||
bw = float(rep[-1][9]) / 1024**3
|
||||
b = b.data.reshape(-1, 4)[:,:3]
|
||||
mint = b[:,1].min()
|
||||
b[:,1:] -= mint
|
||||
smmax = int(b[:,0].max())
|
||||
smmin = int(b[:,0].min())
|
||||
maxt = b.max()
|
||||
|
||||
# print(b)
|
||||
|
||||
s = f"T4: S={S}, B={B}, T={T}, ILP={ILP} C={C} BW={bw:.3f} sm={smmin},{smmax} maxt={maxt}"
|
||||
print(s)
|
||||
import pylab as pl
|
||||
pl.figure(figsize=(16,16))
|
||||
texts = []
|
||||
pret = np.zeros(200, dtype="uint32")
|
||||
for i in range(B):
|
||||
smid, s, t = b[i]
|
||||
pl.plot([s,t], [smid, smid], 'ro-')
|
||||
texts.append((s, smid, i))
|
||||
texts.append((t, smid, i))
|
||||
|
||||
texts = sorted(texts)
|
||||
for (s, smid, bid) in texts:
|
||||
cpos = max(pret[smid], s)
|
||||
pl.text(cpos, smid, str(bid))
|
||||
pret[smid] = cpos + maxt // 30
|
||||
|
||||
|
||||
# print("???")
|
||||
# adjust_text(texts, arrowprops=dict(arrowstyle='->', color='blue'))
|
||||
# print("???")
|
||||
pl.savefig(name)
|
||||
pl.close()
|
||||
return s, bw
|
||||
# test(S=0, B=128, T=1024, ILP=1)
|
||||
# test(S=1, B=128, T=1024, ILP=1)
|
||||
# test(S=0, B=64, T=1024, ILP=1)
|
||||
# test(S=0, B=256, T=1024, ILP=1)
|
||||
# test(S=1, B=128, T=512, ILP=1)
|
||||
# test(S=1, B=128, T=256, ILP=1)
|
||||
|
||||
# test(S=0, B=128, T=1024, ILP=2)
|
||||
# test(S=0, B=128, T=1024, ILP=4)
|
||||
# test(S=0, B=128, T=512, ILP=2)
|
||||
# test(S=0, B=128, T=512, ILP=4)
|
||||
|
||||
# test(S=1, B=128, T=1024, ILP=2)
|
||||
# test(S=1, B=128, T=1024, ILP=4)
|
||||
# test(S=1, B=128, T=1024, ILP=8)
|
||||
# test(S=1, B=128, T=1024, ILP=16)
|
||||
# test(S=1, B=128, T=512, ILP=2)
|
||||
# test(S=1, B=128, T=512, ILP=4)
|
||||
|
||||
# test(S=1, B=256, T=1024, ILP=2)
|
||||
# test(S=1, B=512, T=1024, ILP=2)
|
||||
# test(S=1, B=256, T=1024, ILP=4)
|
||||
# test(S=1, B=256, T=1024, ILP=8)
|
||||
# test(S=1, B=256, T=1024, ILP=16)
|
||||
# test(S=1, B=256, T=512, ILP=2)
|
||||
# test(S=1, B=256, T=512, ILP=4)
|
||||
|
||||
# test(S=1, B=128, T=256, ILP=2)
|
||||
# test(S=1, B=128, T=256, ILP=4)
|
||||
# test(S=0, B=128, T=256, ILP=2)
|
||||
# test(S=0, B=128, T=256, ILP=4)
|
||||
|
||||
# for b in [1, 2, 4, 8, 16, 32, 64, 128,256]:
|
||||
# test(S=1, B=b, T=512, ILP=2)
|
||||
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
import pylab as pl
|
||||
import numpy as np
|
||||
|
||||
# test4(S=1, B=82, T=1024, ILP=2, C=0, name="b.png")
|
||||
# test4(S=1, B=83, T=1024, ILP=2, C=0, name="c.png")
|
||||
# test4(S=1, B=82*3, T=512, ILP=2, C=0, name="d1.png")
|
||||
# test4(S=1, B=82*3+1, T=512, ILP=2, C=0, name="d2.png")
|
||||
# test4(S=1, B=82*6+1, T=512, ILP=2, C=0, name="d3.png")
|
||||
# test4(S=0, B=82*6+1, T=512, ILP=2, C=0, name="d4.png")
|
||||
|
||||
for b in range(70, 83):
|
||||
test4(S=1, B=b, T=1024, ILP=2, C=0, name=f"b-{b}.png")
|
||||
|
||||
# data = []
|
||||
# for b in range(32, 2000, 8):
|
||||
# _, bw = test3(S=0, B=b, T=32, ILP=2)
|
||||
# data.append([b, bw])
|
||||
# data = np.array(data)
|
||||
# pl.plot(data[:,0], data[:,1])
|
||||
|
||||
# for t in [32, 64, 128, 256, 512, 1024]:
|
||||
# data = []
|
||||
# for b in range(32, 2000, 8):
|
||||
# _, bw = test3(S=1, B=b*(1024//t), T=t, ILP=2)
|
||||
# data.append([b, bw])
|
||||
# data = np.array(data)
|
||||
# pl.plot(data[:,0], data[:,1])
|
||||
|
||||
# for t in [1024]:
|
||||
# for c in [0,1]:
|
||||
# data = []
|
||||
# # for b in range(32, 1000, 8):
|
||||
# for b in range(32, 33, 8):
|
||||
# _, bw = test3(S=c, B=b*(1024//t), T=t, ILP=2, C=0)
|
||||
# data.append([b, bw])
|
||||
# data = np.array(data)
|
||||
# pl.plot(data[:,0], data[:,1])
|
||||
|
||||
# for ilp in [2]:
|
||||
# for s in [1]:
|
||||
# for t in [1024,512,256,128]:
|
||||
# data = []
|
||||
# for b in range(32, 1100, 8):
|
||||
# _, bw = test3(S=s, B=b*(1024//t), T=t, ILP=ilp)
|
||||
# data.append([b, bw])
|
||||
# data = np.array(data)
|
||||
# pl.plot(data[:,0], data[:,1])
|
||||
|
||||
# pl.savefig("a.png")
|
||||
# pl.close()
|
||||
# for b in range(80, 90, 1):
|
||||
# _, bw = test3(S=1, B=b, T=1024, ILP=2)
|
||||
# # 82
|
||||
# for b in range(240, 260, 1):
|
||||
# _, bw = test3(S=1, B=b, T=512, ILP=2)
|
||||
# # 82*3 = 246
|
||||
# for b in range(240, 500, 1):
|
||||
# _, bw = test3(S=1, B=b, T=256, ILP=2)
|
||||
# # 492 = 82*6
|
||||
# for b in range(240, 1000, 1):
|
||||
# _, bw = test3(S=1, B=b, T=128, ILP=2)
|
||||
# # 984 = 82*12
|
||||
|
||||
|
||||
# for b in [128,256]:
|
||||
# test(S=1, B=b, T=1024, ILP=2)
|
||||
# for b in [128,256]:
|
||||
# test(S=0, B=b, T=512, ILP=2)
|
||||
# for b in [128,256]:
|
||||
# test(S=0, B=b, T=1024, ILP=2)
|
||||
# for b in [128,256]:
|
||||
# test(S=1, B=b, T=512, ILP=1)
|
||||
# for b in [128,256]:
|
||||
# test(S=1, B=b, T=1024, ILP=1)
|
||||
# for b in [128,256]:
|
||||
# test(S=0, B=b, T=512, ILP=1)
|
||||
# for b in [128,256]:
|
||||
# test(S=0, B=b, T=1024, ILP=1)
|
||||
# test(S=1, B=128, T=512, ILP=4)
|
||||
# test(S=1, B=64, T=512, ILP=2)
|
||||
# test(S=1, B=80, T=512, ILP=2)
|
||||
# test(S=1, B=100, T=512, ILP=2)
|
||||
# test(S=1, B=110, T=512, ILP=2)
|
||||
# test(S=1, B=115, T=512, ILP=2)
|
||||
# test(S=1, B=120, T=512, ILP=2)
|
||||
# test(S=1, B=130, T=512, ILP=2)
|
||||
# test(S=1, B=140, T=512, ILP=2)
|
||||
# test2(S=1, B=128, T=512, ILP=2)
|
||||
# test(S=1, B=128, T=256, ILP=4)
|
||||
# test(S=1, B=128, T=128, ILP=8)
|
||||
# test(S=1, B=128, T=64, ILP=16)
|
||||
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
class TestBenchmarkCUDA(unittest.TestCase):
|
||||
def setUp(self):
|
||||
jt.flags.use_cuda = 1
|
||||
def tearDown(self):
|
||||
jt.flags.use_cuda = 0
|
||||
|
||||
def test_main(self):
|
||||
return
|
||||
get_mem_band()
|
||||
check_simple_add_band()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -19,12 +19,12 @@ def all_eq(x, y):
|
|||
y = convert(y)
|
||||
if str(x.dtype).startswith("float"):
|
||||
return str(y.dtype).startswith("float") and x.shape == y.shape and (x==y).all()
|
||||
return x.dtype == y.dtype and x.shape == y.shape and (x==y).all()
|
||||
return x.dtype == y.dtype and x.shape == y.shape and np.testing.assert_allclose(x, y)
|
||||
|
||||
def check(op, *args):
|
||||
x = eval(f"np.{op}(*args)")
|
||||
y = eval(f"jt.{op}(*args).data")
|
||||
assert all_eq(x, y), f"{x}\n{y}"
|
||||
all_eq(x, y)
|
||||
|
||||
class TestBinaryOp(unittest.TestCase):
|
||||
def test_binary_op(self):
|
||||
|
@ -47,6 +47,9 @@ class TestBinaryOp(unittest.TestCase):
|
|||
|
||||
def test_i(self):
|
||||
def check(op, a, b):
|
||||
if isinstance(a, list):
|
||||
a = np.array(a)
|
||||
b = np.array(b)
|
||||
if jt.flags.use_cuda and op == "@":
|
||||
return
|
||||
if op=="@":
|
||||
|
@ -65,13 +68,13 @@ class TestBinaryOp(unittest.TestCase):
|
|||
a = np.float32(a)
|
||||
ja = np.float32(ja)
|
||||
|
||||
assert all_eq(ja, a), (ja,a)
|
||||
all_eq(ja, a)
|
||||
check("+", 5, 2)
|
||||
check("-", 5, 2)
|
||||
check("*", 5, 2)
|
||||
check("/", 5, 2)
|
||||
check("//", 5, 2)
|
||||
check("@", [[5]], [[2]])
|
||||
# check("@", [[5]], [[2]])
|
||||
check("%", 5, 2)
|
||||
check("**", 5, 2)
|
||||
check("<<", 5, 2)
|
||||
|
@ -80,6 +83,15 @@ class TestBinaryOp(unittest.TestCase):
|
|||
check("^", 5, 2)
|
||||
check("|", 5, 2)
|
||||
|
||||
check("+", [5.0,6.0], [2.0,3.0])
|
||||
check("-", [5.0,6.0], [2.0,3.0])
|
||||
check("*", [5.0,6.0], [2.0,3.0])
|
||||
check("/", [5.0,6.0], [2.0,3.0])
|
||||
check("//", [5.0,6.0], [2.0,3.0])
|
||||
check("@", [[5,6],[7,8]], [[2,3],[4,5]])
|
||||
check("%", [5.0,6.0], [2.0,3.0])
|
||||
check("**", [5.0,6.0], [2.0,3.0])
|
||||
|
||||
def test_r(self):
|
||||
def check(op, a, b):
|
||||
a = np.array(a)
|
||||
|
@ -97,7 +109,7 @@ class TestBinaryOp(unittest.TestCase):
|
|||
a = eval(f"a {op} b")
|
||||
a = np.array(a)
|
||||
|
||||
assert all_eq(jc, a), f"\n{jc}\n{a}"
|
||||
all_eq(jc, a)
|
||||
check("+", 5, 2)
|
||||
check("-", 5, 2)
|
||||
check("*", 5, 2)
|
||||
|
@ -118,6 +130,7 @@ class TestBinaryOp(unittest.TestCase):
|
|||
a = np.random.rand(10)
|
||||
b = np.random.rand(10)
|
||||
c = np.random.rand(10)
|
||||
tol = 1e-2 if jt.flags.amp_reg & 2 else 1e-4
|
||||
for op in ops:
|
||||
func = lambda x: eval(f"((x[0]{op}x[1])*x[2]).sum()")
|
||||
x, grads = ngrad(func, [a,b,c], 1e-8)
|
||||
|
@ -127,7 +140,7 @@ class TestBinaryOp(unittest.TestCase):
|
|||
jx = eval(f"(ja{op}jb)*jc")
|
||||
jgrads = jt.grad(jx, [ja,jb,jc])
|
||||
for jd, nd in zip(jgrads, grads):
|
||||
assert (np.abs(jd.data-nd)<1e-4).all(), f"\n{jd.data}\n{nd}"
|
||||
np.testing.assert_allclose(jd.data, nd, atol=tol, rtol=tol)
|
||||
|
||||
def test_mod_float(self):
|
||||
a = jt.random((10,))
|
||||
|
@ -137,7 +150,8 @@ class TestBinaryOp(unittest.TestCase):
|
|||
a = jt.random((10,), 'float64')
|
||||
b = jt.random((10,), 'float64')
|
||||
c = a % b
|
||||
assert np.allclose(c.data, a.data % b.data)
|
||||
assert np.allclose(c.data, a.data % b.data, a.data, b.data)
|
||||
if jt.flags.amp_reg & 2: return
|
||||
a = jt.random((10,)) * 1000
|
||||
b = (jt.random((10,)) * 10).int() + 1
|
||||
c = a % b
|
||||
|
@ -169,5 +183,19 @@ class TestBinaryOp(unittest.TestCase):
|
|||
class TestBinaryOpCuda(TestBinaryOp, test_cuda(2)):
|
||||
pass
|
||||
|
||||
class TestBinaryOpCpuFp16(TestBinaryOp):
|
||||
def setUp(self):
|
||||
jt.flags.amp_reg = 2 | 4 | 8 | 16
|
||||
def tearDown(self):
|
||||
jt.flags.amp_reg = 0
|
||||
|
||||
class TestBinaryOpCudaFp16(TestBinaryOp):
|
||||
def setUp(self):
|
||||
jt.flags.amp_reg = 2 | 4 | 8 | 16
|
||||
jt.flags.use_cuda = 1
|
||||
def tearDown(self):
|
||||
jt.flags.amp_reg = 0
|
||||
jt.flags.use_cuda = 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,342 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
def transpose0231(x):
|
||||
s0, s1, s2, s3 = x.shape
|
||||
asize = 16
|
||||
bsize = 16
|
||||
ILP = 2
|
||||
return jt.code([s0, s2, s3, s1], x.dtype, [x],
|
||||
cuda_header="#include <type/fp16_compute.h>\n#include <cassert>",
|
||||
cuda_src=f"""
|
||||
__global__ void kernel(in0_type* __restrict__ x, in0_type* __restrict__ y, int s0, int s1, int s2, int s3) {{
|
||||
__shared__ in0_type t[{asize*ILP}*{bsize*ILP+1}];
|
||||
int t3 = threadIdx.x % {bsize};
|
||||
int t1 = threadIdx.x / {bsize};
|
||||
int b3 = blockIdx.x;
|
||||
int b2 = blockIdx.y;
|
||||
int b0 = blockIdx.z;
|
||||
int x3 = 1;
|
||||
int x2 = s3;
|
||||
int x1 = s2*x2;
|
||||
int x0 = s1*x1;
|
||||
int y3 = 1;
|
||||
int y2 = s1;
|
||||
int y1 = s3*y2;
|
||||
int y0 = s2*y1;
|
||||
in0_type tmp[{ILP}];
|
||||
for (int i=0; i<(s1-1)/{asize*ILP}+1; i++)
|
||||
{{
|
||||
int _b3 = b3 * {bsize*ILP} + t3*{ILP};
|
||||
if (_b3 < s3) {{
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++) {{
|
||||
vload<sizeof(in0_type)*{ILP}>(
|
||||
tmp,
|
||||
&x[b0*x0+(t1*{ILP}+j+i*{asize*ILP})*x1+b2*x2+_b3*x3]
|
||||
);
|
||||
#pragma unroll
|
||||
for (int k=0; k<{ILP}; k++)
|
||||
t[(t1*{ILP}+j)*{bsize*ILP+1}+t3*{ILP}+k] = tmp[k];
|
||||
|
||||
}}
|
||||
}}
|
||||
__syncthreads();
|
||||
int t3_ = threadIdx.x % {asize};
|
||||
int t1_ = threadIdx.x / {asize};
|
||||
_b3 = b3 * {bsize*ILP} + t1_*{ILP};
|
||||
if (_b3 < s3) {{
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++) {{
|
||||
#pragma unroll
|
||||
for (int k=0; k<{ILP}; k++) {{
|
||||
tmp[k] =
|
||||
t[(t3*{ILP}+k)*{bsize*ILP+1}+t1_*{ILP}+j];
|
||||
}}
|
||||
vload<sizeof(in0_type)*{ILP}>(
|
||||
&y[b0*y0+b2*y1+(_b3+j)*y2+((t3*{ILP})+i*{asize*ILP})*y3],
|
||||
tmp
|
||||
);
|
||||
}}
|
||||
}}
|
||||
__syncthreads();
|
||||
}}
|
||||
}}
|
||||
int s0, s1, s2, s3;
|
||||
in0->shape.unpack(s0, s1, s2, s3);
|
||||
kernel<<<{{(s3-1)/{bsize*ILP}+1, s2, s0 }}, {bsize*asize}>>>
|
||||
(in0_p, out0_p, s0, s1, s2, s3);
|
||||
""")
|
||||
|
||||
def transpose0231_2(x):
|
||||
s0, s1, s2, s3 = x.shape
|
||||
asize = 16
|
||||
bsize = 8
|
||||
ILP = 2
|
||||
return jt.code([s0, s2, s3, s1], x.dtype, [x],
|
||||
cuda_header="#include <type/fp16_compute.h>\n#include <cassert>",
|
||||
cuda_src=f"""
|
||||
__global__ __launch_bounds__({asize*bsize}) void kernel(in0_type* __restrict__ x, in0_type* __restrict__ y, int s0, int s1, int s2, int s3) {{
|
||||
__shared__ in0_type t[{asize*ILP}*{bsize*ILP+1}];
|
||||
int t3 = threadIdx.x % {bsize};
|
||||
int t1 = threadIdx.x / {bsize};
|
||||
int b3 = blockIdx.x;
|
||||
int b1 = blockIdx.y;
|
||||
int b2 = 0;
|
||||
int b0 = blockIdx.z;
|
||||
int x3 = 1;
|
||||
int x2 = s3;
|
||||
int x1 = s2*x2;
|
||||
int x0 = s1*x1;
|
||||
int y3 = 1;
|
||||
int y2 = s1;
|
||||
int y1 = s3*y2;
|
||||
int y0 = s2*y1;
|
||||
in0_type tmp[{ILP}];
|
||||
{{
|
||||
int _b3 = b3 * {bsize*ILP} + t3*{ILP};
|
||||
if (_b3 < s3) {{
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++) {{
|
||||
if (t1*{ILP}+j+b1*{asize*ILP} >= s1)
|
||||
continue;
|
||||
vload<sizeof(in0_type)*{ILP}>(
|
||||
tmp,
|
||||
&x[b0*x0+(t1*{ILP}+j+b1*{asize*ILP})*x1+b2*x2+_b3*x3]
|
||||
);
|
||||
#pragma unroll
|
||||
for (int k=0; k<{ILP}; k++)
|
||||
t[(t1*{ILP}+j)*{bsize*ILP+1}+t3*{ILP}+k] = tmp[k];
|
||||
|
||||
}}
|
||||
}}
|
||||
__syncthreads();
|
||||
int t3_ = threadIdx.x % {asize};
|
||||
int t1_ = threadIdx.x / {asize};
|
||||
_b3 = b3 * {bsize*ILP} + t1_*{ILP};
|
||||
int yy3 = (t3_*{ILP})+b1*{asize*ILP};
|
||||
if (_b3 < s3 && yy3 < s1) {{
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++) {{
|
||||
#pragma unroll
|
||||
for (int k=0; k<{ILP}; k++) {{
|
||||
tmp[k] =
|
||||
t[(t3_*{ILP}+k)*{bsize*ILP+1}+t1_*{ILP}+j];
|
||||
}}
|
||||
vload<sizeof(in0_type)*{ILP}>(
|
||||
&y[b0*y0+b2*y1+(_b3+j)*y2+yy3*y3],
|
||||
tmp
|
||||
);
|
||||
// printf("%d %d %d %d %d\\n", b0*y0+b2*y1+(_b3+j)*y2+yy3*y3,
|
||||
// b0, b2, (_b3+j), yy3);
|
||||
}}
|
||||
}}
|
||||
__syncthreads();
|
||||
}}
|
||||
}}
|
||||
int s0, s1, s2, s3;
|
||||
in0->shape.unpack(s0, s1, s2, s3);
|
||||
kernel<<<{{(s3-1)/{bsize*ILP}+1, (s1-1)/{asize*ILP}+1, s0 }}, {bsize*asize}>>>
|
||||
(in0_p, out0_p, s0, s1, s2, s3);
|
||||
""")
|
||||
|
||||
def check_share():
|
||||
return
|
||||
a = jt.rand((30, 32, 4, 2000)).float32()
|
||||
jt.code(a.shape, a.dtype, [a],
|
||||
cuda_header="#include <type/fp16_compute.h>\n#include <cassert>",
|
||||
cuda_src="""
|
||||
__global__ void kernel(in0_type* __restrict__ a, in0_type* __restrict__ b) {
|
||||
__shared__ float x[32*33];
|
||||
for (int i=0; i<3; i++) {
|
||||
((float2*)&x[i])[0] = ((float2*)&a[i])[0];
|
||||
((float2*)&b[i])[0] = ((float2*)&x[i+1])[0];
|
||||
}
|
||||
}
|
||||
kernel<<<1024,16*16>>>(in0_p, out0_p);
|
||||
LOGir << "aaa";
|
||||
""").sync()
|
||||
jt.sync_all(True)
|
||||
# print(a[0]+1)
|
||||
print("pass test")
|
||||
|
||||
class TestFP16(unittest.TestCase):
|
||||
def test_array(self):
|
||||
a = np.array([1,2,3], dtype="float16")
|
||||
b = jt.array(a)
|
||||
np.testing.assert_allclose(a, b.data)
|
||||
|
||||
def test_add(self):
|
||||
a = np.array([1,2,3], dtype="float16")
|
||||
b = jt.array(a)
|
||||
c = b+b
|
||||
np.testing.assert_allclose(c.data, a+a)
|
||||
d = c.sum()
|
||||
np.testing.assert_allclose(d.data, [12])
|
||||
c = c+1
|
||||
print(c)
|
||||
|
||||
def test_matmul(self):
|
||||
a = jt.random((100,100)).float16()
|
||||
b = jt.random((100,100)).float16()
|
||||
c = jt.matmul(a, b)
|
||||
c.sync()
|
||||
|
||||
def test_matmul_grad(self):
|
||||
a = jt.random((100,100)).float16()
|
||||
b = jt.random((100,100)).float16()
|
||||
c = jt.matmul(a, b)
|
||||
c.sync()
|
||||
da, db = jt.grad(c, [a,b])
|
||||
jt.sync_all()
|
||||
assert da.dtype == "float16"
|
||||
assert db.dtype == "float16"
|
||||
|
||||
def test_array_random_auto_cast(self):
|
||||
a = jt.array([1.0,2.0])
|
||||
assert a.dtype == "float32"
|
||||
with jt.flag_scope(amp_reg=2+16):
|
||||
a = jt.array([1.0,2.0])
|
||||
assert a.dtype == "float16", a.dtype
|
||||
|
||||
a = jt.random([10])
|
||||
assert a.dtype == "float32"
|
||||
with jt.flag_scope(amp_reg=2+16):
|
||||
a = jt.random([10])
|
||||
assert a.dtype == "float16", a.dtype
|
||||
|
||||
def test_conv(self):
|
||||
a = jt.random((3,4,5,5)).float16()
|
||||
b = jt.random((4,4,3,3)).float16()
|
||||
c = jt.nn.conv(a, b)
|
||||
c.sync()
|
||||
|
||||
def test_max(self):
|
||||
a = jt.random((100,)).float16()
|
||||
b = jt.random((100,)).float16()
|
||||
c = a.maximum(b)
|
||||
c.sync()
|
||||
|
||||
def test_reduce_dtype_infer(self):
|
||||
with jt.flag_scope(amp_reg=1):
|
||||
a = jt.random((3,4,5,5)).float16()
|
||||
b = a.sum()
|
||||
b.sync()
|
||||
assert b.dtype == "float32"
|
||||
with jt.flag_scope(amp_reg=2):
|
||||
a = jt.random((3,4,5,5)).float16()
|
||||
b = a.sum()
|
||||
b.sync()
|
||||
assert b.dtype == "float32"
|
||||
with jt.flag_scope(amp_reg=0):
|
||||
a = jt.random((3,4,5,5)).float16()
|
||||
b = a.sum()
|
||||
b.sync()
|
||||
assert b.dtype == "float32"
|
||||
with jt.flag_scope(amp_reg=2+4):
|
||||
a = jt.random((3,4,5,5)).float16()
|
||||
b = a.sum()
|
||||
b.sync()
|
||||
assert b.dtype == "float16", b.dtype
|
||||
|
||||
def test_white_dtype_infer(self):
|
||||
with jt.flag_scope(amp_reg=1):
|
||||
a = jt.random((3,4,5,5)).float16()
|
||||
b = a**a
|
||||
b.sync()
|
||||
assert b.dtype == "float32"
|
||||
with jt.flag_scope(amp_reg=2):
|
||||
a = jt.random((3,4,5,5)).float16()
|
||||
b = a**a
|
||||
b.sync()
|
||||
assert b.dtype == "float32"
|
||||
with jt.flag_scope(amp_reg=0):
|
||||
a = jt.random((3,4,5,5)).float16()
|
||||
b = a**a
|
||||
b.sync()
|
||||
assert b.dtype == "float32"
|
||||
with jt.flag_scope(amp_reg=2+8):
|
||||
a = jt.random((3,4,5,5)).float16()
|
||||
b = a**a
|
||||
b.sync()
|
||||
assert b.dtype == "float16", b.dtype
|
||||
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
class TestFP16CUDA(TestFP16):
|
||||
def setUp(self):
|
||||
jt.flags.use_cuda = 1
|
||||
def tearDown(self):
|
||||
jt.flags.use_cuda = 0
|
||||
|
||||
def test_softmax(self):
|
||||
a = jt.rand((120, 2000, 2000)).float16()
|
||||
# a = jt.rand((1, 2000, 2000)).float32()
|
||||
jt.sync_all()
|
||||
with jt.profile_scope(10, 100):
|
||||
a.log_softmax(-1).sync()
|
||||
|
||||
def test_transpose(self):
|
||||
check_share()
|
||||
# return
|
||||
a = jt.rand((30, 32, 4, 2000)).float32()
|
||||
# a = jt.rand((1, 1024, 1, 2000)).float32()
|
||||
diff = transpose0231(a).data != a.transpose((0,2,3,1)).data
|
||||
print(np.where(diff))
|
||||
# return
|
||||
jt.sync_all()
|
||||
# with jt.profile_scope(100, 11000):
|
||||
with jt.profile_scope(100, 11000):
|
||||
# a.log_softmax(-1).sync()
|
||||
transpose0231(a).sync()
|
||||
|
||||
a.transpose((0,2,3,1)).sync()
|
||||
# a.transpose((0,2,1,3)).sync()
|
||||
a.fuse_transpose((0,2,1,3)).sync()
|
||||
(a+1).sync()
|
||||
jt.sync_all(True)
|
||||
diff = transpose0231(a).data != a.transpose((0,2,3,1)).data
|
||||
print(np.where(diff))
|
||||
np.testing.assert_allclose(transpose0231(a).data, a.transpose((0,2,3,1)).data)
|
||||
|
||||
def test_transpose2(self):
|
||||
# check_share()
|
||||
# return
|
||||
# a = jt.rand((30, 32, 4, 2000)).float32()
|
||||
# a = jt.rand((1, 10000, 1, 2000)).float32()
|
||||
a = jt.rand((1, 10000, 1, 2048)).float32()
|
||||
print("transpose")
|
||||
transpose0231_2(a).sync()
|
||||
print("add")
|
||||
(a+1).sync()
|
||||
return
|
||||
# a = jt.arange(32*16).reshape((1, 32, 1, 16))
|
||||
diff = transpose0231_2(a).data != a.transpose((0,2,3,1)).data
|
||||
print(np.where(diff))
|
||||
# return
|
||||
jt.sync_all()
|
||||
# with jt.profile_scope(100, 11000):
|
||||
with jt.profile_scope(100, 1100):
|
||||
# a.log_softmax(-1).sync()
|
||||
transpose0231_2(a).sync()
|
||||
|
||||
a.transpose((0,2,3,1)).sync()
|
||||
# a.transpose((0,2,1,3)).sync()
|
||||
a.fuse_transpose((0,2,1,3)).sync()
|
||||
(a+1).sync()
|
||||
jt.sync_all(True)
|
||||
diff = transpose0231_2(a).data != a.transpose((0,2,3,1)).data
|
||||
print(np.where(diff))
|
||||
np.testing.assert_allclose(transpose0231_2(a).data, a.transpose((0,2,3,1)).data)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -75,6 +75,8 @@ class TestPad(unittest.TestCase):
|
|||
print('pass flip test ...')
|
||||
|
||||
def test_cross(self):
|
||||
def check_equal(a, b, tol):
|
||||
np.testing.assert_allclose(a.detach().numpy(), b.numpy(), atol=1e-5)
|
||||
arr1 = np.random.randn(16,3,224,224,3)
|
||||
arr2 = np.random.randn(16,3,224,224,3)
|
||||
check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=1), jt.array(arr1).cross(jt.array(arr2), dim=1), 1e-1)
|
||||
|
@ -257,34 +259,51 @@ class TestOther(unittest.TestCase):
|
|||
a = jt.arctan2(jt.array([1,1.0,0]), jt.array([1,0.0,-1]))
|
||||
np.testing.assert_allclose(a.data, [0.7853982,1.5707964,3.1415927])
|
||||
|
||||
y = jt.random((100,))
|
||||
x = jt.random((100,))
|
||||
z = jt.arctan2(y, x)
|
||||
z2 = np.arctan2(y.data, x.data)
|
||||
np.testing.assert_allclose(z.data, z2)
|
||||
|
||||
def test_code_softmax(self):
|
||||
if not jt.has_cuda: return
|
||||
|
||||
def softmax(x, dim = None):
|
||||
def softmax(x, dim = None, log=False):
|
||||
if dim is None:
|
||||
x = (x - x.max()).exp()
|
||||
ret = x / x.sum()
|
||||
else:
|
||||
x = (x-x.max(dim, keepdims=True)).exp()
|
||||
ret = x / x.sum(dim, keepdims=True)
|
||||
if log: return ret.log()
|
||||
return ret
|
||||
from jittor.other.code_softmax import softmax_v1
|
||||
|
||||
with jt.flag_scope(use_cuda = 1):
|
||||
shape = (120, 2000, 2000)
|
||||
# shape = (3,3)
|
||||
a = jt.rand(shape)
|
||||
c = jt.rand(shape)
|
||||
b = softmax(a, -1)
|
||||
bb = softmax_v1(a)
|
||||
shape = (3,3)
|
||||
for log in [0,1]:
|
||||
for shape in [(3,3),
|
||||
(12, 200, 2000),
|
||||
(12, 200, 2048),
|
||||
(12, 200, 2049)]:
|
||||
print(shape)
|
||||
a = jt.rand(shape)
|
||||
c = jt.rand(shape)
|
||||
b = softmax(a, -1, log=log)
|
||||
bb = softmax_v1(a, log=log)
|
||||
|
||||
err = (bb - b).abs().max()
|
||||
assert err.item() < 1e-5
|
||||
err = (bb - b).abs().max()
|
||||
assert err.item() < 1e-5, (err, bb, b)
|
||||
|
||||
d1 = jt.grad(b*c, a)
|
||||
d2 = jt.grad(bb*c, a)
|
||||
err = (d1 - d2).abs().max()
|
||||
assert err.item() < 1e-5
|
||||
d1 = jt.grad(b*c, a)
|
||||
d2 = jt.grad(bb*c, a)
|
||||
err = (d1 - d2).abs().max()
|
||||
|
||||
if log:
|
||||
assert err.item() < 1e-2, (err.item())
|
||||
else:
|
||||
assert err.item() < 1e-5, (err.item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -36,19 +36,7 @@ class MnistNet(Module):
|
|||
return x
|
||||
|
||||
@unittest.skipIf(skip_this_test, "skip_this_test")
|
||||
class TestResnet(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
# hyper-parameters
|
||||
self.batch_size = int(os.environ.get("TEST_BATCH_SIZE", "100"))
|
||||
self.weight_decay = 0.0001
|
||||
self.momentum = 0.9
|
||||
self.learning_rate = 0.1
|
||||
# mnist dataset
|
||||
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
|
||||
.set_attrs(batch_size=self.batch_size, shuffle=True)
|
||||
self.train_loader.num_workers = 4
|
||||
|
||||
class TestResnetFp32(unittest.TestCase):
|
||||
# setup random seed
|
||||
def setup_seed(self, seed):
|
||||
np.random.seed(seed)
|
||||
|
@ -59,6 +47,19 @@ class TestResnet(unittest.TestCase):
|
|||
@jt.flag_scope(use_cuda=1, use_stat_allocator=1)
|
||||
def test_resnet(self):
|
||||
self.setup_seed(1)
|
||||
|
||||
# hyper-parameters
|
||||
self.batch_size = int(os.environ.get("TEST_BATCH_SIZE", "100"))
|
||||
self.weight_decay = 0.0001
|
||||
self.momentum = 0.9
|
||||
self.learning_rate = 0.1
|
||||
if jt.flags.amp_reg:
|
||||
self.learning_rate = 0.01
|
||||
# mnist dataset
|
||||
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
|
||||
.set_attrs(batch_size=self.batch_size, shuffle=True)
|
||||
self.train_loader.num_workers = 4
|
||||
|
||||
loss_list=[]
|
||||
acc_list=[]
|
||||
mnist_net = MnistNet()
|
||||
|
@ -70,6 +71,7 @@ class TestResnet(unittest.TestCase):
|
|||
for data, target in self.train_loader:
|
||||
batch_id = self.train_loader.batch_id
|
||||
epoch_id = self.train_loader.epoch_id
|
||||
data = data.float_auto()
|
||||
|
||||
# train step
|
||||
# with jt.log_capture_scope(
|
||||
|
@ -120,6 +122,8 @@ class TestResnet(unittest.TestCase):
|
|||
# Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000
|
||||
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
|
||||
|
||||
if jt.flags.amp_reg:
|
||||
continue
|
||||
if jt.in_mpi:
|
||||
assert jt.core.number_of_lived_vars() < 8100, jt.core.number_of_lived_vars()
|
||||
else:
|
||||
|
@ -131,5 +135,14 @@ class TestResnet(unittest.TestCase):
|
|||
assert np.mean(loss_list[-50:])<0.5
|
||||
assert np.mean(acc_list[-50:])>0.8
|
||||
|
||||
|
||||
@unittest.skipIf(skip_this_test, "skip_this_test")
|
||||
class TestResnetFp16(TestResnetFp32):
|
||||
def setup(self):
|
||||
jt.flags.auto_mixed_precision_level = 5
|
||||
|
||||
def tearDown(self):
|
||||
jt.flags.auto_mixed_precision_level = 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from jittor.test.misc import superglue
|
||||
from jittor.test.misc.superglue import SuperGlue
|
||||
import time
|
||||
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def main():
|
||||
global superglue
|
||||
superglue.split_size = int(os.environ.get("split_size", "12"))
|
||||
# superglue.split_size = 1000000
|
||||
|
||||
batch = 30
|
||||
num = 2000
|
||||
dim = 128
|
||||
|
||||
# jt.display_memory_info()
|
||||
# os.system("nvidia-smi")
|
||||
# breakpoint()
|
||||
|
||||
with jt.no_grad():
|
||||
|
||||
config = {
|
||||
'superglue': {
|
||||
'sinkhorn_iterations': 25,
|
||||
'match_threshold': 0.01,
|
||||
'keypoint_position_dim': 2,
|
||||
'descriptor_dim': dim,
|
||||
'use_dual_softmax': True,
|
||||
'GNN_layers': ['self', 'cross'] * 9,
|
||||
}
|
||||
}
|
||||
|
||||
superglue = SuperGlue(config.get('superglue', {}))
|
||||
|
||||
superglue.eval()
|
||||
|
||||
data = {
|
||||
'keypoints0': jt.rand((batch, num, 2), dtype=jt.float),
|
||||
'keypoints1': jt.rand((batch, num, 2), dtype=jt.float),
|
||||
'shape0': jt.rand((batch, 2), dtype=jt.float),
|
||||
'shape1': jt.rand((batch, 2), dtype=jt.float),
|
||||
'descriptors0': jt.rand((batch, dim, num), dtype=jt.float),
|
||||
'descriptors1': jt.rand((batch, dim, num), dtype=jt.float),
|
||||
'scores0': jt.rand((batch, num), dtype=jt.float),
|
||||
'scores1': jt.rand((batch, num), dtype=jt.float),
|
||||
'all_matches': jt.randint(0, num, (batch, num, 2), dtype=jt.int),
|
||||
'return_match': False,
|
||||
# 'match_num': match_num
|
||||
}
|
||||
|
||||
use_fp16 = int(os.environ.get("use_fp16", "0"))
|
||||
if use_fp16:
|
||||
jt.flags.amp_reg = 2
|
||||
for k,v in data.items():
|
||||
if isinstance(v, jt.Var) and v.dtype == "float32":
|
||||
v.assign(v.float16())
|
||||
for v in superglue.parameters():
|
||||
if v.dtype == "float32":
|
||||
v.assign(v.float16())
|
||||
jt.sync_all(True)
|
||||
|
||||
import pickle
|
||||
jt.sync_all(True)
|
||||
for x in range(5):
|
||||
print(x)
|
||||
jt.gc()
|
||||
x = superglue(data)['loss']
|
||||
x.sync()
|
||||
jt.display_memory_info()
|
||||
# os.system("nvidia-smi")
|
||||
# breakpoint()
|
||||
# print(data)
|
||||
# print(x)
|
||||
|
||||
# with open("/tmp/record.pkl", "wb") as f:
|
||||
# pickle.dump([data, x], f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
# with jt.flag_scope(trace_py_var=3, profile_memory_enable=1):
|
||||
# x = superglue(data)['loss']
|
||||
# x.sync()
|
||||
# jt.get_max_memory_treemap()
|
||||
# exit(0)
|
||||
|
||||
jt.sync_all(True)
|
||||
time0 = time.time()
|
||||
jt.flags.profiler_enable = int(os.environ.get("profiler", "0"))
|
||||
|
||||
for x in range(20):
|
||||
print(x)
|
||||
# jt.display_memory_info()
|
||||
x = superglue(data)['loss']
|
||||
x.sync()
|
||||
# print(x)
|
||||
|
||||
jt.sync_all(True)
|
||||
time1 = time.time()
|
||||
print("avg time:", (time1 - time0) / 20)
|
||||
return (time1 - time0) / 20
|
||||
|
||||
|
||||
class TestSuperglue(unittest.TestCase):
|
||||
def test(self):
|
||||
if not jt.has_cuda: return
|
||||
t1 = main()
|
||||
os.environ["use_fp16"] = "1"
|
||||
t2 = main()
|
||||
os.environ["use_fp16"] = "0"
|
||||
assert t1*0.55 > t2
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -17,7 +17,8 @@ def check(op, *args):
|
|||
x = convert(x)
|
||||
y = convert(y)
|
||||
# str match nan and inf
|
||||
assert x.dtype == y.dtype and x.shape == y.shape
|
||||
assert x.dtype == y.dtype and x.shape == y.shape, \
|
||||
(x.dtype, y.dtype, x.shape, y.shape)
|
||||
for a,b in zip(x.flatten(), y.flatten()):
|
||||
assert str(a)[:5] == str(b)[:5], (a,b)
|
||||
|
||||
|
@ -32,9 +33,10 @@ class TestUnaryOp(unittest.TestCase):
|
|||
check("logical_not", a)
|
||||
check("bitwise_not", a)
|
||||
b = np.array([1.1, 2.2, 3.3, 4.4, -1, 0])
|
||||
check("log", a.astype("float32"))
|
||||
check("exp", a.astype("float32"))
|
||||
check("sqrt", a.astype("float32"))
|
||||
type = "float16" if (jt.flags.amp_reg & 2) else "float32"
|
||||
check("log", a.astype(type))
|
||||
check("exp", a.astype(type))
|
||||
check("sqrt", a.astype(type))
|
||||
|
||||
def test_grad(self):
|
||||
ops = ["abs", "negative", "log", "exp", "sqrt",
|
||||
|
@ -60,7 +62,8 @@ class TestUnaryOp(unittest.TestCase):
|
|||
ja = jt.array(b)
|
||||
jb = eval(f"jt.{op}(ja)")
|
||||
jda = jt.grad(jb, ja)
|
||||
assert (np.allclose(jda.data, da)), (jda.data,da,op)
|
||||
tol = 1e-2 if jt.flags.amp_reg & 2 else 1e-6
|
||||
assert (np.allclose(jda.data, da, atol=tol, rtol=tol)), (jda.data,da,op)
|
||||
|
||||
def test_sigmoid(self):
|
||||
a = np.arange(-150,150, 10).astype("float32")
|
||||
|
@ -92,11 +95,26 @@ class TestUnaryOp(unittest.TestCase):
|
|||
np.testing.assert_allclose(y.data, y2.data)
|
||||
d = jt.grad(x2, y2)
|
||||
_, (dn,) = ngrad(lambda y: special.erfinv(y).sum(), [y], 1e-8)
|
||||
np.testing.assert_allclose(d.data, dn, atol=1e-6, rtol=1e-6)
|
||||
tol = 1e-3 if jt.flags.amp_reg & 2 else 1e-6
|
||||
np.testing.assert_allclose(d.data, dn, atol=tol, rtol=tol)
|
||||
|
||||
|
||||
class TestUnaryOpCuda(TestUnaryOp, test_cuda(2)):
|
||||
pass
|
||||
|
||||
class TestUnaryOpCudaFp16(TestUnaryOp, test_cuda(2)):
|
||||
def setUp(self):
|
||||
jt.flags.amp_reg = 2 | 4 | 8 | 16
|
||||
def tearDown(self):
|
||||
jt.flags.amp_reg = 0
|
||||
|
||||
class TestUnaryOpCudaFp16(TestUnaryOp, test_cuda(2)):
|
||||
def setUp(self):
|
||||
jt.flags.amp_reg = 2 | 4 | 8 | 16
|
||||
jt.flags.use_cuda = 1
|
||||
def tearDown(self):
|
||||
jt.flags.amp_reg = 0
|
||||
jt.flags.use_cuda = 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Binary file not shown.
Loading…
Reference in New Issue