diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index fe0a3548..fac31a2d 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -7,7 +7,7 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.1.6.5' +__version__ = '1.1.6.6' from . import lock with lock.lock_scope(): from . import compiler @@ -196,6 +196,7 @@ def clean(): gc.collect() cast = unary +Var.cast = Var.cast def array(data, dtype=None): if isinstance(data, core.Var): @@ -250,7 +251,7 @@ Var.norm = norm origin_reshape = reshape def reshape(x, *shape): - if len(shape) == 1 and isinstance(shape[0], Sequence): + if len(shape) == 1 and isinstance(shape[0], (Sequence, NanoVector)): shape = shape[0] return origin_reshape(x, shape) reshape.__doc__ = origin_reshape.__doc__ @@ -258,7 +259,7 @@ Var.view = Var.reshape = view = reshape origin_transpose = transpose def transpose(x, *dim): - if len(dim) == 1 and isinstance(dim[0], Sequence): + if len(dim) == 1 and isinstance(dim[0], (Sequence, NanoVector)): dim = dim[0] return origin_transpose(x, dim) transpose.__doc__ = origin_transpose.__doc__ diff --git a/python/jittor/contrib.py b/python/jittor/contrib.py index bb0c05bd..b759bf70 100644 --- a/python/jittor/contrib.py +++ b/python/jittor/contrib.py @@ -147,6 +147,7 @@ def setitem(x, slices, value): reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:] xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse() value = jt.broadcast(value, xslice) + value = value.cast(x.dtype) one = jt.broadcast(1, xslice) if not isinstance(reindex_args[0][0], jt.Var): reindex_args = (x.shape,) + reindex_args[1:] diff --git a/python/jittor/test/test_binary_op.py b/python/jittor/test/test_binary_op.py index 1a261a8b..f5d97990 100644 --- a/python/jittor/test/test_binary_op.py +++ b/python/jittor/test/test_binary_op.py @@ -33,8 +33,8 @@ class TestBinaryOp(unittest.TestCase): assert (x == 8).all() x = (jt.array(2) ** jt.array(3)).data assert (x == 8).all() - a = [1,2,3] - b = [7,10,13] + a = np.array([1,2,3]) + b = np.array([7,10,13]) check("logical_and", a, b) check("logical_or", a, b) check("logical_xor", a, b) @@ -79,6 +79,8 @@ class TestBinaryOp(unittest.TestCase): def test_r(self): def check(op, a, b): + a = np.array(a) + b = np.array(b) if jt.flags.use_cuda and op == "@": return jb = jt.array(b) diff --git a/python/jittor/test/test_concat_op.py b/python/jittor/test/test_concat_op.py index 61ed3074..04b413e4 100644 --- a/python/jittor/test/test_concat_op.py +++ b/python/jittor/test/test_concat_op.py @@ -15,9 +15,9 @@ class TestConcatOp(unittest.TestCase): res2 = jt.contrib.concat(tmp, dim=dim) assert (res1!=res2).data.sum()==0, "concat fail..." check([jt.array([[1],[2]]), jt.array([[2],[2]])]) - check([jt.array(range(24)).reshape((1,2,3,4)), jt.array(range(24)).reshape((1,2,3,4))]) - check([jt.array(range(120)).reshape((5,2,3,4)), jt.array(range(24)).reshape((1,2,3,4))]) - check([jt.array(range(5)).reshape((5,1)), jt.array(range(1)).reshape((1,1))]) + check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))]) + check([jt.array(np.array(range(120))).reshape((5,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))]) + check([jt.array(np.array(range(5))).reshape((5,1)), jt.array(np.array(range(1))).reshape((1,1))]) print('concat success...') if __name__ == "__main__": diff --git a/python/jittor/test/test_nano_string.py b/python/jittor/test/test_nano_string.py index 1f26e0a1..5587d7a4 100644 --- a/python/jittor/test/test_nano_string.py +++ b/python/jittor/test/test_nano_string.py @@ -50,7 +50,7 @@ class TestNanoString(unittest.TestCase): assert str(jt.NanoString(jt.float32)) == "float32" assert str(jt.NanoString(jt.float64)) == "float64" assert str(jt.NanoString(jt.int8)) == "int8" - assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int64" + assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int32" assert str(jt.NanoString(jt.sum)) == "add" def get_error_str(call): diff --git a/python/jittor/test/test_node.py b/python/jittor/test/test_node.py index b84afc52..7ca104d4 100644 --- a/python/jittor/test/test_node.py +++ b/python/jittor/test/test_node.py @@ -53,9 +53,9 @@ class TestNode(unittest.TestCase): da, db = jt.grad(c, [a, b]) da.name('da') db.name('db') - check(5,7,5) # dc, 3, da, 1, db, 1 + check(5,6,4) # dc, 3, da, 1, db, 1 del a, b, c - check(2,6,4) + check(2,5,3) da.sync(), db.sync() check(2,2,0) del da, db diff --git a/python/jittor/test/test_reshape.py b/python/jittor/test/test_reshape.py index 10e23447..d2880d3e 100644 --- a/python/jittor/test/test_reshape.py +++ b/python/jittor/test/test_reshape.py @@ -70,6 +70,10 @@ class TestReshapeOp(unittest.TestCase): assert a.flatten(1).shape == [2,12] assert a.flatten(0,-2).shape == [6,4] + def test_reshape_var(self): + a = jt.zeros(10) + b = a.reshape(a.shape) + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_slice.py b/python/jittor/test/test_slice.py new file mode 100644 index 00000000..a4402475 --- /dev/null +++ b/python/jittor/test/test_slice.py @@ -0,0 +1,24 @@ +# *************************************************************** +# Copyright (c) 2020 Jittor. +# Authors: +# Dun Liang . +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np + + +class TestSlice(unittest.TestCase): + def test_slice_bool(self): + a = jt.zeros(10, "bool") + a[1] = True + a[2] = 1 + assert a.dtype == "bool" + print(a) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_unary_op.py b/python/jittor/test/test_unary_op.py index c573b909..c25986e2 100644 --- a/python/jittor/test/test_unary_op.py +++ b/python/jittor/test/test_unary_op.py @@ -25,12 +25,12 @@ class TestUnaryOp(unittest.TestCase): assert jt.float64(1).data.dtype == "float64" assert (jt.abs(-1) == 1).data.all() assert (abs(-jt.float64(1)) == 1).data.all() - a = [-1,2,3,0] + a = np.array([-1,2,3,0]) check("abs", a) check("negative", a) check("logical_not", a) check("bitwise_not", a) - b = [1.1, 2.2, 3.3, 4.4, -1, 0] + b = np.array([1.1, 2.2, 3.3, 4.4, -1, 0]) check("log", a) check("exp", a) check("sqrt", a) @@ -42,7 +42,7 @@ class TestUnaryOp(unittest.TestCase): "cos", "arccos", "cosh", "arccosh", "sigmoid", ] - a = [1.1, 2.2, 3.3, 4.4] + a = np.array([1.1, 2.2, 3.3, 4.4]) for op in ops: if op == "abs": b = np.array(a+[-1,]) diff --git a/src/misc/nano_string.h b/src/misc/nano_string.h index 313f69f4..7fa71737 100644 --- a/src/misc/nano_string.h +++ b/src/misc/nano_string.h @@ -156,7 +156,7 @@ NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0) { if (dsize==8) return ns_int64; if (dsize==4) return ns_int32; if (dsize==2) return ns_int16; - return ns_int8; + return v1; } } diff --git a/src/ops/unary_op.cc b/src/ops/unary_op.cc index 9a652cee..ac6432f2 100644 --- a/src/ops/unary_op.cc +++ b/src/ops/unary_op.cc @@ -75,6 +75,10 @@ UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) { ASSERT(ns.is_unary() | ns.is_dtype()); NanoString dtype; if (ns.is_dtype()) { + if (ns == x->dtype()) { + forward(x); + return; + } dtype = ns; ns = ns_cast; } else if (ns.is_bool()) diff --git a/src/pyjt/numpy.h b/src/pyjt/numpy.h index 73d1c14d..a2853471 100644 --- a/src/pyjt/numpy.h +++ b/src/pyjt/numpy.h @@ -128,6 +128,7 @@ inline int64 PyArray_Size(PyArray_Proxy* arr) { union tmp_data_t { int32 i32; float32 f32; + int8 i8; }; extern tmp_data_t tmp_data; diff --git a/src/pyjt/py_converter.h b/src/pyjt/py_converter.h index cf544362..626dace1 100644 --- a/src/pyjt/py_converter.h +++ b/src/pyjt/py_converter.h @@ -261,6 +261,7 @@ DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) { return Py_TYPE(obj) == PyArray_Type || PyFloat_CheckExact(obj) || PyLong_CheckExact(obj) || + PyBool_Check(obj) || PyList_CheckExact(obj) || Py_TYPE(obj) == &PyjtVarHolder.ht_type; } @@ -289,6 +290,10 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) { tmp_data.i32 = PyLong_AsLong(obj); return {&tmp_data, 1, ns_int32}; } + if (PyBool_Check(obj)) { + tmp_data.i8 = obj == Py_True; + return {&tmp_data, 1, ns_bool}; + } if (Py_TYPE(obj) == &PyjtVarHolder.ht_type) { auto ptr = GET_RAW_PTR(VarHolder, obj); return move(fetch_sync({ptr}).at(0));