add pickle and array interface

This commit is contained in:
Dun Liang 2021-01-28 14:24:48 +08:00
parent 7bfc871b1b
commit 3d56c776ba
2 changed files with 23 additions and 2 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.20'
__version__ = '1.2.2.21'
from . import lock
with lock.lock_scope():
ori_int = int
@ -17,6 +17,7 @@ with lock.lock_scope():
from . import compiler
from .compiler import LOG, has_cuda
from .compiler import compile_custom_ops, compile_custom_op
import jittor_core
import jittor_core as core
from jittor_core import *
from jittor_core.ops import *
@ -963,6 +964,13 @@ Var.float = Var.float32
double = float64
Var.double = Var.float64
# __array__ interface is used for np.array(jt_var)
Var.__array__ = Var.numpy
# __getstate__, __setstate__, __module__ is used for pickle.dump and pickle.load
Var.__getstate__ = Var.numpy
Var.__setstate__ = lambda self, x: self.__init__(x)
Var.__module__ = "jittor"
from . import nn
from . import attention
from . import lr_scheduler
@ -995,4 +1003,4 @@ def normal(mean, std, size=None, dtype="float32"):
else:
if isinstance(mean, Var): size = mean.shape
if isinstance(std, Var): size = std.shape
return jt.init.gauss(size, dtype, mean, std)
return jt.init.gauss(size, dtype, mean, std)

View File

@ -140,6 +140,19 @@ class TestArray(unittest.TestCase):
assert (a.numpy() == [1,2,3]).all()
assert (b.numpy() == [1,2,3]).all()
def test_np_array(self):
a = jt.Var([1,2,3])
b = np.array(a)
assert (b==[1,2,3]).all()
def test_pickle(self):
import pickle
a = jt.Var([1,2,3])
s = pickle.dumps(a)
b = pickle.loads(s)
assert isinstance(b, jt.Var)
assert (b.data == [1,2,3]).all()
if __name__ == "__main__":