update cusparse trans

This commit is contained in:
lusz 2024-12-29 17:23:49 +08:00
parent 1bf6f73d4c
commit 02c3173def
1 changed files with 0 additions and 3 deletions

View File

@ -96,14 +96,11 @@ class TestSpmmCsrOp(unittest.TestCase):
edge_index=jt.array([[0,0,1,2],[1,2,2,1]],dtype="int32")
row_indices=edge_index[0,:]
col_indices=edge_index[1,:]
# print(row_indices)
# print(col_indices)
edge_weight = jt.array([1.0, 1.0, 1.0, 1.0], dtype="float32")
feature_dim=jt.size(x,1)
output=jt.zeros(3,feature_dim)
cusparse_ops.cusparse_spmmcoo(output,x,row_indices,col_indices,edge_weight,3,3,False, False).fetch_sync()
print("Output:", output)
# 定义预期的输出,需根据具体运算和实现来调整
expected_output = np.array([
[5.0, 4.0, 5.0],
[1.0, 2.0, 3.0],