mirror of https://github.com/Jittor/Jittor
add bool setitem alias and vector_to_tuple
This commit is contained in:
parent
0c35e1a29b
commit
f36693c797
|
@ -1467,7 +1467,8 @@ from . import nn
|
|||
from . import attention
|
||||
from . import lr_scheduler
|
||||
from . import linalg
|
||||
from .nn import matmul
|
||||
from .nn import matmul, \
|
||||
bmm, bmm_transpose
|
||||
from . import contrib
|
||||
from . import numpy2cupy
|
||||
from .contrib import concat
|
||||
|
|
|
@ -358,7 +358,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
|||
jit_cc_src.append(f"""
|
||||
/*{doc_string}*/
|
||||
// @pyjt({",".join(pyjt_names)})
|
||||
vector<VarHolder*> {cc_func_name}({", ".join(cc_args)}) {{
|
||||
vector_to_tuple<VarHolder*> {cc_func_name}({", ".join(cc_args)}) {{
|
||||
{ f'return make_vh_vector(make_{cc_func_name}({", ".join(op_args)}));'
|
||||
if "replace_outputs" not in attrs else
|
||||
f'''auto rt = make_vh_vector(make_{cc_func_name}({", ".join(op_args)}));
|
||||
|
|
|
@ -180,7 +180,7 @@ def _setitem_old(x, slices, value):
|
|||
# PATCH
|
||||
def getitem(x, slices):
|
||||
if isinstance(slices, jt.Var) and slices.dtype == "bool":
|
||||
return getitem(x, tuple(slices.where()))
|
||||
return getitem(x, slices.where())
|
||||
if isinstance(slices, tuple):
|
||||
ss = []
|
||||
for s in slices:
|
||||
|
@ -193,7 +193,14 @@ def getitem(x, slices):
|
|||
|
||||
def setitem(x, slices, value):
|
||||
if isinstance(slices, jt.Var) and slices.dtype == "bool":
|
||||
slices = tuple(slices.where())
|
||||
if slices.shape == x.shape:
|
||||
if isinstance(value, (int, float)):
|
||||
value = jt.array(value).broadcast(x.shape)
|
||||
return x.assign(slices.ternary(value, x))
|
||||
elif isinstance(value, jt.Var) and value.shape == [1,]:
|
||||
value = jt.broadcast(value, x.shape)
|
||||
return x.assign(slices.ternary(value, x))
|
||||
slices = slices.where()
|
||||
elif isinstance(slices, tuple):
|
||||
ss = []
|
||||
for s in slices:
|
||||
|
|
|
@ -21,6 +21,13 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
template<class T>
|
||||
struct vector_to_tuple {
|
||||
typedef T value_type;
|
||||
vector<T> x;
|
||||
vector_to_tuple(vector<T>&& _) :x(move(_)) {}
|
||||
};
|
||||
|
||||
#define DEF_IS(check_type, return_type) \
|
||||
template<class T> \
|
||||
typename std::enable_if<std::is_same<T, check_type>::value, return_type>::type
|
||||
|
@ -462,6 +469,7 @@ DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj);
|
|||
typename std::enable_if<is_##check_type<T>::value, return_type>::type
|
||||
|
||||
CHECK_IS_1(vector);
|
||||
CHECK_IS_1(vector_to_tuple);
|
||||
|
||||
CHECK_IS_2(map);
|
||||
DEF_IS_2(map, bool) is_type(PyObject* obj);
|
||||
|
@ -499,6 +507,17 @@ DEF_IS_1(vector, PyObject*) to_py_tuple(const T& a) {
|
|||
return list.release();
|
||||
}
|
||||
|
||||
DEF_IS_1(vector_to_tuple, PyObject*) to_py_object(const T& a) {
|
||||
PyObjHolder list(PyTuple_New(a.x.size()));
|
||||
for (uint i=0; i<a.x.size(); i++) {
|
||||
PyObject* o = to_py_object<typename T::value_type>(a.x[i]);
|
||||
CHECK(o);
|
||||
// PyTuple_SET_ITEM borrow ownership, we do not hold this
|
||||
PyTuple_SET_ITEM(list.obj, i, o);
|
||||
}
|
||||
return list.release();
|
||||
}
|
||||
|
||||
DEF_IS_1(vector, PyObject*) to_py_object(T&& a) {
|
||||
PyObjHolder list(PyList_New(a.size()));
|
||||
for (uint i=0; i<a.size(); i++) {
|
||||
|
|
|
@ -206,6 +206,16 @@ class TestSetitem(unittest.TestCase):
|
|||
b = jt.array([True,False,True,False])
|
||||
a[b] = jt.array([-1,-2])
|
||||
assert (a.data == [-1,2,-2,4]).all()
|
||||
|
||||
def test_setitem_bool2(self):
|
||||
a = jt.array([1,2,3,4])
|
||||
b = jt.array([True,False,True,False])
|
||||
a[b] = jt.array([-1])
|
||||
assert (a.data == [-1,2,-1,4]).all(), a
|
||||
a = jt.array([1,2,3,4])
|
||||
b = jt.array([True,False,True,False])
|
||||
a[b] = -1
|
||||
assert (a.data == [-1,2,-1,4]).all(), a
|
||||
|
||||
def test_slice_none(self):
|
||||
a = jt.array([1,2])
|
||||
|
|
|
@ -49,6 +49,7 @@ class TestWhereOp(unittest.TestCase):
|
|||
def test_reduce_dep(self):
|
||||
a = jt.random([100,100])
|
||||
index = self.where(a>0.5)
|
||||
assert isinstance(index, tuple)
|
||||
x = a.reindex_var(index)
|
||||
xsum =x.sum()
|
||||
na = a.data
|
||||
|
|
Loading…
Reference in New Issue