mirror of https://github.com/Jittor/Jittor
add pickle and array interface
This commit is contained in:
parent
7bfc871b1b
commit
3d56c776ba
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.20'
|
||||
__version__ = '1.2.2.21'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -17,6 +17,7 @@ with lock.lock_scope():
|
|||
from . import compiler
|
||||
from .compiler import LOG, has_cuda
|
||||
from .compiler import compile_custom_ops, compile_custom_op
|
||||
import jittor_core
|
||||
import jittor_core as core
|
||||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
|
@ -963,6 +964,13 @@ Var.float = Var.float32
|
|||
double = float64
|
||||
Var.double = Var.float64
|
||||
|
||||
# __array__ interface is used for np.array(jt_var)
|
||||
Var.__array__ = Var.numpy
|
||||
# __getstate__, __setstate__, __module__ is used for pickle.dump and pickle.load
|
||||
Var.__getstate__ = Var.numpy
|
||||
Var.__setstate__ = lambda self, x: self.__init__(x)
|
||||
Var.__module__ = "jittor"
|
||||
|
||||
from . import nn
|
||||
from . import attention
|
||||
from . import lr_scheduler
|
||||
|
@ -995,4 +1003,4 @@ def normal(mean, std, size=None, dtype="float32"):
|
|||
else:
|
||||
if isinstance(mean, Var): size = mean.shape
|
||||
if isinstance(std, Var): size = std.shape
|
||||
return jt.init.gauss(size, dtype, mean, std)
|
||||
return jt.init.gauss(size, dtype, mean, std)
|
||||
|
|
|
@ -140,6 +140,19 @@ class TestArray(unittest.TestCase):
|
|||
assert (a.numpy() == [1,2,3]).all()
|
||||
assert (b.numpy() == [1,2,3]).all()
|
||||
|
||||
def test_np_array(self):
|
||||
a = jt.Var([1,2,3])
|
||||
b = np.array(a)
|
||||
assert (b==[1,2,3]).all()
|
||||
|
||||
def test_pickle(self):
|
||||
import pickle
|
||||
a = jt.Var([1,2,3])
|
||||
s = pickle.dumps(a)
|
||||
b = pickle.loads(s)
|
||||
assert isinstance(b, jt.Var)
|
||||
assert (b.data == [1,2,3]).all()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue