support tensorcore mode

This commit is contained in:
li-xl 2021-11-02 17:45:08 +08:00
parent 66d9df2f82
commit 40ed259665
10 changed files with 84 additions and 16 deletions

View File

@ -19,6 +19,8 @@ using namespace std;
namespace jittor {
extern int use_tensorcore;
#ifndef JIT
static auto make_cublas_batched_matmul = get_op_info("cublas_batched_matmul")
@ -116,13 +118,25 @@ void CublasBatchedMatmulOp::jit_run() {
k = bs[adim-2];
}
// a: [b,n,m], b: [b,m,k], c: [b,n,k]
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
cublasGemmAlgo_t algo = use_tensorcore ? CUBLAS_GEMM_DEFAULT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
// checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
// k, n, m, &alpha,
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
// c->ptr<T>(), k, k * n,
// batch_size));
checkCudaErrors(cublasGemmStridedBatchedEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(), k, k * n,
batch_size));
b->ptr<T>(),get_dtype(b->dtype()), '@Trans_b' == 'N' ? k : m, k * m,
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(),get_dtype(c->dtype()), k, k * n,
batch_size,computeType,algo));
}
#endif
#endif // JIT

View File

@ -16,6 +16,8 @@ using namespace std;
namespace jittor {
extern int use_tensorcore;
#ifndef JIT
CublasMatmulOp::CublasMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b)
@ -54,6 +56,15 @@ void CublasMatmulOp::jit_prepare(JK& jk) {
#else // JIT
#pragma clang diagnostic ignored "-Wtautological-compare"
static inline cudaDataType get_dtype(NanoString dtype) {
if (dtype == ns_float32) return CUDA_R_32F;
// if (dtype == ns_float64) return CUDA_R_64F;
// if (dtype == ns_float16) return CUDA_R_16F;
LOGf << "not support type" << dtype;
return CUDA_R_32F;
}
void CublasMatmulOp::jit_run() {
cublasHandle_t& handle_ = cublas_handle;
const T alpha = 1.0f;
@ -72,12 +83,24 @@ void CublasMatmulOp::jit_run() {
k = bs[0];
}
// a: [n,m], b: [m,k], c: [n,k]
checkCudaErrors(cublas@op@@gemm(handle_,
cublasGemmAlgo_t algo = use_tensorcore ? CUBLAS_GEMM_DEFAULT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
// checkCudaErrors(cublas@op@@gemm(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
// k, n, m, &alpha,
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
// c->ptr<T>(), k));
checkCudaErrors(cublasGemmEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(), k));
b->ptr<T>(),get_dtype(b->dtype()), '@Trans_b' == 'N' ? k : m,
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(),get_dtype(c->dtype()), k,
computeType, algo));
}
#endif // JIT

View File

@ -18,6 +18,8 @@ using namespace std;
namespace jittor {
extern int use_tensorcore;
#pragma GCC diagnostic ignored "-Wunused-variable"
#ifndef JIT
@ -158,7 +160,9 @@ void CudnnConv3dBackwardWOp::jit_run() {
));
// using tensor core
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
if(use_tensorcore){
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
}
int sy[] = {0,0,0,0,1};

View File

@ -18,6 +18,8 @@ using namespace std;
namespace jittor {
extern int use_tensorcore;
#pragma GCC diagnostic ignored "-Wunused-variable"
#ifndef JIT
@ -149,7 +151,9 @@ void CudnnConv3dBackwardXOp::jit_run() {
));
// using tensor core
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
if(use_tensorcore){
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
}
int sy[] = {0,0,0,0,1};

View File

@ -15,6 +15,8 @@ using namespace std;
namespace jittor {
extern int use_tensorcore;
#pragma GCC diagnostic ignored "-Wunused-variable"
#ifndef JIT
@ -150,7 +152,9 @@ void CudnnConv3dOp::jit_run() {
));
// using tensor core
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
if(use_tensorcore){
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
}
int sy[] = {0,0,0,0,1};

View File

@ -16,6 +16,9 @@
using namespace std;
namespace jittor {
extern int use_tensorcore;
static inline int findc(const string& format, const char& c) {
if (c==format[0]) return 0;
if (c==format[1]) return 1;
@ -149,7 +152,14 @@ void CudnnConvBackwardWOp::jit_run() {
));
// using tensor core
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
if(use_tensorcore){
// CUDNN_TENSOR_OP_MATH
// The use of Tensor Core operations is permitted but will not actively perform datatype down conversion on tensors in order to utilize Tensor Cores.
// CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION
// The use of Tensor Core operations is permitted and will actively perform datatype down conversion on tensors in order to utilize Tensor Cores.
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
}
int dimY[] = {
(int)y->shape[findc("@YFORMAT", 'a')], // n

View File

@ -25,6 +25,8 @@ static inline int findc(const char* format, const char& c) {
namespace jittor {
extern int use_tensorcore;
#ifndef JIT
static inline void get_shape(Var* x, const char* f, const string& format, int& a, int& b, int &c, int& d) {
@ -150,7 +152,9 @@ void CudnnConvBackwardXOp::jit_run() {
));
// using tensor core
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
if(use_tensorcore){
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
}
int dimY[] = {
(int)y->shape[findc("@YFORMAT", 'a')], // n

View File

@ -14,6 +14,8 @@ using namespace std;
namespace jittor {
extern int use_tensorcore;
static inline int findc(const char* format, const char& c) {
if (c==format[0]) return 0;
if (c==format[1]) return 1;
@ -150,7 +152,9 @@ void CudnnConvOp::jit_run() {
));
// using tensor core
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
if(use_tensorcore){
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
}
int dimY[] = {
(int)y->shape[findc("@YFORMAT", 'a')], // n

View File

@ -788,7 +788,7 @@ int doTest(int algo, int* dimA, int* padA, int* convstrideA, int* filterdimA, cu
CUDNN_CONVOLUTION,
CUDNN_DATA_FLOAT) );
if (mathType == 1) {
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
}
checkCudaErrors( cudnnSetTensorNdDescriptor(cudnnOdesc, getDataType<T_ELEM>(), convDim+2, outdimA, outstrideA) );

View File

@ -20,6 +20,7 @@
namespace jittor {
DEFINE_FLAG(vector<int>, cuda_archs, {}, "Cuda arch");
DEFINE_FLAG(int, use_tensorcore, 0, "use tensor core");
unique_ptr<std::default_random_engine> eng;