mirror of https://github.com/Jittor/Jittor
polish matmul
This commit is contained in:
parent
ad57ec890f
commit
1fa89771c4
|
@ -128,6 +128,18 @@ void CublasBatchedMatmulOp::jit_run() {
|
|||
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
|
||||
computeType = CUBLAS_COMPUTE_16F;
|
||||
}
|
||||
#else
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
cudaDataType_t computeType = CUDA_R_32F;
|
||||
if (use_tensorcore) {
|
||||
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||
}
|
||||
if (a->dtype() == ns_float16
|
||||
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
|
||||
computeType = CUDA_R_16F;
|
||||
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||
}
|
||||
#endif
|
||||
checkCudaErrors(cublasGemmStridedBatchedEx(handle_,
|
||||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
k, n, m, &alpha,
|
||||
|
@ -135,15 +147,13 @@ void CublasBatchedMatmulOp::jit_run() {
|
|||
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, &beta,
|
||||
c->ptr<T>(),get_dtype(c->dtype()), k, k * n,
|
||||
batch_size,computeType,algo));
|
||||
#else
|
||||
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
|
||||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
k, n, m, &alpha,
|
||||
b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
|
||||
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
|
||||
c->ptr<T>(), k, k * n,
|
||||
batch_size));
|
||||
#endif
|
||||
// checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
|
||||
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
// k, n, m, &alpha,
|
||||
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
|
||||
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
|
||||
// c->ptr<T>(), k, k * n,
|
||||
// batch_size));
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
|
|
@ -85,6 +85,18 @@ void CublasMatmulOp::jit_run() {
|
|||
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
|
||||
computeType = CUBLAS_COMPUTE_16F;
|
||||
}
|
||||
#else
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
cudaDataType_t computeType = CUDA_R_32F;
|
||||
if (use_tensorcore) {
|
||||
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||
}
|
||||
if (a->dtype() == ns_float16
|
||||
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
|
||||
computeType = CUDA_R_16F;
|
||||
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||
}
|
||||
#endif
|
||||
checkCudaErrors(cublasGemmEx(handle_,
|
||||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
k, n, m, &alpha,
|
||||
|
@ -92,15 +104,13 @@ void CublasMatmulOp::jit_run() {
|
|||
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, &beta,
|
||||
c->ptr<T>(),get_dtype(c->dtype()), k,
|
||||
computeType, algo));
|
||||
#else
|
||||
checkCudaErrors(cublas@op@@gemm(handle_,
|
||||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
k, n, m, &alpha,
|
||||
b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
|
||||
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
|
||||
c->ptr<T>(), k));
|
||||
// checkCudaErrors(cublas@op@@gemm(handle_,
|
||||
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
// k, n, m, &alpha,
|
||||
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
|
||||
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
|
||||
// c->ptr<T>(), k));
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
#endif // JIT
|
||||
|
|
Loading…
Reference in New Issue