polish setitem bool mask

This commit is contained in:
Dun Liang 2021-03-08 22:01:21 +08:00
parent 1d665e4dbf
commit fcaf0f9da5
3 changed files with 11 additions and 10 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.38'
__version__ = '1.2.2.39'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -155,12 +155,12 @@ def slice_var_index(x, slices):
x.stop_fuse()
return (out_shape, out_index, 0, [], extras)
def slice_var(x, slices):
def _slice_var_old(x, slices):
reindex_args = slice_var_index(x, slices)
x.stop_fuse()
return x.reindex(*reindex_args).stop_fuse()
def setitem(x, slices, value):
def _setitem_old(x, slices, value):
reindex_args = slice_var_index(x, slices)
reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:]
xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse()
@ -176,9 +176,6 @@ def setitem(x, slices, value):
x.assign(out)
return x
jt.Var.__getitem__ = jt.Var.slice_var = slice_var
jt.Var.__setitem__ = setitem
# PATCH
def getitem(x, slices):
if isinstance(slices, jt.Var) and slices.dtype == "bool":
@ -195,10 +192,8 @@ def getitem(x, slices):
def setitem(x, slices, value):
if isinstance(slices, jt.Var) and slices.dtype == "bool":
mask = jt.broadcast(slices, x)
value = jt.broadcast(value, x)
return x.assign(mask.ternary(value, x))
if isinstance(slices, Sequence):
slices = tuple(slices.where())
elif isinstance(slices, Sequence):
ss = []
for s in slices:
if isinstance(s, jt.Var) and s.dtype == "bool":

View File

@ -190,6 +190,12 @@ class TestSetitem(unittest.TestCase):
@jt.flag_scope(use_cuda=1)
def test_gather_cuda(self):
self.test_gather()
def test_setitem_bool(self):
a = jt.array([1,2,3,4])
b = jt.array([True,False,True,False])
a[b] = jt.array([-1,-2])
assert (a.data == [-1,2,-2,4]).all()