update cusparse trans

This commit is contained in:
lusz 2024-12-29 16:48:34 +08:00
parent 8419709e31
commit 1bf6f73d4c
6 changed files with 35 additions and 17 deletions

View File

@ -35,5 +35,8 @@ static inline cudaDataType get_dtype(NanoString dtype) {
LOGf << "not support type" << dtype;
return CUDA_R_32F;
}
static inline cusparseOperation_t get_trans_type(bool trans) {
if (trans == true) return CUSPARSE_OPERATION_TRANSPOSE;
else return CUSPARSE_OPERATION_NON_TRANSPOSE;
}
} // jittor

View File

@ -12,8 +12,8 @@ using namespace std;
namespace jittor {
#ifndef JIT
CusparseSpmmcooOp::CusparseSpmmcooOp(Var* outputVar_, Var* x_, Var* row_indices_,Var* col_indices_,Var* value_,int A_row_,int A_col_)
: outputVar(outputVar_), x(x_),row_indices(row_indices_), col_indices(col_indices_), value(value_),A_row(A_row_),A_col(A_col_) {
CusparseSpmmcooOp::CusparseSpmmcooOp(Var* outputVar_, Var* x_, Var* row_indices_,Var* col_indices_,Var* value_,int A_row_,int A_col_,bool trans_A_,bool trans_B_)
: outputVar(outputVar_), x(x_),row_indices(row_indices_), col_indices(col_indices_), value(value_),A_row(A_row_),A_col(A_col_),trans_A(trans_A_),trans_B(trans_B_) {
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_manual_set_vnbb);
@ -56,8 +56,8 @@ void CusparseSpmmcooOp::jit_run() {
// CUSPARSE_SPMM_ALG_DEFAULT , &bufferSize) );
// checkCudaErrors( cudaMalloc(&dBuffer, bufferSize) );
checkCudaErrors( cusparseSpMM(handle_,
CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE,
get_trans_type(trans_A),
get_trans_type(trans_B),
&alpha, matA, matB, &beta, matC, CUDA_R_32F,
CUSPARSE_SPMM_ALG_DEFAULT, NULL) );
// checkCudaErrors( cudaFree(dBuffer) );

View File

@ -18,7 +18,9 @@ struct CusparseSpmmcooOp : Op {
Var* output;
int A_row;
int A_col;
CusparseSpmmcooOp(Var* outputVar_, Var* x_, Var* row_indices_,Var* col_indices_,Var* value_,int A_row,int A_col);
bool trans_A;
bool trans_B;
CusparseSpmmcooOp(Var* outputVar_, Var* x_, Var* row_indices_,Var* col_indices_,Var* value_,int A_row,int A_col,bool trans_A,bool trans_B);
const char* name() const override { return "cusparse_spmmcoo"; }
DECLARE_jit_run;
};

View File

@ -13,8 +13,8 @@ namespace jittor {
#ifndef JIT
CusparseSpmmcsrOp::CusparseSpmmcsrOp(Var* outputVar_, Var* x_, Var* col_indices_,Var* value_,Var* row_offset_,int A_row_,int A_col_)
: outputVar(outputVar_), x(x_), col_indices(col_indices_), value(value_),row_offset(row_offset_),A_row(A_row_),A_col(A_col_){
CusparseSpmmcsrOp::CusparseSpmmcsrOp(Var* outputVar_, Var* x_, Var* col_indices_,Var* value_,Var* row_offset_,int A_row_,int A_col_,bool trans_A_,bool trans_B_)
: outputVar(outputVar_), x(x_), col_indices(col_indices_), value(value_),row_offset(row_offset_),A_row(A_row_),A_col(A_col_),trans_A(trans_A_),trans_B(trans_B_){
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_manual_set_vnbb);
@ -49,14 +49,22 @@ void CusparseSpmmcsrOp::jit_run() {
checkCudaErrors( cusparseCreateDnMat(&matC, os[0], os[1],os[1], outputVar->ptr<T>(), dtype_C, CUSPARSE_ORDER_ROW) );
float alpha = 1.0f;
float beta = 0.0f;
// checkCudaErrors( cusparseSpMM_bufferSize(
// handle_,
// CUSPARSE_OPERATION_NON_TRANSPOSE,
// CUSPARSE_OPERATION_NON_TRANSPOSE,
// &alpha, matA, matB, &beta, matC, dtype_C,
// CUSPARSE_SPMM_CSR_ALG2, &bufferSize) );
checkCudaErrors( cusparseSpMM_bufferSize(
handle_,
CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE,
get_trans_type(trans_A),
get_trans_type(trans_B),
&alpha, matA, matB, &beta, matC, CUDA_R_32F,
CUSPARSE_SPMM_CSR_ALG2, &bufferSize) );
checkCudaErrors( cudaMalloc(&dBuffer, bufferSize) );
checkCudaErrors(cusparseSpMM(handle_, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, CUDA_R_32F, CUSPARSE_SPMM_CSR_ALG2 , dBuffer));
// checkCudaErrors(cusparseSpMM(handle_, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, dtype_C, CUSPARSE_SPMM_CSR_ALG2 , dBuffer)); //CUSPARSE_MM_ALG_DEFAULT, CUSPARSE_SPMM_CSR_ALG2 , CUSPARSE_SPMM_COO_ALG4
checkCudaErrors(cusparseSpMM(handle_, get_trans_type(trans_A), get_trans_type(trans_B), &alpha, matA, matB, &beta, matC, CUDA_R_32F, CUSPARSE_SPMM_CSR_ALG2 , dBuffer));
checkCudaErrors( cudaFree(dBuffer) );
checkCudaErrors( cusparseDestroySpMat(matA) );
checkCudaErrors( cusparseDestroyDnMat(matB) );

View File

@ -18,7 +18,9 @@ struct CusparseSpmmcsrOp : Op {
Var* output;
int A_row;
int A_col;
CusparseSpmmcsrOp(Var* outputVar_, Var* x_, Var* col_indices_,Var* value_,Var* row_offset_,int A_row,int A_col);
bool trans_A;
bool trans_B;
CusparseSpmmcsrOp(Var* outputVar_, Var* x_, Var* col_indices_,Var* value_,Var* row_offset_,int A_row,int A_col,bool trans_A,bool trans_B);
const char* name() const override { return "cusparse_spmmcsr"; }
DECLARE_jit_run;
};

View File

@ -22,7 +22,7 @@ class TestSpmmCsrOp(unittest.TestCase):
output = jt.zeros((3, 3), dtype="float32")
cusparse_ops.cusparse_spmmcsr(
output, x, col_indices, csr_weight, row_offset,
3, 3
3, 3 ,False, False
).fetch_sync()
expected_output = np.array([
[12.0, 8.0, 4.0],
@ -41,7 +41,7 @@ class TestSpmmCsrOp(unittest.TestCase):
output = jt.zeros((3, 3), dtype="float16")
cusparse_ops.cusparse_spmmcsr(
output, x, col_indices, csr_weight, row_offset,
3, 3
3, 3,False, False
).fetch_sync()
expected_output = np.array([
[12.0, 8.0, 4.0],
@ -60,7 +60,7 @@ class TestSpmmCsrOp(unittest.TestCase):
# output = jt.zeros((3, 3), dtype="float64")
# cusparse_ops.cusparse_spmmcsr(
# output, x, col_indices, csr_weight, row_offset,
# 3, 3
# 3, 3,False, False
# ).fetch_sync()
# expected_output = np.array([
# [12.0, 8.0, 4.0],
@ -80,7 +80,7 @@ class TestSpmmCsrOp(unittest.TestCase):
output = jt.zeros((3, 3), dtype="float32")
cusparse_ops.cusparse_spmmcsr(
output, x, col_indices, csr_weight, row_offset,
3, 3
3, 3,False, False
).fetch_sync()
expected_output = np.array([
[12.0, 8.0, 4.0],
@ -96,11 +96,14 @@ class TestSpmmCsrOp(unittest.TestCase):
edge_index=jt.array([[0,0,1,2],[1,2,2,1]],dtype="int32")
row_indices=edge_index[0,:]
col_indices=edge_index[1,:]
# print(row_indices)
# print(col_indices)
edge_weight = jt.array([1.0, 1.0, 1.0, 1.0], dtype="float32")
feature_dim=jt.size(x,1)
output=jt.zeros(3,feature_dim)
cusparse_ops.cusparse_spmmcoo(output,x,row_indices,col_indices,edge_weight,3,3).fetch_sync()
cusparse_ops.cusparse_spmmcoo(output,x,row_indices,col_indices,edge_weight,3,3,False, False).fetch_sync()
print("Output:", output)
# 定义预期的输出,需根据具体运算和实现来调整
expected_output = np.array([
[5.0, 4.0, 5.0],
[1.0, 2.0, 3.0],