This commit is contained in:
zwy 2020-04-30 21:31:54 +08:00
commit 5f88c2e5c1
4 changed files with 176 additions and 0 deletions

View File

@ -341,6 +341,24 @@ def detach(x):
return x.clone().stop_grad().clone()
Var.detach = detach
def std(x):
matsize=1
for i in x.shape:
matsize *= i
out=(x-x.mean()).sqr().sum()
out=out/(matsize-1)
out=out.sqrt()
return out
Var.std = std
def norm(x, k, dim):
assert k==1 or k==2
if k==1:
return x.abs().sum(dim)
if k==2:
return (x**2).sum(dim).sqrt()
Var.norm = norm
origin_reshape = reshape
def reshape(x, *shape):
if len(shape) == 1 and isinstance(shape[0], Sequence):

View File

@ -112,6 +112,48 @@ def cross_entropy_loss(output, target, ignore_index=None):
else:
return loss.sum() / jt.maximum(mask.int().sum(), 1)
def mse_loss(output, target):
return (output-target).sqr().mean()
def bce_loss(output, target):
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
def l1_loss(output, target):
return (output-target).abs().mean()
class CrossEntropyLoss(Module):
def __init__(self):
pass
def execute(self, output, target):
return cross_entropy_loss(output, target)
class MSELoss(Module):
def __init__(self):
pass
def execute(self, output, target):
return mse_loss(output, target)
class BCELoss(Module):
def __init__(self):
pass
def execute(self, output, target):
return bce_loss(output, target)
class L1Loss(Module):
def __init__(self):
pass
def execute(self, output, target):
return l1_loss(output, target)
class BCEWithLogitsLoss(Module):
def __init__(self):
self.sigmoid = Sigmoid()
self.bce = BCELoss()
def execute(self, output, target):
output = self.sigmoid(output)
output = self.bce(output, target)
return output
class SGD(object):
""" Usage:
optimizer = nn.SGD(model.parameters(), lr)

View File

@ -0,0 +1,74 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import os
import numpy as np
import jittor.nn as jnn
from jittor.test.test_log import find_log_with_re
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
import torch.nn as tnn
except:
skip_this_test = True
@unittest.skipIf(skip_this_test, "No Torch found")
class TestLoss(unittest.TestCase):
def test_l1_loss(self):
jt_loss=jnn.L1Loss()
tc_loss=tnn.L1Loss()
output=np.random.randn(10,100).astype(np.float32)
target=np.random.randn(10,100).astype(np.float32)
jt_y=jt_loss(jt.array(output), jt.array(target))
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
assert np.allclose(jt_y.numpy(), tc_y.numpy())
def test_mse_loss(self):
jt_loss=jnn.MSELoss()
tc_loss=tnn.MSELoss()
output=np.random.randn(10,100).astype(np.float32)
target=np.random.randn(10,100).astype(np.float32)
jt_y=jt_loss(jt.array(output), jt.array(target))
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
assert np.allclose(jt_y.numpy(), tc_y.numpy())
def test_cross_entropy_loss(self):
jt_loss=jnn.CrossEntropyLoss()
tc_loss=tnn.CrossEntropyLoss()
output=np.random.randn(10,10).astype(np.float32)
target=np.random.randint(10, size=(10))
jt_y=jt_loss(jt.array(output), jt.array(target))
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
assert np.allclose(jt_y.numpy(), tc_y.numpy())
def test_bce_loss(self):
jt_loss=jnn.BCELoss()
tc_loss=tnn.BCELoss()
jt_sig = jnn.Sigmoid()
tc_sig = tnn.Sigmoid()
output=np.random.randn(100).astype(np.float32)
target=np.random.randint(2, size=(100)).astype(np.float32)
jt_y=jt_loss(jt_sig(jt.array(output)), jt.array(target))
tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target))
assert np.allclose(jt_y.numpy(), tc_y.numpy())
def test_bce_with_logits_loss(self):
jt_loss=jnn.BCEWithLogitsLoss()
tc_loss=tnn.BCEWithLogitsLoss()
output=np.random.randn(100).astype(np.float32)
target=np.random.randint(2, size=(100)).astype(np.float32)
jt_y=jt_loss(jt.array(output), jt.array(target))
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
assert np.allclose(jt_y.numpy(), tc_y.numpy())
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,42 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import os
import numpy as np
import jittor.nn as jnn
from jittor.test.test_log import find_log_with_re
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
import torch.nn as tnn
except:
skip_this_test = True
@unittest.skipIf(skip_this_test, "No Torch found")
class TestStd(unittest.TestCase):
def test_std(self):
x=np.random.randn(100,1000).astype(np.float32)
jt_x=jt.array(x)
tc_x=torch.from_numpy(x)
assert np.allclose(jt_x.std().numpy(), tc_x.std().numpy(), 1e-4) ,(x, jt_x.std().numpy(), tc_x.std().numpy())
def test_norm(self):
x=np.random.randn(100,1000).astype(np.float32)
jt_x=jt.array(x)
tc_x=torch.from_numpy(x)
assert np.allclose(jt_x.norm(1,1).numpy(), tc_x.norm(1,1).numpy())
assert np.allclose(jt_x.norm(1,0).numpy(), tc_x.norm(1,0).numpy())
assert np.allclose(jt_x.norm(2,1).numpy(), tc_x.norm(2,1).numpy())
assert np.allclose(jt_x.norm(2,0).numpy(), tc_x.norm(2,0).numpy())
if __name__ == "__main__":
unittest.main()