mirror of https://github.com/Jittor/Jittor
support `jt.any` with argument `dim`
This commit is contained in:
parent
e0047b6fbd
commit
e1052067c0
|
@ -312,11 +312,14 @@ def change_function():
|
|||
def gather_acl(input, dim, index):
|
||||
return GatherACL()(input, dim, index)
|
||||
|
||||
def any_acl(input):
|
||||
if jt.sum(input != 0).item() > 0:
|
||||
return jt.array([True])
|
||||
def any_acl(input, dim=None):
|
||||
if dim is None:
|
||||
if jt.sum(input != 0).item() > 0:
|
||||
return jt.array([True])
|
||||
else:
|
||||
return jt.array([False])
|
||||
else:
|
||||
return jt.array([False])
|
||||
return jt.sum(input != 0, dim=dim) > 0
|
||||
|
||||
from .aclops.cumsum_op import CumsumACL
|
||||
|
||||
|
|
|
@ -559,6 +559,62 @@ class TestACL(unittest.TestCase):
|
|||
assert b.item() == True
|
||||
print("test any (test case 5) success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_any_6(self):
|
||||
a = jt.array([[False, True, False], [False, False, True],
|
||||
[True, True, False]])
|
||||
b = self.measure_time(lambda: a.any())
|
||||
assert b.item() == True
|
||||
print("test any (test case 6) success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_any_7(self):
|
||||
a = jt.array([[False, False, False], [False, False, True],
|
||||
[True, True, False]])
|
||||
b = self.measure_time(lambda: jt.any(a, dim=1))
|
||||
assert (b.numpy() == [False, True, True]).all(), "Unexpected result for dim=1"
|
||||
print("test any (test case 7) success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_any_8(self):
|
||||
a = jt.array([[False, True, False], [False, False, True],
|
||||
[False, True, False]])
|
||||
b = self.measure_time(lambda: jt.any(a, dim=0))
|
||||
assert (b.numpy() == [False, True, True]).all(), "Unexpected result for dim=0"
|
||||
print("test any (test case 8) success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_any_9(self):
|
||||
a = jt.array([[False, True, False], [False, False, True],
|
||||
[False, True, False]])
|
||||
b = self.measure_time(lambda: a.any(dim=0))
|
||||
assert (b.numpy() == [False, True, True]).all(), "Unexpected result for dim=0"
|
||||
print("test any (test case 9) success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_any_10(self):
|
||||
# 测试在 dim=0 上检查每列是否有非零元素
|
||||
a = jt.array([[0, 1, 0], [0, 0, 0]])
|
||||
b = self.measure_time(lambda: jt.any(a, dim=0))
|
||||
assert (b.numpy() == [False, True, False]).all(), "Unexpected result for dim=0"
|
||||
print("test any (test case 10) success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_any_11(self):
|
||||
# 测试在 dim=0 上检查每列是否有非零元素
|
||||
a = jt.array([[0.0, 1.0, -1.0], [0, 0, 0]])
|
||||
b = self.measure_time(lambda: jt.any(a, dim=0))
|
||||
assert (b.numpy() == [False, True, True]).all(), "Unexpected result for dim=0"
|
||||
print("test any (test case 11) success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_any_12(self):
|
||||
# 测试在 dim=0 上检查每列是否有非零元素
|
||||
a = jt.array([[0.0, 1.0, -1.0], [0, 0, 0]])
|
||||
b = self.measure_time(lambda: jt.any(a, dim=1))
|
||||
assert (b.numpy() == [True, False]).all(), "Unexpected result for dim=0"
|
||||
print("test any (test case 12) success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_scatter(self):
|
||||
a = jt.array([[1, 2], [3, 4]])
|
||||
|
|
Loading…
Reference in New Issue