fix duplicate definition in cudnnops

This commit is contained in:
Exusial 2024-12-18 14:19:04 +08:00
parent ac78f57a7e
commit 9ce77dfb82
3 changed files with 6 additions and 6 deletions

View File

@ -95,8 +95,8 @@ VarPtr CudnnConv3dBackwardWOp::grad(Var* out, Var* dout, Var* v, int v_index) {
EXTERN_LIB unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
// template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
// template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
void CudnnConv3dBackwardWOp::jit_run() {
auto w = dw;

View File

@ -85,8 +85,8 @@ VarPtr CudnnConv3dBackwardXOp::grad(Var* out, Var* dout, Var* v, int v_index) {
EXTERN_LIB unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
// template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
// template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
void CudnnConv3dBackwardXOp::jit_run() {
auto x = dx;

View File

@ -88,8 +88,8 @@ VarPtr CudnnConv3dOp::grad(Var* out, Var* dout, Var* v, int v_index) {
EXTERN_LIB unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
// template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
// template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
void CudnnConv3dOp::jit_run() {
cudnnHandle_t& handle_ = cudnn_handle;