mirror of https://github.com/Jittor/Jittor
support tensorcore
This commit is contained in:
parent
928f7ae5be
commit
7b9113f828
|
@ -15,9 +15,18 @@
|
|||
#include "helper_cuda.h"
|
||||
#include "fp16_emu.h"
|
||||
#include "common.h"
|
||||
#include "misc/nano_string.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
EXTERN_LIB cublasHandle_t cublas_handle;
|
||||
|
||||
static inline cudaDataType get_dtype(NanoString dtype) {
|
||||
if (dtype == ns_float32) return CUDA_R_32F;
|
||||
// if (dtype == ns_float64) return CUDA_R_64F;
|
||||
// if (dtype == ns_float16) return CUDA_R_16F;
|
||||
LOGf << "not support type" << dtype;
|
||||
return CUDA_R_32F;
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -57,14 +57,6 @@ void CublasMatmulOp::jit_prepare(JK& jk) {
|
|||
#else // JIT
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
|
||||
static inline cudaDataType get_dtype(NanoString dtype) {
|
||||
if (dtype == ns_float32) return CUDA_R_32F;
|
||||
// if (dtype == ns_float64) return CUDA_R_64F;
|
||||
// if (dtype == ns_float16) return CUDA_R_16F;
|
||||
LOGf << "not support type" << dtype;
|
||||
return CUDA_R_32F;
|
||||
}
|
||||
|
||||
void CublasMatmulOp::jit_run() {
|
||||
cublasHandle_t& handle_ = cublas_handle;
|
||||
const T alpha = 1.0f;
|
||||
|
|
Loading…
Reference in New Issue