polish save

This commit is contained in:
Dun Liang 2020-12-25 16:35:12 +08:00
parent cd085765a0
commit 1acda62f59
2 changed files with 21 additions and 1 deletions

View File

@ -478,7 +478,17 @@ def load(path):
return model_dict return model_dict
def save(params_dict, path): def save(params_dict, path):
safepickle(params_dict, path) def dfs(x):
if isinstance(x, list):
for i in range(len(x)):
x[i] = dfs(x[i])
elif isinstance(x, dict):
for k in x:
x[k] = dfs(x[k])
elif isinstance(x, Var):
return x.numpy()
return x
safepickle(dfs(params_dict), path)
def _uniq(x): def _uniq(x):
a = set() a = set()

View File

@ -101,5 +101,15 @@ class TestPad(unittest.TestCase):
assert np.allclose(t_res[idx].numpy(), j_res[idx].numpy()) assert np.allclose(t_res[idx].numpy(), j_res[idx].numpy())
print('pass unbind test ...') print('pass unbind test ...')
class TestOther(unittest.TestCase):
def test_save(self):
pp = [1,2,jt.array([1,2,3]), {"a":[1,2,3], "b":jt.array([1,2,3])}]
jt.save(pp, "/tmp/xx.pkl")
x = jt.load("/tmp/xx.pkl")
assert x[:2] == [1,2]
assert (x[2] == np.array([1,2,3])).all()
assert x[3]['a'] == [1,2,3]
assert (x[3]['b'] == np.array([1,2,3])).all()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()