add index_add_

This commit is contained in:
zhouwy19 2021-03-15 21:48:08 +08:00
parent 3e0396a14e
commit 6dde7e40cb
2 changed files with 22 additions and 0 deletions

View File

@ -13,6 +13,12 @@ import numpy as np
import math
from collections.abc import Sequence,Iterable
def index_add_(x, dim, index, tensor):
assert len(index.shape) == 1
assert tensor.shape[0] == index.shape[0]
x[(slice(None,),)*dim+(index,)] += tensor
jt.Var.index_add_ = index_add_
def __copy__(x):
return x.copy().detach()
jt.Var.__copy__ = __copy__

View File

@ -31,6 +31,22 @@ def check_equal(res1, res2, eps=1e-5):
@unittest.skipIf(skip_this_test, "No Torch found")
class TestPad(unittest.TestCase):
def test_index_add_(self):
x = np.ones((5,3))
a1 = torch.Tensor(x)
a1.index_add_(0, torch.tensor([0,4,2]), torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float))
a2 = jt.array(x)
a2.index_add_(0, jt.array([0,4,2]), jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
check_equal(a1, a2)
x = np.ones((3,5))
a1 = torch.Tensor(x)
a1.index_add_(1, torch.tensor([0,4,2]), torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float))
a2 = jt.array(x)
a2.index_add_(1, jt.array([0,4,2]), jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
check_equal(a1, a2)
print('pass index_add_ test ...')
def test_repeat(self):
arr = np.random.randn(16,3,224,224)
check_equal(torch.Tensor(arr).repeat(1,2,3,4), jt.array(arr).repeat(1,2,3,4))