mirror of https://github.com/Jittor/Jittor
optimize conv3d
This commit is contained in:
parent
77b293b6b8
commit
b6fe53e984
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.2.3.22'
|
__version__ = '1.2.3.23'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
|
|
@ -634,7 +634,7 @@ def compile_custom_ops(
|
||||||
if gen_name_ != "":
|
if gen_name_ != "":
|
||||||
gen_name = gen_name_
|
gen_name = gen_name_
|
||||||
if len(gen_name) > 100:
|
if len(gen_name) > 100:
|
||||||
gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))
|
gen_name = gen_name[:80] + "___hash" + str(abs(hash(gen_name)))
|
||||||
|
|
||||||
includes = sorted(list(set(includes)))
|
includes = sorted(list(set(includes)))
|
||||||
includes = "".join(map(lambda x: f" -I'{x}' ", includes))
|
includes = "".join(map(lambda x: f" -I'{x}' ", includes))
|
||||||
|
|
|
@ -0,0 +1,288 @@
|
||||||
|
// ***************************************************************
|
||||||
|
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||||
|
// Maintainers:
|
||||||
|
// Dun Liang <randonlang@gmail.com>
|
||||||
|
// Guowei Yang <471184555@qq.com>
|
||||||
|
//
|
||||||
|
// This file is subject to the terms and conditions defined in
|
||||||
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
// ***************************************************************
|
||||||
|
#include "mem/allocator.h"
|
||||||
|
#include "var.h"
|
||||||
|
#include "cudnn_conv3d_backward_w_op.h"
|
||||||
|
#include "cudnn_warper.h"
|
||||||
|
#include "executor.h"
|
||||||
|
#include "ops/op_register.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
namespace jittor {
|
||||||
|
|
||||||
|
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||||
|
|
||||||
|
#ifndef JIT
|
||||||
|
|
||||||
|
CudnnConv3dBackwardWOp::CudnnConv3dBackwardWOp(Var* x, Var* dy, int kd, int kh, int kw, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat)
|
||||||
|
: x(x), dy(dy), kd(kd), kh(kh), kw(kw), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups),
|
||||||
|
xformat(move(xformat)) {
|
||||||
|
flags.set(NodeFlags::_cuda, 1);
|
||||||
|
flags.set(NodeFlags::_cpu, 0);
|
||||||
|
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudnnConv3dBackwardWOp::infer_shape() {
|
||||||
|
ASSERTop(x->shape.size(),==,5);
|
||||||
|
ASSERTop(dy->shape.size(),==,5);
|
||||||
|
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||||
|
|
||||||
|
if (xformat == "ncdhw") {
|
||||||
|
x->shape.unpack(xn, xc, xd, xh, xw);
|
||||||
|
dy->shape.unpack(yn, yc, yd, yh, yw);
|
||||||
|
} else {
|
||||||
|
x->shape.unpack(xn, xd, xh, xw, xc);
|
||||||
|
dy->shape.unpack(yn, yd, yh, yw, yc);
|
||||||
|
}
|
||||||
|
wco = yc, wci = xc / groups;
|
||||||
|
wh = kh;
|
||||||
|
ww = kw;
|
||||||
|
wd = kd;
|
||||||
|
dw->set_shape(NanoVector(wco, wci, wd, wh, ww));
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudnnConv3dBackwardWOp::jit_prepare(JK& jk) {
|
||||||
|
jk << _CS("[Tx:") << x->dtype();
|
||||||
|
jk << _CS("][Ty:") << dy->dtype();
|
||||||
|
jk << _CS("][Tw:") << dw->dtype();
|
||||||
|
jk << ']';
|
||||||
|
}
|
||||||
|
|
||||||
|
static auto make_conv3d = get_op_info("cudnn_conv3d")
|
||||||
|
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, string>();
|
||||||
|
static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x")
|
||||||
|
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, int, int, int, string>();
|
||||||
|
|
||||||
|
|
||||||
|
VarPtr CudnnConv3dBackwardWOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
|
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||||
|
|
||||||
|
if (xformat == "ncdhw") {
|
||||||
|
x->shape.unpack(xn, xc, xd, xh, xw);
|
||||||
|
dy->shape.unpack(yn, yc, yd, yh, yw);
|
||||||
|
} else {
|
||||||
|
x->shape.unpack(xn, xd, xh, xw, xc);
|
||||||
|
dy->shape.unpack(yn, yd, yh, yw, yc);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (v_index == 0) {
|
||||||
|
return make_backwardx(dout, dy, xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
|
||||||
|
} else {
|
||||||
|
return make_conv3d(x, dout, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
|
||||||
|
|
||||||
|
#else // JIT
|
||||||
|
#ifdef JIT_cuda
|
||||||
|
|
||||||
|
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||||
|
|
||||||
|
extern unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
|
||||||
|
|
||||||
|
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||||
|
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||||
|
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||||
|
|
||||||
|
void CudnnConv3dBackwardWOp::jit_run() {
|
||||||
|
auto w = dw;
|
||||||
|
auto y = dy;
|
||||||
|
cudnnHandle_t& handle_ = cudnn_handle;
|
||||||
|
|
||||||
|
cudnnTensorDescriptor_t cudnnIdesc;
|
||||||
|
cudnnFilterDescriptor_t cudnnFdesc;
|
||||||
|
cudnnTensorDescriptor_t cudnnOdesc;
|
||||||
|
cudnnConvolutionDescriptor_t cudnnConvDesc;
|
||||||
|
|
||||||
|
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc ));
|
||||||
|
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
|
||||||
|
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
|
||||||
|
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
|
||||||
|
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
|
||||||
|
|
||||||
|
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||||
|
int sx[] = {0,0,0,0,1};
|
||||||
|
for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1];
|
||||||
|
int strideX[5];
|
||||||
|
if (xformat == "ncdhw") {
|
||||||
|
x->shape.unpack(xn, xc, xd, xh, xw);
|
||||||
|
int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]};
|
||||||
|
memcpy(strideX, tmp, sizeof(tmp));
|
||||||
|
} else {
|
||||||
|
x->shape.unpack(xn, xd, xh, xw, xc);
|
||||||
|
int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]};
|
||||||
|
memcpy(strideX, tmp, sizeof(tmp));
|
||||||
|
}
|
||||||
|
int dimX[] = {xn, xc, xd, xh, xw};
|
||||||
|
// dimX: ncdhw
|
||||||
|
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||||
|
cudnnIdesc, getDataType<Tx>(),
|
||||||
|
5, dimX, strideX
|
||||||
|
));
|
||||||
|
|
||||||
|
auto ws = w->shape;
|
||||||
|
int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]};
|
||||||
|
// cudnn only support this two format
|
||||||
|
// https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor
|
||||||
|
#define filterFormat_oihw CUDNN_TENSOR_NCHW
|
||||||
|
#define filterFormat_ohwi CUDNN_TENSOR_NHWC
|
||||||
|
|
||||||
|
// dimW: KCRS(oihw)
|
||||||
|
checkCudaErrors(cudnnSetFilterNdDescriptor(
|
||||||
|
cudnnFdesc, getDataType<Tw>(),
|
||||||
|
// filterFormat_@WFORMAT, 5, dimW
|
||||||
|
filterFormat_oihw, 5, dimW
|
||||||
|
));
|
||||||
|
|
||||||
|
int padA[] = {paddingd, paddingh, paddingw};
|
||||||
|
int convstrideA[] = {strided, strideh, stridew};
|
||||||
|
int dilationA[] = {dilationd, dilationh, dilationw};
|
||||||
|
// difference between
|
||||||
|
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
|
||||||
|
// is the kernel rc order
|
||||||
|
// currently, No perf difference is observed between
|
||||||
|
// this two mode
|
||||||
|
checkCudaErrors(cudnnSetConvolutionNdDescriptor(
|
||||||
|
cudnnConvDesc, 3,
|
||||||
|
padA, convstrideA, dilationA,
|
||||||
|
CUDNN_CROSS_CORRELATION, getDataType<Ty>()
|
||||||
|
));
|
||||||
|
|
||||||
|
// using tensor core
|
||||||
|
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||||
|
|
||||||
|
|
||||||
|
int sy[] = {0,0,0,0,1};
|
||||||
|
for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1];
|
||||||
|
int strideY[5];
|
||||||
|
if (xformat == "ncdhw") {
|
||||||
|
y->shape.unpack(yn, yc, yd, yh, yw);
|
||||||
|
int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]};
|
||||||
|
memcpy(strideY, tmp, sizeof(tmp));
|
||||||
|
} else {
|
||||||
|
y->shape.unpack(yn, yd, yh, yw, yc);
|
||||||
|
int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]};
|
||||||
|
memcpy(strideY, tmp, sizeof(tmp));
|
||||||
|
}
|
||||||
|
int dimY[] = {yn, yc, yd, yh, yw};
|
||||||
|
|
||||||
|
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||||
|
cudnnOdesc, getDataType<Ty>(),
|
||||||
|
5, dimY, strideY
|
||||||
|
));
|
||||||
|
|
||||||
|
cudnnConvolutionBwdFilterAlgo_t algos[] = {
|
||||||
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
|
||||||
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
|
||||||
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
|
||||||
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
|
||||||
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
|
||||||
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
|
||||||
|
};
|
||||||
|
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
||||||
|
int perf_count;
|
||||||
|
cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_algos];
|
||||||
|
cudnnConvolutionBwdFilterAlgo_t algo;
|
||||||
|
bool benchmark=true;
|
||||||
|
|
||||||
|
jk.clear();
|
||||||
|
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
|
||||||
|
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
|
||||||
|
jk << paddingd << paddingh << paddingw << "," << strided << strideh <<stridew << "," << dilationd << dilationh << dilationw << "," << groups << ".";
|
||||||
|
auto iter = bwdw_algo_cache.find(jk.to_string());
|
||||||
|
|
||||||
|
if (iter!=bwdw_algo_cache.end()) algo = iter->second;
|
||||||
|
else {
|
||||||
|
if (bwdw_algo_cache.size()>=max_cache_size) benchmark = false;
|
||||||
|
if (benchmark) {
|
||||||
|
size_t max_ws_size = 0;
|
||||||
|
for (int i = 0; i < num_algos; i++) {
|
||||||
|
size_t sz;
|
||||||
|
cudnnStatus_t ret = cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, cudnnFdesc, algos[i], &sz);
|
||||||
|
// continue if use too much workspace
|
||||||
|
if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue;
|
||||||
|
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||||
|
}
|
||||||
|
size_t allocation;
|
||||||
|
void* ws = exe.temp_allocator->alloc(max_ws_size, allocation);
|
||||||
|
checkCudaErrors(cudnnFindConvolutionBackwardFilterAlgorithmEx(
|
||||||
|
handle_,
|
||||||
|
cudnnIdesc, x->ptr<Tx>(),
|
||||||
|
cudnnOdesc, y->ptr<Ty>(),
|
||||||
|
cudnnConvDesc,
|
||||||
|
cudnnFdesc, w->ptr<Tw>(),
|
||||||
|
num_algos,
|
||||||
|
&perf_count,
|
||||||
|
perf_results,
|
||||||
|
ws,
|
||||||
|
max_ws_size));
|
||||||
|
exe.temp_allocator->free(ws, max_ws_size, allocation);
|
||||||
|
} else {
|
||||||
|
checkCudaErrors(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
|
||||||
|
handle_,
|
||||||
|
cudnnIdesc,
|
||||||
|
cudnnOdesc,
|
||||||
|
cudnnConvDesc,
|
||||||
|
cudnnFdesc,
|
||||||
|
num_algos,
|
||||||
|
&perf_count,
|
||||||
|
perf_results));
|
||||||
|
}
|
||||||
|
int best_algo_idx=-1;
|
||||||
|
for (int i = 0; i < perf_count; i++)
|
||||||
|
if (perf_results[i].status == CUDNN_STATUS_SUCCESS){
|
||||||
|
best_algo_idx=i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
ASSERT(best_algo_idx!=-1);
|
||||||
|
algo=perf_results[best_algo_idx].algo;
|
||||||
|
if (benchmark) {
|
||||||
|
bwdw_algo_cache[jk.to_string()] = algo;
|
||||||
|
if (bwdw_algo_cache.size()==max_cache_size)
|
||||||
|
LOGw << "backward w algorithm cache is full";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: warp work space
|
||||||
|
void *workSpace = 0;
|
||||||
|
size_t workSpaceSize;
|
||||||
|
checkCudaErrors (cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
||||||
|
handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc,
|
||||||
|
cudnnFdesc, algo, &workSpaceSize));
|
||||||
|
size_t allocation;
|
||||||
|
if (workSpaceSize > 0) {
|
||||||
|
workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation);
|
||||||
|
}
|
||||||
|
float alpha=1, beta=0;
|
||||||
|
checkCudaErrors(cudnnConvolutionBackwardFilter(
|
||||||
|
handle_,
|
||||||
|
(void*)(&alpha),
|
||||||
|
cudnnIdesc, x->ptr<Tx>(),
|
||||||
|
cudnnOdesc, y->ptr<Ty>(),
|
||||||
|
cudnnConvDesc,
|
||||||
|
algo,
|
||||||
|
workSpace, workSpaceSize,
|
||||||
|
(void*)(&beta),
|
||||||
|
cudnnFdesc, w->ptr<Tw>())
|
||||||
|
);
|
||||||
|
if (workSpace)
|
||||||
|
exe.temp_allocator->free(workSpace, workSpaceSize, allocation);
|
||||||
|
|
||||||
|
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
|
||||||
|
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
|
||||||
|
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc ));
|
||||||
|
checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc ));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif // JIT
|
||||||
|
|
||||||
|
} // jittor
|
|
@ -0,0 +1,28 @@
|
||||||
|
// ***************************************************************
|
||||||
|
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||||
|
// Maintainers:
|
||||||
|
// Dun Liang <randonlang@gmail.com>
|
||||||
|
// Guowei Yang <471184555@qq.com>
|
||||||
|
//
|
||||||
|
// This file is subject to the terms and conditions defined in
|
||||||
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
// ***************************************************************
|
||||||
|
#pragma once
|
||||||
|
#include "op.h"
|
||||||
|
|
||||||
|
namespace jittor {
|
||||||
|
|
||||||
|
struct CudnnConv3dBackwardWOp : Op {
|
||||||
|
Var* x, * dy, * dw;
|
||||||
|
int kd, kh, kw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups;
|
||||||
|
string xformat;
|
||||||
|
|
||||||
|
CudnnConv3dBackwardWOp(Var* x, Var* y, int kd, int kh, int kw, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups=1, string xformat="ncdhw");
|
||||||
|
|
||||||
|
const char* name() const override { return "cudnn_conv3d_backward_w"; }
|
||||||
|
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||||
|
void infer_shape() override;
|
||||||
|
DECLARE_jit_run;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // jittor
|
|
@ -0,0 +1,279 @@
|
||||||
|
// ***************************************************************
|
||||||
|
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||||
|
// Maintainers:
|
||||||
|
// Dun Liang <randonlang@gmail.com>
|
||||||
|
// Guowei Yang <471184555@qq.com>
|
||||||
|
//
|
||||||
|
// This file is subject to the terms and conditions defined in
|
||||||
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
// ***************************************************************
|
||||||
|
#include "mem/allocator.h"
|
||||||
|
#include "var.h"
|
||||||
|
#include "cudnn_conv3d_backward_x_op.h"
|
||||||
|
#include "cudnn_warper.h"
|
||||||
|
#include "executor.h"
|
||||||
|
#include "ops/op_register.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
namespace jittor {
|
||||||
|
|
||||||
|
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||||
|
|
||||||
|
#ifndef JIT
|
||||||
|
|
||||||
|
CudnnConv3dBackwardXOp::CudnnConv3dBackwardXOp(Var* w, Var* dy, int depth, int height, int width, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat)
|
||||||
|
: w(w), dy(dy), xd(depth), xh(height), xw(width), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups),
|
||||||
|
xformat(move(xformat)) {
|
||||||
|
flags.set(NodeFlags::_cuda, 1);
|
||||||
|
flags.set(NodeFlags::_cpu, 0);
|
||||||
|
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudnnConv3dBackwardXOp::infer_shape() {
|
||||||
|
ASSERTop(w->shape.size(),==,5);
|
||||||
|
ASSERTop(dy->shape.size(),==,5);
|
||||||
|
int xn, xc, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||||
|
w->shape.unpack(wco, wci, wd, wh, ww);
|
||||||
|
if (xformat == "ncdhw")
|
||||||
|
dy->shape.unpack(yn, yc, yd, yh, yw);
|
||||||
|
else
|
||||||
|
dy->shape.unpack(yn, yd, yh, yw, yc);
|
||||||
|
xn = yn, xc = wci * groups;
|
||||||
|
if (xformat == "ncdhw")
|
||||||
|
dx->set_shape(NanoVector(xn, xc, xd, xh, xw));
|
||||||
|
else
|
||||||
|
dx->set_shape(NanoVector(xn, xd, xh, xw, xc));
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudnnConv3dBackwardXOp::jit_prepare(JK& jk) {
|
||||||
|
jk << _CS("[Tx:") << dx->dtype();
|
||||||
|
jk << _CS("][Ty:") << dy->dtype();
|
||||||
|
jk << _CS("][Tw:") << w->dtype();
|
||||||
|
jk << ']';
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static auto make_conv3d = get_op_info("cudnn_conv3d")
|
||||||
|
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, string>();
|
||||||
|
static auto make_backwardw = get_op_info("cudnn_conv3d_backward_w")
|
||||||
|
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, int, int, int, string>();
|
||||||
|
|
||||||
|
|
||||||
|
VarPtr CudnnConv3dBackwardXOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
|
int xn, xc, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||||
|
w->shape.unpack(wco, wci, wd, wh, ww);
|
||||||
|
|
||||||
|
if (v_index == 0) {
|
||||||
|
return make_backwardw(dout, dy, wd, wh, ww, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
|
||||||
|
} else {
|
||||||
|
return make_conv3d(dout, w, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
|
||||||
|
|
||||||
|
#else // JIT
|
||||||
|
#ifdef JIT_cuda
|
||||||
|
|
||||||
|
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||||
|
|
||||||
|
extern unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
|
||||||
|
|
||||||
|
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||||
|
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||||
|
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||||
|
|
||||||
|
void CudnnConv3dBackwardXOp::jit_run() {
|
||||||
|
auto x = dx;
|
||||||
|
auto y = dy;
|
||||||
|
cudnnHandle_t& handle_ = cudnn_handle;
|
||||||
|
|
||||||
|
cudnnTensorDescriptor_t cudnnIdesc;
|
||||||
|
cudnnFilterDescriptor_t cudnnFdesc;
|
||||||
|
cudnnTensorDescriptor_t cudnnOdesc;
|
||||||
|
cudnnConvolutionDescriptor_t cudnnConvDesc;
|
||||||
|
|
||||||
|
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc ));
|
||||||
|
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
|
||||||
|
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
|
||||||
|
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
|
||||||
|
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
|
||||||
|
|
||||||
|
|
||||||
|
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||||
|
int sx[] = {0,0,0,0,1};
|
||||||
|
for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1];
|
||||||
|
int strideX[5];
|
||||||
|
if (xformat == "ncdhw") {
|
||||||
|
x->shape.unpack(xn, xc, xd, xh, xw);
|
||||||
|
int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]};
|
||||||
|
memcpy(strideX, tmp, sizeof(tmp));
|
||||||
|
} else {
|
||||||
|
x->shape.unpack(xn, xd, xh, xw, xc);
|
||||||
|
int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]};
|
||||||
|
memcpy(strideX, tmp, sizeof(tmp));
|
||||||
|
}
|
||||||
|
int dimX[] = {xn, xc, xd, xh, xw};
|
||||||
|
// dimX: ncdhw
|
||||||
|
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||||
|
cudnnIdesc, getDataType<Tx>(),
|
||||||
|
5, dimX, strideX
|
||||||
|
));
|
||||||
|
|
||||||
|
auto ws = w->shape;
|
||||||
|
int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]};
|
||||||
|
// cudnn only support this two format
|
||||||
|
// https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor
|
||||||
|
#define filterFormat_oihw CUDNN_TENSOR_NCHW
|
||||||
|
#define filterFormat_ohwi CUDNN_TENSOR_NHWC
|
||||||
|
|
||||||
|
// dimW: KCRS(oihw)
|
||||||
|
checkCudaErrors(cudnnSetFilterNdDescriptor(
|
||||||
|
cudnnFdesc, getDataType<Tw>(),
|
||||||
|
// filterFormat_@WFORMAT, 5, dimW
|
||||||
|
filterFormat_oihw, 5, dimW
|
||||||
|
));
|
||||||
|
|
||||||
|
int padA[] = {paddingd, paddingh, paddingw};
|
||||||
|
int convstrideA[] = {strided, strideh, stridew};
|
||||||
|
int dilationA[] = {dilationd, dilationh, dilationw};
|
||||||
|
// difference between
|
||||||
|
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
|
||||||
|
// is the kernel rc order
|
||||||
|
// currently, No perf difference is observed between
|
||||||
|
// this two mode
|
||||||
|
checkCudaErrors(cudnnSetConvolutionNdDescriptor(
|
||||||
|
cudnnConvDesc, 3,
|
||||||
|
padA, convstrideA, dilationA,
|
||||||
|
CUDNN_CROSS_CORRELATION, getDataType<Ty>()
|
||||||
|
));
|
||||||
|
|
||||||
|
// using tensor core
|
||||||
|
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||||
|
|
||||||
|
|
||||||
|
int sy[] = {0,0,0,0,1};
|
||||||
|
for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1];
|
||||||
|
int strideY[5];
|
||||||
|
if (xformat == "ncdhw") {
|
||||||
|
y->shape.unpack(yn, yc, yd, yh, yw);
|
||||||
|
int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]};
|
||||||
|
memcpy(strideY, tmp, sizeof(tmp));
|
||||||
|
} else {
|
||||||
|
y->shape.unpack(yn, yd, yh, yw, yc);
|
||||||
|
int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]};
|
||||||
|
memcpy(strideY, tmp, sizeof(tmp));
|
||||||
|
}
|
||||||
|
int dimY[] = {yn, yc, yd, yh, yw};
|
||||||
|
|
||||||
|
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||||
|
cudnnOdesc, getDataType<Ty>(),
|
||||||
|
5, dimY, strideY
|
||||||
|
));
|
||||||
|
|
||||||
|
cudnnConvolutionBwdDataAlgo_t algos[] = {
|
||||||
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
|
||||||
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
|
||||||
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
|
||||||
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
|
||||||
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
|
||||||
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED
|
||||||
|
};
|
||||||
|
int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
||||||
|
int perf_count;
|
||||||
|
cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_algos];
|
||||||
|
cudnnConvolutionBwdDataAlgo_t algo;
|
||||||
|
bool benchmark=true;
|
||||||
|
|
||||||
|
jk.clear();
|
||||||
|
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
|
||||||
|
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
|
||||||
|
jk << paddingd << paddingh << paddingw << "," << strided << strideh <<stridew << "," << dilationd << dilationh << dilationw << "," << groups << ".";
|
||||||
|
auto iter = bwdx_algo_cache.find(jk.to_string());
|
||||||
|
|
||||||
|
if (iter!=bwdx_algo_cache.end()) algo = iter->second;
|
||||||
|
else {
|
||||||
|
if (bwdx_algo_cache.size()>=max_cache_size) benchmark = false;
|
||||||
|
if (benchmark) {
|
||||||
|
size_t max_ws_size = 0;
|
||||||
|
for (int i = 0; i < num_algos; i++) {
|
||||||
|
size_t sz;
|
||||||
|
cudnnStatus_t ret = cudnnGetConvolutionBackwardDataWorkspaceSize(handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, cudnnIdesc, algos[i], &sz);
|
||||||
|
// continue if use too much workspace
|
||||||
|
if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue;
|
||||||
|
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||||
|
}
|
||||||
|
size_t allocation;
|
||||||
|
void* ws = exe.temp_allocator->alloc(max_ws_size, allocation);
|
||||||
|
checkCudaErrors(cudnnFindConvolutionBackwardDataAlgorithmEx(
|
||||||
|
handle_,
|
||||||
|
cudnnFdesc, w->ptr<Tw>(),
|
||||||
|
cudnnOdesc, y->ptr<Ty>(),
|
||||||
|
cudnnConvDesc,
|
||||||
|
cudnnIdesc, x->ptr<Tx>(),
|
||||||
|
num_algos,
|
||||||
|
&perf_count,
|
||||||
|
perf_results,
|
||||||
|
ws,
|
||||||
|
max_ws_size));
|
||||||
|
exe.temp_allocator->free(ws, max_ws_size, allocation);
|
||||||
|
} else {
|
||||||
|
checkCudaErrors(cudnnGetConvolutionBackwardDataAlgorithm_v7(
|
||||||
|
handle_,
|
||||||
|
cudnnFdesc,
|
||||||
|
cudnnOdesc,
|
||||||
|
cudnnConvDesc,
|
||||||
|
cudnnIdesc,
|
||||||
|
num_algos,
|
||||||
|
&perf_count,
|
||||||
|
perf_results));
|
||||||
|
}
|
||||||
|
int best_algo_idx=-1;
|
||||||
|
for (int i = 0; i < perf_count; i++)
|
||||||
|
if (perf_results[i].status == CUDNN_STATUS_SUCCESS){
|
||||||
|
best_algo_idx=i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
ASSERT(best_algo_idx!=-1);
|
||||||
|
algo=perf_results[best_algo_idx].algo;
|
||||||
|
if (benchmark) {
|
||||||
|
bwdx_algo_cache[jk.to_string()] = algo;
|
||||||
|
if (bwdx_algo_cache.size()==max_cache_size)
|
||||||
|
LOGw << "backward x algorithm cache is full";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: warp work space
|
||||||
|
void *workSpace = 0;
|
||||||
|
size_t workSpaceSize;
|
||||||
|
checkCudaErrors (cudnnGetConvolutionBackwardDataWorkspaceSize(
|
||||||
|
handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc,
|
||||||
|
cudnnIdesc, algo, &workSpaceSize));
|
||||||
|
size_t allocation;
|
||||||
|
if (workSpaceSize > 0) {
|
||||||
|
workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation);
|
||||||
|
}
|
||||||
|
float alpha=1, beta=0;
|
||||||
|
checkCudaErrors(cudnnConvolutionBackwardData(
|
||||||
|
handle_,
|
||||||
|
(void*)(&alpha),
|
||||||
|
cudnnFdesc, w->ptr<Tw>(),
|
||||||
|
cudnnOdesc, y->ptr<Ty>(),
|
||||||
|
cudnnConvDesc,
|
||||||
|
algo,
|
||||||
|
workSpace, workSpaceSize,
|
||||||
|
(void*)(&beta),
|
||||||
|
cudnnIdesc, x->ptr<Tx>())
|
||||||
|
);
|
||||||
|
if (workSpace)
|
||||||
|
exe.temp_allocator->free(workSpace, workSpaceSize, allocation);
|
||||||
|
|
||||||
|
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
|
||||||
|
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
|
||||||
|
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc ));
|
||||||
|
checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc ));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif // JIT
|
||||||
|
|
||||||
|
} // jittor
|
|
@ -0,0 +1,28 @@
|
||||||
|
// ***************************************************************
|
||||||
|
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||||
|
// Maintainers:
|
||||||
|
// Dun Liang <randonlang@gmail.com>
|
||||||
|
// Guowei Yang <471184555@qq.com>
|
||||||
|
//
|
||||||
|
// This file is subject to the terms and conditions defined in
|
||||||
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
// ***************************************************************
|
||||||
|
#pragma once
|
||||||
|
#include "op.h"
|
||||||
|
|
||||||
|
namespace jittor {
|
||||||
|
|
||||||
|
struct CudnnConv3dBackwardXOp : Op {
|
||||||
|
Var* w, * dy, * dx;
|
||||||
|
int xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups;
|
||||||
|
string xformat;
|
||||||
|
|
||||||
|
CudnnConv3dBackwardXOp(Var* w, Var* y, int depth, int height, int width, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups=1, string xformat="ncdhw");
|
||||||
|
|
||||||
|
const char* name() const override { return "cudnn_conv3d_backward_x"; }
|
||||||
|
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||||
|
void infer_shape() override;
|
||||||
|
DECLARE_jit_run;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // jittor
|
|
@ -0,0 +1,284 @@
|
||||||
|
// ***************************************************************
|
||||||
|
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||||
|
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||||
|
//
|
||||||
|
// This file is subject to the terms and conditions defined in
|
||||||
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
// ***************************************************************
|
||||||
|
#include "var.h"
|
||||||
|
#include "cudnn_conv3d_op.h"
|
||||||
|
#include "cudnn_warper.h"
|
||||||
|
#include "executor.h"
|
||||||
|
#include "ops/op_register.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
namespace jittor {
|
||||||
|
|
||||||
|
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||||
|
|
||||||
|
#ifndef JIT
|
||||||
|
|
||||||
|
CudnnConv3dOp::CudnnConv3dOp(Var* x, Var* w, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat)
|
||||||
|
: x(x), w(w), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups),
|
||||||
|
xformat(move(xformat)) {
|
||||||
|
flags.set(NodeFlags::_cuda, 1);
|
||||||
|
flags.set(NodeFlags::_cpu, 0);
|
||||||
|
y = create_output(nullptr, dtype_infer(x->ns, w->ns));
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudnnConv3dOp::infer_shape() {
|
||||||
|
ASSERTop(x->shape.size(),==,5);
|
||||||
|
ASSERTop(w->shape.size(),==,5);
|
||||||
|
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||||
|
if (xformat == "ncdhw")
|
||||||
|
x->shape.unpack(xn, xc, xd, xh, xw);
|
||||||
|
else
|
||||||
|
x->shape.unpack(xn, xd, xh, xw, xc);
|
||||||
|
w->shape.unpack(wco, wci, wd, wh, ww);
|
||||||
|
ASSERTop(wci * groups,==,xc);
|
||||||
|
yn = xn, yc = wco;
|
||||||
|
yd = (xd+paddingd*2-wd*dilationd+dilationd-1)/strided+1;
|
||||||
|
yh = (xh+paddingh*2-wh*dilationh+dilationh-1)/strideh+1;
|
||||||
|
yw = (xw+paddingw*2-ww*dilationw+dilationw-1)/stridew+1;
|
||||||
|
if (xformat == "ncdhw")
|
||||||
|
y->set_shape(NanoVector(yn, yc, yd, yh, yw));
|
||||||
|
else
|
||||||
|
y->set_shape(NanoVector(yn, yd, yh, yw, yc));
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudnnConv3dOp::jit_prepare(JK& jk) {
|
||||||
|
jk << _CS("[Tx:") << x->dtype();
|
||||||
|
jk << _CS("][Ty:") << y->dtype();
|
||||||
|
jk << _CS("][Tw:") << w->dtype();
|
||||||
|
jk << ']';
|
||||||
|
}
|
||||||
|
|
||||||
|
static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x")
|
||||||
|
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, int, int, int, string>();
|
||||||
|
static auto make_backwardw = get_op_info("cudnn_conv3d_backward_w")
|
||||||
|
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, int, int, int, string>();
|
||||||
|
|
||||||
|
VarPtr CudnnConv3dOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
|
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||||
|
if (xformat == "ncdhw")
|
||||||
|
x->shape.unpack(xn, xc, xd, xh, xw);
|
||||||
|
else
|
||||||
|
x->shape.unpack(xn, xd, xh, xw, xc);
|
||||||
|
w->shape.unpack(wco, wci, wd, wh, ww);
|
||||||
|
if (v_index == 0) {
|
||||||
|
return make_backwardx(w, dout, xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
|
||||||
|
} else {
|
||||||
|
return make_backwardw(x, dout, wd, wh, ww, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
|
||||||
|
|
||||||
|
#else // JIT
|
||||||
|
#ifdef JIT_cuda
|
||||||
|
|
||||||
|
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||||
|
|
||||||
|
extern unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
|
||||||
|
|
||||||
|
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||||
|
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||||
|
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||||
|
|
||||||
|
void CudnnConv3dOp::jit_run() {
|
||||||
|
cudnnHandle_t& handle_ = cudnn_handle;
|
||||||
|
|
||||||
|
cudnnTensorDescriptor_t cudnnIdesc;
|
||||||
|
cudnnFilterDescriptor_t cudnnFdesc;
|
||||||
|
cudnnTensorDescriptor_t cudnnOdesc;
|
||||||
|
cudnnConvolutionDescriptor_t cudnnConvDesc;
|
||||||
|
|
||||||
|
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc ));
|
||||||
|
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
|
||||||
|
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
|
||||||
|
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
|
||||||
|
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
|
||||||
|
|
||||||
|
|
||||||
|
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||||
|
int sx[] = {0,0,0,0,1};
|
||||||
|
for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1];
|
||||||
|
int strideX[5];
|
||||||
|
if (xformat == "ncdhw") {
|
||||||
|
x->shape.unpack(xn, xc, xd, xh, xw);
|
||||||
|
int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]};
|
||||||
|
memcpy(strideX, tmp, sizeof(tmp));
|
||||||
|
} else {
|
||||||
|
x->shape.unpack(xn, xd, xh, xw, xc);
|
||||||
|
int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]};
|
||||||
|
memcpy(strideX, tmp, sizeof(tmp));
|
||||||
|
}
|
||||||
|
int dimX[] = {xn, xc, xd, xh, xw};
|
||||||
|
// dimX: ncdhw
|
||||||
|
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||||
|
cudnnIdesc, getDataType<Tx>(),
|
||||||
|
5, dimX, strideX
|
||||||
|
));
|
||||||
|
|
||||||
|
auto ws = w->shape;
|
||||||
|
int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]};
|
||||||
|
// cudnn only support this two format
|
||||||
|
// https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor
|
||||||
|
#define filterFormat_oihw CUDNN_TENSOR_NCHW
|
||||||
|
#define filterFormat_ohwi CUDNN_TENSOR_NHWC
|
||||||
|
|
||||||
|
// dimW: KCRS(oihw)
|
||||||
|
checkCudaErrors(cudnnSetFilterNdDescriptor(
|
||||||
|
cudnnFdesc, getDataType<Tw>(),
|
||||||
|
// filterFormat_@WFORMAT, 5, dimW
|
||||||
|
filterFormat_oihw, 5, dimW
|
||||||
|
));
|
||||||
|
|
||||||
|
int padA[] = {paddingd, paddingh, paddingw};
|
||||||
|
int convstrideA[] = {strided, strideh, stridew};
|
||||||
|
int dilationA[] = {dilationd, dilationh, dilationw};
|
||||||
|
// difference between
|
||||||
|
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
|
||||||
|
// is the kernel rc order
|
||||||
|
// currently, No perf difference is observed between
|
||||||
|
// this two mode
|
||||||
|
checkCudaErrors(cudnnSetConvolutionNdDescriptor(
|
||||||
|
cudnnConvDesc, 3,
|
||||||
|
padA, convstrideA, dilationA,
|
||||||
|
CUDNN_CROSS_CORRELATION, getDataType<Ty>()
|
||||||
|
));
|
||||||
|
|
||||||
|
// using tensor core
|
||||||
|
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||||
|
|
||||||
|
|
||||||
|
int sy[] = {0,0,0,0,1};
|
||||||
|
for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1];
|
||||||
|
int strideY[5];
|
||||||
|
if (xformat == "ncdhw") {
|
||||||
|
y->shape.unpack(yn, yc, yd, yh, yw);
|
||||||
|
int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]};
|
||||||
|
memcpy(strideY, tmp, sizeof(tmp));
|
||||||
|
} else {
|
||||||
|
y->shape.unpack(yn, yd, yh, yw, yc);
|
||||||
|
int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]};
|
||||||
|
memcpy(strideY, tmp, sizeof(tmp));
|
||||||
|
}
|
||||||
|
int dimY[] = {yn, yc, yd, yh, yw};
|
||||||
|
|
||||||
|
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||||
|
cudnnOdesc, getDataType<Ty>(),
|
||||||
|
5, dimY, strideY
|
||||||
|
));
|
||||||
|
|
||||||
|
cudnnConvolutionFwdAlgo_t algos[] = {
|
||||||
|
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||||
|
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||||
|
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||||
|
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||||
|
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
||||||
|
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||||
|
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||||
|
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||||
|
};
|
||||||
|
int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
||||||
|
int perf_count;
|
||||||
|
cudnnConvolutionFwdAlgoPerf_t perf_results[num_algos];
|
||||||
|
cudnnConvolutionFwdAlgo_t algo;
|
||||||
|
bool benchmark=true;
|
||||||
|
|
||||||
|
jk.clear();
|
||||||
|
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
|
||||||
|
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
|
||||||
|
jk << paddingd << paddingh << paddingw << "," << strided << strideh <<stridew << "," << dilationd << dilationh << dilationw << "," << groups << ".";
|
||||||
|
auto iter = fwd_algo_cache.find(jk.to_string());
|
||||||
|
|
||||||
|
if (iter!=fwd_algo_cache.end()) algo = iter->second;
|
||||||
|
else {
|
||||||
|
if (fwd_algo_cache.size()>=max_cache_size) benchmark = false;
|
||||||
|
if (benchmark) {
|
||||||
|
size_t max_ws_size = 0;
|
||||||
|
for (int i = 0; i < num_algos; i++) {
|
||||||
|
size_t sz;
|
||||||
|
cudnnStatus_t ret = cudnnGetConvolutionForwardWorkspaceSize(
|
||||||
|
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
|
||||||
|
cudnnOdesc, algos[i], &sz);
|
||||||
|
// continue if use too much workspace
|
||||||
|
if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue;
|
||||||
|
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||||
|
}
|
||||||
|
size_t allocation;
|
||||||
|
void* ws = exe.temp_allocator->alloc(max_ws_size, allocation);
|
||||||
|
checkCudaErrors(cudnnFindConvolutionForwardAlgorithmEx(
|
||||||
|
handle_,
|
||||||
|
cudnnIdesc, x->ptr<Tx>(),
|
||||||
|
cudnnFdesc, w->ptr<Tw>(),
|
||||||
|
cudnnConvDesc,
|
||||||
|
cudnnOdesc, y->ptr<Ty>(),
|
||||||
|
num_algos,
|
||||||
|
&perf_count,
|
||||||
|
perf_results,
|
||||||
|
ws,
|
||||||
|
max_ws_size));
|
||||||
|
exe.temp_allocator->free(ws, max_ws_size, allocation);
|
||||||
|
} else {
|
||||||
|
checkCudaErrors(cudnnGetConvolutionForwardAlgorithm_v7(
|
||||||
|
handle_,
|
||||||
|
cudnnIdesc,
|
||||||
|
cudnnFdesc,
|
||||||
|
cudnnConvDesc,
|
||||||
|
cudnnOdesc,
|
||||||
|
num_algos,
|
||||||
|
&perf_count,
|
||||||
|
perf_results));
|
||||||
|
}
|
||||||
|
int best_algo_idx=-1;
|
||||||
|
for (int i = 0; i < perf_count; i++)
|
||||||
|
if (perf_results[i].status == CUDNN_STATUS_SUCCESS){
|
||||||
|
best_algo_idx=i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
ASSERT(best_algo_idx!=-1);
|
||||||
|
algo=perf_results[best_algo_idx].algo;
|
||||||
|
if (benchmark) {
|
||||||
|
fwd_algo_cache[jk.to_string()] = algo;
|
||||||
|
if (fwd_algo_cache.size()==max_cache_size)
|
||||||
|
LOGw << "forward_ algorithm cache is full";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: warp work space
|
||||||
|
void *workSpace = 0;
|
||||||
|
size_t workSpaceSize;
|
||||||
|
checkCudaErrors (cudnnGetConvolutionForwardWorkspaceSize(
|
||||||
|
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
|
||||||
|
cudnnOdesc, algo, &workSpaceSize) );
|
||||||
|
size_t allocation;
|
||||||
|
if (workSpaceSize > 0) {
|
||||||
|
workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation);
|
||||||
|
}
|
||||||
|
float alpha=1, beta=0;
|
||||||
|
checkCudaErrors(cudnnConvolutionForward(
|
||||||
|
handle_,
|
||||||
|
(void*)(&alpha),
|
||||||
|
cudnnIdesc, x->ptr<Tx>(),
|
||||||
|
cudnnFdesc, w->ptr<Tw>(),
|
||||||
|
cudnnConvDesc,
|
||||||
|
algo,
|
||||||
|
workSpace, workSpaceSize,
|
||||||
|
(void*)(&beta),
|
||||||
|
cudnnOdesc, y->ptr<Ty>())
|
||||||
|
);
|
||||||
|
if (workSpace)
|
||||||
|
exe.temp_allocator->free(workSpace, workSpaceSize, allocation);
|
||||||
|
|
||||||
|
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
|
||||||
|
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
|
||||||
|
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc ));
|
||||||
|
checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc ));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif // JIT
|
||||||
|
|
||||||
|
} // jittor
|
|
@ -0,0 +1,24 @@
|
||||||
|
// ***************************************************************
|
||||||
|
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||||
|
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||||
|
// This file is subject to the terms and conditions defined in
|
||||||
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
// ***************************************************************
|
||||||
|
#pragma once
|
||||||
|
#include "op.h"
|
||||||
|
|
||||||
|
namespace jittor {
|
||||||
|
|
||||||
|
struct CudnnConv3dOp : Op {
|
||||||
|
Var* x, * w, * y;
|
||||||
|
int strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups;
|
||||||
|
string xformat;
|
||||||
|
CudnnConv3dOp(Var* x, Var* w, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd=1, int dilationh=1, int dilationw=1, int groups=1, string xformat="ncdhw");
|
||||||
|
|
||||||
|
const char* name() const override { return "cudnn_conv3d"; }
|
||||||
|
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||||
|
void infer_shape() override;
|
||||||
|
DECLARE_jit_run;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // jittor
|
|
@ -21,6 +21,7 @@ from collections import OrderedDict
|
||||||
from jittor.pool import *
|
from jittor.pool import *
|
||||||
from jittor.optim import *
|
from jittor.optim import *
|
||||||
from jittor.misc import _pair, _triple
|
from jittor.misc import _pair, _triple
|
||||||
|
from jittor_utils import LOG
|
||||||
|
|
||||||
|
|
||||||
def matmul_transpose(a, b):
|
def matmul_transpose(a, b):
|
||||||
|
@ -639,7 +640,6 @@ class Conv1d(Module):
|
||||||
|
|
||||||
class Conv3d(Module):
|
class Conv3d(Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||||
LOG.w("Optimizations of Conv3d are working in progress, it maybe slow currently.")
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
|
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
|
||||||
|
@ -665,65 +665,7 @@ class Conv3d(Module):
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
if self.groups == 1:
|
return conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
N,C,H,W,D = x.shape
|
|
||||||
Kh, Kw, Kd = self.kernel_size
|
|
||||||
assert C==self.in_channels
|
|
||||||
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
|
||||||
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
|
||||||
od = (D+self.padding[2]*2-Kd*self.dilation[2]+self.dilation[2]-1)//self.stride[2]+1
|
|
||||||
xx = x.reindex([N,self.out_channels,C,oh,ow,od,Kh,Kw,Kd], [
|
|
||||||
'i0', # Nid
|
|
||||||
'i2', # Cid
|
|
||||||
f'i3*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid
|
|
||||||
f'i4*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid
|
|
||||||
f'i5*{self.stride[2]}-{self.padding[2]}+i8*{self.dilation[2]}', # Did+KDid
|
|
||||||
])
|
|
||||||
ww = self.weight.broadcast(xx.shape, [0,3,4,5])
|
|
||||||
yy = xx*ww
|
|
||||||
y = yy.sum([2,6,7,8]) # Kc, Kh, Kw, Kd
|
|
||||||
if self.bias is not None:
|
|
||||||
b = self.bias.broadcast(y.shape, [0,2,3,4])
|
|
||||||
y = y + b
|
|
||||||
return y
|
|
||||||
else:
|
|
||||||
N,C,H,W,D = x.shape
|
|
||||||
Kh, Kw, Kd = self.kernel_size
|
|
||||||
G = self.groups
|
|
||||||
CpG = C // G # channels per group
|
|
||||||
assert C==self.in_channels
|
|
||||||
oc = self.out_channels
|
|
||||||
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
|
||||||
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
|
||||||
od = (D+self.padding[2]*2-Kd*self.dilation[2]+self.dilation[2]-1)//self.stride[2]+1
|
|
||||||
xx = x.reindex([N,G,oc//G,CpG,oh,ow,od,Kh,Kw,Kd], [
|
|
||||||
'i0', # Nid
|
|
||||||
f'i1*{CpG}+i3', # Gid
|
|
||||||
f'i4*{self.stride[0]}-{self.padding[0]}+i7*{self.dilation[0]}', # Hid+Khid
|
|
||||||
f'i5*{self.stride[1]}-{self.padding[1]}+i8*{self.dilation[1]}', # Wid+KWid
|
|
||||||
f'i6*{self.stride[2]}-{self.padding[2]}+i9*{self.dilation[2]}', # Did+KDid
|
|
||||||
])
|
|
||||||
# w: [oc, CpG, Kh, Kw, Kd]
|
|
||||||
ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, od, Kh, Kw, Kd], [
|
|
||||||
f'i1*{oc//G}+i2',
|
|
||||||
'i3',
|
|
||||||
'i7',
|
|
||||||
'i8',
|
|
||||||
'i9'
|
|
||||||
])
|
|
||||||
ww.compile_options = xx.compile_options = {"G":G,"C":C}
|
|
||||||
yy = xx*ww
|
|
||||||
y = yy.reindex_reduce('add', [N, oc, oh, ow, od], [
|
|
||||||
'i0',
|
|
||||||
f'i1*{oc//G}+i2',
|
|
||||||
'i4',
|
|
||||||
'i5',
|
|
||||||
'i6'
|
|
||||||
])
|
|
||||||
if self.bias is not None:
|
|
||||||
b = self.bias.broadcast(y.shape, [0,2,3,4])
|
|
||||||
y = y + b
|
|
||||||
return y
|
|
||||||
|
|
||||||
def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||||
padding = _pair(padding)
|
padding = _pair(padding)
|
||||||
|
@ -790,12 +732,12 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||||
out_channels = weight.shape[0]
|
out_channels = weight.shape[0]
|
||||||
|
|
||||||
if groups == 1:
|
if groups == 1:
|
||||||
N,C,H,W,D = x.shape
|
N,C,D,H,W = x.shape
|
||||||
Kh, Kw, Kd = weight.shape[-3:]
|
Kd, Kh, Kw = weight.shape[-3:]
|
||||||
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
|
od = (D+padding[0]*2-Kd*dilation[0]+dilation[0]-1)//stride[0]+1
|
||||||
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
|
oh = (H+padding[1]*2-Kh*dilation[1]+dilation[1]-1)//stride[1]+1
|
||||||
od = (D+padding[2]*2-Kd*dilation[2]+dilation[2]-1)//stride[2]+1
|
ow = (W+padding[2]*2-Kw*dilation[2]+dilation[2]-1)//stride[2]+1
|
||||||
xx = x.reindex([N,out_channels,C,oh,ow,od,Kh,Kw,Kd], [
|
xx = x.reindex([N,out_channels,C,od,oh,ow,Kd,Kh,Kw], [
|
||||||
'i0', # Nid
|
'i0', # Nid
|
||||||
'i2', # Cid
|
'i2', # Cid
|
||||||
f'i3*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid
|
f'i3*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid
|
||||||
|
@ -810,15 +752,15 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||||
y = y + b
|
y = y + b
|
||||||
return y
|
return y
|
||||||
else:
|
else:
|
||||||
N,C,H,W,D = x.shape
|
N,C,D,H,W = x.shape
|
||||||
Kh, Kw, Kd = weight.shape[-3:]
|
Kd, Kh, Kw = weight.shape[-3:]
|
||||||
G = groups
|
G = groups
|
||||||
CpG = C // G # channels per group
|
CpG = C // G # channels per group
|
||||||
oc = out_channels
|
oc = out_channels
|
||||||
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
|
od = (D+padding[0]*2-Kd*dilation[0]+dilation[0]-1)//stride[0]+1
|
||||||
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
|
oh = (H+padding[1]*2-Kh*dilation[1]+dilation[1]-1)//stride[1]+1
|
||||||
od = (D+padding[2]*2-Kd*dilation[2]+dilation[2]-1)//stride[2]+1
|
ow = (W+padding[2]*2-Kw*dilation[2]+dilation[2]-1)//stride[2]+1
|
||||||
xx = x.reindex([N,G,oc//G,CpG,oh,ow,od,Kh,Kw,Kd], [
|
xx = x.reindex([N,G,oc//G,CpG,od,oh,ow,Kd,Kh,Kw], [
|
||||||
'i0', # Nid
|
'i0', # Nid
|
||||||
f'i1*{CpG}+i3', # Gid
|
f'i1*{CpG}+i3', # Gid
|
||||||
f'i4*{stride[0]}-{padding[0]}+i7*{dilation[0]}', # Hid+Khid
|
f'i4*{stride[0]}-{padding[0]}+i7*{dilation[0]}', # Hid+Khid
|
||||||
|
@ -835,7 +777,7 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||||
'i9'
|
'i9'
|
||||||
])
|
])
|
||||||
yy = xx*ww
|
yy = xx*ww
|
||||||
y = yy.reindex_reduce('add', [N, oc, oh, ow, od], [
|
y = yy.reindex_reduce('add', [N, oc, od, oh, ow], [
|
||||||
'i0',
|
'i0',
|
||||||
f'i1*{oc//G}+i2',
|
f'i1*{oc//G}+i2',
|
||||||
'i4',
|
'i4',
|
||||||
|
@ -906,6 +848,45 @@ class ConvTranspose(Module):
|
||||||
y = y + b
|
y = y + b
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
class ConvTranspose3d(Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
|
||||||
|
padding=0, output_padding=0, groups=1, bias=True, dilation=1):
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
# added
|
||||||
|
self.dilation = dilation
|
||||||
|
self.group = groups
|
||||||
|
assert groups==1, "Group conv not supported yet."
|
||||||
|
|
||||||
|
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
|
||||||
|
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
||||||
|
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation)
|
||||||
|
# added
|
||||||
|
self.padding = padding if isinstance(padding, tuple) else (padding, padding, padding)
|
||||||
|
self.real_padding = (
|
||||||
|
self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0],
|
||||||
|
self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1],
|
||||||
|
self.dilation[2] * (self.kernel_size[2] - 1) - self.padding[2])
|
||||||
|
self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding, output_padding)
|
||||||
|
assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \
|
||||||
|
self.output_padding[1] < max(self.stride[1], self.dilation[1]) and \
|
||||||
|
self.output_padding[2] < max(self.stride[2], self.dilation[2]), \
|
||||||
|
"output padding must be smaller than max(stride, dilation)"
|
||||||
|
|
||||||
|
self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float")
|
||||||
|
if bias:
|
||||||
|
fan=1
|
||||||
|
for i in self.weight.shape[1:]:
|
||||||
|
fan *= i
|
||||||
|
bound = 1 / math.sqrt(fan)
|
||||||
|
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def execute(self, x):
|
||||||
|
return conv_transpose3d(x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.group, self.dilation)
|
||||||
|
|
||||||
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
||||||
x = input
|
x = input
|
||||||
N,C,H,W = x.shape
|
N,C,H,W = x.shape
|
||||||
|
@ -944,6 +925,47 @@ def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding
|
||||||
assert not bias, "Bias should be none or jittor var"
|
assert not bias, "Bias should be none or jittor var"
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
||||||
|
x = input
|
||||||
|
N,C,D,H,W = x.shape
|
||||||
|
i,o,d,h,w = weight.shape
|
||||||
|
assert C==i
|
||||||
|
assert groups==1, "Group conv not supported yet."
|
||||||
|
stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
||||||
|
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation)
|
||||||
|
# added
|
||||||
|
padding = padding if isinstance(padding, tuple) else (padding, padding, padding)
|
||||||
|
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding, output_padding)
|
||||||
|
assert output_padding[0] < max(stride[0], dilation[0]) and \
|
||||||
|
output_padding[1] < max(stride[1], dilation[1]) and \
|
||||||
|
output_padding[2] < max(stride[2], dilation[2]), \
|
||||||
|
"output padding must be smaller than max(stride, dilation)"
|
||||||
|
|
||||||
|
stride_d, stride_h, stride_w = stride
|
||||||
|
padding_d, padding_h, padding_w = padding
|
||||||
|
dilation_d, dilation_h, dilation_w = dilation
|
||||||
|
|
||||||
|
d_out = (D-1) * stride_d + output_padding[0] - 2*padding_d + 1 + (d-1)*dilation_d
|
||||||
|
h_out = (H-1) * stride_h + output_padding[1] - 2*padding_h + 1 + (h-1)*dilation_h
|
||||||
|
w_out = (W-1) * stride_w + output_padding[2] - 2*padding_w + 1 + (w-1)*dilation_w
|
||||||
|
out_shape = (N, o, d_out, h_out, w_out)
|
||||||
|
shape = (N, i, o, D, H, W, d, h, w)
|
||||||
|
xx = x.broadcast(shape, (2, 6, 7, 8)) # i,h,w
|
||||||
|
ww = weight.broadcast(shape, (0, 3, 4, 5)) # N,H,W
|
||||||
|
y = (ww*xx).reindex_reduce("add", out_shape, [
|
||||||
|
'i0', # N
|
||||||
|
'i2', # o
|
||||||
|
f'i3*{stride_d}-{padding_d}+i6*{dilation_d}', # Did+Kdid
|
||||||
|
f'i4*{stride_h}-{padding_h}+i7*{dilation_h}', # Hid+Khid
|
||||||
|
f'i5*{stride_w}-{padding_w}+i8*{dilation_w}', # Wid+KWid
|
||||||
|
])
|
||||||
|
if isinstance(bias, jt.Var):
|
||||||
|
b = bias.broadcast(y.shape, [0,2,3,4])
|
||||||
|
y = y + b
|
||||||
|
else:
|
||||||
|
assert not bias, "Bias should be none or jittor var"
|
||||||
|
return y
|
||||||
|
|
||||||
conv_transpose2d = conv_transpose
|
conv_transpose2d = conv_transpose
|
||||||
|
|
||||||
def pad(x,padding, mode='constant', value=0):
|
def pad(x,padding, mode='constant', value=0):
|
||||||
|
@ -1286,7 +1308,7 @@ def linspace_from_neg_one(grid,num_steps,align_corners):
|
||||||
return jt.array(ra,dtype=grid.dtype)
|
return jt.array(ra,dtype=grid.dtype)
|
||||||
|
|
||||||
def make_base_grid_4D(theta,N,C,H,W,align_corners):
|
def make_base_grid_4D(theta,N,C,H,W,align_corners):
|
||||||
base_grid = jt.zeros((N, H, W, 3), dtype=theta.dtype);
|
base_grid = jt.zeros((N, H, W, 3), dtype=theta.dtype)
|
||||||
base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners)
|
base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners)
|
||||||
base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1)
|
base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1)
|
||||||
base_grid[...,-1] = 1
|
base_grid[...,-1] = 1
|
||||||
|
|
|
@ -238,6 +238,21 @@ struct NanoVector {
|
||||||
v[i] = at(i);
|
v[i] = at(i);
|
||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void _unpack(int i) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class... Args>
|
||||||
|
void _unpack(int i, int& x, Args&&... args) {
|
||||||
|
x = this->operator[](i);
|
||||||
|
_unpack(i+1, std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class... Args>
|
||||||
|
void unpack(Args&&... args) {
|
||||||
|
_unpack(0, std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -128,7 +128,55 @@ class TestCudnnConvOp(unittest.TestCase):
|
||||||
check([10,3,100,100], [5,3,3,3], stride=2, padding=0, dilation=1)
|
check([10,3,100,100], [5,3,3,3], stride=2, padding=0, dilation=1)
|
||||||
check([10,4,40,50], [5,4,5,5], stride=1, padding=1, dilation=1)
|
check([10,4,40,50], [5,4,5,5], stride=1, padding=1, dilation=1)
|
||||||
check([10,4,40,50], [5,4,4,4], stride=3, padding=1, dilation=1)
|
check([10,4,40,50], [5,4,4,4], stride=3, padding=1, dilation=1)
|
||||||
|
|
||||||
|
def test_conv3d(self):
|
||||||
|
def check(xshape, wshape, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), group=1):
|
||||||
|
with jt.flag_scope(use_cuda=1):
|
||||||
|
x = jt.random(xshape)
|
||||||
|
w = jt.random(wshape)
|
||||||
|
y = jt.cudnn.ops.cudnn_conv3d(x, w, *stride, *padding, *dilation, group)
|
||||||
|
masky = jt.rand_like(y)
|
||||||
|
dx, dw = jt.grad(masky*y, [x, w])
|
||||||
|
|
||||||
|
y2 = jt.nn.conv3d(x, w, None, stride, padding, dilation, group)
|
||||||
|
dx2, dw2 = jt.grad(masky*y2, [x, w])
|
||||||
|
np.testing.assert_allclose(y.data, y2.data)
|
||||||
|
np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-5, atol=1e-3)
|
||||||
|
np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-5, atol=1e-3)
|
||||||
|
|
||||||
|
check((2,4,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1))
|
||||||
|
check((2,4,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1))
|
||||||
|
check((2,4,10,10,10), (5,4,3,3,3), (2,2,2), (0,0,0))
|
||||||
|
check((2,4,10,10,10), (5,4,3,3,3), (1,2,3), (0,0,0))
|
||||||
|
check((2,4,10,10,10), (5,4,3,4,5), (1,1,1), (1,1,1))
|
||||||
|
check((2,4,10,10,10), (5,4,3,4,5), (1,2,3), (0,0,0))
|
||||||
|
check((2,4,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1), dilation=(1,2,3))
|
||||||
|
|
||||||
|
def test_conv_transpose3d(self):
|
||||||
|
def check(xshape, wshape, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), group=1):
|
||||||
|
with jt.flag_scope(use_cuda=1):
|
||||||
|
x = jt.random(xshape)
|
||||||
|
w = jt.random(wshape)
|
||||||
|
|
||||||
|
y2 = jt.nn.conv_transpose3d(x, w, None, stride, padding, 0, group, dilation)
|
||||||
|
|
||||||
|
with jt.flag_scope(use_cuda=1):
|
||||||
|
y = jt.cudnn.ops.cudnn_conv3d_backward_x(w, x, *y2.shape[2:], *stride, *padding, *dilation, group)
|
||||||
|
masky = jt.rand_like(y)
|
||||||
|
dx, dw = jt.grad(masky*y, [x, w])
|
||||||
|
|
||||||
|
dx2, dw2 = jt.grad(masky*y2, [x, w])
|
||||||
|
np.testing.assert_allclose(y.data, y2.data, rtol=1e-6, atol=1e-4)
|
||||||
|
np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-6, atol=1e-4)
|
||||||
|
np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-5, atol=1e-3)
|
||||||
|
|
||||||
|
check((2,5,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1))
|
||||||
|
check((2,5,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1))
|
||||||
|
check((2,5,10,10,10), (5,4,3,3,3), (2,2,2), (0,0,0))
|
||||||
|
check((2,5,10,10,10), (5,4,3,3,3), (1,2,3), (0,0,0))
|
||||||
|
check((2,5,10,10,10), (5,4,3,4,5), (1,1,1), (1,1,1))
|
||||||
|
check((2,5,10,10,10), (5,4,3,4,5), (1,2,3), (0,0,0))
|
||||||
|
check((2,5,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1), dilation=(1,2,3))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue