mirror of https://github.com/Jittor/Jittor
fix duplicate definition in cudnnops
This commit is contained in:
parent
ac78f57a7e
commit
9ce77dfb82
|
@ -95,8 +95,8 @@ VarPtr CudnnConv3dBackwardWOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
EXTERN_LIB unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
|
||||
|
||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
// template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
// template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
|
||||
void CudnnConv3dBackwardWOp::jit_run() {
|
||||
auto w = dw;
|
||||
|
|
|
@ -85,8 +85,8 @@ VarPtr CudnnConv3dBackwardXOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
EXTERN_LIB unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
|
||||
|
||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
// template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
// template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
|
||||
void CudnnConv3dBackwardXOp::jit_run() {
|
||||
auto x = dx;
|
||||
|
|
|
@ -88,8 +88,8 @@ VarPtr CudnnConv3dOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
EXTERN_LIB unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
|
||||
|
||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
// template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
// template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
|
||||
void CudnnConv3dOp::jit_run() {
|
||||
cudnnHandle_t& handle_ = cudnn_handle;
|
||||
|
|
Loading…
Reference in New Issue