diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 2abf5cde..fe2f360d 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -8,7 +8,7 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.2.37' +__version__ = '1.2.2.38' from . import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/misc.py b/python/jittor/misc.py index f88f4d6c..60bd1b88 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -1073,3 +1073,89 @@ inline static void searchsorted( cpu_src=_searchsorted_src, cuda_header=_searchsorted_header, cuda_src=_searchsorted_src) + + +def scatter(x:jt.Var, dim:int, index:jt.Var, src:jt.Var, reduce='void'): + ''' if x is a 3-D array, rewrite x like: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + +Parameters:: + + * x (jt.Var) – input array + * dim (int) – the axis along which to index + * index (jt.Var) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged. + * src (jt.Var) – the source element(s) to scatter. + * reduce (str, optional) – reduction operation to apply, can be either 'add' or 'multiply'. + +Example:: + + src = jt.arange(1, 11).reshape((2, 5)) + index = jt.array([[0, 1, 2, 0]]) + x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src) + assert (x.data == + [[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]).all() + index = jt.array([[0, 1, 2], [0, 1, 4]]) + x = jt.zeros((3, 5), dtype=src.dtype).scatter_(1, index, src) + assert (x.data == + [[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]).all() + x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]), + jt.array(1.23), reduce='multiply') + assert np.allclose(x.data, + [[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]), x + x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]), + jt.array(1.23), reduce='add') + assert np.allclose(x.data, + [[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) + + ''' + shape = index.shape + if src.shape != shape: + src = src[tuple( slice(None,s) for s in shape )] + indexes = [ f'i{i}' for i in range(len(shape)) ] + indexes[dim] = index + return x.setitem(tuple(indexes), src, reduce) + +def scatter_(x, dim, index, src, reduce='void'): + return x.assign(x.scatter(dim, index, src, reduce)) + +jt.Var.scatter = scatter +jt.Var.scatter_ = scatter_ + +def gather(x, dim, index): + ''' if x is a 3-D array, reindex x like: + + out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 + out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 + out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 + + +Parameters:: + + * input (jt.Var) – the source array + * dim (int) – the axis along which to index + * index (jt.Var) – the indices of elements to gather + +Example:: + + t = jt.array([[1, 2], [3, 4]]) + data = t.gather(1, jt.array([[0, 0], [1, 0]])) + assert (data.data == [[ 1, 1], [ 4, 3]]).all() + data = t.gather(0, jt.array([[0, 0], [1, 0]])) + assert (data.data == [[ 1, 2], [ 3, 2]]).all() + + ''' + shape = index.shape + indexes = [ f'i{i}' for i in range(len(shape)) ] + indexes[dim] = index + return x.getitem(tuple(indexes)) + +jt.Var.gather = gather diff --git a/python/jittor/test/test_setitem.py b/python/jittor/test/test_setitem.py index 83262252..75daa950 100644 --- a/python/jittor/test/test_setitem.py +++ b/python/jittor/test/test_setitem.py @@ -149,26 +149,49 @@ class TestSetitem(unittest.TestCase): assert (a[0].numpy() == [-1,2]).all(), a[0].numpy() assert (a[1].numpy() == [3,-2]).all(), a[1].numpy() - # def test_scatter(self): - # src = jt.arange(1, 11).reshape((2, 5)) - # index = jt.array([[0, 1, 2, 0]]) - # print(index.shape, src.shape) - # x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src) - # print(x) + def test_scatter(self): + src = jt.arange(1, 11).reshape((2, 5)) + index = jt.array([[0, 1, 2, 0]]) + x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src) + assert (x.data == + [[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]).all() + index = jt.array([[0, 1, 2], [0, 1, 4]]) + x = jt.zeros((3, 5), dtype=src.dtype).scatter_(1, index, src) + assert (x.data == + [[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]).all() + x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]), + jt.array(1.23), reduce='multiply') + assert np.allclose(x.data, + [[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]), x + x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]), + jt.array(1.23), reduce='add') + assert np.allclose(x.data, + [[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) + + def test_gather(self): + t = jt.array([[1, 2], [3, 4]]) + data = t.gather(1, jt.array([[0, 0], [1, 0]])).data + assert (data == [[ 1, 1], [ 4, 3]]).all() + data = t.gather(0, jt.array([[0, 0], [1, 0]])).data + assert (data == [[ 1, 2], [ 3, 2]]).all() + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_scatter_cuda(self): + self.test_scatter() + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_gather_cuda(self): + self.test_gather() -# def scatter(x, dim, index, src, reduce='void'): -# shape = index.shape -# indexes = [ jt.index(shape, i) for i in range(dim) ] -# indexes.append(index) -# print(indexes) -# return x.setitem(tuple(indexes), src, reduce) - -# def scatter_(x, dim, index, src, reduce='void'): -# return x.assign(x.scatter(dim, index, src, reduce)) - -# jt.Var.scatter = scatter -# jt.Var.scatter_ = scatter_ if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/src/ops/getitem_op.cc b/src/ops/getitem_op.cc index 59ac5162..2882f404 100644 --- a/src/ops/getitem_op.cc +++ b/src/ops/getitem_op.cc @@ -117,6 +117,10 @@ void GetitemOp::infer_slices( auto& v = s.slice.start; if (v<0) v += in_shape_i; CHECK(v>=0 && v>in_shape_i>>")"; + } else + if (s.is_str()) { + i_to_vs[i] = vid++; + i_to_o[i] = -1; } else { // slice auto& slice = s.slice; @@ -401,6 +405,10 @@ void GetitemOp::jit_prepare(JK& jk) { if (iv>=0 && io==-1) { if (v.is_int()) { jk << _CS("][VS") << JK::hex1(i) << _CS(":-1"); + } else + if (v.is_str()) { + jk << _CS("][VS") << JK::hex1(i) << _CS(":-5"); + jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str(); } else { ASSERT(v.is_var()); auto var = v.var; @@ -498,9 +506,10 @@ void GetitemOp::jit_run() { @if(IV@d==-2, 0, @if(IO@d!=-1, (i@{IO@d}*vstep@d+vstart@d), @if(VS@d==-1, vi@d, + @if(VS@d==-5, VSS@d, @if(VS@d>=0, index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))]) - , ??? ))))); + , ??? )))))); ) auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d); op[oid] = ip[iid]; diff --git a/src/ops/setitem_op.cc b/src/ops/setitem_op.cc index 0b8ae47d..3bf94802 100644 --- a/src/ops/setitem_op.cc +++ b/src/ops/setitem_op.cc @@ -206,6 +206,10 @@ void SetitemOp::jit_prepare(JK& jk) { if (iv>=0 && io==-1) { if (v.is_int()) { jk << _CS("][VS") << JK::hex1(i) << _CS(":-1"); + } else + if (v.is_str()) { + jk << _CS("][VS") << JK::hex1(i) << _CS(":-5"); + jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str(); } else { ASSERT(v.is_var()); auto var = v.var; @@ -323,9 +327,10 @@ void SetitemOp::jit_run() { @if(IV@d==-2, 0, @if(IO@d!=-1, (i@{IO@d}*vstep@d+vstart@d), @if(VS@d==-1, vi@d, + @if(VS@d==-5, VSS@d, @if(VS@d>=0, index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))]) - , ??? ))))); + , ??? )))))); ) auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d); diff --git a/src/pyjt/py_converter.h b/src/pyjt/py_converter.h index 598b08b7..8186b77a 100644 --- a/src/pyjt/py_converter.h +++ b/src/pyjt/py_converter.h @@ -719,6 +719,7 @@ DEF_IS(VarSlices, bool) is_type(PyObject* obj) { PySlice_Check(obj) || (Py_TYPE(obj) == &PyEllipsis_Type) || obj == Py_None || + PyUnicode_CheckExact(obj) || is_type(obj); } @@ -733,6 +734,9 @@ void load_var_slice(PyObject* obj, T* var_slice, vector>& if (Py_TYPE(obj) == &PyEllipsis_Type) { var_slice->set_ellipsis(); } else + if (PyUnicode_CheckExact(obj)) { + var_slice->set_str(from_py_object(obj)); + } else if (obj == Py_None) { var_slice->set_none(); } else diff --git a/src/var_slices.h b/src/var_slices.h index 8b933595..235c03e2 100644 --- a/src/var_slices.h +++ b/src/var_slices.h @@ -21,11 +21,21 @@ union VarSlice { inline bool is_ellipsis() const { return slice.mask == -2; } inline bool is_none() const { return slice.mask == -3; } inline bool is_int() const { return slice.mask == -4; } + inline bool is_str() const { return slice.mask == -5; } inline bool is_slice() const { return slice.mask >= 0; } inline void set_var(Var* v) { slice.mask = -1; var = v; } inline void set_ellipsis() { slice.mask = -2; } inline void set_none() { slice.mask = -3; } inline void set_int(int64 v) { slice.mask = -4; i = v; } + inline void set_str(const string& s) { + slice.mask = -5; + CHECK(s.size() < 16) << "String slice too long" << s; + auto v = (int64*)s.c_str(); + slice.start = v[0]; + slice.stop = v[1]; + slice.step = s.size(); + } + inline char* get_str() {return (char*)this;} }; struct VarSlices {