mirror of https://github.com/Jittor/Jittor
repeat chunk stack flip Softplus GN grid_sample
This commit is contained in:
parent
f434f64cd2
commit
90a1422b3c
|
@ -761,3 +761,4 @@ from . import nn
|
||||||
from .nn import matmul
|
from .nn import matmul
|
||||||
from . import contrib
|
from . import contrib
|
||||||
from .contrib import concat
|
from .contrib import concat
|
||||||
|
from .misc import *
|
|
@ -0,0 +1,144 @@
|
||||||
|
# ***************************************************************
|
||||||
|
# Copyright (c) 2020 Jittor. Authors:
|
||||||
|
# Dun Liang <randonlang@gmail.com>.
|
||||||
|
# Wenyang Zhou <576825820@qq.com>
|
||||||
|
#
|
||||||
|
# All Rights Reserved.
|
||||||
|
# This file is subject to the terms and conditions defined in
|
||||||
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
# ***************************************************************
|
||||||
|
import jittor as jt
|
||||||
|
import numpy as np
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
def repeat(x, *shape):
|
||||||
|
r'''
|
||||||
|
Repeats this var along the specified dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
[in] x (var): jittor var.
|
||||||
|
|
||||||
|
[in] shape (tuple): int or tuple. The number of times to repeat this var along each dimension.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> x = jt.array([1, 2, 3])
|
||||||
|
|
||||||
|
>>> x.repeat(4, 2)
|
||||||
|
[[ 1, 2, 3, 1, 2, 3],
|
||||||
|
[ 1, 2, 3, 1, 2, 3],
|
||||||
|
[ 1, 2, 3, 1, 2, 3],
|
||||||
|
[ 1, 2, 3, 1, 2, 3]]
|
||||||
|
|
||||||
|
>>> x.repeat(4, 2, 1).size()
|
||||||
|
[4, 2, 3,]
|
||||||
|
'''
|
||||||
|
if len(shape) == 1 and isinstance(shape[0], Sequence):
|
||||||
|
shape = shape[0]
|
||||||
|
len_x_shape = len(x.shape)
|
||||||
|
len_shape = len(shape)
|
||||||
|
x_shape = x.shape
|
||||||
|
rep_shape = shape
|
||||||
|
if len_x_shape < len_shape:
|
||||||
|
x_shape = (len_shape - len_x_shape) * [1] + x.shape
|
||||||
|
x = x.broadcast(x_shape)
|
||||||
|
elif len_x_shape > len_shape:
|
||||||
|
rep_shape = (len_x_shape - len_shape) * [1] + shape
|
||||||
|
tar_shape = (np.array(x_shape) * np.array(rep_shape)).tolist()
|
||||||
|
dims = []
|
||||||
|
for i in range(len(tar_shape)): dims.append(f"i{i}%{x_shape[i]}")
|
||||||
|
return x.reindex(tar_shape, dims)
|
||||||
|
jt.Var.repeat = repeat
|
||||||
|
|
||||||
|
def chunk(x, chunks, dim=0):
|
||||||
|
r'''
|
||||||
|
Splits a var into a specific number of chunks. Each chunk is a view of the input var.
|
||||||
|
|
||||||
|
Last chunk will be smaller if the var size along the given dimension dim is not divisible by chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
[in] input (var) – the var to split.
|
||||||
|
|
||||||
|
[in] chunks (int) – number of chunks to return.
|
||||||
|
|
||||||
|
[in] dim (int) – dimension along which to split the var.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> x = jt.random((10,3,3))
|
||||||
|
|
||||||
|
>>> res = jt.chunk(x, 2, 0)
|
||||||
|
|
||||||
|
>>> print(res[0].shape, res[1].shape)
|
||||||
|
[5,3,3,] [5,3,3,]
|
||||||
|
'''
|
||||||
|
l = x.shape[dim]
|
||||||
|
res = []
|
||||||
|
if l <= chunks:
|
||||||
|
for i in range(l):
|
||||||
|
res.append(x[(slice(None,),)*dim+([i,],)])
|
||||||
|
else:
|
||||||
|
nums = (l-1) // chunks + 1
|
||||||
|
for i in range(chunks-1):
|
||||||
|
res.append(x[(slice(None,),)*dim+(slice(i*nums,(i+1)*nums),)])
|
||||||
|
if (i+1)*nums < l:
|
||||||
|
res.append(x[(slice(None,),)*dim+(slice((i+1)*nums,None),)])
|
||||||
|
return res
|
||||||
|
jt.Var.chunk = chunk
|
||||||
|
|
||||||
|
def stack(x, dim=0):
|
||||||
|
r'''
|
||||||
|
Concatenates sequence of vars along a new dimension.
|
||||||
|
|
||||||
|
All vars need to be of the same size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
[in] x (sequence of vars) – sequence of vars to concatenate.
|
||||||
|
|
||||||
|
[in] dim (int) – dimension to insert. Has to be between 0 and the number of dimensions of concatenated vars (inclusive).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> a1 = jt.array([[1,2,3]])
|
||||||
|
|
||||||
|
>>> a2 = jt.array([[4,5,6]])
|
||||||
|
|
||||||
|
>>> jt.stack([a1, a2], 0)
|
||||||
|
[[[1 2 3]
|
||||||
|
[[4 5 6]]]
|
||||||
|
'''
|
||||||
|
assert isinstance(x, list)
|
||||||
|
assert len(x) >= 2
|
||||||
|
res = [x_.unsqueeze(dim) for x_ in x]
|
||||||
|
return jt.contrib.concat(res, dim=dim)
|
||||||
|
jt.Var.stack = stack
|
||||||
|
|
||||||
|
def flip(x, dim=0):
|
||||||
|
r'''
|
||||||
|
Reverse the order of a n-D var along given axis in dims.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
[in] input (var) – the input var.
|
||||||
|
|
||||||
|
[in] dims (a list or tuple) – axis to flip on.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> x = jt.array([[1,2,3,4]])
|
||||||
|
|
||||||
|
>>> x.flip(1)
|
||||||
|
[[4 3 2 1]]
|
||||||
|
'''
|
||||||
|
assert isinstance(dim, int)
|
||||||
|
tar_dims = []
|
||||||
|
for i in range(len(x.shape)):
|
||||||
|
if i == dim:
|
||||||
|
tar_dims.append(f"{x.shape[dim]-1}-i{i}")
|
||||||
|
else:
|
||||||
|
tar_dims.append(f"i{i}")
|
||||||
|
return x.reindex(x.shape, tar_dims)
|
||||||
|
jt.Var.flip = flip
|
|
@ -290,6 +290,36 @@ class InstanceNorm2d(Module):
|
||||||
b = self.bias.broadcast(x, [0,2,3])
|
b = self.bias.broadcast(x, [0,2,3])
|
||||||
return norm_x * w + b
|
return norm_x * w + b
|
||||||
|
|
||||||
|
class GroupNorm(Module):
|
||||||
|
def __init__(self, num_groups, num_channels, eps=1e-05, affine=None, is_train=True, sync=True):
|
||||||
|
assert affine == None
|
||||||
|
self.num_groups = num_groups
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.eps = eps
|
||||||
|
self.is_train = is_train
|
||||||
|
self.sync = sync
|
||||||
|
self.weight = init.constant((num_channels,), "float32", 1.0)
|
||||||
|
self.bias = init.constant((num_channels,), "float32", 0.0)
|
||||||
|
|
||||||
|
def execute(self, x):
|
||||||
|
N,C,H,W = x.shape
|
||||||
|
assert C == self.num_channels
|
||||||
|
assert C % self.num_groups == 0
|
||||||
|
x_ = x.reindex([N, int(C/self.num_groups), self.num_groups, H, W], [
|
||||||
|
"i0", f"i2*{C/self.num_groups}+i1", "i3", "i4"
|
||||||
|
])
|
||||||
|
xmean = jt.mean(x_, dims=[1,3,4], keepdims=1).reindex(x.shape, ["i0", "0", f"i1/({C}/{self.num_groups})","0", "0"])
|
||||||
|
x2mean = jt.mean(x_*x_, dims=[1,3,4], keepdims=1).reindex(x.shape, ["i0", "0", f"i1/({C}/{self.num_groups})","0", "0"])
|
||||||
|
if self.sync and jt.in_mpi:
|
||||||
|
xmean = xmean.mpi_all_reduce("mean")
|
||||||
|
x2mean = x2mean.mpi_all_reduce("mean")
|
||||||
|
|
||||||
|
xvar = jt.maximum(x2mean-xmean*xmean, 0)
|
||||||
|
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||||
|
w = self.weight.broadcast(x, [0,2,3])
|
||||||
|
b = self.bias.broadcast(x, [0,2,3])
|
||||||
|
return norm_x * w + b
|
||||||
|
|
||||||
Relu = jt.make_module(relu)
|
Relu = jt.make_module(relu)
|
||||||
ReLU = Relu
|
ReLU = Relu
|
||||||
Leaky_relu = jt.make_module(leaky_relu, 2)
|
Leaky_relu = jt.make_module(leaky_relu, 2)
|
||||||
|
@ -558,6 +588,23 @@ class Sigmoid(Module):
|
||||||
def execute(self, x) :
|
def execute(self, x) :
|
||||||
return x.sigmoid()
|
return x.sigmoid()
|
||||||
|
|
||||||
|
class Softplus(Module):
|
||||||
|
r'''
|
||||||
|
SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
[in] beta (float): the beta value for the Softplus formulation. Default: 1.
|
||||||
|
|
||||||
|
[in] threshold (float): values above this revert to a linear function. Default: 20.
|
||||||
|
'''
|
||||||
|
def __init__(self, beta=1, threshold=20):
|
||||||
|
self.beta = beta
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
def execute(self, x):
|
||||||
|
return 1 / self.beta * jt.log(1 + (self.beta * x).exp())
|
||||||
|
|
||||||
class Resize(Module):
|
class Resize(Module):
|
||||||
def __init__(self, size, mode="nearest", align_corners=False):
|
def __init__(self, size, mode="nearest", align_corners=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -611,6 +658,52 @@ def upsample(img, size, mode="nearest", align_corners=False):
|
||||||
y = wid * (w / W)
|
y = wid * (w / W)
|
||||||
return _interpolate(img, x, y, (nid,cid), mode)
|
return _interpolate(img, x, y, (nid,cid), mode)
|
||||||
|
|
||||||
|
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
|
||||||
|
r'''
|
||||||
|
Given an input and a flow-field grid, computes the output using input values and pixel locations from grid.
|
||||||
|
|
||||||
|
grid specifies the sampling pixel locations normalized by the input spatial dimensions. Therefore, it should have most values in the range of [-1, 1]. For example, values x = -1, y = -1 is the left-top pixel of input, and values x = 1, y = 1 is the right-bottom pixel of input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
[in] input (var): the source input var, whose shape is (N, C, Hi, Wi)
|
||||||
|
|
||||||
|
[in] grid (var): the pixel locations, whose shape is (N, Ho, Wo, 2)
|
||||||
|
|
||||||
|
[in] mode (string): the interpolate way, default: bilinear.
|
||||||
|
|
||||||
|
[in] padding_mode (string): the padding way, default: zeros.
|
||||||
|
|
||||||
|
[out] output (var): the output var, whose shape is (N, C, Ho, Wo)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> x = jt.array([[[[1,2],[3,4]]]])
|
||||||
|
|
||||||
|
>>> print(x)
|
||||||
|
[[[[1 2]
|
||||||
|
[3 4]]]]
|
||||||
|
|
||||||
|
>>> grid = jt.array([[[[0.5, 0.5]]]])
|
||||||
|
|
||||||
|
>>> print(x.shape, grid.shape)
|
||||||
|
[1,1,2,2,], [1,1,2,2,]
|
||||||
|
|
||||||
|
>>> nn.grid_sample(x, grid)
|
||||||
|
[[[[3.25]]]]
|
||||||
|
'''
|
||||||
|
assert padding_mode == 'zeros'
|
||||||
|
Ni, Ci, Hi, Wi = input.shape
|
||||||
|
No, Ho, Wo, D = grid.shape
|
||||||
|
assert D == 2
|
||||||
|
assert Ni == No
|
||||||
|
assert len(input.shape) == 4 and len(grid.shape)
|
||||||
|
|
||||||
|
nid, cid, hid, wid = jt.index((Ni,Ci,Ho,Wo))
|
||||||
|
x = ((grid[:,:,:,1].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Hi - 1)
|
||||||
|
y = ((grid[:,:,:,0].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Wi - 1)
|
||||||
|
return _interpolate(input, x, y, (nid,cid), mode)
|
||||||
|
|
||||||
class Upsample(Module):
|
class Upsample(Module):
|
||||||
def __init__(self, scale_factor=None, mode='nearest'):
|
def __init__(self, scale_factor=None, mode='nearest'):
|
||||||
self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor)
|
self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor)
|
||||||
|
|
|
@ -23,23 +23,21 @@ except:
|
||||||
tnn = None
|
tnn = None
|
||||||
skip_this_test = True
|
skip_this_test = True
|
||||||
|
|
||||||
def check_equal_with_istrain(arr, j_layer, p_layer, is_train=True, threshold=1e-5):
|
def check_equal_with_istrain(arr, j_layer, p_layer, is_train=True, has_running=True, threshold=1e-5):
|
||||||
jittor_arr = jt.array(arr)
|
jittor_arr = jt.array(arr)
|
||||||
pytorch_arr = torch.Tensor(arr)
|
pytorch_arr = torch.Tensor(arr)
|
||||||
if is_train:
|
if has_running:
|
||||||
assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
|
if is_train:
|
||||||
# assert np.allclose(p_layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
|
assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
|
||||||
else:
|
else:
|
||||||
assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
|
assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
|
||||||
# assert np.allclose(p_layer.layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
|
|
||||||
jittor_result = j_layer(jittor_arr)
|
jittor_result = j_layer(jittor_arr)
|
||||||
pytorch_result = p_layer(pytorch_arr)
|
pytorch_result = p_layer(pytorch_arr)
|
||||||
if is_train:
|
if has_running:
|
||||||
assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
|
if is_train:
|
||||||
# assert np.allclose(p_layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
|
assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
|
||||||
else:
|
else:
|
||||||
assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
|
assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
|
||||||
# assert np.allclose(p_layer.layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
|
|
||||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), threshold)
|
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), threshold)
|
||||||
|
|
||||||
def check_equal_without_istrain(arr, j_layer, p_layer, threshold=1e-5):
|
def check_equal_without_istrain(arr, j_layer, p_layer, threshold=1e-5):
|
||||||
|
@ -100,5 +98,20 @@ class TestBatchNorm(unittest.TestCase):
|
||||||
model.eval()
|
model.eval()
|
||||||
check_equal_with_istrain(arr, jnn.BatchNorm1d(10, is_train=False), model, False)
|
check_equal_with_istrain(arr, jnn.BatchNorm1d(10, is_train=False), model, False)
|
||||||
|
|
||||||
|
# ***************************************************************
|
||||||
|
# Test GroupNorm Layer
|
||||||
|
# ***************************************************************
|
||||||
|
arr = np.random.randn(16,10,224,224)
|
||||||
|
|
||||||
|
class Model(tnn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
self.layer = tnn.GroupNorm(2, 10)
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layer(x)
|
||||||
|
model = Model()
|
||||||
|
model.eval()
|
||||||
|
check_equal_with_istrain(arr, jnn.GroupNorm(2, 10, is_train=False), model, False, False)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
|
@ -0,0 +1,58 @@
|
||||||
|
|
||||||
|
# ***************************************************************
|
||||||
|
# Copyright (c) 2020 Jittor. Authors:
|
||||||
|
# Wenyang Zhou <576825820@qq.com>
|
||||||
|
# Dun Liang <randonlang@gmail.com>.
|
||||||
|
# All Rights Reserved.
|
||||||
|
# This file is subject to the terms and conditions defined in
|
||||||
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
# ***************************************************************
|
||||||
|
import unittest
|
||||||
|
import jittor as jt
|
||||||
|
import numpy as np
|
||||||
|
import jittor.nn as jnn
|
||||||
|
|
||||||
|
skip_this_test = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
jt.dirty_fix_pytorch_runtime_error()
|
||||||
|
import torch
|
||||||
|
import torch.nn as tnn
|
||||||
|
except:
|
||||||
|
torch = None
|
||||||
|
tnn = None
|
||||||
|
skip_this_test = True
|
||||||
|
|
||||||
|
def check_equal(res1, res2):
|
||||||
|
assert np.allclose(res1.detach().numpy(), res2.numpy())
|
||||||
|
|
||||||
|
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||||
|
class TestPad(unittest.TestCase):
|
||||||
|
def test_repeat(self):
|
||||||
|
arr = np.random.randn(16,3,224,224)
|
||||||
|
check_equal(torch.Tensor(arr).repeat(1,2,3,4), jt.array(arr).repeat(1,2,3,4))
|
||||||
|
check_equal(torch.Tensor(arr).repeat(4,2,3,4), jt.array(arr).repeat(4,2,3,4))
|
||||||
|
print('pass repeat test ...')
|
||||||
|
|
||||||
|
def test_chunk(self):
|
||||||
|
arr = np.random.randn(16,3,224,224)
|
||||||
|
check_equal(torch.Tensor(arr).chunk(2,0)[0], jt.array(arr).chunk(2,0)[0])
|
||||||
|
check_equal(torch.Tensor(arr).chunk(2,0)[1], jt.array(arr).chunk(2,0)[1])
|
||||||
|
print('pass chunk test ...')
|
||||||
|
|
||||||
|
def test_stack(self):
|
||||||
|
arr1 = np.random.randn(16,3,224,224)
|
||||||
|
arr2 = np.random.randn(16,3,224,224)
|
||||||
|
check_equal(torch.stack([torch.Tensor(arr1), torch.Tensor(arr2)], 0), jt.stack([jt.array(arr1), jt.array(arr2)], 0))
|
||||||
|
print('pass stack test ...')
|
||||||
|
|
||||||
|
def test_flip(self):
|
||||||
|
arr = np.random.randn(16,3,224,224)
|
||||||
|
check_equal(torch.Tensor(arr).flip(0), jt.array(arr).flip(0))
|
||||||
|
check_equal(torch.Tensor(arr).flip(1), jt.array(arr).flip(1))
|
||||||
|
check_equal(torch.Tensor(arr).flip(2), jt.array(arr).flip(2))
|
||||||
|
check_equal(torch.Tensor(arr).flip(3), jt.array(arr).flip(3))
|
||||||
|
print('pass flip test ...')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
|
@ -61,6 +61,14 @@ class TestRelu(unittest.TestCase):
|
||||||
check_equal(arr, jnn.LeakyReLU(), tnn.LeakyReLU())
|
check_equal(arr, jnn.LeakyReLU(), tnn.LeakyReLU())
|
||||||
check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2))
|
check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2))
|
||||||
check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9))
|
check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9))
|
||||||
|
|
||||||
|
# ***************************************************************
|
||||||
|
# Test Softplus Layer
|
||||||
|
# ***************************************************************
|
||||||
|
arr = np.random.randn(16,10,224,224)
|
||||||
|
check_equal(arr, jnn.Softplus (), tnn.Softplus ())
|
||||||
|
check_equal(arr, jnn.Softplus (2), tnn.Softplus (2))
|
||||||
|
check_equal(arr, jnn.Softplus (2, 99.9), tnn.Softplus (2, 99.9))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
Loading…
Reference in New Issue