optimize conv3d

This commit is contained in:
Dun Liang 2021-06-10 15:11:03 +08:00
parent 77b293b6b8
commit b6fe53e984
11 changed files with 1093 additions and 77 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.22'
__version__ = '1.2.3.23'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -634,7 +634,7 @@ def compile_custom_ops(
if gen_name_ != "":
gen_name = gen_name_
if len(gen_name) > 100:
gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))
gen_name = gen_name[:80] + "___hash" + str(abs(hash(gen_name)))
includes = sorted(list(set(includes)))
includes = "".join(map(lambda x: f" -I'{x}' ", includes))

View File

@ -0,0 +1,288 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>
// Guowei Yang <471184555@qq.com>
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "mem/allocator.h"
#include "var.h"
#include "cudnn_conv3d_backward_w_op.h"
#include "cudnn_warper.h"
#include "executor.h"
#include "ops/op_register.h"
using namespace std;
namespace jittor {
#pragma GCC diagnostic ignored "-Wunused-variable"
#ifndef JIT
CudnnConv3dBackwardWOp::CudnnConv3dBackwardWOp(Var* x, Var* dy, int kd, int kh, int kw, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat)
: x(x), dy(dy), kd(kd), kh(kh), kw(kw), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups),
xformat(move(xformat)) {
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
}
void CudnnConv3dBackwardWOp::infer_shape() {
ASSERTop(x->shape.size(),==,5);
ASSERTop(dy->shape.size(),==,5);
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
if (xformat == "ncdhw") {
x->shape.unpack(xn, xc, xd, xh, xw);
dy->shape.unpack(yn, yc, yd, yh, yw);
} else {
x->shape.unpack(xn, xd, xh, xw, xc);
dy->shape.unpack(yn, yd, yh, yw, yc);
}
wco = yc, wci = xc / groups;
wh = kh;
ww = kw;
wd = kd;
dw->set_shape(NanoVector(wco, wci, wd, wh, ww));
}
void CudnnConv3dBackwardWOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][Ty:") << dy->dtype();
jk << _CS("][Tw:") << dw->dtype();
jk << ']';
}
static auto make_conv3d = get_op_info("cudnn_conv3d")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, string>();
static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, int, int, int, string>();
VarPtr CudnnConv3dBackwardWOp::grad(Var* out, Var* dout, Var* v, int v_index) {
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
if (xformat == "ncdhw") {
x->shape.unpack(xn, xc, xd, xh, xw);
dy->shape.unpack(yn, yc, yd, yh, yw);
} else {
x->shape.unpack(xn, xd, xh, xw, xc);
dy->shape.unpack(yn, yd, yh, yw, yc);
}
if (v_index == 0) {
return make_backwardx(dout, dy, xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
} else {
return make_conv3d(x, dout, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
}
}
// unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
#else // JIT
#ifdef JIT_cuda
#pragma clang diagnostic ignored "-Wtautological-compare"
extern unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
void CudnnConv3dBackwardWOp::jit_run() {
auto w = dw;
auto y = dy;
cudnnHandle_t& handle_ = cudnn_handle;
cudnnTensorDescriptor_t cudnnIdesc;
cudnnFilterDescriptor_t cudnnFdesc;
cudnnTensorDescriptor_t cudnnOdesc;
cudnnConvolutionDescriptor_t cudnnConvDesc;
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc ));
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
int sx[] = {0,0,0,0,1};
for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1];
int strideX[5];
if (xformat == "ncdhw") {
x->shape.unpack(xn, xc, xd, xh, xw);
int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]};
memcpy(strideX, tmp, sizeof(tmp));
} else {
x->shape.unpack(xn, xd, xh, xw, xc);
int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]};
memcpy(strideX, tmp, sizeof(tmp));
}
int dimX[] = {xn, xc, xd, xh, xw};
// dimX: ncdhw
checkCudaErrors(cudnnSetTensorNdDescriptor(
cudnnIdesc, getDataType<Tx>(),
5, dimX, strideX
));
auto ws = w->shape;
int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]};
// cudnn only support this two format
// https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor
#define filterFormat_oihw CUDNN_TENSOR_NCHW
#define filterFormat_ohwi CUDNN_TENSOR_NHWC
// dimW: KCRS(oihw)
checkCudaErrors(cudnnSetFilterNdDescriptor(
cudnnFdesc, getDataType<Tw>(),
// filterFormat_@WFORMAT, 5, dimW
filterFormat_oihw, 5, dimW
));
int padA[] = {paddingd, paddingh, paddingw};
int convstrideA[] = {strided, strideh, stridew};
int dilationA[] = {dilationd, dilationh, dilationw};
// difference between
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
// is the kernel rc order
// currently, No perf difference is observed between
// this two mode
checkCudaErrors(cudnnSetConvolutionNdDescriptor(
cudnnConvDesc, 3,
padA, convstrideA, dilationA,
CUDNN_CROSS_CORRELATION, getDataType<Ty>()
));
// using tensor core
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
int sy[] = {0,0,0,0,1};
for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1];
int strideY[5];
if (xformat == "ncdhw") {
y->shape.unpack(yn, yc, yd, yh, yw);
int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]};
memcpy(strideY, tmp, sizeof(tmp));
} else {
y->shape.unpack(yn, yd, yh, yw, yc);
int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]};
memcpy(strideY, tmp, sizeof(tmp));
}
int dimY[] = {yn, yc, yd, yh, yw};
checkCudaErrors(cudnnSetTensorNdDescriptor(
cudnnOdesc, getDataType<Ty>(),
5, dimY, strideY
));
cudnnConvolutionBwdFilterAlgo_t algos[] = {
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
};
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
int perf_count;
cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_algos];
cudnnConvolutionBwdFilterAlgo_t algo;
bool benchmark=true;
jk.clear();
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
jk << paddingd << paddingh << paddingw << "," << strided << strideh <<stridew << "," << dilationd << dilationh << dilationw << "," << groups << ".";
auto iter = bwdw_algo_cache.find(jk.to_string());
if (iter!=bwdw_algo_cache.end()) algo = iter->second;
else {
if (bwdw_algo_cache.size()>=max_cache_size) benchmark = false;
if (benchmark) {
size_t max_ws_size = 0;
for (int i = 0; i < num_algos; i++) {
size_t sz;
cudnnStatus_t ret = cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, cudnnFdesc, algos[i], &sz);
// continue if use too much workspace
if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue;
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
}
size_t allocation;
void* ws = exe.temp_allocator->alloc(max_ws_size, allocation);
checkCudaErrors(cudnnFindConvolutionBackwardFilterAlgorithmEx(
handle_,
cudnnIdesc, x->ptr<Tx>(),
cudnnOdesc, y->ptr<Ty>(),
cudnnConvDesc,
cudnnFdesc, w->ptr<Tw>(),
num_algos,
&perf_count,
perf_results,
ws,
max_ws_size));
exe.temp_allocator->free(ws, max_ws_size, allocation);
} else {
checkCudaErrors(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
handle_,
cudnnIdesc,
cudnnOdesc,
cudnnConvDesc,
cudnnFdesc,
num_algos,
&perf_count,
perf_results));
}
int best_algo_idx=-1;
for (int i = 0; i < perf_count; i++)
if (perf_results[i].status == CUDNN_STATUS_SUCCESS){
best_algo_idx=i;
break;
}
ASSERT(best_algo_idx!=-1);
algo=perf_results[best_algo_idx].algo;
if (benchmark) {
bwdw_algo_cache[jk.to_string()] = algo;
if (bwdw_algo_cache.size()==max_cache_size)
LOGw << "backward w algorithm cache is full";
}
}
// TODO: warp work space
void *workSpace = 0;
size_t workSpaceSize;
checkCudaErrors (cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc,
cudnnFdesc, algo, &workSpaceSize));
size_t allocation;
if (workSpaceSize > 0) {
workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation);
}
float alpha=1, beta=0;
checkCudaErrors(cudnnConvolutionBackwardFilter(
handle_,
(void*)(&alpha),
cudnnIdesc, x->ptr<Tx>(),
cudnnOdesc, y->ptr<Ty>(),
cudnnConvDesc,
algo,
workSpace, workSpaceSize,
(void*)(&beta),
cudnnFdesc, w->ptr<Tw>())
);
if (workSpace)
exe.temp_allocator->free(workSpace, workSpaceSize, allocation);
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc ));
checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc ));
}
#endif
#endif // JIT
} // jittor

View File

@ -0,0 +1,28 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>
// Guowei Yang <471184555@qq.com>
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
struct CudnnConv3dBackwardWOp : Op {
Var* x, * dy, * dw;
int kd, kh, kw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups;
string xformat;
CudnnConv3dBackwardWOp(Var* x, Var* y, int kd, int kh, int kw, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups=1, string xformat="ncdhw");
const char* name() const override { return "cudnn_conv3d_backward_w"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
DECLARE_jit_run;
};
} // jittor

View File

@ -0,0 +1,279 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>
// Guowei Yang <471184555@qq.com>
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "mem/allocator.h"
#include "var.h"
#include "cudnn_conv3d_backward_x_op.h"
#include "cudnn_warper.h"
#include "executor.h"
#include "ops/op_register.h"
using namespace std;
namespace jittor {
#pragma GCC diagnostic ignored "-Wunused-variable"
#ifndef JIT
CudnnConv3dBackwardXOp::CudnnConv3dBackwardXOp(Var* w, Var* dy, int depth, int height, int width, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat)
: w(w), dy(dy), xd(depth), xh(height), xw(width), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups),
xformat(move(xformat)) {
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));
}
void CudnnConv3dBackwardXOp::infer_shape() {
ASSERTop(w->shape.size(),==,5);
ASSERTop(dy->shape.size(),==,5);
int xn, xc, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
w->shape.unpack(wco, wci, wd, wh, ww);
if (xformat == "ncdhw")
dy->shape.unpack(yn, yc, yd, yh, yw);
else
dy->shape.unpack(yn, yd, yh, yw, yc);
xn = yn, xc = wci * groups;
if (xformat == "ncdhw")
dx->set_shape(NanoVector(xn, xc, xd, xh, xw));
else
dx->set_shape(NanoVector(xn, xd, xh, xw, xc));
}
void CudnnConv3dBackwardXOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << dx->dtype();
jk << _CS("][Ty:") << dy->dtype();
jk << _CS("][Tw:") << w->dtype();
jk << ']';
}
static auto make_conv3d = get_op_info("cudnn_conv3d")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, string>();
static auto make_backwardw = get_op_info("cudnn_conv3d_backward_w")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, int, int, int, string>();
VarPtr CudnnConv3dBackwardXOp::grad(Var* out, Var* dout, Var* v, int v_index) {
int xn, xc, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
w->shape.unpack(wco, wci, wd, wh, ww);
if (v_index == 0) {
return make_backwardw(dout, dy, wd, wh, ww, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
} else {
return make_conv3d(dout, w, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
}
}
// unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
#else // JIT
#ifdef JIT_cuda
#pragma clang diagnostic ignored "-Wtautological-compare"
extern unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
void CudnnConv3dBackwardXOp::jit_run() {
auto x = dx;
auto y = dy;
cudnnHandle_t& handle_ = cudnn_handle;
cudnnTensorDescriptor_t cudnnIdesc;
cudnnFilterDescriptor_t cudnnFdesc;
cudnnTensorDescriptor_t cudnnOdesc;
cudnnConvolutionDescriptor_t cudnnConvDesc;
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc ));
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
int sx[] = {0,0,0,0,1};
for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1];
int strideX[5];
if (xformat == "ncdhw") {
x->shape.unpack(xn, xc, xd, xh, xw);
int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]};
memcpy(strideX, tmp, sizeof(tmp));
} else {
x->shape.unpack(xn, xd, xh, xw, xc);
int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]};
memcpy(strideX, tmp, sizeof(tmp));
}
int dimX[] = {xn, xc, xd, xh, xw};
// dimX: ncdhw
checkCudaErrors(cudnnSetTensorNdDescriptor(
cudnnIdesc, getDataType<Tx>(),
5, dimX, strideX
));
auto ws = w->shape;
int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]};
// cudnn only support this two format
// https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor
#define filterFormat_oihw CUDNN_TENSOR_NCHW
#define filterFormat_ohwi CUDNN_TENSOR_NHWC
// dimW: KCRS(oihw)
checkCudaErrors(cudnnSetFilterNdDescriptor(
cudnnFdesc, getDataType<Tw>(),
// filterFormat_@WFORMAT, 5, dimW
filterFormat_oihw, 5, dimW
));
int padA[] = {paddingd, paddingh, paddingw};
int convstrideA[] = {strided, strideh, stridew};
int dilationA[] = {dilationd, dilationh, dilationw};
// difference between
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
// is the kernel rc order
// currently, No perf difference is observed between
// this two mode
checkCudaErrors(cudnnSetConvolutionNdDescriptor(
cudnnConvDesc, 3,
padA, convstrideA, dilationA,
CUDNN_CROSS_CORRELATION, getDataType<Ty>()
));
// using tensor core
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
int sy[] = {0,0,0,0,1};
for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1];
int strideY[5];
if (xformat == "ncdhw") {
y->shape.unpack(yn, yc, yd, yh, yw);
int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]};
memcpy(strideY, tmp, sizeof(tmp));
} else {
y->shape.unpack(yn, yd, yh, yw, yc);
int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]};
memcpy(strideY, tmp, sizeof(tmp));
}
int dimY[] = {yn, yc, yd, yh, yw};
checkCudaErrors(cudnnSetTensorNdDescriptor(
cudnnOdesc, getDataType<Ty>(),
5, dimY, strideY
));
cudnnConvolutionBwdDataAlgo_t algos[] = {
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED
};
int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
int perf_count;
cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_algos];
cudnnConvolutionBwdDataAlgo_t algo;
bool benchmark=true;
jk.clear();
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
jk << paddingd << paddingh << paddingw << "," << strided << strideh <<stridew << "," << dilationd << dilationh << dilationw << "," << groups << ".";
auto iter = bwdx_algo_cache.find(jk.to_string());
if (iter!=bwdx_algo_cache.end()) algo = iter->second;
else {
if (bwdx_algo_cache.size()>=max_cache_size) benchmark = false;
if (benchmark) {
size_t max_ws_size = 0;
for (int i = 0; i < num_algos; i++) {
size_t sz;
cudnnStatus_t ret = cudnnGetConvolutionBackwardDataWorkspaceSize(handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, cudnnIdesc, algos[i], &sz);
// continue if use too much workspace
if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue;
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
}
size_t allocation;
void* ws = exe.temp_allocator->alloc(max_ws_size, allocation);
checkCudaErrors(cudnnFindConvolutionBackwardDataAlgorithmEx(
handle_,
cudnnFdesc, w->ptr<Tw>(),
cudnnOdesc, y->ptr<Ty>(),
cudnnConvDesc,
cudnnIdesc, x->ptr<Tx>(),
num_algos,
&perf_count,
perf_results,
ws,
max_ws_size));
exe.temp_allocator->free(ws, max_ws_size, allocation);
} else {
checkCudaErrors(cudnnGetConvolutionBackwardDataAlgorithm_v7(
handle_,
cudnnFdesc,
cudnnOdesc,
cudnnConvDesc,
cudnnIdesc,
num_algos,
&perf_count,
perf_results));
}
int best_algo_idx=-1;
for (int i = 0; i < perf_count; i++)
if (perf_results[i].status == CUDNN_STATUS_SUCCESS){
best_algo_idx=i;
break;
}
ASSERT(best_algo_idx!=-1);
algo=perf_results[best_algo_idx].algo;
if (benchmark) {
bwdx_algo_cache[jk.to_string()] = algo;
if (bwdx_algo_cache.size()==max_cache_size)
LOGw << "backward x algorithm cache is full";
}
}
// TODO: warp work space
void *workSpace = 0;
size_t workSpaceSize;
checkCudaErrors (cudnnGetConvolutionBackwardDataWorkspaceSize(
handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc,
cudnnIdesc, algo, &workSpaceSize));
size_t allocation;
if (workSpaceSize > 0) {
workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation);
}
float alpha=1, beta=0;
checkCudaErrors(cudnnConvolutionBackwardData(
handle_,
(void*)(&alpha),
cudnnFdesc, w->ptr<Tw>(),
cudnnOdesc, y->ptr<Ty>(),
cudnnConvDesc,
algo,
workSpace, workSpaceSize,
(void*)(&beta),
cudnnIdesc, x->ptr<Tx>())
);
if (workSpace)
exe.temp_allocator->free(workSpace, workSpaceSize, allocation);
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc ));
checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc ));
}
#endif
#endif // JIT
} // jittor

View File

@ -0,0 +1,28 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>
// Guowei Yang <471184555@qq.com>
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
struct CudnnConv3dBackwardXOp : Op {
Var* w, * dy, * dx;
int xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups;
string xformat;
CudnnConv3dBackwardXOp(Var* w, Var* y, int depth, int height, int width, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups=1, string xformat="ncdhw");
const char* name() const override { return "cudnn_conv3d_backward_x"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
DECLARE_jit_run;
};
} // jittor

View File

@ -0,0 +1,284 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "var.h"
#include "cudnn_conv3d_op.h"
#include "cudnn_warper.h"
#include "executor.h"
#include "ops/op_register.h"
using namespace std;
namespace jittor {
#pragma GCC diagnostic ignored "-Wunused-variable"
#ifndef JIT
CudnnConv3dOp::CudnnConv3dOp(Var* x, Var* w, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat)
: x(x), w(w), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups),
xformat(move(xformat)) {
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
y = create_output(nullptr, dtype_infer(x->ns, w->ns));
}
void CudnnConv3dOp::infer_shape() {
ASSERTop(x->shape.size(),==,5);
ASSERTop(w->shape.size(),==,5);
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
if (xformat == "ncdhw")
x->shape.unpack(xn, xc, xd, xh, xw);
else
x->shape.unpack(xn, xd, xh, xw, xc);
w->shape.unpack(wco, wci, wd, wh, ww);
ASSERTop(wci * groups,==,xc);
yn = xn, yc = wco;
yd = (xd+paddingd*2-wd*dilationd+dilationd-1)/strided+1;
yh = (xh+paddingh*2-wh*dilationh+dilationh-1)/strideh+1;
yw = (xw+paddingw*2-ww*dilationw+dilationw-1)/stridew+1;
if (xformat == "ncdhw")
y->set_shape(NanoVector(yn, yc, yd, yh, yw));
else
y->set_shape(NanoVector(yn, yd, yh, yw, yc));
}
void CudnnConv3dOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][Ty:") << y->dtype();
jk << _CS("][Tw:") << w->dtype();
jk << ']';
}
static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, int, int, int, string>();
static auto make_backwardw = get_op_info("cudnn_conv3d_backward_w")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, int, int, int, string>();
VarPtr CudnnConv3dOp::grad(Var* out, Var* dout, Var* v, int v_index) {
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
if (xformat == "ncdhw")
x->shape.unpack(xn, xc, xd, xh, xw);
else
x->shape.unpack(xn, xd, xh, xw, xc);
w->shape.unpack(wco, wci, wd, wh, ww);
if (v_index == 0) {
return make_backwardx(w, dout, xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
} else {
return make_backwardw(x, dout, wd, wh, ww, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
}
}
// unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
#else // JIT
#ifdef JIT_cuda
#pragma clang diagnostic ignored "-Wtautological-compare"
extern unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
void CudnnConv3dOp::jit_run() {
cudnnHandle_t& handle_ = cudnn_handle;
cudnnTensorDescriptor_t cudnnIdesc;
cudnnFilterDescriptor_t cudnnFdesc;
cudnnTensorDescriptor_t cudnnOdesc;
cudnnConvolutionDescriptor_t cudnnConvDesc;
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc ));
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
int sx[] = {0,0,0,0,1};
for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1];
int strideX[5];
if (xformat == "ncdhw") {
x->shape.unpack(xn, xc, xd, xh, xw);
int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]};
memcpy(strideX, tmp, sizeof(tmp));
} else {
x->shape.unpack(xn, xd, xh, xw, xc);
int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]};
memcpy(strideX, tmp, sizeof(tmp));
}
int dimX[] = {xn, xc, xd, xh, xw};
// dimX: ncdhw
checkCudaErrors(cudnnSetTensorNdDescriptor(
cudnnIdesc, getDataType<Tx>(),
5, dimX, strideX
));
auto ws = w->shape;
int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]};
// cudnn only support this two format
// https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor
#define filterFormat_oihw CUDNN_TENSOR_NCHW
#define filterFormat_ohwi CUDNN_TENSOR_NHWC
// dimW: KCRS(oihw)
checkCudaErrors(cudnnSetFilterNdDescriptor(
cudnnFdesc, getDataType<Tw>(),
// filterFormat_@WFORMAT, 5, dimW
filterFormat_oihw, 5, dimW
));
int padA[] = {paddingd, paddingh, paddingw};
int convstrideA[] = {strided, strideh, stridew};
int dilationA[] = {dilationd, dilationh, dilationw};
// difference between
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
// is the kernel rc order
// currently, No perf difference is observed between
// this two mode
checkCudaErrors(cudnnSetConvolutionNdDescriptor(
cudnnConvDesc, 3,
padA, convstrideA, dilationA,
CUDNN_CROSS_CORRELATION, getDataType<Ty>()
));
// using tensor core
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
int sy[] = {0,0,0,0,1};
for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1];
int strideY[5];
if (xformat == "ncdhw") {
y->shape.unpack(yn, yc, yd, yh, yw);
int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]};
memcpy(strideY, tmp, sizeof(tmp));
} else {
y->shape.unpack(yn, yd, yh, yw, yc);
int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]};
memcpy(strideY, tmp, sizeof(tmp));
}
int dimY[] = {yn, yc, yd, yh, yw};
checkCudaErrors(cudnnSetTensorNdDescriptor(
cudnnOdesc, getDataType<Ty>(),
5, dimY, strideY
));
cudnnConvolutionFwdAlgo_t algos[] = {
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
};
int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
int perf_count;
cudnnConvolutionFwdAlgoPerf_t perf_results[num_algos];
cudnnConvolutionFwdAlgo_t algo;
bool benchmark=true;
jk.clear();
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
jk << paddingd << paddingh << paddingw << "," << strided << strideh <<stridew << "," << dilationd << dilationh << dilationw << "," << groups << ".";
auto iter = fwd_algo_cache.find(jk.to_string());
if (iter!=fwd_algo_cache.end()) algo = iter->second;
else {
if (fwd_algo_cache.size()>=max_cache_size) benchmark = false;
if (benchmark) {
size_t max_ws_size = 0;
for (int i = 0; i < num_algos; i++) {
size_t sz;
cudnnStatus_t ret = cudnnGetConvolutionForwardWorkspaceSize(
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
cudnnOdesc, algos[i], &sz);
// continue if use too much workspace
if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue;
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
}
size_t allocation;
void* ws = exe.temp_allocator->alloc(max_ws_size, allocation);
checkCudaErrors(cudnnFindConvolutionForwardAlgorithmEx(
handle_,
cudnnIdesc, x->ptr<Tx>(),
cudnnFdesc, w->ptr<Tw>(),
cudnnConvDesc,
cudnnOdesc, y->ptr<Ty>(),
num_algos,
&perf_count,
perf_results,
ws,
max_ws_size));
exe.temp_allocator->free(ws, max_ws_size, allocation);
} else {
checkCudaErrors(cudnnGetConvolutionForwardAlgorithm_v7(
handle_,
cudnnIdesc,
cudnnFdesc,
cudnnConvDesc,
cudnnOdesc,
num_algos,
&perf_count,
perf_results));
}
int best_algo_idx=-1;
for (int i = 0; i < perf_count; i++)
if (perf_results[i].status == CUDNN_STATUS_SUCCESS){
best_algo_idx=i;
break;
}
ASSERT(best_algo_idx!=-1);
algo=perf_results[best_algo_idx].algo;
if (benchmark) {
fwd_algo_cache[jk.to_string()] = algo;
if (fwd_algo_cache.size()==max_cache_size)
LOGw << "forward_ algorithm cache is full";
}
}
// TODO: warp work space
void *workSpace = 0;
size_t workSpaceSize;
checkCudaErrors (cudnnGetConvolutionForwardWorkspaceSize(
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
cudnnOdesc, algo, &workSpaceSize) );
size_t allocation;
if (workSpaceSize > 0) {
workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation);
}
float alpha=1, beta=0;
checkCudaErrors(cudnnConvolutionForward(
handle_,
(void*)(&alpha),
cudnnIdesc, x->ptr<Tx>(),
cudnnFdesc, w->ptr<Tw>(),
cudnnConvDesc,
algo,
workSpace, workSpaceSize,
(void*)(&beta),
cudnnOdesc, y->ptr<Ty>())
);
if (workSpace)
exe.temp_allocator->free(workSpace, workSpaceSize, allocation);
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc ));
checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc ));
}
#endif
#endif // JIT
} // jittor

View File

@ -0,0 +1,24 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
struct CudnnConv3dOp : Op {
Var* x, * w, * y;
int strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups;
string xformat;
CudnnConv3dOp(Var* x, Var* w, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd=1, int dilationh=1, int dilationw=1, int groups=1, string xformat="ncdhw");
const char* name() const override { return "cudnn_conv3d"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
DECLARE_jit_run;
};
} // jittor

View File

@ -21,6 +21,7 @@ from collections import OrderedDict
from jittor.pool import *
from jittor.optim import *
from jittor.misc import _pair, _triple
from jittor_utils import LOG
def matmul_transpose(a, b):
@ -639,7 +640,6 @@ class Conv1d(Module):
class Conv3d(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
LOG.w("Optimizations of Conv3d are working in progress, it maybe slow currently.")
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
@ -665,65 +665,7 @@ class Conv3d(Module):
self.bias = None
def execute(self, x):
if self.groups == 1:
N,C,H,W,D = x.shape
Kh, Kw, Kd = self.kernel_size
assert C==self.in_channels
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
od = (D+self.padding[2]*2-Kd*self.dilation[2]+self.dilation[2]-1)//self.stride[2]+1
xx = x.reindex([N,self.out_channels,C,oh,ow,od,Kh,Kw,Kd], [
'i0', # Nid
'i2', # Cid
f'i3*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid
f'i4*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid
f'i5*{self.stride[2]}-{self.padding[2]}+i8*{self.dilation[2]}', # Did+KDid
])
ww = self.weight.broadcast(xx.shape, [0,3,4,5])
yy = xx*ww
y = yy.sum([2,6,7,8]) # Kc, Kh, Kw, Kd
if self.bias is not None:
b = self.bias.broadcast(y.shape, [0,2,3,4])
y = y + b
return y
else:
N,C,H,W,D = x.shape
Kh, Kw, Kd = self.kernel_size
G = self.groups
CpG = C // G # channels per group
assert C==self.in_channels
oc = self.out_channels
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
od = (D+self.padding[2]*2-Kd*self.dilation[2]+self.dilation[2]-1)//self.stride[2]+1
xx = x.reindex([N,G,oc//G,CpG,oh,ow,od,Kh,Kw,Kd], [
'i0', # Nid
f'i1*{CpG}+i3', # Gid
f'i4*{self.stride[0]}-{self.padding[0]}+i7*{self.dilation[0]}', # Hid+Khid
f'i5*{self.stride[1]}-{self.padding[1]}+i8*{self.dilation[1]}', # Wid+KWid
f'i6*{self.stride[2]}-{self.padding[2]}+i9*{self.dilation[2]}', # Did+KDid
])
# w: [oc, CpG, Kh, Kw, Kd]
ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, od, Kh, Kw, Kd], [
f'i1*{oc//G}+i2',
'i3',
'i7',
'i8',
'i9'
])
ww.compile_options = xx.compile_options = {"G":G,"C":C}
yy = xx*ww
y = yy.reindex_reduce('add', [N, oc, oh, ow, od], [
'i0',
f'i1*{oc//G}+i2',
'i4',
'i5',
'i6'
])
if self.bias is not None:
b = self.bias.broadcast(y.shape, [0,2,3,4])
y = y + b
return y
return conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
padding = _pair(padding)
@ -790,12 +732,12 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
out_channels = weight.shape[0]
if groups == 1:
N,C,H,W,D = x.shape
Kh, Kw, Kd = weight.shape[-3:]
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
od = (D+padding[2]*2-Kd*dilation[2]+dilation[2]-1)//stride[2]+1
xx = x.reindex([N,out_channels,C,oh,ow,od,Kh,Kw,Kd], [
N,C,D,H,W = x.shape
Kd, Kh, Kw = weight.shape[-3:]
od = (D+padding[0]*2-Kd*dilation[0]+dilation[0]-1)//stride[0]+1
oh = (H+padding[1]*2-Kh*dilation[1]+dilation[1]-1)//stride[1]+1
ow = (W+padding[2]*2-Kw*dilation[2]+dilation[2]-1)//stride[2]+1
xx = x.reindex([N,out_channels,C,od,oh,ow,Kd,Kh,Kw], [
'i0', # Nid
'i2', # Cid
f'i3*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid
@ -810,15 +752,15 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
y = y + b
return y
else:
N,C,H,W,D = x.shape
Kh, Kw, Kd = weight.shape[-3:]
N,C,D,H,W = x.shape
Kd, Kh, Kw = weight.shape[-3:]
G = groups
CpG = C // G # channels per group
oc = out_channels
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
od = (D+padding[2]*2-Kd*dilation[2]+dilation[2]-1)//stride[2]+1
xx = x.reindex([N,G,oc//G,CpG,oh,ow,od,Kh,Kw,Kd], [
od = (D+padding[0]*2-Kd*dilation[0]+dilation[0]-1)//stride[0]+1
oh = (H+padding[1]*2-Kh*dilation[1]+dilation[1]-1)//stride[1]+1
ow = (W+padding[2]*2-Kw*dilation[2]+dilation[2]-1)//stride[2]+1
xx = x.reindex([N,G,oc//G,CpG,od,oh,ow,Kd,Kh,Kw], [
'i0', # Nid
f'i1*{CpG}+i3', # Gid
f'i4*{stride[0]}-{padding[0]}+i7*{dilation[0]}', # Hid+Khid
@ -835,7 +777,7 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
'i9'
])
yy = xx*ww
y = yy.reindex_reduce('add', [N, oc, oh, ow, od], [
y = yy.reindex_reduce('add', [N, oc, od, oh, ow], [
'i0',
f'i1*{oc//G}+i2',
'i4',
@ -906,6 +848,45 @@ class ConvTranspose(Module):
y = y + b
return y
class ConvTranspose3d(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
padding=0, output_padding=0, groups=1, bias=True, dilation=1):
self.in_channels = in_channels
self.out_channels = out_channels
# added
self.dilation = dilation
self.group = groups
assert groups==1, "Group conv not supported yet."
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation)
# added
self.padding = padding if isinstance(padding, tuple) else (padding, padding, padding)
self.real_padding = (
self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0],
self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1],
self.dilation[2] * (self.kernel_size[2] - 1) - self.padding[2])
self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding, output_padding)
assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \
self.output_padding[1] < max(self.stride[1], self.dilation[1]) and \
self.output_padding[2] < max(self.stride[2], self.dilation[2]), \
"output padding must be smaller than max(stride, dilation)"
self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float")
if bias:
fan=1
for i in self.weight.shape[1:]:
fan *= i
bound = 1 / math.sqrt(fan)
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
else:
self.bias = None
def execute(self, x):
return conv_transpose3d(x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.group, self.dilation)
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
x = input
N,C,H,W = x.shape
@ -944,6 +925,47 @@ def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding
assert not bias, "Bias should be none or jittor var"
return y
def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
x = input
N,C,D,H,W = x.shape
i,o,d,h,w = weight.shape
assert C==i
assert groups==1, "Group conv not supported yet."
stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation)
# added
padding = padding if isinstance(padding, tuple) else (padding, padding, padding)
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding, output_padding)
assert output_padding[0] < max(stride[0], dilation[0]) and \
output_padding[1] < max(stride[1], dilation[1]) and \
output_padding[2] < max(stride[2], dilation[2]), \
"output padding must be smaller than max(stride, dilation)"
stride_d, stride_h, stride_w = stride
padding_d, padding_h, padding_w = padding
dilation_d, dilation_h, dilation_w = dilation
d_out = (D-1) * stride_d + output_padding[0] - 2*padding_d + 1 + (d-1)*dilation_d
h_out = (H-1) * stride_h + output_padding[1] - 2*padding_h + 1 + (h-1)*dilation_h
w_out = (W-1) * stride_w + output_padding[2] - 2*padding_w + 1 + (w-1)*dilation_w
out_shape = (N, o, d_out, h_out, w_out)
shape = (N, i, o, D, H, W, d, h, w)
xx = x.broadcast(shape, (2, 6, 7, 8)) # i,h,w
ww = weight.broadcast(shape, (0, 3, 4, 5)) # N,H,W
y = (ww*xx).reindex_reduce("add", out_shape, [
'i0', # N
'i2', # o
f'i3*{stride_d}-{padding_d}+i6*{dilation_d}', # Did+Kdid
f'i4*{stride_h}-{padding_h}+i7*{dilation_h}', # Hid+Khid
f'i5*{stride_w}-{padding_w}+i8*{dilation_w}', # Wid+KWid
])
if isinstance(bias, jt.Var):
b = bias.broadcast(y.shape, [0,2,3,4])
y = y + b
else:
assert not bias, "Bias should be none or jittor var"
return y
conv_transpose2d = conv_transpose
def pad(x,padding, mode='constant', value=0):
@ -1286,7 +1308,7 @@ def linspace_from_neg_one(grid,num_steps,align_corners):
return jt.array(ra,dtype=grid.dtype)
def make_base_grid_4D(theta,N,C,H,W,align_corners):
base_grid = jt.zeros((N, H, W, 3), dtype=theta.dtype);
base_grid = jt.zeros((N, H, W, 3), dtype=theta.dtype)
base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners)
base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1)
base_grid[...,-1] = 1

View File

@ -238,6 +238,21 @@ struct NanoVector {
v[i] = at(i);
return v;
}
inline void _unpack(int i) {
return;
}
template<class... Args>
void _unpack(int i, int& x, Args&&... args) {
x = this->operator[](i);
_unpack(i+1, std::forward<Args>(args)...);
}
template<class... Args>
void unpack(Args&&... args) {
_unpack(0, std::forward<Args>(args)...);
}
};

View File

@ -129,6 +129,54 @@ class TestCudnnConvOp(unittest.TestCase):
check([10,4,40,50], [5,4,5,5], stride=1, padding=1, dilation=1)
check([10,4,40,50], [5,4,4,4], stride=3, padding=1, dilation=1)
def test_conv3d(self):
def check(xshape, wshape, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), group=1):
with jt.flag_scope(use_cuda=1):
x = jt.random(xshape)
w = jt.random(wshape)
y = jt.cudnn.ops.cudnn_conv3d(x, w, *stride, *padding, *dilation, group)
masky = jt.rand_like(y)
dx, dw = jt.grad(masky*y, [x, w])
y2 = jt.nn.conv3d(x, w, None, stride, padding, dilation, group)
dx2, dw2 = jt.grad(masky*y2, [x, w])
np.testing.assert_allclose(y.data, y2.data)
np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-5, atol=1e-3)
np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-5, atol=1e-3)
check((2,4,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1))
check((2,4,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1))
check((2,4,10,10,10), (5,4,3,3,3), (2,2,2), (0,0,0))
check((2,4,10,10,10), (5,4,3,3,3), (1,2,3), (0,0,0))
check((2,4,10,10,10), (5,4,3,4,5), (1,1,1), (1,1,1))
check((2,4,10,10,10), (5,4,3,4,5), (1,2,3), (0,0,0))
check((2,4,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1), dilation=(1,2,3))
def test_conv_transpose3d(self):
def check(xshape, wshape, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), group=1):
with jt.flag_scope(use_cuda=1):
x = jt.random(xshape)
w = jt.random(wshape)
y2 = jt.nn.conv_transpose3d(x, w, None, stride, padding, 0, group, dilation)
with jt.flag_scope(use_cuda=1):
y = jt.cudnn.ops.cudnn_conv3d_backward_x(w, x, *y2.shape[2:], *stride, *padding, *dilation, group)
masky = jt.rand_like(y)
dx, dw = jt.grad(masky*y, [x, w])
dx2, dw2 = jt.grad(masky*y2, [x, w])
np.testing.assert_allclose(y.data, y2.data, rtol=1e-6, atol=1e-4)
np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-6, atol=1e-4)
np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-5, atol=1e-3)
check((2,5,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1))
check((2,5,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1))
check((2,5,10,10,10), (5,4,3,3,3), (2,2,2), (0,0,0))
check((2,5,10,10,10), (5,4,3,3,3), (1,2,3), (0,0,0))
check((2,5,10,10,10), (5,4,3,4,5), (1,1,1), (1,1,1))
check((2,5,10,10,10), (5,4,3,4,5), (1,2,3), (0,0,0))
check((2,5,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1), dilation=(1,2,3))
if __name__ == "__main__":
unittest.main()