diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 37719b09..6752d0a6 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.3.22' +__version__ = '1.2.3.23' from jittor_utils import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index d33a89ec..c3fc2007 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -634,7 +634,7 @@ def compile_custom_ops( if gen_name_ != "": gen_name = gen_name_ if len(gen_name) > 100: - gen_name = gen_name[:80] + "___hash" + str(hash(gen_name)) + gen_name = gen_name[:80] + "___hash" + str(abs(hash(gen_name))) includes = sorted(list(set(includes))) includes = "".join(map(lambda x: f" -I'{x}' ", includes)) diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.cc new file mode 100644 index 00000000..b18b41df --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.cc @@ -0,0 +1,288 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang +// Guowei Yang <471184555@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mem/allocator.h" +#include "var.h" +#include "cudnn_conv3d_backward_w_op.h" +#include "cudnn_warper.h" +#include "executor.h" +#include "ops/op_register.h" + +using namespace std; + +namespace jittor { + +#pragma GCC diagnostic ignored "-Wunused-variable" + +#ifndef JIT + +CudnnConv3dBackwardWOp::CudnnConv3dBackwardWOp(Var* x, Var* dy, int kd, int kh, int kw, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat) + : x(x), dy(dy), kd(kd), kh(kh), kw(kw), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups), + xformat(move(xformat)) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + dw = create_output(nullptr, dtype_infer(dy->ns, x->ns)); +} + +void CudnnConv3dBackwardWOp::infer_shape() { + ASSERTop(x->shape.size(),==,5); + ASSERTop(dy->shape.size(),==,5); + int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + + if (xformat == "ncdhw") { + x->shape.unpack(xn, xc, xd, xh, xw); + dy->shape.unpack(yn, yc, yd, yh, yw); + } else { + x->shape.unpack(xn, xd, xh, xw, xc); + dy->shape.unpack(yn, yd, yh, yw, yc); + } + wco = yc, wci = xc / groups; + wh = kh; + ww = kw; + wd = kd; + dw->set_shape(NanoVector(wco, wci, wd, wh, ww)); +} + +void CudnnConv3dBackwardWOp::jit_prepare(JK& jk) { + jk << _CS("[Tx:") << x->dtype(); + jk << _CS("][Ty:") << dy->dtype(); + jk << _CS("][Tw:") << dw->dtype(); + jk << ']'; +} + +static auto make_conv3d = get_op_info("cudnn_conv3d") + .get_constructor(); +static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x") + .get_constructor(); + + +VarPtr CudnnConv3dBackwardWOp::grad(Var* out, Var* dout, Var* v, int v_index) { + int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + + if (xformat == "ncdhw") { + x->shape.unpack(xn, xc, xd, xh, xw); + dy->shape.unpack(yn, yc, yd, yh, yw); + } else { + x->shape.unpack(xn, xd, xh, xw, xc); + dy->shape.unpack(yn, yd, yh, yw, yc); + } + + if (v_index == 0) { + return make_backwardx(dout, dy, xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat); + } else { + return make_conv3d(x, dout, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat); + } +} + +// unordered_map bwdw_algo_cache; + +#else // JIT +#ifdef JIT_cuda + +#pragma clang diagnostic ignored "-Wtautological-compare" + +extern unordered_map bwdw_algo_cache; + +template __inline__ cudnnDataType_t getDataType(); +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } + +void CudnnConv3dBackwardWOp::jit_run() { + auto w = dw; + auto y = dy; + cudnnHandle_t& handle_ = cudnn_handle; + + cudnnTensorDescriptor_t cudnnIdesc; + cudnnFilterDescriptor_t cudnnFdesc; + cudnnTensorDescriptor_t cudnnOdesc; + cudnnConvolutionDescriptor_t cudnnConvDesc; + + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc )); + checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc )); + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc )); + checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc )); + checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups )); + + int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + int sx[] = {0,0,0,0,1}; + for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1]; + int strideX[5]; + if (xformat == "ncdhw") { + x->shape.unpack(xn, xc, xd, xh, xw); + int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]}; + memcpy(strideX, tmp, sizeof(tmp)); + } else { + x->shape.unpack(xn, xd, xh, xw, xc); + int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]}; + memcpy(strideX, tmp, sizeof(tmp)); + } + int dimX[] = {xn, xc, xd, xh, xw}; + // dimX: ncdhw + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnIdesc, getDataType(), + 5, dimX, strideX + )); + + auto ws = w->shape; + int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]}; + // cudnn only support this two format + // https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor + #define filterFormat_oihw CUDNN_TENSOR_NCHW + #define filterFormat_ohwi CUDNN_TENSOR_NHWC + + // dimW: KCRS(oihw) + checkCudaErrors(cudnnSetFilterNdDescriptor( + cudnnFdesc, getDataType(), + // filterFormat_@WFORMAT, 5, dimW + filterFormat_oihw, 5, dimW + )); + + int padA[] = {paddingd, paddingh, paddingw}; + int convstrideA[] = {strided, strideh, stridew}; + int dilationA[] = {dilationd, dilationh, dilationw}; + // difference between + // CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION + // is the kernel rc order + // currently, No perf difference is observed between + // this two mode + checkCudaErrors(cudnnSetConvolutionNdDescriptor( + cudnnConvDesc, 3, + padA, convstrideA, dilationA, + CUDNN_CROSS_CORRELATION, getDataType() + )); + + // using tensor core + // checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) ); + + + int sy[] = {0,0,0,0,1}; + for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1]; + int strideY[5]; + if (xformat == "ncdhw") { + y->shape.unpack(yn, yc, yd, yh, yw); + int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]}; + memcpy(strideY, tmp, sizeof(tmp)); + } else { + y->shape.unpack(yn, yd, yh, yw, yc); + int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]}; + memcpy(strideY, tmp, sizeof(tmp)); + } + int dimY[] = {yn, yc, yd, yh, yw}; + + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnOdesc, getDataType(), + 5, dimY, strideY + )); + + cudnnConvolutionBwdFilterAlgo_t algos[] = { + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, + }; + int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; + int perf_count; + cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_algos]; + cudnnConvolutionBwdFilterAlgo_t algo; + bool benchmark=true; + + jk.clear(); + jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ","; + jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ","; + jk << paddingd << paddingh << paddingw << "," << strided << strideh <second; + else { + if (bwdw_algo_cache.size()>=max_cache_size) benchmark = false; + if (benchmark) { + size_t max_ws_size = 0; + for (int i = 0; i < num_algos; i++) { + size_t sz; + cudnnStatus_t ret = cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, cudnnFdesc, algos[i], &sz); + // continue if use too much workspace + if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue; + if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; + } + size_t allocation; + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); + checkCudaErrors(cudnnFindConvolutionBackwardFilterAlgorithmEx( + handle_, + cudnnIdesc, x->ptr(), + cudnnOdesc, y->ptr(), + cudnnConvDesc, + cudnnFdesc, w->ptr(), + num_algos, + &perf_count, + perf_results, + ws, + max_ws_size)); + exe.temp_allocator->free(ws, max_ws_size, allocation); + } else { + checkCudaErrors(cudnnGetConvolutionBackwardFilterAlgorithm_v7( + handle_, + cudnnIdesc, + cudnnOdesc, + cudnnConvDesc, + cudnnFdesc, + num_algos, + &perf_count, + perf_results)); + } + int best_algo_idx=-1; + for (int i = 0; i < perf_count; i++) + if (perf_results[i].status == CUDNN_STATUS_SUCCESS){ + best_algo_idx=i; + break; + } + ASSERT(best_algo_idx!=-1); + algo=perf_results[best_algo_idx].algo; + if (benchmark) { + bwdw_algo_cache[jk.to_string()] = algo; + if (bwdw_algo_cache.size()==max_cache_size) + LOGw << "backward w algorithm cache is full"; + } + } + + // TODO: warp work space + void *workSpace = 0; + size_t workSpaceSize; + checkCudaErrors (cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, + cudnnFdesc, algo, &workSpaceSize)); + size_t allocation; + if (workSpaceSize > 0) { + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); + } + float alpha=1, beta=0; + checkCudaErrors(cudnnConvolutionBackwardFilter( + handle_, + (void*)(&alpha), + cudnnIdesc, x->ptr(), + cudnnOdesc, y->ptr(), + cudnnConvDesc, + algo, + workSpace, workSpaceSize, + (void*)(&beta), + cudnnFdesc, w->ptr()) + ); + if (workSpace) + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); + + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); + checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc )); + checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc )); +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.h new file mode 100644 index 00000000..2b15ff8b --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.h @@ -0,0 +1,28 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang +// Guowei Yang <471184555@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnConv3dBackwardWOp : Op { + Var* x, * dy, * dw; + int kd, kh, kw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups; + string xformat; + + CudnnConv3dBackwardWOp(Var* x, Var* y, int kd, int kh, int kw, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups=1, string xformat="ncdhw"); + + const char* name() const override { return "cudnn_conv3d_backward_w"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.cc new file mode 100644 index 00000000..2a4debdd --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.cc @@ -0,0 +1,279 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang +// Guowei Yang <471184555@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mem/allocator.h" +#include "var.h" +#include "cudnn_conv3d_backward_x_op.h" +#include "cudnn_warper.h" +#include "executor.h" +#include "ops/op_register.h" + +using namespace std; + +namespace jittor { + +#pragma GCC diagnostic ignored "-Wunused-variable" + +#ifndef JIT + +CudnnConv3dBackwardXOp::CudnnConv3dBackwardXOp(Var* w, Var* dy, int depth, int height, int width, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat) + : w(w), dy(dy), xd(depth), xh(height), xw(width), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups), + xformat(move(xformat)) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + dx = create_output(nullptr, dtype_infer(dy->ns, w->ns)); +} + +void CudnnConv3dBackwardXOp::infer_shape() { + ASSERTop(w->shape.size(),==,5); + ASSERTop(dy->shape.size(),==,5); + int xn, xc, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + w->shape.unpack(wco, wci, wd, wh, ww); + if (xformat == "ncdhw") + dy->shape.unpack(yn, yc, yd, yh, yw); + else + dy->shape.unpack(yn, yd, yh, yw, yc); + xn = yn, xc = wci * groups; + if (xformat == "ncdhw") + dx->set_shape(NanoVector(xn, xc, xd, xh, xw)); + else + dx->set_shape(NanoVector(xn, xd, xh, xw, xc)); +} + +void CudnnConv3dBackwardXOp::jit_prepare(JK& jk) { + jk << _CS("[Tx:") << dx->dtype(); + jk << _CS("][Ty:") << dy->dtype(); + jk << _CS("][Tw:") << w->dtype(); + jk << ']'; +} + + +static auto make_conv3d = get_op_info("cudnn_conv3d") + .get_constructor(); +static auto make_backwardw = get_op_info("cudnn_conv3d_backward_w") + .get_constructor(); + + +VarPtr CudnnConv3dBackwardXOp::grad(Var* out, Var* dout, Var* v, int v_index) { + int xn, xc, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + w->shape.unpack(wco, wci, wd, wh, ww); + + if (v_index == 0) { + return make_backwardw(dout, dy, wd, wh, ww, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat); + } else { + return make_conv3d(dout, w, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat); + } +} +// unordered_map bwdx_algo_cache; + +#else // JIT +#ifdef JIT_cuda + +#pragma clang diagnostic ignored "-Wtautological-compare" + +extern unordered_map bwdx_algo_cache; + +template __inline__ cudnnDataType_t getDataType(); +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } + +void CudnnConv3dBackwardXOp::jit_run() { + auto x = dx; + auto y = dy; + cudnnHandle_t& handle_ = cudnn_handle; + + cudnnTensorDescriptor_t cudnnIdesc; + cudnnFilterDescriptor_t cudnnFdesc; + cudnnTensorDescriptor_t cudnnOdesc; + cudnnConvolutionDescriptor_t cudnnConvDesc; + + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc )); + checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc )); + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc )); + checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc )); + checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups )); + + + int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + int sx[] = {0,0,0,0,1}; + for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1]; + int strideX[5]; + if (xformat == "ncdhw") { + x->shape.unpack(xn, xc, xd, xh, xw); + int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]}; + memcpy(strideX, tmp, sizeof(tmp)); + } else { + x->shape.unpack(xn, xd, xh, xw, xc); + int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]}; + memcpy(strideX, tmp, sizeof(tmp)); + } + int dimX[] = {xn, xc, xd, xh, xw}; + // dimX: ncdhw + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnIdesc, getDataType(), + 5, dimX, strideX + )); + + auto ws = w->shape; + int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]}; + // cudnn only support this two format + // https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor + #define filterFormat_oihw CUDNN_TENSOR_NCHW + #define filterFormat_ohwi CUDNN_TENSOR_NHWC + + // dimW: KCRS(oihw) + checkCudaErrors(cudnnSetFilterNdDescriptor( + cudnnFdesc, getDataType(), + // filterFormat_@WFORMAT, 5, dimW + filterFormat_oihw, 5, dimW + )); + + int padA[] = {paddingd, paddingh, paddingw}; + int convstrideA[] = {strided, strideh, stridew}; + int dilationA[] = {dilationd, dilationh, dilationw}; + // difference between + // CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION + // is the kernel rc order + // currently, No perf difference is observed between + // this two mode + checkCudaErrors(cudnnSetConvolutionNdDescriptor( + cudnnConvDesc, 3, + padA, convstrideA, dilationA, + CUDNN_CROSS_CORRELATION, getDataType() + )); + + // using tensor core + // checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) ); + + + int sy[] = {0,0,0,0,1}; + for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1]; + int strideY[5]; + if (xformat == "ncdhw") { + y->shape.unpack(yn, yc, yd, yh, yw); + int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]}; + memcpy(strideY, tmp, sizeof(tmp)); + } else { + y->shape.unpack(yn, yd, yh, yw, yc); + int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]}; + memcpy(strideY, tmp, sizeof(tmp)); + } + int dimY[] = {yn, yc, yd, yh, yw}; + + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnOdesc, getDataType(), + 5, dimY, strideY + )); + + cudnnConvolutionBwdDataAlgo_t algos[] = { + CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED + }; + int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; + int perf_count; + cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_algos]; + cudnnConvolutionBwdDataAlgo_t algo; + bool benchmark=true; + + jk.clear(); + jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ","; + jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ","; + jk << paddingd << paddingh << paddingw << "," << strided << strideh <second; + else { + if (bwdx_algo_cache.size()>=max_cache_size) benchmark = false; + if (benchmark) { + size_t max_ws_size = 0; + for (int i = 0; i < num_algos; i++) { + size_t sz; + cudnnStatus_t ret = cudnnGetConvolutionBackwardDataWorkspaceSize(handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, cudnnIdesc, algos[i], &sz); + // continue if use too much workspace + if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue; + if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; + } + size_t allocation; + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); + checkCudaErrors(cudnnFindConvolutionBackwardDataAlgorithmEx( + handle_, + cudnnFdesc, w->ptr(), + cudnnOdesc, y->ptr(), + cudnnConvDesc, + cudnnIdesc, x->ptr(), + num_algos, + &perf_count, + perf_results, + ws, + max_ws_size)); + exe.temp_allocator->free(ws, max_ws_size, allocation); + } else { + checkCudaErrors(cudnnGetConvolutionBackwardDataAlgorithm_v7( + handle_, + cudnnFdesc, + cudnnOdesc, + cudnnConvDesc, + cudnnIdesc, + num_algos, + &perf_count, + perf_results)); + } + int best_algo_idx=-1; + for (int i = 0; i < perf_count; i++) + if (perf_results[i].status == CUDNN_STATUS_SUCCESS){ + best_algo_idx=i; + break; + } + ASSERT(best_algo_idx!=-1); + algo=perf_results[best_algo_idx].algo; + if (benchmark) { + bwdx_algo_cache[jk.to_string()] = algo; + if (bwdx_algo_cache.size()==max_cache_size) + LOGw << "backward x algorithm cache is full"; + } + } + + // TODO: warp work space + void *workSpace = 0; + size_t workSpaceSize; + checkCudaErrors (cudnnGetConvolutionBackwardDataWorkspaceSize( + handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, + cudnnIdesc, algo, &workSpaceSize)); + size_t allocation; + if (workSpaceSize > 0) { + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); + } + float alpha=1, beta=0; + checkCudaErrors(cudnnConvolutionBackwardData( + handle_, + (void*)(&alpha), + cudnnFdesc, w->ptr(), + cudnnOdesc, y->ptr(), + cudnnConvDesc, + algo, + workSpace, workSpaceSize, + (void*)(&beta), + cudnnIdesc, x->ptr()) + ); + if (workSpace) + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); + + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); + checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc )); + checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc )); +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.h new file mode 100644 index 00000000..e510072f --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.h @@ -0,0 +1,28 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang +// Guowei Yang <471184555@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnConv3dBackwardXOp : Op { + Var* w, * dy, * dx; + int xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups; + string xformat; + + CudnnConv3dBackwardXOp(Var* w, Var* y, int depth, int height, int width, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups=1, string xformat="ncdhw"); + + const char* name() const override { return "cudnn_conv3d_backward_x"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.cc new file mode 100644 index 00000000..743656aa --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.cc @@ -0,0 +1,284 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "cudnn_conv3d_op.h" +#include "cudnn_warper.h" +#include "executor.h" +#include "ops/op_register.h" + +using namespace std; + +namespace jittor { + +#pragma GCC diagnostic ignored "-Wunused-variable" + +#ifndef JIT + +CudnnConv3dOp::CudnnConv3dOp(Var* x, Var* w, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat) + : x(x), w(w), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups), + xformat(move(xformat)) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + y = create_output(nullptr, dtype_infer(x->ns, w->ns)); +} + +void CudnnConv3dOp::infer_shape() { + ASSERTop(x->shape.size(),==,5); + ASSERTop(w->shape.size(),==,5); + int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + if (xformat == "ncdhw") + x->shape.unpack(xn, xc, xd, xh, xw); + else + x->shape.unpack(xn, xd, xh, xw, xc); + w->shape.unpack(wco, wci, wd, wh, ww); + ASSERTop(wci * groups,==,xc); + yn = xn, yc = wco; + yd = (xd+paddingd*2-wd*dilationd+dilationd-1)/strided+1; + yh = (xh+paddingh*2-wh*dilationh+dilationh-1)/strideh+1; + yw = (xw+paddingw*2-ww*dilationw+dilationw-1)/stridew+1; + if (xformat == "ncdhw") + y->set_shape(NanoVector(yn, yc, yd, yh, yw)); + else + y->set_shape(NanoVector(yn, yd, yh, yw, yc)); +} + +void CudnnConv3dOp::jit_prepare(JK& jk) { + jk << _CS("[Tx:") << x->dtype(); + jk << _CS("][Ty:") << y->dtype(); + jk << _CS("][Tw:") << w->dtype(); + jk << ']'; +} + +static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x") + .get_constructor(); +static auto make_backwardw = get_op_info("cudnn_conv3d_backward_w") + .get_constructor(); + +VarPtr CudnnConv3dOp::grad(Var* out, Var* dout, Var* v, int v_index) { + int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + if (xformat == "ncdhw") + x->shape.unpack(xn, xc, xd, xh, xw); + else + x->shape.unpack(xn, xd, xh, xw, xc); + w->shape.unpack(wco, wci, wd, wh, ww); + if (v_index == 0) { + return make_backwardx(w, dout, xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat); + } else { + return make_backwardw(x, dout, wd, wh, ww, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat); + } +} + +// unordered_map fwd_algo_cache; + +#else // JIT +#ifdef JIT_cuda + +#pragma clang diagnostic ignored "-Wtautological-compare" + +extern unordered_map fwd_algo_cache; + +template __inline__ cudnnDataType_t getDataType(); +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } + +void CudnnConv3dOp::jit_run() { + cudnnHandle_t& handle_ = cudnn_handle; + + cudnnTensorDescriptor_t cudnnIdesc; + cudnnFilterDescriptor_t cudnnFdesc; + cudnnTensorDescriptor_t cudnnOdesc; + cudnnConvolutionDescriptor_t cudnnConvDesc; + + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc )); + checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc )); + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc )); + checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc )); + checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups )); + + + int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + int sx[] = {0,0,0,0,1}; + for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1]; + int strideX[5]; + if (xformat == "ncdhw") { + x->shape.unpack(xn, xc, xd, xh, xw); + int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]}; + memcpy(strideX, tmp, sizeof(tmp)); + } else { + x->shape.unpack(xn, xd, xh, xw, xc); + int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]}; + memcpy(strideX, tmp, sizeof(tmp)); + } + int dimX[] = {xn, xc, xd, xh, xw}; + // dimX: ncdhw + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnIdesc, getDataType(), + 5, dimX, strideX + )); + + auto ws = w->shape; + int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]}; + // cudnn only support this two format + // https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor + #define filterFormat_oihw CUDNN_TENSOR_NCHW + #define filterFormat_ohwi CUDNN_TENSOR_NHWC + + // dimW: KCRS(oihw) + checkCudaErrors(cudnnSetFilterNdDescriptor( + cudnnFdesc, getDataType(), + // filterFormat_@WFORMAT, 5, dimW + filterFormat_oihw, 5, dimW + )); + + int padA[] = {paddingd, paddingh, paddingw}; + int convstrideA[] = {strided, strideh, stridew}; + int dilationA[] = {dilationd, dilationh, dilationw}; + // difference between + // CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION + // is the kernel rc order + // currently, No perf difference is observed between + // this two mode + checkCudaErrors(cudnnSetConvolutionNdDescriptor( + cudnnConvDesc, 3, + padA, convstrideA, dilationA, + CUDNN_CROSS_CORRELATION, getDataType() + )); + + // using tensor core + // checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) ); + + + int sy[] = {0,0,0,0,1}; + for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1]; + int strideY[5]; + if (xformat == "ncdhw") { + y->shape.unpack(yn, yc, yd, yh, yw); + int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]}; + memcpy(strideY, tmp, sizeof(tmp)); + } else { + y->shape.unpack(yn, yd, yh, yw, yc); + int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]}; + memcpy(strideY, tmp, sizeof(tmp)); + } + int dimY[] = {yn, yc, yd, yh, yw}; + + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnOdesc, getDataType(), + 5, dimY, strideY + )); + + cudnnConvolutionFwdAlgo_t algos[] = { + CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_FFT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, + }; + int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; + int perf_count; + cudnnConvolutionFwdAlgoPerf_t perf_results[num_algos]; + cudnnConvolutionFwdAlgo_t algo; + bool benchmark=true; + + jk.clear(); + jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ","; + jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ","; + jk << paddingd << paddingh << paddingw << "," << strided << strideh <second; + else { + if (fwd_algo_cache.size()>=max_cache_size) benchmark = false; + if (benchmark) { + size_t max_ws_size = 0; + for (int i = 0; i < num_algos; i++) { + size_t sz; + cudnnStatus_t ret = cudnnGetConvolutionForwardWorkspaceSize( + handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc, + cudnnOdesc, algos[i], &sz); + // continue if use too much workspace + if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue; + if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; + } + size_t allocation; + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); + checkCudaErrors(cudnnFindConvolutionForwardAlgorithmEx( + handle_, + cudnnIdesc, x->ptr(), + cudnnFdesc, w->ptr(), + cudnnConvDesc, + cudnnOdesc, y->ptr(), + num_algos, + &perf_count, + perf_results, + ws, + max_ws_size)); + exe.temp_allocator->free(ws, max_ws_size, allocation); + } else { + checkCudaErrors(cudnnGetConvolutionForwardAlgorithm_v7( + handle_, + cudnnIdesc, + cudnnFdesc, + cudnnConvDesc, + cudnnOdesc, + num_algos, + &perf_count, + perf_results)); + } + int best_algo_idx=-1; + for (int i = 0; i < perf_count; i++) + if (perf_results[i].status == CUDNN_STATUS_SUCCESS){ + best_algo_idx=i; + break; + } + ASSERT(best_algo_idx!=-1); + algo=perf_results[best_algo_idx].algo; + if (benchmark) { + fwd_algo_cache[jk.to_string()] = algo; + if (fwd_algo_cache.size()==max_cache_size) + LOGw << "forward_ algorithm cache is full"; + } + } + + // TODO: warp work space + void *workSpace = 0; + size_t workSpaceSize; + checkCudaErrors (cudnnGetConvolutionForwardWorkspaceSize( + handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc, + cudnnOdesc, algo, &workSpaceSize) ); + size_t allocation; + if (workSpaceSize > 0) { + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); + } + float alpha=1, beta=0; + checkCudaErrors(cudnnConvolutionForward( + handle_, + (void*)(&alpha), + cudnnIdesc, x->ptr(), + cudnnFdesc, w->ptr(), + cudnnConvDesc, + algo, + workSpace, workSpaceSize, + (void*)(&beta), + cudnnOdesc, y->ptr()) + ); + if (workSpace) + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); + + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); + checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc )); + checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc )); +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.h new file mode 100644 index 00000000..6aaf36a0 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.h @@ -0,0 +1,24 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnConv3dOp : Op { + Var* x, * w, * y; + int strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups; + string xformat; + CudnnConv3dOp(Var* x, Var* w, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd=1, int dilationh=1, int dilationw=1, int groups=1, string xformat="ncdhw"); + + const char* name() const override { return "cudnn_conv3d"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 5d33a493..3d164ae2 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -21,6 +21,7 @@ from collections import OrderedDict from jittor.pool import * from jittor.optim import * from jittor.misc import _pair, _triple +from jittor_utils import LOG def matmul_transpose(a, b): @@ -639,7 +640,6 @@ class Conv1d(Module): class Conv3d(Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): - LOG.w("Optimizations of Conv3d are working in progress, it maybe slow currently.") self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) @@ -665,65 +665,7 @@ class Conv3d(Module): self.bias = None def execute(self, x): - if self.groups == 1: - N,C,H,W,D = x.shape - Kh, Kw, Kd = self.kernel_size - assert C==self.in_channels - oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1 - ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1 - od = (D+self.padding[2]*2-Kd*self.dilation[2]+self.dilation[2]-1)//self.stride[2]+1 - xx = x.reindex([N,self.out_channels,C,oh,ow,od,Kh,Kw,Kd], [ - 'i0', # Nid - 'i2', # Cid - f'i3*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid - f'i4*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid - f'i5*{self.stride[2]}-{self.padding[2]}+i8*{self.dilation[2]}', # Did+KDid - ]) - ww = self.weight.broadcast(xx.shape, [0,3,4,5]) - yy = xx*ww - y = yy.sum([2,6,7,8]) # Kc, Kh, Kw, Kd - if self.bias is not None: - b = self.bias.broadcast(y.shape, [0,2,3,4]) - y = y + b - return y - else: - N,C,H,W,D = x.shape - Kh, Kw, Kd = self.kernel_size - G = self.groups - CpG = C // G # channels per group - assert C==self.in_channels - oc = self.out_channels - oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1 - ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1 - od = (D+self.padding[2]*2-Kd*self.dilation[2]+self.dilation[2]-1)//self.stride[2]+1 - xx = x.reindex([N,G,oc//G,CpG,oh,ow,od,Kh,Kw,Kd], [ - 'i0', # Nid - f'i1*{CpG}+i3', # Gid - f'i4*{self.stride[0]}-{self.padding[0]}+i7*{self.dilation[0]}', # Hid+Khid - f'i5*{self.stride[1]}-{self.padding[1]}+i8*{self.dilation[1]}', # Wid+KWid - f'i6*{self.stride[2]}-{self.padding[2]}+i9*{self.dilation[2]}', # Did+KDid - ]) - # w: [oc, CpG, Kh, Kw, Kd] - ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, od, Kh, Kw, Kd], [ - f'i1*{oc//G}+i2', - 'i3', - 'i7', - 'i8', - 'i9' - ]) - ww.compile_options = xx.compile_options = {"G":G,"C":C} - yy = xx*ww - y = yy.reindex_reduce('add', [N, oc, oh, ow, od], [ - 'i0', - f'i1*{oc//G}+i2', - 'i4', - 'i5', - 'i6' - ]) - if self.bias is not None: - b = self.bias.broadcast(y.shape, [0,2,3,4]) - y = y + b - return y + return conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): padding = _pair(padding) @@ -790,12 +732,12 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): out_channels = weight.shape[0] if groups == 1: - N,C,H,W,D = x.shape - Kh, Kw, Kd = weight.shape[-3:] - oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1 - ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1 - od = (D+padding[2]*2-Kd*dilation[2]+dilation[2]-1)//stride[2]+1 - xx = x.reindex([N,out_channels,C,oh,ow,od,Kh,Kw,Kd], [ + N,C,D,H,W = x.shape + Kd, Kh, Kw = weight.shape[-3:] + od = (D+padding[0]*2-Kd*dilation[0]+dilation[0]-1)//stride[0]+1 + oh = (H+padding[1]*2-Kh*dilation[1]+dilation[1]-1)//stride[1]+1 + ow = (W+padding[2]*2-Kw*dilation[2]+dilation[2]-1)//stride[2]+1 + xx = x.reindex([N,out_channels,C,od,oh,ow,Kd,Kh,Kw], [ 'i0', # Nid 'i2', # Cid f'i3*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid @@ -810,15 +752,15 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): y = y + b return y else: - N,C,H,W,D = x.shape - Kh, Kw, Kd = weight.shape[-3:] + N,C,D,H,W = x.shape + Kd, Kh, Kw = weight.shape[-3:] G = groups CpG = C // G # channels per group oc = out_channels - oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1 - ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1 - od = (D+padding[2]*2-Kd*dilation[2]+dilation[2]-1)//stride[2]+1 - xx = x.reindex([N,G,oc//G,CpG,oh,ow,od,Kh,Kw,Kd], [ + od = (D+padding[0]*2-Kd*dilation[0]+dilation[0]-1)//stride[0]+1 + oh = (H+padding[1]*2-Kh*dilation[1]+dilation[1]-1)//stride[1]+1 + ow = (W+padding[2]*2-Kw*dilation[2]+dilation[2]-1)//stride[2]+1 + xx = x.reindex([N,G,oc//G,CpG,od,oh,ow,Kd,Kh,Kw], [ 'i0', # Nid f'i1*{CpG}+i3', # Gid f'i4*{stride[0]}-{padding[0]}+i7*{dilation[0]}', # Hid+Khid @@ -835,7 +777,7 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 'i9' ]) yy = xx*ww - y = yy.reindex_reduce('add', [N, oc, oh, ow, od], [ + y = yy.reindex_reduce('add', [N, oc, od, oh, ow], [ 'i0', f'i1*{oc//G}+i2', 'i4', @@ -906,6 +848,45 @@ class ConvTranspose(Module): y = y + b return y +class ConvTranspose3d(Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ + padding=0, output_padding=0, groups=1, bias=True, dilation=1): + self.in_channels = in_channels + self.out_channels = out_channels + + # added + self.dilation = dilation + self.group = groups + assert groups==1, "Group conv not supported yet." + + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation) + # added + self.padding = padding if isinstance(padding, tuple) else (padding, padding, padding) + self.real_padding = ( + self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0], + self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1], + self.dilation[2] * (self.kernel_size[2] - 1) - self.padding[2]) + self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding, output_padding) + assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \ + self.output_padding[1] < max(self.stride[1], self.dilation[1]) and \ + self.output_padding[2] < max(self.stride[2], self.dilation[2]), \ + "output padding must be smaller than max(stride, dilation)" + + self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float") + if bias: + fan=1 + for i in self.weight.shape[1:]: + fan *= i + bound = 1 / math.sqrt(fan) + self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound) + else: + self.bias = None + + def execute(self, x): + return conv_transpose3d(x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.group, self.dilation) + def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): x = input N,C,H,W = x.shape @@ -944,6 +925,47 @@ def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding assert not bias, "Bias should be none or jittor var" return y +def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + x = input + N,C,D,H,W = x.shape + i,o,d,h,w = weight.shape + assert C==i + assert groups==1, "Group conv not supported yet." + stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation) + # added + padding = padding if isinstance(padding, tuple) else (padding, padding, padding) + output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding, output_padding) + assert output_padding[0] < max(stride[0], dilation[0]) and \ + output_padding[1] < max(stride[1], dilation[1]) and \ + output_padding[2] < max(stride[2], dilation[2]), \ + "output padding must be smaller than max(stride, dilation)" + + stride_d, stride_h, stride_w = stride + padding_d, padding_h, padding_w = padding + dilation_d, dilation_h, dilation_w = dilation + + d_out = (D-1) * stride_d + output_padding[0] - 2*padding_d + 1 + (d-1)*dilation_d + h_out = (H-1) * stride_h + output_padding[1] - 2*padding_h + 1 + (h-1)*dilation_h + w_out = (W-1) * stride_w + output_padding[2] - 2*padding_w + 1 + (w-1)*dilation_w + out_shape = (N, o, d_out, h_out, w_out) + shape = (N, i, o, D, H, W, d, h, w) + xx = x.broadcast(shape, (2, 6, 7, 8)) # i,h,w + ww = weight.broadcast(shape, (0, 3, 4, 5)) # N,H,W + y = (ww*xx).reindex_reduce("add", out_shape, [ + 'i0', # N + 'i2', # o + f'i3*{stride_d}-{padding_d}+i6*{dilation_d}', # Did+Kdid + f'i4*{stride_h}-{padding_h}+i7*{dilation_h}', # Hid+Khid + f'i5*{stride_w}-{padding_w}+i8*{dilation_w}', # Wid+KWid + ]) + if isinstance(bias, jt.Var): + b = bias.broadcast(y.shape, [0,2,3,4]) + y = y + b + else: + assert not bias, "Bias should be none or jittor var" + return y + conv_transpose2d = conv_transpose def pad(x,padding, mode='constant', value=0): @@ -1286,7 +1308,7 @@ def linspace_from_neg_one(grid,num_steps,align_corners): return jt.array(ra,dtype=grid.dtype) def make_base_grid_4D(theta,N,C,H,W,align_corners): - base_grid = jt.zeros((N, H, W, 3), dtype=theta.dtype); + base_grid = jt.zeros((N, H, W, 3), dtype=theta.dtype) base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners) base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1) base_grid[...,-1] = 1 diff --git a/python/jittor/src/misc/nano_vector.h b/python/jittor/src/misc/nano_vector.h index 86115cc6..ca814abf 100644 --- a/python/jittor/src/misc/nano_vector.h +++ b/python/jittor/src/misc/nano_vector.h @@ -238,6 +238,21 @@ struct NanoVector { v[i] = at(i); return v; } + + inline void _unpack(int i) { + return; + } + + template + void _unpack(int i, int& x, Args&&... args) { + x = this->operator[](i); + _unpack(i+1, std::forward(args)...); + } + + template + void unpack(Args&&... args) { + _unpack(0, std::forward(args)...); + } }; diff --git a/python/jittor/test/test_cudnn_op.py b/python/jittor/test/test_cudnn_op.py index 03187273..00438b85 100644 --- a/python/jittor/test/test_cudnn_op.py +++ b/python/jittor/test/test_cudnn_op.py @@ -128,7 +128,55 @@ class TestCudnnConvOp(unittest.TestCase): check([10,3,100,100], [5,3,3,3], stride=2, padding=0, dilation=1) check([10,4,40,50], [5,4,5,5], stride=1, padding=1, dilation=1) check([10,4,40,50], [5,4,4,4], stride=3, padding=1, dilation=1) - + + def test_conv3d(self): + def check(xshape, wshape, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), group=1): + with jt.flag_scope(use_cuda=1): + x = jt.random(xshape) + w = jt.random(wshape) + y = jt.cudnn.ops.cudnn_conv3d(x, w, *stride, *padding, *dilation, group) + masky = jt.rand_like(y) + dx, dw = jt.grad(masky*y, [x, w]) + + y2 = jt.nn.conv3d(x, w, None, stride, padding, dilation, group) + dx2, dw2 = jt.grad(masky*y2, [x, w]) + np.testing.assert_allclose(y.data, y2.data) + np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-5, atol=1e-3) + np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-5, atol=1e-3) + + check((2,4,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1)) + check((2,4,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1)) + check((2,4,10,10,10), (5,4,3,3,3), (2,2,2), (0,0,0)) + check((2,4,10,10,10), (5,4,3,3,3), (1,2,3), (0,0,0)) + check((2,4,10,10,10), (5,4,3,4,5), (1,1,1), (1,1,1)) + check((2,4,10,10,10), (5,4,3,4,5), (1,2,3), (0,0,0)) + check((2,4,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1), dilation=(1,2,3)) + + def test_conv_transpose3d(self): + def check(xshape, wshape, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), group=1): + with jt.flag_scope(use_cuda=1): + x = jt.random(xshape) + w = jt.random(wshape) + + y2 = jt.nn.conv_transpose3d(x, w, None, stride, padding, 0, group, dilation) + + with jt.flag_scope(use_cuda=1): + y = jt.cudnn.ops.cudnn_conv3d_backward_x(w, x, *y2.shape[2:], *stride, *padding, *dilation, group) + masky = jt.rand_like(y) + dx, dw = jt.grad(masky*y, [x, w]) + + dx2, dw2 = jt.grad(masky*y2, [x, w]) + np.testing.assert_allclose(y.data, y2.data, rtol=1e-6, atol=1e-4) + np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-6, atol=1e-4) + np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-5, atol=1e-3) + + check((2,5,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1)) + check((2,5,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1)) + check((2,5,10,10,10), (5,4,3,3,3), (2,2,2), (0,0,0)) + check((2,5,10,10,10), (5,4,3,3,3), (1,2,3), (0,0,0)) + check((2,5,10,10,10), (5,4,3,4,5), (1,1,1), (1,1,1)) + check((2,5,10,10,10), (5,4,3,4,5), (1,2,3), (0,0,0)) + check((2,5,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1), dilation=(1,2,3)) if __name__ == "__main__": unittest.main()