mirror of https://github.com/Jittor/Jittor
add triu and tril function
This commit is contained in:
parent
44fdc718ab
commit
499d3ee99c
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.5.42'
|
||||
__version__ = '1.3.5.43'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -443,7 +443,7 @@ def liveness_info():
|
|||
"lived_ops": core.number_of_lived_ops(),
|
||||
}
|
||||
|
||||
def ones(shape, dtype="float32"):
|
||||
def ones(*shape, dtype="float32"):
|
||||
''' Constructs a jittor Var with all elements set to 1.
|
||||
|
||||
:param shape: The shape of the output Var.
|
||||
|
@ -453,8 +453,8 @@ def ones(shape, dtype="float32"):
|
|||
:return: The output Var.
|
||||
:rtype: jittor.Var
|
||||
'''
|
||||
if not isinstance(shape, (NanoVector, Sequence)):
|
||||
shape = (shape,)
|
||||
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
|
||||
shape = shape[0]
|
||||
return unary(1, dtype).broadcast(shape)
|
||||
|
||||
def ones_like(x):
|
||||
|
@ -467,7 +467,7 @@ def ones_like(x):
|
|||
'''
|
||||
return ones(x.shape,x.dtype)
|
||||
|
||||
def zeros(shape, dtype="float32"):
|
||||
def zeros(*shape, dtype="float32"):
|
||||
''' Constructs a jittor Var with all elements set to 0.
|
||||
|
||||
:param shape: The shape of the output Var.
|
||||
|
@ -477,8 +477,8 @@ def zeros(shape, dtype="float32"):
|
|||
:return: The output Var.
|
||||
:rtype: jittor.Var
|
||||
'''
|
||||
if not isinstance(shape, (NanoVector, Sequence)):
|
||||
shape = (shape,)
|
||||
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
|
||||
shape = shape[0]
|
||||
return unary(0, dtype).broadcast(shape)
|
||||
|
||||
def full(shape,val,dtype="float32"):
|
||||
|
|
|
@ -1991,3 +1991,55 @@ def from_torch(x):
|
|||
Convert torch Tensor to Jittor Var
|
||||
'''
|
||||
return jt.Var(x.cpu().numpy())
|
||||
|
||||
def triu(input: jt.Var, diagonal:int=0) -> jt.Var:
|
||||
''' Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
|
||||
|
||||
:param input: the input tensor.
|
||||
:param diagonal: the diagonal to consider(int).
|
||||
|
||||
Example::
|
||||
|
||||
a = jt.ones(3, 3)
|
||||
b = jt.triu(a)
|
||||
assert jt.all_equal(b, [[1,1,1],[0,1,1],[0,0,1]])
|
||||
|
||||
b = jt.triu(a, diagonal=1)
|
||||
assert jt.all_equal(b, [[0,1,1],[0,0,1],[0,0,0]])
|
||||
|
||||
b = jt.triu(a, diagonal=-1)
|
||||
assert jt.all_equal(b, [[1,1,1],[1,1,1],[0,1,1]])
|
||||
|
||||
'''
|
||||
index = input.index()
|
||||
mask = index[-2] <= index[-1] - diagonal
|
||||
return input*mask
|
||||
jt.Var.triu = triu
|
||||
|
||||
def tril(input: jt.Var, diagonal:int=0) -> jt.Var:
|
||||
''' Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
|
||||
|
||||
:param input: the input tensor.
|
||||
:param diagonal: the diagonal to consider(int).
|
||||
|
||||
Example::
|
||||
|
||||
a = jt.ones(3, 3)
|
||||
b = jt.tril(a)
|
||||
assert jt.all_equal(b, [[1,0,0],[1,1,0],[1,1,1]])
|
||||
|
||||
b = jt.tril(a, diagonal=1)
|
||||
assert jt.all_equal(b, [[1,1,0],[1,1,1],[1,1,1]])
|
||||
|
||||
b = jt.tril(a, diagonal=-1)
|
||||
assert jt.all_equal(b, [[0,0,0],[1,0,0],[1,1,0]])
|
||||
|
||||
'''
|
||||
index = input.index()
|
||||
mask = index[-2] >= index[-1] - diagonal
|
||||
return input*mask
|
||||
jt.Var.tril = tril
|
||||
|
||||
def all_equal(a: jt.Var, b: jt.Var) -> bool:
|
||||
return (a == b).all().item()
|
||||
jt.all_equal = all_equal
|
||||
|
|
|
@ -343,5 +343,26 @@ class TestOther(unittest.TestCase):
|
|||
output = m(input)
|
||||
output.sync()
|
||||
|
||||
def test_tri(self):
|
||||
a = jt.ones(3, 3)
|
||||
b = jt.triu(a)
|
||||
assert jt.all_equal(b, [[1,1,1],[0,1,1],[0,0,1]])
|
||||
|
||||
b = jt.triu(a, diagonal=1)
|
||||
assert jt.all_equal(b, [[0,1,1],[0,0,1],[0,0,0]])
|
||||
|
||||
b = jt.triu(a, diagonal=-1)
|
||||
assert jt.all_equal(b, [[1,1,1],[1,1,1],[0,1,1]])
|
||||
|
||||
a = jt.ones(3, 3)
|
||||
b = jt.tril(a)
|
||||
assert jt.all_equal(b, [[1,0,0],[1,1,0],[1,1,1]])
|
||||
|
||||
b = jt.tril(a, diagonal=1)
|
||||
assert jt.all_equal(b, [[1,1,0],[1,1,1],[1,1,1]])
|
||||
|
||||
b = jt.tril(a, diagonal=-1)
|
||||
assert jt.all_equal(b, [[0,0,0],[1,0,0],[1,1,0]])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue