add index_add_ doc

This commit is contained in:
zhouwy19 2021-03-16 10:27:26 +08:00
parent 6dde7e40cb
commit a09c988299
1 changed files with 16 additions and 0 deletions

View File

@ -14,6 +14,22 @@ import math
from collections.abc import Sequence,Iterable
def index_add_(x, dim, index, tensor):
""" Take out each index subscript vector of the dim dimension and add the corresponding tensor variable.
Example:
x = jt.ones((5,3))
tensor = jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
index = jt.array([0,4,2])
x.index_add_(0, index, tensor)
print(x)
>>> jt.Var([[ 2., 3., 4.],
[ 1., 1., 1.],
[ 8., 9., 10.],
[ 1., 1., 1.],
[ 5., 6., 7.]])
"""
assert len(index.shape) == 1
assert tensor.shape[0] == index.shape[0]
x[(slice(None,),)*dim+(index,)] += tensor