mirror of https://github.com/Jittor/Jittor
fix setitem
This commit is contained in:
parent
d31cdab30f
commit
84c10c7a06
|
@ -188,7 +188,7 @@ def setitem(x, slices, value):
|
||||||
if isinstance(slices, jt.Var) and slices.dtype == "bool":
|
if isinstance(slices, jt.Var) and slices.dtype == "bool":
|
||||||
mask = jt.broadcast(slices, x)
|
mask = jt.broadcast(slices, x)
|
||||||
value = jt.broadcast(value, x)
|
value = jt.broadcast(value, x)
|
||||||
return x.assign(mask.ternary(value, mask))
|
return x.assign(mask.ternary(value, x))
|
||||||
if isinstance(slices, list):
|
if isinstance(slices, list):
|
||||||
slices = tuple(slices)
|
slices = tuple(slices)
|
||||||
return x.assign(x.setitem(slices, value))
|
return x.assign(x.setitem(slices, value))
|
||||||
|
|
Loading…
Reference in New Issue