diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 8c2d486c..f3de97e9 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -761,3 +761,4 @@ from . import nn from .nn import matmul from . import contrib from .contrib import concat +from .misc import * \ No newline at end of file diff --git a/python/jittor/misc.py b/python/jittor/misc.py new file mode 100644 index 00000000..523e7fb8 --- /dev/null +++ b/python/jittor/misc.py @@ -0,0 +1,144 @@ +# *************************************************************** +# Copyright (c) 2020 Jittor. Authors: +# Dun Liang . +# Wenyang Zhou <576825820@qq.com> +# +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import numpy as np +from collections.abc import Sequence + +def repeat(x, *shape): + r''' + Repeats this var along the specified dimensions. + + Args: + + [in] x (var): jittor var. + + [in] shape (tuple): int or tuple. The number of times to repeat this var along each dimension. + + Example: + + >>> x = jt.array([1, 2, 3]) + + >>> x.repeat(4, 2) + [[ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3]] + + >>> x.repeat(4, 2, 1).size() + [4, 2, 3,] + ''' + if len(shape) == 1 and isinstance(shape[0], Sequence): + shape = shape[0] + len_x_shape = len(x.shape) + len_shape = len(shape) + x_shape = x.shape + rep_shape = shape + if len_x_shape < len_shape: + x_shape = (len_shape - len_x_shape) * [1] + x.shape + x = x.broadcast(x_shape) + elif len_x_shape > len_shape: + rep_shape = (len_x_shape - len_shape) * [1] + shape + tar_shape = (np.array(x_shape) * np.array(rep_shape)).tolist() + dims = [] + for i in range(len(tar_shape)): dims.append(f"i{i}%{x_shape[i]}") + return x.reindex(tar_shape, dims) +jt.Var.repeat = repeat + +def chunk(x, chunks, dim=0): + r''' + Splits a var into a specific number of chunks. Each chunk is a view of the input var. + + Last chunk will be smaller if the var size along the given dimension dim is not divisible by chunks. + + Args: + + [in] input (var) – the var to split. + + [in] chunks (int) – number of chunks to return. + + [in] dim (int) – dimension along which to split the var. + + Example: + + >>> x = jt.random((10,3,3)) + + >>> res = jt.chunk(x, 2, 0) + + >>> print(res[0].shape, res[1].shape) + [5,3,3,] [5,3,3,] + ''' + l = x.shape[dim] + res = [] + if l <= chunks: + for i in range(l): + res.append(x[(slice(None,),)*dim+([i,],)]) + else: + nums = (l-1) // chunks + 1 + for i in range(chunks-1): + res.append(x[(slice(None,),)*dim+(slice(i*nums,(i+1)*nums),)]) + if (i+1)*nums < l: + res.append(x[(slice(None,),)*dim+(slice((i+1)*nums,None),)]) + return res +jt.Var.chunk = chunk + +def stack(x, dim=0): + r''' + Concatenates sequence of vars along a new dimension. + + All vars need to be of the same size. + + Args: + + [in] x (sequence of vars) – sequence of vars to concatenate. + + [in] dim (int) – dimension to insert. Has to be between 0 and the number of dimensions of concatenated vars (inclusive). + + Example: + + >>> a1 = jt.array([[1,2,3]]) + + >>> a2 = jt.array([[4,5,6]]) + + >>> jt.stack([a1, a2], 0) + [[[1 2 3] + [[4 5 6]]] + ''' + assert isinstance(x, list) + assert len(x) >= 2 + res = [x_.unsqueeze(dim) for x_ in x] + return jt.contrib.concat(res, dim=dim) +jt.Var.stack = stack + +def flip(x, dim=0): + r''' + Reverse the order of a n-D var along given axis in dims. + + Args: + + [in] input (var) – the input var. + + [in] dims (a list or tuple) – axis to flip on. + + Example: + + >>> x = jt.array([[1,2,3,4]]) + + >>> x.flip(1) + [[4 3 2 1]] + ''' + assert isinstance(dim, int) + tar_dims = [] + for i in range(len(x.shape)): + if i == dim: + tar_dims.append(f"{x.shape[dim]-1}-i{i}") + else: + tar_dims.append(f"i{i}") + return x.reindex(x.shape, tar_dims) +jt.Var.flip = flip \ No newline at end of file diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 65d011a2..e7308b48 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -290,6 +290,36 @@ class InstanceNorm2d(Module): b = self.bias.broadcast(x, [0,2,3]) return norm_x * w + b +class GroupNorm(Module): + def __init__(self, num_groups, num_channels, eps=1e-05, affine=None, is_train=True, sync=True): + assert affine == None + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + self.is_train = is_train + self.sync = sync + self.weight = init.constant((num_channels,), "float32", 1.0) + self.bias = init.constant((num_channels,), "float32", 0.0) + + def execute(self, x): + N,C,H,W = x.shape + assert C == self.num_channels + assert C % self.num_groups == 0 + x_ = x.reindex([N, int(C/self.num_groups), self.num_groups, H, W], [ + "i0", f"i2*{C/self.num_groups}+i1", "i3", "i4" + ]) + xmean = jt.mean(x_, dims=[1,3,4], keepdims=1).reindex(x.shape, ["i0", "0", f"i1/({C}/{self.num_groups})","0", "0"]) + x2mean = jt.mean(x_*x_, dims=[1,3,4], keepdims=1).reindex(x.shape, ["i0", "0", f"i1/({C}/{self.num_groups})","0", "0"]) + if self.sync and jt.in_mpi: + xmean = xmean.mpi_all_reduce("mean") + x2mean = x2mean.mpi_all_reduce("mean") + + xvar = jt.maximum(x2mean-xmean*xmean, 0) + norm_x = (x-xmean)/jt.sqrt(xvar+self.eps) + w = self.weight.broadcast(x, [0,2,3]) + b = self.bias.broadcast(x, [0,2,3]) + return norm_x * w + b + Relu = jt.make_module(relu) ReLU = Relu Leaky_relu = jt.make_module(leaky_relu, 2) @@ -558,6 +588,23 @@ class Sigmoid(Module): def execute(self, x) : return x.sigmoid() +class Softplus(Module): + r''' + SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive. + + Args: + + [in] beta (float): the beta value for the Softplus formulation. Default: 1. + + [in] threshold (float): values above this revert to a linear function. Default: 20. + ''' + def __init__(self, beta=1, threshold=20): + self.beta = beta + self.threshold = threshold + + def execute(self, x): + return 1 / self.beta * jt.log(1 + (self.beta * x).exp()) + class Resize(Module): def __init__(self, size, mode="nearest", align_corners=False): super().__init__() @@ -611,6 +658,52 @@ def upsample(img, size, mode="nearest", align_corners=False): y = wid * (w / W) return _interpolate(img, x, y, (nid,cid), mode) +def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'): + r''' + Given an input and a flow-field grid, computes the output using input values and pixel locations from grid. + + grid specifies the sampling pixel locations normalized by the input spatial dimensions. Therefore, it should have most values in the range of [-1, 1]. For example, values x = -1, y = -1 is the left-top pixel of input, and values x = 1, y = 1 is the right-bottom pixel of input. + + Args: + + [in] input (var): the source input var, whose shape is (N, C, Hi, Wi) + + [in] grid (var): the pixel locations, whose shape is (N, Ho, Wo, 2) + + [in] mode (string): the interpolate way, default: bilinear. + + [in] padding_mode (string): the padding way, default: zeros. + + [out] output (var): the output var, whose shape is (N, C, Ho, Wo) + + Example: + + >>> x = jt.array([[[[1,2],[3,4]]]]) + + >>> print(x) + [[[[1 2] + [3 4]]]] + + >>> grid = jt.array([[[[0.5, 0.5]]]]) + + >>> print(x.shape, grid.shape) + [1,1,2,2,], [1,1,2,2,] + + >>> nn.grid_sample(x, grid) + [[[[3.25]]]] + ''' + assert padding_mode == 'zeros' + Ni, Ci, Hi, Wi = input.shape + No, Ho, Wo, D = grid.shape + assert D == 2 + assert Ni == No + assert len(input.shape) == 4 and len(grid.shape) + + nid, cid, hid, wid = jt.index((Ni,Ci,Ho,Wo)) + x = ((grid[:,:,:,1].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Hi - 1) + y = ((grid[:,:,:,0].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Wi - 1) + return _interpolate(input, x, y, (nid,cid), mode) + class Upsample(Module): def __init__(self, scale_factor=None, mode='nearest'): self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor) diff --git a/python/jittor/test/test_batchnorm.py b/python/jittor/test/test_batchnorm.py index 3812d7a6..cd93f89b 100644 --- a/python/jittor/test/test_batchnorm.py +++ b/python/jittor/test/test_batchnorm.py @@ -23,23 +23,21 @@ except: tnn = None skip_this_test = True -def check_equal_with_istrain(arr, j_layer, p_layer, is_train=True, threshold=1e-5): +def check_equal_with_istrain(arr, j_layer, p_layer, is_train=True, has_running=True, threshold=1e-5): jittor_arr = jt.array(arr) pytorch_arr = torch.Tensor(arr) - if is_train: - assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) - # assert np.allclose(p_layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold) - else: - assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) - # assert np.allclose(p_layer.layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold) + if has_running: + if is_train: + assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) + else: + assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) jittor_result = j_layer(jittor_arr) pytorch_result = p_layer(pytorch_arr) - if is_train: - assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) - # assert np.allclose(p_layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold) - else: - assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) - # assert np.allclose(p_layer.layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold) + if has_running: + if is_train: + assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) + else: + assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), threshold) def check_equal_without_istrain(arr, j_layer, p_layer, threshold=1e-5): @@ -100,5 +98,20 @@ class TestBatchNorm(unittest.TestCase): model.eval() check_equal_with_istrain(arr, jnn.BatchNorm1d(10, is_train=False), model, False) + # *************************************************************** + # Test GroupNorm Layer + # *************************************************************** + arr = np.random.randn(16,10,224,224) + + class Model(tnn.Module): + def __init__(self): + super(Model, self).__init__() + self.layer = tnn.GroupNorm(2, 10) + def forward(self, x): + return self.layer(x) + model = Model() + model.eval() + check_equal_with_istrain(arr, jnn.GroupNorm(2, 10, is_train=False), model, False, False) + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_misc_op.py b/python/jittor/test/test_misc_op.py new file mode 100644 index 00000000..33c82890 --- /dev/null +++ b/python/jittor/test/test_misc_op.py @@ -0,0 +1,58 @@ + +# *************************************************************** +# Copyright (c) 2020 Jittor. Authors: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.nn as jnn + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn +except: + torch = None + tnn = None + skip_this_test = True + +def check_equal(res1, res2): + assert np.allclose(res1.detach().numpy(), res2.numpy()) + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestPad(unittest.TestCase): + def test_repeat(self): + arr = np.random.randn(16,3,224,224) + check_equal(torch.Tensor(arr).repeat(1,2,3,4), jt.array(arr).repeat(1,2,3,4)) + check_equal(torch.Tensor(arr).repeat(4,2,3,4), jt.array(arr).repeat(4,2,3,4)) + print('pass repeat test ...') + + def test_chunk(self): + arr = np.random.randn(16,3,224,224) + check_equal(torch.Tensor(arr).chunk(2,0)[0], jt.array(arr).chunk(2,0)[0]) + check_equal(torch.Tensor(arr).chunk(2,0)[1], jt.array(arr).chunk(2,0)[1]) + print('pass chunk test ...') + + def test_stack(self): + arr1 = np.random.randn(16,3,224,224) + arr2 = np.random.randn(16,3,224,224) + check_equal(torch.stack([torch.Tensor(arr1), torch.Tensor(arr2)], 0), jt.stack([jt.array(arr1), jt.array(arr2)], 0)) + print('pass stack test ...') + + def test_flip(self): + arr = np.random.randn(16,3,224,224) + check_equal(torch.Tensor(arr).flip(0), jt.array(arr).flip(0)) + check_equal(torch.Tensor(arr).flip(1), jt.array(arr).flip(1)) + check_equal(torch.Tensor(arr).flip(2), jt.array(arr).flip(2)) + check_equal(torch.Tensor(arr).flip(3), jt.array(arr).flip(3)) + print('pass flip test ...') + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_relu.py b/python/jittor/test/test_relu.py index a61c9b7d..fefdec55 100644 --- a/python/jittor/test/test_relu.py +++ b/python/jittor/test/test_relu.py @@ -61,6 +61,14 @@ class TestRelu(unittest.TestCase): check_equal(arr, jnn.LeakyReLU(), tnn.LeakyReLU()) check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2)) check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9)) + + # *************************************************************** + # Test Softplus Layer + # *************************************************************** + arr = np.random.randn(16,10,224,224) + check_equal(arr, jnn.Softplus (), tnn.Softplus ()) + check_equal(arr, jnn.Softplus (2), tnn.Softplus (2)) + check_equal(arr, jnn.Softplus (2, 99.9), tnn.Softplus (2, 99.9)) if __name__ == "__main__": unittest.main() \ No newline at end of file