mirror of https://github.com/Jittor/Jittor
support build triu/tril of an inf matrix
The original implementation results in nan because zeros * inf = nan. Suggest to use jt.ternary op instead of multiply a mask matrix.
This commit is contained in:
parent
576b0c9e03
commit
3d40091693
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.6.12'
|
||||
__version__ = '1.3.6.13'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -2013,7 +2013,7 @@ def triu(input: jt.Var, diagonal:int=0) -> jt.Var:
|
|||
'''
|
||||
index = input.index()
|
||||
mask = index[-2] <= index[-1] - diagonal
|
||||
return input*mask
|
||||
return jt.ternary(mask, input, jt.zeros_like(input))
|
||||
jt.Var.triu = triu
|
||||
|
||||
def tril(input: jt.Var, diagonal:int=0) -> jt.Var:
|
||||
|
@ -2037,7 +2037,7 @@ def tril(input: jt.Var, diagonal:int=0) -> jt.Var:
|
|||
'''
|
||||
index = input.index()
|
||||
mask = index[-2] >= index[-1] - diagonal
|
||||
return input*mask
|
||||
return jt.ternary(mask, input, jt.zeros_like(input))
|
||||
jt.Var.tril = tril
|
||||
|
||||
def all_equal(a: jt.Var, b: jt.Var) -> bool:
|
||||
|
|
Loading…
Reference in New Issue