mirror of https://github.com/Jittor/Jittor
add triu and tril function
This commit is contained in:
parent
44fdc718ab
commit
499d3ee99c
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.3.5.42'
|
__version__ = '1.3.5.43'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
@ -443,7 +443,7 @@ def liveness_info():
|
||||||
"lived_ops": core.number_of_lived_ops(),
|
"lived_ops": core.number_of_lived_ops(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def ones(shape, dtype="float32"):
|
def ones(*shape, dtype="float32"):
|
||||||
''' Constructs a jittor Var with all elements set to 1.
|
''' Constructs a jittor Var with all elements set to 1.
|
||||||
|
|
||||||
:param shape: The shape of the output Var.
|
:param shape: The shape of the output Var.
|
||||||
|
@ -453,8 +453,8 @@ def ones(shape, dtype="float32"):
|
||||||
:return: The output Var.
|
:return: The output Var.
|
||||||
:rtype: jittor.Var
|
:rtype: jittor.Var
|
||||||
'''
|
'''
|
||||||
if not isinstance(shape, (NanoVector, Sequence)):
|
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
|
||||||
shape = (shape,)
|
shape = shape[0]
|
||||||
return unary(1, dtype).broadcast(shape)
|
return unary(1, dtype).broadcast(shape)
|
||||||
|
|
||||||
def ones_like(x):
|
def ones_like(x):
|
||||||
|
@ -467,7 +467,7 @@ def ones_like(x):
|
||||||
'''
|
'''
|
||||||
return ones(x.shape,x.dtype)
|
return ones(x.shape,x.dtype)
|
||||||
|
|
||||||
def zeros(shape, dtype="float32"):
|
def zeros(*shape, dtype="float32"):
|
||||||
''' Constructs a jittor Var with all elements set to 0.
|
''' Constructs a jittor Var with all elements set to 0.
|
||||||
|
|
||||||
:param shape: The shape of the output Var.
|
:param shape: The shape of the output Var.
|
||||||
|
@ -477,8 +477,8 @@ def zeros(shape, dtype="float32"):
|
||||||
:return: The output Var.
|
:return: The output Var.
|
||||||
:rtype: jittor.Var
|
:rtype: jittor.Var
|
||||||
'''
|
'''
|
||||||
if not isinstance(shape, (NanoVector, Sequence)):
|
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
|
||||||
shape = (shape,)
|
shape = shape[0]
|
||||||
return unary(0, dtype).broadcast(shape)
|
return unary(0, dtype).broadcast(shape)
|
||||||
|
|
||||||
def full(shape,val,dtype="float32"):
|
def full(shape,val,dtype="float32"):
|
||||||
|
|
|
@ -1990,4 +1990,56 @@ def from_torch(x):
|
||||||
'''
|
'''
|
||||||
Convert torch Tensor to Jittor Var
|
Convert torch Tensor to Jittor Var
|
||||||
'''
|
'''
|
||||||
return jt.Var(x.cpu().numpy())
|
return jt.Var(x.cpu().numpy())
|
||||||
|
|
||||||
|
def triu(input: jt.Var, diagonal:int=0) -> jt.Var:
|
||||||
|
''' Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
|
||||||
|
|
||||||
|
:param input: the input tensor.
|
||||||
|
:param diagonal: the diagonal to consider(int).
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
a = jt.ones(3, 3)
|
||||||
|
b = jt.triu(a)
|
||||||
|
assert jt.all_equal(b, [[1,1,1],[0,1,1],[0,0,1]])
|
||||||
|
|
||||||
|
b = jt.triu(a, diagonal=1)
|
||||||
|
assert jt.all_equal(b, [[0,1,1],[0,0,1],[0,0,0]])
|
||||||
|
|
||||||
|
b = jt.triu(a, diagonal=-1)
|
||||||
|
assert jt.all_equal(b, [[1,1,1],[1,1,1],[0,1,1]])
|
||||||
|
|
||||||
|
'''
|
||||||
|
index = input.index()
|
||||||
|
mask = index[-2] <= index[-1] - diagonal
|
||||||
|
return input*mask
|
||||||
|
jt.Var.triu = triu
|
||||||
|
|
||||||
|
def tril(input: jt.Var, diagonal:int=0) -> jt.Var:
|
||||||
|
''' Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
|
||||||
|
|
||||||
|
:param input: the input tensor.
|
||||||
|
:param diagonal: the diagonal to consider(int).
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
a = jt.ones(3, 3)
|
||||||
|
b = jt.tril(a)
|
||||||
|
assert jt.all_equal(b, [[1,0,0],[1,1,0],[1,1,1]])
|
||||||
|
|
||||||
|
b = jt.tril(a, diagonal=1)
|
||||||
|
assert jt.all_equal(b, [[1,1,0],[1,1,1],[1,1,1]])
|
||||||
|
|
||||||
|
b = jt.tril(a, diagonal=-1)
|
||||||
|
assert jt.all_equal(b, [[0,0,0],[1,0,0],[1,1,0]])
|
||||||
|
|
||||||
|
'''
|
||||||
|
index = input.index()
|
||||||
|
mask = index[-2] >= index[-1] - diagonal
|
||||||
|
return input*mask
|
||||||
|
jt.Var.tril = tril
|
||||||
|
|
||||||
|
def all_equal(a: jt.Var, b: jt.Var) -> bool:
|
||||||
|
return (a == b).all().item()
|
||||||
|
jt.all_equal = all_equal
|
||||||
|
|
|
@ -343,5 +343,26 @@ class TestOther(unittest.TestCase):
|
||||||
output = m(input)
|
output = m(input)
|
||||||
output.sync()
|
output.sync()
|
||||||
|
|
||||||
|
def test_tri(self):
|
||||||
|
a = jt.ones(3, 3)
|
||||||
|
b = jt.triu(a)
|
||||||
|
assert jt.all_equal(b, [[1,1,1],[0,1,1],[0,0,1]])
|
||||||
|
|
||||||
|
b = jt.triu(a, diagonal=1)
|
||||||
|
assert jt.all_equal(b, [[0,1,1],[0,0,1],[0,0,0]])
|
||||||
|
|
||||||
|
b = jt.triu(a, diagonal=-1)
|
||||||
|
assert jt.all_equal(b, [[1,1,1],[1,1,1],[0,1,1]])
|
||||||
|
|
||||||
|
a = jt.ones(3, 3)
|
||||||
|
b = jt.tril(a)
|
||||||
|
assert jt.all_equal(b, [[1,0,0],[1,1,0],[1,1,1]])
|
||||||
|
|
||||||
|
b = jt.tril(a, diagonal=1)
|
||||||
|
assert jt.all_equal(b, [[1,1,0],[1,1,1],[1,1,1]])
|
||||||
|
|
||||||
|
b = jt.tril(a, diagonal=-1)
|
||||||
|
assert jt.all_equal(b, [[0,0,0],[1,0,0],[1,1,0]])
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
Loading…
Reference in New Issue