mirror of https://github.com/Jittor/Jittor
commit
655f3cc090
|
@ -13,6 +13,28 @@ import numpy as np
|
|||
import math
|
||||
from collections.abc import Sequence,Iterable
|
||||
|
||||
def index_add_(x, dim, index, tensor):
|
||||
""" Take out each index subscript vector of the dim dimension and add the corresponding tensor variable.
|
||||
|
||||
Example:
|
||||
|
||||
x = jt.ones((5,3))
|
||||
tensor = jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
index = jt.array([0,4,2])
|
||||
x.index_add_(0, index, tensor)
|
||||
print(x)
|
||||
|
||||
>>> jt.Var([[ 2., 3., 4.],
|
||||
[ 1., 1., 1.],
|
||||
[ 8., 9., 10.],
|
||||
[ 1., 1., 1.],
|
||||
[ 5., 6., 7.]])
|
||||
"""
|
||||
assert len(index.shape) == 1
|
||||
assert tensor.shape[0] == index.shape[0]
|
||||
x[(slice(None,),)*dim+(index,)] += tensor
|
||||
jt.Var.index_add_ = index_add_
|
||||
|
||||
def __copy__(x):
|
||||
return x.copy().detach()
|
||||
jt.Var.__copy__ = __copy__
|
||||
|
|
|
@ -31,6 +31,22 @@ def check_equal(res1, res2, eps=1e-5):
|
|||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestPad(unittest.TestCase):
|
||||
def test_index_add_(self):
|
||||
x = np.ones((5,3))
|
||||
a1 = torch.Tensor(x)
|
||||
a1.index_add_(0, torch.tensor([0,4,2]), torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float))
|
||||
a2 = jt.array(x)
|
||||
a2.index_add_(0, jt.array([0,4,2]), jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
|
||||
check_equal(a1, a2)
|
||||
|
||||
x = np.ones((3,5))
|
||||
a1 = torch.Tensor(x)
|
||||
a1.index_add_(1, torch.tensor([0,4,2]), torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float))
|
||||
a2 = jt.array(x)
|
||||
a2.index_add_(1, jt.array([0,4,2]), jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
|
||||
check_equal(a1, a2)
|
||||
print('pass index_add_ test ...')
|
||||
|
||||
def test_repeat(self):
|
||||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(torch.Tensor(arr).repeat(1,2,3,4), jt.array(arr).repeat(1,2,3,4))
|
||||
|
|
Loading…
Reference in New Issue