mirror of https://github.com/Jittor/Jittor
611 lines
22 KiB
Python
611 lines
22 KiB
Python
# ***************************************************************
|
|
# Copyright (c) 2020 Jittor. Authors:
|
|
# Guowei Yang <471184555@qq.com>
|
|
# Guoye Yang <498731903@qq.com>
|
|
# Wenyang Zhou <576825820@qq.com>
|
|
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
|
# Dun Liang <randonlang@gmail.com>.
|
|
#
|
|
# All Rights Reserved.
|
|
# This file is subject to the terms and conditions defined in
|
|
# file 'LICENSE.txt', which is part of this source code package.
|
|
# ***************************************************************
|
|
import jittor as jt
|
|
from jittor import init, Module
|
|
import numpy as np
|
|
import math
|
|
from jittor.pool import Pool, pool, AdaptiveAvgPool2d
|
|
from jittor.optim import *
|
|
|
|
def matmul_transpose(a, b):
|
|
'''
|
|
returns a * b^T
|
|
'''
|
|
assert len(a.shape) >= 2 and len(b.shape) == 2
|
|
assert a.shape[-1] == b.shape[-1]
|
|
|
|
shape = list(a.shape)[:-1] + list(b.shape)
|
|
a = a.broadcast(shape, [len(shape)-2])
|
|
b = b.broadcast(shape)
|
|
return (a*b).sum(len(shape)-1)
|
|
|
|
def matmul(a, b):
|
|
assert len(a.shape) >= 2 and len(b.shape) == 2
|
|
assert a.shape[-1] == b.shape[-2]
|
|
|
|
shape = list(a.shape) + [b.shape[-1]]
|
|
a = a.broadcast(shape, [len(shape)-1])
|
|
b = b.broadcast(shape)
|
|
return (a*b).sum(len(shape)-2)
|
|
jt.Var.matmul = jt.Var.__matmul__ = matmul
|
|
jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))
|
|
|
|
def get_init_var_rand(shape, dtype):
|
|
return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32))
|
|
|
|
def relu(x): return jt.maximum(x, 0)
|
|
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
|
|
def relu6(x): return jt.minimum(jt.maximum(x, 0), 6)
|
|
|
|
class PReLU(Module):
|
|
def __init__(self, num_parameters=1, init_=0.25):
|
|
self.num_parameters = num_parameters
|
|
self.a = init.constant((num_parameters,), "float32", init_)
|
|
|
|
def execute(self, x):
|
|
if self.num_parameters != 1:
|
|
assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU"
|
|
return jt.maximum(0, x) + self.a.broadcast(x, [0,2,3]) * jt.minimum(0, x)
|
|
else:
|
|
return jt.maximum(0, x) + self.a * jt.minimum(0, x)
|
|
|
|
#TODO dims is 4 will cause slowly execution
|
|
def cross_entropy_loss(output, target, ignore_index=None):
|
|
if len(output.shape) == 4:
|
|
c_dim = output.shape[1]
|
|
output = output.transpose((0, 2, 3, 1))
|
|
output = output.reshape((-1, c_dim))
|
|
if ignore_index is not None:
|
|
target = jt.ternary(target==ignore_index,
|
|
jt.array(-1).broadcast(target), target)
|
|
mask = jt.logical_and(target >= 0, target < output.shape[1])
|
|
target = target.reshape((-1, ))
|
|
target = target.broadcast(output, [1])
|
|
target = target.index(1) == target
|
|
|
|
output = output - output.max([1], keepdims=True)
|
|
loss = output.exp().sum(1).log()
|
|
loss = loss - (output*target).sum(1)
|
|
if ignore_index is None:
|
|
return loss.mean()
|
|
else:
|
|
return loss.sum() / jt.maximum(mask.int().sum(), 1)
|
|
|
|
def mse_loss(output, target):
|
|
return (output-target).sqr().mean()
|
|
|
|
def bce_loss(output, target):
|
|
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
|
|
|
|
def l1_loss(output, target):
|
|
return (output-target).abs().mean()
|
|
|
|
class CrossEntropyLoss(Module):
|
|
def __init__(self):
|
|
pass
|
|
def execute(self, output, target):
|
|
return cross_entropy_loss(output, target)
|
|
|
|
class MSELoss(Module):
|
|
def __init__(self):
|
|
pass
|
|
def execute(self, output, target):
|
|
return mse_loss(output, target)
|
|
|
|
class BCELoss(Module):
|
|
def __init__(self):
|
|
pass
|
|
def execute(self, output, target):
|
|
return bce_loss(output, target)
|
|
|
|
class L1Loss(Module):
|
|
def __init__(self):
|
|
pass
|
|
def execute(self, output, target):
|
|
return l1_loss(output, target)
|
|
|
|
class BCEWithLogitsLoss(Module):
|
|
def __init__(self):
|
|
self.sigmoid = Sigmoid()
|
|
self.bce = BCELoss()
|
|
def execute(self, output, target):
|
|
output = self.sigmoid(output)
|
|
output = self.bce(output, target)
|
|
return output
|
|
|
|
def softmax(x, dim = None):
|
|
if dim is None:
|
|
x = (x - x.max()).exp()
|
|
ret = x / x.sum()
|
|
else:
|
|
x = (x-x.max(dim, keepdims=True)).exp()
|
|
ret = x / x.sum(dim, keepdims=True)
|
|
return ret
|
|
|
|
class Dropout(Module):
|
|
def __init__(self, p=0.5, is_train=False):
|
|
assert p >= 0 and p <= 1, "dropout probability has to be between 0 and 1, but got {}".format(p)
|
|
self.p = p
|
|
self.is_train = is_train
|
|
#TODO: test model.train() to change self.is_train
|
|
def execute(self, input):
|
|
output = input
|
|
if self.p > 0 and self.is_train:
|
|
if self.p == 1:
|
|
noise = jt.zeros(input.shape)
|
|
output = output * noise
|
|
else:
|
|
noise = jt.random(input.shape)
|
|
noise = (noise > self.p).int()
|
|
output = output * noise / (1.0 - self.p) # div keep prob
|
|
return output
|
|
|
|
class Linear(Module):
|
|
def __init__(self, in_features, out_features, bias=True):
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.weight = init.invariant_uniform((out_features, in_features), "float32")
|
|
bound = 1.0/math.sqrt(in_features)
|
|
self.bias = init.uniform((out_features,), "float32",-bound,bound) if bias else None
|
|
|
|
def execute(self, x):
|
|
x = matmul_transpose(x, self.weight)
|
|
if self.bias is not None:
|
|
return x + self.bias
|
|
return x
|
|
|
|
class BatchNorm(Module):
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
|
|
assert affine == None
|
|
|
|
self.sync = sync
|
|
self.num_features = num_features
|
|
self.is_train = is_train
|
|
self.eps = eps
|
|
self.momentum = momentum
|
|
self.weight = init.constant((num_features,), "float32", 1.0)
|
|
self.bias = init.constant((num_features,), "float32", 0.0)
|
|
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
|
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
|
|
|
def execute(self, x):
|
|
if self.is_train:
|
|
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
|
|
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
|
|
if self.sync and jt.mpi:
|
|
xmean = xmean.mpi_all_reduce("mean")
|
|
x2mean = x2mean.mpi_all_reduce("mean")
|
|
|
|
xvar = x2mean-xmean*xmean
|
|
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
|
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
|
|
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
|
|
else:
|
|
running_mean = self.running_mean.broadcast(x, [0,2,3])
|
|
running_var = self.running_var.broadcast(x, [0,2,3])
|
|
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
|
w = self.weight.broadcast(x, [0,2,3])
|
|
b = self.bias.broadcast(x, [0,2,3])
|
|
return norm_x * w + b
|
|
|
|
class BatchNorm1d(Module):
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
|
|
assert affine == None
|
|
self.sync = sync
|
|
self.num_features = num_features
|
|
self.is_train = is_train
|
|
self.eps = eps
|
|
self.momentum = momentum
|
|
self.weight = init.constant((num_features,), "float32", 1.0)
|
|
self.bias = init.constant((num_features,), "float32", 0.0)
|
|
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
|
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
|
|
|
def execute(self, x):
|
|
if self.is_train:
|
|
xmean = jt.mean(x, dims=[0], keepdims=1)
|
|
x2mean = jt.mean(x*x, dims=[0], keepdims=1)
|
|
|
|
if self.sync and jt.mpi:
|
|
xmean = xmean.mpi_all_reduce("mean")
|
|
x2mean = x2mean.mpi_all_reduce("mean")
|
|
|
|
xvar = x2mean-xmean*xmean
|
|
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
|
self.running_mean += (xmean.sum([0])-self.running_mean)*self.momentum
|
|
self.running_var += (xvar.sum([0])-self.running_var)*self.momentum
|
|
else:
|
|
running_mean = self.running_mean.broadcast(x, [0])
|
|
running_var = self.running_var.broadcast(x, [0])
|
|
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
|
w = self.weight.broadcast(x, [0])
|
|
b = self.bias.broadcast(x, [0])
|
|
return norm_x * w + b
|
|
|
|
class InstanceNorm2d(Module):
|
|
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=None, is_train=True, sync=True):
|
|
assert affine == None
|
|
self.sync = sync
|
|
self.num_features = num_features
|
|
self.is_train = is_train
|
|
self.eps = eps
|
|
self.momentum = momentum
|
|
self.weight = init.constant((num_features,), "float32", 1.0)
|
|
self.bias = init.constant((num_features,), "float32", 0.0)
|
|
|
|
def execute(self, x):
|
|
xmean = jt.mean(x, dims=[2,3], keepdims=1)
|
|
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
|
|
if self.sync and jt.mpi:
|
|
xmean = xmean.mpi_all_reduce("mean")
|
|
x2mean = x2mean.mpi_all_reduce("mean")
|
|
|
|
xvar = jt.maximum(x2mean-xmean*xmean, 0)
|
|
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
|
w = self.weight.broadcast(x, [0,2,3])
|
|
b = self.bias.broadcast(x, [0,2,3])
|
|
return norm_x * w + b
|
|
|
|
Relu = jt.make_module(relu)
|
|
ReLU = Relu
|
|
Leaky_relu = jt.make_module(leaky_relu, 2)
|
|
LeakyReLU = Leaky_relu
|
|
ReLU6 = jt.make_module(relu6)
|
|
Softmax = jt.make_module(softmax, 2)
|
|
|
|
class Conv(Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
|
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
|
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
|
|
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
|
self.groups = groups
|
|
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
|
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
|
Kh, Kw = self.kernel_size
|
|
self.groups = groups
|
|
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
|
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
|
|
|
self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
|
|
if bias:
|
|
self.bias = init.uniform([out_channels], dtype="float", low=-1, high=1)
|
|
else:
|
|
self.bias = None
|
|
|
|
def execute(self, x):
|
|
if self.groups == 1:
|
|
N,C,H,W = x.shape
|
|
Kh, Kw = self.kernel_size
|
|
assert C==self.in_channels
|
|
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
|
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
|
xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [
|
|
'i0', # Nid
|
|
'i2', # Cid
|
|
f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid
|
|
f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid
|
|
])
|
|
ww = self.weight.broadcast(xx.shape, [0,3,4])
|
|
yy = xx*ww
|
|
y = yy.sum([2,5,6]) # Kc, Kh, Kw
|
|
if self.bias is not None:
|
|
b = self.bias.broadcast(y.shape, [0,2,3])
|
|
y = y + b
|
|
return y
|
|
else:
|
|
N,C,H,W = x.shape
|
|
Kh, Kw = self.kernel_size
|
|
G = self.groups
|
|
CpG = C // G # channels per group
|
|
assert C==self.in_channels
|
|
oc = self.out_channels
|
|
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
|
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
|
xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [
|
|
'i0', # Nid
|
|
f'i1*{CpG}+i3', # Gid
|
|
f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid
|
|
f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid
|
|
])
|
|
xx.compile_options = {"G":G}
|
|
# w: [oc, CpG, Kh, Kw]
|
|
ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [
|
|
f'i1*{oc//G}+i2',
|
|
'i3',
|
|
'i6',
|
|
'i7'
|
|
])
|
|
yy = xx*ww
|
|
y = yy.reindex_reduce('add', [N, oc, oh, ow], [
|
|
'i0',
|
|
f'i1*{oc//G}+i2',
|
|
'i4',
|
|
'i5'
|
|
])
|
|
if self.bias is not None:
|
|
b = self.bias.broadcast(y.shape, [0,2,3])
|
|
y = y + b
|
|
return y
|
|
|
|
|
|
class ConvTranspose(Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
|
|
padding=0, output_padding=0, groups=1, bias=True, dilation=1):
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
|
|
# added
|
|
self.dilation = dilation
|
|
self.group = groups
|
|
assert groups==1, "Group conv not supported yet."
|
|
|
|
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
|
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
|
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
|
# added
|
|
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
|
|
self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0],
|
|
self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1])
|
|
self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
|
|
|
|
self.weight = init.relu_invariant_gauss((in_channels, out_channels) + self.kernel_size, dtype="float", mode="fan_out")
|
|
if bias:
|
|
self.bias = init.uniform([out_channels], dtype="float", low=-1, high=1)
|
|
else:
|
|
self.bias = None
|
|
|
|
def execute(self, x):
|
|
N,C,H,W = x.shape
|
|
i,o,h,w = self.weight.shape
|
|
assert C==i
|
|
stride_h, stride_w = self.stride
|
|
padding_h, padding_w = self.padding
|
|
dilation_h, dilation_w = self.dilation
|
|
|
|
h_out = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
|
w_out = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
|
out_shape = (N, o, h_out, w_out)
|
|
shape = (N, i, o, H, W, h, w)
|
|
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
|
|
ww = self.weight.broadcast(shape, (0, 3, 4)) # N,H,W
|
|
y = (ww*xx).reindex_reduce("add", out_shape, [
|
|
'i0', # N
|
|
'i2', # o
|
|
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
|
|
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
|
|
])
|
|
if self.bias is not None:
|
|
b = self.bias.broadcast(y.shape, [0,2,3])
|
|
y = y + b
|
|
return y
|
|
|
|
|
|
class ReflectionPad2d(Module):
|
|
def __init__(self, padding):
|
|
self.padding = padding
|
|
if isinstance(self.padding, int):
|
|
self.pl = self.padding
|
|
self.pr = self.padding
|
|
self.pt = self.padding
|
|
self.pb = self.padding
|
|
elif isinstance(self.padding, tuple):
|
|
self.pl, self.pr, self.pt, self.pb = self.padding
|
|
else:
|
|
raise TypeError(f"ReflectionPad2d padding just support int or tuple, but found {type(padding)}")
|
|
|
|
def execute(self, x):
|
|
n,c,h,w = x.shape
|
|
assert (self.pl < w and self.pr < w), f"padding_left and padding_right should be smaller than input width"
|
|
assert (self.pt < h and self.pb < h), f"padding_top and padding_bottom should be smaller than input height"
|
|
oh=h+self.pt+self.pb
|
|
ow=w+self.pl+self.pr
|
|
l = self.pl
|
|
r = self.pl + w - 1
|
|
t = self.pt
|
|
b = self.pt + h - 1
|
|
return x.reindex([n,c,oh,ow], ["i0","i1",
|
|
f"i2<{t} ? {t}-i2 : i2 > {b} ? {h-1+b}-i2 : i2-{t}",
|
|
f"i3<{l} ? {l}-i3 : i3 > {r} ? {w-1+r}-i3 : i3-{l}",
|
|
])
|
|
|
|
class ZeroPad2d(Module):
|
|
def __init__(self, padding):
|
|
self.padding = padding
|
|
if isinstance(self.padding, int):
|
|
self.pl = self.padding
|
|
self.pr = self.padding
|
|
self.pt = self.padding
|
|
self.pb = self.padding
|
|
elif isinstance(self.padding, tuple):
|
|
self.pl, self.pr, self.pt, self.pb = self.padding
|
|
else:
|
|
raise TypeError(f"ZeroPad2d padding just support int or tuple, but found {type(padding)}")
|
|
|
|
def execute(self, x):
|
|
n,c,h,w = x.shape
|
|
return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"])
|
|
|
|
class ConstantPad2d(Module):
|
|
def __init__(self, padding, value):
|
|
self.padding = padding
|
|
if isinstance(self.padding, int):
|
|
self.pl = self.padding
|
|
self.pr = self.padding
|
|
self.pt = self.padding
|
|
self.pb = self.padding
|
|
elif isinstance(self.padding, tuple):
|
|
self.pl, self.pr, self.pt, self.pb = self.padding
|
|
else:
|
|
raise TypeError(f"ConstantPad2d padding just support int or tuple, but found {type(padding)}")
|
|
self.value = value
|
|
|
|
def execute(self, x):
|
|
n,c,h,w = x.shape
|
|
return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"], overflow_value=self.value)
|
|
|
|
class ReplicationPad2d(Module):
|
|
def __init__(self, padding):
|
|
self.padding = padding
|
|
if isinstance(self.padding, int):
|
|
self.pl = self.padding
|
|
self.pr = self.padding
|
|
self.pt = self.padding
|
|
self.pb = self.padding
|
|
elif isinstance(self.padding, tuple):
|
|
self.pl, self.pr, self.pt, self.pb = self.padding
|
|
else:
|
|
raise TypeError(f"ReplicationPad2d padding just support int or tuple, but found {type(padding)}")
|
|
|
|
def execute(self, x):
|
|
n,c,h,w = x.shape
|
|
oh=h+self.pt+self.pb
|
|
ow=w+self.pl+self.pr
|
|
l = self.pl
|
|
r = self.pl + w - 1
|
|
t = self.pt
|
|
b = self.pt + h - 1
|
|
return x.reindex([n,c,oh,ow], ["i0","i1",
|
|
f"i2<{t} ? 0 : i2 > {b} ? {h-1} : i2-{t}",
|
|
f"i3<{l} ? 0 : i3 > {r} ? {w-1} : i3-{l}"
|
|
])
|
|
|
|
class Embedding(Module):
|
|
def __init__(self, num, dim):
|
|
self.num = num
|
|
self.dim = dim
|
|
self.weight = jt.init.gauss([num,dim],'float32').stop_grad()
|
|
|
|
def execute(self, x):
|
|
res = self.weight[x].reshape([x.shape[0],self.dim])
|
|
return res
|
|
|
|
class PixelShuffle(Module):
|
|
def __init__(self, upscale_factor):
|
|
self.upscale_factor = upscale_factor
|
|
|
|
def execute(self, x):
|
|
n,c,h,w = x.shape
|
|
r = self.upscale_factor
|
|
assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
|
|
return x.reindex([n,int(c/r**2),h*r,w*r], [
|
|
"i0",
|
|
f"i1*{r*r}+i2%{r}*{r}+i3%{r}",
|
|
f"i2/{r}",
|
|
f"i3/{r}"
|
|
])
|
|
|
|
class Tanh(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
def execute(self, x) :
|
|
return x.tanh()
|
|
|
|
class Sigmoid(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
def execute(self, x) :
|
|
return x.sigmoid()
|
|
|
|
class Resize(Module):
|
|
def __init__(self, size, mode="nearest", align_corners=False):
|
|
super().__init__()
|
|
self.size = size
|
|
self.mode = mode
|
|
self.align_corners = align_corners
|
|
def execute(self, x):
|
|
return resize(x, self.size, self.mode, self.align_corners)
|
|
|
|
def _interpolate(img, x, y, ids, mode):
|
|
if mode=="nearest":
|
|
return img.reindex([*ids, x.floor(), y.floor()])
|
|
if mode=="bilinear":
|
|
fx, fy = x.floor(), y.floor()
|
|
cx, cy = fx+1, fy+1
|
|
dx, dy = x-fx, y-fy
|
|
a = img.reindex_var([*ids, fx, fy])
|
|
b = img.reindex_var([*ids, cx, fy])
|
|
c = img.reindex_var([*ids, fx, cy])
|
|
d = img.reindex_var([*ids, cx, cy])
|
|
dnx, dny = 1-dx, 1-dy
|
|
ab = dx*b + dnx*a
|
|
cd = dx*d + dnx*c
|
|
o = ab*dny + cd*dy
|
|
return o
|
|
raise(f"Not support interpolation mode: {mode}")
|
|
|
|
def resize(img, size, mode="nearest", align_corners=False):
|
|
n,c,h,w = img.shape
|
|
H,W = size
|
|
nid, cid, hid, wid = jt.index((n,c,H,W))
|
|
if align_corners:
|
|
x = hid * ((h-1) / max(1, H-1))
|
|
y = wid * ((w-1) / max(1, W-1))
|
|
else:
|
|
x = hid * (h / H) + (h/H*0.5 - 0.5)
|
|
if H>h: x = x.clamp(0, h-1)
|
|
y = wid * (w / W) + (w/W*0.5 - 0.5)
|
|
if W>w: y = y.clamp(0, w-1)
|
|
return _interpolate(img, x, y, (nid,cid), mode)
|
|
|
|
def upsample(img, size, mode="nearest", align_corners=False):
|
|
n,c,h,w = img.shape
|
|
H,W = size
|
|
nid, cid, hid, wid = jt.index((n,c,H,W))
|
|
if align_corners:
|
|
x = hid * ((h-1) / max(1, H-1))
|
|
y = wid * ((w-1) / max(1, W-1))
|
|
else:
|
|
x = hid * (h / H)
|
|
y = wid * (w / W)
|
|
return _interpolate(img, x, y, (nid,cid), mode)
|
|
|
|
class Upsample(Module):
|
|
def __init__(self, scale_factor=None, mode='nearest'):
|
|
self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor)
|
|
self.mode = mode
|
|
|
|
def execute(self, x):
|
|
return upsample(x,
|
|
size=(
|
|
int(x.shape[2]*self.scale_factor[0]),
|
|
int(x.shape[3]*self.scale_factor[1])),
|
|
mode=self.mode)
|
|
|
|
class Sequential(Module):
|
|
def __init__(self, *args):
|
|
self.layers = list(args)
|
|
def __getitem__(self, idx):
|
|
return self.layers[idx]
|
|
def execute(self, x):
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
return x
|
|
def dfs(self, parents, k, callback, callback_leave):
|
|
n_children = len(self.layers)
|
|
ret = callback(parents, k, self, n_children)
|
|
if ret == False:
|
|
return
|
|
for k,v in enumerate(self.layers):
|
|
parents.append(self)
|
|
v.dfs(parents, k, callback, callback_leave)
|
|
parents.pop()
|
|
if callback_leave:
|
|
callback_leave(parents, k, self, n_children)
|
|
def append(self, mod):
|
|
self.layers.append(mod)
|
|
|
|
ModuleList = Sequential
|