From 499d3ee99c3f7b2e79b95b9d2de2c5e2ddabf778 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Sun, 4 Dec 2022 20:47:38 +0800 Subject: [PATCH] add triu and tril function --- python/jittor/__init__.py | 14 ++++---- python/jittor/misc.py | 54 +++++++++++++++++++++++++++++- python/jittor/test/test_misc_op.py | 21 ++++++++++++ 3 files changed, 81 insertions(+), 8 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 1100ced6..68c75d0b 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.5.42' +__version__ = '1.3.5.43' from jittor_utils import lock with lock.lock_scope(): ori_int = int @@ -443,7 +443,7 @@ def liveness_info(): "lived_ops": core.number_of_lived_ops(), } -def ones(shape, dtype="float32"): +def ones(*shape, dtype="float32"): ''' Constructs a jittor Var with all elements set to 1. :param shape: The shape of the output Var. @@ -453,8 +453,8 @@ def ones(shape, dtype="float32"): :return: The output Var. :rtype: jittor.Var ''' - if not isinstance(shape, (NanoVector, Sequence)): - shape = (shape,) + if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)): + shape = shape[0] return unary(1, dtype).broadcast(shape) def ones_like(x): @@ -467,7 +467,7 @@ def ones_like(x): ''' return ones(x.shape,x.dtype) -def zeros(shape, dtype="float32"): +def zeros(*shape, dtype="float32"): ''' Constructs a jittor Var with all elements set to 0. :param shape: The shape of the output Var. @@ -477,8 +477,8 @@ def zeros(shape, dtype="float32"): :return: The output Var. :rtype: jittor.Var ''' - if not isinstance(shape, (NanoVector, Sequence)): - shape = (shape,) + if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)): + shape = shape[0] return unary(0, dtype).broadcast(shape) def full(shape,val,dtype="float32"): diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 5cc0b53e..9d397e28 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -1990,4 +1990,56 @@ def from_torch(x): ''' Convert torch Tensor to Jittor Var ''' - return jt.Var(x.cpu().numpy()) \ No newline at end of file + return jt.Var(x.cpu().numpy()) + +def triu(input: jt.Var, diagonal:int=0) -> jt.Var: + ''' Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0. + + :param input: the input tensor. + :param diagonal: the diagonal to consider(int). + + Example:: + + a = jt.ones(3, 3) + b = jt.triu(a) + assert jt.all_equal(b, [[1,1,1],[0,1,1],[0,0,1]]) + + b = jt.triu(a, diagonal=1) + assert jt.all_equal(b, [[0,1,1],[0,0,1],[0,0,0]]) + + b = jt.triu(a, diagonal=-1) + assert jt.all_equal(b, [[1,1,1],[1,1,1],[0,1,1]]) + + ''' + index = input.index() + mask = index[-2] <= index[-1] - diagonal + return input*mask +jt.Var.triu = triu + +def tril(input: jt.Var, diagonal:int=0) -> jt.Var: + ''' Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0. + + :param input: the input tensor. + :param diagonal: the diagonal to consider(int). + + Example:: + + a = jt.ones(3, 3) + b = jt.tril(a) + assert jt.all_equal(b, [[1,0,0],[1,1,0],[1,1,1]]) + + b = jt.tril(a, diagonal=1) + assert jt.all_equal(b, [[1,1,0],[1,1,1],[1,1,1]]) + + b = jt.tril(a, diagonal=-1) + assert jt.all_equal(b, [[0,0,0],[1,0,0],[1,1,0]]) + + ''' + index = input.index() + mask = index[-2] >= index[-1] - diagonal + return input*mask +jt.Var.tril = tril + +def all_equal(a: jt.Var, b: jt.Var) -> bool: + return (a == b).all().item() +jt.all_equal = all_equal diff --git a/python/jittor/test/test_misc_op.py b/python/jittor/test/test_misc_op.py index 784971f4..61a5a4e9 100644 --- a/python/jittor/test/test_misc_op.py +++ b/python/jittor/test/test_misc_op.py @@ -343,5 +343,26 @@ class TestOther(unittest.TestCase): output = m(input) output.sync() + def test_tri(self): + a = jt.ones(3, 3) + b = jt.triu(a) + assert jt.all_equal(b, [[1,1,1],[0,1,1],[0,0,1]]) + + b = jt.triu(a, diagonal=1) + assert jt.all_equal(b, [[0,1,1],[0,0,1],[0,0,0]]) + + b = jt.triu(a, diagonal=-1) + assert jt.all_equal(b, [[1,1,1],[1,1,1],[0,1,1]]) + + a = jt.ones(3, 3) + b = jt.tril(a) + assert jt.all_equal(b, [[1,0,0],[1,1,0],[1,1,1]]) + + b = jt.tril(a, diagonal=1) + assert jt.all_equal(b, [[1,1,0],[1,1,1],[1,1,1]]) + + b = jt.tril(a, diagonal=-1) + assert jt.all_equal(b, [[0,0,0],[1,0,0],[1,1,0]]) + if __name__ == "__main__": unittest.main() \ No newline at end of file