mirror of https://github.com/Jittor/Jittor
commit
908ec890c5
|
@ -341,6 +341,24 @@ def detach(x):
|
|||
return x.clone().stop_grad().clone()
|
||||
Var.detach = detach
|
||||
|
||||
def std(x):
|
||||
matsize=1
|
||||
for i in x.shape:
|
||||
matsize *= i
|
||||
out=(x-x.mean()).sqr().sum()
|
||||
out=out/(matsize-1)
|
||||
out=out.sqrt()
|
||||
return out
|
||||
Var.std = std
|
||||
|
||||
def norm(x, k, dim):
|
||||
assert k==1 or k==2
|
||||
if k==1:
|
||||
return x.abs().sum(dim)
|
||||
if k==2:
|
||||
return (x**2).sum(dim).sqrt()
|
||||
Var.norm = norm
|
||||
|
||||
origin_reshape = reshape
|
||||
def reshape(x, *shape):
|
||||
if len(shape) == 1 and isinstance(shape[0], Sequence):
|
||||
|
|
|
@ -100,6 +100,48 @@ def cross_entropy_loss(output, target, ignore_index=None):
|
|||
else:
|
||||
return loss.sum() / jt.maximum(mask.int().sum(), 1)
|
||||
|
||||
def mse_loss(output, target):
|
||||
return (output-target).sqr().mean()
|
||||
|
||||
def bce_loss(output, target):
|
||||
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
|
||||
|
||||
def l1_loss(output, target):
|
||||
return (output-target).abs().mean()
|
||||
|
||||
class CrossEntropyLoss(Module):
|
||||
def __init__(self):
|
||||
pass
|
||||
def execute(self, output, target):
|
||||
return cross_entropy_loss(output, target)
|
||||
|
||||
class MSELoss(Module):
|
||||
def __init__(self):
|
||||
pass
|
||||
def execute(self, output, target):
|
||||
return mse_loss(output, target)
|
||||
|
||||
class BCELoss(Module):
|
||||
def __init__(self):
|
||||
pass
|
||||
def execute(self, output, target):
|
||||
return bce_loss(output, target)
|
||||
|
||||
class L1Loss(Module):
|
||||
def __init__(self):
|
||||
pass
|
||||
def execute(self, output, target):
|
||||
return l1_loss(output, target)
|
||||
|
||||
class BCEWithLogitsLoss(Module):
|
||||
def __init__(self):
|
||||
self.sigmoid = Sigmoid()
|
||||
self.bce = BCELoss()
|
||||
def execute(self, output, target):
|
||||
output = self.sigmoid(output)
|
||||
output = self.bce(output, target)
|
||||
return output
|
||||
|
||||
class SGD(object):
|
||||
""" Usage:
|
||||
optimizer = nn.SGD(model.parameters(), lr)
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import os
|
||||
import numpy as np
|
||||
import jittor.nn as jnn
|
||||
|
||||
from jittor.test.test_log import find_log_with_re
|
||||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
except:
|
||||
skip_this_test = True
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestLoss(unittest.TestCase):
|
||||
def test_l1_loss(self):
|
||||
jt_loss=jnn.L1Loss()
|
||||
tc_loss=tnn.L1Loss()
|
||||
output=np.random.randn(10,100).astype(np.float32)
|
||||
target=np.random.randn(10,100).astype(np.float32)
|
||||
jt_y=jt_loss(jt.array(output), jt.array(target))
|
||||
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
|
||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||
|
||||
def test_mse_loss(self):
|
||||
jt_loss=jnn.MSELoss()
|
||||
tc_loss=tnn.MSELoss()
|
||||
output=np.random.randn(10,100).astype(np.float32)
|
||||
target=np.random.randn(10,100).astype(np.float32)
|
||||
jt_y=jt_loss(jt.array(output), jt.array(target))
|
||||
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
|
||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||
|
||||
def test_cross_entropy_loss(self):
|
||||
jt_loss=jnn.CrossEntropyLoss()
|
||||
tc_loss=tnn.CrossEntropyLoss()
|
||||
output=np.random.randn(10,10).astype(np.float32)
|
||||
target=np.random.randint(10, size=(10))
|
||||
jt_y=jt_loss(jt.array(output), jt.array(target))
|
||||
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
|
||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||
|
||||
def test_bce_loss(self):
|
||||
jt_loss=jnn.BCELoss()
|
||||
tc_loss=tnn.BCELoss()
|
||||
jt_sig = jnn.Sigmoid()
|
||||
tc_sig = tnn.Sigmoid()
|
||||
output=np.random.randn(100).astype(np.float32)
|
||||
target=np.random.randint(2, size=(100)).astype(np.float32)
|
||||
jt_y=jt_loss(jt_sig(jt.array(output)), jt.array(target))
|
||||
tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target))
|
||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||
|
||||
def test_bce_with_logits_loss(self):
|
||||
jt_loss=jnn.BCEWithLogitsLoss()
|
||||
tc_loss=tnn.BCEWithLogitsLoss()
|
||||
output=np.random.randn(100).astype(np.float32)
|
||||
target=np.random.randint(2, size=(100)).astype(np.float32)
|
||||
jt_y=jt_loss(jt.array(output), jt.array(target))
|
||||
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
|
||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,42 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import os
|
||||
import numpy as np
|
||||
import jittor.nn as jnn
|
||||
|
||||
from jittor.test.test_log import find_log_with_re
|
||||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
except:
|
||||
skip_this_test = True
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestStd(unittest.TestCase):
|
||||
def test_std(self):
|
||||
x=np.random.randn(100,1000).astype(np.float32)
|
||||
jt_x=jt.array(x)
|
||||
tc_x=torch.from_numpy(x)
|
||||
assert np.allclose(jt_x.std().numpy(), tc_x.std().numpy(), 1e-4) ,(x, jt_x.std().numpy(), tc_x.std().numpy())
|
||||
|
||||
def test_norm(self):
|
||||
x=np.random.randn(100,1000).astype(np.float32)
|
||||
jt_x=jt.array(x)
|
||||
tc_x=torch.from_numpy(x)
|
||||
assert np.allclose(jt_x.norm(1,1).numpy(), tc_x.norm(1,1).numpy())
|
||||
assert np.allclose(jt_x.norm(1,0).numpy(), tc_x.norm(1,0).numpy())
|
||||
assert np.allclose(jt_x.norm(2,1).numpy(), tc_x.norm(2,1).numpy())
|
||||
assert np.allclose(jt_x.norm(2,0).numpy(), tc_x.norm(2,0).numpy())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue