fix slices bugs

This commit is contained in:
li-xl 2020-12-01 19:45:06 +08:00
parent 7c0798d057
commit 7f0444d0e0
1 changed files with 9 additions and 2 deletions

View File

@ -10,6 +10,7 @@
import jittor as jt
import numpy as np
from jittor import pool
from collections.abc import Sequence
def argmax_pool(x, size, stride, padding=0):
return pool.pool(x, size, 'maximum', padding, stride)
@ -180,8 +181,14 @@ jt.Var.__setitem__ = setitem
def getitem(x, slices):
if isinstance(slices, jt.Var) and slices.dtype == "bool":
return getitem(x, slices.where())
if isinstance(slices, list):
slices = tuple(slices)
if isinstance(slices, Sequence):
ss = []
for s in slices:
if isinstance(s, jt.Var) and s.dtype == "bool":
ss.extend(s.where())
else:
ss.append(s)
slices = tuple(ss)
return x.getitem(slices)
def setitem(x, slices, value):