mirror of https://github.com/Jittor/Jittor
fix bool setitem and reshape NanoVector
This commit is contained in:
parent
40cdd27a01
commit
431ab5c70f
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.1.6.5'
|
||||
__version__ = '1.1.6.6'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
@ -196,6 +196,7 @@ def clean():
|
|||
gc.collect()
|
||||
|
||||
cast = unary
|
||||
Var.cast = Var.cast
|
||||
|
||||
def array(data, dtype=None):
|
||||
if isinstance(data, core.Var):
|
||||
|
@ -250,7 +251,7 @@ Var.norm = norm
|
|||
|
||||
origin_reshape = reshape
|
||||
def reshape(x, *shape):
|
||||
if len(shape) == 1 and isinstance(shape[0], Sequence):
|
||||
if len(shape) == 1 and isinstance(shape[0], (Sequence, NanoVector)):
|
||||
shape = shape[0]
|
||||
return origin_reshape(x, shape)
|
||||
reshape.__doc__ = origin_reshape.__doc__
|
||||
|
@ -258,7 +259,7 @@ Var.view = Var.reshape = view = reshape
|
|||
|
||||
origin_transpose = transpose
|
||||
def transpose(x, *dim):
|
||||
if len(dim) == 1 and isinstance(dim[0], Sequence):
|
||||
if len(dim) == 1 and isinstance(dim[0], (Sequence, NanoVector)):
|
||||
dim = dim[0]
|
||||
return origin_transpose(x, dim)
|
||||
transpose.__doc__ = origin_transpose.__doc__
|
||||
|
|
|
@ -147,6 +147,7 @@ def setitem(x, slices, value):
|
|||
reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:]
|
||||
xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse()
|
||||
value = jt.broadcast(value, xslice)
|
||||
value = value.cast(x.dtype)
|
||||
one = jt.broadcast(1, xslice)
|
||||
if not isinstance(reindex_args[0][0], jt.Var):
|
||||
reindex_args = (x.shape,) + reindex_args[1:]
|
||||
|
|
|
@ -33,8 +33,8 @@ class TestBinaryOp(unittest.TestCase):
|
|||
assert (x == 8).all()
|
||||
x = (jt.array(2) ** jt.array(3)).data
|
||||
assert (x == 8).all()
|
||||
a = [1,2,3]
|
||||
b = [7,10,13]
|
||||
a = np.array([1,2,3])
|
||||
b = np.array([7,10,13])
|
||||
check("logical_and", a, b)
|
||||
check("logical_or", a, b)
|
||||
check("logical_xor", a, b)
|
||||
|
@ -79,6 +79,8 @@ class TestBinaryOp(unittest.TestCase):
|
|||
|
||||
def test_r(self):
|
||||
def check(op, a, b):
|
||||
a = np.array(a)
|
||||
b = np.array(b)
|
||||
if jt.flags.use_cuda and op == "@":
|
||||
return
|
||||
jb = jt.array(b)
|
||||
|
|
|
@ -15,9 +15,9 @@ class TestConcatOp(unittest.TestCase):
|
|||
res2 = jt.contrib.concat(tmp, dim=dim)
|
||||
assert (res1!=res2).data.sum()==0, "concat fail..."
|
||||
check([jt.array([[1],[2]]), jt.array([[2],[2]])])
|
||||
check([jt.array(range(24)).reshape((1,2,3,4)), jt.array(range(24)).reshape((1,2,3,4))])
|
||||
check([jt.array(range(120)).reshape((5,2,3,4)), jt.array(range(24)).reshape((1,2,3,4))])
|
||||
check([jt.array(range(5)).reshape((5,1)), jt.array(range(1)).reshape((1,1))])
|
||||
check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
|
||||
check([jt.array(np.array(range(120))).reshape((5,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
|
||||
check([jt.array(np.array(range(5))).reshape((5,1)), jt.array(np.array(range(1))).reshape((1,1))])
|
||||
print('concat success...')
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -50,7 +50,7 @@ class TestNanoString(unittest.TestCase):
|
|||
assert str(jt.NanoString(jt.float32)) == "float32"
|
||||
assert str(jt.NanoString(jt.float64)) == "float64"
|
||||
assert str(jt.NanoString(jt.int8)) == "int8"
|
||||
assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int64"
|
||||
assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int32"
|
||||
assert str(jt.NanoString(jt.sum)) == "add"
|
||||
|
||||
def get_error_str(call):
|
||||
|
|
|
@ -53,9 +53,9 @@ class TestNode(unittest.TestCase):
|
|||
da, db = jt.grad(c, [a, b])
|
||||
da.name('da')
|
||||
db.name('db')
|
||||
check(5,7,5) # dc, 3, da, 1, db, 1
|
||||
check(5,6,4) # dc, 3, da, 1, db, 1
|
||||
del a, b, c
|
||||
check(2,6,4)
|
||||
check(2,5,3)
|
||||
da.sync(), db.sync()
|
||||
check(2,2,0)
|
||||
del da, db
|
||||
|
|
|
@ -70,6 +70,10 @@ class TestReshapeOp(unittest.TestCase):
|
|||
assert a.flatten(1).shape == [2,12]
|
||||
assert a.flatten(0,-2).shape == [6,4]
|
||||
|
||||
def test_reshape_var(self):
|
||||
a = jt.zeros(10)
|
||||
b = a.reshape(a.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,24 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor.
|
||||
# Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestSlice(unittest.TestCase):
|
||||
def test_slice_bool(self):
|
||||
a = jt.zeros(10, "bool")
|
||||
a[1] = True
|
||||
a[2] = 1
|
||||
assert a.dtype == "bool"
|
||||
print(a)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -25,12 +25,12 @@ class TestUnaryOp(unittest.TestCase):
|
|||
assert jt.float64(1).data.dtype == "float64"
|
||||
assert (jt.abs(-1) == 1).data.all()
|
||||
assert (abs(-jt.float64(1)) == 1).data.all()
|
||||
a = [-1,2,3,0]
|
||||
a = np.array([-1,2,3,0])
|
||||
check("abs", a)
|
||||
check("negative", a)
|
||||
check("logical_not", a)
|
||||
check("bitwise_not", a)
|
||||
b = [1.1, 2.2, 3.3, 4.4, -1, 0]
|
||||
b = np.array([1.1, 2.2, 3.3, 4.4, -1, 0])
|
||||
check("log", a)
|
||||
check("exp", a)
|
||||
check("sqrt", a)
|
||||
|
@ -42,7 +42,7 @@ class TestUnaryOp(unittest.TestCase):
|
|||
"cos", "arccos", "cosh", "arccosh",
|
||||
"sigmoid",
|
||||
]
|
||||
a = [1.1, 2.2, 3.3, 4.4]
|
||||
a = np.array([1.1, 2.2, 3.3, 4.4])
|
||||
for op in ops:
|
||||
if op == "abs":
|
||||
b = np.array(a+[-1,])
|
||||
|
|
|
@ -156,7 +156,7 @@ NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0) {
|
|||
if (dsize==8) return ns_int64;
|
||||
if (dsize==4) return ns_int32;
|
||||
if (dsize==2) return ns_int16;
|
||||
return ns_int8;
|
||||
return v1;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -75,6 +75,10 @@ UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
|
|||
ASSERT(ns.is_unary() | ns.is_dtype());
|
||||
NanoString dtype;
|
||||
if (ns.is_dtype()) {
|
||||
if (ns == x->dtype()) {
|
||||
forward(x);
|
||||
return;
|
||||
}
|
||||
dtype = ns;
|
||||
ns = ns_cast;
|
||||
} else if (ns.is_bool())
|
||||
|
|
|
@ -128,6 +128,7 @@ inline int64 PyArray_Size(PyArray_Proxy* arr) {
|
|||
union tmp_data_t {
|
||||
int32 i32;
|
||||
float32 f32;
|
||||
int8 i8;
|
||||
};
|
||||
|
||||
extern tmp_data_t tmp_data;
|
||||
|
|
|
@ -261,6 +261,7 @@ DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) {
|
|||
return Py_TYPE(obj) == PyArray_Type ||
|
||||
PyFloat_CheckExact(obj) ||
|
||||
PyLong_CheckExact(obj) ||
|
||||
PyBool_Check(obj) ||
|
||||
PyList_CheckExact(obj) ||
|
||||
Py_TYPE(obj) == &PyjtVarHolder.ht_type;
|
||||
}
|
||||
|
@ -289,6 +290,10 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) {
|
|||
tmp_data.i32 = PyLong_AsLong(obj);
|
||||
return {&tmp_data, 1, ns_int32};
|
||||
}
|
||||
if (PyBool_Check(obj)) {
|
||||
tmp_data.i8 = obj == Py_True;
|
||||
return {&tmp_data, 1, ns_bool};
|
||||
}
|
||||
if (Py_TYPE(obj) == &PyjtVarHolder.ht_type) {
|
||||
auto ptr = GET_RAW_PTR(VarHolder, obj);
|
||||
return move(fetch_sync({ptr}).at(0));
|
||||
|
|
Loading…
Reference in New Issue