This commit is contained in:
Dun Liang 2021-06-24 21:39:43 +08:00
parent 385d60a147
commit 1246c37692
2 changed files with 2 additions and 2 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.42'
__version__ = '1.2.3.43'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -195,6 +195,7 @@ def cross_entropy_loss(output, target, weight=None, ignore_index=None):
output = output.transpose((0, 2, 3, 1))
output = output.reshape((-1, c_dim))
target = target.reshape((-1, ))
target_weight = jt.ones(target.shape[0], dtype='float32')
if weight is not None:
target_weight = weight[target]
@ -205,7 +206,6 @@ def cross_entropy_loss(output, target, weight=None, ignore_index=None):
target_weight
)
target = target.reshape((-1, ))
target = target.broadcast(output, [1])
target = target.index(1) == target