mirror of https://github.com/Jittor/Jittor
polish setitem bool mask
This commit is contained in:
parent
1d665e4dbf
commit
fcaf0f9da5
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.38'
|
||||
__version__ = '1.2.2.39'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -155,12 +155,12 @@ def slice_var_index(x, slices):
|
|||
x.stop_fuse()
|
||||
return (out_shape, out_index, 0, [], extras)
|
||||
|
||||
def slice_var(x, slices):
|
||||
def _slice_var_old(x, slices):
|
||||
reindex_args = slice_var_index(x, slices)
|
||||
x.stop_fuse()
|
||||
return x.reindex(*reindex_args).stop_fuse()
|
||||
|
||||
def setitem(x, slices, value):
|
||||
def _setitem_old(x, slices, value):
|
||||
reindex_args = slice_var_index(x, slices)
|
||||
reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:]
|
||||
xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse()
|
||||
|
@ -176,9 +176,6 @@ def setitem(x, slices, value):
|
|||
x.assign(out)
|
||||
return x
|
||||
|
||||
jt.Var.__getitem__ = jt.Var.slice_var = slice_var
|
||||
jt.Var.__setitem__ = setitem
|
||||
|
||||
# PATCH
|
||||
def getitem(x, slices):
|
||||
if isinstance(slices, jt.Var) and slices.dtype == "bool":
|
||||
|
@ -195,10 +192,8 @@ def getitem(x, slices):
|
|||
|
||||
def setitem(x, slices, value):
|
||||
if isinstance(slices, jt.Var) and slices.dtype == "bool":
|
||||
mask = jt.broadcast(slices, x)
|
||||
value = jt.broadcast(value, x)
|
||||
return x.assign(mask.ternary(value, x))
|
||||
if isinstance(slices, Sequence):
|
||||
slices = tuple(slices.where())
|
||||
elif isinstance(slices, Sequence):
|
||||
ss = []
|
||||
for s in slices:
|
||||
if isinstance(s, jt.Var) and s.dtype == "bool":
|
||||
|
|
|
@ -190,6 +190,12 @@ class TestSetitem(unittest.TestCase):
|
|||
@jt.flag_scope(use_cuda=1)
|
||||
def test_gather_cuda(self):
|
||||
self.test_gather()
|
||||
|
||||
def test_setitem_bool(self):
|
||||
a = jt.array([1,2,3,4])
|
||||
b = jt.array([True,False,True,False])
|
||||
a[b] = jt.array([-1,-2])
|
||||
assert (a.data == [-1,2,-2,4]).all()
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue