mirror of https://github.com/Jittor/Jittor
densenet & lr_scheduler
This commit is contained in:
parent
03aefb620a
commit
f32bc4ae49
|
@ -777,6 +777,8 @@ double = float64
|
|||
Var.double = Var.float64
|
||||
|
||||
from . import nn
|
||||
from . import attention
|
||||
from . import lr_scheduler
|
||||
from . import linalg
|
||||
from .nn import matmul
|
||||
from . import contrib
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
import jittor as jt
|
||||
from jittor.optim import Optimizer
|
||||
import math
|
||||
|
||||
class ReduceLROnPlateau(object):
|
||||
def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=1e-4, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8):
|
||||
assert factor < 1.0, "factor should be < 1.0."
|
||||
assert isinstance(optimizer, Optimizer), '{} is not an Optimizer'.format(type(optimizer).__name__)
|
||||
assert mode in {'min', 'max'}, 'mode ' + mode + ' is unknown!'
|
||||
assert threshold_mode in {'rel', 'abs'}, 'threshold mode ' + threshold_mode + ' is unknown!'
|
||||
|
||||
if isinstance(min_lr, list) or isinstance(min_lr, tuple):
|
||||
assert len(min_lr) == len(optimizer.param_groups), "expected {} min_lrs, got {}".format(len(optimizer.param_groups), len(min_lr))
|
||||
self.min_lrs = list(min_lr)
|
||||
else:
|
||||
self.min_lrs = [min_lr] * len(optimizer.param_groups)
|
||||
self.factor = factor
|
||||
self.optimizer = optimizer
|
||||
self.patience = patience
|
||||
self.verbose = verbose
|
||||
self.cooldown = cooldown
|
||||
self.n_cd = 0
|
||||
self.mode = mode
|
||||
self.threshold = threshold
|
||||
self.threshold_mode = threshold_mode
|
||||
self.loss_best = None
|
||||
self.n_bad = 0
|
||||
self.eps = eps
|
||||
self.last_epoch = 0
|
||||
self.loss_best = math.inf if mode=="min" else -math.inf
|
||||
|
||||
def step(self, loss, epoch=None):
|
||||
# convert `metrics` to float, in case it's a zero-dim Tensor
|
||||
loss_now = float(loss)
|
||||
if epoch is None:
|
||||
epoch = self.last_epoch + 1
|
||||
self.last_epoch = epoch
|
||||
|
||||
if self.better(loss_now, self.loss_best):
|
||||
self.loss_best = loss_now
|
||||
self.n_bad = 0
|
||||
else:
|
||||
self.n_bad += 1
|
||||
|
||||
if self.n_cd > 0:
|
||||
self.n_cd -= 1
|
||||
self.n_bad = 0
|
||||
|
||||
if self.n_bad > self.patience:
|
||||
self.update_lr(epoch)
|
||||
self.n_cd = self.cooldown
|
||||
self.n_bad = 0
|
||||
|
||||
def update_lr(self, epoch):
|
||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||
old_lr = float(param_group.get("lr", self.optimizer.lr))
|
||||
new_lr = max(old_lr * self.factor, self.min_lrs[i])
|
||||
if old_lr - new_lr > self.eps:
|
||||
if param_group.get("lr")!=None:
|
||||
param_group["lr"] = new_lr
|
||||
else:
|
||||
self.optimizer.lr = new_lr
|
||||
if self.verbose:
|
||||
print('Epoch {:5d}: reducing learning rate of group {} from {:.4e} to {:.4e}.'.format(epoch, i, old_lr, new_lr))
|
||||
|
||||
def better(self, a, b):
|
||||
if self.mode == 'min' and self.threshold_mode == 'rel':
|
||||
save = 1.0 - self.threshold
|
||||
return a < b * save
|
||||
elif self.mode == 'min' and self.threshold_mode == 'abs':
|
||||
return a < b - self.threshold
|
||||
elif self.mode == 'max' and self.threshold_mode == 'rel':
|
||||
save = self.threshold + 1.0
|
||||
return a > b * save
|
||||
else:
|
||||
return a > b + self.threshold
|
|
@ -17,3 +17,5 @@ from .mnasnet import *
|
|||
from . import shufflenetv2
|
||||
from .shufflenetv2 import *
|
||||
from .res2net import res2net50, res2net101
|
||||
from . import densenet
|
||||
from .densenet import *
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
# This model is generated by pytorch converter.
|
||||
import jittor as jt
|
||||
from jittor import nn
|
||||
from jittor import init
|
||||
from collections import OrderedDict
|
||||
|
||||
def Densenet169(pretrained=False, **kwargs):
|
||||
'Densenet-169 model from\n `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_\n\n Args:\n pretrained (bool): If True, returns a model pre-trained on ImageNet\n '
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
|
||||
assert not pretrained, "pretrained doesn't support now"
|
||||
# if pretrained:
|
||||
# save=torch.load('parameter.pkl')
|
||||
# params=model.parameters()
|
||||
# model.load_parameters(save)
|
||||
return model
|
||||
|
||||
class _DenseLayer(nn.Sequential):
|
||||
|
||||
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
|
||||
super(_DenseLayer, self).__init__()
|
||||
self.add_module('norm1', nn.BatchNorm(num_input_features))
|
||||
self.add_module('relu1', nn.ReLU())
|
||||
self.add_module('conv1', nn.Conv(num_input_features, (bn_size * growth_rate), 1, stride=1, bias=False))
|
||||
self.add_module('norm2', nn.BatchNorm((bn_size * growth_rate)))
|
||||
self.add_module('relu2', nn.ReLU())
|
||||
self.add_module('conv2', nn.Conv((bn_size * growth_rate), growth_rate, 3, stride=1, padding=1, bias=False))
|
||||
self.drop_rate = drop_rate
|
||||
self.drop = nn.Dropout(self.drop_rate)
|
||||
|
||||
def execute(self, x):
|
||||
new_features = super(_DenseLayer, self).execute(x)
|
||||
if (self.drop_rate > 0):
|
||||
new_features = self.drop(new_features)
|
||||
return jt.contrib.concat([x, new_features], dim=1)
|
||||
|
||||
class _DenseBlock(nn.Sequential):
|
||||
|
||||
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
|
||||
super(_DenseBlock, self).__init__()
|
||||
for i in range(num_layers):
|
||||
layer = _DenseLayer((num_input_features + (i * growth_rate)), growth_rate, bn_size, drop_rate)
|
||||
self.add_module('denselayer%d' % (i + 1), layer)
|
||||
|
||||
class _Transition(nn.Sequential):
|
||||
|
||||
def __init__(self, num_input_features, num_output_features):
|
||||
super(_Transition, self).__init__()
|
||||
self.add_module('norm', nn.BatchNorm(num_input_features))
|
||||
self.add_module('relu', nn.ReLU())
|
||||
self.add_module('conv', nn.Conv(num_input_features, num_output_features, 1, stride=1, bias=False))
|
||||
self.add_module('pool', nn.Pool(2, stride=2, op='mean'))
|
||||
|
||||
class DenseNet(nn.Module):
|
||||
'Densenet-BC model class, based on\n `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_\n\n Args:\n growth_rate (int) - how many filters to add each layer (`k` in paper)\n block_config (list of 4 ints) - how many layers in each pooling block\n num_init_features (int) - the number of filters to learn in the first convolution layer\n bn_size (int) - multiplicative factor for number of bottle neck layers\n (i.e. bn_size * k features in the bottleneck layer)\n drop_rate (float) - dropout rate after each dense layer\n num_classes (int) - number of classification classes\n '
|
||||
|
||||
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
|
||||
super(DenseNet, self).__init__()
|
||||
self.features = nn.Sequential(OrderedDict([
|
||||
('conv0', nn.Conv(3, num_init_features, 7, stride=2, padding=3, bias=False)),
|
||||
('norm0', nn.BatchNorm(num_init_features)),
|
||||
('relu0', nn.ReLU()),
|
||||
('pool0', nn.Pool(3, stride=2, padding=1, op='maximum')),
|
||||
]))
|
||||
num_features = num_init_features
|
||||
for (i, num_layers) in enumerate(block_config):
|
||||
block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
|
||||
self.features.add_module('denseblock%d' % (i + 1), block)
|
||||
num_features = (num_features + (num_layers * growth_rate))
|
||||
if (i != (len(block_config) - 1)):
|
||||
trans = _Transition(num_input_features=num_features, num_output_features=(num_features // 2))
|
||||
self.features.add_module('transition%d' % (i + 1), trans)
|
||||
num_features = (num_features // 2)
|
||||
self.features.add_module('norm5', nn.BatchNorm(num_features))
|
||||
self.fc = nn.Linear(num_features, num_classes)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv):
|
||||
nn.init.invariant_uniform_(m.weight)
|
||||
elif isinstance(m, nn.BatchNorm):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def execute(self, x):
|
||||
features = self.features(x)
|
||||
out = nn.relu(features)
|
||||
out = jt.pool.pool(out, kernel_size=7, op="mean",stride=1).reshape([features.shape[0], -1])
|
||||
out = jt.sigmoid(self.fc(out))
|
||||
return out
|
|
@ -0,0 +1,102 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
from jittor import nn, Module
|
||||
from jittor.models import densenet
|
||||
import numpy as np
|
||||
import sys, os
|
||||
import random
|
||||
import math
|
||||
import unittest
|
||||
from jittor.test.test_reorder_tuner import simple_parser
|
||||
from jittor.test.test_log import find_log_with_re
|
||||
from jittor.dataset.mnist import MNIST
|
||||
import jittor.transform as trans
|
||||
import time
|
||||
|
||||
skip_this_test = False
|
||||
|
||||
class MnistNet(Module):
|
||||
def __init__(self):
|
||||
self.model = densenet.Densenet169()
|
||||
self.layer = nn.Linear(1000,10)
|
||||
def execute(self, x):
|
||||
x = self.model(x)
|
||||
x = self.layer(x)
|
||||
return x
|
||||
|
||||
@unittest.skipIf(skip_this_test, "skip_this_test")
|
||||
class TestDensenet(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
# hyper-parameters
|
||||
self.batch_size = 100
|
||||
self.weight_decay = 0.0001
|
||||
self.momentum = 0.9
|
||||
self.learning_rate = 0.1
|
||||
# mnist dataset
|
||||
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
|
||||
.set_attrs(batch_size=self.batch_size, shuffle=True)
|
||||
self.train_loader.num_workers = 4
|
||||
|
||||
# setup random seed
|
||||
def setup_seed(self, seed):
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
jt.seed(seed)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1, use_stat_allocator=1)
|
||||
def test_densenet(self):
|
||||
self.setup_seed(1)
|
||||
loss_list=[]
|
||||
acc_list=[]
|
||||
mnist_net = MnistNet()
|
||||
global prev
|
||||
prev = time.time()
|
||||
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
|
||||
# SGD = jt.optim.Adam(mnist_net.parameters(), lr=0.0001)
|
||||
|
||||
for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||
output = mnist_net(data)
|
||||
loss = nn.cross_entropy_loss(output, target)
|
||||
SGD.step(loss)
|
||||
def callback(batch_idx, loss, output, target):
|
||||
# print train info
|
||||
global prev
|
||||
pred = np.argmax(output, axis=1)
|
||||
acc = np.mean(target==pred)
|
||||
loss_list.append(loss[0])
|
||||
acc_list.append(acc)
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
|
||||
.format(0, batch_idx, 600,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev))
|
||||
# prev = time.time()
|
||||
jt.fetch(batch_idx, loss, output, target, callback)
|
||||
# Train Epoch: 0 [0/600 (0%)] Loss: 2.402650 Acc: 0.060000
|
||||
# Train Epoch: 0 [1/600 (0%)] Loss: 2.770145 Acc: 0.100000
|
||||
# Train Epoch: 0 [2/600 (0%)] Loss: 3.528072 Acc: 0.100000
|
||||
# Train Epoch: 0 [3/600 (0%)] Loss: 2.992042 Acc: 0.100000
|
||||
# Train Epoch: 0 [4/600 (1%)] Loss: 4.672772 Acc: 0.060000
|
||||
# Train Epoch: 0 [5/600 (1%)] Loss: 5.003410 Acc: 0.080000
|
||||
# Train Epoch: 0 [6/600 (1%)] Loss: 5.417546 Acc: 0.100000
|
||||
# Train Epoch: 0 [7/600 (1%)] Loss: 5.137665 Acc: 0.100000
|
||||
# Train Epoch: 0 [8/600 (1%)] Loss: 5.241075 Acc: 0.070000
|
||||
# Train Epoch: 0 [9/600 (2%)] Loss: 4.515363 Acc: 0.100000
|
||||
# Train Epoch: 0 [10/600 (2%)] Loss: 3.357187 Acc: 0.170000
|
||||
# Train Epoch: 0 [20/600 (3%)] Loss: 2.265879 Acc: 0.100000
|
||||
# Train Epoch: 0 [30/600 (5%)] Loss: 2.107000 Acc: 0.250000
|
||||
# Train Epoch: 0 [40/600 (7%)] Loss: 1.918214 Acc: 0.290000
|
||||
# Train Epoch: 0 [50/600 (8%)] Loss: 1.645694 Acc: 0.400000
|
||||
|
||||
jt.sync_all(True)
|
||||
assert np.mean(loss_list[-50:])<0.3
|
||||
assert np.mean(acc_list[-50:])>0.9
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,52 @@
|
|||
|
||||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
torch = None
|
||||
skip_this_test = True
|
||||
|
||||
def check_equal(q,k,v,tatt,jatt):
|
||||
tq=torch.from_numpy(q)
|
||||
jq=jt.array(q)
|
||||
tk=torch.from_numpy(k)
|
||||
jk=jt.array(k)
|
||||
tv=torch.from_numpy(v)
|
||||
jv=jt.array(v)
|
||||
|
||||
jatt.load_parameters(tatt.state_dict())
|
||||
ty, tw = tatt(tq, tk, tv)
|
||||
jy, jw = jatt(jq, jk, jv)
|
||||
assert np.allclose(ty.detach().numpy(), jy.numpy(), rtol=1e-3)
|
||||
assert np.allclose(tw.detach().numpy(), jw.numpy(), rtol=1e-3)
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestAttention(unittest.TestCase):
|
||||
def test_attention(self):
|
||||
j_opt = jt.optim.SGD([jt.array([1])], 1.0)
|
||||
t_opt = torch.optim.SGD([torch.ones([1])], 1.0)
|
||||
j_scheduler = jt.lr_scheduler.ReduceLROnPlateau(j_opt)
|
||||
t_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(t_opt)
|
||||
for i in range(100):
|
||||
loss=random.random()
|
||||
j_scheduler.step(loss)
|
||||
t_scheduler.step(loss)
|
||||
assert j_opt.lr == t_opt.state_dict()['param_groups'][0]['lr']
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue