From aaf97d5f58a3fed8aa2f93be132a9ed9be8dd4a5 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Fri, 25 Jun 2021 14:35:32 +0800 Subject: [PATCH] add ce loss test --- python/jittor/__init__.py | 2 +- python/jittor/test/test_loss.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 2337427c..492314f8 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.3.44' +__version__ = '1.2.3.45' from jittor_utils import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/test/test_loss.py b/python/jittor/test/test_loss.py index 99a2c80c..e72a231b 100644 --- a/python/jittor/test/test_loss.py +++ b/python/jittor/test/test_loss.py @@ -85,6 +85,17 @@ class TestLoss(unittest.TestCase): tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target)) assert np.allclose(jt_y.numpy(), tc_y.numpy()) + def test_cross_entropy_weight_ignore(self): + weight = np.random.rand(4).astype('float32') + jt_loss = jnn.CrossEntropyLoss(weight=jt.array(weight), ignore_index=1) + tc_loss = tnn.CrossEntropyLoss(weight=torch.from_numpy(weight), ignore_index=1) + output = np.random.rand(32, 4, 512, 512).astype(np.float32) + target = np.random.randint(4, size=(32, 512, 512)) + jt_y=jt_loss(jt.array(output), jt.array(target)) + tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target)) + assert np.allclose(jt_y.numpy(), tc_y.numpy()) + + def test_bce_loss(self): jt_loss=jnn.BCELoss() tc_loss=tnn.BCELoss()