mirror of https://github.com/Jittor/Jittor
fix ce
This commit is contained in:
parent
385d60a147
commit
1246c37692
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.42'
|
||||
__version__ = '1.2.3.43'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -195,6 +195,7 @@ def cross_entropy_loss(output, target, weight=None, ignore_index=None):
|
|||
output = output.transpose((0, 2, 3, 1))
|
||||
output = output.reshape((-1, c_dim))
|
||||
|
||||
target = target.reshape((-1, ))
|
||||
target_weight = jt.ones(target.shape[0], dtype='float32')
|
||||
if weight is not None:
|
||||
target_weight = weight[target]
|
||||
|
@ -205,7 +206,6 @@ def cross_entropy_loss(output, target, weight=None, ignore_index=None):
|
|||
target_weight
|
||||
)
|
||||
|
||||
target = target.reshape((-1, ))
|
||||
target = target.broadcast(output, [1])
|
||||
target = target.index(1) == target
|
||||
|
||||
|
|
Loading…
Reference in New Issue