support build triu/tril of an inf matrix

The original implementation results in nan because zeros * inf = nan.
Suggest to use jt.ternary op instead of multiply a mask matrix.
This commit is contained in:
lzhengning 2023-03-22 13:16:04 +08:00
parent 576b0c9e03
commit 3d40091693
2 changed files with 3 additions and 3 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.6.12'
__version__ = '1.3.6.13'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -2013,7 +2013,7 @@ def triu(input: jt.Var, diagonal:int=0) -> jt.Var:
'''
index = input.index()
mask = index[-2] <= index[-1] - diagonal
return input*mask
return jt.ternary(mask, input, jt.zeros_like(input))
jt.Var.triu = triu
def tril(input: jt.Var, diagonal:int=0) -> jt.Var:
@ -2037,7 +2037,7 @@ def tril(input: jt.Var, diagonal:int=0) -> jt.Var:
'''
index = input.index()
mask = index[-2] >= index[-1] - diagonal
return input*mask
return jt.ternary(mask, input, jt.zeros_like(input))
jt.Var.tril = tril
def all_equal(a: jt.Var, b: jt.Var) -> bool: