add triu and tril function

This commit is contained in:
Dun Liang 2022-12-04 20:47:38 +08:00
parent 44fdc718ab
commit 499d3ee99c
3 changed files with 81 additions and 8 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.5.42'
__version__ = '1.3.5.43'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -443,7 +443,7 @@ def liveness_info():
"lived_ops": core.number_of_lived_ops(),
}
def ones(shape, dtype="float32"):
def ones(*shape, dtype="float32"):
''' Constructs a jittor Var with all elements set to 1.
:param shape: The shape of the output Var.
@ -453,8 +453,8 @@ def ones(shape, dtype="float32"):
:return: The output Var.
:rtype: jittor.Var
'''
if not isinstance(shape, (NanoVector, Sequence)):
shape = (shape,)
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
shape = shape[0]
return unary(1, dtype).broadcast(shape)
def ones_like(x):
@ -467,7 +467,7 @@ def ones_like(x):
'''
return ones(x.shape,x.dtype)
def zeros(shape, dtype="float32"):
def zeros(*shape, dtype="float32"):
''' Constructs a jittor Var with all elements set to 0.
:param shape: The shape of the output Var.
@ -477,8 +477,8 @@ def zeros(shape, dtype="float32"):
:return: The output Var.
:rtype: jittor.Var
'''
if not isinstance(shape, (NanoVector, Sequence)):
shape = (shape,)
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
shape = shape[0]
return unary(0, dtype).broadcast(shape)
def full(shape,val,dtype="float32"):

View File

@ -1991,3 +1991,55 @@ def from_torch(x):
Convert torch Tensor to Jittor Var
'''
return jt.Var(x.cpu().numpy())
def triu(input: jt.Var, diagonal:int=0) -> jt.Var:
''' Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
:param input: the input tensor.
:param diagonal: the diagonal to consider(int).
Example::
a = jt.ones(3, 3)
b = jt.triu(a)
assert jt.all_equal(b, [[1,1,1],[0,1,1],[0,0,1]])
b = jt.triu(a, diagonal=1)
assert jt.all_equal(b, [[0,1,1],[0,0,1],[0,0,0]])
b = jt.triu(a, diagonal=-1)
assert jt.all_equal(b, [[1,1,1],[1,1,1],[0,1,1]])
'''
index = input.index()
mask = index[-2] <= index[-1] - diagonal
return input*mask
jt.Var.triu = triu
def tril(input: jt.Var, diagonal:int=0) -> jt.Var:
''' Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
:param input: the input tensor.
:param diagonal: the diagonal to consider(int).
Example::
a = jt.ones(3, 3)
b = jt.tril(a)
assert jt.all_equal(b, [[1,0,0],[1,1,0],[1,1,1]])
b = jt.tril(a, diagonal=1)
assert jt.all_equal(b, [[1,1,0],[1,1,1],[1,1,1]])
b = jt.tril(a, diagonal=-1)
assert jt.all_equal(b, [[0,0,0],[1,0,0],[1,1,0]])
'''
index = input.index()
mask = index[-2] >= index[-1] - diagonal
return input*mask
jt.Var.tril = tril
def all_equal(a: jt.Var, b: jt.Var) -> bool:
return (a == b).all().item()
jt.all_equal = all_equal

View File

@ -343,5 +343,26 @@ class TestOther(unittest.TestCase):
output = m(input)
output.sync()
def test_tri(self):
a = jt.ones(3, 3)
b = jt.triu(a)
assert jt.all_equal(b, [[1,1,1],[0,1,1],[0,0,1]])
b = jt.triu(a, diagonal=1)
assert jt.all_equal(b, [[0,1,1],[0,0,1],[0,0,0]])
b = jt.triu(a, diagonal=-1)
assert jt.all_equal(b, [[1,1,1],[1,1,1],[0,1,1]])
a = jt.ones(3, 3)
b = jt.tril(a)
assert jt.all_equal(b, [[1,0,0],[1,1,0],[1,1,1]])
b = jt.tril(a, diagonal=1)
assert jt.all_equal(b, [[1,1,0],[1,1,1],[1,1,1]])
b = jt.tril(a, diagonal=-1)
assert jt.all_equal(b, [[0,0,0],[1,0,0],[1,1,0]])
if __name__ == "__main__":
unittest.main()