Merge pull request #185 from Jittor/zwy1

add index_add_
This commit is contained in:
zhouwy19 2021-03-16 11:21:23 +08:00 committed by GitHub
commit 655f3cc090
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 0 deletions

View File

@ -13,6 +13,28 @@ import numpy as np
import math
from collections.abc import Sequence,Iterable
def index_add_(x, dim, index, tensor):
""" Take out each index subscript vector of the dim dimension and add the corresponding tensor variable.
Example:
x = jt.ones((5,3))
tensor = jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
index = jt.array([0,4,2])
x.index_add_(0, index, tensor)
print(x)
>>> jt.Var([[ 2., 3., 4.],
[ 1., 1., 1.],
[ 8., 9., 10.],
[ 1., 1., 1.],
[ 5., 6., 7.]])
"""
assert len(index.shape) == 1
assert tensor.shape[0] == index.shape[0]
x[(slice(None,),)*dim+(index,)] += tensor
jt.Var.index_add_ = index_add_
def __copy__(x):
return x.copy().detach()
jt.Var.__copy__ = __copy__

View File

@ -31,6 +31,22 @@ def check_equal(res1, res2, eps=1e-5):
@unittest.skipIf(skip_this_test, "No Torch found")
class TestPad(unittest.TestCase):
def test_index_add_(self):
x = np.ones((5,3))
a1 = torch.Tensor(x)
a1.index_add_(0, torch.tensor([0,4,2]), torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float))
a2 = jt.array(x)
a2.index_add_(0, jt.array([0,4,2]), jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
check_equal(a1, a2)
x = np.ones((3,5))
a1 = torch.Tensor(x)
a1.index_add_(1, torch.tensor([0,4,2]), torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float))
a2 = jt.array(x)
a2.index_add_(1, jt.array([0,4,2]), jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
check_equal(a1, a2)
print('pass index_add_ test ...')
def test_repeat(self):
arr = np.random.randn(16,3,224,224)
check_equal(torch.Tensor(arr).repeat(1,2,3,4), jt.array(arr).repeat(1,2,3,4))