add cross & normalize

This commit is contained in:
zhouwy19 2020-07-31 09:34:49 +08:00
parent 1e32486f59
commit 87b1933447
2 changed files with 119 additions and 12 deletions

View File

@ -17,9 +17,9 @@ def repeat(x, *shape):
Args:
[in] x (var): jittor var.
x (var): jittor var.
[in] shape (tuple): int or tuple. The number of times to repeat this var along each dimension.
shape (tuple): int or tuple. The number of times to repeat this var along each dimension.
Example:
@ -59,11 +59,11 @@ def chunk(x, chunks, dim=0):
Args:
[in] input (var) the var to split.
input (var) the var to split.
[in] chunks (int) number of chunks to return.
chunks (int) number of chunks to return.
[in] dim (int) dimension along which to split the var.
dim (int) dimension along which to split the var.
Example:
@ -96,9 +96,9 @@ def stack(x, dim=0):
Args:
[in] x (sequence of vars) sequence of vars to concatenate.
x (sequence of vars) sequence of vars to concatenate.
[in] dim (int) dimension to insert. Has to be between 0 and the number of dimensions of concatenated vars (inclusive).
dim (int) dimension to insert. Has to be between 0 and the number of dimensions of concatenated vars (inclusive).
Example:
@ -122,9 +122,9 @@ def flip(x, dim=0):
Args:
[in] input (var) the input var.
input (var) the input var.
[in] dims (a list or tuple) axis to flip on.
dims (a list or tuple) axis to flip on.
Example:
@ -141,4 +141,92 @@ def flip(x, dim=0):
else:
tar_dims.append(f"i{i}")
return x.reindex(x.shape, tar_dims)
jt.Var.flip = flip
jt.Var.flip = flip
def cross(input, other, dim=-1):
r'''
Returns the cross product of vectors in dimension dim of input and other.
the cross product can be calculated by (a1,a2,a3) x (b1,b2,b3) = (a2b3-a3b2, a3b1-a1b3, a1b2-a2b1)
input and other must have the same size, and the size of their dim dimension should be 3.
If dim is not given, it defaults to the first dimension found with the size 3.
Args:
input (Tensor) the input tensor.
other (Tensor) the second input tensor
dim (int, optional) the dimension to take the cross-product in.
out (Tensor, optional) the output tensor.
Example:
>>> input = jt.random((6,3))
>>> other = jt.random((6,3))
>>> jt.cross(input, other, dim=1)
[[-0.42732686 0.6827885 -0.49206433]
[ 0.4651107 0.27036983 -0.5580432 ]
[-0.31933784 0.10543461 0.09676848]
[-0.58346975 -0.21417202 0.55176204]
[-0.40861478 0.01496297 0.38638002]
[ 0.18393655 -0.04907863 -0.17928357]]
>>> jt.cross(input, other)
[[-0.42732686 0.6827885 -0.49206433]
[ 0.4651107 0.27036983 -0.5580432 ]
[-0.31933784 0.10543461 0.09676848]
[-0.58346975 -0.21417202 0.55176204]
[-0.40861478 0.01496297 0.38638002]
[ 0.18393655 -0.04907863 -0.17928357]]
'''
assert input.shape==other.shape, "input shape and other shape must be same"
if dim < 0: dim += len(input.shape)
assert input.shape[dim] == 3, "input dim shape must be 3"
a1 = input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(2,)]-input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(1,)]
a2 = input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(0,)]-input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(2,)]
a3 = input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(1,)]-input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(0,)]
return jt.contrib.concat([a1.unsqueeze(dim),a2.unsqueeze(dim),a3.unsqueeze(dim)], dim=dim)
jt.Var.cross = cross
def normalize(input, p=2, dim=1, eps=1e-12):
r'''
Performs L_p normalization of inputs over specified dimension.
Args:
input input array of any shape
p (float) the exponent value in the norm formulation. Default: 2
dim (int) the dimension to reduce. Default: 1
eps (float) small value to avoid division by zero. Default: 1e-12
Example:
>>> x = jt.random((6,3))
[[0.18777736 0.9739261 0.77647036]
[0.13710196 0.27282116 0.30533272]
[0.7272278 0.5174613 0.9719775 ]
[0.02566639 0.37504175 0.32676998]
[0.0231761 0.5207773 0.70337296]
[0.58966476 0.49547017 0.36724383]]
>>> jt.normalize(x)
[[0.14907198 0.7731768 0.61642134]
[0.31750825 0.63181424 0.7071063 ]
[0.5510936 0.39213243 0.736565 ]
[0.05152962 0.7529597 0.656046 ]
[0.02647221 0.59484214 0.80340654]
[0.6910677 0.58067477 0.4303977 ]]
'''
assert p == 2
if p == 2:
return input / jt.maximum(input.sqr().sum(dim,True).sqrt(), eps)
jt.Var.normalize = normalize

View File

@ -23,8 +23,8 @@ except:
tnn = None
skip_this_test = True
def check_equal(res1, res2):
assert np.allclose(res1.detach().numpy(), res2.numpy())
def check_equal(res1, res2, eps=1e-5):
assert np.allclose(res1.detach().numpy(), res2.numpy(), eps)
@unittest.skipIf(skip_this_test, "No Torch found")
class TestPad(unittest.TestCase):
@ -54,5 +54,24 @@ class TestPad(unittest.TestCase):
check_equal(torch.Tensor(arr).flip(3), jt.array(arr).flip(3))
print('pass flip test ...')
def test_cross(self):
arr1 = np.random.randn(16,3,224,224,3)
arr2 = np.random.randn(16,3,224,224,3)
check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=1), jt.array(arr1).cross(jt.array(arr2), dim=1), 1e-1)
check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=-4), jt.array(arr1).cross(jt.array(arr2), dim=-4), 1e-1)
check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=-1), jt.array(arr1).cross(jt.array(arr2), dim=-1), 1e-1)
check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=4), jt.array(arr1).cross(jt.array(arr2), dim=4), 1e-1)
print('pass cross test ...')
def test_normalize(self):
arr = np.random.randn(16,3,224,224,3)
check_equal(tnn.functional.normalize(torch.Tensor(arr)), jt.normalize(jt.array(arr)))
check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=0), jt.normalize(jt.array(arr), dim=0), 1e-1)
check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=1), jt.normalize(jt.array(arr), dim=1), 1e-1)
check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=-1), jt.normalize(jt.array(arr), dim=-1), 1e-1)
check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=2), jt.normalize(jt.array(arr), dim=2), 1e-1)
check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=3), jt.normalize(jt.array(arr), dim=3), 1e-1)
print('pass normalize test ...')
if __name__ == "__main__":
unittest.main()