mirror of https://github.com/Jittor/Jittor
update cusparse trans
This commit is contained in:
parent
8419709e31
commit
1bf6f73d4c
|
@ -35,5 +35,8 @@ static inline cudaDataType get_dtype(NanoString dtype) {
|
|||
LOGf << "not support type" << dtype;
|
||||
return CUDA_R_32F;
|
||||
}
|
||||
|
||||
static inline cusparseOperation_t get_trans_type(bool trans) {
|
||||
if (trans == true) return CUSPARSE_OPERATION_TRANSPOSE;
|
||||
else return CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
}
|
||||
} // jittor
|
|
@ -12,8 +12,8 @@ using namespace std;
|
|||
namespace jittor {
|
||||
#ifndef JIT
|
||||
|
||||
CusparseSpmmcooOp::CusparseSpmmcooOp(Var* outputVar_, Var* x_, Var* row_indices_,Var* col_indices_,Var* value_,int A_row_,int A_col_)
|
||||
: outputVar(outputVar_), x(x_),row_indices(row_indices_), col_indices(col_indices_), value(value_),A_row(A_row_),A_col(A_col_) {
|
||||
CusparseSpmmcooOp::CusparseSpmmcooOp(Var* outputVar_, Var* x_, Var* row_indices_,Var* col_indices_,Var* value_,int A_row_,int A_col_,bool trans_A_,bool trans_B_)
|
||||
: outputVar(outputVar_), x(x_),row_indices(row_indices_), col_indices(col_indices_), value(value_),A_row(A_row_),A_col(A_col_),trans_A(trans_A_),trans_B(trans_B_) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
|
@ -56,8 +56,8 @@ void CusparseSpmmcooOp::jit_run() {
|
|||
// CUSPARSE_SPMM_ALG_DEFAULT , &bufferSize) );
|
||||
// checkCudaErrors( cudaMalloc(&dBuffer, bufferSize) );
|
||||
checkCudaErrors( cusparseSpMM(handle_,
|
||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
get_trans_type(trans_A),
|
||||
get_trans_type(trans_B),
|
||||
&alpha, matA, matB, &beta, matC, CUDA_R_32F,
|
||||
CUSPARSE_SPMM_ALG_DEFAULT, NULL) );
|
||||
// checkCudaErrors( cudaFree(dBuffer) );
|
||||
|
|
|
@ -18,7 +18,9 @@ struct CusparseSpmmcooOp : Op {
|
|||
Var* output;
|
||||
int A_row;
|
||||
int A_col;
|
||||
CusparseSpmmcooOp(Var* outputVar_, Var* x_, Var* row_indices_,Var* col_indices_,Var* value_,int A_row,int A_col);
|
||||
bool trans_A;
|
||||
bool trans_B;
|
||||
CusparseSpmmcooOp(Var* outputVar_, Var* x_, Var* row_indices_,Var* col_indices_,Var* value_,int A_row,int A_col,bool trans_A,bool trans_B);
|
||||
const char* name() const override { return "cusparse_spmmcoo"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
|
|
@ -13,8 +13,8 @@ namespace jittor {
|
|||
|
||||
#ifndef JIT
|
||||
|
||||
CusparseSpmmcsrOp::CusparseSpmmcsrOp(Var* outputVar_, Var* x_, Var* col_indices_,Var* value_,Var* row_offset_,int A_row_,int A_col_)
|
||||
: outputVar(outputVar_), x(x_), col_indices(col_indices_), value(value_),row_offset(row_offset_),A_row(A_row_),A_col(A_col_){
|
||||
CusparseSpmmcsrOp::CusparseSpmmcsrOp(Var* outputVar_, Var* x_, Var* col_indices_,Var* value_,Var* row_offset_,int A_row_,int A_col_,bool trans_A_,bool trans_B_)
|
||||
: outputVar(outputVar_), x(x_), col_indices(col_indices_), value(value_),row_offset(row_offset_),A_row(A_row_),A_col(A_col_),trans_A(trans_A_),trans_B(trans_B_){
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
|
@ -49,14 +49,22 @@ void CusparseSpmmcsrOp::jit_run() {
|
|||
checkCudaErrors( cusparseCreateDnMat(&matC, os[0], os[1],os[1], outputVar->ptr<T>(), dtype_C, CUSPARSE_ORDER_ROW) );
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
// checkCudaErrors( cusparseSpMM_bufferSize(
|
||||
// handle_,
|
||||
// CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
// CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
// &alpha, matA, matB, &beta, matC, dtype_C,
|
||||
// CUSPARSE_SPMM_CSR_ALG2, &bufferSize) );
|
||||
checkCudaErrors( cusparseSpMM_bufferSize(
|
||||
handle_,
|
||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
get_trans_type(trans_A),
|
||||
get_trans_type(trans_B),
|
||||
&alpha, matA, matB, &beta, matC, CUDA_R_32F,
|
||||
CUSPARSE_SPMM_CSR_ALG2, &bufferSize) );
|
||||
checkCudaErrors( cudaMalloc(&dBuffer, bufferSize) );
|
||||
checkCudaErrors(cusparseSpMM(handle_, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, CUDA_R_32F, CUSPARSE_SPMM_CSR_ALG2 , dBuffer));
|
||||
|
||||
// checkCudaErrors(cusparseSpMM(handle_, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, dtype_C, CUSPARSE_SPMM_CSR_ALG2 , dBuffer)); //CUSPARSE_MM_ALG_DEFAULT, CUSPARSE_SPMM_CSR_ALG2 , CUSPARSE_SPMM_COO_ALG4
|
||||
checkCudaErrors(cusparseSpMM(handle_, get_trans_type(trans_A), get_trans_type(trans_B), &alpha, matA, matB, &beta, matC, CUDA_R_32F, CUSPARSE_SPMM_CSR_ALG2 , dBuffer));
|
||||
checkCudaErrors( cudaFree(dBuffer) );
|
||||
checkCudaErrors( cusparseDestroySpMat(matA) );
|
||||
checkCudaErrors( cusparseDestroyDnMat(matB) );
|
||||
|
|
|
@ -18,7 +18,9 @@ struct CusparseSpmmcsrOp : Op {
|
|||
Var* output;
|
||||
int A_row;
|
||||
int A_col;
|
||||
CusparseSpmmcsrOp(Var* outputVar_, Var* x_, Var* col_indices_,Var* value_,Var* row_offset_,int A_row,int A_col);
|
||||
bool trans_A;
|
||||
bool trans_B;
|
||||
CusparseSpmmcsrOp(Var* outputVar_, Var* x_, Var* col_indices_,Var* value_,Var* row_offset_,int A_row,int A_col,bool trans_A,bool trans_B);
|
||||
const char* name() const override { return "cusparse_spmmcsr"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
|
|
@ -22,7 +22,7 @@ class TestSpmmCsrOp(unittest.TestCase):
|
|||
output = jt.zeros((3, 3), dtype="float32")
|
||||
cusparse_ops.cusparse_spmmcsr(
|
||||
output, x, col_indices, csr_weight, row_offset,
|
||||
3, 3
|
||||
3, 3 ,False, False
|
||||
).fetch_sync()
|
||||
expected_output = np.array([
|
||||
[12.0, 8.0, 4.0],
|
||||
|
@ -41,7 +41,7 @@ class TestSpmmCsrOp(unittest.TestCase):
|
|||
output = jt.zeros((3, 3), dtype="float16")
|
||||
cusparse_ops.cusparse_spmmcsr(
|
||||
output, x, col_indices, csr_weight, row_offset,
|
||||
3, 3
|
||||
3, 3,False, False
|
||||
).fetch_sync()
|
||||
expected_output = np.array([
|
||||
[12.0, 8.0, 4.0],
|
||||
|
@ -60,7 +60,7 @@ class TestSpmmCsrOp(unittest.TestCase):
|
|||
# output = jt.zeros((3, 3), dtype="float64")
|
||||
# cusparse_ops.cusparse_spmmcsr(
|
||||
# output, x, col_indices, csr_weight, row_offset,
|
||||
# 3, 3
|
||||
# 3, 3,False, False
|
||||
# ).fetch_sync()
|
||||
# expected_output = np.array([
|
||||
# [12.0, 8.0, 4.0],
|
||||
|
@ -80,7 +80,7 @@ class TestSpmmCsrOp(unittest.TestCase):
|
|||
output = jt.zeros((3, 3), dtype="float32")
|
||||
cusparse_ops.cusparse_spmmcsr(
|
||||
output, x, col_indices, csr_weight, row_offset,
|
||||
3, 3
|
||||
3, 3,False, False
|
||||
).fetch_sync()
|
||||
expected_output = np.array([
|
||||
[12.0, 8.0, 4.0],
|
||||
|
@ -96,11 +96,14 @@ class TestSpmmCsrOp(unittest.TestCase):
|
|||
edge_index=jt.array([[0,0,1,2],[1,2,2,1]],dtype="int32")
|
||||
row_indices=edge_index[0,:]
|
||||
col_indices=edge_index[1,:]
|
||||
# print(row_indices)
|
||||
# print(col_indices)
|
||||
edge_weight = jt.array([1.0, 1.0, 1.0, 1.0], dtype="float32")
|
||||
feature_dim=jt.size(x,1)
|
||||
output=jt.zeros(3,feature_dim)
|
||||
cusparse_ops.cusparse_spmmcoo(output,x,row_indices,col_indices,edge_weight,3,3).fetch_sync()
|
||||
cusparse_ops.cusparse_spmmcoo(output,x,row_indices,col_indices,edge_weight,3,3,False, False).fetch_sync()
|
||||
print("Output:", output)
|
||||
# 定义预期的输出,需根据具体运算和实现来调整
|
||||
expected_output = np.array([
|
||||
[5.0, 4.0, 5.0],
|
||||
[1.0, 2.0, 3.0],
|
||||
|
|
Loading…
Reference in New Issue