jt.array type check

This commit is contained in:
Dun Liang 2021-01-30 16:29:55 +08:00
parent b6d262e6b6
commit 4883d75e1d
5 changed files with 14 additions and 4 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.22'
__version__ = '1.2.2.23'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -8,6 +8,7 @@ import unittest
import jittor as jt
import numpy as np
from jittor import compile_extern
from jittor.test.test_core import expect_error
class TestArray(unittest.TestCase):
def test_data(self):
@ -153,6 +154,11 @@ class TestArray(unittest.TestCase):
assert isinstance(b, jt.Var)
assert (b.data == [1,2,3,4]).all()
def test_tuple_array(self):
a = jt.array((4,5))
expect_error(lambda : jt.array({}))
expect_error(lambda : jt.array("asdasd"))
expect_error(lambda : jt.array(jt))
if __name__ == "__main__":

View File

@ -121,7 +121,8 @@ struct RingBuffer {
rr = c2 << size_bit;
rr_next = rr + size;
}
CHECK(rr_next <= r+this->size) << "Buffer size too small, please increase buffer size.";
CHECK(rr_next <= r+this->size) << "Buffer size too small, please increase buffer size. Current size:"
<< this->size << "Required size:" << rr_next - r;
while (rr_next > l + this->size) {
wait();
}

View File

@ -176,7 +176,7 @@ static void getitem_inplace(GetitemOp* op) {
auto ou = op->outputs().front();
// return if input or output's shape is variable
if (in->num < 0 || ou->num < 0)
if (in->num <= 0 || ou->num <= 0)
return;
VarSlices vs = op->vs;

View File

@ -46,6 +46,7 @@ ArrayOp::ArrayOp(PyObject* obj) {
ArrayArgs args;
PyObjHolder holder;
args.ptr = nullptr;
allocation.ptr = nullptr;
if (PyFloat_CheckExact(obj)) {
tmp_data.f32 = PyFloat_AS_DOUBLE(obj);
args = {&tmp_data, 1, ns_float32};
@ -63,7 +64,7 @@ ArrayOp::ArrayOp(PyObject* obj) {
args = move(fetch_sync({ptr}).at(0));
} else
if (Py_TYPE(obj) == PyArray_Type ||
PyList_CheckExact(obj) ||
PyList_CheckExact(obj) || PyTuple_CheckExact(obj) ||
PyObject_TypeCheck(obj, PyNumberArrType_Type)
) {
if (Py_TYPE(obj) != PyArray_Type) {
@ -97,6 +98,8 @@ ArrayOp::ArrayOp(PyObject* obj) {
}
}
} else {
LOGf << "type <" >> Py_TYPE(obj)->tp_name >> "> not support for jittor array";
}
NanoVector shape = args.shape;
output = create_output(shape, args.dtype);