fix vary shape setitem

This commit is contained in:
Dun Liang 2020-10-17 15:08:55 +08:00
parent d31cdab30f
commit 05a67323ae
3 changed files with 16 additions and 1 deletions

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.0.5'
__version__ = '1.2.0.6'
from . import lock
with lock.lock_scope():
from . import compiler

View File

@ -133,6 +133,13 @@ class TestSlice(unittest.TestCase):
assert np.allclose(da.numpy(), nda, atol = 1e-3)
assert np.allclose(db.numpy(), ndb, atol = 1e-3)
def test_vary_shape_setitem(self):
a = jt.array([1,2,3,4,5])
b = jt.array([1,2,3,4,5])
c = jt.where(b>3)
a[c] = 0
assert (a.data == [1,2,3,0,0]).all()
if __name__ == "__main__":

View File

@ -176,6 +176,14 @@ VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
}
void SetitemOp::jit_prepare() {
for (int i=0; i<o_shape.size(); i++)
if (o_shape[i]<0) {
// because output shape is inferd, check in
// executor not work
// reinfer shape if o_shape has vary shape
infer_shape();
break;
}
auto data = input(1);
add_jit_define("OP", op);
add_jit_define("Td", data->dtype());