polish matmul

This commit is contained in:
lixl19 2022-03-18 16:21:34 +08:00
parent ad57ec890f
commit 1fa89771c4
2 changed files with 37 additions and 17 deletions

View File

@ -128,6 +128,18 @@ void CublasBatchedMatmulOp::jit_run() {
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUBLAS_COMPUTE_16F;
}
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cudaDataType_t computeType = CUDA_R_32F;
if (use_tensorcore) {
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUDA_R_16F;
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
#endif
checkCudaErrors(cublasGemmStridedBatchedEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
@ -135,15 +147,13 @@ void CublasBatchedMatmulOp::jit_run() {
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(),get_dtype(c->dtype()), k, k * n,
batch_size,computeType,algo));
#else
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(), k, k * n,
batch_size));
#endif
// checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
// k, n, m, &alpha,
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
// c->ptr<T>(), k, k * n,
// batch_size));
}
#endif
#endif // JIT

View File

@ -85,6 +85,18 @@ void CublasMatmulOp::jit_run() {
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUBLAS_COMPUTE_16F;
}
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cudaDataType_t computeType = CUDA_R_32F;
if (use_tensorcore) {
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUDA_R_16F;
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
#endif
checkCudaErrors(cublasGemmEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
@ -92,15 +104,13 @@ void CublasMatmulOp::jit_run() {
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(),get_dtype(c->dtype()), k,
computeType, algo));
#else
checkCudaErrors(cublas@op@@gemm(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(), k));
// checkCudaErrors(cublas@op@@gemm(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
// k, n, m, &alpha,
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
// c->ptr<T>(), k));
#endif
}
#endif // JIT