mirror of https://github.com/Jittor/Jittor
jt.array type check
This commit is contained in:
parent
b6d262e6b6
commit
4883d75e1d
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.22'
|
||||
__version__ = '1.2.2.23'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -8,6 +8,7 @@ import unittest
|
|||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor import compile_extern
|
||||
from jittor.test.test_core import expect_error
|
||||
|
||||
class TestArray(unittest.TestCase):
|
||||
def test_data(self):
|
||||
|
@ -153,6 +154,11 @@ class TestArray(unittest.TestCase):
|
|||
assert isinstance(b, jt.Var)
|
||||
assert (b.data == [1,2,3,4]).all()
|
||||
|
||||
def test_tuple_array(self):
|
||||
a = jt.array((4,5))
|
||||
expect_error(lambda : jt.array({}))
|
||||
expect_error(lambda : jt.array("asdasd"))
|
||||
expect_error(lambda : jt.array(jt))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -121,7 +121,8 @@ struct RingBuffer {
|
|||
rr = c2 << size_bit;
|
||||
rr_next = rr + size;
|
||||
}
|
||||
CHECK(rr_next <= r+this->size) << "Buffer size too small, please increase buffer size.";
|
||||
CHECK(rr_next <= r+this->size) << "Buffer size too small, please increase buffer size. Current size:"
|
||||
<< this->size << "Required size:" << rr_next - r;
|
||||
while (rr_next > l + this->size) {
|
||||
wait();
|
||||
}
|
||||
|
|
|
@ -176,7 +176,7 @@ static void getitem_inplace(GetitemOp* op) {
|
|||
auto ou = op->outputs().front();
|
||||
|
||||
// return if input or output's shape is variable
|
||||
if (in->num < 0 || ou->num < 0)
|
||||
if (in->num <= 0 || ou->num <= 0)
|
||||
return;
|
||||
|
||||
VarSlices vs = op->vs;
|
||||
|
|
|
@ -46,6 +46,7 @@ ArrayOp::ArrayOp(PyObject* obj) {
|
|||
ArrayArgs args;
|
||||
PyObjHolder holder;
|
||||
args.ptr = nullptr;
|
||||
allocation.ptr = nullptr;
|
||||
if (PyFloat_CheckExact(obj)) {
|
||||
tmp_data.f32 = PyFloat_AS_DOUBLE(obj);
|
||||
args = {&tmp_data, 1, ns_float32};
|
||||
|
@ -63,7 +64,7 @@ ArrayOp::ArrayOp(PyObject* obj) {
|
|||
args = move(fetch_sync({ptr}).at(0));
|
||||
} else
|
||||
if (Py_TYPE(obj) == PyArray_Type ||
|
||||
PyList_CheckExact(obj) ||
|
||||
PyList_CheckExact(obj) || PyTuple_CheckExact(obj) ||
|
||||
PyObject_TypeCheck(obj, PyNumberArrType_Type)
|
||||
) {
|
||||
if (Py_TYPE(obj) != PyArray_Type) {
|
||||
|
@ -97,6 +98,8 @@ ArrayOp::ArrayOp(PyObject* obj) {
|
|||
}
|
||||
|
||||
}
|
||||
} else {
|
||||
LOGf << "type <" >> Py_TYPE(obj)->tp_name >> "> not support for jittor array";
|
||||
}
|
||||
NanoVector shape = args.shape;
|
||||
output = create_output(shape, args.dtype);
|
||||
|
|
Loading…
Reference in New Issue