mirror of https://github.com/Jittor/Jittor
polish save
This commit is contained in:
parent
cd085765a0
commit
1acda62f59
|
@ -478,7 +478,17 @@ def load(path):
|
||||||
return model_dict
|
return model_dict
|
||||||
|
|
||||||
def save(params_dict, path):
|
def save(params_dict, path):
|
||||||
safepickle(params_dict, path)
|
def dfs(x):
|
||||||
|
if isinstance(x, list):
|
||||||
|
for i in range(len(x)):
|
||||||
|
x[i] = dfs(x[i])
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
for k in x:
|
||||||
|
x[k] = dfs(x[k])
|
||||||
|
elif isinstance(x, Var):
|
||||||
|
return x.numpy()
|
||||||
|
return x
|
||||||
|
safepickle(dfs(params_dict), path)
|
||||||
|
|
||||||
def _uniq(x):
|
def _uniq(x):
|
||||||
a = set()
|
a = set()
|
||||||
|
|
|
@ -101,5 +101,15 @@ class TestPad(unittest.TestCase):
|
||||||
assert np.allclose(t_res[idx].numpy(), j_res[idx].numpy())
|
assert np.allclose(t_res[idx].numpy(), j_res[idx].numpy())
|
||||||
print('pass unbind test ...')
|
print('pass unbind test ...')
|
||||||
|
|
||||||
|
class TestOther(unittest.TestCase):
|
||||||
|
def test_save(self):
|
||||||
|
pp = [1,2,jt.array([1,2,3]), {"a":[1,2,3], "b":jt.array([1,2,3])}]
|
||||||
|
jt.save(pp, "/tmp/xx.pkl")
|
||||||
|
x = jt.load("/tmp/xx.pkl")
|
||||||
|
assert x[:2] == [1,2]
|
||||||
|
assert (x[2] == np.array([1,2,3])).all()
|
||||||
|
assert x[3]['a'] == [1,2,3]
|
||||||
|
assert (x[3]['b'] == np.array([1,2,3])).all()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
Loading…
Reference in New Issue