diff --git a/python/jittor/lr_scheduler.py b/python/jittor/lr_scheduler.py index 6a8fd599..8e964759 100644 --- a/python/jittor/lr_scheduler.py +++ b/python/jittor/lr_scheduler.py @@ -84,4 +84,58 @@ class ReduceLROnPlateau(object): save = self.threshold + 1.0 return a > b * save else: - return a > b + self.threshold \ No newline at end of file + return a > b + self.threshold + +class CosineAnnealingLR(object): + def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): + self.T_max = T_max + self.eta_min = eta_min + self.optimizer = optimizer + self.last_epoch = last_epoch + self.base_lr = optimizer.lr + #TODO set last_epoch is not ready + + def get_lr(self): + if self.last_epoch == 0: + return self.base_lr + now_lr = self.optimizer.lr + if (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: + return (now_lr + (self.base_lr - self.eta_min) * + (1 - math.cos(math.pi / self.T_max)) / 2) + return ((1 + math.cos(math.pi * self.last_epoch / self.T_max)) / + (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * + (now_lr - self.eta_min) + self.eta_min) + + def step(self): + self.last_epoch += 1 + self.update_lr(self.get_lr()) + + def update_lr(self, new_lr): + self.optimizer.lr = new_lr + for i, param_group in enumerate(self.optimizer.param_groups): + if param_group.get("lr")!=None: + param_group["lr"] = new_lr + +class MultiStepLR(object): + def __init__(self, optimizer, milestones=[], gamma=0.1, last_epoch=-1): + self.optimizer = optimizer + self.milestones = milestones + self.gamma = gamma + self.last_epoch = last_epoch + #TODO set last_epoch is not ready + + def get_lr(self): + now_lr = self.optimizer.lr + if (self.last_epoch in self.milestones): + now_lr *= gamma + return now_lr + + def step(self): + self.last_epoch += 1 + self.update_lr(self.get_lr()) + + def update_lr(self, new_lr): + self.optimizer.lr = new_lr + for i, param_group in enumerate(self.optimizer.param_groups): + if param_group.get("lr")!=None: + param_group["lr"] = new_lr \ No newline at end of file diff --git a/python/jittor/models/resnet.py b/python/jittor/models/resnet.py index c2d4b7d7..ceb08db6 100644 --- a/python/jittor/models/resnet.py +++ b/python/jittor/models/resnet.py @@ -12,8 +12,8 @@ import jittor as jt from jittor import nn -__all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet50', 'Resnet101', 'Resnet152', 'Resnext50_32x4d', 'Resnext101_32x8d', 'Wide_resnet50_2', 'Wide_resnet101_2', - 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] +__all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet26', 'Resnet38', 'Resnet50', 'Resnet101', 'Resnet152', 'Resnext50_32x4d', 'Resnext101_32x8d', 'Wide_resnet50_2', 'Wide_resnet101_2', + 'resnet18', 'resnet34', 'resnet26', 'resnet38', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): conv=nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) @@ -174,7 +174,15 @@ def Resnet50(pretrained=False, **kwargs): resnet50 = Resnet50 -def Resnet101(pretrained=False, **kwargs): +def Resnet38(**kwargs): + return _resnet(Bottleneck, [2, 3, 5, 2], **kwargs) +resnet38 = Resnet38 + +def Resnet26(**kwargs): + return _resnet(Bottleneck, [1, 2, 4, 1], **kwargs) +resnet26 = Resnet26 + +def Resnet101(**kwargs): """ ResNet-101 model architecture.