mirror of https://github.com/Jittor/Jittor
fix bool setitem and reshape NanoVector
This commit is contained in:
parent
40cdd27a01
commit
431ab5c70f
|
@ -7,7 +7,7 @@
|
||||||
# This file is subject to the terms and conditions defined in
|
# This file is subject to the terms and conditions defined in
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
__version__ = '1.1.6.5'
|
__version__ = '1.1.6.6'
|
||||||
from . import lock
|
from . import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
from . import compiler
|
from . import compiler
|
||||||
|
@ -196,6 +196,7 @@ def clean():
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
cast = unary
|
cast = unary
|
||||||
|
Var.cast = Var.cast
|
||||||
|
|
||||||
def array(data, dtype=None):
|
def array(data, dtype=None):
|
||||||
if isinstance(data, core.Var):
|
if isinstance(data, core.Var):
|
||||||
|
@ -250,7 +251,7 @@ Var.norm = norm
|
||||||
|
|
||||||
origin_reshape = reshape
|
origin_reshape = reshape
|
||||||
def reshape(x, *shape):
|
def reshape(x, *shape):
|
||||||
if len(shape) == 1 and isinstance(shape[0], Sequence):
|
if len(shape) == 1 and isinstance(shape[0], (Sequence, NanoVector)):
|
||||||
shape = shape[0]
|
shape = shape[0]
|
||||||
return origin_reshape(x, shape)
|
return origin_reshape(x, shape)
|
||||||
reshape.__doc__ = origin_reshape.__doc__
|
reshape.__doc__ = origin_reshape.__doc__
|
||||||
|
@ -258,7 +259,7 @@ Var.view = Var.reshape = view = reshape
|
||||||
|
|
||||||
origin_transpose = transpose
|
origin_transpose = transpose
|
||||||
def transpose(x, *dim):
|
def transpose(x, *dim):
|
||||||
if len(dim) == 1 and isinstance(dim[0], Sequence):
|
if len(dim) == 1 and isinstance(dim[0], (Sequence, NanoVector)):
|
||||||
dim = dim[0]
|
dim = dim[0]
|
||||||
return origin_transpose(x, dim)
|
return origin_transpose(x, dim)
|
||||||
transpose.__doc__ = origin_transpose.__doc__
|
transpose.__doc__ = origin_transpose.__doc__
|
||||||
|
|
|
@ -147,6 +147,7 @@ def setitem(x, slices, value):
|
||||||
reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:]
|
reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:]
|
||||||
xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse()
|
xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse()
|
||||||
value = jt.broadcast(value, xslice)
|
value = jt.broadcast(value, xslice)
|
||||||
|
value = value.cast(x.dtype)
|
||||||
one = jt.broadcast(1, xslice)
|
one = jt.broadcast(1, xslice)
|
||||||
if not isinstance(reindex_args[0][0], jt.Var):
|
if not isinstance(reindex_args[0][0], jt.Var):
|
||||||
reindex_args = (x.shape,) + reindex_args[1:]
|
reindex_args = (x.shape,) + reindex_args[1:]
|
||||||
|
|
|
@ -33,8 +33,8 @@ class TestBinaryOp(unittest.TestCase):
|
||||||
assert (x == 8).all()
|
assert (x == 8).all()
|
||||||
x = (jt.array(2) ** jt.array(3)).data
|
x = (jt.array(2) ** jt.array(3)).data
|
||||||
assert (x == 8).all()
|
assert (x == 8).all()
|
||||||
a = [1,2,3]
|
a = np.array([1,2,3])
|
||||||
b = [7,10,13]
|
b = np.array([7,10,13])
|
||||||
check("logical_and", a, b)
|
check("logical_and", a, b)
|
||||||
check("logical_or", a, b)
|
check("logical_or", a, b)
|
||||||
check("logical_xor", a, b)
|
check("logical_xor", a, b)
|
||||||
|
@ -79,6 +79,8 @@ class TestBinaryOp(unittest.TestCase):
|
||||||
|
|
||||||
def test_r(self):
|
def test_r(self):
|
||||||
def check(op, a, b):
|
def check(op, a, b):
|
||||||
|
a = np.array(a)
|
||||||
|
b = np.array(b)
|
||||||
if jt.flags.use_cuda and op == "@":
|
if jt.flags.use_cuda and op == "@":
|
||||||
return
|
return
|
||||||
jb = jt.array(b)
|
jb = jt.array(b)
|
||||||
|
|
|
@ -15,9 +15,9 @@ class TestConcatOp(unittest.TestCase):
|
||||||
res2 = jt.contrib.concat(tmp, dim=dim)
|
res2 = jt.contrib.concat(tmp, dim=dim)
|
||||||
assert (res1!=res2).data.sum()==0, "concat fail..."
|
assert (res1!=res2).data.sum()==0, "concat fail..."
|
||||||
check([jt.array([[1],[2]]), jt.array([[2],[2]])])
|
check([jt.array([[1],[2]]), jt.array([[2],[2]])])
|
||||||
check([jt.array(range(24)).reshape((1,2,3,4)), jt.array(range(24)).reshape((1,2,3,4))])
|
check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
|
||||||
check([jt.array(range(120)).reshape((5,2,3,4)), jt.array(range(24)).reshape((1,2,3,4))])
|
check([jt.array(np.array(range(120))).reshape((5,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
|
||||||
check([jt.array(range(5)).reshape((5,1)), jt.array(range(1)).reshape((1,1))])
|
check([jt.array(np.array(range(5))).reshape((5,1)), jt.array(np.array(range(1))).reshape((1,1))])
|
||||||
print('concat success...')
|
print('concat success...')
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -50,7 +50,7 @@ class TestNanoString(unittest.TestCase):
|
||||||
assert str(jt.NanoString(jt.float32)) == "float32"
|
assert str(jt.NanoString(jt.float32)) == "float32"
|
||||||
assert str(jt.NanoString(jt.float64)) == "float64"
|
assert str(jt.NanoString(jt.float64)) == "float64"
|
||||||
assert str(jt.NanoString(jt.int8)) == "int8"
|
assert str(jt.NanoString(jt.int8)) == "int8"
|
||||||
assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int64"
|
assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int32"
|
||||||
assert str(jt.NanoString(jt.sum)) == "add"
|
assert str(jt.NanoString(jt.sum)) == "add"
|
||||||
|
|
||||||
def get_error_str(call):
|
def get_error_str(call):
|
||||||
|
|
|
@ -53,9 +53,9 @@ class TestNode(unittest.TestCase):
|
||||||
da, db = jt.grad(c, [a, b])
|
da, db = jt.grad(c, [a, b])
|
||||||
da.name('da')
|
da.name('da')
|
||||||
db.name('db')
|
db.name('db')
|
||||||
check(5,7,5) # dc, 3, da, 1, db, 1
|
check(5,6,4) # dc, 3, da, 1, db, 1
|
||||||
del a, b, c
|
del a, b, c
|
||||||
check(2,6,4)
|
check(2,5,3)
|
||||||
da.sync(), db.sync()
|
da.sync(), db.sync()
|
||||||
check(2,2,0)
|
check(2,2,0)
|
||||||
del da, db
|
del da, db
|
||||||
|
|
|
@ -70,6 +70,10 @@ class TestReshapeOp(unittest.TestCase):
|
||||||
assert a.flatten(1).shape == [2,12]
|
assert a.flatten(1).shape == [2,12]
|
||||||
assert a.flatten(0,-2).shape == [6,4]
|
assert a.flatten(0,-2).shape == [6,4]
|
||||||
|
|
||||||
|
def test_reshape_var(self):
|
||||||
|
a = jt.zeros(10)
|
||||||
|
b = a.reshape(a.shape)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
|
@ -0,0 +1,24 @@
|
||||||
|
# ***************************************************************
|
||||||
|
# Copyright (c) 2020 Jittor.
|
||||||
|
# Authors:
|
||||||
|
# Dun Liang <randonlang@gmail.com>.
|
||||||
|
# All Rights Reserved.
|
||||||
|
# This file is subject to the terms and conditions defined in
|
||||||
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
# ***************************************************************
|
||||||
|
import unittest
|
||||||
|
import jittor as jt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class TestSlice(unittest.TestCase):
|
||||||
|
def test_slice_bool(self):
|
||||||
|
a = jt.zeros(10, "bool")
|
||||||
|
a[1] = True
|
||||||
|
a[2] = 1
|
||||||
|
assert a.dtype == "bool"
|
||||||
|
print(a)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
|
@ -25,12 +25,12 @@ class TestUnaryOp(unittest.TestCase):
|
||||||
assert jt.float64(1).data.dtype == "float64"
|
assert jt.float64(1).data.dtype == "float64"
|
||||||
assert (jt.abs(-1) == 1).data.all()
|
assert (jt.abs(-1) == 1).data.all()
|
||||||
assert (abs(-jt.float64(1)) == 1).data.all()
|
assert (abs(-jt.float64(1)) == 1).data.all()
|
||||||
a = [-1,2,3,0]
|
a = np.array([-1,2,3,0])
|
||||||
check("abs", a)
|
check("abs", a)
|
||||||
check("negative", a)
|
check("negative", a)
|
||||||
check("logical_not", a)
|
check("logical_not", a)
|
||||||
check("bitwise_not", a)
|
check("bitwise_not", a)
|
||||||
b = [1.1, 2.2, 3.3, 4.4, -1, 0]
|
b = np.array([1.1, 2.2, 3.3, 4.4, -1, 0])
|
||||||
check("log", a)
|
check("log", a)
|
||||||
check("exp", a)
|
check("exp", a)
|
||||||
check("sqrt", a)
|
check("sqrt", a)
|
||||||
|
@ -42,7 +42,7 @@ class TestUnaryOp(unittest.TestCase):
|
||||||
"cos", "arccos", "cosh", "arccosh",
|
"cos", "arccos", "cosh", "arccosh",
|
||||||
"sigmoid",
|
"sigmoid",
|
||||||
]
|
]
|
||||||
a = [1.1, 2.2, 3.3, 4.4]
|
a = np.array([1.1, 2.2, 3.3, 4.4])
|
||||||
for op in ops:
|
for op in ops:
|
||||||
if op == "abs":
|
if op == "abs":
|
||||||
b = np.array(a+[-1,])
|
b = np.array(a+[-1,])
|
||||||
|
|
|
@ -156,7 +156,7 @@ NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0) {
|
||||||
if (dsize==8) return ns_int64;
|
if (dsize==8) return ns_int64;
|
||||||
if (dsize==4) return ns_int32;
|
if (dsize==4) return ns_int32;
|
||||||
if (dsize==2) return ns_int16;
|
if (dsize==2) return ns_int16;
|
||||||
return ns_int8;
|
return v1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -75,6 +75,10 @@ UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
|
||||||
ASSERT(ns.is_unary() | ns.is_dtype());
|
ASSERT(ns.is_unary() | ns.is_dtype());
|
||||||
NanoString dtype;
|
NanoString dtype;
|
||||||
if (ns.is_dtype()) {
|
if (ns.is_dtype()) {
|
||||||
|
if (ns == x->dtype()) {
|
||||||
|
forward(x);
|
||||||
|
return;
|
||||||
|
}
|
||||||
dtype = ns;
|
dtype = ns;
|
||||||
ns = ns_cast;
|
ns = ns_cast;
|
||||||
} else if (ns.is_bool())
|
} else if (ns.is_bool())
|
||||||
|
|
|
@ -128,6 +128,7 @@ inline int64 PyArray_Size(PyArray_Proxy* arr) {
|
||||||
union tmp_data_t {
|
union tmp_data_t {
|
||||||
int32 i32;
|
int32 i32;
|
||||||
float32 f32;
|
float32 f32;
|
||||||
|
int8 i8;
|
||||||
};
|
};
|
||||||
|
|
||||||
extern tmp_data_t tmp_data;
|
extern tmp_data_t tmp_data;
|
||||||
|
|
|
@ -261,6 +261,7 @@ DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) {
|
||||||
return Py_TYPE(obj) == PyArray_Type ||
|
return Py_TYPE(obj) == PyArray_Type ||
|
||||||
PyFloat_CheckExact(obj) ||
|
PyFloat_CheckExact(obj) ||
|
||||||
PyLong_CheckExact(obj) ||
|
PyLong_CheckExact(obj) ||
|
||||||
|
PyBool_Check(obj) ||
|
||||||
PyList_CheckExact(obj) ||
|
PyList_CheckExact(obj) ||
|
||||||
Py_TYPE(obj) == &PyjtVarHolder.ht_type;
|
Py_TYPE(obj) == &PyjtVarHolder.ht_type;
|
||||||
}
|
}
|
||||||
|
@ -289,6 +290,10 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) {
|
||||||
tmp_data.i32 = PyLong_AsLong(obj);
|
tmp_data.i32 = PyLong_AsLong(obj);
|
||||||
return {&tmp_data, 1, ns_int32};
|
return {&tmp_data, 1, ns_int32};
|
||||||
}
|
}
|
||||||
|
if (PyBool_Check(obj)) {
|
||||||
|
tmp_data.i8 = obj == Py_True;
|
||||||
|
return {&tmp_data, 1, ns_bool};
|
||||||
|
}
|
||||||
if (Py_TYPE(obj) == &PyjtVarHolder.ht_type) {
|
if (Py_TYPE(obj) == &PyjtVarHolder.ht_type) {
|
||||||
auto ptr = GET_RAW_PTR(VarHolder, obj);
|
auto ptr = GET_RAW_PTR(VarHolder, obj);
|
||||||
return move(fetch_sync({ptr}).at(0));
|
return move(fetch_sync({ptr}).at(0));
|
||||||
|
|
Loading…
Reference in New Issue