cos&step lr_scheduler

This commit is contained in:
cxjyxx_me 2020-12-08 11:13:07 +08:00
parent 8d79e298e5
commit 86b7aaaa4e
3 changed files with 66 additions and 4 deletions

View File

@ -83,4 +83,58 @@ class ReduceLROnPlateau(object):
save = self.threshold + 1.0
return a > b * save
else:
return a > b + self.threshold
return a > b + self.threshold
class CosineAnnealingLR(object):
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
self.T_max = T_max
self.eta_min = eta_min
self.optimizer = optimizer
self.last_epoch = last_epoch
self.base_lr = optimizer.lr
#TODO set last_epoch is not ready
def get_lr(self):
if self.last_epoch == 0:
return self.base_lr
now_lr = self.optimizer.lr
if (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
return (now_lr + (self.base_lr - self.eta_min) *
(1 - math.cos(math.pi / self.T_max)) / 2)
return ((1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
(1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
(now_lr - self.eta_min) + self.eta_min)
def step(self):
self.last_epoch += 1
self.update_lr(self.get_lr())
def update_lr(self, new_lr):
self.optimizer.lr = new_lr
for i, param_group in enumerate(self.optimizer.param_groups):
if param_group.get("lr")!=None:
param_group["lr"] = new_lr
class MultiStepLR(object):
def __init__(self, optimizer, milestones=[], gamma=0.1, last_epoch=-1):
self.optimizer = optimizer
self.milestones = milestones
self.gamma = gamma
self.last_epoch = last_epoch
#TODO set last_epoch is not ready
def get_lr(self):
now_lr = self.optimizer.lr
if (self.last_epoch in self.milestones):
now_lr *= gamma
return now_lr
def step(self):
self.last_epoch += 1
self.update_lr(self.get_lr())
def update_lr(self, new_lr):
self.optimizer.lr = new_lr
for i, param_group in enumerate(self.optimizer.param_groups):
if param_group.get("lr")!=None:
param_group["lr"] = new_lr

View File

@ -11,8 +11,8 @@
import jittor as jt
from jittor import nn
__all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet50', 'Resnet101', 'Resnet152', 'Resnext50_32x4d', 'Resnext101_32x8d', 'Wide_resnet50_2', 'Wide_resnet101_2',
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2']
__all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet26', 'Resnet38', 'Resnet50', 'Resnet101', 'Resnet152', 'Resnext50_32x4d', 'Resnext101_32x8d', 'Wide_resnet50_2', 'Wide_resnet101_2',
'resnet18', 'resnet34', 'resnet26', 'resnet38', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2']
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
conv=nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
@ -166,6 +166,14 @@ def Resnet50(**kwargs):
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
resnet50 = Resnet50
def Resnet38(**kwargs):
return _resnet(Bottleneck, [2, 3, 5, 2], **kwargs)
resnet38 = Resnet38
def Resnet26(**kwargs):
return _resnet(Bottleneck, [1, 2, 4, 1], **kwargs)
resnet26 = Resnet26
def Resnet101(**kwargs):
"""
ResNet-101 model architecture.

View File

@ -91,7 +91,7 @@ void PassManager::run_passes() {
run_pass<SolveConflictDefinePass>();
run_pass<MergeLoopVarPass>();
run_pass<ConstVarPass>();
// run_pass<ConstVarPass>();
run_pass<RestridePass>();