fix slices bugs

This commit is contained in:
li-xl 2020-12-01 19:45:06 +08:00
parent 7c0798d057
commit 7f0444d0e0
1 changed files with 9 additions and 2 deletions

View File

@ -10,6 +10,7 @@
import jittor as jt import jittor as jt
import numpy as np import numpy as np
from jittor import pool from jittor import pool
from collections.abc import Sequence
def argmax_pool(x, size, stride, padding=0): def argmax_pool(x, size, stride, padding=0):
return pool.pool(x, size, 'maximum', padding, stride) return pool.pool(x, size, 'maximum', padding, stride)
@ -180,8 +181,14 @@ jt.Var.__setitem__ = setitem
def getitem(x, slices): def getitem(x, slices):
if isinstance(slices, jt.Var) and slices.dtype == "bool": if isinstance(slices, jt.Var) and slices.dtype == "bool":
return getitem(x, slices.where()) return getitem(x, slices.where())
if isinstance(slices, list): if isinstance(slices, Sequence):
slices = tuple(slices) ss = []
for s in slices:
if isinstance(s, jt.Var) and s.dtype == "bool":
ss.extend(s.where())
else:
ss.append(s)
slices = tuple(ss)
return x.getitem(slices) return x.getitem(slices)
def setitem(x, slices, value): def setitem(x, slices, value):