add res2net pretrain

This commit is contained in:
Dun Liang 2021-04-03 15:30:03 +08:00
parent b031e5d617
commit c7b78f570e
3 changed files with 166 additions and 111 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.57'
__version__ = '1.2.2.58'
from . import lock
with lock.lock_scope():
ori_int = int
@ -59,6 +59,14 @@ def safeunpickle(path):
from jittor.utils.misc import download_url_to_local
download_url_to_local(path, base, compiler.ck_path, None)
path = fname
if path.endswith(".pth"):
try:
dirty_fix_pytorch_runtime_error()
import torch
except:
raise RuntimeError("pytorch need to be installed when load pth format.")
model_dict = torch.load(path, map_location=torch.device('cpu'))
return model_dict
with open(path, "rb") as f:
s = f.read()
if not s.endswith(b"HCAJSLHD"):
@ -680,15 +688,7 @@ def display_memory_info():
def load(path: str):
''' loads an object from a file.
'''
if path.endswith(".pth"):
try:
dirty_fix_pytorch_runtime_error()
import torch
except:
raise RuntimeError("pytorch need to be installed when load pth format.")
model_dict = torch.load(path, map_location=torch.device('cpu'))
else:
model_dict = safeunpickle(path)
model_dict = safeunpickle(path)
return model_dict
def save(params_dict, path: str):
@ -725,7 +725,7 @@ class Module:
def __init__(self, *args, **kw):
pass
def execute(self, *args, **kw):
raise NotImplementedError
raise NotImplementedError("Please implement 'execute' method of "+str(type(self)))
def __call__(self, *args, **kw):
return self.execute(*args, **kw)
def __repr__(self):

View File

@ -5,10 +5,21 @@ from jittor import init
from jittor.contrib import concat, argmax_pool
import math
model_urls = {
'res2net50_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_4s-06e79181.pth',
'res2net50_48w_2s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_48w_2s-afed724a.pth',
'res2net50_14w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_14w_8s-6527dddc.pth',
'res2net50_26w_6s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_6s-19041792.pth',
'res2net50_26w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_8s-2c7c9f12.pth',
'res2net101_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_26w_4s-02a759a1.pth',
}
class Bottle2neck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, baseWidth=26, scale = 4, stype='normal'):
def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale = 4, stype='normal'):
""" Constructor
Args:
inplanes: input channel dimensionality
@ -22,31 +33,31 @@ class Bottle2neck(nn.Module):
super(Bottle2neck, self).__init__()
width = int(math.floor(planes * (baseWidth/64.0)))
self.conv1 = nn.Conv(inplanes, width*scale, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm(width*scale)
assert scale > 1, 'Res2Net degenerates to ResNet when scales = 1.'
self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale)
if scale == 1:
self.nums = 1
else:
self.nums = scale -1
if stype == 'stage':
self.pool = nn.Pool(kernel_size=3, stride = stride, padding=1, op='mean')
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1)
convs = []
bns = []
for i in range(self.nums):
self.convs.append(nn.Conv(width, width, kernel_size=3, stride = stride, dilation=dilation, padding=dilation, bias=False))
self.bns.append(nn.BatchNorm(width))
convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.conv3 = nn.Conv(width*scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm(planes * self.expansion)
self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU()
self.downsample = downsample
self.stype = stype
self.scale = scale
self.width = width
self.stride = stride
self.dilation = dilation
def execute(self, x):
residual = x
@ -54,22 +65,23 @@ class Bottle2neck(nn.Module):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = out
outs = []
spx = jt.split(out, self.width, 1)
for i in range(self.nums):
if i==0 or self.stype=='stage':
sp = spx[:, i*self.width: (i+1)*self.width]
else:
sp = sp + spx[:, i*self.width: (i+1)*self.width]
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
outs.append(sp)
if self.stype=='normal' or self.stride==1:
outs.append(spx[:, self.nums*self.width: (self.nums+1)*self.width])
elif self.stype=='stage':
outs.append(self.pool(spx[:, self.nums*self.width: (self.nums+1)*self.width]))
out = concat(outs, 1)
if i==0 or self.stype=='stage':
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i==0:
out = sp
else:
out = jt.concat((out, sp), 1)
if self.scale != 1 and self.stype=='normal':
out = jt.concat((out, spx[self.nums]),1)
elif self.scale != 1 and self.stype=='stage':
out = jt.concat((out, self.pool(spx[self.nums])),1)
out = self.conv3(out)
out = self.bn3(out)
@ -82,103 +94,138 @@ class Bottle2neck(nn.Module):
return out
class Res2Net(nn.Module):
class Res2Net(Module):
def __init__(self, block, layers, output_stride, baseWidth = 26, scale = 4):
def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000):
self.inplanes = 64
super(Res2Net, self).__init__()
self.baseWidth = baseWidth
self.scale = scale
self.inplanes = 64
blocks = [1, 2, 4]
if output_stride == 16:
strides = [1, 2, 2, 1]
dilations = [1, 1, 1, 2]
elif output_stride == 8:
strides = [1, 2, 1, 1]
dilations = [1, 1, 2, 4]
else:
raise NotImplementedError
# Modules
self.conv1 = nn.Sequential(
nn.Conv(3, 32, 3, 2, 1, bias=False),
nn.BatchNorm(32),
nn.ReLU(),
nn.Conv(32, 32, 3, 1, 1, bias=False),
nn.BatchNorm(32),
nn.ReLU(),
nn.Conv(32, 64, 3, 1, 1, bias=False)
)
self.bn1 = nn.BatchNorm(64)
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
# self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1])
self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2])
self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3])
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Pool(kernel_size=stride, stride=stride,
ceil_mode=True, op='mean'),
nn.Conv(self.inplanes, planes * block.expansion,
kernel_size=1, stride=1, bias=False),
nn.BatchNorm(planes * block.expansion),
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, dilation, downsample,
layers.append(block(self.inplanes, planes, stride, downsample=downsample,
stype='stage', baseWidth = self.baseWidth, scale=self.scale))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation, baseWidth = self.baseWidth, scale=self.scale))
layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale))
return nn.Sequential(*layers)
def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Pool(kernel_size=stride, stride=stride,
ceil_mode=True, op='mean'),
nn.Conv(self.inplanes, planes * block.expansion,
kernel_size=1, stride=1, bias=False),
nn.BatchNorm(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
downsample=downsample, stype='stage', baseWidth = self.baseWidth, scale=self.scale))
self.inplanes = planes * block.expansion
for i in range(1, len(blocks)):
layers.append(block(self.inplanes, planes, stride=1,
dilation=blocks[i]*dilation, baseWidth = self.baseWidth, scale=self.scale))
return nn.Sequential(*layers)
def execute(self, input):
x = self.conv1(input)
def execute(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = argmax_pool(x, 2, 2)
x = self.maxpool(x)
x = self.layer1(x)
low_level_feat = x
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x, low_level_feat
def res2net50(output_stride=16):
model = Res2Net(Bottle2neck, [3,4,6,3], output_stride)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def res2net50(pretrained=False, **kwargs):
"""Constructs a Res2Net-50 model.
Res2Net-50 refers to the Res2Net-50_26w_4s.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
if pretrained:
model.load(model_urls['res2net50_26w_4s'])
return model
def res2net101(output_stride=16):
model = Res2Net(Bottle2neck, [3,4,23,3], output_stride)
def res2net50_26w_4s(pretrained=False, **kwargs):
"""Constructs a Res2Net-50_26w_4s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
if pretrained:
model.load(model_urls['res2net50_26w_4s'])
return model
def res2net101_26w_4s(pretrained=False, **kwargs):
"""Constructs a Res2Net-50_26w_4s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)
if pretrained:
model.load(model_urls['res2net101_26w_4s'])
return model
res2net101 = res2net101_26w_4s
def res2net50_26w_6s(pretrained=False, **kwargs):
"""Constructs a Res2Net-50_26w_4s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 6, **kwargs)
if pretrained:
model.load(model_urls['res2net50_26w_6s'])
return model
def res2net50_26w_8s(pretrained=False, **kwargs):
"""Constructs a Res2Net-50_26w_4s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 8, **kwargs)
if pretrained:
model.load(model_urls['res2net50_26w_8s'])
return model
def res2net50_48w_2s(pretrained=False, **kwargs):
"""Constructs a Res2Net-50_48w_2s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 48, scale = 2, **kwargs)
if pretrained:
model.load(model_urls['res2net50_48w_2s'])
return model
def res2net50_14w_8s(pretrained=False, **kwargs):
"""Constructs a Res2Net-50_14w_8s model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 14, scale = 8, **kwargs)
if pretrained:
model.load(model_urls['res2net50_14w_8s'])
return model

View File

@ -158,6 +158,14 @@ jt.mkl_ops.mkl_conv(x, w, 1, 1, 2, 2).sync()
assert n == 2
assert list(x.keys()) == [0,1]
# def test_res2net(self):
# import jittor.models
# net = jittor.models.res2net50(True)
# img = jt.random((2,3,224,224))
# out = net(img)
# print(out.shape, out.sum())
# assert out.shape == [2,1000]
if __name__ == "__main__":
unittest.main()