mirror of https://github.com/Jittor/Jittor
2792 lines
101 KiB
Python
2792 lines
101 KiB
Python
# ***************************************************************
|
|
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
|
# Maintainers:
|
|
# Guowei Yang <471184555@qq.com>
|
|
# Guoye Yang <498731903@qq.com>
|
|
# Wenyang Zhou <576825820@qq.com>
|
|
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
|
# Dun Liang <randonlang@gmail.com>.
|
|
# Zheng-Ning Liu <lzhengning@gmail.com>
|
|
#
|
|
# This file is subject to the terms and conditions defined in
|
|
# file 'LICENSE.txt', which is part of this source code package.
|
|
# ***************************************************************
|
|
from abc import abstractmethod
|
|
from sys import breakpointhook
|
|
import jittor as jt
|
|
from jittor import flatten, init, Module
|
|
import numpy as np
|
|
import collections
|
|
import math
|
|
from collections import OrderedDict
|
|
from jittor.pool import *
|
|
from jittor.optim import *
|
|
from jittor.misc import _pair, _triple
|
|
from jittor_utils import LOG
|
|
|
|
|
|
def matmul_transpose(a, b):
|
|
'''
|
|
returns a * b^T
|
|
'''
|
|
assert a.shape[-1] == b.shape[-1], (a.shape, b.shape)
|
|
if len(a.shape) != 2:
|
|
aa = a.reshape((-1, a.shape[-1]))
|
|
cc = matmul_transpose(aa, b)
|
|
return cc.reshape(a.shape[:-1]+(-1,))
|
|
assert len(a.shape) == 2 and len(b.shape) == 2
|
|
|
|
shape = list(a.shape)[:-1] + list(b.shape)
|
|
with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4):
|
|
a = a.broadcast(shape, [len(shape)-2])
|
|
b = b.broadcast(shape)
|
|
return (a*b).sum(len(shape)-1)
|
|
|
|
|
|
def bmm_transpose(a, b):
|
|
'''
|
|
returns a * b^T
|
|
'''
|
|
if jt.flags.use_cuda and jt.compile_extern.cublas_ops:
|
|
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 1)
|
|
t = list(range(b.ndim))
|
|
t[-1], t[-2] = t[-2], t[-1]
|
|
return bmm(a, b.transpose(t))
|
|
|
|
|
|
def bmm(a, b):
|
|
''' batch matrix multiply,
|
|
shape of input a is [batch, n, m],
|
|
shape of input b is [batch, m, k],
|
|
return shape is [batch, n, k]
|
|
|
|
Example::
|
|
|
|
import jittor as jt
|
|
from jittor import nn
|
|
|
|
batch, n, m, k = 100, 5, 6, 7
|
|
|
|
a = jt.random((batch, n, m))
|
|
b = jt.random((batch, m, k))
|
|
c = nn.bmm(a, b)
|
|
'''
|
|
assert len(a.shape) > 2 and len(b.shape) > 2
|
|
return matmul(a, b)
|
|
|
|
def matmul(a, b):
|
|
''' matrix multiply,
|
|
|
|
Example::
|
|
|
|
a = jt.random([3])
|
|
b = jt.random([3])
|
|
c = jt.matmul(a, b)
|
|
assert c.shape == [1]
|
|
|
|
a = jt.random([3, 4])
|
|
b = jt.random([4])
|
|
c = jt.matmul(a, b)
|
|
assert c.shape == [3]
|
|
|
|
a = jt.random([10, 3, 4])
|
|
b = jt.random([4])
|
|
c = jt.matmul(a, b)
|
|
assert c.shape == [10, 3]
|
|
|
|
a = jt.random([10, 3, 4])
|
|
b = jt.random([4, 5])
|
|
c = jt.matmul(a, b)
|
|
assert c.shape == [10, 3, 5]
|
|
|
|
a = jt.random([10, 3, 4])
|
|
b = jt.random([10, 4, 5])
|
|
c = jt.matmul(a, b)
|
|
assert c.shape == [10, 3, 5]
|
|
|
|
a = jt.random([8, 1, 3, 4])
|
|
b = jt.random([10, 4, 5])
|
|
c = jt.matmul(a, b)
|
|
assert c.shape == [8, 10, 3, 5]
|
|
'''
|
|
with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4):
|
|
len_a = len(a.shape)
|
|
len_b = len(b.shape)
|
|
if len_b == 1:
|
|
# a: [n, m], b:[m], c:[n]
|
|
return (a*b).sum(-1)
|
|
if len_a == 1:
|
|
# a: [n], b:[n,k], c:[k]
|
|
return (a.broadcast(b, [-1]) * b).sum(0)
|
|
if len_a>=3 and len_a==len_b:
|
|
# bmm
|
|
# a: [..., n, m], b: [..., m, k], c:[..., n, k]
|
|
if jt.flags.use_cuda and jt.compile_extern.cublas_ops:
|
|
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
|
|
shape = []
|
|
len_c = max(len_a, len_b)
|
|
(n, m), (m_, k) = a.shape[-2:], b.shape[-2:]
|
|
assert m == m_, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}"
|
|
# a: [..., n, m]
|
|
# b: [..., m, k]
|
|
# cc:[..., n, m, k]
|
|
# -->
|
|
# 012
|
|
if len_b == 2 and len_a>2:
|
|
# TODO:ugly implementation for tuner
|
|
aa = a.reshape((-1, m))
|
|
cc = matmul(aa, b)
|
|
# print(a.shape, b.shape, cc.shape)
|
|
return cc.reshape(a.shape[:-1] + [k])
|
|
for i in range(len_c-2):
|
|
ai = len_a-(len_c-i)
|
|
bi = len_b-(len_c-i)
|
|
an = a.shape[ai] if ai>=0 else 1
|
|
bn = b.shape[bi] if bi>=0 else 1
|
|
if an!=1 and bn!=1:
|
|
assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}"
|
|
cn = max(an, bn)
|
|
shape.append(cn)
|
|
shape.extend([n, m, k])
|
|
a = a.broadcast(shape, [-1])
|
|
b = b.broadcast(shape, [-3])
|
|
return (a*b).sum(-2)
|
|
jt.Var.matmul = jt.Var.__matmul__ = matmul
|
|
jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))
|
|
|
|
def get_init_var_rand(shape, dtype):
|
|
return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32))
|
|
|
|
def relu(x):
|
|
r''' Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{ReLU6}(x) = \max(0,x)
|
|
|
|
:param x: the input var
|
|
:type x: jt.Var
|
|
|
|
Example:
|
|
>>> a = jt.randn(3)
|
|
>>> a
|
|
jt.Var([-0.38380373 1.1338731 6.128115 ], dtype=float32)
|
|
>>> nn.relu(a)
|
|
jt.Var([0. 1.1338731 6.128115 ], dtype=float32)
|
|
'''
|
|
return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x))
|
|
|
|
|
|
def leaky_relu(x, scale=0.01):
|
|
r''' Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{LeakyRELU}(x) =
|
|
\begin{cases}
|
|
x, & \text{ if } x \geq 0 \\
|
|
\text{scale} \times x, & \text{ otherwise }
|
|
\end{cases}
|
|
|
|
:param x: the input var
|
|
:type x: jt.Var
|
|
|
|
:param scale: the :math:`\scale` value for the leaky relu formulation. Default: 0.01
|
|
:param scale: float, optional
|
|
|
|
Example:
|
|
>>> a = jt.randn(3)
|
|
>>> a
|
|
jt.Var([-0.38380373 1.1338731 6.128115 ], dtype=float32)
|
|
>>> nn.leaky_relu(a)
|
|
jt.Var([-3.8380371e-03 1.1338731e+00 6.1281152e+00], dtype=float32)
|
|
'''
|
|
return jt.ternary(x>0, x, x*scale)
|
|
|
|
def relu6(x):
|
|
r''' Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{ReLU6}(x) = \min(\max(0,x), 6)
|
|
|
|
:param x: the input var
|
|
:type x: jt.Var
|
|
|
|
Example:
|
|
>>> a = jt.randn(3)
|
|
>>> a
|
|
jt.Var([-0.38380373 1.1338731 6.128115 ], dtype=float32)
|
|
>>> nn.relu6(a)
|
|
jt.Var([0. 1.1338731 6. ], dtype=float32)
|
|
'''
|
|
return jt.minimum(jt.maximum(x, 0.0), 6.0)
|
|
|
|
def elu(x: jt.Var, alpha: float = 1.0) -> jt.Var:
|
|
r''' Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{ELU}(x) = \begin{cases}
|
|
x, & \text{ if } x > 0\\
|
|
\alpha * (\exp(x) - 1), & \text{ if } x \leq 0
|
|
\end{cases}
|
|
|
|
:param x: the input var
|
|
:type x: jt.Var
|
|
|
|
:param alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
|
|
:param alpha: float, optional
|
|
|
|
Example:
|
|
>>> a = jt.randn(3)
|
|
>>> a
|
|
jt.Var([-0.38380373 -1.1338731 2.128115 ], dtype=float32)
|
|
>>> nn.elu(a)
|
|
jt.Var([-0.31873488 -0.6782155 2.128115 ], dtype=float32)
|
|
'''
|
|
return jt.ternary(x>0,x,alpha*(x.exp()-1))
|
|
|
|
def sign(x: jt.Var) -> jt.Var:
|
|
''' returns the signs of elements of x
|
|
|
|
:param x: the input Var
|
|
:type x: jt.Var
|
|
|
|
Example:
|
|
>>> a = jt.float32([0.99, 0, -0.99])
|
|
>>> nn.sign(a)
|
|
jt.Var([ 1. 0. -1.], dtype=float32)
|
|
'''
|
|
one = jt.ones(x.shape)
|
|
x = jt.ternary(x>0, one, x)
|
|
return jt.ternary(x<0, -one, x)
|
|
|
|
def gelu(x):
|
|
r''' Applies the element-wise function:
|
|
|
|
.. math:: \text{GELU}(x) = x * \Phi(x)
|
|
|
|
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
|
|
|
|
:param x: the input var
|
|
:type x: jt.Var
|
|
|
|
Example:
|
|
>>> a = jt.randn(3)
|
|
>>> a
|
|
jt.Var([-0.38380373 -1.1338731 2.128115 ], dtype=float32)
|
|
>>> nn.gelu(a)
|
|
jt.Var([-0.134547 0.9882567 6.128115 ], dtype=float32)
|
|
'''
|
|
_sqrt2 = 1.4142135623730951
|
|
erf = jt.erf(x/_sqrt2)+1
|
|
r = erf*x*.5
|
|
return r
|
|
|
|
class ELU(Module):
|
|
r''' Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{ELU}(x) = \begin{cases}
|
|
x, & \text{ if } x > 0\\
|
|
\alpha * (\exp(x) - 1), & \text{ if } x \leq 0
|
|
\end{cases}
|
|
|
|
:param x: the input var
|
|
:type x: jt.Var
|
|
|
|
:param alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
|
|
:param alpha: float, optional
|
|
|
|
Example:
|
|
>>> a = jt.randn(3)
|
|
>>> a
|
|
jt.Var([-0.38380373 -1.1338731 2.128115 ], dtype=float32)
|
|
>>> nn.elu(a)
|
|
jt.Var([-0.31873488 -0.6782155 2.128115 ], dtype=float32)
|
|
'''
|
|
def __init__(self,alpha=1.0):
|
|
self.alpha=alpha
|
|
|
|
def execute(self,x):
|
|
return elu(x,self.alpha)
|
|
|
|
class PReLU(Module):
|
|
r''' Applies the element-wise function:
|
|
|
|
.. math::
|
|
\text{PReLU}(x) =
|
|
\begin{cases}
|
|
x, & \text{ if } x \geq 0 \\
|
|
ax, & \text{ otherwise }
|
|
\end{cases}
|
|
|
|
:param x: the input var
|
|
:type x: jt.Var
|
|
|
|
:param num_parameters: number of :math:`a` to learn, can be either 1 or the number of channels at input. Default: 1
|
|
:type num_parameters: int, optional
|
|
|
|
:param init: the initial value of :math:`a`. Default: 0.25
|
|
:param init: float, optional
|
|
|
|
Example:
|
|
>>> a = jt.randn(3)
|
|
>>> prelu = nn.PReLU()
|
|
>>> prelu(a)
|
|
jt.Var([-0.09595093 1.1338731 6.128115 ], dtype=float32)
|
|
'''
|
|
|
|
def __init__(self, num_parameters=1, init_=0.25):
|
|
self.num_parameters = num_parameters
|
|
self.weight = init.constant((num_parameters,), "float32", init_)
|
|
|
|
def execute(self, x):
|
|
if self.num_parameters != 1:
|
|
assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU"
|
|
return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) * jt.minimum(0, x)
|
|
else:
|
|
return jt.maximum(0, x) + self.weight * jt.minimum(0, x)
|
|
|
|
#TODO dims is 4 will cause slowly execution
|
|
def cross_entropy_loss(output, target, weight=None, ignore_index=None,reduction='sum'):
|
|
if len(output.shape) == 4:
|
|
c_dim = output.shape[1]
|
|
output = output.transpose((0, 2, 3, 1))
|
|
output = output.reshape((-1, c_dim))
|
|
|
|
target = target.reshape((-1, ))
|
|
target_weight = jt.ones(target.shape[0], dtype='float32')
|
|
if weight is not None:
|
|
target_weight = weight[target]
|
|
if ignore_index is not None:
|
|
target_weight = jt.ternary(
|
|
target==ignore_index,
|
|
jt.array(0).broadcast(target_weight),
|
|
target_weight
|
|
)
|
|
|
|
target = target.broadcast(output, [1])
|
|
target = target.index(1) == target
|
|
|
|
output = output - output.max([1], keepdims=True)
|
|
logsum = output.exp().sum(1).log()
|
|
loss = (logsum - (output*target).sum(1)) * target_weight
|
|
if reduction == 'sum':
|
|
return loss.sum() / target_weight.sum()
|
|
elif reduction == 'mean':
|
|
return loss.mean() / target_weight.mean()
|
|
else:
|
|
return loss / target_weight
|
|
|
|
def mse_loss(output, target):
|
|
return (output-target).sqr().mean()
|
|
|
|
def bce_loss(output, target, weight=None, size_average=True):
|
|
loss = - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20)))
|
|
|
|
if weight is not None:
|
|
loss *= weight
|
|
|
|
if size_average:
|
|
return loss.mean()
|
|
else:
|
|
return loss.sum()
|
|
|
|
def l1_loss(output, target):
|
|
return (output-target).abs().mean()
|
|
|
|
|
|
def smooth_l1_loss(y_true, y_pred,reduction="mean"):
|
|
"""Implements Smooth-L1 loss.
|
|
y_true and y_pred are typically: [N, 4], but could be any shape.
|
|
|
|
Args:
|
|
y_true - ground truth
|
|
y_pred - predictions
|
|
reduction - the mode of cal loss which must be in ['mean','sum','none']
|
|
"""
|
|
diff = jt.abs(y_true - y_pred)
|
|
less_than_one = (diff<1.0).float32()
|
|
loss = (less_than_one * 0.5 * diff.sqr()) + (1 - less_than_one) * (diff - 0.5)
|
|
if reduction=="mean":
|
|
return loss.mean()
|
|
elif reduction=="sum":
|
|
return loss.sum()
|
|
elif reduction=="none":
|
|
return loss
|
|
else:
|
|
raise ValueError(f'not support {reduction}')
|
|
|
|
def nll_loss(output,target,weight=None,ignore_index=-100,reduction='mean'):
|
|
assert output.ndim<=2 and output.ndim>0 and target.ndim==1
|
|
n_classes = output.shape[-1]
|
|
assert weight is None or weight.numel()==n_classes
|
|
assert ignore_index<0 or ignore_index<n_classes
|
|
if weight is None:
|
|
weight = jt.ones((n_classes,))
|
|
if ignore_index>0:
|
|
weight[ignore_index]=0
|
|
if output.ndim==2:
|
|
index = jt.index((output.shape[0],),dim=0)
|
|
loss = -output[index,target]*weight[target]
|
|
else:
|
|
loss = -output[target[0]]*weight[target[0]]
|
|
if reduction=="mean":
|
|
total_weight = weight[target].sum() if output.ndim==2 else weight[target[0]].sum()
|
|
return loss.sum()/total_weight
|
|
elif reduction=="sum":
|
|
return loss.sum()
|
|
elif reduction=="none":
|
|
return loss
|
|
else:
|
|
raise ValueError(f'not support {reduction}')
|
|
|
|
class CrossEntropyLoss(Module):
|
|
def __init__(self, weight=None, ignore_index=None):
|
|
self.weight = weight
|
|
self.ignore_index = ignore_index
|
|
|
|
def execute(self, output, target):
|
|
return cross_entropy_loss(output, target, self.weight, self.ignore_index)
|
|
|
|
class MSELoss(Module):
|
|
def __init__(self):
|
|
pass
|
|
def execute(self, output, target):
|
|
return mse_loss(output, target)
|
|
|
|
class BCELoss(Module):
|
|
def __init__(self, weight=None, size_average=True):
|
|
self.weight = weight
|
|
self.size_average = size_average
|
|
def execute(self, output, target):
|
|
return bce_loss(output, target, self.weight, self.size_average)
|
|
|
|
class L1Loss(Module):
|
|
def __init__(self):
|
|
pass
|
|
def execute(self, output, target):
|
|
return l1_loss(output, target)
|
|
|
|
def binary_cross_entropy_with_logits(output, target, weight=None, pos_weight=None, size_average=True):
|
|
max_val = jt.clamp(-output,min_v=0)
|
|
if pos_weight is not None:
|
|
log_weight = (pos_weight-1)*target + 1
|
|
loss = (1-target)*output+(log_weight*(((-max_val).exp()+(-output - max_val).exp()).log()+max_val))
|
|
else:
|
|
loss = (1-target)*output+max_val+((-max_val).exp()+(-output -max_val).exp()).log()
|
|
if weight is not None:
|
|
loss *=weight
|
|
|
|
if size_average:
|
|
return loss.mean()
|
|
else:
|
|
return loss.sum()
|
|
|
|
class BCEWithLogitsLoss(Module):
|
|
def __init__(self, weight=None, pos_weight=None, size_average=True):
|
|
self.pos_weight = pos_weight
|
|
self.weight = weight
|
|
self.size_average = size_average
|
|
|
|
def execute(self, output, target):
|
|
return binary_cross_entropy_with_logits(output,target,self.weight,self.pos_weight,self.size_average)
|
|
|
|
def softmax(x, dim=None, log=False):
|
|
import jittor.other.code_softmax as code_softmax
|
|
if code_softmax.can_softmax_v1(x, dim):
|
|
return code_softmax.softmax_v1(x, log)
|
|
if dim is None:
|
|
x = (x - x.max()).exp()
|
|
ret = x / x.sum()
|
|
else:
|
|
x = (x-x.max(dim, keepdims=True)).exp()
|
|
ret = x / x.sum(dim, keepdims=True)
|
|
if log: return ret.log()
|
|
return ret
|
|
jt.Var.softmax = softmax
|
|
|
|
def log_softmax(x,dim=None):
|
|
return softmax(x,dim=dim, log=True)
|
|
jt.Var.log_softmax = log_softmax
|
|
|
|
def log_sigmoid(x):
|
|
return jt.log(jt.sigmoid(x))
|
|
jt.Var.log_sigmoid = log_sigmoid
|
|
|
|
def logsumexp(x, dim, keepdim=False):
|
|
return x.exp().sum(dim, keepdim).log()
|
|
jt.Var.logsumexp = logsumexp
|
|
|
|
class Identity(Module):
|
|
def __init__(self, *args, **kwargs):
|
|
super(Identity, self).__init__()
|
|
|
|
def execute(self, input):
|
|
return input
|
|
|
|
def identity(input): return input
|
|
|
|
class Dropout(Module):
|
|
def __init__(self, p=0.5, is_train=False):
|
|
assert p >= 0 and p <= 1, "dropout probability has to be between 0 and 1, but got {}".format(p)
|
|
self.p = p
|
|
self.is_train = is_train
|
|
#TODO: test model.train() to change self.is_train
|
|
def execute(self, input):
|
|
output = input
|
|
if self.p > 0 and self.is_train:
|
|
if self.p == 1:
|
|
noise = jt.zeros(input.shape)
|
|
output = output * noise
|
|
else:
|
|
noise = jt.random(input.shape)
|
|
noise = (noise > self.p).int()
|
|
output = output * noise / (1.0 - self.p) # div keep prob
|
|
return output
|
|
|
|
def dropout(x,p=0.5,is_train=False):
|
|
return Dropout(p=p,is_train=is_train)(x)
|
|
|
|
class Linear(Module):
|
|
def __init__(self, in_features, out_features, bias=True):
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.weight = init.invariant_uniform((out_features, in_features), "float32")
|
|
bound = 1.0/math.sqrt(in_features)
|
|
self.bias = init.uniform((out_features,), "float32",-bound,bound) if bias else None
|
|
|
|
def execute(self, x):
|
|
x = matmul_transpose(x, self.weight)
|
|
if self.bias is not None:
|
|
return x + self.bias
|
|
return x
|
|
|
|
class BatchNorm(Module):
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=True, sync=True):
|
|
self.sync = sync
|
|
self.num_features = num_features
|
|
self.is_train = is_train
|
|
self.eps = eps
|
|
self.momentum = momentum
|
|
self.affine = affine
|
|
self.weight = init.constant((num_features,), "float32", 1.0) if affine else 1.0
|
|
self.bias = init.constant((num_features,), "float32", 0.0) if affine else 0.0
|
|
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
|
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
|
|
|
def execute(self, x):
|
|
dims = [0]+list(range(2,x.ndim))
|
|
if self.is_train:
|
|
xmean = jt.mean(x, dims=dims)
|
|
x2mean = jt.mean(x*x, dims=dims)
|
|
if self.sync and jt.in_mpi:
|
|
xmean = xmean.mpi_all_reduce("mean")
|
|
x2mean = x2mean.mpi_all_reduce("mean")
|
|
|
|
xvar = (x2mean-xmean*xmean).maximum(0.0)
|
|
w = self.weight / jt.sqrt(xvar+self.eps)
|
|
b = self.bias - xmean * w
|
|
norm_x = x * w.broadcast(x, dims) + b.broadcast(x, dims)
|
|
|
|
self.running_mean.update(self.running_mean +
|
|
(xmean.reshape((-1,)) - self.running_mean) * self.momentum)
|
|
self.running_var.update(self.running_var +
|
|
(xvar.reshape((-1,))-self.running_var)*self.momentum)
|
|
return norm_x
|
|
else:
|
|
w = self.weight / jt.sqrt(self.running_var+self.eps)
|
|
b = self.bias - self.running_mean * w
|
|
norm_x = x * w.broadcast(x, dims) + b.broadcast(x, dims)
|
|
return norm_x
|
|
|
|
BatchNorm3d = BatchNorm2d = BatchNorm1d = BatchNorm
|
|
|
|
def batch_norm(x, running_mean, running_var, weight=1, bias=0, training=False, momentum=0.1, eps=1e-05):
|
|
dims = [0]+list(range(2,x.ndim))
|
|
assert not training
|
|
w = weight / jt.sqrt(running_var+eps)
|
|
b = bias - running_mean * w
|
|
norm_x = x * w.broadcast(x, dims) + b.broadcast(x, dims)
|
|
return norm_x
|
|
|
|
|
|
class InstanceNorm(Module):
|
|
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, is_train=True, sync=True):
|
|
self.sync = sync
|
|
self.num_features = num_features
|
|
self.is_train = is_train
|
|
self.eps = eps
|
|
self.momentum = momentum
|
|
|
|
self.affine = affine
|
|
self.weight = init.constant((num_features,), "float32", 1.0) if affine else 1.0
|
|
self.bias = init.constant((num_features,), "float32", 0.0) if affine else 0.0
|
|
|
|
def execute(self, x):
|
|
dims = list(range(2,x.ndim))
|
|
xmean = jt.mean(x, dims=dims)
|
|
x2mean = jt.mean(x*x, dims=dims)
|
|
|
|
xvar = (x2mean-xmean*xmean).maximum(0.0)
|
|
w = self.weight / jt.sqrt(xvar+self.eps)
|
|
b = self.bias - xmean * w
|
|
return x * w.broadcast(x, dims) + b.broadcast(x, dims)
|
|
|
|
InstanceNorm3d = InstanceNorm2d = InstanceNorm1d = InstanceNorm
|
|
|
|
def instance_norm(x,
|
|
running_mean = None,
|
|
running_var = None,
|
|
weight = 1,
|
|
bias = 0,
|
|
momentum = 0.1,
|
|
eps = 1e-5):
|
|
dims = list(range(2,x.ndim))
|
|
xmean = jt.mean(x, dims=dims)
|
|
x2mean = jt.mean(x*x, dims=dims)
|
|
|
|
xvar = (x2mean-xmean*xmean).maximum(0.0)
|
|
w = weight / jt.sqrt(xvar+eps)
|
|
b = bias - xmean * w
|
|
return x * w.broadcast(x, dims) + b.broadcast(x, dims)
|
|
|
|
class LayerNorm(Module):
|
|
def __init__(self, normalized_shape, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
|
|
if isinstance(normalized_shape, int):
|
|
normalized_shape = (normalized_shape,)
|
|
self.normalized_shape = tuple(normalized_shape)
|
|
self.eps = eps
|
|
self.elementwise_affine = elementwise_affine
|
|
self.weight = init.constant(normalized_shape, "float32", 1.0) if elementwise_affine else 1.0
|
|
self.bias = init.constant(normalized_shape, "float32", 0.0) if elementwise_affine else 0.0
|
|
|
|
def execute(self, x):
|
|
dims = [-i for i in range(len(self.normalized_shape), 0, -1)]
|
|
xmean = jt.mean(x, dims=dims, keepdims=1)
|
|
x2mean = jt.mean(x*x, dims=dims, keepdims=1)
|
|
|
|
xvar = (x2mean-xmean*xmean).maximum(0.0)
|
|
w = self.weight / jt.sqrt(xvar+self.eps)
|
|
b = self.bias - xmean * w
|
|
return x * w + b
|
|
|
|
|
|
LayerNorm3d = LayerNorm2d = LayerNorm1d = LayerNorm
|
|
|
|
def layer_norm(x,
|
|
normalized_shape,
|
|
weight = 1,
|
|
bias = 0,
|
|
eps: float = 1e-5,
|
|
elementwise_affine: bool = True):
|
|
dims = [-i for i in range(len(normalized_shape), 0, -1)]
|
|
xmean = jt.mean(x, dims=dims, keepdims=1)
|
|
x2mean = jt.mean(x*x, dims=dims, keepdims=1)
|
|
|
|
xvar = (x2mean-xmean*xmean).maximum(0.0)
|
|
w = weight / jt.sqrt(xvar+eps)
|
|
b = bias - xmean * w
|
|
return x * w + b
|
|
|
|
class GroupNorm(Module):
|
|
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, is_train=True):
|
|
self.num_groups = num_groups
|
|
self.num_channels = num_channels
|
|
self.eps = eps
|
|
|
|
self.affine = affine
|
|
self.weight = init.constant((num_channels,), "float32", 1.0) if affine else 1.0
|
|
self.bias = init.constant((num_channels,), "float32", 0.0) if affine else 0.0
|
|
|
|
def execute(self, x):
|
|
N = x.shape[0]
|
|
C = self.num_channels
|
|
output_shape = (N,-1)
|
|
# TODO: 3d group norm
|
|
if x.ndim==4:
|
|
output_shape = x.shape
|
|
assert C % self.num_groups == 0
|
|
x = x.reshape((N, self.num_groups, C//self.num_groups, -1))
|
|
xmean = jt.mean(x, dims=[2,3]).reshape((N, self.num_groups, 1))
|
|
x2mean = jt.mean(x*x, dims=[2,3]).reshape((N, self.num_groups, 1))
|
|
xvar = (x2mean-xmean*xmean).maximum(0.0)
|
|
|
|
if self.affine:
|
|
w = self.weight.reshape((1, self.num_groups, -1))
|
|
b = self.bias.reshape((1, self.num_groups, -1))
|
|
else:
|
|
w = 1
|
|
b = 0
|
|
w = w / jt.sqrt(xvar+self.eps)
|
|
b = b - xmean * w
|
|
x = x * w.broadcast(x, [3]) + b.broadcast(x, [3])
|
|
return x.reshape(output_shape)
|
|
|
|
def group_norm(x,
|
|
num_groups,
|
|
weight = 1,
|
|
bias = 0,
|
|
eps=1e-05):
|
|
N = x.shape[0]
|
|
C = x.shape[1]
|
|
output_shape = (N,-1)
|
|
# TODO: 3d group norm
|
|
if x.ndim==4:
|
|
output_shape = x.shape
|
|
assert C % num_groups == 0
|
|
x = x.reshape((N, num_groups, C//num_groups, -1))
|
|
xmean = jt.mean(x, dims=[2,3]).reshape((N, num_groups, 1))
|
|
x2mean = jt.mean(x*x, dims=[2,3]).reshape((N, num_groups, 1))
|
|
xvar = (x2mean-xmean*xmean).maximum(0.0)
|
|
|
|
if isinstance(weight, jt.Var):
|
|
weight = weight.reshape((1, num_groups, -1))
|
|
if isinstance(bias, jt.Var):
|
|
bias = bias.reshape((1, num_groups, -1))
|
|
weight = weight / jt.sqrt(xvar+eps)
|
|
bias = bias - xmean * weight
|
|
x = x * weight.broadcast(x, [3]) + bias.broadcast(x, [3])
|
|
return x.reshape(output_shape)
|
|
|
|
|
|
Relu = jt.make_module(relu)
|
|
ReLU = Relu
|
|
Leaky_relu = jt.make_module(leaky_relu, 2)
|
|
LeakyReLU = Leaky_relu
|
|
ReLU6 = jt.make_module(relu6)
|
|
Softmax = jt.make_module(softmax, 2)
|
|
GELU = jt.make_module(gelu)
|
|
|
|
from jittor.depthwise_conv import DepthwiseConv
|
|
|
|
class Conv(Module):
|
|
''' Applies a 2D convolution over an input signal composed of several input planes.
|
|
|
|
:param in_channels: Number of channels in the input feature map
|
|
:type in_channels: int
|
|
|
|
:param out_channels: Number of channels in the output feature map
|
|
:type out_channels: int
|
|
|
|
:param kernel_size: Size of the convolving kernel
|
|
:type kernel_size: int or tuple
|
|
|
|
:param stride: Stride of the convolution. Default: 1
|
|
:type stride: int or tuple, optional
|
|
|
|
:param padding: Padding added to all four sides of the input. Default: 0
|
|
:type padding: int or tuple, optional
|
|
|
|
:param dilation: Spacing between kernel elements. Default: 1
|
|
:type dilation: int or tuple, optional
|
|
|
|
:param groups: Number of blocked connections from input channels to output channels. Default: 1
|
|
:type groups: int, optional
|
|
|
|
:param bias: If True, adds a learnable bias to the output. Default: True
|
|
:type bias: bool, optional
|
|
|
|
Example:
|
|
|
|
>>> conv = nn.Conv2d(24, 32, 3)
|
|
>>> conv = nn.Conv2d(24, 32, (3,3))
|
|
>>> conv = nn.Conv2d(24, 32, 3, stride=2, padding=1)
|
|
>>> conv = nn.Conv2d(24, 32, 3, dilation=(3, 1))
|
|
>>> input = jt.randn(4, 24, 100, 100)
|
|
>>> output = conv(input)
|
|
'''
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
|
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
|
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
|
|
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
|
self.groups = groups
|
|
self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels
|
|
if self.is_depthwise_conv and jt.flags.use_cuda:
|
|
self.depthwise_conv = DepthwiseConv(stride, padding, dilation)
|
|
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
|
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
|
Kh, Kw = self.kernel_size
|
|
|
|
# self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
|
|
self.weight = init.invariant_uniform([out_channels, in_channels//groups, Kh, Kw], dtype="float")
|
|
if bias:
|
|
fan=1
|
|
for i in self.weight.shape[1:]:
|
|
fan *= i
|
|
bound = 1 / math.sqrt(fan)
|
|
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
|
|
else:
|
|
self.bias = None
|
|
|
|
def execute(self, x):
|
|
if self.is_depthwise_conv and jt.flags.use_cuda:
|
|
y = self.depthwise_conv(x, self.weight)
|
|
if self.bias is not None:
|
|
b = self.bias.broadcast(y.shape, [0,2,3])
|
|
y = y + b
|
|
return y
|
|
elif self.groups == 1:
|
|
N,C,H,W = x.shape
|
|
Kh, Kw = self.kernel_size
|
|
assert C==self.in_channels
|
|
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
|
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
|
assert oh>0 and ow>0
|
|
with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4):
|
|
xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [
|
|
'i0', # Nid
|
|
'i2', # Cid
|
|
f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid
|
|
f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid
|
|
])
|
|
ww = self.weight.broadcast(xx.shape, [0,3,4])
|
|
yy = xx*ww
|
|
y = yy.sum([2,5,6]) # Kc, Kh, Kw
|
|
if self.bias is not None:
|
|
b = self.bias.broadcast(y.shape, [0,2,3])
|
|
y = y + b
|
|
return y
|
|
else:
|
|
N,C,H,W = x.shape
|
|
Kh, Kw = self.kernel_size
|
|
G = self.groups
|
|
CpG = C // G # channels per group
|
|
assert C==self.in_channels
|
|
oc = self.out_channels
|
|
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
|
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
|
assert oh>0 and ow>0
|
|
xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [
|
|
'i0', # Nid
|
|
f'i1*{CpG}+i3', # Gid
|
|
f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid
|
|
f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid
|
|
])
|
|
# w: [oc, CpG, Kh, Kw]
|
|
ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [
|
|
f'i1*{oc//G}+i2',
|
|
'i3',
|
|
'i6',
|
|
'i7'
|
|
])
|
|
ww.compile_options = xx.compile_options = {"G":G,"C":C}
|
|
yy = xx*ww
|
|
y = yy.reindex_reduce('add', [N, oc, oh, ow], [
|
|
'i0',
|
|
f'i1*{oc//G}+i2',
|
|
'i4',
|
|
'i5'
|
|
])
|
|
if self.bias is not None:
|
|
b = self.bias.broadcast(y.shape, [0,2,3])
|
|
y = y + b
|
|
return y
|
|
|
|
Conv2d = Conv
|
|
|
|
class Conv1d(Module):
|
|
''' Applies a 1D convolution over an input signal composed of several input planes.
|
|
|
|
:param in_channels: Number of channels in the input feature map
|
|
:type in_channels: int
|
|
|
|
:param out_channels: Number of channels in the output feature map
|
|
:type out_channels: int
|
|
|
|
:param kernel_size: Size of the convolving kernel
|
|
:type kernel_size: int or tuple
|
|
|
|
:param stride: Stride of the convolution. Default: 1
|
|
:type stride: int or tuple, optional
|
|
|
|
:param padding: Padding added to all four sides of the input. Default: 0
|
|
:type padding: int or tuple, optional
|
|
|
|
:param dilation: Spacing between kernel elements. Default: 1
|
|
:type dilation: int or tuple, optional
|
|
|
|
:param groups: Number of blocked connections from input channels to output channels. Default: 1
|
|
:type groups: int, optional
|
|
|
|
:param bias: If True, adds a learnable bias to the output. Default: True
|
|
:type bias: bool, optional
|
|
|
|
Example:
|
|
|
|
>>> conv = nn.Conv1d(24, 32, 3)
|
|
>>> conv = nn.Conv1d(24, 32, (3,3))
|
|
>>> conv = nn.Conv1d(24, 32, 3, stride=2, padding=1)
|
|
>>> conv = nn.Conv1d(24, 32, 3, dilation=(3, 1))
|
|
>>> input = jt.randn(4, 24, 100)
|
|
>>> output = conv(input)
|
|
'''
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_size = (kernel_size, 1)
|
|
self.stride = (stride, 1)
|
|
self.padding = (padding, 0)
|
|
self.dilation = (dilation, 1)
|
|
self.groups = groups
|
|
self.bias = bias
|
|
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
|
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
|
# using list to escape module dfs
|
|
self._conv = [Conv(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias)]
|
|
self.weight = self._conv[0].weight.squeeze(-1)
|
|
self.bias = self._conv[0].bias
|
|
|
|
def execute(self, x):
|
|
N,C,D = x.shape
|
|
assert C==self.in_channels
|
|
self._conv[0].weight = self.weight.unsqueeze(-1)
|
|
x = x.unsqueeze(-1)
|
|
x = self._conv[0](x)
|
|
y = x.squeeze(-1)
|
|
return y
|
|
|
|
class Conv3d(Module):
|
|
''' Applies a 3D convolution over an input signal composed of several input planes.
|
|
|
|
:param in_channels: Number of channels in the input feature map
|
|
:type in_channels: int
|
|
|
|
:param out_channels: Number of channels in the output feature map
|
|
:type out_channels: int
|
|
|
|
:param kernel_size: Size of the convolving kernel
|
|
:type kernel_size: int or tuple
|
|
|
|
:param stride: Stride of the convolution. Default: 1
|
|
:type stride: int or tuple, optional
|
|
|
|
:param padding: Padding added to all four sides of the input. Default: 0
|
|
:type padding: int or tuple, optional
|
|
|
|
:param dilation: Spacing between kernel elements. Default: 1
|
|
:type dilation: int or tuple, optional
|
|
|
|
:param groups: Number of blocked connections from input channels to output channels. Default: 1
|
|
:type groups: int, optional
|
|
|
|
:param bias: If True, adds a learnable bias to the output. Default: True
|
|
:type bias: bool, optional
|
|
|
|
Example:
|
|
|
|
>>> conv = nn.Conv3d(24, 32, 3)
|
|
>>> conv = nn.Conv3d(24, 32, (3,3))
|
|
>>> conv = nn.Conv3d(24, 32, 3, stride=2, padding=1)
|
|
>>> conv = nn.Conv3d(24, 32, 3, dilation=(3, 1))
|
|
>>> input = jt.randn(4, 24, 50, 50, 50)
|
|
>>> output = conv(input)
|
|
'''
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
|
|
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
|
self.padding = padding if isinstance(padding, tuple) else (padding, padding, padding)
|
|
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation)
|
|
self.groups = groups
|
|
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
|
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
|
Kh, Kw, Kd = self.kernel_size
|
|
self.groups = groups
|
|
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
|
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
|
|
|
self.weight = init.invariant_uniform([out_channels, in_channels//groups, Kh, Kw, Kd], dtype="float")
|
|
if bias:
|
|
fan=1
|
|
for i in self.weight.shape[1:]:
|
|
fan *= i
|
|
bound = 1 / math.sqrt(fan)
|
|
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
|
|
else:
|
|
self.bias = None
|
|
|
|
def execute(self, x):
|
|
return conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
|
|
|
|
class Conv1d_sp(Linear):
|
|
def __init__(self, inchannels, outchannels, kernel_size=1, bias=True):
|
|
super().__init__(inchannels, outchannels, bias=bias)
|
|
assert kernel_size == 1
|
|
|
|
def execute(self, x):
|
|
x = x.transpose(0, 2, 1)
|
|
x = super().execute(x)
|
|
x = x.transpose(0, 2, 1)
|
|
return x
|
|
|
|
def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
|
''' Applies a 2D convolution over an input signal composed of several input planes.
|
|
|
|
:param x: the input image
|
|
:type x: jt.Var
|
|
|
|
:param weight: the convolution kernel
|
|
:type weight: jt.Var
|
|
|
|
:param bias: the bias after convolution
|
|
:type bias: jt,Var, optional
|
|
|
|
:param stride: Stride of the convolution. Default: 1
|
|
:type stride: int or tuple, optional
|
|
|
|
:param padding: Padding added to all four sides of the input. Default: 0
|
|
:type padding: int or tuple, optional
|
|
|
|
:param dilation: Spacing between kernel elements. Default: 1
|
|
:type dilation: int or tuple, optional
|
|
|
|
:param groups: Number of blocked connections from input channels to output channels. Default: 1
|
|
:type groups: int, optional
|
|
|
|
Example:
|
|
|
|
>>> x = jt.randn(4, 24, 100, 100)
|
|
>>> w = jt.randn(32, 24, 3, 3)
|
|
>>> y = nn.conv2d(x, w)
|
|
'''
|
|
padding = _pair(padding)
|
|
stride = _pair(stride)
|
|
dilation = _pair(dilation)
|
|
out_channels = weight.shape[0]
|
|
|
|
if groups == 1:
|
|
N,C,H,W = x.shape
|
|
Kh, Kw = weight.shape[-2:]
|
|
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
|
|
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
|
|
with jt.flag_scope(amp_reg = jt.flags.amp_reg | 4):
|
|
xx = x.reindex([N,out_channels,C,oh,ow,Kh,Kw], [
|
|
'i0', # Nid
|
|
'i2', # Cid
|
|
f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}', # Hid+Khid
|
|
f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}', # Wid+KWid
|
|
])
|
|
ww = weight.broadcast(xx.shape, [0,3,4])
|
|
yy = xx*ww
|
|
y = yy.sum([2,5,6]) # Kc, Kh, Kw
|
|
if bias is not None:
|
|
b = bias.broadcast(y.shape, [0,2,3])
|
|
y = y + b
|
|
return y
|
|
else:
|
|
N,C,H,W = x.shape
|
|
Kh, Kw = weight.shape[-2:]
|
|
G = groups
|
|
CpG = C // G # channels per group
|
|
oc = out_channels
|
|
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
|
|
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
|
|
xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [
|
|
'i0', # Nid
|
|
f'i1*{CpG}+i3', # Gid
|
|
f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid
|
|
f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid
|
|
])
|
|
xx.compile_options = {"G":G}
|
|
# w: [oc, CpG, Kh, Kw]
|
|
ww = weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [
|
|
f'i1*{oc//G}+i2',
|
|
'i3',
|
|
'i6',
|
|
'i7'
|
|
])
|
|
yy = xx*ww
|
|
y = yy.reindex_reduce('add', [N, oc, oh, ow], [
|
|
'i0',
|
|
f'i1*{oc//G}+i2',
|
|
'i4',
|
|
'i5'
|
|
])
|
|
if bias is not None:
|
|
b = bias.broadcast(y.shape, [0,2,3])
|
|
y = y + b
|
|
return y
|
|
conv = conv2d
|
|
|
|
def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
|
''' Applies a 3D convolution over an input signal composed of several input planes.
|
|
|
|
:param x: the input volume
|
|
:type x: jt.Var
|
|
|
|
:param weight: the convolution kernel
|
|
:type weight: jt.Var
|
|
|
|
:param bias: the bias after convolution
|
|
:type bias: jt,Var, optional
|
|
|
|
:param stride: Stride of the convolution. Default: 1
|
|
:type stride: int or tuple, optional
|
|
|
|
:param padding: Padding added to all four sides of the input. Default: 0
|
|
:type padding: int or tuple, optional
|
|
|
|
:param dilation: Spacing between kernel elements. Default: 1
|
|
:type dilation: int or tuple, optional
|
|
|
|
:param groups: Number of blocked connections from input channels to output channels. Default: 1
|
|
:type groups: int, optional
|
|
|
|
Example:
|
|
|
|
>>> x = jt.randn(4, 24, 50, 50, 50)
|
|
>>> w = jt.randn(32, 24, 3, 3, 3)
|
|
>>> y = nn.conv2d(x, w)
|
|
'''
|
|
padding = _triple(padding)
|
|
stride = _triple(stride)
|
|
dilation = _triple(dilation)
|
|
out_channels = weight.shape[0]
|
|
|
|
if jt.flags.use_cuda and jt.cudnn:
|
|
y = jt.cudnn.ops.cudnn_conv3d(x, weight, *stride, *padding, *dilation, groups)
|
|
elif groups == 1:
|
|
N,C,D,H,W = x.shape
|
|
Kd, Kh, Kw = weight.shape[-3:]
|
|
od = (D+padding[0]*2-Kd*dilation[0]+dilation[0]-1)//stride[0]+1
|
|
oh = (H+padding[1]*2-Kh*dilation[1]+dilation[1]-1)//stride[1]+1
|
|
ow = (W+padding[2]*2-Kw*dilation[2]+dilation[2]-1)//stride[2]+1
|
|
xx = x.reindex([N,out_channels,C,od,oh,ow,Kd,Kh,Kw], [
|
|
'i0', # Nid
|
|
'i2', # Cid
|
|
f'i3*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid
|
|
f'i4*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid
|
|
f'i5*{stride[2]}-{padding[2]}+i8*{dilation[2]}', # Did+KDid
|
|
])
|
|
ww = weight.broadcast(xx.shape, [0,3,4,5])
|
|
yy = xx*ww
|
|
y = yy.sum([2,6,7,8]) # Kc, Kh, Kw,Kd
|
|
else:
|
|
N,C,D,H,W = x.shape
|
|
Kd, Kh, Kw = weight.shape[-3:]
|
|
G = groups
|
|
CpG = C // G # channels per group
|
|
oc = out_channels
|
|
od = (D+padding[0]*2-Kd*dilation[0]+dilation[0]-1)//stride[0]+1
|
|
oh = (H+padding[1]*2-Kh*dilation[1]+dilation[1]-1)//stride[1]+1
|
|
ow = (W+padding[2]*2-Kw*dilation[2]+dilation[2]-1)//stride[2]+1
|
|
xx = x.reindex([N,G,oc//G,CpG,od,oh,ow,Kd,Kh,Kw], [
|
|
'i0', # Nid
|
|
f'i1*{CpG}+i3', # Gid
|
|
f'i4*{stride[0]}-{padding[0]}+i7*{dilation[0]}', # Hid+Khid
|
|
f'i5*{stride[1]}-{padding[1]}+i8*{dilation[1]}', # Wid+KWid
|
|
f'i6*{stride[2]}-{padding[2]}+i9*{dilation[2]}', # Did+KDid
|
|
])
|
|
xx.compile_options = {"G":G}
|
|
# w: [oc, CpG, Kh, Kw, Kd]
|
|
ww = weight.reindex([N, G, oc//G, CpG, oh, ow, od, Kh, Kw, Kd], [
|
|
f'i1*{oc//G}+i2',
|
|
'i3',
|
|
'i7',
|
|
'i8',
|
|
'i9'
|
|
])
|
|
yy = xx*ww
|
|
y = yy.reindex_reduce('add', [N, oc, od, oh, ow], [
|
|
'i0',
|
|
f'i1*{oc//G}+i2',
|
|
'i4',
|
|
'i5',
|
|
'i6'
|
|
])
|
|
|
|
if bias is not None:
|
|
b = bias.broadcast(y.shape, [0,2,3,4])
|
|
y = y + b
|
|
return y
|
|
|
|
class ConvTranspose(Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
|
|
padding=0, output_padding=0, groups=1, bias=True, dilation=1):
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
|
|
# added
|
|
self.dilation = dilation
|
|
self.groups = groups
|
|
|
|
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
|
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
|
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
|
# added
|
|
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
|
|
self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0],
|
|
self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1])
|
|
self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
|
|
assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \
|
|
self.output_padding[1] < max(self.stride[1], self.dilation[1]), \
|
|
"output padding must be smaller than max(stride, dilation)"
|
|
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
|
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
|
|
|
self.weight = init.invariant_uniform((in_channels, out_channels//groups) + self.kernel_size, dtype="float")
|
|
if bias:
|
|
fan=1
|
|
for i in self.weight.shape[1:]:
|
|
fan *= i
|
|
bound = 1 / math.sqrt(fan)
|
|
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
|
|
else:
|
|
self.bias = None
|
|
|
|
def execute(self, x):
|
|
if self.groups == 1:
|
|
N,C,H,W = x.shape
|
|
i,o,h,w = self.weight.shape
|
|
assert C==i
|
|
stride_h, stride_w = self.stride
|
|
padding_h, padding_w = self.padding
|
|
dilation_h, dilation_w = self.dilation
|
|
|
|
h_out = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
|
w_out = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
|
out_shape = (N, o, h_out, w_out)
|
|
shape = (N, i, o, H, W, h, w)
|
|
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
|
|
ww = self.weight.broadcast(shape, (0, 3, 4)) # N,H,W
|
|
y = (ww*xx).reindex_reduce("add", out_shape, [
|
|
'i0', # N
|
|
'i2', # o
|
|
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
|
|
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
|
|
])
|
|
if self.bias is not None:
|
|
b = self.bias.broadcast(y.shape, [0,2,3])
|
|
y = y + b
|
|
return y
|
|
else:
|
|
N,C,H,W = x.shape
|
|
Kh, Kw = self.kernel_size
|
|
i,o,h,w = self.weight.shape
|
|
oc = self.out_channels
|
|
G = self.groups
|
|
CpG = C // G # channels per group
|
|
assert C==self.in_channels
|
|
stride_h, stride_w = self.stride
|
|
padding_h, padding_w = self.padding
|
|
dilation_h, dilation_w = self.dilation
|
|
|
|
oh = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
|
ow = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
|
out_shape = (N, oc, oh, ow)
|
|
shape = [N,G,oc//G,CpG,oh,ow,Kh,Kw]
|
|
xx = x.reindex(shape, [
|
|
'i0',
|
|
f'i1*{oc//G}+i2',
|
|
'i4',
|
|
'i5'
|
|
])
|
|
ww = self.weight.reindex(shape, [
|
|
f'i1*{oc//G}+i2',
|
|
'i3',
|
|
'i6',
|
|
'i7'
|
|
])
|
|
ww.compile_options = xx.compile_options = {"G":G,"C":C}
|
|
y = (ww*xx).reindex_reduce("add", out_shape, [
|
|
'i0', # Nid
|
|
f'i1*{CpG}+i3', # Gid
|
|
f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid
|
|
f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid
|
|
])
|
|
if self.bias is not None:
|
|
b = self.bias.broadcast(y.shape, [0,2,3])
|
|
y = y + b
|
|
return y
|
|
|
|
|
|
class ConvTranspose3d(Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
|
|
padding=0, output_padding=0, groups=1, bias=True, dilation=1):
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
|
|
# added
|
|
self.dilation = dilation
|
|
self.group = groups
|
|
assert groups==1, "Group conv not supported yet."
|
|
|
|
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
|
|
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
|
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation)
|
|
# added
|
|
self.padding = padding if isinstance(padding, tuple) else (padding, padding, padding)
|
|
self.real_padding = (
|
|
self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0],
|
|
self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1],
|
|
self.dilation[2] * (self.kernel_size[2] - 1) - self.padding[2])
|
|
self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding, output_padding)
|
|
assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \
|
|
self.output_padding[1] < max(self.stride[1], self.dilation[1]) and \
|
|
self.output_padding[2] < max(self.stride[2], self.dilation[2]), \
|
|
"output padding must be smaller than max(stride, dilation)"
|
|
|
|
self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float")
|
|
if bias:
|
|
fan=1
|
|
for i in self.weight.shape[1:]:
|
|
fan *= i
|
|
bound = 1 / math.sqrt(fan)
|
|
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
|
|
else:
|
|
self.bias = None
|
|
|
|
def execute(self, x):
|
|
return conv_transpose3d(x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.group, self.dilation)
|
|
|
|
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
|
if groups == 1:
|
|
x = input
|
|
N,C,H,W = x.shape
|
|
i,o,h,w = weight.shape
|
|
assert C==i
|
|
stride = stride if isinstance(stride, tuple) else (stride, stride)
|
|
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
|
# added
|
|
padding = padding if isinstance(padding, tuple) else (padding, padding)
|
|
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
|
|
assert output_padding[0] < max(stride[0], dilation[0]) and \
|
|
output_padding[1] < max(stride[1], dilation[1]), \
|
|
"output padding must be smaller than max(stride, dilation)"
|
|
|
|
stride_h, stride_w = stride
|
|
padding_h, padding_w = padding
|
|
dilation_h, dilation_w = dilation
|
|
|
|
h_out = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
|
w_out = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
|
out_shape = (N, o, h_out, w_out)
|
|
shape = (N, i, o, H, W, h, w)
|
|
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
|
|
ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W
|
|
y = (ww*xx).reindex_reduce("add", out_shape, [
|
|
'i0', # N
|
|
'i2', # o
|
|
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
|
|
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
|
|
])
|
|
if isinstance(bias, jt.Var):
|
|
b = bias.broadcast(y.shape, [0,2,3])
|
|
y = y + b
|
|
else:
|
|
assert not bias, "Bias should be none or jittor var"
|
|
return y
|
|
else:
|
|
N,C,H,W = input.shape
|
|
i,o,h,w = weight.shape
|
|
G = groups
|
|
oc = o * G
|
|
CpG = C // G # channels per group
|
|
assert C % G == 0
|
|
assert C==i, (C, i)
|
|
stride = stride if isinstance(stride, tuple) else (stride, stride)
|
|
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
|
# added
|
|
padding = padding if isinstance(padding, tuple) else (padding, padding)
|
|
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
|
|
assert output_padding[0] < max(stride[0], dilation[0]) and \
|
|
output_padding[1] < max(stride[1], dilation[1]), \
|
|
"output padding must be smaller than max(stride, dilation)"
|
|
|
|
stride_h, stride_w = stride
|
|
padding_h, padding_w = padding
|
|
dilation_h, dilation_w = dilation
|
|
|
|
oh = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
|
ow = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
|
out_shape = (N, oc, oh, ow)
|
|
shape = [N,G,oc//G,CpG,oh,ow,h,w]
|
|
xx = input.reindex(shape, [
|
|
'i0',
|
|
f'i1*{oc//G}+i2',
|
|
'i4',
|
|
'i5'
|
|
])
|
|
ww = weight.reindex(shape, [
|
|
f'i1*{oc//G}+i2',
|
|
'i3',
|
|
'i6',
|
|
'i7'
|
|
])
|
|
ww.compile_options = xx.compile_options = {"G":G,"C":C}
|
|
y = (ww*xx).reindex_reduce("add", out_shape, [
|
|
'i0', # Nid
|
|
f'i1*{CpG}+i3', # Gid
|
|
f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid
|
|
f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid
|
|
])
|
|
if bias is not None:
|
|
b = bias.broadcast(y.shape, [0,2,3])
|
|
y = y + b
|
|
return y
|
|
conv_transpose2d = conv_transpose
|
|
|
|
def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
|
x = input
|
|
N,C,D,H,W = x.shape
|
|
i,o,d,h,w = weight.shape
|
|
assert C==i
|
|
assert groups==1, "Group conv not supported yet."
|
|
stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
|
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation)
|
|
# added
|
|
padding = padding if isinstance(padding, tuple) else (padding, padding, padding)
|
|
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding, output_padding)
|
|
assert output_padding[0] < max(stride[0], dilation[0]) and \
|
|
output_padding[1] < max(stride[1], dilation[1]) and \
|
|
output_padding[2] < max(stride[2], dilation[2]), \
|
|
"output padding must be smaller than max(stride, dilation)"
|
|
|
|
stride_d, stride_h, stride_w = stride
|
|
padding_d, padding_h, padding_w = padding
|
|
dilation_d, dilation_h, dilation_w = dilation
|
|
|
|
d_out = (D-1) * stride_d + output_padding[0] - 2*padding_d + 1 + (d-1)*dilation_d
|
|
h_out = (H-1) * stride_h + output_padding[1] - 2*padding_h + 1 + (h-1)*dilation_h
|
|
w_out = (W-1) * stride_w + output_padding[2] - 2*padding_w + 1 + (w-1)*dilation_w
|
|
out_shape = (N, o, d_out, h_out, w_out)
|
|
if jt.flags.use_cuda and jt.cudnn:
|
|
return jt.cudnn.ops.cudnn_conv3d_backward_x(weight, x, *out_shape[2:], *stride, *padding, *dilation, groups)
|
|
shape = (N, i, o, D, H, W, d, h, w)
|
|
xx = x.broadcast(shape, (2, 6, 7, 8)) # i,h,w
|
|
ww = weight.broadcast(shape, (0, 3, 4, 5)) # N,H,W
|
|
y = (ww*xx).reindex_reduce("add", out_shape, [
|
|
'i0', # N
|
|
'i2', # o
|
|
f'i3*{stride_d}-{padding_d}+i6*{dilation_d}', # Did+Kdid
|
|
f'i4*{stride_h}-{padding_h}+i7*{dilation_h}', # Hid+Khid
|
|
f'i5*{stride_w}-{padding_w}+i8*{dilation_w}', # Wid+KWid
|
|
])
|
|
if isinstance(bias, jt.Var):
|
|
b = bias.broadcast(y.shape, [0,2,3,4])
|
|
y = y + b
|
|
else:
|
|
assert not bias, "Bias should be none or jittor var"
|
|
return y
|
|
|
|
conv_transpose2d = conv_transpose
|
|
|
|
def pad(x,padding, mode='constant', value=0):
|
|
assert mode in ['constant','replicate','reflect','circular'],'only support constant,replicate,reflect,circular pad'
|
|
assert len(padding)%2==0 and len(padding)//2<=x.ndim
|
|
|
|
padding = list(padding)
|
|
left = [0]*(x.ndim-len(padding)//2)+padding[::2][::-1]
|
|
right = [0]*(x.ndim-len(padding)//2)+padding[1::2][::-1]
|
|
|
|
out_dims = []
|
|
out_shape = []
|
|
for i,n,l,r in zip(range(x.ndim),x.shape,left,right):
|
|
out_shape.append(n+l+r)
|
|
if mode == 'constant':
|
|
out_dims.append(f'i{i}-{l}')
|
|
elif mode == 'replicate':
|
|
out_dims.append(f"i{i}<{l} ? 0 : i{i} > {n+l-1} ? {n-1} : i{i}-{l}")
|
|
elif mode == 'reflect':
|
|
out_dims.append(f"i{i}<{l} ? {l}-i{i} : i{i} > {n+l-1} ? {2*(n-1)+l}-i{i} : i{i}-{l}")
|
|
elif mode == 'circular':
|
|
out_dims.append(f"i{i}<{l} ? {n-l}+i{i} : i{i} > {n+l-1} ? i{i}-{n+l} : i{i}-{l}")
|
|
|
|
return x.reindex(out_shape,out_dims,overflow_value=value)
|
|
|
|
|
|
class ReflectionPad2d(Module):
|
|
def __init__(self, padding):
|
|
self.padding = padding
|
|
if isinstance(self.padding, int):
|
|
self.pl = self.padding
|
|
self.pr = self.padding
|
|
self.pt = self.padding
|
|
self.pb = self.padding
|
|
elif isinstance(self.padding, tuple):
|
|
self.pl, self.pr, self.pt, self.pb = self.padding
|
|
else:
|
|
raise TypeError(f"ReflectionPad2d padding just support int or tuple, but found {type(padding)}")
|
|
|
|
def execute(self, x):
|
|
n,c,h,w = x.shape
|
|
assert (self.pl < w and self.pr < w), f"padding_left and padding_right should be smaller than input width"
|
|
assert (self.pt < h and self.pb < h), f"padding_top and padding_bottom should be smaller than input height"
|
|
oh=h+self.pt+self.pb
|
|
ow=w+self.pl+self.pr
|
|
l = self.pl
|
|
r = self.pl + w - 1
|
|
t = self.pt
|
|
b = self.pt + h - 1
|
|
return x.reindex([n,c,oh,ow], ["i0","i1",
|
|
f"i2<{t} ? {t}-i2 : i2 > {b} ? {h-1+b}-i2 : i2-{t}",
|
|
f"i3<{l} ? {l}-i3 : i3 > {r} ? {w-1+r}-i3 : i3-{l}",
|
|
])
|
|
|
|
class ZeroPad2d(Module):
|
|
def __init__(self, padding):
|
|
self.padding = padding
|
|
if isinstance(self.padding, int):
|
|
self.pl = self.padding
|
|
self.pr = self.padding
|
|
self.pt = self.padding
|
|
self.pb = self.padding
|
|
elif isinstance(self.padding, (tuple,list)):
|
|
self.pl, self.pr, self.pt, self.pb = self.padding
|
|
else:
|
|
raise TypeError(f"ZeroPad2d padding just support int or tuple, but found {type(padding)}")
|
|
|
|
def execute(self, x):
|
|
n,c,h,w = x.shape
|
|
return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"])
|
|
|
|
class ConstantPad2d(Module):
|
|
def __init__(self, padding, value):
|
|
self.padding = padding
|
|
if isinstance(self.padding, int):
|
|
self.pl = self.padding
|
|
self.pr = self.padding
|
|
self.pt = self.padding
|
|
self.pb = self.padding
|
|
elif isinstance(self.padding, tuple):
|
|
self.pl, self.pr, self.pt, self.pb = self.padding
|
|
else:
|
|
raise TypeError(f"ConstantPad2d padding just support int or tuple, but found {type(padding)}")
|
|
self.value = value
|
|
|
|
def execute(self, x):
|
|
assert len(x.shape) >= 2
|
|
shape = x.shape
|
|
tar_shape = shape[0:-2] + [shape[-2]+self.pt+self.pb,shape[-1]+self.pl+self.pr]
|
|
tar_dims = []
|
|
for i in range(len(shape)-2):
|
|
tar_dims.append(f"i{i}")
|
|
tar_dims.append(f"i{i+1}-{self.pt}")
|
|
tar_dims.append(f"i{i+2}-{self.pl}")
|
|
return x.reindex(tar_shape, tar_dims, overflow_value=self.value)
|
|
|
|
class ReplicationPad2d(Module):
|
|
def __init__(self, padding):
|
|
self.padding = padding
|
|
if isinstance(self.padding, int):
|
|
self.pl = self.padding
|
|
self.pr = self.padding
|
|
self.pt = self.padding
|
|
self.pb = self.padding
|
|
elif isinstance(self.padding, tuple):
|
|
self.pl, self.pr, self.pt, self.pb = self.padding
|
|
else:
|
|
raise TypeError(f"ReplicationPad2d padding just support int or tuple, but found {type(padding)}")
|
|
|
|
def execute(self, x):
|
|
n,c,h,w = x.shape
|
|
oh=h+self.pt+self.pb
|
|
ow=w+self.pl+self.pr
|
|
l = self.pl
|
|
r = self.pl + w - 1
|
|
t = self.pt
|
|
b = self.pt + h - 1
|
|
return x.reindex([n,c,oh,ow], ["i0","i1",
|
|
f"i2<{t} ? 0 : i2 > {b} ? {h-1} : i2-{t}",
|
|
f"i3<{l} ? 0 : i3 > {r} ? {w-1} : i3-{l}"
|
|
])
|
|
|
|
class Embedding(Module):
|
|
''' A simple lookup table that stores embeddings of a fixed dictionary and size.
|
|
|
|
:param num: size of the dictionary of embeddings
|
|
:type num: int
|
|
|
|
:param dim: the size of each embedding vector
|
|
:type dim: int
|
|
|
|
Example:
|
|
>>> embedding = nn.Embedding(10, 3)
|
|
>>> x = jt.int32([1, 2, 3, 3])
|
|
>>> embedding(x)
|
|
jt.Var([[ 1.1128596 0.19169547 0.706642]
|
|
[ 1.2047412 1.9668795 0.9932192]
|
|
[ 0.14941819 0.57047683 -1.3217674]
|
|
[ 0.14941819 0.57047683 -1.3217674]], dtype=float32)
|
|
'''
|
|
def __init__(self, num, dim):
|
|
self.num = num
|
|
self.dim = dim
|
|
self.weight = jt.init.gauss([num,dim],'float32').stop_grad()
|
|
|
|
def execute(self, x):
|
|
res = self.weight[x.flatten()].reshape(x.shape + [self.dim])
|
|
return res
|
|
|
|
class PixelShuffle(Module):
|
|
def __init__(self, upscale_factor):
|
|
self.upscale_factor = upscale_factor
|
|
|
|
def execute(self, x):
|
|
n,c,h,w = x.shape
|
|
r = self.upscale_factor
|
|
assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
|
|
return x.reindex([n,int(c/r**2),h*r,w*r], [
|
|
"i0",
|
|
f"i1*{r*r}+i2%{r}*{r}+i3%{r}",
|
|
f"i2/{r}",
|
|
f"i3/{r}"
|
|
])
|
|
|
|
class Tanh(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
def execute(self, x) :
|
|
return x.tanh()
|
|
|
|
class Sigmoid(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
def execute(self, x) :
|
|
return x.sigmoid()
|
|
|
|
def softplus(x,beta=1.0,threshold=20.0):
|
|
return 1 / beta * jt.log(1 + (beta * x).minimum(threshold).exp()) + \
|
|
(x - threshold/beta).maximum(0.0)
|
|
|
|
def hardtanh(x,min_val=-1,max_val=1):
|
|
return jt.clamp(x,min_v=min_val,max_v=max_val)
|
|
|
|
|
|
class Softplus(Module):
|
|
r'''
|
|
SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive.
|
|
|
|
Args:
|
|
|
|
[in] beta (float): the beta value for the Softplus formulation. Default: 1.
|
|
|
|
[in] threshold (float): values above this revert to a linear function. Default: 20.
|
|
'''
|
|
def __init__(self, beta=1, threshold=20):
|
|
self.beta = beta
|
|
self.threshold = threshold
|
|
|
|
def execute(self, x):
|
|
return softplus(x, self.beta, self.threshold)
|
|
|
|
class Resize(Module):
|
|
def __init__(self, size, mode="nearest", align_corners=False):
|
|
super().__init__()
|
|
self.size = size
|
|
self.mode = mode
|
|
self.align_corners = align_corners
|
|
def execute(self, x):
|
|
return resize(x, self.size, self.mode, self.align_corners)
|
|
|
|
|
|
def _bicubic(x, a, func):
|
|
# normal ver
|
|
if func == 1:
|
|
return (a+2)*(jt.abs(x)**3)-(a+3)*(x**2)+1
|
|
if func == 2:
|
|
return a*(jt.abs(x)**3)-5*a*(x**2)+8*a*(jt.abs(x))-4*a
|
|
return 0
|
|
|
|
|
|
def _interpolate(img, x, y, ids, mode):
|
|
if mode == "nearest":
|
|
return img.reindex([*ids, x.floor_int(), y.floor_int()])
|
|
if mode == "bilinear":
|
|
fx, fy = x.floor_int(), y.floor_int()
|
|
cx, cy = fx + 1, fy + 1
|
|
dx, dy = x - fx, y - fy
|
|
a = img.reindex_var([*ids, fx, fy])
|
|
b = img.reindex_var([*ids, cx, fy])
|
|
c = img.reindex_var([*ids, fx, cy])
|
|
d = img.reindex_var([*ids, cx, cy])
|
|
dnx, dny = 1 - dx, 1 - dy
|
|
ab = dx * b + dnx * a
|
|
cd = dx * d + dnx * c
|
|
o = ab * dny + cd * dy
|
|
return o
|
|
if mode=="bicubic": # ugly ver.
|
|
n,c,h,w = img.shape
|
|
fx, fy = x.floor_int(), y.floor_int()
|
|
dix, diy = x - fx, y - fy
|
|
ax, ay = _bicubic(dix+1,-0.75,2), _bicubic(diy+1,-0.75,2)
|
|
bx, by = _bicubic(dix,-0.75,1), _bicubic(diy,-0.75,1)
|
|
cx, cy = _bicubic(1-dix,-0.75,1), _bicubic(1-diy,-0.75,1)
|
|
dx, dy = _bicubic(2-dix,-0.75,2), _bicubic(2-diy,-0.75,2)
|
|
afx, afy = jt.maximum(jt.minimum(fx-1,h-1),0), jt.maximum(jt.minimum(fy-1,w-1),0)
|
|
bfx, bfy = jt.maximum(jt.minimum(fx,h-1),0), jt.maximum(jt.minimum(fy,w-1),0)
|
|
cfx, cfy = jt.maximum(jt.minimum(fx+1,h-1),0), jt.maximum(jt.minimum(fy+1,w-1),0)
|
|
dfx, dfy = jt.maximum(jt.minimum(fx+2,h-1),0), jt.maximum(jt.minimum(fy+2,w-1),0)
|
|
a = ax*(img.reindex_var([*ids,afx,afy])*ay+img.reindex_var([*ids,afx,bfy])*by+img.reindex_var([*ids,afx,cfy])*cy+img.reindex_var([*ids,afx,dfy])*dy)
|
|
b = bx*(img.reindex_var([*ids,bfx,afy])*ay+img.reindex_var([*ids,bfx,bfy])*by+img.reindex_var([*ids,bfx,cfy])*cy+img.reindex_var([*ids,bfx,dfy])*dy)
|
|
c = cx*(img.reindex_var([*ids,cfx,afy])*ay+img.reindex_var([*ids,cfx,bfy])*by+img.reindex_var([*ids,cfx,cfy])*cy+img.reindex_var([*ids,cfx,dfy])*dy)
|
|
d = dx*(img.reindex_var([*ids,dfx,afy])*ay+img.reindex_var([*ids,dfx,bfy])*by+img.reindex_var([*ids,dfx,cfy])*cy+img.reindex_var([*ids,dfx,dfy])*dy)
|
|
o = a + b + c + d
|
|
return o
|
|
raise (f"Not support interpolation mode: {mode}")
|
|
|
|
# TODO: tf_mode to another function
|
|
def resize(img, size, mode="nearest", align_corners=False, tf_mode=False):
|
|
n, c, h, w = img.shape
|
|
H, W = size
|
|
nid, cid, hid, wid = jt.index((n, c, H, W))
|
|
if align_corners:
|
|
x = hid * ((h - 1) / max(1, H - 1))
|
|
y = wid * ((w - 1) / max(1, W - 1))
|
|
elif mode == "bicubic":
|
|
x = (hid + 0.5) * (h / H) - 0.5
|
|
y = (wid + 0.5) * (w / W) - 0.5
|
|
elif mode == 'nearest':
|
|
x = hid * (h / H)
|
|
y = wid * (w / W)
|
|
else:
|
|
if (tf_mode):
|
|
x = hid * (h / H)
|
|
if H > h: x = x.clamp(0, h - 1)
|
|
y = wid * (w / W)
|
|
if W > w: y = y.clamp(0, w - 1)
|
|
else:
|
|
x = hid * (h / H) + (h / H * 0.5 - 0.5)
|
|
if H > h: x = x.clamp(0, h - 1)
|
|
y = wid * (w / W) + (w / W * 0.5 - 0.5)
|
|
if W > w: y = y.clamp(0, w - 1)
|
|
return _interpolate(img, x, y, (nid, cid), mode)
|
|
|
|
upsample = resize
|
|
|
|
|
|
def interpolate(X, size=None, scale_factor=None, mode='bilinear', align_corners=False, tf_mode=False):
|
|
if scale_factor is not None:
|
|
size = [int(X.shape[-2] * scale_factor), int(X.shape[-1] * scale_factor)]
|
|
if isinstance(size, int):
|
|
size = (size, size)
|
|
if scale_factor is not None and scale_factor > 1:
|
|
return upsample(X, size, mode, align_corners, tf_mode)
|
|
else:
|
|
return resize(X, size, mode, align_corners, tf_mode)
|
|
|
|
|
|
def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'):
|
|
r'''
|
|
Given an input and a flow-field grid, computes the output using input values and pixel locations from grid.
|
|
|
|
grid specifies the sampling pixel locations normalized by the input spatial dimensions. Therefore, it should have most values in the range of [-1, 1]. For example, values x = -1, y = -1 is the left-top pixel of input, and values x = 1, y = 1 is the right-bottom pixel of input.
|
|
|
|
Args:
|
|
|
|
[in] input (var): the source input var, whose shape is (N, C, Hi, Wi)
|
|
|
|
[in] grid (var): the pixel locations, whose shape is (N, Ho, Wo, 2)
|
|
|
|
[in] mode (string): the interpolate way, default: bilinear.
|
|
|
|
[in] padding_mode (string): the padding way, default: zeros.
|
|
|
|
[out] output (var): the output var, whose shape is (N, C, Ho, Wo)
|
|
|
|
Example:
|
|
|
|
>>> x = jt.array([[[[1,2],[3,4]]]])
|
|
>>> print(x)
|
|
[[[[1 2]
|
|
[3 4]]]]
|
|
|
|
>>> grid = jt.array([[[[0.5, 0.5]]]])
|
|
>>> print(x.shape, grid.shape)
|
|
[1,1,2,2,], [1,1,2,2,]
|
|
|
|
>>> nn.grid_sample(x, grid)
|
|
[[[[3.25]]]]
|
|
'''
|
|
assert padding_mode == 'zeros'
|
|
Ni, Ci, Hi, Wi = input.shape
|
|
No, Ho, Wo, D = grid.shape
|
|
assert D == 2
|
|
assert Ni == No
|
|
assert len(input.shape) == 4 and len(grid.shape)
|
|
|
|
nid, cid, hid, wid = jt.index((Ni, Ci, Ho, Wo))
|
|
x = ((grid[:, :, :, 1].unsqueeze(1).repeat([1, Ci, 1, 1]) + 1) / 2) * (Hi - 1)
|
|
y = ((grid[:, :, :, 0].unsqueeze(1).repeat([1, Ci, 1, 1]) + 1) / 2) * (Wi - 1)
|
|
return _interpolate(input, x, y, (nid, cid), mode)
|
|
|
|
|
|
def linspace_from_neg_one(grid,num_steps,align_corners):
|
|
if num_steps <= 1:
|
|
return jt.array([],dtype=grid.dtype)
|
|
# TODO: use jt.index
|
|
ra = np.linspace(-1,1,num_steps)
|
|
if not align_corners:
|
|
ra = ra*(num_steps-1)/num_steps
|
|
return jt.array(ra,dtype=grid.dtype)
|
|
|
|
def make_base_grid_4D(theta,N,C,H,W,align_corners):
|
|
base_grid = jt.zeros((N, H, W, 3), dtype=theta.dtype)
|
|
base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners)
|
|
base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1)
|
|
base_grid[...,-1] = 1
|
|
return base_grid
|
|
|
|
def make_base_grid_5D(theta,N,C,D,H,W,align_corners):
|
|
base_grid = jt.zeros((N, D, H, W, 4), dtype=theta.dtype)
|
|
base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners)
|
|
base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1)
|
|
base_grid[...,2] = jt.unsqueeze(jt.unsqueeze(linspace_from_neg_one(theta, D, align_corners),-1),-1)
|
|
base_grid[...,-1] = 1
|
|
return base_grid
|
|
|
|
def affine_grid_generator_4D(theta,N,C,H,W,align_corners):
|
|
base_grid = make_base_grid_4D(theta, N, C, H, W, align_corners)
|
|
grid = jt.nn.bmm(base_grid.reshape(N, H * W, 3),theta.transpose(0,2,1))
|
|
return grid.reshape(N, H, W, 2)
|
|
|
|
def affine_grid_generator_5D(theta,N,C,D,H,W,align_corners):
|
|
base_grid = make_base_grid_5D(theta, N, C, D, H, W, align_corners)
|
|
grid = jt.nn.bmm(base_grid.reshape(N, D * H * W, 4),theta.transpose(0,2,1))
|
|
return grid.reshape(N, D, H, W, 3)
|
|
|
|
def affine_grid(theta, size, align_corners=False):
|
|
assert str(theta.dtype) in ['float','float32','float64']
|
|
assert min(size)>0
|
|
assert len(size) in [4,5]
|
|
if len(size)== 4:
|
|
assert theta.ndim == 3 and theta.shape[-2] == 2 and theta.shape[-1] == 3
|
|
return affine_grid_generator_4D(theta, size[0], size[1], size[2], size[3], align_corners)
|
|
elif len(size)==5:
|
|
assert theta.ndim == 3 and theta.shape[-2] == 3 and theta.shape[-1] == 4
|
|
return affine_grid_generator_5D(theta, size[0], size[1], size[2], size[3], size[4], align_corners)
|
|
|
|
|
|
def grid_sampler_unnormalize(coord,size,align_corners):
|
|
if align_corners:
|
|
#unnormalize coord from [-1, 1] to [0, size - 1]
|
|
return ((coord + 1) / 2) * (size - 1)
|
|
else:
|
|
#unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
|
|
return ((coord + 1) * size - 1) / 2
|
|
|
|
|
|
def clip_coordinates(x,clip_limit):
|
|
return jt.clamp(x,min_v=0,max_v=clip_limit-1)
|
|
|
|
def reflect_coordinates(x,twice_low,twice_high):
|
|
if twice_low == twice_high:
|
|
return jt.zeros_like(x)
|
|
m = twice_low / 2
|
|
span = (twice_high - twice_low) / 2
|
|
x = (x - m).abs()
|
|
#`fmod` returns same sign as `in`, which is positive after the `fabs` above.
|
|
extra = x.mod(span)
|
|
flips = (x / span).floor_int()
|
|
result1 = extra+m
|
|
result2 = span-extra+m
|
|
con = flips%2==0
|
|
not_con = flips%2!=0
|
|
result1[not_con]=0.0
|
|
result2[con]=0.0
|
|
return result1+result2
|
|
|
|
|
|
def grid_sampler_compute_source_index(coord,size,padding_mode,align_corners):
|
|
coord = grid_sampler_unnormalize(coord, size, align_corners)
|
|
if padding_mode == 'border':
|
|
#clip coordinates to image borders
|
|
coord = clip_coordinates(coord, size)
|
|
elif padding_mode == 'reflection':
|
|
#reflect coordinates by image borders
|
|
if align_corners:
|
|
coord = reflect_coordinates(coord, 0, 2*(size - 1))
|
|
else:
|
|
coord = reflect_coordinates(coord, -1, 2*size - 1)
|
|
#clip coordinates to image borders
|
|
coord = clip_coordinates(coord, size)
|
|
return coord
|
|
|
|
|
|
|
|
def grid_sampler_3d(X,grid,mode,padding_mode,align_corners):
|
|
N = X.shape[0]
|
|
C = X.shape[1]
|
|
inp_D = X.shape[2]
|
|
inp_H = X.shape[3]
|
|
inp_W = X.shape[4]
|
|
|
|
D = grid.shape[1]
|
|
H = grid.shape[2]
|
|
W = grid.shape[3]
|
|
x = grid[:,:,:,:,0]
|
|
y = grid[:,:,:,:,1]
|
|
z = grid[:,:,:,:,2]
|
|
shape = [N,C,D,H,W]
|
|
cid = jt.index(shape, dim=1)
|
|
nid = jt.index(shape, dim=0)
|
|
|
|
x = grid_sampler_compute_source_index(x,inp_W,padding_mode,align_corners)
|
|
y = grid_sampler_compute_source_index(y,inp_H,padding_mode,align_corners)
|
|
z = grid_sampler_compute_source_index(z,inp_D,padding_mode,align_corners)
|
|
xid = x.reindex(shape,['i0','i2','i3','i4'])
|
|
yid = y.reindex(shape,['i0','i2','i3','i4'])
|
|
zid = z.reindex(shape,['i0','i2','i3','i4'])
|
|
|
|
if mode=='nearest':
|
|
return X.reindex([nid,cid,zid.round_int(),yid.round_int(),xid.round_int()])
|
|
elif mode=='bilinear':
|
|
fx,fy,fz = xid.floor_int(),yid.floor_int(),zid.floor_int()
|
|
cx,cy,cz = fx+1,fy+1,fz+1
|
|
dx,dy,dz = xid-fx,yid-fy,zid-fz
|
|
dnx,dny,dnz = cx-xid,cy-yid,cz-zid
|
|
a = X.reindex([nid,cid,fz,fy,fx])
|
|
b = X.reindex([nid,cid,cz,fy,fx])
|
|
c = X.reindex([nid,cid,fz,cy,fx])
|
|
d = X.reindex([nid,cid,fz,fy,cx])
|
|
e = X.reindex([nid,cid,fz,cy,cx])
|
|
f = X.reindex([nid,cid,cz,fy,cx])
|
|
g = X.reindex([nid,cid,cz,cy,fx])
|
|
h = X.reindex([nid,cid,cz,cy,cx])
|
|
o = a*dnx*dny*dnz+b*dnx*dny*dz+c*dnx*dy*dnz+d*dx*dny*dnz+e*dx*dy*dnz+f*dx*dny*dz+g*dnx*dy*dz+h*dx*dy*dz
|
|
return o
|
|
|
|
def grid_sampler_2d(X,grid,mode,padding_mode,align_corners):
|
|
N = X.shape[0]
|
|
C = X.shape[1]
|
|
inp_H = X.shape[2]
|
|
inp_W = X.shape[3]
|
|
|
|
H = grid.shape[1]
|
|
W = grid.shape[2]
|
|
x = grid[:,:,:,0]
|
|
y = grid[:,:,:,1]
|
|
shape = [N,C,H,W]
|
|
cid = jt.index(shape, dim=1)
|
|
nid = jt.index(shape, dim=0)
|
|
|
|
x = grid_sampler_compute_source_index(x,inp_W,padding_mode,align_corners)
|
|
y = grid_sampler_compute_source_index(y,inp_H,padding_mode,align_corners)
|
|
xid = x.reindex(shape,['i0','i2','i3'])
|
|
yid = y.reindex(shape,['i0','i2','i3'])
|
|
|
|
if mode=='nearest':
|
|
return X.reindex([nid,cid,yid.round_int(),xid.round_int()])
|
|
elif mode=='bilinear':
|
|
#xid,yid = (xid+0.00001),(yid+0.00001)
|
|
fx,fy = (xid).floor_int(),(yid).floor_int()
|
|
cx,cy = fx+1,fy+1
|
|
dx,dy = xid-fx,yid-fy
|
|
dnx,dny = cx-xid,cy-yid
|
|
|
|
a = X.reindex([nid,cid,fy,fx],overflow_value=0.0)
|
|
b = X.reindex([nid,cid,cy,fx],overflow_value=0.0)
|
|
c = X.reindex([nid,cid,fy,cx],overflow_value=0.0)
|
|
d = X.reindex([nid,cid,cy,cx],overflow_value=0.0)
|
|
o = a*dnx*dny+b*dnx*dy+c*dx*dny+d*dx*dy
|
|
return o
|
|
|
|
|
|
def grid_sampler(X, grid, mode, padding_mode, align_corners):
|
|
assert X.dtype==grid.dtype
|
|
assert ((X.ndim==4 or X.ndim==5) and X.ndim==grid.ndim)
|
|
assert X.shape[0]==grid.shape[0] and grid.shape[-1]==X.ndim-2
|
|
assert X.numel()>0
|
|
if X.ndim == 4:
|
|
return grid_sampler_2d(X, grid, mode, padding_mode, align_corners)
|
|
else:
|
|
return grid_sampler_3d(X, grid, mode, padding_mode, align_corners)
|
|
|
|
|
|
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False):
|
|
assert mode in ['bilinear','nearest']
|
|
assert padding_mode in ['zeros','border','reflection']
|
|
return grid_sampler(input, grid, mode, padding_mode, align_corners)
|
|
|
|
|
|
class Upsample(Module):
|
|
def __init__(self, scale_factor=None, mode='nearest'):
|
|
self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor)
|
|
self.mode = mode
|
|
|
|
def execute(self, x):
|
|
return upsample(x,
|
|
size=(
|
|
int(x.shape[2]*self.scale_factor[0]),
|
|
int(x.shape[3]*self.scale_factor[1])),
|
|
mode=self.mode)
|
|
|
|
class UpsamplingBilinear2d(Upsample):
|
|
def __init__(self, scale_factor=None):
|
|
Upsample.__init__(self, scale_factor, 'bilinear')
|
|
|
|
class UpsamplingNearest2d(Upsample):
|
|
def __init__(self, scale_factor=None):
|
|
Upsample.__init__(self, scale_factor, 'nearest')
|
|
|
|
class Sequential(Module):
|
|
def __init__(self, *args):
|
|
self.layers = collections.OrderedDict()
|
|
for mod in args:
|
|
if isinstance(mod, collections.OrderedDict):
|
|
for k, m in mod.items():
|
|
self.add_module(k, m)
|
|
elif isinstance(mod,list):
|
|
for m in mod:
|
|
self.append(m)
|
|
else:
|
|
self.append(mod)
|
|
def __getitem__(self, idx):
|
|
if idx not in self.layers:
|
|
return list(self.layers.values())[idx]
|
|
|
|
return self.layers[idx]
|
|
def __iter__(self):
|
|
return self.layers.values().__iter__()
|
|
def keys(self):
|
|
return self.layers.keys()
|
|
def values(self):
|
|
return self.layers.values()
|
|
def items(self):
|
|
return self.layers.items()
|
|
def execute(self, x):
|
|
for k, layer in self.layers.items():
|
|
x = layer(x)
|
|
return x
|
|
def dfs(self, parents, k, callback, callback_leave):
|
|
n_children = len(self.layers)
|
|
ret = callback(parents, k, self, n_children)
|
|
if ret == False:
|
|
return
|
|
parents.append(self)
|
|
for k,v in self.layers.items():
|
|
if isinstance(v, Module):
|
|
v.dfs(parents, k, callback, callback_leave)
|
|
parents.pop()
|
|
if callback_leave:
|
|
callback_leave(parents, k, self, n_children)
|
|
def append(self, mod):
|
|
assert callable(mod), f"Module <{type(mod)}> is not callable"
|
|
assert not isinstance(mod, type), f"Module is not a type"
|
|
self.layers[len(self.layers)]=mod
|
|
def add_module(self, name, mod):
|
|
assert callable(mod), f"Module <{type(mod)}> is not callable"
|
|
assert not isinstance(mod, type), f"Module is not a type"
|
|
self.layers[name]=mod
|
|
|
|
def __len__(self):
|
|
return len(self.layers)
|
|
|
|
|
|
class ParameterList(Module):
|
|
def __init__(self, *args):
|
|
self.params = collections.OrderedDict()
|
|
for var in args:
|
|
if isinstance(var, (collections.OrderedDict, dict)):
|
|
for k, v in var.items():
|
|
self.add_param(k, v)
|
|
elif isinstance(var, list):
|
|
for v in var:
|
|
self.append(v)
|
|
else:
|
|
self.append(var)
|
|
def __getitem__(self, idx):
|
|
if idx not in self.params:
|
|
return list(self.params.values())[idx]
|
|
|
|
return self.params[idx]
|
|
def __iter__(self):
|
|
return self.params.values().__iter__()
|
|
def keys(self):
|
|
return self.params.keys()
|
|
def values(self):
|
|
return self.params.values()
|
|
def items(self):
|
|
return self.params.items()
|
|
def execute(self, x):
|
|
raise NotImplementedError("Parameters is not executable")
|
|
def append(self, var):
|
|
assert isinstance(var, jt.Var), f"argument <{type(var)}> is not jittor var"
|
|
self.params[len(self.params)] = var
|
|
def add_param(self, name, var):
|
|
assert isinstance(var, jt.Var), f"argument <{type(var)}> is not jittor var"
|
|
self.params[name]=var
|
|
def __setitem__(self, name, var):
|
|
self.add_param(name, var)
|
|
|
|
def __len__(self):
|
|
return len(self.params)
|
|
|
|
ParameterDict = ParameterList
|
|
|
|
def Parameter(data, requires_grad=True):
|
|
''' The `Parameter` interface isn't needed in Jittor, this interface
|
|
doesn't nothings and it is just used for compatible.
|
|
|
|
A Jittor Var is a Parameter
|
|
when it is a member of Module, if you don't want a Jittor
|
|
Var menber is treated as a Parameter, just name it startswith
|
|
underscore `_`.
|
|
'''
|
|
LOG.w(Parameter.__doc__)
|
|
data = data.clone()
|
|
data.requires_grad = requires_grad
|
|
return data
|
|
|
|
def backward(v, *args, **kw):
|
|
''' The `backward` variable interface doesn't exist in Jittor.
|
|
please use `optimizer.backward(loss)` or
|
|
`optimizer.step(loss)` instead.
|
|
For example, if your code looks like this::
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
It can be changed to this::
|
|
|
|
optimizer.zero_grad()
|
|
optimizer.backward(loss)
|
|
optimizer.step()
|
|
|
|
Or more concise::
|
|
|
|
optimizer.step(loss)
|
|
|
|
The step function will automatically zero grad and backward.
|
|
'''
|
|
LOG.f(backward.__doc__)
|
|
|
|
jt.Var.backward = backward
|
|
|
|
def unfold(X, kernel_size, dilation=1, padding=0, stride=1):
|
|
assert X.ndim == 4
|
|
if not isinstance(kernel_size, tuple):
|
|
kernel_size = (kernel_size, kernel_size)
|
|
if not isinstance(dilation, tuple):
|
|
dilation = (dilation, dilation)
|
|
if not isinstance(padding, tuple):
|
|
padding = (padding, padding)
|
|
if not isinstance(stride, tuple):
|
|
stride = (stride, stride)
|
|
n, c, h, w = X.shape
|
|
shape = X.shape
|
|
area = kernel_size[0] * kernel_size[1]
|
|
block_nums = []
|
|
for i in range(2, 4):
|
|
block_nums.append(
|
|
(shape[i] + 2 * padding[i - 2] - dilation[i - 2] * (kernel_size[i - 2] - 1) - 1) // stride[i - 2] + 1)
|
|
if padding[0] != 0 or padding[1] != 0:
|
|
X = X.reindex([n, c, h + padding[0] * 2, w + padding[1] * 2],
|
|
["i0", "i1", f"i2-{padding[0]}", f"i3-{padding[1]}"])
|
|
output = X.reindex([n, c * area, block_nums[0] * block_nums[1]], ["i0", f"i1/{area}",
|
|
f"i2/{block_nums[1]}*{stride[0]}+(i1%{area})/{kernel_size[1]}*{dilation[0]}",
|
|
f"i2%{block_nums[1]}*{stride[1]}+(i1%{area})%{kernel_size[1]}*{dilation[1]}"])
|
|
return output
|
|
|
|
|
|
def fold(X,output_size,kernel_size,dilation=1,padding=0,stride=1):
|
|
assert X.ndim==3
|
|
if not isinstance(kernel_size,tuple):
|
|
kernel_size = (kernel_size,kernel_size)
|
|
if not isinstance(dilation,tuple):
|
|
dilation = (dilation,dilation)
|
|
if not isinstance(padding,tuple):
|
|
padding = (padding,padding)
|
|
if not isinstance(stride,tuple):
|
|
stride = (stride,stride)
|
|
n,cl,num = X.shape
|
|
area = kernel_size[0] * kernel_size[1]
|
|
block_nums = []
|
|
for i in range(2,4):
|
|
block_nums.append((output_size[i-2]+2*padding[i-2]-dilation[i-2]*(kernel_size[i-2]-1)-1) // stride[i-2]+1)
|
|
output = X.reindex_reduce("add",[n,cl // area,output_size[0]+2*padding[0],output_size[1]+2*padding[1]],["i0",f"i1/{area}",f"i2/{block_nums[1]}*{stride[0]}+(i1%{area})/{kernel_size[1]}*{dilation[0]}",f"i2%{block_nums[1]}*{stride[1]}+(i1%{area})%{kernel_size[1]}*{dilation[1]}"])
|
|
return output[:,:,padding[0]:padding[0]+output_size[0],padding[1]:padding[1]+output_size[1]]
|
|
|
|
ModuleList = Sequential
|
|
|
|
|
|
class LSTMCell(jt.Module):
|
|
''' A long short-term memory (LSTM) cell.
|
|
|
|
:param input_size: The number of expected features in the input
|
|
:type input_size: int
|
|
|
|
:param hidden_size: The number of features in the hidden state
|
|
:type hidden_size: int
|
|
|
|
:param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True.
|
|
:type bias: bool, optional
|
|
|
|
Example:
|
|
|
|
>>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
|
|
>>> input = jt.randn(2, 3, 10) # (time_steps, batch, input_size)
|
|
>>> hx = jt.randn(3, 20) # (batch, hidden_size)
|
|
>>> cx = jt.randn(3, 20)
|
|
>>> output = []
|
|
>>> for i in range(input.shape[0]):
|
|
hx, cx = rnn(input[i], (hx, cx))
|
|
output.append(hx)
|
|
>>> output = jt.stack(output, dim=0)
|
|
'''
|
|
def __init__(self, input_size, hidden_size, bias=True):
|
|
super().__init__()
|
|
|
|
self.hidden_size = hidden_size
|
|
self.bias = bias
|
|
|
|
k = math.sqrt(1 / hidden_size)
|
|
self.weight_ih = init.uniform((4 * hidden_size, input_size), 'float32', -k, k)
|
|
self.weight_hh = init.uniform((4 * hidden_size, hidden_size), 'float32', -k, k)
|
|
|
|
if bias:
|
|
self.bias_ih = init.uniform((4 * hidden_size,), 'float32', -k, k)
|
|
self.bias_hh = init.uniform((4 * hidden_size,), 'float32', -k, k)
|
|
|
|
def execute(self, input, hx = None):
|
|
if hx is None:
|
|
zeros = jt.zeros((input.shape[0], self.hidden_size), dtype=input.dtype)
|
|
h, c = zeros, zeros
|
|
else:
|
|
h, c = hx
|
|
|
|
y = matmul_transpose(input, self.weight_ih) + matmul_transpose(h, self.weight_hh)
|
|
|
|
if self.bias:
|
|
y = y + self.bias_ih + self.bias_hh
|
|
|
|
i = y[:, :self.hidden_size].sigmoid()
|
|
f = y[:, self.hidden_size : 2 * self.hidden_size].sigmoid()
|
|
g = y[:, 2 * self.hidden_size : 3 * self.hidden_size].tanh()
|
|
o = y[:, 3 * self.hidden_size:].sigmoid()
|
|
|
|
c = f * c + i * g
|
|
h = o * c.tanh()
|
|
|
|
return h, c
|
|
|
|
|
|
class RNNCell(jt.Module):
|
|
''' An Elman RNN cell with tanh or ReLU non-linearity.
|
|
|
|
:param input_size: The number of expected features in the input
|
|
:type input_size: int
|
|
|
|
:param hidden_size: The number of features in the hidden state
|
|
:type hidden_size: int
|
|
|
|
:param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True.
|
|
:type bias: bool, optional
|
|
|
|
:param nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'.
|
|
:type nonlinearity: str, optional
|
|
|
|
Example:
|
|
|
|
>>> rnn = nn.RNNCell(10, 20)
|
|
>>> input = jt.randn((6, 3, 10))
|
|
>>> hx = jt.randn((3, 20))
|
|
>>> output = []
|
|
>>> for i in range(6):
|
|
hx = rnn(input[i], hx)
|
|
output.append(hx)
|
|
'''
|
|
def __init__(self, input_size, hidden_size, bias=True, nonlinearity = "tanh"):
|
|
super().__init__()
|
|
|
|
self.hidden_size = hidden_size
|
|
self.bias = bias
|
|
self.nonlinearity = nonlinearity
|
|
|
|
k = math.sqrt(1 / hidden_size)
|
|
self.weight_ih = init.uniform((hidden_size, input_size), 'float32', -k, k)
|
|
self.weight_hh = init.uniform((hidden_size, hidden_size), 'float32', -k, k)
|
|
|
|
if bias:
|
|
self.bias_ih = init.uniform((hidden_size,), 'float32', -k, k)
|
|
self.bias_hh = init.uniform((hidden_size,), 'float32', -k, k)
|
|
|
|
def execute(self, input, hx = None):
|
|
if hx is None:
|
|
hx = jt.zeros((input.shape[0], self.hidden_size), dtype=input.dtype)
|
|
|
|
y = matmul_transpose(input, self.weight_ih)+matmul_transpose(hx, self.weight_hh)
|
|
|
|
if self.bias:
|
|
y= y + self.bias_ih + self.bias_hh
|
|
|
|
if self.nonlinearity == 'tanh':
|
|
y = y.tanh()
|
|
elif self.nonlinearity == 'relu':
|
|
y = relu(y)
|
|
else:
|
|
raise RuntimeError("Unknown nonlinearity: {}".format(self.nonlinearity))
|
|
|
|
return y
|
|
|
|
|
|
class GRUCell(jt.Module):
|
|
''' A gated recurrent unit (GRU) cell.
|
|
|
|
:param input_size: The number of expected features in the input
|
|
:type input_size: int
|
|
|
|
:param hidden_size: The number of features in the hidden state
|
|
:type hidden_size: int
|
|
|
|
:param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True.
|
|
:type bias: bool, optional
|
|
|
|
Example:
|
|
|
|
>>> rnn = nn.GRUCell(10, 20)
|
|
>>> input = jt.randn((6, 3, 10))
|
|
>>> hx = jt.randn((3, 20))
|
|
>>> output = []
|
|
>>> for i in range(6):
|
|
hx = rnn(input[i], hx)
|
|
output.append(hx)
|
|
'''
|
|
def __init__(self, input_size, hidden_size, bias=True):
|
|
super().__init__()
|
|
|
|
self.hidden_size = hidden_size
|
|
self.bias = bias
|
|
|
|
k = math.sqrt(1 / hidden_size)
|
|
self.weight_ih = init.uniform((3*hidden_size, input_size), 'float32', -k, k)
|
|
self.weight_hh = init.uniform((3*hidden_size, hidden_size), 'float32', -k, k)
|
|
|
|
if bias:
|
|
self.bias_ih = init.uniform((3*hidden_size,), 'float32', -k, k)
|
|
self.bias_hh = init.uniform((3*hidden_size,), 'float32', -k, k)
|
|
|
|
def execute(self, input, hx = None):
|
|
if hx is None:
|
|
hx = jt.zeros((input.shape[0], self.hidden_size), dtype=input.dtype)
|
|
|
|
gi = matmul_transpose(input, self.weight_ih)
|
|
gh = matmul_transpose(hx, self.weight_hh)
|
|
|
|
if self.bias:
|
|
gi += self.bias_ih
|
|
gh += self.bias_hh
|
|
|
|
i_r, i_i, i_n = gi.chunk(3, 1)
|
|
h_r, h_i, h_n = gh.chunk(3, 1)
|
|
|
|
resetgate = jt.sigmoid(i_r + h_r)
|
|
inputgate = jt.sigmoid(i_i + h_i)
|
|
newgate = jt.tanh(i_n + resetgate * h_n)
|
|
hy = newgate + inputgate * (hx - newgate)
|
|
return hy
|
|
|
|
class RNNBase(Module):
|
|
def __init__(self, mode: str, input_size: int, hidden_size: int,
|
|
num_layers: int = 1, bias: bool = True, batch_first: bool = False,
|
|
dropout: float = 0, bidirectional: bool = False,
|
|
proj_size: int = 0, nonlinearity: str = None) -> None:
|
|
super().__init__()
|
|
|
|
self.mode = mode
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = num_layers
|
|
self.bias = bias
|
|
self.batch_first = batch_first
|
|
self.dropout = dropout
|
|
self.bidirectional = bidirectional
|
|
self.proj_size = proj_size
|
|
self.nonlinearity = nonlinearity
|
|
|
|
if mode == 'LSTM':
|
|
gate_size = 4 * hidden_size
|
|
elif mode == 'GRU':
|
|
gate_size = 3 * hidden_size
|
|
elif mode == 'RNN':
|
|
gate_size = hidden_size
|
|
else:
|
|
raise ValueError("Unrecognized RNN mode: " + mode)
|
|
|
|
num_directions = 1 + bidirectional
|
|
k = math.sqrt(1 / hidden_size)
|
|
|
|
def build_unit(name, in_channels, out_channels=None):
|
|
if out_channels is not None:
|
|
shape = (in_channels, out_channels)
|
|
else:
|
|
shape = (in_channels,)
|
|
setattr(self, name, init.uniform(shape, 'float32', -k, k))
|
|
if self.bidirectional:
|
|
setattr(self, name + '_reverse', init.uniform(shape, 'float32', -k, k))
|
|
|
|
for layer in range(num_layers):
|
|
if layer == 0:
|
|
build_unit(f'weight_ih_l{layer}', gate_size, input_size)
|
|
else:
|
|
if proj_size > 0:
|
|
build_unit(f'weight_ih_l{layer}', gate_size, num_directions * proj_size)
|
|
else:
|
|
build_unit(f'weight_ih_l{layer}', gate_size, num_directions * hidden_size)
|
|
|
|
if proj_size > 0:
|
|
build_unit(f'weight_hh_l{layer}', gate_size, proj_size)
|
|
build_unit(f'weight_hr_l{layer}', proj_size, hidden_size)
|
|
else:
|
|
build_unit(f'weight_hh_l{layer}', gate_size, hidden_size)
|
|
|
|
if bias:
|
|
build_unit(f'bias_ih_l{layer}', gate_size)
|
|
build_unit(f'bias_hh_l{layer}', gate_size)
|
|
|
|
def _cudnn_flatten_weights(self, cudnn_mode):
|
|
def copy_to_flatten_weight(param_name, offset_idx, num_gates):
|
|
def copy_to(param_name, offset_idx, idx):
|
|
cur_offset = self._cudnn_weight_offset[offset_idx]
|
|
param = getattr(self, param_name)
|
|
param = param[self.hidden_size * idx: self.hidden_size * (idx + 1)]
|
|
ft_weight[cur_offset:cur_offset + param.numel()] = param.flatten()
|
|
|
|
if self.bias:
|
|
for idx in range(num_gates):
|
|
copy_to('weight' + param_name, offset_idx + idx * 2, idx)
|
|
copy_to('bias' + param_name, offset_idx + idx * 2 + 1, idx)
|
|
return num_gates * 2
|
|
else:
|
|
for idx in range(num_gates):
|
|
copy_to('weight' + param_name, offset_idx + idx, idx)
|
|
return num_gates
|
|
|
|
if jt.flags.use_cuda and jt.cudnn:
|
|
if getattr(self, '_cudnn_weight_size', None) is None:
|
|
offset_array = jt.cudnn.cudnn_rnn_weight_offset(
|
|
cudnn_mode,
|
|
self.input_size,
|
|
self.hidden_size,
|
|
self.num_layers,
|
|
self.proj_size,
|
|
self.bias,
|
|
self.bidirectional
|
|
)
|
|
self._cudnn_weight_size = offset_array[0]
|
|
self._cudnn_weight_offset = offset_array[1:]
|
|
|
|
num_gates = {
|
|
"RNN": 1, "LSTM": 4, "GRU": 3
|
|
}[self.mode]
|
|
ft_weight = jt.zeros(self._cudnn_weight_size, dtype=jt.float32)
|
|
|
|
cnt = 0
|
|
for layer in range(self.num_layers):
|
|
suffix = ''
|
|
cnt += copy_to_flatten_weight(f'_ih_l{layer}' + suffix, cnt, num_gates)
|
|
cnt += copy_to_flatten_weight(f'_hh_l{layer}' + suffix, cnt, num_gates)
|
|
if self.bidirectional:
|
|
suffix = '_reverse'
|
|
cnt += copy_to_flatten_weight(f'_ih_l{layer}' + suffix, cnt, num_gates)
|
|
cnt += copy_to_flatten_weight(f'_hh_l{layer}' + suffix, cnt, num_gates)
|
|
return ft_weight
|
|
else:
|
|
raise RuntimeError("Not Cudnn found")
|
|
|
|
@abstractmethod
|
|
def call_rnn_cell(self, input, hidden, suffix):
|
|
pass
|
|
|
|
def call_rnn_sequence(self, input, hidden, suffix):
|
|
if 'reverse' in suffix:
|
|
input = input[::-1]
|
|
|
|
output = []
|
|
for s in range(input.shape[0]):
|
|
out, hidden = self.call_rnn_cell(input[s], hidden, suffix)
|
|
output.append(out)
|
|
|
|
if 'reverse' in suffix:
|
|
output = output[::-1]
|
|
output = jt.stack(output, dim=0)
|
|
|
|
return output, hidden
|
|
|
|
def _execute_cudnn_rnn(self, input, hx):
|
|
cudnn_mode = {
|
|
('RNN', 'tanh'): 'tanh',
|
|
('RNN', 'relu'): 'relu',
|
|
('LSTM', None): 'lstm',
|
|
('GRU', None): 'gru'
|
|
}[(self.mode, self.nonlinearity)]
|
|
ft_weight = self._cudnn_flatten_weights(cudnn_mode)
|
|
|
|
if self.mode == 'LSTM':
|
|
ret = jt.cudnn.ops.cudnn_rnn(input, hx[0], hx[1], ft_weight,
|
|
cudnn_mode, self.input_size, self.hidden_size, self.num_layers, 0,
|
|
self.dropout, self.bias, self.bidirectional, self.is_training()
|
|
)
|
|
return ret[0], (ret[1], ret[2])
|
|
else:
|
|
ret = jt.cudnn.ops.cudnn_rnn(input, hx, ft_weight,
|
|
cudnn_mode, self.input_size, self.hidden_size, self.num_layers, 0,
|
|
self.dropout, self.bias, self.bidirectional, self.is_training()
|
|
)
|
|
return ret[0], ret[1]
|
|
|
|
def execute(self, input, hx=None):
|
|
if self.batch_first:
|
|
input = input.permute(1, 0, 2)
|
|
|
|
num_directions = 2 if self.bidirectional else 1
|
|
|
|
if hx is None:
|
|
if self.mode in ['RNN', 'GRU']:
|
|
hx = jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype)
|
|
elif self.mode == 'LSTM':
|
|
hx = (jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype),
|
|
jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype))
|
|
|
|
if jt.flags.use_cuda and jt.cudnn and self.proj_size == 0:
|
|
return self._execute_cudnn_rnn(input, hx)
|
|
else:
|
|
hidden_n = []
|
|
|
|
for l in range(self.num_layers):
|
|
output = []
|
|
|
|
if isinstance(hx, tuple):
|
|
hidden = [h[l * num_directions] for h in hx]
|
|
else:
|
|
hidden = hx[l * num_directions]
|
|
|
|
output, _hidden = self.call_rnn_sequence(input, hidden, f'l{l}')
|
|
hidden_n.append(_hidden)
|
|
|
|
if self.bidirectional:
|
|
if isinstance(hx, tuple):
|
|
hidden = [h[l * num_directions + 1] for h in hx]
|
|
else:
|
|
hidden = hx[l * num_directions + 1]
|
|
|
|
output_b, _hidden = self.call_rnn_sequence(input, hidden, f'l{l}_reverse')
|
|
output = jt.concat([output, output_b], dim=-1)
|
|
hidden_n.append(_hidden)
|
|
|
|
if self.dropout > 0:
|
|
input = dropout(output, p=self.dropout)
|
|
else:
|
|
input = output
|
|
|
|
if isinstance(hx, tuple):
|
|
hidden_n = tuple(jt.stack(hn, dim=0) for hn in zip(*hidden_n))
|
|
else:
|
|
hidden_n = jt.stack(hidden_n, dim=0)
|
|
|
|
return output, hidden_n
|
|
|
|
|
|
class RNN(RNNBase):
|
|
''' Applies a multi-layer Elman RNN with tanh ReLU non-linearity to an input sequence.
|
|
|
|
:param input_size: The number of expected features in the input.
|
|
:type input_size: int
|
|
|
|
:param hidden_size: The number of features in the hidden state.
|
|
:type hidden_size: int
|
|
|
|
:param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1
|
|
:type num_layers: int, optinal
|
|
|
|
:param nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'
|
|
:type nonlinearity: str, optional
|
|
|
|
:param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True.
|
|
:type bias: bool, optional
|
|
|
|
:param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False
|
|
:type bias: bool, optional
|
|
|
|
:param dropout: If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0
|
|
:type dropout: float, optional
|
|
|
|
:param bidirectional: If True, becomes a bidirectional RNN. Default: False
|
|
:type bidirectional: bool, optional
|
|
|
|
Example:
|
|
>>> rnn = nn.RNN(10, 20, 2)
|
|
>>> input = jt.randn(5, 3, 10)
|
|
>>> h0 = jt.randn(2, 3, 20)
|
|
>>> output, hn = rnn(input, h0)
|
|
'''
|
|
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1,
|
|
nonlinearity: str = 'tanh', bias: bool = True, batch_first: bool = False,
|
|
dropout: float = 0, bidirectional: bool = False) -> None:
|
|
super().__init__('RNN', input_size, hidden_size, num_layers=num_layers,
|
|
bias=bias, batch_first=batch_first, dropout=dropout,
|
|
bidirectional=bidirectional)
|
|
|
|
if not nonlinearity in ['tanh', 'relu']:
|
|
raise ValueError('Unrecognized nonlinearity: ' + nonlinearity)
|
|
self.nonlinearity = nonlinearity
|
|
|
|
def call_rnn_cell(self, input, hidden, suffix):
|
|
y = matmul_transpose(input, getattr(self, f'weight_ih_{suffix}'))
|
|
y = y + matmul_transpose(hidden, getattr(self, f'weight_hh_{suffix}'))
|
|
|
|
if self.bias:
|
|
y = y + getattr(self, f'bias_ih_{suffix}') + getattr(self, f'bias_hh_{suffix}')
|
|
|
|
if self.nonlinearity == 'tanh':
|
|
h = jt.tanh(y)
|
|
else:
|
|
h = jt.nn.relu(y)
|
|
|
|
return h, h
|
|
|
|
|
|
class LSTM(RNNBase):
|
|
''' Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.
|
|
|
|
:param input_size: The number of expected features in the input.
|
|
:type input_size: int
|
|
|
|
:param hidden_size: The number of features in the hidden state.
|
|
:type hidden_size: int
|
|
|
|
:param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1
|
|
:type num_layers: int, optinal
|
|
|
|
:param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True.
|
|
:type bias: bool, optional
|
|
|
|
:param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False
|
|
:type bias: bool, optional
|
|
|
|
:param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0
|
|
:type dropout: float, optional
|
|
|
|
:param bidirectional: If True, becomes a bidirectional LSTM. Default: False
|
|
:type bidirectional: bool, optional
|
|
|
|
:param proj_size: If > 0, will use LSTM with projections of corresponding size. Default: 0
|
|
:type proj_size: int, optional
|
|
|
|
Example:
|
|
>>> rnn = nn.LSTM(10, 20, 2)
|
|
>>> input = jt.randn(5, 3, 10)
|
|
>>> h0 = jt.randn(2, 3, 20)
|
|
>>> c0 = jt.randn(2, 3, 20)
|
|
>>> output, (hn, cn) = rnn(input, (h0, c0))
|
|
'''
|
|
|
|
def __init__(self, input_size, hidden_size, num_layers=1, bias=True,
|
|
batch_first=False, dropout=0, bidirectional=False, proj_size=0):
|
|
super().__init__('LSTM', input_size, hidden_size, num_layers=num_layers,
|
|
bias=bias, batch_first=batch_first, dropout=dropout,
|
|
bidirectional=bidirectional, proj_size=proj_size)
|
|
|
|
def call_rnn_cell(self, input, hidden, suffix):
|
|
h, c = hidden
|
|
y = matmul_transpose(input, getattr(self, f'weight_ih_{suffix}'))
|
|
y = y + matmul_transpose(h, getattr(self, f'weight_hh_{suffix}'))
|
|
|
|
if self.bias:
|
|
y = y + getattr(self, f'bias_ih_{suffix}') + getattr(self, f'bias_hh_{suffix}')
|
|
|
|
i = y[:, :self.hidden_size].sigmoid()
|
|
f = y[:, self.hidden_size : 2 * self.hidden_size].sigmoid()
|
|
g = y[:, 2 * self.hidden_size : 3 * self.hidden_size].tanh()
|
|
o = y[:, 3 * self.hidden_size:].sigmoid()
|
|
c = f * c + i * g
|
|
h = o * c.tanh()
|
|
|
|
if self.proj_size > 0:
|
|
h = matmul_transpose(h, getattr(self, f'weight_hr_{suffix}'))
|
|
|
|
return h, (h, c)
|
|
|
|
|
|
class GRU(RNNBase):
|
|
''' Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
|
|
|
|
:param input_size: The number of expected features in the input.
|
|
:type input_size: int
|
|
|
|
:param hidden_size: The number of features in the hidden state.
|
|
:type hidden_size: int
|
|
|
|
:param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two GRUs together to form a stacked GRU, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1
|
|
:type num_layers: int, optinal
|
|
|
|
:param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True.
|
|
:type bias: bool, optional
|
|
|
|
:param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False
|
|
:type bias: bool, optional
|
|
|
|
:param dropout: If non-zero, introduces a Dropout layer on the outputs of each GRU layer except the last layer, with dropout probability equal to dropout. Default: 0
|
|
:type dropout: float, optional
|
|
|
|
:param bidirectional: If True, becomes a bidirectional GRU. Default: False
|
|
:type bidirectional: bool, optional
|
|
|
|
Example:
|
|
>>> rnn = nn.GRU(10, 20, 2)
|
|
>>> input = jt.randn(5, 3, 10)
|
|
>>> h0 = jt.randn(2, 3, 20)
|
|
>>> output, hn = rnn(input, h0)
|
|
'''
|
|
|
|
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1,
|
|
bias: bool = True, batch_first: bool = False, dropout: float = 0,
|
|
bidirectional: bool = False) -> None:
|
|
super().__init__('GRU', input_size, hidden_size, num_layers=num_layers,
|
|
bias=bias, batch_first=batch_first, dropout=dropout,
|
|
bidirectional=bidirectional)
|
|
|
|
def call_rnn_cell(self, input, hidden, suffix):
|
|
ih = matmul_transpose(input, getattr(self, f'weight_ih_{suffix}'))
|
|
hh = matmul_transpose(hidden, getattr(self, f'weight_hh_{suffix}'))
|
|
|
|
if self.bias:
|
|
ih = ih + getattr(self, f'bias_ih_{suffix}')
|
|
hh = hh + getattr(self, f'bias_hh_{suffix}')
|
|
|
|
hs = self.hidden_size
|
|
r = (ih[:, :hs] + hh[:, :hs]).sigmoid()
|
|
z = (ih[:, hs: 2 * hs] + hh[:, hs: 2 * hs]).sigmoid()
|
|
n = (ih[:, 2 * hs:] + r * hh[:, 2 * hs:]).tanh()
|
|
h = (1 - z) * n + z * hidden
|
|
|
|
return h, h
|
|
|
|
def bilinear(in1, in2, weight, bias):
|
|
w = weight.transpose((1,0,2))
|
|
w = w.reshape((w.shape[0], -1))
|
|
x = jt.matmul(in1, w)
|
|
x = x.reshape(x.shape[:-1]+[weight.shape[0], weight.shape[2]])
|
|
y = in2.broadcast(x, (-2,))
|
|
z = (x*y).sum(-1)
|
|
if bias is not None:
|
|
z += bias
|
|
return z
|
|
|
|
|
|
class Bilinear(Module):
|
|
''' bilinear transformation $out = in1^T W in2 + bias$, Example::
|
|
|
|
m = nn.Bilinear(20, 30, 40)
|
|
input1 = jt.randn(128, 20)
|
|
input2 = jt.randn(128, 30)
|
|
output = m(input1, input2)
|
|
print(output.shape)
|
|
# [128, 40]
|
|
|
|
'''
|
|
def __init__(self, in1_features, in2_features, out_features, bias=True, dtype="float32"):
|
|
bound = 1 / math.sqrt(in1_features)
|
|
self.weight = jt.init.uniform([out_features, in1_features, in2_features], dtype, -bound, bound)
|
|
self.bias = bias
|
|
if bias:
|
|
self.bias = jt.init.uniform([out_features], dtype, -bound, bound)
|
|
|
|
def execute(self, in1, in2):
|
|
return bilinear(in1, in2, self.weight, self.bias)
|