BCELoss add weight

This commit is contained in:
zhowuy19 2020-08-14 17:32:42 +08:00
parent a918260ab6
commit 843b983e90
3 changed files with 72 additions and 13 deletions

View File

@ -167,11 +167,16 @@ def cross_entropy_loss(output, target, ignore_index=None):
def mse_loss(output, target):
return (output-target).sqr().mean()
def bce_loss(output, target, size_average=True):
def bce_loss(output, target, weight=None, size_average=True):
loss = - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20)))
if weight is not None:
loss *= weight
if size_average:
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
return loss.mean()
else:
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).sum()
return loss.sum()
def l1_loss(output, target):
return (output-target).abs().mean()
@ -189,10 +194,11 @@ class MSELoss(Module):
return mse_loss(output, target)
class BCELoss(Module):
def __init__(self):
pass
def execute(self, output, target, size_average=True):
return bce_loss(output, target, size_average)
def __init__(self, weight=None, size_average=True):
self.weight = weight
self.size_average = size_average
def execute(self, output, target):
return bce_loss(output, target, self.weight, self.size_average)
class L1Loss(Module):
def __init__(self):
@ -201,14 +207,17 @@ class L1Loss(Module):
return l1_loss(output, target)
class BCEWithLogitsLoss(Module):
def __init__(self):
def __init__(self, weight=None, size_average=True):
self.sigmoid = Sigmoid()
self.bce = BCELoss()
def execute(self, output, target, size_average=True):
self.bce = BCELoss(weight, size_average)
def execute(self, output, target):
output = self.sigmoid(output)
output = self.bce(output, target, size_average)
output = self.bce(output, target)
return output
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True):
return BCEWithLogitsLoss(weight, size_average)(input, target)
def softmax(x, dim = None):
if dim is None:
x = (x - x.max()).exp()

View File

@ -61,6 +61,13 @@ class TestLoss(unittest.TestCase):
tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target))
assert np.allclose(jt_y.numpy(), tc_y.numpy())
weight=np.random.randn(100).astype(np.float32)
jt_loss=jnn.BCELoss(weight=jt.array(weight), size_average=False)
tc_loss=tnn.BCELoss(weight=torch.Tensor(weight), size_average=False)
jt_y=jt_loss(jt_sig(jt.array(output)), jt.array(target))
tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target))
assert np.allclose(jt_y.numpy(), tc_y.numpy())
def test_bce_with_logits_loss(self):
jt_loss=jnn.BCEWithLogitsLoss()
tc_loss=tnn.BCEWithLogitsLoss()

View File

@ -78,6 +78,32 @@ pjmap = {
'extras': {},
'delete': ['inplace'],
},
'relu': {
'pytorch': {
'args': 'input',
},
'jittor': {
'module': 'nn',
'name': 'relu',
'args': 'x'
},
'links': {'input': 'x'},
'extras': {},
'delete': [],
},
'binary_cross_entropy_with_logits': {
'pytorch': {
'args': 'input, target, weight, size_average=True',
},
'jittor': {
'module': 'nn',
'name': 'binary_cross_entropy_with_logits',
'args': 'input, target, weight, size_average=True'
},
'links': {},
'extras': {},
'delete': [],
},
'ReLU6': {
'pytorch': {
'args': 'inplace=False',
@ -274,6 +300,23 @@ pjmap = {
'links': {},
'extras': {},
},
'clamp': {
'pytorch': {
'prefix': ['torch'],
'args_prefix': 'input, min, max, out=None',
'args': 'min, max, out=None',
},
'jittor': {
'prefix': 'jt',
'module': '',
'name': 'clamp',
'args_prefix': 'x, min_v, max_v',
'args': 'min_v, max_v'
},
'links': {'min': 'min_v', 'max': 'max_v'},
'extras': {},
'delete': ['out'],
},
'permute': {
'pytorch': {
'prefix': [],
@ -354,7 +397,7 @@ unsupport_ops = [
# ***************************************************************
# torch.nn
# ***************************************************************
'Parameter', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict',
'Parameter', 'ModuleDict', 'ParameterList', 'ParameterDict',
'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold',
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d',
'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d',