mirror of https://github.com/Jittor/Jittor
cos&step lr_scheduler
This commit is contained in:
parent
4af3f7a0d2
commit
3dfa26bbce
|
@ -84,4 +84,58 @@ class ReduceLROnPlateau(object):
|
||||||
save = self.threshold + 1.0
|
save = self.threshold + 1.0
|
||||||
return a > b * save
|
return a > b * save
|
||||||
else:
|
else:
|
||||||
return a > b + self.threshold
|
return a > b + self.threshold
|
||||||
|
|
||||||
|
class CosineAnnealingLR(object):
|
||||||
|
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
|
||||||
|
self.T_max = T_max
|
||||||
|
self.eta_min = eta_min
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.last_epoch = last_epoch
|
||||||
|
self.base_lr = optimizer.lr
|
||||||
|
#TODO set last_epoch is not ready
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
if self.last_epoch == 0:
|
||||||
|
return self.base_lr
|
||||||
|
now_lr = self.optimizer.lr
|
||||||
|
if (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
|
||||||
|
return (now_lr + (self.base_lr - self.eta_min) *
|
||||||
|
(1 - math.cos(math.pi / self.T_max)) / 2)
|
||||||
|
return ((1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
|
||||||
|
(1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
|
||||||
|
(now_lr - self.eta_min) + self.eta_min)
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.last_epoch += 1
|
||||||
|
self.update_lr(self.get_lr())
|
||||||
|
|
||||||
|
def update_lr(self, new_lr):
|
||||||
|
self.optimizer.lr = new_lr
|
||||||
|
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||||
|
if param_group.get("lr")!=None:
|
||||||
|
param_group["lr"] = new_lr
|
||||||
|
|
||||||
|
class MultiStepLR(object):
|
||||||
|
def __init__(self, optimizer, milestones=[], gamma=0.1, last_epoch=-1):
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.milestones = milestones
|
||||||
|
self.gamma = gamma
|
||||||
|
self.last_epoch = last_epoch
|
||||||
|
#TODO set last_epoch is not ready
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
now_lr = self.optimizer.lr
|
||||||
|
if (self.last_epoch in self.milestones):
|
||||||
|
now_lr *= gamma
|
||||||
|
return now_lr
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.last_epoch += 1
|
||||||
|
self.update_lr(self.get_lr())
|
||||||
|
|
||||||
|
def update_lr(self, new_lr):
|
||||||
|
self.optimizer.lr = new_lr
|
||||||
|
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||||
|
if param_group.get("lr")!=None:
|
||||||
|
param_group["lr"] = new_lr
|
|
@ -12,8 +12,8 @@
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
from jittor import nn
|
from jittor import nn
|
||||||
|
|
||||||
__all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet50', 'Resnet101', 'Resnet152', 'Resnext50_32x4d', 'Resnext101_32x8d', 'Wide_resnet50_2', 'Wide_resnet101_2',
|
__all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet26', 'Resnet38', 'Resnet50', 'Resnet101', 'Resnet152', 'Resnext50_32x4d', 'Resnext101_32x8d', 'Wide_resnet50_2', 'Wide_resnet101_2',
|
||||||
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2']
|
'resnet18', 'resnet34', 'resnet26', 'resnet38', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2']
|
||||||
|
|
||||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||||
conv=nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
|
conv=nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||||
|
@ -174,7 +174,15 @@ def Resnet50(pretrained=False, **kwargs):
|
||||||
|
|
||||||
resnet50 = Resnet50
|
resnet50 = Resnet50
|
||||||
|
|
||||||
def Resnet101(pretrained=False, **kwargs):
|
def Resnet38(**kwargs):
|
||||||
|
return _resnet(Bottleneck, [2, 3, 5, 2], **kwargs)
|
||||||
|
resnet38 = Resnet38
|
||||||
|
|
||||||
|
def Resnet26(**kwargs):
|
||||||
|
return _resnet(Bottleneck, [1, 2, 4, 1], **kwargs)
|
||||||
|
resnet26 = Resnet26
|
||||||
|
|
||||||
|
def Resnet101(**kwargs):
|
||||||
"""
|
"""
|
||||||
ResNet-101 model architecture.
|
ResNet-101 model architecture.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue