add jittor var pickle-able

This commit is contained in:
Dun Liang 2021-01-29 15:20:40 +08:00
parent 3be3a5499c
commit d33d76c791
6 changed files with 71 additions and 32 deletions

View File

@ -378,6 +378,7 @@ Var.clamp = clamp
def type_as(a, b):
return a.unary(op=b.dtype)
Var.type_as = type_as
Var.astype = Var.cast
def masked_fill(x, mask, value):
assert list(x.shape) == list(mask.shape)
@ -403,6 +404,44 @@ def argmin(x, dim, keepdims:bool=False):
return x.arg_reduce("min", dim, keepdims)
Var.argmin = argmin
def randn(*size, dtype="float32", requires_grad=True):
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
arr = jt.random(size, dtype, "normal")
if not requires_grad: return arr.stop_grad()
return arr
def rand(*size, dtype="float32", requires_grad=True):
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
arr = jt.random(size, dtype)
if not requires_grad: return arr.stop_grad()
return arr
def rand_like(x, dtype=None):
if dtype is None: dtype = x.dtype
return jt.random(x.shape, x.dtype)
def randn_like(x, dtype=None):
if dtype is None: dtype = x.dtype
return jt.random(x.shape, x.dtype, "normal")
def randint(low, high=None, shape=(1,), dtype="int32"):
if high is None: low, high = 0, low
v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5)
return v.astype(dtype)
def randint_like(x, low, high=None):
return randint(low, high, x.shape, x.dtype)
def normal(mean, std, size=None, dtype="float32"):
if size is None:
if isinstance(mean, Var) and isinstance(std, Var):
assert mean.shape == std.shape
size = mean.shape
else:
if isinstance(mean, Var): size = mean.shape
if isinstance(std, Var): size = std.shape
return jt.init.gauss(size, dtype, mean, std)
def attrs(var):
return {
"is_stop_fuse": var.is_stop_fuse(),
@ -966,10 +1005,10 @@ Var.double = Var.float64
# __array__ interface is used for np.array(jt_var)
Var.__array__ = Var.numpy
# __getstate__, __setstate__, __module__ is used for pickle.dump and pickle.load
Var.__getstate__ = Var.numpy
Var.__setstate__ = Var.__init__
Var.__array_priority__ = 2000
# __reduce__, __module__ is used for pickle.dump and pickle.load
Var.__module__ = "jittor"
Var.__reduce__ = lambda self: (Var, (self.data,))
from . import nn
from . import attention
@ -981,26 +1020,3 @@ from . import numpy2cupy
from .contrib import concat
from .misc import *
from . import sparse
def randn(*size, dtype="float32", requires_grad=False):
if isinstance(size, tuple) and isinstance(size[0], tuple): size = size[0]
arr = jt.random(size, dtype, "normal")
if not requires_grad: return arr.stop_grad()
return arr
def rand(*size, dtype="float32", requires_grad=False):
if isinstance(size, tuple) and isinstance(size[0], tuple): size = size[0]
arr = jt.random(size, dtype)
if not requires_grad: return arr.stop_grad()
return arr
def normal(mean, std, size=None, dtype="float32"):
if size is None:
if isinstance(mean, Var) and isinstance(std, Var):
assert mean.shape == std.shape
size = mean.shape
else:
if isinstance(mean, Var): size = mean.shape
if isinstance(std, Var): size = std.shape
return jt.init.gauss(size, dtype, mean, std)

View File

@ -147,11 +147,11 @@ class TestArray(unittest.TestCase):
def test_pickle(self):
import pickle
a = jt.Var([1,2,3])
s = pickle.dumps(a)
a = jt.Var([1,2,3,4])
s = pickle.dumps(a, pickle.HIGHEST_PROTOCOL)
b = pickle.loads(s)
assert isinstance(b, jt.Var)
assert (b.data == [1,2,3]).all()
assert (b.data == [1,2,3,4]).all()

View File

@ -87,6 +87,18 @@ class TestRandomOp(unittest.TestCase):
def test_normal_cuda(self):
self.test_normal()
def test_other_rand(self):
a = jt.array([1.0,2.0,3.0])
b = jt.rand_like(a)
c = jt.randn_like(a)
assert b.shape == c.shape
assert b.shape == a.shape
print(b, c)
assert jt.randint(10, 20, (2000,)).min() == 10
assert jt.randint(10, 20, (2000,)).max() == 19
assert jt.randint(10, shape=(2000,)).max() == 9
assert jt.randint_like(a, 10).shape == a.shape
if __name__ == "__main__":
unittest.main()

View File

@ -15,8 +15,15 @@ from jittor.dataset.mnist import MNIST
import jittor.transform as trans
from tqdm import tqdm
class BBox:
def __init__(self, x):
self.x = x
def __eq__(self, other):
return bool((self.x == other.x).all())
def test_ring_buffer():
buffer = jt.RingBuffer(1000)
buffer = jt.RingBuffer(2000)
def test_send_recv(data):
print("test send recv", type(data))
buffer.push(data)
@ -65,6 +72,9 @@ def test_ring_buffer():
test_send_recv(jt.array(np.random.rand(10,10)))
bbox = BBox(jt.array(np.random.rand(10,10)))
test_send_recv(bbox)
expect_error(lambda: test_send_recv(np.random.rand(10,1000)))

View File

@ -121,8 +121,8 @@ struct RingBuffer {
rr = c2 << size_bit;
rr_next = rr + size;
}
CHECK(rr_next <= r+this->size) << "Buffer size too small, please increase buffer size.";
while (rr_next > l + this->size) {
CHECKop(size,<=,this->size) << "Buffer size too small, please increase buffer size.";
wait();
}
offset = rr_next;

View File

@ -15,8 +15,9 @@ namespace jittor {
static void push_py_object_pickle(RingBuffer* rb, PyObject* obj, uint64& __restrict__ offset) {
PyObjHolder pickle(PyImport_ImportModule("pickle"));
PyObjHolder dumps(PyObject_GetAttrString(pickle.obj, "dumps"));
PyObjHolder proto(PyObject_GetAttrString(pickle.obj, "HIGHEST_PROTOCOL"));
rb->push_t<uint8>(6, offset);
PyObjHolder ret(PyObject_CallFunctionObjArgs(dumps.obj, obj, nullptr));
PyObjHolder ret(PyObject_CallFunctionObjArgs(dumps.obj, obj, proto.obj, nullptr));
obj = ret.obj;
Py_ssize_t size;
char* s;