mirror of https://github.com/Jittor/Jittor
update cusparse trans
This commit is contained in:
parent
1bf6f73d4c
commit
02c3173def
|
@ -96,14 +96,11 @@ class TestSpmmCsrOp(unittest.TestCase):
|
|||
edge_index=jt.array([[0,0,1,2],[1,2,2,1]],dtype="int32")
|
||||
row_indices=edge_index[0,:]
|
||||
col_indices=edge_index[1,:]
|
||||
# print(row_indices)
|
||||
# print(col_indices)
|
||||
edge_weight = jt.array([1.0, 1.0, 1.0, 1.0], dtype="float32")
|
||||
feature_dim=jt.size(x,1)
|
||||
output=jt.zeros(3,feature_dim)
|
||||
cusparse_ops.cusparse_spmmcoo(output,x,row_indices,col_indices,edge_weight,3,3,False, False).fetch_sync()
|
||||
print("Output:", output)
|
||||
# 定义预期的输出,需根据具体运算和实现来调整
|
||||
expected_output = np.array([
|
||||
[5.0, 4.0, 5.0],
|
||||
[1.0, 2.0, 3.0],
|
||||
|
|
Loading…
Reference in New Issue