mirror of https://github.com/Jittor/Jittor
fix ce
This commit is contained in:
parent
385d60a147
commit
1246c37692
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.2.3.42'
|
__version__ = '1.2.3.43'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
|
|
@ -195,6 +195,7 @@ def cross_entropy_loss(output, target, weight=None, ignore_index=None):
|
||||||
output = output.transpose((0, 2, 3, 1))
|
output = output.transpose((0, 2, 3, 1))
|
||||||
output = output.reshape((-1, c_dim))
|
output = output.reshape((-1, c_dim))
|
||||||
|
|
||||||
|
target = target.reshape((-1, ))
|
||||||
target_weight = jt.ones(target.shape[0], dtype='float32')
|
target_weight = jt.ones(target.shape[0], dtype='float32')
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
target_weight = weight[target]
|
target_weight = weight[target]
|
||||||
|
@ -205,7 +206,6 @@ def cross_entropy_loss(output, target, weight=None, ignore_index=None):
|
||||||
target_weight
|
target_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
target = target.reshape((-1, ))
|
|
||||||
target = target.broadcast(output, [1])
|
target = target.broadcast(output, [1])
|
||||||
target = target.index(1) == target
|
target = target.index(1) == target
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue