add bool setitem alias and vector_to_tuple

This commit is contained in:
Dun Liang 2022-01-10 12:36:28 +08:00
parent 0c35e1a29b
commit f36693c797
6 changed files with 42 additions and 4 deletions

View File

@ -1467,7 +1467,8 @@ from . import nn
from . import attention
from . import lr_scheduler
from . import linalg
from .nn import matmul
from .nn import matmul, \
bmm, bmm_transpose
from . import contrib
from . import numpy2cupy
from .contrib import concat

View File

@ -358,7 +358,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
jit_cc_src.append(f"""
/*{doc_string}*/
// @pyjt({",".join(pyjt_names)})
vector<VarHolder*> {cc_func_name}({", ".join(cc_args)}) {{
vector_to_tuple<VarHolder*> {cc_func_name}({", ".join(cc_args)}) {{
{ f'return make_vh_vector(make_{cc_func_name}({", ".join(op_args)}));'
if "replace_outputs" not in attrs else
f'''auto rt = make_vh_vector(make_{cc_func_name}({", ".join(op_args)}));

View File

@ -180,7 +180,7 @@ def _setitem_old(x, slices, value):
# PATCH
def getitem(x, slices):
if isinstance(slices, jt.Var) and slices.dtype == "bool":
return getitem(x, tuple(slices.where()))
return getitem(x, slices.where())
if isinstance(slices, tuple):
ss = []
for s in slices:
@ -193,7 +193,14 @@ def getitem(x, slices):
def setitem(x, slices, value):
if isinstance(slices, jt.Var) and slices.dtype == "bool":
slices = tuple(slices.where())
if slices.shape == x.shape:
if isinstance(value, (int, float)):
value = jt.array(value).broadcast(x.shape)
return x.assign(slices.ternary(value, x))
elif isinstance(value, jt.Var) and value.shape == [1,]:
value = jt.broadcast(value, x.shape)
return x.assign(slices.ternary(value, x))
slices = slices.where()
elif isinstance(slices, tuple):
ss = []
for s in slices:

View File

@ -21,6 +21,13 @@
namespace jittor {
template<class T>
struct vector_to_tuple {
typedef T value_type;
vector<T> x;
vector_to_tuple(vector<T>&& _) :x(move(_)) {}
};
#define DEF_IS(check_type, return_type) \
template<class T> \
typename std::enable_if<std::is_same<T, check_type>::value, return_type>::type
@ -462,6 +469,7 @@ DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj);
typename std::enable_if<is_##check_type<T>::value, return_type>::type
CHECK_IS_1(vector);
CHECK_IS_1(vector_to_tuple);
CHECK_IS_2(map);
DEF_IS_2(map, bool) is_type(PyObject* obj);
@ -499,6 +507,17 @@ DEF_IS_1(vector, PyObject*) to_py_tuple(const T& a) {
return list.release();
}
DEF_IS_1(vector_to_tuple, PyObject*) to_py_object(const T& a) {
PyObjHolder list(PyTuple_New(a.x.size()));
for (uint i=0; i<a.x.size(); i++) {
PyObject* o = to_py_object<typename T::value_type>(a.x[i]);
CHECK(o);
// PyTuple_SET_ITEM borrow ownership, we do not hold this
PyTuple_SET_ITEM(list.obj, i, o);
}
return list.release();
}
DEF_IS_1(vector, PyObject*) to_py_object(T&& a) {
PyObjHolder list(PyList_New(a.size()));
for (uint i=0; i<a.size(); i++) {

View File

@ -206,6 +206,16 @@ class TestSetitem(unittest.TestCase):
b = jt.array([True,False,True,False])
a[b] = jt.array([-1,-2])
assert (a.data == [-1,2,-2,4]).all()
def test_setitem_bool2(self):
a = jt.array([1,2,3,4])
b = jt.array([True,False,True,False])
a[b] = jt.array([-1])
assert (a.data == [-1,2,-1,4]).all(), a
a = jt.array([1,2,3,4])
b = jt.array([True,False,True,False])
a[b] = -1
assert (a.data == [-1,2,-1,4]).all(), a
def test_slice_none(self):
a = jt.array([1,2])

View File

@ -49,6 +49,7 @@ class TestWhereOp(unittest.TestCase):
def test_reduce_dep(self):
a = jt.random([100,100])
index = self.where(a>0.5)
assert isinstance(index, tuple)
x = a.reindex_var(index)
xsum =x.sum()
na = a.data