mirror of https://github.com/Jittor/Jittor
polish setitem inplace opt
This commit is contained in:
parent
9c03fd4f75
commit
b4cb572c90
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.66'
|
||||
__version__ = '1.2.3.67'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -186,8 +186,8 @@ static void getitem_inplace(GetitemOp* op) {
|
|||
VarSlice s = vs.slices[i];
|
||||
if (!(s.is_slice())) return;
|
||||
Slice ss = s.slice;
|
||||
if (!(ss.start == 0 && ss.stop >= in_shape[i] && ss.step == 1))
|
||||
return;
|
||||
if (!(ss.start == 0 && (ss.mask&2) && ss.step == 1))
|
||||
return;
|
||||
}
|
||||
|
||||
VarSlice s = vs.slices[0];
|
||||
|
|
|
@ -73,6 +73,11 @@ class TestSetitem(unittest.TestCase):
|
|||
arr4_res.data[0,0,1,1] = 1
|
||||
assert arr4[0,0,1,1] == 1
|
||||
|
||||
arr4 = jt.random((4,2,3,3))
|
||||
arr4_res = arr4[...,:,:2]
|
||||
arr4_res.data[0,0,1,1] = 1
|
||||
assert arr4[0,0,1,1] != 1
|
||||
|
||||
arr5 = jt.random((4,2,3,3))
|
||||
arr5_res = arr5[1:3,:,:,:]
|
||||
arr5_res.data[1,0,1,1] = 1
|
||||
|
|
Loading…
Reference in New Issue