mirror of https://github.com/Jittor/Jittor
polish sub module save load
This commit is contained in:
parent
42dfaaed2e
commit
4fb462e1b9
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.101'
|
||||
__version__ = '1.2.3.102'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -69,6 +69,15 @@ def safeunpickle(path):
|
|||
except:
|
||||
raise RuntimeError("pytorch need to be installed when load pth format.")
|
||||
model_dict = torch.load(path, map_location=torch.device('cpu'))
|
||||
try:
|
||||
for k, v in model_dict.items():
|
||||
try:
|
||||
if not isinstance(v, np.ndarray) and hasattr(v, "cpu"):
|
||||
model_dict[k] = v.cpu().detach().numpy()
|
||||
except:
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
return model_dict
|
||||
with open(path, "rb") as f:
|
||||
s = f.read()
|
||||
|
@ -796,12 +805,30 @@ class Module:
|
|||
return _uniq(ps)
|
||||
|
||||
def named_parameters(self):
|
||||
ps = self.parameters()
|
||||
return [ (p.name(), p) for p in ps ]
|
||||
uniq_set = set()
|
||||
ps = {}
|
||||
stack = []
|
||||
def callback(parents, k, v, n):
|
||||
stack.append(str(k))
|
||||
dc = v.__dict__
|
||||
if isinstance(v, nn.ParameterList):
|
||||
dc = v.params
|
||||
for k2, p in dc.items():
|
||||
if isinstance(k2, str) and k2.startswith("_"): continue
|
||||
if isinstance(p, Var):
|
||||
if id(p) in uniq_set: continue
|
||||
uniq_set.add(id(p))
|
||||
pname = ".".join(stack[1:]+[str(k2)])
|
||||
ps[pname] = p
|
||||
if len(pname) > len(p.name()):
|
||||
p.name(pname)
|
||||
def callback_leave(parents, k, v, n):
|
||||
stack.pop()
|
||||
self.dfs([], None, callback, callback_leave)
|
||||
return ps
|
||||
|
||||
def state_dict(self):
|
||||
ps = self.parameters()
|
||||
return { p.name(): p for p in ps }
|
||||
return self.named_parameters()
|
||||
|
||||
def load_state_dict(self, params):
|
||||
self.load_parameters(params)
|
||||
|
@ -1011,10 +1038,13 @@ Arguments of hook are defined as::
|
|||
>>> net.save('net.pkl')
|
||||
>>> net.load('net.pkl')
|
||||
'''
|
||||
params = self.parameters()
|
||||
params = self.named_parameters()
|
||||
params_dict = {}
|
||||
for p in params:
|
||||
params_dict[p.name()] = p.data
|
||||
for k, v in params.items():
|
||||
if isinstance(v, Var):
|
||||
params_dict[k] = v.numpy()
|
||||
else:
|
||||
params_dict[k] = v
|
||||
safepickle(params_dict, path)
|
||||
|
||||
def load(self, path: str):
|
||||
|
|
|
@ -1616,7 +1616,8 @@ class Sequential(Module):
|
|||
return
|
||||
parents.append(self)
|
||||
for k,v in self.layers.items():
|
||||
v.dfs(parents, k, callback, callback_leave)
|
||||
if isinstance(v, Module):
|
||||
v.dfs(parents, k, callback, callback_leave)
|
||||
parents.pop()
|
||||
if callback_leave:
|
||||
callback_leave(parents, k, self, n_children)
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
def expect_error(func):
|
||||
try:
|
||||
|
@ -86,6 +87,17 @@ class TestCore(unittest.TestCase):
|
|||
c = np.matmul(a, b)
|
||||
jtc = jt.matmul(jt.array(a), jt.array(b)).data
|
||||
assert np.all(jtc == c)
|
||||
|
||||
def test_save_load_sub_module(self):
|
||||
class Net(jt.Module):
|
||||
def __init__(self):
|
||||
self.conv1 = jt.nn.Conv(3,3,3)
|
||||
net = Net()
|
||||
assert list(net.named_parameters().keys()) == ['conv1.weight', 'conv1.bias']
|
||||
assert list(net.conv1.named_parameters().keys()) == ['weight', 'bias']
|
||||
pkl_name = os.path.join(jt.flags.cache_path, "sub.pkl")
|
||||
net.conv1.save(pkl_name)
|
||||
net.conv1.load(pkl_name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -47,7 +47,17 @@ class TestLoadPth(unittest.TestCase):
|
|||
jt_out = jt_model(jt_img)
|
||||
torch_out = torch_model(torch_img)
|
||||
print(np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())))
|
||||
assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-4
|
||||
assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-3
|
||||
|
||||
pth_name = os.path.join(jt.flags.cache_path, "x.pth")
|
||||
torch.save(torch_model.state_dict, pth_name)
|
||||
jt_model.load(pth_name)
|
||||
|
||||
# output
|
||||
jt_out = jt_model(jt_img)
|
||||
# torch_out = torch_model(torch_img)
|
||||
print(np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())))
|
||||
assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-3
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -165,6 +165,8 @@ jt.mkl_ops.mkl_conv(x, w, 1, 1, 2, 2).sync()
|
|||
n += 1
|
||||
assert n == 2
|
||||
assert list(x.keys()) == [0,1]
|
||||
p = x.parameters()
|
||||
assert len(p)==0
|
||||
|
||||
# def test_res2net(self):
|
||||
# import jittor.models
|
||||
|
|
Loading…
Reference in New Issue