support `jt.any` with argument `dim`

This commit is contained in:
CHEN Xinsheng 2025-02-18 23:27:30 +08:00
parent e0047b6fbd
commit e1052067c0
2 changed files with 63 additions and 4 deletions

View File

@ -312,11 +312,14 @@ def change_function():
def gather_acl(input, dim, index):
return GatherACL()(input, dim, index)
def any_acl(input):
if jt.sum(input != 0).item() > 0:
return jt.array([True])
def any_acl(input, dim=None):
if dim is None:
if jt.sum(input != 0).item() > 0:
return jt.array([True])
else:
return jt.array([False])
else:
return jt.array([False])
return jt.sum(input != 0, dim=dim) > 0
from .aclops.cumsum_op import CumsumACL

View File

@ -559,6 +559,62 @@ class TestACL(unittest.TestCase):
assert b.item() == True
print("test any (test case 5) success")
@jt.flag_scope(use_acl=1)
def test_any_6(self):
a = jt.array([[False, True, False], [False, False, True],
[True, True, False]])
b = self.measure_time(lambda: a.any())
assert b.item() == True
print("test any (test case 6) success")
@jt.flag_scope(use_acl=1)
def test_any_7(self):
a = jt.array([[False, False, False], [False, False, True],
[True, True, False]])
b = self.measure_time(lambda: jt.any(a, dim=1))
assert (b.numpy() == [False, True, True]).all(), "Unexpected result for dim=1"
print("test any (test case 7) success")
@jt.flag_scope(use_acl=1)
def test_any_8(self):
a = jt.array([[False, True, False], [False, False, True],
[False, True, False]])
b = self.measure_time(lambda: jt.any(a, dim=0))
assert (b.numpy() == [False, True, True]).all(), "Unexpected result for dim=0"
print("test any (test case 8) success")
@jt.flag_scope(use_acl=1)
def test_any_9(self):
a = jt.array([[False, True, False], [False, False, True],
[False, True, False]])
b = self.measure_time(lambda: a.any(dim=0))
assert (b.numpy() == [False, True, True]).all(), "Unexpected result for dim=0"
print("test any (test case 9) success")
@jt.flag_scope(use_acl=1)
def test_any_10(self):
# 测试在 dim=0 上检查每列是否有非零元素
a = jt.array([[0, 1, 0], [0, 0, 0]])
b = self.measure_time(lambda: jt.any(a, dim=0))
assert (b.numpy() == [False, True, False]).all(), "Unexpected result for dim=0"
print("test any (test case 10) success")
@jt.flag_scope(use_acl=1)
def test_any_11(self):
# 测试在 dim=0 上检查每列是否有非零元素
a = jt.array([[0.0, 1.0, -1.0], [0, 0, 0]])
b = self.measure_time(lambda: jt.any(a, dim=0))
assert (b.numpy() == [False, True, True]).all(), "Unexpected result for dim=0"
print("test any (test case 11) success")
@jt.flag_scope(use_acl=1)
def test_any_12(self):
# 测试在 dim=0 上检查每列是否有非零元素
a = jt.array([[0.0, 1.0, -1.0], [0, 0, 0]])
b = self.measure_time(lambda: jt.any(a, dim=1))
assert (b.numpy() == [True, False]).all(), "Unexpected result for dim=0"
print("test any (test case 12) success")
@jt.flag_scope(use_acl=1)
def test_scatter(self):
a = jt.array([[1, 2], [3, 4]])