Add Index Check for Get Item. Warning: It may slow down the speed, and has not passed a fully check!

This commit is contained in:
514flowey 2024-12-16 22:19:55 +08:00
parent 9ee61d26f1
commit 9e7e479df2
2 changed files with 50 additions and 0 deletions

View File

@ -572,6 +572,8 @@ void GetitemOp::jit_run() {
index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))])
, ??? ))))));
)
@for(d, 0, IDIM, if (iid@d < 0) iid@d += ishape@d;
)
auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d);
op[oid] = ip[iid];
}

View File

@ -0,0 +1,48 @@
# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
class TestGetItemSimple(unittest.TestCase):
def test_get_by_pos_int(self):
a = jt.array([-2,3,4,-5,-6])
b = a[3]
b.sync()
assert b.item() == -5
def test_get_by_neg_int(self):
a = jt.array([-2,3,4,-5,-6])
b = a[-3]
b.sync()
assert b.item() == 4
def test_get_slice(self):
a = jt.array([-2,3,4,-5,-6])
b = a[-1:-3:-1].numpy().tolist()
assert len(b) == 2
assert b[0] == -6
assert b[1] == -5
def test_get_by_list(self):
a = jt.array([-2,3,4,-5,-6])
b = a[[-1, -3, 1]].numpy().tolist()
assert len(b) == 3
assert b[0] == -6
assert b[1] == 4
assert b[2] == 3
def test_multidim_by_points(self):
a = jt.arange(24).reshape(2, 3, 4)
b = jt.array([0, 1, 0])
c = jt.array([0, -1, 1])
d = jt.array([-2, 0, 3])
e = a[(b, c, d)].numpy().tolist()
assert len(e) == 3
assert e[0] == 2
assert e[1] == 20
assert e[2] == 7
if __name__ == "__main__":
jt.flags.use_cuda = True
unittest.main()