mirror of https://github.com/Jittor/Jittor
fix setitem bugs
This commit is contained in:
parent
d7cad0ca82
commit
fded962b54
|
@ -12,6 +12,7 @@ import numpy as np
|
|||
from jittor import pool
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
def argmax_pool(x, size, stride, padding=0):
|
||||
return pool.pool(x, size, 'maximum', padding, stride)
|
||||
|
||||
|
@ -196,8 +197,14 @@ def setitem(x, slices, value):
|
|||
mask = jt.broadcast(slices, x)
|
||||
value = jt.broadcast(value, x)
|
||||
return x.assign(mask.ternary(value, x))
|
||||
if isinstance(slices, list):
|
||||
slices = tuple(slices)
|
||||
if isinstance(slices, Sequence):
|
||||
ss = []
|
||||
for s in slices:
|
||||
if isinstance(s, jt.Var) and s.dtype == "bool":
|
||||
ss.extend(s.where())
|
||||
else:
|
||||
ss.append(s)
|
||||
slices = tuple(ss)
|
||||
return x.assign(x.setitem(slices, value))
|
||||
|
||||
jt.Var.__getitem__ = jt.Var.slice_var = getitem
|
||||
|
|
Loading…
Reference in New Issue