diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 3fc495ae..8a177685 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -478,7 +478,17 @@ def load(path): return model_dict def save(params_dict, path): - safepickle(params_dict, path) + def dfs(x): + if isinstance(x, list): + for i in range(len(x)): + x[i] = dfs(x[i]) + elif isinstance(x, dict): + for k in x: + x[k] = dfs(x[k]) + elif isinstance(x, Var): + return x.numpy() + return x + safepickle(dfs(params_dict), path) def _uniq(x): a = set() diff --git a/python/jittor/test/test_misc_op.py b/python/jittor/test/test_misc_op.py index b29091a5..dda1c74a 100644 --- a/python/jittor/test/test_misc_op.py +++ b/python/jittor/test/test_misc_op.py @@ -101,5 +101,15 @@ class TestPad(unittest.TestCase): assert np.allclose(t_res[idx].numpy(), j_res[idx].numpy()) print('pass unbind test ...') +class TestOther(unittest.TestCase): + def test_save(self): + pp = [1,2,jt.array([1,2,3]), {"a":[1,2,3], "b":jt.array([1,2,3])}] + jt.save(pp, "/tmp/xx.pkl") + x = jt.load("/tmp/xx.pkl") + assert x[:2] == [1,2] + assert (x[2] == np.array([1,2,3])).all() + assert x[3]['a'] == [1,2,3] + assert (x[3]['b'] == np.array([1,2,3])).all() + if __name__ == "__main__": unittest.main() \ No newline at end of file