mirror of https://github.com/Jittor/Jittor
add pretrained weights
This commit is contained in:
parent
ac052171ce
commit
fa64a66dad
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.1'
|
||||
__version__ = '1.2.2.2'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -33,9 +33,38 @@ from collections import OrderedDict
|
|||
from collections.abc import Sequence, Mapping
|
||||
import types
|
||||
import pickle
|
||||
import sys
|
||||
import hashlib
|
||||
import sys, os
|
||||
import traceback
|
||||
|
||||
|
||||
def safepickle(obj, path):
|
||||
s = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
|
||||
checksum = hashlib.sha1(s).digest()
|
||||
s += bytes(checksum)
|
||||
s += b"HCAJSLHD"
|
||||
with open(path, 'wb') as f:
|
||||
f.write(s)
|
||||
|
||||
def safeunpickle(path):
|
||||
if path.startswith("jittorhub://"):
|
||||
path = path.replace("jittorhub://", "https://cg.cs.tsinghua.edu.cn/jittor/assets/build/checkpoints/")
|
||||
if path.startswith("https:") or path.startswith("http:"):
|
||||
base = path.split("/")[-1]
|
||||
fname = os.path.join(compiler.ck_path, base)
|
||||
from jittor.utils.misc import download_url_to_local
|
||||
download_url_to_local(path, base, compiler.ck_path, None)
|
||||
path = fname
|
||||
with open(path, "rb") as f:
|
||||
s = f.read()
|
||||
if not s.endswith(b"HCAJSLHD"):
|
||||
return pickle.loads(s)
|
||||
checksum = s[-28:-8]
|
||||
s = s[:-28]
|
||||
if hashlib.sha1(s).digest() != checksum:
|
||||
raise ValueError("Pickle checksum does not match! path: "+path)
|
||||
return pickle.loads(s)
|
||||
|
||||
class _call_no_record_scope:
|
||||
def __enter__(self): pass
|
||||
def __exit__(self, *exc): pass
|
||||
|
@ -436,8 +465,7 @@ def display_memory_info():
|
|||
core.display_memory_info(fileline)
|
||||
|
||||
def load(path):
|
||||
pkl_file = open(path, 'rb')
|
||||
model_dict = pickle.load(pkl_file)
|
||||
model_dict = safeunpickle(path)
|
||||
return model_dict
|
||||
|
||||
def _uniq(x):
|
||||
|
@ -647,8 +675,7 @@ class Module:
|
|||
params_dict = {}
|
||||
for p in params:
|
||||
params_dict[p.name()] = p.data
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(params_dict, f, pickle.HIGHEST_PROTOCOL)
|
||||
safepickle(params_dict, path)
|
||||
|
||||
def load(self, path):
|
||||
if path.endswith(".pth"):
|
||||
|
@ -659,8 +686,7 @@ class Module:
|
|||
raise RuntimeError("pytorch need to be installed when load pth format.")
|
||||
self.load_parameters(torch.load(path, map_location=torch.device('cpu')))
|
||||
return
|
||||
with open(path, 'rb') as f:
|
||||
self.load_parameters(pickle.load(f))
|
||||
self.load_parameters(safeunpickle(path))
|
||||
|
||||
def eval(self):
|
||||
def callback(parents, k, v, n):
|
||||
|
|
|
@ -894,6 +894,8 @@ make_cache_dir(cache_path)
|
|||
make_cache_dir(os.path.join(cache_path, "jit"))
|
||||
make_cache_dir(os.path.join(cache_path, "obj_files"))
|
||||
make_cache_dir(os.path.join(cache_path, "gen"))
|
||||
ck_path = os.path.join(cache_path, "checkpoints")
|
||||
make_cache_dir(ck_path)
|
||||
|
||||
# build cache_compile
|
||||
cc_flags += f" -I{jittor_path}/src "
|
||||
|
|
|
@ -61,6 +61,7 @@ class AlexNet(nn.Module):
|
|||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def alexnet(**kwargs):
|
||||
def alexnet(pretrained=False, **kwargs):
|
||||
model = AlexNet(**kwargs)
|
||||
if pretrained: model.load("jittorhub://alexnet.pkl")
|
||||
return model
|
||||
|
|
|
@ -21,7 +21,7 @@ def densenet121(pretrained=False, **kwargs):
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
'''
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
|
||||
assert not pretrained, "pretrained doesn't support now"
|
||||
if pretrained: model.load("jittorhub://densenet121.pkl")
|
||||
return model
|
||||
|
||||
def densenet161(pretrained=False, **kwargs):
|
||||
|
@ -32,7 +32,7 @@ def densenet161(pretrained=False, **kwargs):
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
'''
|
||||
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs)
|
||||
assert not pretrained, "pretrained doesn't support now"
|
||||
if pretrained: model.load("jittorhub://densenet161.pkl")
|
||||
return model
|
||||
|
||||
def densenet169(pretrained=False, **kwargs):
|
||||
|
@ -43,7 +43,7 @@ def densenet169(pretrained=False, **kwargs):
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
'''
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
|
||||
assert not pretrained, "pretrained doesn't support now"
|
||||
if pretrained: model.load("jittorhub://densenet169.pkl")
|
||||
return model
|
||||
|
||||
def densenet201(pretrained=False, **kwargs):
|
||||
|
@ -54,7 +54,7 @@ def densenet201(pretrained=False, **kwargs):
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
'''
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
|
||||
assert not pretrained, "pretrained doesn't support now"
|
||||
if pretrained: model.load("jittorhub://densenet201.pkl")
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -12,8 +12,10 @@ from jittor import nn
|
|||
|
||||
__all__ = ['GoogLeNet', 'googlenet']
|
||||
|
||||
def googlenet(**kwargs):
|
||||
return GoogLeNet(**kwargs)
|
||||
def googlenet(pretrained=False, **kwargs):
|
||||
model = GoogLeNet(**kwargs)
|
||||
if pretrained: model.load("jittorhub://googlenet.pkl")
|
||||
return model
|
||||
|
||||
class GoogLeNet(nn.Module):
|
||||
""" GoogLeNet model architecture.
|
||||
|
|
|
@ -4,7 +4,9 @@ from jittor import nn
|
|||
__all__ = ['Inception3', 'inception_v3']
|
||||
|
||||
def inception_v3(pretrained=False, progress=True, **kwargs):
|
||||
return Inception3(**kwargs)
|
||||
model = Inception3(**kwargs)
|
||||
if pretrained: model.load("jittorhub://inception_v3.pkl")
|
||||
return model
|
||||
|
||||
class Inception3(nn.Module):
|
||||
""" Inceptionv3 model architecture.
|
||||
|
|
|
@ -90,18 +90,22 @@ class MNASNet(nn.Module):
|
|||
x = x.mean([2, 3])
|
||||
return self.classifier(x)
|
||||
|
||||
def mnasnet0_5(**kwargs):
|
||||
def mnasnet0_5(pretrained=False, **kwargs):
|
||||
model = MNASNet(0.5, **kwargs)
|
||||
if pretrained: model.load("jittorhub://mnasnet0_5.pkl")
|
||||
return model
|
||||
|
||||
def mnasnet0_75(**kwargs):
|
||||
def mnasnet0_75(pretrained=False, **kwargs):
|
||||
model = MNASNet(0.75, **kwargs)
|
||||
if pretrained: model.load("jittorhub://mnasnet0_75.pkl")
|
||||
return model
|
||||
|
||||
def mnasnet1_0(**kwargs):
|
||||
def mnasnet1_0(pretrained=False, **kwargs):
|
||||
model = MNASNet(1.0, **kwargs)
|
||||
if pretrained: model.load("jittorhub://mnasnet1_0.pkl")
|
||||
return model
|
||||
|
||||
def mnasnet1_3(**kwargs):
|
||||
def mnasnet1_3(pretrained=False, **kwargs):
|
||||
model = MNASNet(1.3, **kwargs)
|
||||
if pretrained: model.load("jittorhub://mnasnet1_3.pkl")
|
||||
return model
|
||||
|
|
|
@ -93,7 +93,8 @@ class MobileNetV2(nn.Module):
|
|||
def execute(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
def mobilenet_v2():
|
||||
def mobilenet_v2(pretrained=False):
|
||||
model = MobileNetV2()
|
||||
if pretrained: model.load("jittorhub://mobilenet_v2.pkl")
|
||||
return model
|
||||
|
||||
|
|
|
@ -154,19 +154,26 @@ def _resnet(block, layers, **kwargs):
|
|||
model = ResNet(block, layers, **kwargs)
|
||||
return model
|
||||
|
||||
def Resnet18(**kwargs):
|
||||
return _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
def Resnet18(pretrained=False, **kwargs):
|
||||
model = _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnet18.pkl")
|
||||
return model
|
||||
resnet18 = Resnet18
|
||||
|
||||
def Resnet34(**kwargs):
|
||||
return _resnet( BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
def Resnet34(pretrained=False, **kwargs):
|
||||
model = _resnet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnet34.pkl")
|
||||
return model
|
||||
resnet34 = Resnet34
|
||||
|
||||
def Resnet50(**kwargs):
|
||||
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
def Resnet50(pretrained=False, **kwargs):
|
||||
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnet50.pkl")
|
||||
return model
|
||||
|
||||
resnet50 = Resnet50
|
||||
|
||||
def Resnet101(**kwargs):
|
||||
def Resnet101(pretrained=False, **kwargs):
|
||||
"""
|
||||
ResNet-101 model architecture.
|
||||
|
||||
|
@ -180,28 +187,38 @@ def Resnet101(**kwargs):
|
|||
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
resnet101 = Resnet101
|
||||
|
||||
def Resnet152(**kwargs):
|
||||
return _resnet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
def Resnet152(pretrained=False, **kwargs):
|
||||
model = _resnet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnet152.pkl")
|
||||
return model
|
||||
resnet152 = Resnet152
|
||||
|
||||
def Resnext50_32x4d(**kwargs):
|
||||
def Resnext50_32x4d(pretrained=False, **kwargs):
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnext50_32x4d.pkl")
|
||||
return model
|
||||
resnext50_32x4d = Resnext50_32x4d
|
||||
|
||||
def Resnext101_32x8d(**kwargs):
|
||||
def Resnext101_32x8d(pretrained=False, **kwargs):
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnext101_32x8d.pkl")
|
||||
return model
|
||||
resnext101_32x8d = Resnext101_32x8d
|
||||
|
||||
def Wide_resnet50_2(**kwargs):
|
||||
def Wide_resnet50_2(pretrained=False, **kwargs):
|
||||
kwargs['width_per_group'] = (64 * 2)
|
||||
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://wide_resnet50_2.pkl")
|
||||
return model
|
||||
wide_resnet50_2 = Wide_resnet50_2
|
||||
|
||||
def Wide_resnet101_2(**kwargs):
|
||||
def Wide_resnet101_2(pretrained=False, **kwargs):
|
||||
kwargs['width_per_group'] = (64 * 2)
|
||||
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://wide_resnet101_2.pkl")
|
||||
return model
|
||||
wide_resnet101_2 = Wide_resnet101_2
|
|
@ -93,14 +93,22 @@ def _shufflenetv2(arch, *args):
|
|||
model = ShuffleNetV2(*args)
|
||||
return model
|
||||
|
||||
def shufflenet_v2_x0_5():
|
||||
return _shufflenetv2('shufflenetv2_x0.5', [4, 8, 4], [24, 48, 96, 192, 1024])
|
||||
def shufflenet_v2_x0_5(pretrained=False):
|
||||
model = _shufflenetv2('shufflenetv2_x0.5', [4, 8, 4], [24, 48, 96, 192, 1024])
|
||||
if pretrained: model.load("jittorhub://shufflenet_v2_x0_5.pkl")
|
||||
return model
|
||||
|
||||
def shufflenet_v2_x1_0():
|
||||
return _shufflenetv2('shufflenetv2_x1.0', [4, 8, 4], [24, 116, 232, 464, 1024])
|
||||
def shufflenet_v2_x1_0(pretrained=False):
|
||||
model = _shufflenetv2('shufflenetv2_x1.0', [4, 8, 4], [24, 116, 232, 464, 1024])
|
||||
if pretrained: model.load("jittorhub://shufflenet_v2_x1_0.pkl")
|
||||
return model
|
||||
|
||||
def shufflenet_v2_x1_5():
|
||||
return _shufflenetv2('shufflenetv2_x1.5', [4, 8, 4], [24, 176, 352, 704, 1024])
|
||||
def shufflenet_v2_x1_5(pretrained=False):
|
||||
model = _shufflenetv2('shufflenetv2_x1.5', [4, 8, 4], [24, 176, 352, 704, 1024])
|
||||
if pretrained: model.load("jittorhub://shufflenet_v2_x1_5.pkl")
|
||||
return model
|
||||
|
||||
def shufflenet_v2_x2_0():
|
||||
return _shufflenetv2('shufflenetv2_x2.0', [4, 8, 4], [24, 244, 488, 976, 2048])
|
||||
def shufflenet_v2_x2_0(pretrained=False):
|
||||
model = _shufflenetv2('shufflenetv2_x2.0', [4, 8, 4], [24, 244, 488, 976, 2048])
|
||||
if pretrained: model.load("jittorhub://shufflenet_v2_x2_0.pkl")
|
||||
return model
|
||||
|
|
|
@ -83,8 +83,12 @@ def _squeezenet(version, **kwargs):
|
|||
model = SqueezeNet(version, **kwargs)
|
||||
return model
|
||||
|
||||
def squeezenet1_0(**kwargs):
|
||||
return _squeezenet('1_0', **kwargs)
|
||||
def squeezenet1_0(pretrained=False, **kwargs):
|
||||
model = _squeezenet('1_0', **kwargs)
|
||||
if pretrained: model.load("jittorhub://squeezenet1_0.pkl")
|
||||
return model
|
||||
|
||||
def squeezenet1_1(**kwargs):
|
||||
return _squeezenet('1_1', **kwargs)
|
||||
def squeezenet1_1(pretrained=False, **kwargs):
|
||||
model = _squeezenet('1_1', **kwargs)
|
||||
if pretrained: model.load("jittorhub://squeezenet1_1.pkl")
|
||||
return model
|
||||
|
|
|
@ -67,33 +67,49 @@ def _vgg(arch, cfg, batch_norm, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
def vgg11(**kwargs):
|
||||
return _vgg('vgg11', 'A', False, **kwargs)
|
||||
def vgg11(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg11', 'A', False, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg11.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg11_bn(**kwargs):
|
||||
return _vgg('vgg11_bn', 'A', True, **kwargs)
|
||||
def vgg11_bn(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg11_bn', 'A', True, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg11_bn.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg13(**kwargs):
|
||||
return _vgg('vgg13', 'B', False, **kwargs)
|
||||
def vgg13(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg13', 'B', False, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg13.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg13_bn(**kwargs):
|
||||
return _vgg('vgg13_bn', 'B', True, **kwargs)
|
||||
def vgg13_bn(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg13_bn', 'B', True, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg13_bn.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg16(**kwargs):
|
||||
return _vgg('vgg16', 'D', False, **kwargs)
|
||||
def vgg16(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg16', 'D', False, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg16.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg16_bn(**kwargs):
|
||||
return _vgg('vgg16_bn', 'D', True, **kwargs)
|
||||
def vgg16_bn(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg16_bn', 'D', True, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg16_bn.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg19(**kwargs):
|
||||
return _vgg('vgg19', 'E', False, **kwargs)
|
||||
def vgg19(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg19', 'E', False, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg19.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg19_bn(**kwargs):
|
||||
return _vgg('vgg19_bn', 'E', True, **kwargs)
|
||||
def vgg19_bn(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg19_bn', 'E', True, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg19_bn.pkl")
|
||||
return model
|
|
@ -38,7 +38,7 @@ def download_url_to_local(url, filename, root_folder, md5):
|
|||
ensure_dir(root_folder)
|
||||
file_path = os.path.join(root_folder, filename)
|
||||
if check_file_exist(file_path, md5):
|
||||
print("Data file has been downloaded and verified")
|
||||
return
|
||||
else:
|
||||
try:
|
||||
print('Downloading ' + url + ' to ' + file_path)
|
||||
|
|
|
@ -91,7 +91,8 @@ void PassManager::run_passes() {
|
|||
|
||||
run_pass<SolveConflictDefinePass>();
|
||||
run_pass<MergeLoopVarPass>();
|
||||
run_pass<ConstVarPass>();
|
||||
// tmp disable ConstVarPass
|
||||
// run_pass<ConstVarPass>();
|
||||
|
||||
run_pass<RestridePass>();
|
||||
|
||||
|
|
Loading…
Reference in New Issue