repeat chunk stack flip Softplus GN grid_sample

This commit is contained in:
zhouwy19 2020-07-30 23:07:35 +08:00
parent f434f64cd2
commit 90a1422b3c
6 changed files with 330 additions and 13 deletions

View File

@ -761,3 +761,4 @@ from . import nn
from .nn import matmul from .nn import matmul
from . import contrib from . import contrib
from .contrib import concat from .contrib import concat
from .misc import *

144
python/jittor/misc.py Normal file
View File

@ -0,0 +1,144 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Dun Liang <randonlang@gmail.com>.
# Wenyang Zhou <576825820@qq.com>
#
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import numpy as np
from collections.abc import Sequence
def repeat(x, *shape):
r'''
Repeats this var along the specified dimensions.
Args:
[in] x (var): jittor var.
[in] shape (tuple): int or tuple. The number of times to repeat this var along each dimension.
Example:
>>> x = jt.array([1, 2, 3])
>>> x.repeat(4, 2)
[[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3]]
>>> x.repeat(4, 2, 1).size()
[4, 2, 3,]
'''
if len(shape) == 1 and isinstance(shape[0], Sequence):
shape = shape[0]
len_x_shape = len(x.shape)
len_shape = len(shape)
x_shape = x.shape
rep_shape = shape
if len_x_shape < len_shape:
x_shape = (len_shape - len_x_shape) * [1] + x.shape
x = x.broadcast(x_shape)
elif len_x_shape > len_shape:
rep_shape = (len_x_shape - len_shape) * [1] + shape
tar_shape = (np.array(x_shape) * np.array(rep_shape)).tolist()
dims = []
for i in range(len(tar_shape)): dims.append(f"i{i}%{x_shape[i]}")
return x.reindex(tar_shape, dims)
jt.Var.repeat = repeat
def chunk(x, chunks, dim=0):
r'''
Splits a var into a specific number of chunks. Each chunk is a view of the input var.
Last chunk will be smaller if the var size along the given dimension dim is not divisible by chunks.
Args:
[in] input (var) the var to split.
[in] chunks (int) number of chunks to return.
[in] dim (int) dimension along which to split the var.
Example:
>>> x = jt.random((10,3,3))
>>> res = jt.chunk(x, 2, 0)
>>> print(res[0].shape, res[1].shape)
[5,3,3,] [5,3,3,]
'''
l = x.shape[dim]
res = []
if l <= chunks:
for i in range(l):
res.append(x[(slice(None,),)*dim+([i,],)])
else:
nums = (l-1) // chunks + 1
for i in range(chunks-1):
res.append(x[(slice(None,),)*dim+(slice(i*nums,(i+1)*nums),)])
if (i+1)*nums < l:
res.append(x[(slice(None,),)*dim+(slice((i+1)*nums,None),)])
return res
jt.Var.chunk = chunk
def stack(x, dim=0):
r'''
Concatenates sequence of vars along a new dimension.
All vars need to be of the same size.
Args:
[in] x (sequence of vars) sequence of vars to concatenate.
[in] dim (int) dimension to insert. Has to be between 0 and the number of dimensions of concatenated vars (inclusive).
Example:
>>> a1 = jt.array([[1,2,3]])
>>> a2 = jt.array([[4,5,6]])
>>> jt.stack([a1, a2], 0)
[[[1 2 3]
[[4 5 6]]]
'''
assert isinstance(x, list)
assert len(x) >= 2
res = [x_.unsqueeze(dim) for x_ in x]
return jt.contrib.concat(res, dim=dim)
jt.Var.stack = stack
def flip(x, dim=0):
r'''
Reverse the order of a n-D var along given axis in dims.
Args:
[in] input (var) the input var.
[in] dims (a list or tuple) axis to flip on.
Example:
>>> x = jt.array([[1,2,3,4]])
>>> x.flip(1)
[[4 3 2 1]]
'''
assert isinstance(dim, int)
tar_dims = []
for i in range(len(x.shape)):
if i == dim:
tar_dims.append(f"{x.shape[dim]-1}-i{i}")
else:
tar_dims.append(f"i{i}")
return x.reindex(x.shape, tar_dims)
jt.Var.flip = flip

View File

@ -290,6 +290,36 @@ class InstanceNorm2d(Module):
b = self.bias.broadcast(x, [0,2,3]) b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b return norm_x * w + b
class GroupNorm(Module):
def __init__(self, num_groups, num_channels, eps=1e-05, affine=None, is_train=True, sync=True):
assert affine == None
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.is_train = is_train
self.sync = sync
self.weight = init.constant((num_channels,), "float32", 1.0)
self.bias = init.constant((num_channels,), "float32", 0.0)
def execute(self, x):
N,C,H,W = x.shape
assert C == self.num_channels
assert C % self.num_groups == 0
x_ = x.reindex([N, int(C/self.num_groups), self.num_groups, H, W], [
"i0", f"i2*{C/self.num_groups}+i1", "i3", "i4"
])
xmean = jt.mean(x_, dims=[1,3,4], keepdims=1).reindex(x.shape, ["i0", "0", f"i1/({C}/{self.num_groups})","0", "0"])
x2mean = jt.mean(x_*x_, dims=[1,3,4], keepdims=1).reindex(x.shape, ["i0", "0", f"i1/({C}/{self.num_groups})","0", "0"])
if self.sync and jt.in_mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")
xvar = jt.maximum(x2mean-xmean*xmean, 0)
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
w = self.weight.broadcast(x, [0,2,3])
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
Relu = jt.make_module(relu) Relu = jt.make_module(relu)
ReLU = Relu ReLU = Relu
Leaky_relu = jt.make_module(leaky_relu, 2) Leaky_relu = jt.make_module(leaky_relu, 2)
@ -558,6 +588,23 @@ class Sigmoid(Module):
def execute(self, x) : def execute(self, x) :
return x.sigmoid() return x.sigmoid()
class Softplus(Module):
r'''
SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive.
Args:
[in] beta (float): the beta value for the Softplus formulation. Default: 1.
[in] threshold (float): values above this revert to a linear function. Default: 20.
'''
def __init__(self, beta=1, threshold=20):
self.beta = beta
self.threshold = threshold
def execute(self, x):
return 1 / self.beta * jt.log(1 + (self.beta * x).exp())
class Resize(Module): class Resize(Module):
def __init__(self, size, mode="nearest", align_corners=False): def __init__(self, size, mode="nearest", align_corners=False):
super().__init__() super().__init__()
@ -611,6 +658,52 @@ def upsample(img, size, mode="nearest", align_corners=False):
y = wid * (w / W) y = wid * (w / W)
return _interpolate(img, x, y, (nid,cid), mode) return _interpolate(img, x, y, (nid,cid), mode)
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
r'''
Given an input and a flow-field grid, computes the output using input values and pixel locations from grid.
grid specifies the sampling pixel locations normalized by the input spatial dimensions. Therefore, it should have most values in the range of [-1, 1]. For example, values x = -1, y = -1 is the left-top pixel of input, and values x = 1, y = 1 is the right-bottom pixel of input.
Args:
[in] input (var): the source input var, whose shape is (N, C, Hi, Wi)
[in] grid (var): the pixel locations, whose shape is (N, Ho, Wo, 2)
[in] mode (string): the interpolate way, default: bilinear.
[in] padding_mode (string): the padding way, default: zeros.
[out] output (var): the output var, whose shape is (N, C, Ho, Wo)
Example:
>>> x = jt.array([[[[1,2],[3,4]]]])
>>> print(x)
[[[[1 2]
[3 4]]]]
>>> grid = jt.array([[[[0.5, 0.5]]]])
>>> print(x.shape, grid.shape)
[1,1,2,2,], [1,1,2,2,]
>>> nn.grid_sample(x, grid)
[[[[3.25]]]]
'''
assert padding_mode == 'zeros'
Ni, Ci, Hi, Wi = input.shape
No, Ho, Wo, D = grid.shape
assert D == 2
assert Ni == No
assert len(input.shape) == 4 and len(grid.shape)
nid, cid, hid, wid = jt.index((Ni,Ci,Ho,Wo))
x = ((grid[:,:,:,1].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Hi - 1)
y = ((grid[:,:,:,0].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Wi - 1)
return _interpolate(input, x, y, (nid,cid), mode)
class Upsample(Module): class Upsample(Module):
def __init__(self, scale_factor=None, mode='nearest'): def __init__(self, scale_factor=None, mode='nearest'):
self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor) self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor)

View File

@ -23,23 +23,21 @@ except:
tnn = None tnn = None
skip_this_test = True skip_this_test = True
def check_equal_with_istrain(arr, j_layer, p_layer, is_train=True, threshold=1e-5): def check_equal_with_istrain(arr, j_layer, p_layer, is_train=True, has_running=True, threshold=1e-5):
jittor_arr = jt.array(arr) jittor_arr = jt.array(arr)
pytorch_arr = torch.Tensor(arr) pytorch_arr = torch.Tensor(arr)
if is_train: if has_running:
assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) if is_train:
# assert np.allclose(p_layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold) assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
else: else:
assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
# assert np.allclose(p_layer.layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
jittor_result = j_layer(jittor_arr) jittor_result = j_layer(jittor_arr)
pytorch_result = p_layer(pytorch_arr) pytorch_result = p_layer(pytorch_arr)
if is_train: if has_running:
assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) if is_train:
# assert np.allclose(p_layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold) assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
else: else:
assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
# assert np.allclose(p_layer.layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), threshold) assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), threshold)
def check_equal_without_istrain(arr, j_layer, p_layer, threshold=1e-5): def check_equal_without_istrain(arr, j_layer, p_layer, threshold=1e-5):
@ -100,5 +98,20 @@ class TestBatchNorm(unittest.TestCase):
model.eval() model.eval()
check_equal_with_istrain(arr, jnn.BatchNorm1d(10, is_train=False), model, False) check_equal_with_istrain(arr, jnn.BatchNorm1d(10, is_train=False), model, False)
# ***************************************************************
# Test GroupNorm Layer
# ***************************************************************
arr = np.random.randn(16,10,224,224)
class Model(tnn.Module):
def __init__(self):
super(Model, self).__init__()
self.layer = tnn.GroupNorm(2, 10)
def forward(self, x):
return self.layer(x)
model = Model()
model.eval()
check_equal_with_istrain(arr, jnn.GroupNorm(2, 10, is_train=False), model, False, False)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -0,0 +1,58 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import jittor.nn as jnn
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
import torch.nn as tnn
except:
torch = None
tnn = None
skip_this_test = True
def check_equal(res1, res2):
assert np.allclose(res1.detach().numpy(), res2.numpy())
@unittest.skipIf(skip_this_test, "No Torch found")
class TestPad(unittest.TestCase):
def test_repeat(self):
arr = np.random.randn(16,3,224,224)
check_equal(torch.Tensor(arr).repeat(1,2,3,4), jt.array(arr).repeat(1,2,3,4))
check_equal(torch.Tensor(arr).repeat(4,2,3,4), jt.array(arr).repeat(4,2,3,4))
print('pass repeat test ...')
def test_chunk(self):
arr = np.random.randn(16,3,224,224)
check_equal(torch.Tensor(arr).chunk(2,0)[0], jt.array(arr).chunk(2,0)[0])
check_equal(torch.Tensor(arr).chunk(2,0)[1], jt.array(arr).chunk(2,0)[1])
print('pass chunk test ...')
def test_stack(self):
arr1 = np.random.randn(16,3,224,224)
arr2 = np.random.randn(16,3,224,224)
check_equal(torch.stack([torch.Tensor(arr1), torch.Tensor(arr2)], 0), jt.stack([jt.array(arr1), jt.array(arr2)], 0))
print('pass stack test ...')
def test_flip(self):
arr = np.random.randn(16,3,224,224)
check_equal(torch.Tensor(arr).flip(0), jt.array(arr).flip(0))
check_equal(torch.Tensor(arr).flip(1), jt.array(arr).flip(1))
check_equal(torch.Tensor(arr).flip(2), jt.array(arr).flip(2))
check_equal(torch.Tensor(arr).flip(3), jt.array(arr).flip(3))
print('pass flip test ...')
if __name__ == "__main__":
unittest.main()

View File

@ -62,5 +62,13 @@ class TestRelu(unittest.TestCase):
check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2)) check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2))
check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9)) check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9))
# ***************************************************************
# Test Softplus Layer
# ***************************************************************
arr = np.random.randn(16,10,224,224)
check_equal(arr, jnn.Softplus (), tnn.Softplus ())
check_equal(arr, jnn.Softplus (2), tnn.Softplus (2))
check_equal(arr, jnn.Softplus (2, 99.9), tnn.Softplus (2, 99.9))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()