fix setitem bugs

This commit is contained in:
li-xl 2020-12-15 15:17:30 +08:00 committed by Dun Liang
parent d7cad0ca82
commit fded962b54
1 changed files with 9 additions and 2 deletions

View File

@ -12,6 +12,7 @@ import numpy as np
from jittor import pool
from collections.abc import Sequence
def argmax_pool(x, size, stride, padding=0):
return pool.pool(x, size, 'maximum', padding, stride)
@ -196,8 +197,14 @@ def setitem(x, slices, value):
mask = jt.broadcast(slices, x)
value = jt.broadcast(value, x)
return x.assign(mask.ternary(value, x))
if isinstance(slices, list):
slices = tuple(slices)
if isinstance(slices, Sequence):
ss = []
for s in slices:
if isinstance(s, jt.Var) and s.dtype == "bool":
ss.extend(s.where())
else:
ss.append(s)
slices = tuple(ss)
return x.assign(x.setitem(slices, value))
jt.Var.__getitem__ = jt.Var.slice_var = getitem