mirror of https://github.com/Jittor/Jittor
fix slices bugs
This commit is contained in:
parent
7c0798d057
commit
7f0444d0e0
|
@ -10,6 +10,7 @@
|
|||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor import pool
|
||||
from collections.abc import Sequence
|
||||
|
||||
def argmax_pool(x, size, stride, padding=0):
|
||||
return pool.pool(x, size, 'maximum', padding, stride)
|
||||
|
@ -180,8 +181,14 @@ jt.Var.__setitem__ = setitem
|
|||
def getitem(x, slices):
|
||||
if isinstance(slices, jt.Var) and slices.dtype == "bool":
|
||||
return getitem(x, slices.where())
|
||||
if isinstance(slices, list):
|
||||
slices = tuple(slices)
|
||||
if isinstance(slices, Sequence):
|
||||
ss = []
|
||||
for s in slices:
|
||||
if isinstance(s, jt.Var) and s.dtype == "bool":
|
||||
ss.extend(s.where())
|
||||
else:
|
||||
ss.append(s)
|
||||
slices = tuple(ss)
|
||||
return x.getitem(slices)
|
||||
|
||||
def setitem(x, slices, value):
|
||||
|
|
Loading…
Reference in New Issue