Merge pull request #623 from 514flowey/master

Fix Get Item Problem. Warning: This change has not passed a completely check.
This commit is contained in:
Zikai Xiao 2025-06-10 21:52:37 +08:00 committed by GitHub
commit 330dec69d2
3 changed files with 51 additions and 1 deletions

View File

@ -456,7 +456,7 @@ def setup_cutt():
def install_cutlass(root_folder):
# Modified from: https://github.com/ap-hynninen/cutlass
url = "https://cloud.tsinghua.edu.cn/f/171e49e5825549548bc4/?dl=1"
url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/cutlass.zip"
filename = "cutlass.zip"
fullname = os.path.join(root_folder, filename)

View File

@ -572,6 +572,8 @@ void GetitemOp::jit_run() {
index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))])
, ??? ))))));
)
@for(d, 0, IDIM, if (iid@d < 0) iid@d += ishape@d;
)
auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d);
op[oid] = ip[iid];
}

View File

@ -0,0 +1,48 @@
# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
class TestGetItemSimple(unittest.TestCase):
def test_get_by_pos_int(self):
a = jt.array([-2,3,4,-5,-6])
b = a[3]
b.sync()
assert b.item() == -5
def test_get_by_neg_int(self):
a = jt.array([-2,3,4,-5,-6])
b = a[-3]
b.sync()
assert b.item() == 4
def test_get_slice(self):
a = jt.array([-2,3,4,-5,-6])
b = a[-1:-3:-1].numpy().tolist()
assert len(b) == 2
assert b[0] == -6
assert b[1] == -5
def test_get_by_list(self):
a = jt.array([-2,3,4,-5,-6])
b = a[[-1, -3, 1]].numpy().tolist()
assert len(b) == 3
assert b[0] == -6
assert b[1] == 4
assert b[2] == 3
def test_multidim_by_points(self):
a = jt.arange(24).reshape(2, 3, 4)
b = jt.array([0, 1, 0])
c = jt.array([0, -1, 1])
d = jt.array([-2, 0, 3])
e = a[(b, c, d)].numpy().tolist()
assert len(e) == 3
assert e[0] == 2
assert e[1] == 20
assert e[2] == 7
if __name__ == "__main__":
jt.flags.use_cuda = True
unittest.main()