add scatter gather support

This commit is contained in:
Dun Liang 2021-03-07 16:11:28 +08:00
parent 8bf423980e
commit 1d665e4dbf
7 changed files with 158 additions and 21 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in # This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.2.2.37' __version__ = '1.2.2.38'
from . import lock from . import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int

View File

@ -1073,3 +1073,89 @@ inline static void searchsorted(
cpu_src=_searchsorted_src, cpu_src=_searchsorted_src,
cuda_header=_searchsorted_header, cuda_header=_searchsorted_header,
cuda_src=_searchsorted_src) cuda_src=_searchsorted_src)
def scatter(x:jt.Var, dim:int, index:jt.Var, src:jt.Var, reduce='void'):
''' if x is a 3-D array, rewrite x like:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
Parameters::
* x (jt.Var) input array
* dim (int) the axis along which to index
* index (jt.Var) the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged.
* src (jt.Var) the source element(s) to scatter.
* reduce (str, optional) reduction operation to apply, can be either 'add' or 'multiply'.
Example::
src = jt.arange(1, 11).reshape((2, 5))
index = jt.array([[0, 1, 2, 0]])
x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src)
assert (x.data ==
[[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]]).all()
index = jt.array([[0, 1, 2], [0, 1, 4]])
x = jt.zeros((3, 5), dtype=src.dtype).scatter_(1, index, src)
assert (x.data ==
[[1, 2, 3, 0, 0],
[6, 7, 0, 0, 8],
[0, 0, 0, 0, 0]]).all()
x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]),
jt.array(1.23), reduce='multiply')
assert np.allclose(x.data,
[[2.0000, 2.0000, 2.4600, 2.0000],
[2.0000, 2.0000, 2.0000, 2.4600]]), x
x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]),
jt.array(1.23), reduce='add')
assert np.allclose(x.data,
[[2.0000, 2.0000, 3.2300, 2.0000],
[2.0000, 2.0000, 2.0000, 3.2300]])
'''
shape = index.shape
if src.shape != shape:
src = src[tuple( slice(None,s) for s in shape )]
indexes = [ f'i{i}' for i in range(len(shape)) ]
indexes[dim] = index
return x.setitem(tuple(indexes), src, reduce)
def scatter_(x, dim, index, src, reduce='void'):
return x.assign(x.scatter(dim, index, src, reduce))
jt.Var.scatter = scatter
jt.Var.scatter_ = scatter_
def gather(x, dim, index):
''' if x is a 3-D array, reindex x like:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Parameters::
* input (jt.Var) the source array
* dim (int) the axis along which to index
* index (jt.Var) the indices of elements to gather
Example::
t = jt.array([[1, 2], [3, 4]])
data = t.gather(1, jt.array([[0, 0], [1, 0]]))
assert (data.data == [[ 1, 1], [ 4, 3]]).all()
data = t.gather(0, jt.array([[0, 0], [1, 0]]))
assert (data.data == [[ 1, 2], [ 3, 2]]).all()
'''
shape = index.shape
indexes = [ f'i{i}' for i in range(len(shape)) ]
indexes[dim] = index
return x.getitem(tuple(indexes))
jt.Var.gather = gather

View File

@ -149,26 +149,49 @@ class TestSetitem(unittest.TestCase):
assert (a[0].numpy() == [-1,2]).all(), a[0].numpy() assert (a[0].numpy() == [-1,2]).all(), a[0].numpy()
assert (a[1].numpy() == [3,-2]).all(), a[1].numpy() assert (a[1].numpy() == [3,-2]).all(), a[1].numpy()
# def test_scatter(self): def test_scatter(self):
# src = jt.arange(1, 11).reshape((2, 5)) src = jt.arange(1, 11).reshape((2, 5))
# index = jt.array([[0, 1, 2, 0]]) index = jt.array([[0, 1, 2, 0]])
# print(index.shape, src.shape) x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src)
# x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src) assert (x.data ==
# print(x) [[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]]).all()
index = jt.array([[0, 1, 2], [0, 1, 4]])
x = jt.zeros((3, 5), dtype=src.dtype).scatter_(1, index, src)
assert (x.data ==
[[1, 2, 3, 0, 0],
[6, 7, 0, 0, 8],
[0, 0, 0, 0, 0]]).all()
x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]),
jt.array(1.23), reduce='multiply')
assert np.allclose(x.data,
[[2.0000, 2.0000, 2.4600, 2.0000],
[2.0000, 2.0000, 2.0000, 2.4600]]), x
x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]),
jt.array(1.23), reduce='add')
assert np.allclose(x.data,
[[2.0000, 2.0000, 3.2300, 2.0000],
[2.0000, 2.0000, 2.0000, 3.2300]])
def test_gather(self):
t = jt.array([[1, 2], [3, 4]])
data = t.gather(1, jt.array([[0, 0], [1, 0]])).data
assert (data == [[ 1, 1], [ 4, 3]]).all()
data = t.gather(0, jt.array([[0, 0], [1, 0]])).data
assert (data == [[ 1, 2], [ 3, 2]]).all()
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_scatter_cuda(self):
self.test_scatter()
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_gather_cuda(self):
self.test_gather()
# def scatter(x, dim, index, src, reduce='void'):
# shape = index.shape
# indexes = [ jt.index(shape, i) for i in range(dim) ]
# indexes.append(index)
# print(indexes)
# return x.setitem(tuple(indexes), src, reduce)
# def scatter_(x, dim, index, src, reduce='void'):
# return x.assign(x.scatter(dim, index, src, reduce))
# jt.Var.scatter = scatter
# jt.Var.scatter_ = scatter_
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -117,6 +117,10 @@ void GetitemOp::infer_slices(
auto& v = s.slice.start; auto& v = s.slice.start;
if (v<0) v += in_shape_i; if (v<0) v += in_shape_i;
CHECK(v>=0 && v<in_shape_i) << "slice overflow, " << v << "not in [0,">>in_shape_i>>")"; CHECK(v>=0 && v<in_shape_i) << "slice overflow, " << v << "not in [0,">>in_shape_i>>")";
} else
if (s.is_str()) {
i_to_vs[i] = vid++;
i_to_o[i] = -1;
} else { } else {
// slice // slice
auto& slice = s.slice; auto& slice = s.slice;
@ -401,6 +405,10 @@ void GetitemOp::jit_prepare(JK& jk) {
if (iv>=0 && io==-1) { if (iv>=0 && io==-1) {
if (v.is_int()) { if (v.is_int()) {
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1"); jk << _CS("][VS") << JK::hex1(i) << _CS(":-1");
} else
if (v.is_str()) {
jk << _CS("][VS") << JK::hex1(i) << _CS(":-5");
jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str();
} else { } else {
ASSERT(v.is_var()); ASSERT(v.is_var());
auto var = v.var; auto var = v.var;
@ -498,9 +506,10 @@ void GetitemOp::jit_run() {
@if(IV@d==-2, 0, @if(IV@d==-2, 0,
@if(IO@d!=-1, (i@{IO@d}*vstep@d+vstart@d), @if(IO@d!=-1, (i@{IO@d}*vstep@d+vstart@d),
@if(VS@d==-1, vi@d, @if(VS@d==-1, vi@d,
@if(VS@d==-5, VSS@d,
@if(VS@d>=0, @if(VS@d>=0,
index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))]) index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))])
, ??? ))))); , ??? ))))));
) )
auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d); auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d);
op[oid] = ip[iid]; op[oid] = ip[iid];

View File

@ -206,6 +206,10 @@ void SetitemOp::jit_prepare(JK& jk) {
if (iv>=0 && io==-1) { if (iv>=0 && io==-1) {
if (v.is_int()) { if (v.is_int()) {
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1"); jk << _CS("][VS") << JK::hex1(i) << _CS(":-1");
} else
if (v.is_str()) {
jk << _CS("][VS") << JK::hex1(i) << _CS(":-5");
jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str();
} else { } else {
ASSERT(v.is_var()); ASSERT(v.is_var());
auto var = v.var; auto var = v.var;
@ -323,9 +327,10 @@ void SetitemOp::jit_run() {
@if(IV@d==-2, 0, @if(IV@d==-2, 0,
@if(IO@d!=-1, (i@{IO@d}*vstep@d+vstart@d), @if(IO@d!=-1, (i@{IO@d}*vstep@d+vstart@d),
@if(VS@d==-1, vi@d, @if(VS@d==-1, vi@d,
@if(VS@d==-5, VSS@d,
@if(VS@d>=0, @if(VS@d>=0,
index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))]) index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))])
, ??? ))))); , ??? ))))));
) )
auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d); auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d);

View File

@ -719,6 +719,7 @@ DEF_IS(VarSlices, bool) is_type(PyObject* obj) {
PySlice_Check(obj) || PySlice_Check(obj) ||
(Py_TYPE(obj) == &PyEllipsis_Type) || (Py_TYPE(obj) == &PyEllipsis_Type) ||
obj == Py_None || obj == Py_None ||
PyUnicode_CheckExact(obj) ||
is_type<VarHolder*>(obj); is_type<VarHolder*>(obj);
} }
@ -733,6 +734,9 @@ void load_var_slice(PyObject* obj, T* var_slice, vector<unique_ptr<VarHolder>>&
if (Py_TYPE(obj) == &PyEllipsis_Type) { if (Py_TYPE(obj) == &PyEllipsis_Type) {
var_slice->set_ellipsis(); var_slice->set_ellipsis();
} else } else
if (PyUnicode_CheckExact(obj)) {
var_slice->set_str(from_py_object<string>(obj));
} else
if (obj == Py_None) { if (obj == Py_None) {
var_slice->set_none(); var_slice->set_none();
} else } else

View File

@ -21,11 +21,21 @@ union VarSlice {
inline bool is_ellipsis() const { return slice.mask == -2; } inline bool is_ellipsis() const { return slice.mask == -2; }
inline bool is_none() const { return slice.mask == -3; } inline bool is_none() const { return slice.mask == -3; }
inline bool is_int() const { return slice.mask == -4; } inline bool is_int() const { return slice.mask == -4; }
inline bool is_str() const { return slice.mask == -5; }
inline bool is_slice() const { return slice.mask >= 0; } inline bool is_slice() const { return slice.mask >= 0; }
inline void set_var(Var* v) { slice.mask = -1; var = v; } inline void set_var(Var* v) { slice.mask = -1; var = v; }
inline void set_ellipsis() { slice.mask = -2; } inline void set_ellipsis() { slice.mask = -2; }
inline void set_none() { slice.mask = -3; } inline void set_none() { slice.mask = -3; }
inline void set_int(int64 v) { slice.mask = -4; i = v; } inline void set_int(int64 v) { slice.mask = -4; i = v; }
inline void set_str(const string& s) {
slice.mask = -5;
CHECK(s.size() < 16) << "String slice too long" << s;
auto v = (int64*)s.c_str();
slice.start = v[0];
slice.stop = v[1];
slice.step = s.size();
}
inline char* get_str() {return (char*)this;}
}; };
struct VarSlices { struct VarSlices {