add ce loss test

This commit is contained in:
Dun Liang 2021-06-25 14:35:32 +08:00
parent ef66d6d832
commit aaf97d5f58
2 changed files with 12 additions and 1 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.2.3.44' __version__ = '1.2.3.45'
from jittor_utils import lock from jittor_utils import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int

View File

@ -85,6 +85,17 @@ class TestLoss(unittest.TestCase):
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target)) tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
assert np.allclose(jt_y.numpy(), tc_y.numpy()) assert np.allclose(jt_y.numpy(), tc_y.numpy())
def test_cross_entropy_weight_ignore(self):
weight = np.random.rand(4).astype('float32')
jt_loss = jnn.CrossEntropyLoss(weight=jt.array(weight), ignore_index=1)
tc_loss = tnn.CrossEntropyLoss(weight=torch.from_numpy(weight), ignore_index=1)
output = np.random.rand(32, 4, 512, 512).astype(np.float32)
target = np.random.randint(4, size=(32, 512, 512))
jt_y=jt_loss(jt.array(output), jt.array(target))
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
assert np.allclose(jt_y.numpy(), tc_y.numpy())
def test_bce_loss(self): def test_bce_loss(self):
jt_loss=jnn.BCELoss() jt_loss=jnn.BCELoss()
tc_loss=tnn.BCELoss() tc_loss=tnn.BCELoss()