mirror of https://github.com/Jittor/Jittor
add jittor var pickle-able
This commit is contained in:
parent
3be3a5499c
commit
d33d76c791
|
@ -378,6 +378,7 @@ Var.clamp = clamp
|
|||
def type_as(a, b):
|
||||
return a.unary(op=b.dtype)
|
||||
Var.type_as = type_as
|
||||
Var.astype = Var.cast
|
||||
|
||||
def masked_fill(x, mask, value):
|
||||
assert list(x.shape) == list(mask.shape)
|
||||
|
@ -403,6 +404,44 @@ def argmin(x, dim, keepdims:bool=False):
|
|||
return x.arg_reduce("min", dim, keepdims)
|
||||
Var.argmin = argmin
|
||||
|
||||
def randn(*size, dtype="float32", requires_grad=True):
|
||||
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
|
||||
arr = jt.random(size, dtype, "normal")
|
||||
if not requires_grad: return arr.stop_grad()
|
||||
return arr
|
||||
|
||||
def rand(*size, dtype="float32", requires_grad=True):
|
||||
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
|
||||
arr = jt.random(size, dtype)
|
||||
if not requires_grad: return arr.stop_grad()
|
||||
return arr
|
||||
|
||||
def rand_like(x, dtype=None):
|
||||
if dtype is None: dtype = x.dtype
|
||||
return jt.random(x.shape, x.dtype)
|
||||
|
||||
def randn_like(x, dtype=None):
|
||||
if dtype is None: dtype = x.dtype
|
||||
return jt.random(x.shape, x.dtype, "normal")
|
||||
|
||||
def randint(low, high=None, shape=(1,), dtype="int32"):
|
||||
if high is None: low, high = 0, low
|
||||
v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5)
|
||||
return v.astype(dtype)
|
||||
|
||||
def randint_like(x, low, high=None):
|
||||
return randint(low, high, x.shape, x.dtype)
|
||||
|
||||
def normal(mean, std, size=None, dtype="float32"):
|
||||
if size is None:
|
||||
if isinstance(mean, Var) and isinstance(std, Var):
|
||||
assert mean.shape == std.shape
|
||||
size = mean.shape
|
||||
else:
|
||||
if isinstance(mean, Var): size = mean.shape
|
||||
if isinstance(std, Var): size = std.shape
|
||||
return jt.init.gauss(size, dtype, mean, std)
|
||||
|
||||
def attrs(var):
|
||||
return {
|
||||
"is_stop_fuse": var.is_stop_fuse(),
|
||||
|
@ -966,10 +1005,10 @@ Var.double = Var.float64
|
|||
|
||||
# __array__ interface is used for np.array(jt_var)
|
||||
Var.__array__ = Var.numpy
|
||||
# __getstate__, __setstate__, __module__ is used for pickle.dump and pickle.load
|
||||
Var.__getstate__ = Var.numpy
|
||||
Var.__setstate__ = Var.__init__
|
||||
Var.__array_priority__ = 2000
|
||||
# __reduce__, __module__ is used for pickle.dump and pickle.load
|
||||
Var.__module__ = "jittor"
|
||||
Var.__reduce__ = lambda self: (Var, (self.data,))
|
||||
|
||||
from . import nn
|
||||
from . import attention
|
||||
|
@ -981,26 +1020,3 @@ from . import numpy2cupy
|
|||
from .contrib import concat
|
||||
from .misc import *
|
||||
from . import sparse
|
||||
|
||||
|
||||
def randn(*size, dtype="float32", requires_grad=False):
|
||||
if isinstance(size, tuple) and isinstance(size[0], tuple): size = size[0]
|
||||
arr = jt.random(size, dtype, "normal")
|
||||
if not requires_grad: return arr.stop_grad()
|
||||
return arr
|
||||
|
||||
def rand(*size, dtype="float32", requires_grad=False):
|
||||
if isinstance(size, tuple) and isinstance(size[0], tuple): size = size[0]
|
||||
arr = jt.random(size, dtype)
|
||||
if not requires_grad: return arr.stop_grad()
|
||||
return arr
|
||||
|
||||
def normal(mean, std, size=None, dtype="float32"):
|
||||
if size is None:
|
||||
if isinstance(mean, Var) and isinstance(std, Var):
|
||||
assert mean.shape == std.shape
|
||||
size = mean.shape
|
||||
else:
|
||||
if isinstance(mean, Var): size = mean.shape
|
||||
if isinstance(std, Var): size = std.shape
|
||||
return jt.init.gauss(size, dtype, mean, std)
|
||||
|
|
|
@ -147,11 +147,11 @@ class TestArray(unittest.TestCase):
|
|||
|
||||
def test_pickle(self):
|
||||
import pickle
|
||||
a = jt.Var([1,2,3])
|
||||
s = pickle.dumps(a)
|
||||
a = jt.Var([1,2,3,4])
|
||||
s = pickle.dumps(a, pickle.HIGHEST_PROTOCOL)
|
||||
b = pickle.loads(s)
|
||||
assert isinstance(b, jt.Var)
|
||||
assert (b.data == [1,2,3]).all()
|
||||
assert (b.data == [1,2,3,4]).all()
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -87,6 +87,18 @@ class TestRandomOp(unittest.TestCase):
|
|||
def test_normal_cuda(self):
|
||||
self.test_normal()
|
||||
|
||||
def test_other_rand(self):
|
||||
a = jt.array([1.0,2.0,3.0])
|
||||
b = jt.rand_like(a)
|
||||
c = jt.randn_like(a)
|
||||
assert b.shape == c.shape
|
||||
assert b.shape == a.shape
|
||||
print(b, c)
|
||||
assert jt.randint(10, 20, (2000,)).min() == 10
|
||||
assert jt.randint(10, 20, (2000,)).max() == 19
|
||||
assert jt.randint(10, shape=(2000,)).max() == 9
|
||||
assert jt.randint_like(a, 10).shape == a.shape
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -15,8 +15,15 @@ from jittor.dataset.mnist import MNIST
|
|||
import jittor.transform as trans
|
||||
from tqdm import tqdm
|
||||
|
||||
class BBox:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
def __eq__(self, other):
|
||||
return bool((self.x == other.x).all())
|
||||
|
||||
def test_ring_buffer():
|
||||
buffer = jt.RingBuffer(1000)
|
||||
buffer = jt.RingBuffer(2000)
|
||||
def test_send_recv(data):
|
||||
print("test send recv", type(data))
|
||||
buffer.push(data)
|
||||
|
@ -65,6 +72,9 @@ def test_ring_buffer():
|
|||
|
||||
test_send_recv(jt.array(np.random.rand(10,10)))
|
||||
|
||||
bbox = BBox(jt.array(np.random.rand(10,10)))
|
||||
test_send_recv(bbox)
|
||||
|
||||
expect_error(lambda: test_send_recv(np.random.rand(10,1000)))
|
||||
|
||||
|
||||
|
|
|
@ -121,8 +121,8 @@ struct RingBuffer {
|
|||
rr = c2 << size_bit;
|
||||
rr_next = rr + size;
|
||||
}
|
||||
CHECK(rr_next <= r+this->size) << "Buffer size too small, please increase buffer size.";
|
||||
while (rr_next > l + this->size) {
|
||||
CHECKop(size,<=,this->size) << "Buffer size too small, please increase buffer size.";
|
||||
wait();
|
||||
}
|
||||
offset = rr_next;
|
||||
|
|
|
@ -15,8 +15,9 @@ namespace jittor {
|
|||
static void push_py_object_pickle(RingBuffer* rb, PyObject* obj, uint64& __restrict__ offset) {
|
||||
PyObjHolder pickle(PyImport_ImportModule("pickle"));
|
||||
PyObjHolder dumps(PyObject_GetAttrString(pickle.obj, "dumps"));
|
||||
PyObjHolder proto(PyObject_GetAttrString(pickle.obj, "HIGHEST_PROTOCOL"));
|
||||
rb->push_t<uint8>(6, offset);
|
||||
PyObjHolder ret(PyObject_CallFunctionObjArgs(dumps.obj, obj, nullptr));
|
||||
PyObjHolder ret(PyObject_CallFunctionObjArgs(dumps.obj, obj, proto.obj, nullptr));
|
||||
obj = ret.obj;
|
||||
Py_ssize_t size;
|
||||
char* s;
|
||||
|
|
Loading…
Reference in New Issue