mirror of https://github.com/Jittor/Jittor
BCELoss add weight
This commit is contained in:
parent
a918260ab6
commit
843b983e90
|
@ -167,11 +167,16 @@ def cross_entropy_loss(output, target, ignore_index=None):
|
|||
def mse_loss(output, target):
|
||||
return (output-target).sqr().mean()
|
||||
|
||||
def bce_loss(output, target, size_average=True):
|
||||
def bce_loss(output, target, weight=None, size_average=True):
|
||||
loss = - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20)))
|
||||
|
||||
if weight is not None:
|
||||
loss *= weight
|
||||
|
||||
if size_average:
|
||||
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
|
||||
return loss.mean()
|
||||
else:
|
||||
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).sum()
|
||||
return loss.sum()
|
||||
|
||||
def l1_loss(output, target):
|
||||
return (output-target).abs().mean()
|
||||
|
@ -189,10 +194,11 @@ class MSELoss(Module):
|
|||
return mse_loss(output, target)
|
||||
|
||||
class BCELoss(Module):
|
||||
def __init__(self):
|
||||
pass
|
||||
def execute(self, output, target, size_average=True):
|
||||
return bce_loss(output, target, size_average)
|
||||
def __init__(self, weight=None, size_average=True):
|
||||
self.weight = weight
|
||||
self.size_average = size_average
|
||||
def execute(self, output, target):
|
||||
return bce_loss(output, target, self.weight, self.size_average)
|
||||
|
||||
class L1Loss(Module):
|
||||
def __init__(self):
|
||||
|
@ -201,14 +207,17 @@ class L1Loss(Module):
|
|||
return l1_loss(output, target)
|
||||
|
||||
class BCEWithLogitsLoss(Module):
|
||||
def __init__(self):
|
||||
def __init__(self, weight=None, size_average=True):
|
||||
self.sigmoid = Sigmoid()
|
||||
self.bce = BCELoss()
|
||||
def execute(self, output, target, size_average=True):
|
||||
self.bce = BCELoss(weight, size_average)
|
||||
def execute(self, output, target):
|
||||
output = self.sigmoid(output)
|
||||
output = self.bce(output, target, size_average)
|
||||
output = self.bce(output, target)
|
||||
return output
|
||||
|
||||
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True):
|
||||
return BCEWithLogitsLoss(weight, size_average)(input, target)
|
||||
|
||||
def softmax(x, dim = None):
|
||||
if dim is None:
|
||||
x = (x - x.max()).exp()
|
||||
|
|
|
@ -49,7 +49,7 @@ class TestLoss(unittest.TestCase):
|
|||
jt_y=jt_loss(jt.array(output), jt.array(target))
|
||||
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
|
||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||
|
||||
|
||||
def test_bce_loss(self):
|
||||
jt_loss=jnn.BCELoss()
|
||||
tc_loss=tnn.BCELoss()
|
||||
|
@ -60,6 +60,13 @@ class TestLoss(unittest.TestCase):
|
|||
jt_y=jt_loss(jt_sig(jt.array(output)), jt.array(target))
|
||||
tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target))
|
||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||
|
||||
weight=np.random.randn(100).astype(np.float32)
|
||||
jt_loss=jnn.BCELoss(weight=jt.array(weight), size_average=False)
|
||||
tc_loss=tnn.BCELoss(weight=torch.Tensor(weight), size_average=False)
|
||||
jt_y=jt_loss(jt_sig(jt.array(output)), jt.array(target))
|
||||
tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target))
|
||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||
|
||||
def test_bce_with_logits_loss(self):
|
||||
jt_loss=jnn.BCEWithLogitsLoss()
|
||||
|
|
|
@ -78,6 +78,32 @@ pjmap = {
|
|||
'extras': {},
|
||||
'delete': ['inplace'],
|
||||
},
|
||||
'relu': {
|
||||
'pytorch': {
|
||||
'args': 'input',
|
||||
},
|
||||
'jittor': {
|
||||
'module': 'nn',
|
||||
'name': 'relu',
|
||||
'args': 'x'
|
||||
},
|
||||
'links': {'input': 'x'},
|
||||
'extras': {},
|
||||
'delete': [],
|
||||
},
|
||||
'binary_cross_entropy_with_logits': {
|
||||
'pytorch': {
|
||||
'args': 'input, target, weight, size_average=True',
|
||||
},
|
||||
'jittor': {
|
||||
'module': 'nn',
|
||||
'name': 'binary_cross_entropy_with_logits',
|
||||
'args': 'input, target, weight, size_average=True'
|
||||
},
|
||||
'links': {},
|
||||
'extras': {},
|
||||
'delete': [],
|
||||
},
|
||||
'ReLU6': {
|
||||
'pytorch': {
|
||||
'args': 'inplace=False',
|
||||
|
@ -274,6 +300,23 @@ pjmap = {
|
|||
'links': {},
|
||||
'extras': {},
|
||||
},
|
||||
'clamp': {
|
||||
'pytorch': {
|
||||
'prefix': ['torch'],
|
||||
'args_prefix': 'input, min, max, out=None',
|
||||
'args': 'min, max, out=None',
|
||||
},
|
||||
'jittor': {
|
||||
'prefix': 'jt',
|
||||
'module': '',
|
||||
'name': 'clamp',
|
||||
'args_prefix': 'x, min_v, max_v',
|
||||
'args': 'min_v, max_v'
|
||||
},
|
||||
'links': {'min': 'min_v', 'max': 'max_v'},
|
||||
'extras': {},
|
||||
'delete': ['out'],
|
||||
},
|
||||
'permute': {
|
||||
'pytorch': {
|
||||
'prefix': [],
|
||||
|
@ -354,7 +397,7 @@ unsupport_ops = [
|
|||
# ***************************************************************
|
||||
# torch.nn
|
||||
# ***************************************************************
|
||||
'Parameter', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict',
|
||||
'Parameter', 'ModuleDict', 'ParameterList', 'ParameterDict',
|
||||
'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold',
|
||||
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d',
|
||||
'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d',
|
||||
|
|
Loading…
Reference in New Issue