mirror of https://github.com/Jittor/Jittor
add cross & normalize
This commit is contained in:
parent
1e32486f59
commit
87b1933447
|
@ -17,9 +17,9 @@ def repeat(x, *shape):
|
|||
|
||||
Args:
|
||||
|
||||
[in] x (var): jittor var.
|
||||
x (var): jittor var.
|
||||
|
||||
[in] shape (tuple): int or tuple. The number of times to repeat this var along each dimension.
|
||||
shape (tuple): int or tuple. The number of times to repeat this var along each dimension.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -59,11 +59,11 @@ def chunk(x, chunks, dim=0):
|
|||
|
||||
Args:
|
||||
|
||||
[in] input (var) – the var to split.
|
||||
input (var) – the var to split.
|
||||
|
||||
[in] chunks (int) – number of chunks to return.
|
||||
chunks (int) – number of chunks to return.
|
||||
|
||||
[in] dim (int) – dimension along which to split the var.
|
||||
dim (int) – dimension along which to split the var.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -96,9 +96,9 @@ def stack(x, dim=0):
|
|||
|
||||
Args:
|
||||
|
||||
[in] x (sequence of vars) – sequence of vars to concatenate.
|
||||
x (sequence of vars) – sequence of vars to concatenate.
|
||||
|
||||
[in] dim (int) – dimension to insert. Has to be between 0 and the number of dimensions of concatenated vars (inclusive).
|
||||
dim (int) – dimension to insert. Has to be between 0 and the number of dimensions of concatenated vars (inclusive).
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -122,9 +122,9 @@ def flip(x, dim=0):
|
|||
|
||||
Args:
|
||||
|
||||
[in] input (var) – the input var.
|
||||
input (var) – the input var.
|
||||
|
||||
[in] dims (a list or tuple) – axis to flip on.
|
||||
dims (a list or tuple) – axis to flip on.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -141,4 +141,92 @@ def flip(x, dim=0):
|
|||
else:
|
||||
tar_dims.append(f"i{i}")
|
||||
return x.reindex(x.shape, tar_dims)
|
||||
jt.Var.flip = flip
|
||||
jt.Var.flip = flip
|
||||
|
||||
def cross(input, other, dim=-1):
|
||||
r'''
|
||||
Returns the cross product of vectors in dimension dim of input and other.
|
||||
|
||||
the cross product can be calculated by (a1,a2,a3) x (b1,b2,b3) = (a2b3-a3b2, a3b1-a1b3, a1b2-a2b1)
|
||||
|
||||
input and other must have the same size, and the size of their dim dimension should be 3.
|
||||
|
||||
If dim is not given, it defaults to the first dimension found with the size 3.
|
||||
|
||||
Args:
|
||||
|
||||
input (Tensor) – the input tensor.
|
||||
|
||||
other (Tensor) – the second input tensor
|
||||
|
||||
dim (int, optional) – the dimension to take the cross-product in.
|
||||
|
||||
out (Tensor, optional) – the output tensor.
|
||||
|
||||
Example:
|
||||
|
||||
>>> input = jt.random((6,3))
|
||||
|
||||
>>> other = jt.random((6,3))
|
||||
|
||||
>>> jt.cross(input, other, dim=1)
|
||||
[[-0.42732686 0.6827885 -0.49206433]
|
||||
[ 0.4651107 0.27036983 -0.5580432 ]
|
||||
[-0.31933784 0.10543461 0.09676848]
|
||||
[-0.58346975 -0.21417202 0.55176204]
|
||||
[-0.40861478 0.01496297 0.38638002]
|
||||
[ 0.18393655 -0.04907863 -0.17928357]]
|
||||
|
||||
>>> jt.cross(input, other)
|
||||
[[-0.42732686 0.6827885 -0.49206433]
|
||||
[ 0.4651107 0.27036983 -0.5580432 ]
|
||||
[-0.31933784 0.10543461 0.09676848]
|
||||
[-0.58346975 -0.21417202 0.55176204]
|
||||
[-0.40861478 0.01496297 0.38638002]
|
||||
[ 0.18393655 -0.04907863 -0.17928357]]
|
||||
'''
|
||||
assert input.shape==other.shape, "input shape and other shape must be same"
|
||||
if dim < 0: dim += len(input.shape)
|
||||
assert input.shape[dim] == 3, "input dim shape must be 3"
|
||||
a1 = input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(2,)]-input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(1,)]
|
||||
a2 = input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(0,)]-input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(2,)]
|
||||
a3 = input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(1,)]-input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(0,)]
|
||||
return jt.contrib.concat([a1.unsqueeze(dim),a2.unsqueeze(dim),a3.unsqueeze(dim)], dim=dim)
|
||||
jt.Var.cross = cross
|
||||
|
||||
def normalize(input, p=2, dim=1, eps=1e-12):
|
||||
r'''
|
||||
Performs L_p normalization of inputs over specified dimension.
|
||||
|
||||
Args:
|
||||
|
||||
input – input array of any shape
|
||||
|
||||
p (float) – the exponent value in the norm formulation. Default: 2
|
||||
|
||||
dim (int) – the dimension to reduce. Default: 1
|
||||
|
||||
eps (float) – small value to avoid division by zero. Default: 1e-12
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = jt.random((6,3))
|
||||
[[0.18777736 0.9739261 0.77647036]
|
||||
[0.13710196 0.27282116 0.30533272]
|
||||
[0.7272278 0.5174613 0.9719775 ]
|
||||
[0.02566639 0.37504175 0.32676998]
|
||||
[0.0231761 0.5207773 0.70337296]
|
||||
[0.58966476 0.49547017 0.36724383]]
|
||||
|
||||
>>> jt.normalize(x)
|
||||
[[0.14907198 0.7731768 0.61642134]
|
||||
[0.31750825 0.63181424 0.7071063 ]
|
||||
[0.5510936 0.39213243 0.736565 ]
|
||||
[0.05152962 0.7529597 0.656046 ]
|
||||
[0.02647221 0.59484214 0.80340654]
|
||||
[0.6910677 0.58067477 0.4303977 ]]
|
||||
'''
|
||||
assert p == 2
|
||||
if p == 2:
|
||||
return input / jt.maximum(input.sqr().sum(dim,True).sqrt(), eps)
|
||||
jt.Var.normalize = normalize
|
|
@ -23,8 +23,8 @@ except:
|
|||
tnn = None
|
||||
skip_this_test = True
|
||||
|
||||
def check_equal(res1, res2):
|
||||
assert np.allclose(res1.detach().numpy(), res2.numpy())
|
||||
def check_equal(res1, res2, eps=1e-5):
|
||||
assert np.allclose(res1.detach().numpy(), res2.numpy(), eps)
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestPad(unittest.TestCase):
|
||||
|
@ -54,5 +54,24 @@ class TestPad(unittest.TestCase):
|
|||
check_equal(torch.Tensor(arr).flip(3), jt.array(arr).flip(3))
|
||||
print('pass flip test ...')
|
||||
|
||||
def test_cross(self):
|
||||
arr1 = np.random.randn(16,3,224,224,3)
|
||||
arr2 = np.random.randn(16,3,224,224,3)
|
||||
check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=1), jt.array(arr1).cross(jt.array(arr2), dim=1), 1e-1)
|
||||
check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=-4), jt.array(arr1).cross(jt.array(arr2), dim=-4), 1e-1)
|
||||
check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=-1), jt.array(arr1).cross(jt.array(arr2), dim=-1), 1e-1)
|
||||
check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=4), jt.array(arr1).cross(jt.array(arr2), dim=4), 1e-1)
|
||||
print('pass cross test ...')
|
||||
|
||||
def test_normalize(self):
|
||||
arr = np.random.randn(16,3,224,224,3)
|
||||
check_equal(tnn.functional.normalize(torch.Tensor(arr)), jt.normalize(jt.array(arr)))
|
||||
check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=0), jt.normalize(jt.array(arr), dim=0), 1e-1)
|
||||
check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=1), jt.normalize(jt.array(arr), dim=1), 1e-1)
|
||||
check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=-1), jt.normalize(jt.array(arr), dim=-1), 1e-1)
|
||||
check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=2), jt.normalize(jt.array(arr), dim=2), 1e-1)
|
||||
check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=3), jt.normalize(jt.array(arr), dim=3), 1e-1)
|
||||
print('pass normalize test ...')
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue