mirror of https://github.com/Jittor/Jittor
add index_add_ doc
This commit is contained in:
parent
6dde7e40cb
commit
a09c988299
|
@ -14,6 +14,22 @@ import math
|
|||
from collections.abc import Sequence,Iterable
|
||||
|
||||
def index_add_(x, dim, index, tensor):
|
||||
""" Take out each index subscript vector of the dim dimension and add the corresponding tensor variable.
|
||||
|
||||
Example:
|
||||
|
||||
x = jt.ones((5,3))
|
||||
tensor = jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
index = jt.array([0,4,2])
|
||||
x.index_add_(0, index, tensor)
|
||||
print(x)
|
||||
|
||||
>>> jt.Var([[ 2., 3., 4.],
|
||||
[ 1., 1., 1.],
|
||||
[ 8., 9., 10.],
|
||||
[ 1., 1., 1.],
|
||||
[ 5., 6., 7.]])
|
||||
"""
|
||||
assert len(index.shape) == 1
|
||||
assert tensor.shape[0] == index.shape[0]
|
||||
x[(slice(None,),)*dim+(index,)] += tensor
|
||||
|
|
Loading…
Reference in New Issue