mirror of https://github.com/Jittor/Jittor
fix vary shape setitem
This commit is contained in:
parent
d31cdab30f
commit
05a67323ae
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.0.5'
|
||||
__version__ = '1.2.0.6'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
|
|
@ -133,6 +133,13 @@ class TestSlice(unittest.TestCase):
|
|||
assert np.allclose(da.numpy(), nda, atol = 1e-3)
|
||||
assert np.allclose(db.numpy(), ndb, atol = 1e-3)
|
||||
|
||||
def test_vary_shape_setitem(self):
|
||||
a = jt.array([1,2,3,4,5])
|
||||
b = jt.array([1,2,3,4,5])
|
||||
c = jt.where(b>3)
|
||||
a[c] = 0
|
||||
assert (a.data == [1,2,3,0,0]).all()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -176,6 +176,14 @@ VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
|
||||
void SetitemOp::jit_prepare() {
|
||||
for (int i=0; i<o_shape.size(); i++)
|
||||
if (o_shape[i]<0) {
|
||||
// because output shape is inferd, check in
|
||||
// executor not work
|
||||
// reinfer shape if o_shape has vary shape
|
||||
infer_shape();
|
||||
break;
|
||||
}
|
||||
auto data = input(1);
|
||||
add_jit_define("OP", op);
|
||||
add_jit_define("Td", data->dtype());
|
||||
|
|
Loading…
Reference in New Issue