support tensorcore

This commit is contained in:
Dun Liang 2021-12-30 16:28:21 +08:00
parent 928f7ae5be
commit 7b9113f828
2 changed files with 9 additions and 8 deletions

View File

@ -15,9 +15,18 @@
#include "helper_cuda.h"
#include "fp16_emu.h"
#include "common.h"
#include "misc/nano_string.h"
namespace jittor {
EXTERN_LIB cublasHandle_t cublas_handle;
static inline cudaDataType get_dtype(NanoString dtype) {
if (dtype == ns_float32) return CUDA_R_32F;
// if (dtype == ns_float64) return CUDA_R_64F;
// if (dtype == ns_float16) return CUDA_R_16F;
LOGf << "not support type" << dtype;
return CUDA_R_32F;
}
} // jittor

View File

@ -57,14 +57,6 @@ void CublasMatmulOp::jit_prepare(JK& jk) {
#else // JIT
#pragma clang diagnostic ignored "-Wtautological-compare"
static inline cudaDataType get_dtype(NanoString dtype) {
if (dtype == ns_float32) return CUDA_R_32F;
// if (dtype == ns_float64) return CUDA_R_64F;
// if (dtype == ns_float16) return CUDA_R_16F;
LOGf << "not support type" << dtype;
return CUDA_R_32F;
}
void CublasMatmulOp::jit_run() {
cublasHandle_t& handle_ = cublas_handle;
const T alpha = 1.0f;