mirror of https://github.com/Jittor/Jittor
266 lines
7.9 KiB
Python
266 lines
7.9 KiB
Python
# ***************************************************************
|
||
# Copyright (c) 2020 Jittor. Authors:
|
||
# Dun Liang <randonlang@gmail.com>.
|
||
# Wenyang Zhou <576825820@qq.com>
|
||
#
|
||
# All Rights Reserved.
|
||
# This file is subject to the terms and conditions defined in
|
||
# file 'LICENSE.txt', which is part of this source code package.
|
||
# ***************************************************************
|
||
import jittor as jt
|
||
import numpy as np
|
||
import math
|
||
from collections.abc import Sequence
|
||
|
||
def repeat(x, *shape):
|
||
r'''
|
||
Repeats this var along the specified dimensions.
|
||
|
||
Args:
|
||
|
||
x (var): jittor var.
|
||
|
||
shape (tuple): int or tuple. The number of times to repeat this var along each dimension.
|
||
|
||
Example:
|
||
|
||
>>> x = jt.array([1, 2, 3])
|
||
|
||
>>> x.repeat(4, 2)
|
||
[[ 1, 2, 3, 1, 2, 3],
|
||
[ 1, 2, 3, 1, 2, 3],
|
||
[ 1, 2, 3, 1, 2, 3],
|
||
[ 1, 2, 3, 1, 2, 3]]
|
||
|
||
>>> x.repeat(4, 2, 1).size()
|
||
[4, 2, 3,]
|
||
'''
|
||
if len(shape) == 1 and isinstance(shape[0], Sequence):
|
||
shape = shape[0]
|
||
len_x_shape = len(x.shape)
|
||
len_shape = len(shape)
|
||
x_shape = x.shape
|
||
rep_shape = shape
|
||
if len_x_shape < len_shape:
|
||
x_shape = (len_shape - len_x_shape) * [1] + x.shape
|
||
x = x.broadcast(x_shape)
|
||
elif len_x_shape > len_shape:
|
||
rep_shape = (len_x_shape - len_shape) * [1] + shape
|
||
tar_shape = (np.array(x_shape) * np.array(rep_shape)).tolist()
|
||
dims = []
|
||
for i in range(len(tar_shape)): dims.append(f"i{i}%{x_shape[i]}")
|
||
return x.reindex(tar_shape, dims)
|
||
jt.Var.repeat = repeat
|
||
|
||
def chunk(x, chunks, dim=0):
|
||
r'''
|
||
Splits a var into a specific number of chunks. Each chunk is a view of the input var.
|
||
|
||
Last chunk will be smaller if the var size along the given dimension dim is not divisible by chunks.
|
||
|
||
Args:
|
||
|
||
input (var) – the var to split.
|
||
|
||
chunks (int) – number of chunks to return.
|
||
|
||
dim (int) – dimension along which to split the var.
|
||
|
||
Example:
|
||
|
||
>>> x = jt.random((10,3,3))
|
||
|
||
>>> res = jt.chunk(x, 2, 0)
|
||
|
||
>>> print(res[0].shape, res[1].shape)
|
||
[5,3,3,] [5,3,3,]
|
||
'''
|
||
l = x.shape[dim]
|
||
res = []
|
||
if l <= chunks:
|
||
for i in range(l):
|
||
res.append(x[(slice(None,),)*dim+([i,],)])
|
||
else:
|
||
nums = (l-1) // chunks + 1
|
||
for i in range(chunks-1):
|
||
res.append(x[(slice(None,),)*dim+(slice(i*nums,(i+1)*nums),)])
|
||
if (i+1)*nums < l:
|
||
res.append(x[(slice(None,),)*dim+(slice((i+1)*nums,None),)])
|
||
return res
|
||
jt.Var.chunk = chunk
|
||
|
||
def stack(x, dim=0):
|
||
r'''
|
||
Concatenates sequence of vars along a new dimension.
|
||
|
||
All vars need to be of the same size.
|
||
|
||
Args:
|
||
|
||
x (sequence of vars) – sequence of vars to concatenate.
|
||
|
||
dim (int) – dimension to insert. Has to be between 0 and the number of dimensions of concatenated vars (inclusive).
|
||
|
||
Example:
|
||
|
||
>>> a1 = jt.array([[1,2,3]])
|
||
|
||
>>> a2 = jt.array([[4,5,6]])
|
||
|
||
>>> jt.stack([a1, a2], 0)
|
||
[[[1 2 3]
|
||
[[4 5 6]]]
|
||
'''
|
||
assert isinstance(x, list)
|
||
assert len(x) >= 2
|
||
res = [x_.unsqueeze(dim) for x_ in x]
|
||
return jt.contrib.concat(res, dim=dim)
|
||
jt.Var.stack = stack
|
||
|
||
def flip(x, dim=0):
|
||
r'''
|
||
Reverse the order of a n-D var along given axis in dims.
|
||
|
||
Args:
|
||
|
||
input (var) – the input var.
|
||
|
||
dims (a list or tuple) – axis to flip on.
|
||
|
||
Example:
|
||
|
||
>>> x = jt.array([[1,2,3,4]])
|
||
|
||
>>> x.flip(1)
|
||
[[4 3 2 1]]
|
||
'''
|
||
assert isinstance(dim, int)
|
||
tar_dims = []
|
||
for i in range(len(x.shape)):
|
||
if i == dim:
|
||
tar_dims.append(f"{x.shape[dim]-1}-i{i}")
|
||
else:
|
||
tar_dims.append(f"i{i}")
|
||
return x.reindex(x.shape, tar_dims)
|
||
jt.Var.flip = flip
|
||
|
||
def cross(input, other, dim=-1):
|
||
r'''
|
||
Returns the cross product of vectors in dimension dim of input and other.
|
||
|
||
the cross product can be calculated by (a1,a2,a3) x (b1,b2,b3) = (a2b3-a3b2, a3b1-a1b3, a1b2-a2b1)
|
||
|
||
input and other must have the same size, and the size of their dim dimension should be 3.
|
||
|
||
If dim is not given, it defaults to the first dimension found with the size 3.
|
||
|
||
Args:
|
||
|
||
input (Tensor) – the input tensor.
|
||
|
||
other (Tensor) – the second input tensor
|
||
|
||
dim (int, optional) – the dimension to take the cross-product in.
|
||
|
||
out (Tensor, optional) – the output tensor.
|
||
|
||
Example:
|
||
|
||
>>> input = jt.random((6,3))
|
||
|
||
>>> other = jt.random((6,3))
|
||
|
||
>>> jt.cross(input, other, dim=1)
|
||
[[-0.42732686 0.6827885 -0.49206433]
|
||
[ 0.4651107 0.27036983 -0.5580432 ]
|
||
[-0.31933784 0.10543461 0.09676848]
|
||
[-0.58346975 -0.21417202 0.55176204]
|
||
[-0.40861478 0.01496297 0.38638002]
|
||
[ 0.18393655 -0.04907863 -0.17928357]]
|
||
|
||
>>> jt.cross(input, other)
|
||
[[-0.42732686 0.6827885 -0.49206433]
|
||
[ 0.4651107 0.27036983 -0.5580432 ]
|
||
[-0.31933784 0.10543461 0.09676848]
|
||
[-0.58346975 -0.21417202 0.55176204]
|
||
[-0.40861478 0.01496297 0.38638002]
|
||
[ 0.18393655 -0.04907863 -0.17928357]]
|
||
'''
|
||
assert input.shape==other.shape, "input shape and other shape must be same"
|
||
if dim < 0: dim += len(input.shape)
|
||
assert input.shape[dim] == 3, "input dim shape must be 3"
|
||
a1 = input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(2,)]-input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(1,)]
|
||
a2 = input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(0,)]-input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(2,)]
|
||
a3 = input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(1,)]-input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(0,)]
|
||
return jt.contrib.concat([a1.unsqueeze(dim),a2.unsqueeze(dim),a3.unsqueeze(dim)], dim=dim)
|
||
jt.Var.cross = cross
|
||
|
||
def normalize(input, p=2, dim=1, eps=1e-12):
|
||
r'''
|
||
Performs L_p normalization of inputs over specified dimension.
|
||
|
||
Args:
|
||
|
||
input – input array of any shape
|
||
|
||
p (float) – the exponent value in the norm formulation. Default: 2
|
||
|
||
dim (int) – the dimension to reduce. Default: 1
|
||
|
||
eps (float) – small value to avoid division by zero. Default: 1e-12
|
||
|
||
Example:
|
||
|
||
>>> x = jt.random((6,3))
|
||
[[0.18777736 0.9739261 0.77647036]
|
||
[0.13710196 0.27282116 0.30533272]
|
||
[0.7272278 0.5174613 0.9719775 ]
|
||
[0.02566639 0.37504175 0.32676998]
|
||
[0.0231761 0.5207773 0.70337296]
|
||
[0.58966476 0.49547017 0.36724383]]
|
||
|
||
>>> jt.normalize(x)
|
||
[[0.14907198 0.7731768 0.61642134]
|
||
[0.31750825 0.63181424 0.7071063 ]
|
||
[0.5510936 0.39213243 0.736565 ]
|
||
[0.05152962 0.7529597 0.656046 ]
|
||
[0.02647221 0.59484214 0.80340654]
|
||
[0.6910677 0.58067477 0.4303977 ]]
|
||
'''
|
||
assert p == 2
|
||
if p == 2:
|
||
return input / jt.maximum(input.sqr().sum(dim,True).sqrt(), eps)
|
||
jt.Var.normalize = normalize
|
||
|
||
def unbind(x, dim=0):
|
||
r'''
|
||
Removes a var dimension.
|
||
|
||
Returns a tuple of all slices along a given dimension, already without it.
|
||
|
||
Args:
|
||
|
||
input (var) – the var to unbind
|
||
|
||
dim (int) – dimension to remove
|
||
|
||
Example:
|
||
|
||
a = jt.random((3,3))
|
||
b = jt.unbind(a, 0)
|
||
|
||
'''
|
||
if dim < 0: dim += len(x.shape)
|
||
return [x[(slice(None),)*dim+(i,)] for i in range(x.shape[dim])]
|
||
|
||
def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0):
|
||
assert range == None
|
||
assert scale_each == False
|
||
if isinstance(x, list): x = jt.stack(x)
|
||
if normalize: x = (x - x.min()) / (x.max() - x.min())
|
||
b,c,h,w = x.shape
|
||
ncol = math.ceil(b / nrow)
|
||
return x.reindex([c, h*ncol+(ncol+1)*padding, w*nrow+(nrow+1)*padding],
|
||
[f"i1/{padding+h}*{nrow}+i2/{padding+w}", "i0",
|
||
f"i1-i1/{padding+h}*{padding+h}-{padding}", f"i2-i2/{padding+w}*{padding+w}-{padding}"], overflow_value=pad_value)
|