diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 62f22c97..f12fdae7 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -167,11 +167,16 @@ def cross_entropy_loss(output, target, ignore_index=None): def mse_loss(output, target): return (output-target).sqr().mean() -def bce_loss(output, target, size_average=True): +def bce_loss(output, target, weight=None, size_average=True): + loss = - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))) + + if weight is not None: + loss *= weight + if size_average: - return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean() + return loss.mean() else: - return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).sum() + return loss.sum() def l1_loss(output, target): return (output-target).abs().mean() @@ -189,10 +194,11 @@ class MSELoss(Module): return mse_loss(output, target) class BCELoss(Module): - def __init__(self): - pass - def execute(self, output, target, size_average=True): - return bce_loss(output, target, size_average) + def __init__(self, weight=None, size_average=True): + self.weight = weight + self.size_average = size_average + def execute(self, output, target): + return bce_loss(output, target, self.weight, self.size_average) class L1Loss(Module): def __init__(self): @@ -201,14 +207,17 @@ class L1Loss(Module): return l1_loss(output, target) class BCEWithLogitsLoss(Module): - def __init__(self): + def __init__(self, weight=None, size_average=True): self.sigmoid = Sigmoid() - self.bce = BCELoss() - def execute(self, output, target, size_average=True): + self.bce = BCELoss(weight, size_average) + def execute(self, output, target): output = self.sigmoid(output) - output = self.bce(output, target, size_average) + output = self.bce(output, target) return output +def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True): + return BCEWithLogitsLoss(weight, size_average)(input, target) + def softmax(x, dim = None): if dim is None: x = (x - x.max()).exp() diff --git a/python/jittor/test/test_loss.py b/python/jittor/test/test_loss.py index 9afb5ddd..9a96dc66 100644 --- a/python/jittor/test/test_loss.py +++ b/python/jittor/test/test_loss.py @@ -49,7 +49,7 @@ class TestLoss(unittest.TestCase): jt_y=jt_loss(jt.array(output), jt.array(target)) tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target)) assert np.allclose(jt_y.numpy(), tc_y.numpy()) - + def test_bce_loss(self): jt_loss=jnn.BCELoss() tc_loss=tnn.BCELoss() @@ -60,6 +60,13 @@ class TestLoss(unittest.TestCase): jt_y=jt_loss(jt_sig(jt.array(output)), jt.array(target)) tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target)) assert np.allclose(jt_y.numpy(), tc_y.numpy()) + + weight=np.random.randn(100).astype(np.float32) + jt_loss=jnn.BCELoss(weight=jt.array(weight), size_average=False) + tc_loss=tnn.BCELoss(weight=torch.Tensor(weight), size_average=False) + jt_y=jt_loss(jt_sig(jt.array(output)), jt.array(target)) + tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target)) + assert np.allclose(jt_y.numpy(), tc_y.numpy()) def test_bce_with_logits_loss(self): jt_loss=jnn.BCEWithLogitsLoss() diff --git a/python/jittor/utils/pytorch_converter.py b/python/jittor/utils/pytorch_converter.py index f0293748..19961e6a 100644 --- a/python/jittor/utils/pytorch_converter.py +++ b/python/jittor/utils/pytorch_converter.py @@ -78,6 +78,32 @@ pjmap = { 'extras': {}, 'delete': ['inplace'], }, + 'relu': { + 'pytorch': { + 'args': 'input', + }, + 'jittor': { + 'module': 'nn', + 'name': 'relu', + 'args': 'x' + }, + 'links': {'input': 'x'}, + 'extras': {}, + 'delete': [], + }, + 'binary_cross_entropy_with_logits': { + 'pytorch': { + 'args': 'input, target, weight, size_average=True', + }, + 'jittor': { + 'module': 'nn', + 'name': 'binary_cross_entropy_with_logits', + 'args': 'input, target, weight, size_average=True' + }, + 'links': {}, + 'extras': {}, + 'delete': [], + }, 'ReLU6': { 'pytorch': { 'args': 'inplace=False', @@ -274,6 +300,23 @@ pjmap = { 'links': {}, 'extras': {}, }, + 'clamp': { + 'pytorch': { + 'prefix': ['torch'], + 'args_prefix': 'input, min, max, out=None', + 'args': 'min, max, out=None', + }, + 'jittor': { + 'prefix': 'jt', + 'module': '', + 'name': 'clamp', + 'args_prefix': 'x, min_v, max_v', + 'args': 'min_v, max_v' + }, + 'links': {'min': 'min_v', 'max': 'max_v'}, + 'extras': {}, + 'delete': ['out'], + }, 'permute': { 'pytorch': { 'prefix': [], @@ -354,7 +397,7 @@ unsupport_ops = [ # *************************************************************** # torch.nn # *************************************************************** - 'Parameter', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict', + 'Parameter', 'ModuleDict', 'ParameterList', 'ParameterDict', 'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold', 'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d', 'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d',