mirror of https://github.com/Jittor/Jittor
fix slices bugs
This commit is contained in:
parent
7c0798d057
commit
7f0444d0e0
|
@ -10,6 +10,7 @@
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jittor import pool
|
from jittor import pool
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
def argmax_pool(x, size, stride, padding=0):
|
def argmax_pool(x, size, stride, padding=0):
|
||||||
return pool.pool(x, size, 'maximum', padding, stride)
|
return pool.pool(x, size, 'maximum', padding, stride)
|
||||||
|
@ -180,8 +181,14 @@ jt.Var.__setitem__ = setitem
|
||||||
def getitem(x, slices):
|
def getitem(x, slices):
|
||||||
if isinstance(slices, jt.Var) and slices.dtype == "bool":
|
if isinstance(slices, jt.Var) and slices.dtype == "bool":
|
||||||
return getitem(x, slices.where())
|
return getitem(x, slices.where())
|
||||||
if isinstance(slices, list):
|
if isinstance(slices, Sequence):
|
||||||
slices = tuple(slices)
|
ss = []
|
||||||
|
for s in slices:
|
||||||
|
if isinstance(s, jt.Var) and s.dtype == "bool":
|
||||||
|
ss.extend(s.where())
|
||||||
|
else:
|
||||||
|
ss.append(s)
|
||||||
|
slices = tuple(ss)
|
||||||
return x.getitem(slices)
|
return x.getitem(slices)
|
||||||
|
|
||||||
def setitem(x, slices, value):
|
def setitem(x, slices, value):
|
||||||
|
|
Loading…
Reference in New Issue