mirror of https://github.com/Jittor/Jittor
Merge branch 'Jittor:master' into master
This commit is contained in:
commit
4225804df2
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.9.12'
|
||||
__version__ = '1.3.9.13'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -26,7 +26,7 @@ with lock.lock_scope():
|
|||
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank, world_size
|
||||
if core.get_device_count() == 0:
|
||||
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
|
||||
from .compile_extern import cudnn, curand, cublas, cufft
|
||||
from .compile_extern import cudnn, curand, cublas, cufft, cusparse
|
||||
from .init_cupy import numpy2cupy
|
||||
|
||||
from typing import List, Tuple
|
||||
|
|
|
@ -2,7 +2,7 @@ from jittor_core import *
|
|||
from jittor_core.ops import *
|
||||
from .misc import *
|
||||
from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse, gradfunctional as gradfunctional
|
||||
from .compile_extern import cublas as cublas, cudnn as cudnn, cufft as cufft, curand as curand, mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size
|
||||
from .compile_extern import cublas as cublas, cudnn as cudnn, cufft as cufft, curand as curand, cusparse as cusparse ,mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size
|
||||
from .compiler import compile_custom_op as compile_custom_op, compile_custom_ops as compile_custom_ops
|
||||
from .contrib import concat as concat
|
||||
from .nn import bmm as bmm, bmm_transpose as bmm_transpose, matmul as matmul
|
||||
|
|
|
@ -224,7 +224,7 @@ def setup_cuda_extern():
|
|||
line = traceback.format_exc()
|
||||
LOG.w(f"CUDA found but cub is not loaded:\n{line}")
|
||||
|
||||
libs = ["cublas", "cudnn", "curand", "cufft"]
|
||||
libs = ["cublas", "cudnn", "curand", "cufft", "cusparse"]
|
||||
# in cuda 11.4, module memory comsumptions:
|
||||
# default context: 259 MB
|
||||
# cublas: 340 MB
|
||||
|
@ -240,6 +240,9 @@ def setup_cuda_extern():
|
|||
msg += """Develop version of CUDNN not found,
|
||||
please refer to CUDA offical tar file installation:
|
||||
https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar"""
|
||||
if lib_name == "cusparse":
|
||||
msg += """CUSPARSE library is not loaded,
|
||||
please ensure it is installed along with the CUDA toolkit."""
|
||||
if platform.machine() in ["x86_64", "AMD64"]:
|
||||
msg += f"""
|
||||
or you can let jittor install cuda and cudnn for you:
|
||||
|
@ -300,6 +303,13 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
link_flags = f"-l{lib_name} -L\"{os.path.dirname(culib_path)}\""
|
||||
# print("link_flags", link_flags, culib_path)
|
||||
|
||||
if lib_name == "cusparse" :
|
||||
try:
|
||||
cusparse_spmv_path = search_file([cuda_lib, extra_lib_path], "libcusparse.so")
|
||||
ctypes.CDLL(cusparse_spmv_path, dlopen_flags)
|
||||
except:
|
||||
LOG.w("Failed to load cusparse-specific shared libraries.")
|
||||
|
||||
# find all source files
|
||||
culib_src_dir = os.path.join(jittor_path, "extern", "cuda", lib_name)
|
||||
culib_src_files = []
|
||||
|
@ -693,7 +703,7 @@ if FIX_TORCH_ERROR:
|
|||
except:
|
||||
pass
|
||||
|
||||
cudnn = cublas = curand = cufft = None
|
||||
cudnn = cublas = curand = cufft = cusparse = None
|
||||
setup_mpi()
|
||||
rank = mpi.world_rank() if in_mpi else 0
|
||||
world_size = mpi.world_size() if in_mpi else 1
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
// Maintainers: Shizhan Lu <578752274@qq.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <cuda_runtime.h>
|
||||
#include <cusparse.h>
|
||||
|
||||
#include "utils/log.h"
|
||||
#include "helper_cuda.h"
|
||||
#include "fp16_emu.h"
|
||||
#include "common.h"
|
||||
#include "misc/nano_string.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
EXTERN_LIB cusparseHandle_t cusparse_handle;
|
||||
|
||||
static inline cusparseIndexType_t get_index_dtype(NanoString dtype) {
|
||||
if (dtype == ns_int32) return CUSPARSE_INDEX_32I;
|
||||
if (dtype == ns_int64) return CUSPARSE_INDEX_64I;
|
||||
LOGf << "not support type" << dtype;
|
||||
return CUSPARSE_INDEX_32I;
|
||||
}
|
||||
|
||||
static inline cudaDataType get_dtype(NanoString dtype) {
|
||||
if (dtype == ns_float32) return CUDA_R_32F;
|
||||
if (dtype == ns_float64) return CUDA_R_64F;
|
||||
if (dtype == ns_float16) return CUDA_R_16F;
|
||||
#ifndef IS_ROCM
|
||||
if (dtype == ns_bfloat16) return CUDA_R_16BF;
|
||||
#endif
|
||||
LOGf << "not support type" << dtype;
|
||||
return CUDA_R_32F;
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,70 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
// Maintainers: Shizhan Lu <578752274@qq.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "cusparse_spmmcoo_op.h"
|
||||
#include "cusparse_wrapper.h"
|
||||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
#ifndef JIT
|
||||
|
||||
CusparseSpmmcooOp::CusparseSpmmcooOp(Var* outputVar_, Var* x_, Var* row_indices_,Var* col_indices_,Var* value_,int A_row_,int A_col_)
|
||||
: outputVar(outputVar_), x(x_),row_indices(row_indices_), col_indices(col_indices_), value(value_),A_row(A_row_),A_col(A_col_) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
ASSERT(x->dtype().is_float() && outputVar->dtype().is_float()) << "type of two inputs should be the same";
|
||||
output = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
void CusparseSpmmcooOp::jit_prepare(JK& jk) {
|
||||
add_jit_define(jk, "T", x->dtype());
|
||||
add_jit_define(jk, "Tindex", col_indices->dtype());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
||||
void CusparseSpmmcooOp::jit_run() {
|
||||
cusparseSpMatDescr_t matA;
|
||||
cusparseDnMatDescr_t matB, matC;
|
||||
cusparseHandle_t &handle_ = cusparse_handle;
|
||||
// void* dBuffer = NULL;
|
||||
// size_t bufferSize = 0;
|
||||
const auto& xs = x->shape;
|
||||
const auto& vs = value->shape;
|
||||
const auto& os = outputVar->shape;
|
||||
ASSERT(xs==os)<<"matrix A and matrix C size not match";
|
||||
ASSERT(A_col==xs[0])<<"matrix A and matrix B size not match";
|
||||
auto dtype_A = get_dtype(value->dtype());
|
||||
auto dtype_B = get_dtype(x->dtype());
|
||||
auto dtype_C = get_dtype(outputVar->dtype());
|
||||
auto dtype_index = get_index_dtype(col_indices->dtype());
|
||||
checkCudaErrors( cusparseCreateCoo(&matA, A_row, A_col, vs[0], row_indices->ptr<Tindex>(), col_indices->ptr<Tindex>(), value->ptr<T>(), dtype_index, CUSPARSE_INDEX_BASE_ZERO, dtype_A) );
|
||||
checkCudaErrors( cusparseCreateDnMat(&matB, xs[0], xs[1], xs[1], x->ptr<T>(), dtype_B, CUSPARSE_ORDER_ROW) );
|
||||
checkCudaErrors( cusparseCreateDnMat(&matC, os[0], os[1],os[1], outputVar->ptr<T>(), dtype_C, CUSPARSE_ORDER_ROW) );
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
// checkCudaErrors( cusparseSpMM_bufferSize(
|
||||
// handle_,
|
||||
// CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
// CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
// &alpha, matA, matB, &beta, matC, CUDA_R_32F,
|
||||
// CUSPARSE_SPMM_ALG_DEFAULT , &bufferSize) );
|
||||
// checkCudaErrors( cudaMalloc(&dBuffer, bufferSize) );
|
||||
checkCudaErrors( cusparseSpMM(handle_,
|
||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
&alpha, matA, matB, &beta, matC, CUDA_R_32F,
|
||||
CUSPARSE_SPMM_ALG_DEFAULT, NULL) );
|
||||
// checkCudaErrors( cudaFree(dBuffer) );
|
||||
checkCudaErrors( cusparseDestroySpMat(matA) );
|
||||
checkCudaErrors( cusparseDestroyDnMat(matB) );
|
||||
checkCudaErrors( cusparseDestroyDnMat(matC) );
|
||||
}
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,26 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
// Maintainers: Shizhan Lu <578752274@qq.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
#include "cusparse.h"
|
||||
namespace jittor {
|
||||
|
||||
struct CusparseSpmmcooOp : Op {
|
||||
Var* x;
|
||||
Var* outputVar;
|
||||
Var* row_indices;
|
||||
Var* col_indices;
|
||||
Var* value;
|
||||
Var* output;
|
||||
int A_row;
|
||||
int A_col;
|
||||
CusparseSpmmcooOp(Var* outputVar_, Var* x_, Var* row_indices_,Var* col_indices_,Var* value_,int A_row,int A_col);
|
||||
const char* name() const override { return "cusparse_spmmcoo"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,67 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
// Maintainers: Shizhan Lu <578752274@qq.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "cusparse_spmmcsr_op.h"
|
||||
#include "cusparse_wrapper.h"
|
||||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
CusparseSpmmcsrOp::CusparseSpmmcsrOp(Var* outputVar_, Var* x_, Var* col_indices_,Var* value_,Var* row_offset_,int A_row_,int A_col_)
|
||||
: outputVar(outputVar_), x(x_), col_indices(col_indices_), value(value_),row_offset(row_offset_),A_row(A_row_),A_col(A_col_){
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
ASSERT(x->dtype().is_float() && outputVar->dtype().is_float()) << "type of two inputs should be the same";
|
||||
output = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
void CusparseSpmmcsrOp::jit_prepare(JK& jk) {
|
||||
add_jit_define(jk, "T", x->dtype());
|
||||
add_jit_define(jk, "Tindex", col_indices->dtype());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
||||
void CusparseSpmmcsrOp::jit_run() {
|
||||
cusparseSpMatDescr_t matA;
|
||||
cusparseDnMatDescr_t matB, matC;
|
||||
void* dBuffer = NULL;
|
||||
size_t bufferSize = 0;
|
||||
cusparseHandle_t &handle_ = cusparse_handle;
|
||||
const auto& xs = x->shape;
|
||||
const auto& vs = value->shape;
|
||||
const auto& os = outputVar->shape;
|
||||
ASSERT(xs==os)<<"matrix A and matrix C size not match";
|
||||
ASSERT(A_col==xs[0])<<"matrix A and matrix B size not match";
|
||||
auto dtype_A = get_dtype(value->dtype());
|
||||
auto dtype_B = get_dtype(x->dtype());
|
||||
auto dtype_C = get_dtype(outputVar->dtype());
|
||||
auto dtype_index = get_index_dtype(col_indices->dtype());
|
||||
checkCudaErrors( cusparseCreateCsr(&matA, A_row, A_col, vs[0], row_offset->ptr<Tindex>(), col_indices->ptr<Tindex>(), value->ptr<T>(), dtype_index, dtype_index, CUSPARSE_INDEX_BASE_ZERO, dtype_A) );
|
||||
checkCudaErrors( cusparseCreateDnMat(&matB, xs[0], xs[1], xs[1], x->ptr<T>(), dtype_B, CUSPARSE_ORDER_ROW) );
|
||||
checkCudaErrors( cusparseCreateDnMat(&matC, os[0], os[1],os[1], outputVar->ptr<T>(), dtype_C, CUSPARSE_ORDER_ROW) );
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
checkCudaErrors( cusparseSpMM_bufferSize(
|
||||
handle_,
|
||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
&alpha, matA, matB, &beta, matC, CUDA_R_32F,
|
||||
CUSPARSE_SPMM_CSR_ALG2, &bufferSize) );
|
||||
checkCudaErrors( cudaMalloc(&dBuffer, bufferSize) );
|
||||
checkCudaErrors(cusparseSpMM(handle_, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, CUDA_R_32F, CUSPARSE_SPMM_CSR_ALG2 , dBuffer));
|
||||
checkCudaErrors( cudaFree(dBuffer) );
|
||||
checkCudaErrors( cusparseDestroySpMat(matA) );
|
||||
checkCudaErrors( cusparseDestroyDnMat(matB) );
|
||||
checkCudaErrors( cusparseDestroyDnMat(matC) );
|
||||
}
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,26 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
// Maintainers: Shizhan Lu <578752274@qq.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
#include "cusparse.h"
|
||||
namespace jittor {
|
||||
|
||||
struct CusparseSpmmcsrOp : Op {
|
||||
Var* x;
|
||||
Var* outputVar;
|
||||
Var* col_indices;
|
||||
Var* row_offset;
|
||||
Var* value;
|
||||
Var* output;
|
||||
int A_row;
|
||||
int A_col;
|
||||
CusparseSpmmcsrOp(Var* outputVar_, Var* x_, Var* col_indices_,Var* value_,Var* row_offset_,int A_row,int A_col);
|
||||
const char* name() const override { return "cusparse_spmmcsr"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,31 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
// Maintainers: Shizhan Lu <578752274@qq.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "cusparse_wrapper.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
cusparseHandle_t cusparse_handle;
|
||||
|
||||
struct cusparse_initer {
|
||||
|
||||
inline cusparse_initer() {
|
||||
if (!get_device_count()) return;
|
||||
checkCudaErrors(cusparseCreate(&cusparse_handle));
|
||||
LOGv << "cusparseCreate finished" << (void*)cusparse_handle;
|
||||
}
|
||||
|
||||
inline ~cusparse_initer() {
|
||||
if (!get_device_count()) return;
|
||||
LOGv << "cusparseDestroy:" << (void*)cusparse_handle;
|
||||
checkCudaErrors(cusparseDestroy(cusparse_handle));
|
||||
LOGv << "cusparseDestroy finished";
|
||||
}
|
||||
|
||||
} init;
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 1993-2017 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
||||
* with this source code for terms and conditions that govern your use of
|
||||
* this software. Any use, reproduction, disclosure, or distribution of
|
||||
* this software and related documentation outside the terms of the EULA
|
||||
* is strictly prohibited.
|
||||
*
|
||||
*/
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// These are CUDA Helper functions for initialization and error checking
|
||||
#include <cuda_runtime.h>
|
||||
#include <cusparse.h>
|
||||
#include "helper_cuda.h"
|
||||
|
||||
#ifdef CUSPARSEAPI
|
||||
// cuSPARSE API errors
|
||||
const char *_cudaGetErrorEnum(cusparseStatus_t status) {
|
||||
switch(status)
|
||||
{
|
||||
case CUSPARSE_STATUS_SUCCESS:
|
||||
return "success";
|
||||
|
||||
case CUSPARSE_STATUS_NOT_INITIALIZED:
|
||||
return "library not initialized";
|
||||
|
||||
case CUSPARSE_STATUS_ALLOC_FAILED:
|
||||
return "resource allocation failed";
|
||||
|
||||
case CUSPARSE_STATUS_INVALID_VALUE:
|
||||
return "an invalid numeric value was used as an argument";
|
||||
|
||||
case CUSPARSE_STATUS_ARCH_MISMATCH:
|
||||
return "an absent device architectural feature is required";
|
||||
|
||||
case CUSPARSE_STATUS_MAPPING_ERROR:
|
||||
return "an access to GPU memory space failed";
|
||||
|
||||
case CUSPARSE_STATUS_EXECUTION_FAILED:
|
||||
return "the GPU program failed to execute";
|
||||
|
||||
case CUSPARSE_STATUS_INTERNAL_ERROR:
|
||||
return "an internal operation failed";
|
||||
|
||||
case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
|
||||
return "the matrix type is not supported by this function";
|
||||
|
||||
case CUSPARSE_STATUS_ZERO_PIVOT:
|
||||
return "an entry of the matrix is either structural zero or numerical zero (singular block)";
|
||||
|
||||
default:
|
||||
return "unknown error";
|
||||
}
|
||||
}
|
||||
#endif
|
|
@ -269,7 +269,7 @@ def generate_error_code_from_func_header(func_head, target_scope_name, name, dfs
|
|||
help_name = ""+target_scope_name+'.'+name
|
||||
else:
|
||||
help_name = name
|
||||
if lib_name in ["mpi", "nccl", "cudnn", "curand" "cufft", "cublas", "mkl"]:
|
||||
if lib_name in ["mpi", "nccl", "cudnn", "curand" "cufft", "cublas", "mkl", "cusparse"]:
|
||||
help_name = lib_name+'.'+help_name
|
||||
help_cmd = f"help(jt.{help_name})"
|
||||
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
# Maintainers: Shizhan Lu <578752274@qq.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
from jittor.compile_extern import cusparse_ops
|
||||
|
||||
|
||||
class TestSpmmCsrOp(unittest.TestCase):
|
||||
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
|
||||
@jt.flag_scope(use_cuda=1, lazy_execution=0)
|
||||
def test_spmm_csr_forward_float32_int32(self):
|
||||
x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float32")
|
||||
col_indices = jt.array([0, 1, 1, 2], dtype="int32")
|
||||
row_offset = jt.array([0, 2, 3, 4], dtype="int32")
|
||||
csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float32")
|
||||
output = jt.zeros((3, 3), dtype="float32")
|
||||
cusparse_ops.cusparse_spmmcsr(
|
||||
output, x, col_indices, csr_weight, row_offset,
|
||||
3, 3
|
||||
).fetch_sync()
|
||||
expected_output = np.array([
|
||||
[12.0, 8.0, 4.0],
|
||||
[12.0, 8.0, 4.0],
|
||||
[6.0, 4.0, 2.0]
|
||||
])
|
||||
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
|
||||
@jt.flag_scope(use_cuda=1, lazy_execution=0)
|
||||
def test_spmm_csr_forward_float16_int32(self):
|
||||
x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float16")
|
||||
col_indices = jt.array([0, 1, 1, 2], dtype="int32")
|
||||
row_offset = jt.array([0, 2, 3, 4], dtype="int32")
|
||||
csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float16")
|
||||
output = jt.zeros((3, 3), dtype="float16")
|
||||
cusparse_ops.cusparse_spmmcsr(
|
||||
output, x, col_indices, csr_weight, row_offset,
|
||||
3, 3
|
||||
).fetch_sync()
|
||||
expected_output = np.array([
|
||||
[12.0, 8.0, 4.0],
|
||||
[12.0, 8.0, 4.0],
|
||||
[6.0, 4.0, 2.0]
|
||||
], dtype="float16")
|
||||
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
|
||||
|
||||
# @unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
|
||||
# @jt.flag_scope(use_cuda=1, lazy_execution=0)
|
||||
# def test_spmm_csr_forward_float64_int32(self):
|
||||
# x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float64")
|
||||
# col_indices = jt.array([0, 1, 1, 2], dtype="int32")
|
||||
# row_offset = jt.array([0, 2, 3, 4], dtype="int32")
|
||||
# csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float64")
|
||||
# output = jt.zeros((3, 3), dtype="float64")
|
||||
# cusparse_ops.cusparse_spmmcsr(
|
||||
# output, x, col_indices, csr_weight, row_offset,
|
||||
# 3, 3
|
||||
# ).fetch_sync()
|
||||
# expected_output = np.array([
|
||||
# [12.0, 8.0, 4.0],
|
||||
# [12.0, 8.0, 4.0],
|
||||
# [6.0, 4.0, 2.0]
|
||||
# ], dtype="float64")
|
||||
# np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
|
||||
@jt.flag_scope(use_cuda=1, lazy_execution=0)
|
||||
def test_spmm_csr_forward_float32_int64(self):
|
||||
x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float32")
|
||||
col_indices = jt.array([0, 1, 1, 2], dtype="int64")
|
||||
row_offset = jt.array([0, 2, 3, 4], dtype="int64")
|
||||
csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float32")
|
||||
output = jt.zeros((3, 3), dtype="float32")
|
||||
cusparse_ops.cusparse_spmmcsr(
|
||||
output, x, col_indices, csr_weight, row_offset,
|
||||
3, 3
|
||||
).fetch_sync()
|
||||
expected_output = np.array([
|
||||
[12.0, 8.0, 4.0],
|
||||
[12.0, 8.0, 4.0],
|
||||
[6.0, 4.0, 2.0]
|
||||
], dtype="float32")
|
||||
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
|
||||
@jt.flag_scope(use_cuda=1, lazy_execution=0)
|
||||
def test_spmm_coo(self):
|
||||
x=jt.array([[3.0, 2.0, 1.0],[4.0, 2.0, 2.0],[1.0, 2.0, 3.0]], dtype="float32")
|
||||
edge_index=jt.array([[0,0,1,2],[1,2,2,1]],dtype="int32")
|
||||
row_indices=edge_index[0,:]
|
||||
col_indices=edge_index[1,:]
|
||||
edge_weight = jt.array([1.0, 1.0, 1.0, 1.0], dtype="float32")
|
||||
feature_dim=jt.size(x,1)
|
||||
output=jt.zeros(3,feature_dim)
|
||||
cusparse_ops.cusparse_spmmcoo(output,x,row_indices,col_indices,edge_weight,3,3).fetch_sync()
|
||||
print("Output:", output)
|
||||
expected_output = np.array([
|
||||
[5.0, 4.0, 5.0],
|
||||
[1.0, 2.0, 3.0],
|
||||
[4.0, 2.0, 2.0]
|
||||
], dtype="float32")
|
||||
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue