mirror of https://github.com/Jittor/Jittor
support tensorcore mode
This commit is contained in:
parent
66d9df2f82
commit
40ed259665
|
@ -19,6 +19,8 @@ using namespace std;
|
|||
|
||||
namespace jittor {
|
||||
|
||||
extern int use_tensorcore;
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
static auto make_cublas_batched_matmul = get_op_info("cublas_batched_matmul")
|
||||
|
@ -116,13 +118,25 @@ void CublasBatchedMatmulOp::jit_run() {
|
|||
k = bs[adim-2];
|
||||
}
|
||||
// a: [b,n,m], b: [b,m,k], c: [b,n,k]
|
||||
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
|
||||
cublasGemmAlgo_t algo = use_tensorcore ? CUBLAS_GEMM_DEFAULT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
|
||||
// checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
|
||||
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
// k, n, m, &alpha,
|
||||
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
|
||||
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
|
||||
// c->ptr<T>(), k, k * n,
|
||||
// batch_size));
|
||||
|
||||
checkCudaErrors(cublasGemmStridedBatchedEx(handle_,
|
||||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
k, n, m, &alpha,
|
||||
b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
|
||||
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
|
||||
c->ptr<T>(), k, k * n,
|
||||
batch_size));
|
||||
b->ptr<T>(),get_dtype(b->dtype()), '@Trans_b' == 'N' ? k : m, k * m,
|
||||
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, &beta,
|
||||
c->ptr<T>(),get_dtype(c->dtype()), k, k * n,
|
||||
batch_size,computeType,algo));
|
||||
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
|
|
@ -16,6 +16,8 @@ using namespace std;
|
|||
|
||||
namespace jittor {
|
||||
|
||||
extern int use_tensorcore;
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
CublasMatmulOp::CublasMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b)
|
||||
|
@ -54,6 +56,15 @@ void CublasMatmulOp::jit_prepare(JK& jk) {
|
|||
|
||||
#else // JIT
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
|
||||
static inline cudaDataType get_dtype(NanoString dtype) {
|
||||
if (dtype == ns_float32) return CUDA_R_32F;
|
||||
// if (dtype == ns_float64) return CUDA_R_64F;
|
||||
// if (dtype == ns_float16) return CUDA_R_16F;
|
||||
LOGf << "not support type" << dtype;
|
||||
return CUDA_R_32F;
|
||||
}
|
||||
|
||||
void CublasMatmulOp::jit_run() {
|
||||
cublasHandle_t& handle_ = cublas_handle;
|
||||
const T alpha = 1.0f;
|
||||
|
@ -72,12 +83,24 @@ void CublasMatmulOp::jit_run() {
|
|||
k = bs[0];
|
||||
}
|
||||
// a: [n,m], b: [m,k], c: [n,k]
|
||||
checkCudaErrors(cublas@op@@gemm(handle_,
|
||||
cublasGemmAlgo_t algo = use_tensorcore ? CUBLAS_GEMM_DEFAULT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
|
||||
// checkCudaErrors(cublas@op@@gemm(handle_,
|
||||
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
// k, n, m, &alpha,
|
||||
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
|
||||
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
|
||||
// c->ptr<T>(), k));
|
||||
|
||||
checkCudaErrors(cublasGemmEx(handle_,
|
||||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
k, n, m, &alpha,
|
||||
b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
|
||||
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
|
||||
c->ptr<T>(), k));
|
||||
b->ptr<T>(),get_dtype(b->dtype()), '@Trans_b' == 'N' ? k : m,
|
||||
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, &beta,
|
||||
c->ptr<T>(),get_dtype(c->dtype()), k,
|
||||
computeType, algo));
|
||||
|
||||
}
|
||||
#endif // JIT
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@ using namespace std;
|
|||
|
||||
namespace jittor {
|
||||
|
||||
extern int use_tensorcore;
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||
|
||||
#ifndef JIT
|
||||
|
@ -158,7 +160,9 @@ void CudnnConv3dBackwardWOp::jit_run() {
|
|||
));
|
||||
|
||||
// using tensor core
|
||||
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||
if(use_tensorcore){
|
||||
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
|
||||
}
|
||||
|
||||
|
||||
int sy[] = {0,0,0,0,1};
|
||||
|
|
|
@ -18,6 +18,8 @@ using namespace std;
|
|||
|
||||
namespace jittor {
|
||||
|
||||
extern int use_tensorcore;
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||
|
||||
#ifndef JIT
|
||||
|
@ -149,7 +151,9 @@ void CudnnConv3dBackwardXOp::jit_run() {
|
|||
));
|
||||
|
||||
// using tensor core
|
||||
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||
if(use_tensorcore){
|
||||
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
|
||||
}
|
||||
|
||||
|
||||
int sy[] = {0,0,0,0,1};
|
||||
|
|
|
@ -15,6 +15,8 @@ using namespace std;
|
|||
|
||||
namespace jittor {
|
||||
|
||||
extern int use_tensorcore;
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||
|
||||
#ifndef JIT
|
||||
|
@ -150,7 +152,9 @@ void CudnnConv3dOp::jit_run() {
|
|||
));
|
||||
|
||||
// using tensor core
|
||||
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||
if(use_tensorcore){
|
||||
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
|
||||
}
|
||||
|
||||
|
||||
int sy[] = {0,0,0,0,1};
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
|
||||
extern int use_tensorcore;
|
||||
|
||||
static inline int findc(const string& format, const char& c) {
|
||||
if (c==format[0]) return 0;
|
||||
if (c==format[1]) return 1;
|
||||
|
@ -149,7 +152,14 @@ void CudnnConvBackwardWOp::jit_run() {
|
|||
));
|
||||
|
||||
// using tensor core
|
||||
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||
if(use_tensorcore){
|
||||
// CUDNN_TENSOR_OP_MATH
|
||||
// The use of Tensor Core operations is permitted but will not actively perform datatype down conversion on tensors in order to utilize Tensor Cores.
|
||||
// CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION
|
||||
// The use of Tensor Core operations is permitted and will actively perform datatype down conversion on tensors in order to utilize Tensor Cores.
|
||||
|
||||
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
|
||||
}
|
||||
|
||||
int dimY[] = {
|
||||
(int)y->shape[findc("@YFORMAT", 'a')], // n
|
||||
|
|
|
@ -25,6 +25,8 @@ static inline int findc(const char* format, const char& c) {
|
|||
|
||||
namespace jittor {
|
||||
|
||||
extern int use_tensorcore;
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
static inline void get_shape(Var* x, const char* f, const string& format, int& a, int& b, int &c, int& d) {
|
||||
|
@ -150,7 +152,9 @@ void CudnnConvBackwardXOp::jit_run() {
|
|||
));
|
||||
|
||||
// using tensor core
|
||||
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||
if(use_tensorcore){
|
||||
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
|
||||
}
|
||||
|
||||
int dimY[] = {
|
||||
(int)y->shape[findc("@YFORMAT", 'a')], // n
|
||||
|
|
|
@ -14,6 +14,8 @@ using namespace std;
|
|||
|
||||
namespace jittor {
|
||||
|
||||
extern int use_tensorcore;
|
||||
|
||||
static inline int findc(const char* format, const char& c) {
|
||||
if (c==format[0]) return 0;
|
||||
if (c==format[1]) return 1;
|
||||
|
@ -150,7 +152,9 @@ void CudnnConvOp::jit_run() {
|
|||
));
|
||||
|
||||
// using tensor core
|
||||
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||
if(use_tensorcore){
|
||||
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
|
||||
}
|
||||
|
||||
int dimY[] = {
|
||||
(int)y->shape[findc("@YFORMAT", 'a')], // n
|
||||
|
|
|
@ -788,7 +788,7 @@ int doTest(int algo, int* dimA, int* padA, int* convstrideA, int* filterdimA, cu
|
|||
CUDNN_CONVOLUTION,
|
||||
CUDNN_DATA_FLOAT) );
|
||||
if (mathType == 1) {
|
||||
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
|
||||
}
|
||||
|
||||
checkCudaErrors( cudnnSetTensorNdDescriptor(cudnnOdesc, getDataType<T_ELEM>(), convDim+2, outdimA, outstrideA) );
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
namespace jittor {
|
||||
|
||||
DEFINE_FLAG(vector<int>, cuda_archs, {}, "Cuda arch");
|
||||
DEFINE_FLAG(int, use_tensorcore, 0, "use tensor core");
|
||||
|
||||
unique_ptr<std::default_random_engine> eng;
|
||||
|
||||
|
|
Loading…
Reference in New Issue