mirror of https://github.com/Jittor/Jittor
add ce loss test
This commit is contained in:
parent
ef66d6d832
commit
aaf97d5f58
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.2.3.44'
|
__version__ = '1.2.3.45'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
|
|
@ -85,6 +85,17 @@ class TestLoss(unittest.TestCase):
|
||||||
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
|
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
|
||||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||||
|
|
||||||
|
def test_cross_entropy_weight_ignore(self):
|
||||||
|
weight = np.random.rand(4).astype('float32')
|
||||||
|
jt_loss = jnn.CrossEntropyLoss(weight=jt.array(weight), ignore_index=1)
|
||||||
|
tc_loss = tnn.CrossEntropyLoss(weight=torch.from_numpy(weight), ignore_index=1)
|
||||||
|
output = np.random.rand(32, 4, 512, 512).astype(np.float32)
|
||||||
|
target = np.random.randint(4, size=(32, 512, 512))
|
||||||
|
jt_y=jt_loss(jt.array(output), jt.array(target))
|
||||||
|
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
|
||||||
|
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||||
|
|
||||||
|
|
||||||
def test_bce_loss(self):
|
def test_bce_loss(self):
|
||||||
jt_loss=jnn.BCELoss()
|
jt_loss=jnn.BCELoss()
|
||||||
tc_loss=tnn.BCELoss()
|
tc_loss=tnn.BCELoss()
|
||||||
|
|
Loading…
Reference in New Issue