mirror of https://github.com/Jittor/Jittor
polish save
This commit is contained in:
parent
cd085765a0
commit
1acda62f59
|
@ -478,7 +478,17 @@ def load(path):
|
|||
return model_dict
|
||||
|
||||
def save(params_dict, path):
|
||||
safepickle(params_dict, path)
|
||||
def dfs(x):
|
||||
if isinstance(x, list):
|
||||
for i in range(len(x)):
|
||||
x[i] = dfs(x[i])
|
||||
elif isinstance(x, dict):
|
||||
for k in x:
|
||||
x[k] = dfs(x[k])
|
||||
elif isinstance(x, Var):
|
||||
return x.numpy()
|
||||
return x
|
||||
safepickle(dfs(params_dict), path)
|
||||
|
||||
def _uniq(x):
|
||||
a = set()
|
||||
|
|
|
@ -101,5 +101,15 @@ class TestPad(unittest.TestCase):
|
|||
assert np.allclose(t_res[idx].numpy(), j_res[idx].numpy())
|
||||
print('pass unbind test ...')
|
||||
|
||||
class TestOther(unittest.TestCase):
|
||||
def test_save(self):
|
||||
pp = [1,2,jt.array([1,2,3]), {"a":[1,2,3], "b":jt.array([1,2,3])}]
|
||||
jt.save(pp, "/tmp/xx.pkl")
|
||||
x = jt.load("/tmp/xx.pkl")
|
||||
assert x[:2] == [1,2]
|
||||
assert (x[2] == np.array([1,2,3])).all()
|
||||
assert x[3]['a'] == [1,2,3]
|
||||
assert (x[3]['b'] == np.array([1,2,3])).all()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue