mirror of https://github.com/Jittor/Jittor
add index_add_
This commit is contained in:
parent
3e0396a14e
commit
6dde7e40cb
|
@ -13,6 +13,12 @@ import numpy as np
|
|||
import math
|
||||
from collections.abc import Sequence,Iterable
|
||||
|
||||
def index_add_(x, dim, index, tensor):
|
||||
assert len(index.shape) == 1
|
||||
assert tensor.shape[0] == index.shape[0]
|
||||
x[(slice(None,),)*dim+(index,)] += tensor
|
||||
jt.Var.index_add_ = index_add_
|
||||
|
||||
def __copy__(x):
|
||||
return x.copy().detach()
|
||||
jt.Var.__copy__ = __copy__
|
||||
|
|
|
@ -31,6 +31,22 @@ def check_equal(res1, res2, eps=1e-5):
|
|||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestPad(unittest.TestCase):
|
||||
def test_index_add_(self):
|
||||
x = np.ones((5,3))
|
||||
a1 = torch.Tensor(x)
|
||||
a1.index_add_(0, torch.tensor([0,4,2]), torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float))
|
||||
a2 = jt.array(x)
|
||||
a2.index_add_(0, jt.array([0,4,2]), jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
|
||||
check_equal(a1, a2)
|
||||
|
||||
x = np.ones((3,5))
|
||||
a1 = torch.Tensor(x)
|
||||
a1.index_add_(1, torch.tensor([0,4,2]), torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float))
|
||||
a2 = jt.array(x)
|
||||
a2.index_add_(1, jt.array([0,4,2]), jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
|
||||
check_equal(a1, a2)
|
||||
print('pass index_add_ test ...')
|
||||
|
||||
def test_repeat(self):
|
||||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(torch.Tensor(arr).repeat(1,2,3,4), jt.array(arr).repeat(1,2,3,4))
|
||||
|
|
Loading…
Reference in New Issue