polish sub module save load

This commit is contained in:
Dun Liang 2021-09-13 11:03:28 +08:00
parent 42dfaaed2e
commit 4fb462e1b9
5 changed files with 65 additions and 10 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.101'
__version__ = '1.2.3.102'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -69,6 +69,15 @@ def safeunpickle(path):
except:
raise RuntimeError("pytorch need to be installed when load pth format.")
model_dict = torch.load(path, map_location=torch.device('cpu'))
try:
for k, v in model_dict.items():
try:
if not isinstance(v, np.ndarray) and hasattr(v, "cpu"):
model_dict[k] = v.cpu().detach().numpy()
except:
pass
except:
pass
return model_dict
with open(path, "rb") as f:
s = f.read()
@ -796,12 +805,30 @@ class Module:
return _uniq(ps)
def named_parameters(self):
ps = self.parameters()
return [ (p.name(), p) for p in ps ]
uniq_set = set()
ps = {}
stack = []
def callback(parents, k, v, n):
stack.append(str(k))
dc = v.__dict__
if isinstance(v, nn.ParameterList):
dc = v.params
for k2, p in dc.items():
if isinstance(k2, str) and k2.startswith("_"): continue
if isinstance(p, Var):
if id(p) in uniq_set: continue
uniq_set.add(id(p))
pname = ".".join(stack[1:]+[str(k2)])
ps[pname] = p
if len(pname) > len(p.name()):
p.name(pname)
def callback_leave(parents, k, v, n):
stack.pop()
self.dfs([], None, callback, callback_leave)
return ps
def state_dict(self):
ps = self.parameters()
return { p.name(): p for p in ps }
return self.named_parameters()
def load_state_dict(self, params):
self.load_parameters(params)
@ -1011,10 +1038,13 @@ Arguments of hook are defined as::
>>> net.save('net.pkl')
>>> net.load('net.pkl')
'''
params = self.parameters()
params = self.named_parameters()
params_dict = {}
for p in params:
params_dict[p.name()] = p.data
for k, v in params.items():
if isinstance(v, Var):
params_dict[k] = v.numpy()
else:
params_dict[k] = v
safepickle(params_dict, path)
def load(self, path: str):

View File

@ -1616,7 +1616,8 @@ class Sequential(Module):
return
parents.append(self)
for k,v in self.layers.items():
v.dfs(parents, k, callback, callback_leave)
if isinstance(v, Module):
v.dfs(parents, k, callback, callback_leave)
parents.pop()
if callback_leave:
callback_leave(parents, k, self, n_children)

View File

@ -7,6 +7,7 @@
import unittest
import jittor as jt
import numpy as np
import os
def expect_error(func):
try:
@ -86,6 +87,17 @@ class TestCore(unittest.TestCase):
c = np.matmul(a, b)
jtc = jt.matmul(jt.array(a), jt.array(b)).data
assert np.all(jtc == c)
def test_save_load_sub_module(self):
class Net(jt.Module):
def __init__(self):
self.conv1 = jt.nn.Conv(3,3,3)
net = Net()
assert list(net.named_parameters().keys()) == ['conv1.weight', 'conv1.bias']
assert list(net.conv1.named_parameters().keys()) == ['weight', 'bias']
pkl_name = os.path.join(jt.flags.cache_path, "sub.pkl")
net.conv1.save(pkl_name)
net.conv1.load(pkl_name)
if __name__ == "__main__":
unittest.main()

View File

@ -47,7 +47,17 @@ class TestLoadPth(unittest.TestCase):
jt_out = jt_model(jt_img)
torch_out = torch_model(torch_img)
print(np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())))
assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-4
assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-3
pth_name = os.path.join(jt.flags.cache_path, "x.pth")
torch.save(torch_model.state_dict, pth_name)
jt_model.load(pth_name)
# output
jt_out = jt_model(jt_img)
# torch_out = torch_model(torch_img)
print(np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())))
assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-3
if __name__ == "__main__":
unittest.main()

View File

@ -165,6 +165,8 @@ jt.mkl_ops.mkl_conv(x, w, 1, 1, 2, 2).sync()
n += 1
assert n == 2
assert list(x.keys()) == [0,1]
p = x.parameters()
assert len(p)==0
# def test_res2net(self):
# import jittor.models