fix bool setitem and reshape NanoVector

This commit is contained in:
Dun Liang 2020-07-27 16:02:01 +08:00
parent 40cdd27a01
commit 431ab5c70f
13 changed files with 57 additions and 15 deletions

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.1.6.5'
__version__ = '1.1.6.6'
from . import lock
with lock.lock_scope():
from . import compiler
@ -196,6 +196,7 @@ def clean():
gc.collect()
cast = unary
Var.cast = Var.cast
def array(data, dtype=None):
if isinstance(data, core.Var):
@ -250,7 +251,7 @@ Var.norm = norm
origin_reshape = reshape
def reshape(x, *shape):
if len(shape) == 1 and isinstance(shape[0], Sequence):
if len(shape) == 1 and isinstance(shape[0], (Sequence, NanoVector)):
shape = shape[0]
return origin_reshape(x, shape)
reshape.__doc__ = origin_reshape.__doc__
@ -258,7 +259,7 @@ Var.view = Var.reshape = view = reshape
origin_transpose = transpose
def transpose(x, *dim):
if len(dim) == 1 and isinstance(dim[0], Sequence):
if len(dim) == 1 and isinstance(dim[0], (Sequence, NanoVector)):
dim = dim[0]
return origin_transpose(x, dim)
transpose.__doc__ = origin_transpose.__doc__

View File

@ -147,6 +147,7 @@ def setitem(x, slices, value):
reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:]
xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse()
value = jt.broadcast(value, xslice)
value = value.cast(x.dtype)
one = jt.broadcast(1, xslice)
if not isinstance(reindex_args[0][0], jt.Var):
reindex_args = (x.shape,) + reindex_args[1:]

View File

@ -33,8 +33,8 @@ class TestBinaryOp(unittest.TestCase):
assert (x == 8).all()
x = (jt.array(2) ** jt.array(3)).data
assert (x == 8).all()
a = [1,2,3]
b = [7,10,13]
a = np.array([1,2,3])
b = np.array([7,10,13])
check("logical_and", a, b)
check("logical_or", a, b)
check("logical_xor", a, b)
@ -79,6 +79,8 @@ class TestBinaryOp(unittest.TestCase):
def test_r(self):
def check(op, a, b):
a = np.array(a)
b = np.array(b)
if jt.flags.use_cuda and op == "@":
return
jb = jt.array(b)

View File

@ -15,9 +15,9 @@ class TestConcatOp(unittest.TestCase):
res2 = jt.contrib.concat(tmp, dim=dim)
assert (res1!=res2).data.sum()==0, "concat fail..."
check([jt.array([[1],[2]]), jt.array([[2],[2]])])
check([jt.array(range(24)).reshape((1,2,3,4)), jt.array(range(24)).reshape((1,2,3,4))])
check([jt.array(range(120)).reshape((5,2,3,4)), jt.array(range(24)).reshape((1,2,3,4))])
check([jt.array(range(5)).reshape((5,1)), jt.array(range(1)).reshape((1,1))])
check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
check([jt.array(np.array(range(120))).reshape((5,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
check([jt.array(np.array(range(5))).reshape((5,1)), jt.array(np.array(range(1))).reshape((1,1))])
print('concat success...')
if __name__ == "__main__":

View File

@ -50,7 +50,7 @@ class TestNanoString(unittest.TestCase):
assert str(jt.NanoString(jt.float32)) == "float32"
assert str(jt.NanoString(jt.float64)) == "float64"
assert str(jt.NanoString(jt.int8)) == "int8"
assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int64"
assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int32"
assert str(jt.NanoString(jt.sum)) == "add"
def get_error_str(call):

View File

@ -53,9 +53,9 @@ class TestNode(unittest.TestCase):
da, db = jt.grad(c, [a, b])
da.name('da')
db.name('db')
check(5,7,5) # dc, 3, da, 1, db, 1
check(5,6,4) # dc, 3, da, 1, db, 1
del a, b, c
check(2,6,4)
check(2,5,3)
da.sync(), db.sync()
check(2,2,0)
del da, db

View File

@ -70,6 +70,10 @@ class TestReshapeOp(unittest.TestCase):
assert a.flatten(1).shape == [2,12]
assert a.flatten(0,-2).shape == [6,4]
def test_reshape_var(self):
a = jt.zeros(10)
b = a.reshape(a.shape)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,24 @@
# ***************************************************************
# Copyright (c) 2020 Jittor.
# Authors:
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
class TestSlice(unittest.TestCase):
def test_slice_bool(self):
a = jt.zeros(10, "bool")
a[1] = True
a[2] = 1
assert a.dtype == "bool"
print(a)
if __name__ == "__main__":
unittest.main()

View File

@ -25,12 +25,12 @@ class TestUnaryOp(unittest.TestCase):
assert jt.float64(1).data.dtype == "float64"
assert (jt.abs(-1) == 1).data.all()
assert (abs(-jt.float64(1)) == 1).data.all()
a = [-1,2,3,0]
a = np.array([-1,2,3,0])
check("abs", a)
check("negative", a)
check("logical_not", a)
check("bitwise_not", a)
b = [1.1, 2.2, 3.3, 4.4, -1, 0]
b = np.array([1.1, 2.2, 3.3, 4.4, -1, 0])
check("log", a)
check("exp", a)
check("sqrt", a)
@ -42,7 +42,7 @@ class TestUnaryOp(unittest.TestCase):
"cos", "arccos", "cosh", "arccosh",
"sigmoid",
]
a = [1.1, 2.2, 3.3, 4.4]
a = np.array([1.1, 2.2, 3.3, 4.4])
for op in ops:
if op == "abs":
b = np.array(a+[-1,])

View File

@ -156,7 +156,7 @@ NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0) {
if (dsize==8) return ns_int64;
if (dsize==4) return ns_int32;
if (dsize==2) return ns_int16;
return ns_int8;
return v1;
}
}

View File

@ -75,6 +75,10 @@ UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
ASSERT(ns.is_unary() | ns.is_dtype());
NanoString dtype;
if (ns.is_dtype()) {
if (ns == x->dtype()) {
forward(x);
return;
}
dtype = ns;
ns = ns_cast;
} else if (ns.is_bool())

View File

@ -128,6 +128,7 @@ inline int64 PyArray_Size(PyArray_Proxy* arr) {
union tmp_data_t {
int32 i32;
float32 f32;
int8 i8;
};
extern tmp_data_t tmp_data;

View File

@ -261,6 +261,7 @@ DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) {
return Py_TYPE(obj) == PyArray_Type ||
PyFloat_CheckExact(obj) ||
PyLong_CheckExact(obj) ||
PyBool_Check(obj) ||
PyList_CheckExact(obj) ||
Py_TYPE(obj) == &PyjtVarHolder.ht_type;
}
@ -289,6 +290,10 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) {
tmp_data.i32 = PyLong_AsLong(obj);
return {&tmp_data, 1, ns_int32};
}
if (PyBool_Check(obj)) {
tmp_data.i8 = obj == Py_True;
return {&tmp_data, 1, ns_bool};
}
if (Py_TYPE(obj) == &PyjtVarHolder.ht_type) {
auto ptr = GET_RAW_PTR(VarHolder, obj);
return move(fetch_sync({ptr}).at(0));