BCELoss add weight

This commit is contained in:
zhowuy19 2020-08-14 17:32:42 +08:00
parent a918260ab6
commit 843b983e90
3 changed files with 72 additions and 13 deletions

View File

@ -167,11 +167,16 @@ def cross_entropy_loss(output, target, ignore_index=None):
def mse_loss(output, target): def mse_loss(output, target):
return (output-target).sqr().mean() return (output-target).sqr().mean()
def bce_loss(output, target, size_average=True): def bce_loss(output, target, weight=None, size_average=True):
loss = - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20)))
if weight is not None:
loss *= weight
if size_average: if size_average:
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean() return loss.mean()
else: else:
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).sum() return loss.sum()
def l1_loss(output, target): def l1_loss(output, target):
return (output-target).abs().mean() return (output-target).abs().mean()
@ -189,10 +194,11 @@ class MSELoss(Module):
return mse_loss(output, target) return mse_loss(output, target)
class BCELoss(Module): class BCELoss(Module):
def __init__(self): def __init__(self, weight=None, size_average=True):
pass self.weight = weight
def execute(self, output, target, size_average=True): self.size_average = size_average
return bce_loss(output, target, size_average) def execute(self, output, target):
return bce_loss(output, target, self.weight, self.size_average)
class L1Loss(Module): class L1Loss(Module):
def __init__(self): def __init__(self):
@ -201,14 +207,17 @@ class L1Loss(Module):
return l1_loss(output, target) return l1_loss(output, target)
class BCEWithLogitsLoss(Module): class BCEWithLogitsLoss(Module):
def __init__(self): def __init__(self, weight=None, size_average=True):
self.sigmoid = Sigmoid() self.sigmoid = Sigmoid()
self.bce = BCELoss() self.bce = BCELoss(weight, size_average)
def execute(self, output, target, size_average=True): def execute(self, output, target):
output = self.sigmoid(output) output = self.sigmoid(output)
output = self.bce(output, target, size_average) output = self.bce(output, target)
return output return output
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True):
return BCEWithLogitsLoss(weight, size_average)(input, target)
def softmax(x, dim = None): def softmax(x, dim = None):
if dim is None: if dim is None:
x = (x - x.max()).exp() x = (x - x.max()).exp()

View File

@ -49,7 +49,7 @@ class TestLoss(unittest.TestCase):
jt_y=jt_loss(jt.array(output), jt.array(target)) jt_y=jt_loss(jt.array(output), jt.array(target))
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target)) tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
assert np.allclose(jt_y.numpy(), tc_y.numpy()) assert np.allclose(jt_y.numpy(), tc_y.numpy())
def test_bce_loss(self): def test_bce_loss(self):
jt_loss=jnn.BCELoss() jt_loss=jnn.BCELoss()
tc_loss=tnn.BCELoss() tc_loss=tnn.BCELoss()
@ -60,6 +60,13 @@ class TestLoss(unittest.TestCase):
jt_y=jt_loss(jt_sig(jt.array(output)), jt.array(target)) jt_y=jt_loss(jt_sig(jt.array(output)), jt.array(target))
tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target)) tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target))
assert np.allclose(jt_y.numpy(), tc_y.numpy()) assert np.allclose(jt_y.numpy(), tc_y.numpy())
weight=np.random.randn(100).astype(np.float32)
jt_loss=jnn.BCELoss(weight=jt.array(weight), size_average=False)
tc_loss=tnn.BCELoss(weight=torch.Tensor(weight), size_average=False)
jt_y=jt_loss(jt_sig(jt.array(output)), jt.array(target))
tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target))
assert np.allclose(jt_y.numpy(), tc_y.numpy())
def test_bce_with_logits_loss(self): def test_bce_with_logits_loss(self):
jt_loss=jnn.BCEWithLogitsLoss() jt_loss=jnn.BCEWithLogitsLoss()

View File

@ -78,6 +78,32 @@ pjmap = {
'extras': {}, 'extras': {},
'delete': ['inplace'], 'delete': ['inplace'],
}, },
'relu': {
'pytorch': {
'args': 'input',
},
'jittor': {
'module': 'nn',
'name': 'relu',
'args': 'x'
},
'links': {'input': 'x'},
'extras': {},
'delete': [],
},
'binary_cross_entropy_with_logits': {
'pytorch': {
'args': 'input, target, weight, size_average=True',
},
'jittor': {
'module': 'nn',
'name': 'binary_cross_entropy_with_logits',
'args': 'input, target, weight, size_average=True'
},
'links': {},
'extras': {},
'delete': [],
},
'ReLU6': { 'ReLU6': {
'pytorch': { 'pytorch': {
'args': 'inplace=False', 'args': 'inplace=False',
@ -274,6 +300,23 @@ pjmap = {
'links': {}, 'links': {},
'extras': {}, 'extras': {},
}, },
'clamp': {
'pytorch': {
'prefix': ['torch'],
'args_prefix': 'input, min, max, out=None',
'args': 'min, max, out=None',
},
'jittor': {
'prefix': 'jt',
'module': '',
'name': 'clamp',
'args_prefix': 'x, min_v, max_v',
'args': 'min_v, max_v'
},
'links': {'min': 'min_v', 'max': 'max_v'},
'extras': {},
'delete': ['out'],
},
'permute': { 'permute': {
'pytorch': { 'pytorch': {
'prefix': [], 'prefix': [],
@ -354,7 +397,7 @@ unsupport_ops = [
# *************************************************************** # ***************************************************************
# torch.nn # torch.nn
# *************************************************************** # ***************************************************************
'Parameter', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict', 'Parameter', 'ModuleDict', 'ParameterList', 'ParameterDict',
'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold', 'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold',
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d', 'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d',
'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d', 'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d',