From aefe719770282c880f740e6d711720eb11e67fda Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Sat, 20 Feb 2021 16:34:24 +0800 Subject: [PATCH] add stride padding(h,w) support and ctrl-c quick exit --- .../cudnn/ops/cudnn_conv_backward_w_op.cc | 12 +- .../cuda/cudnn/ops/cudnn_conv_backward_w_op.h | 4 +- .../cudnn/ops/cudnn_conv_backward_x_op.cc | 12 +- .../cuda/cudnn/ops/cudnn_conv_backward_x_op.h | 4 +- extern/cuda/cudnn/ops/cudnn_conv_op.cc | 16 +-- extern/cuda/cudnn/ops/cudnn_conv_op.h | 4 +- extern/mkl/ops/mkl_conv_backward_w_op.cc | 12 +- extern/mkl/ops/mkl_conv_backward_w_op.h | 4 +- extern/mkl/ops/mkl_conv_backward_x_op.cc | 19 +-- extern/mkl/ops/mkl_conv_backward_x_op.h | 4 +- extern/mkl/ops/mkl_conv_op.cc | 16 +-- extern/mkl/ops/mkl_conv_op.h | 4 +- python/jittor/__init__.py | 2 +- python/jittor/depthwise_conv.py | 1 + python/jittor/nn.py | 2 + python/jittor/test/test_inception.py | 127 ++++++++++++++++++ python/jittor/test/test_misc_issue.py | 4 +- python/jittor/test/test_mkl_conv_op.py | 12 +- python/jittor/test/test_setitem.py | 21 +++ src/jit_key.cc | 3 +- src/opt/tuner/conv_tuner.cc | 27 ++-- src/test/test_sfrl_allocator.cc | 15 +-- src/utils/log.cc | 3 +- 23 files changed, 233 insertions(+), 95 deletions(-) create mode 100644 python/jittor/test/test_inception.py diff --git a/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc b/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc index 3e82c9bb..a341786c 100644 --- a/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc +++ b/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc @@ -43,8 +43,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a, shape[0], shape[1], shape[2], shape[3])); } -CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat) - : x(x), dy(dy), kh(kh), kw(kw), stride(stride), padding(padding), dilation(dilation), groups(groups), +CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat) + : x(x), dy(dy), kh(kh), kw(kw), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups), xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cpu, 0); @@ -134,9 +134,9 @@ void CudnnConvBackwardWOp::jit_run() { filterFormat_@WFORMAT, 4, dimW )); - int padA[] = {padding, padding}; - int convstrideA[] = {stride, stride}; - int dilationA[] = {dilation, dilation}; + int padA[] = {paddingh, paddingw}; + int convstrideA[] = {strideh, stridew}; + int dilationA[] = {dilationh, dilationw}; // difference between // CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION // is the kernel rc order @@ -187,7 +187,7 @@ void CudnnConvBackwardWOp::jit_run() { jk.clear(); jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ","; jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ","; - jk << padding << "," <second; diff --git a/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.h b/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.h index db30b31e..ab102674 100644 --- a/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.h +++ b/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.h @@ -14,10 +14,10 @@ namespace jittor { struct CudnnConvBackwardWOp : Op { Var* x, * dy, * dw; - int kh, kw, stride, padding, dilation, groups; + int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; string xformat, wformat, yformat; - CudnnConvBackwardWOp(Var* x, Var* y, int kh, int kw, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); + CudnnConvBackwardWOp(Var* x, Var* y, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); const char* name() const override { return "cudnn_conv_backward_w"; } void infer_shape() override; diff --git a/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc b/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc index a604f679..9a27db81 100644 --- a/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc +++ b/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc @@ -45,8 +45,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a, shape[0], shape[1], shape[2], shape[3])); } -CudnnConvBackwardXOp::CudnnConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat) - : w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation), groups(groups), +CudnnConvBackwardXOp::CudnnConvBackwardXOp(Var* w, Var* dy, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat) + : w(w), dy(dy), xh(height), xw(width), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups), xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cpu, 0); @@ -135,9 +135,9 @@ void CudnnConvBackwardXOp::jit_run() { filterFormat_@WFORMAT, 4, dimW )); - int padA[] = {padding, padding}; - int convstrideA[] = {stride, stride}; - int dilationA[] = {dilation, dilation}; + int padA[] = {paddingh, paddingw}; + int convstrideA[] = {strideh, stridew}; + int dilationA[] = {dilationh, dilationw}; // difference between // CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION // is the kernel rc order @@ -188,7 +188,7 @@ void CudnnConvBackwardXOp::jit_run() { jk.clear(); jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ","; jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ","; - jk << padding << "," <second; diff --git a/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.h b/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.h index 9c13ee87..a9537cd3 100644 --- a/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.h +++ b/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.h @@ -14,10 +14,10 @@ namespace jittor { struct CudnnConvBackwardXOp : Op { Var* w, * dy, * dx; - int xh, xw, stride, padding, dilation, groups; + int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; string xformat, wformat, yformat; - CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); + CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); const char* name() const override { return "cudnn_conv_backward_x"; } void infer_shape() override; diff --git a/extern/cuda/cudnn/ops/cudnn_conv_op.cc b/extern/cuda/cudnn/ops/cudnn_conv_op.cc index 7f3305ed..6dcef7b3 100644 --- a/extern/cuda/cudnn/ops/cudnn_conv_op.cc +++ b/extern/cuda/cudnn/ops/cudnn_conv_op.cc @@ -42,8 +42,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a, shape[0], shape[1], shape[2], shape[3])); } -CudnnConvOp::CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat) - : x(x), w(w), stride(stride), padding(padding), dilation(dilation), groups(groups), +CudnnConvOp::CudnnConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat) + : x(x), w(w), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups), xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cpu, 0); @@ -60,8 +60,8 @@ void CudnnConvOp::infer_shape() { get_shape(w, "oihw", wformat, wco, wci, wh, ww); ASSERTop(wci * groups,==,xc); yn = xn, yc = wco; - yh = (xh+padding*2-wh*dilation+dilation-1)/stride+1; - yw = (xw+padding*2-ww*dilation+dilation-1)/stride+1; + yh = (xh+paddingh*2-wh*dilationh+dilationh-1)/strideh+1; + yw = (xw+paddingw*2-ww*dilationw+dilationw-1)/stridew+1; set_shape(y, "abcd", yformat, yn, yc, yh, yw); } @@ -135,9 +135,9 @@ void CudnnConvOp::jit_run() { filterFormat_@WFORMAT, 4, dimW )); - int padA[] = {padding, padding}; - int convstrideA[] = {stride, stride}; - int dilationA[] = {dilation, dilation}; + int padA[] = {paddingh, paddingw}; + int convstrideA[] = {strideh, stridew}; + int dilationA[] = {dilationh, dilationw}; // difference between // CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION // is the kernel rc order @@ -190,7 +190,7 @@ void CudnnConvOp::jit_run() { jk.clear(); jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ","; jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ","; - jk << padding << "," <second; diff --git a/extern/cuda/cudnn/ops/cudnn_conv_op.h b/extern/cuda/cudnn/ops/cudnn_conv_op.h index a57f0228..8f082297 100644 --- a/extern/cuda/cudnn/ops/cudnn_conv_op.h +++ b/extern/cuda/cudnn/ops/cudnn_conv_op.h @@ -11,10 +11,10 @@ namespace jittor { struct CudnnConvOp : Op { Var* x, * w, * y; - int stride, padding, dilation, groups; + int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; string xformat, wformat, yformat; /* CudnnConvOp: xformat abcd represents nchw */ - CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat=""); + CudnnConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh=1, int dilationw=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat=""); const char* name() const override { return "cudnn_conv"; } void infer_shape() override; diff --git a/extern/mkl/ops/mkl_conv_backward_w_op.cc b/extern/mkl/ops/mkl_conv_backward_w_op.cc index fca73dbe..117285d7 100644 --- a/extern/mkl/ops/mkl_conv_backward_w_op.cc +++ b/extern/mkl/ops/mkl_conv_backward_w_op.cc @@ -46,8 +46,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a, shape[0], shape[1], shape[2], shape[3])); } -MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat) - : x(x), dy(dy), kh(kh), kw(kw), stride(stride), padding(padding), dilation(dilation), groups(groups), +MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat) + : x(x), dy(dy), kh(kh), kw(kw), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups), xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { dw = create_output(nullptr, dtype_infer(dy->ns, x->ns)); } @@ -119,10 +119,10 @@ void MklConvBackwardWOp::jit_run() { memory::dims conv_weights_tz = groups>1 ? memory::dims{groups, ch_out/groups, ch_in/groups, kh, kw} : memory::dims{ch_out, ch_in, kh, kw}; - memory::dims conv_dst_tz = {batch, ch_out, (height+padding*2-kh*dilation+dilation-1)/stride+1, (width+padding*2-kw*dilation+dilation-1)/stride+1}; - memory::dims conv_strides = {stride, stride}; - memory::dims conv_padding = {padding, padding}; - memory::dims conv_dilation = {dilation-1, dilation-1}; + memory::dims conv_dst_tz = {batch, ch_out, (height+paddingh*2-kh*dilationh+dilationh-1)/strideh+1, (width+paddingw*2-kw*dilationw+dilationw-1)/stridew+1}; + memory::dims conv_strides = {strideh, stridew}; + memory::dims conv_padding = {paddingh, paddingw}; + memory::dims conv_dilation = {dilationh-1, dilationw-1}; if (groups>1) ASSERT(tag::@WFORMAT == tag::oihw); diff --git a/extern/mkl/ops/mkl_conv_backward_w_op.h b/extern/mkl/ops/mkl_conv_backward_w_op.h index 912af516..53162886 100644 --- a/extern/mkl/ops/mkl_conv_backward_w_op.h +++ b/extern/mkl/ops/mkl_conv_backward_w_op.h @@ -14,10 +14,10 @@ namespace jittor { struct MklConvBackwardWOp : Op { Var* x, * dy, * dw; - int kh, kw, stride, padding, dilation, groups; + int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; string xformat, wformat, yformat; - MklConvBackwardWOp(Var* x, Var* y, int kh, int kw, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); + MklConvBackwardWOp(Var* x, Var* y, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); const char* name() const override { return "mkl_conv_backward_w"; } void infer_shape() override; diff --git a/extern/mkl/ops/mkl_conv_backward_x_op.cc b/extern/mkl/ops/mkl_conv_backward_x_op.cc index 91b9154a..201524b6 100644 --- a/extern/mkl/ops/mkl_conv_backward_x_op.cc +++ b/extern/mkl/ops/mkl_conv_backward_x_op.cc @@ -46,8 +46,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a, shape[0], shape[1], shape[2], shape[3])); } -MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat) - : w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation), groups(groups), +MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat) + : w(w), dy(dy), xh(height), xw(width), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups), xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { dx = create_output(nullptr, dtype_infer(dy->ns, w->ns)); } @@ -97,7 +97,8 @@ void MklConvBackwardXOp::jit_run() { int height = dx->shape[findc("@XFORMAT",'c')]; int width = dx->shape[findc("@XFORMAT",'d')]; int ch_out = w->shape[findc("@WFORMAT",'o')]; - int kernel_size = w->shape[findc("@WFORMAT",'h')]; + int kernel_sizeh = w->shape[findc("@WFORMAT",'h')]; + int kernel_sizew = w->shape[findc("@WFORMAT",'w')]; auto* __restrict__ conv_weights = w->ptr(); auto* __restrict__ net_diff_dst = dy->ptr(); @@ -114,12 +115,12 @@ void MklConvBackwardXOp::jit_run() { memory::dims conv_src_tz = {batch, ch_in, height, width}; memory::dims conv_weights_tz = groups>1 - ? memory::dims{groups, ch_out/groups, ch_in/groups, kernel_size, kernel_size} - : memory::dims{ch_out, ch_in, kernel_size, kernel_size}; - memory::dims conv_dst_tz = {batch, ch_out, (height+padding*2-kernel_size*dilation+dilation-1)/stride+1, (width+padding*2-kernel_size*dilation+dilation-1)/stride+1}; - memory::dims conv_strides = {stride, stride}; - memory::dims conv_padding = {padding, padding}; - memory::dims conv_dilation = {dilation-1, dilation-1}; + ? memory::dims{groups, ch_out/groups, ch_in/groups, kernel_sizeh, kernel_sizew} + : memory::dims{ch_out, ch_in, kernel_sizeh, kernel_sizew}; + memory::dims conv_dst_tz = {batch, ch_out, (height+paddingh*2-kernel_sizeh*dilationh+dilationh-1)/strideh+1, (width+paddingw*2-kernel_sizew*dilationw+dilationw-1)/stridew+1}; + memory::dims conv_strides = {strideh, stridew}; + memory::dims conv_padding = {paddingh, paddingw}; + memory::dims conv_dilation = {dilationh-1, dilationw-1}; if (groups>1) ASSERT(tag::@WFORMAT == tag::oihw); diff --git a/extern/mkl/ops/mkl_conv_backward_x_op.h b/extern/mkl/ops/mkl_conv_backward_x_op.h index 357062c7..4f2e8d9d 100644 --- a/extern/mkl/ops/mkl_conv_backward_x_op.h +++ b/extern/mkl/ops/mkl_conv_backward_x_op.h @@ -14,10 +14,10 @@ namespace jittor { struct MklConvBackwardXOp : Op { Var* w, * dy, * dx; - int xh, xw, stride, padding, dilation, groups; + int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; string xformat, wformat, yformat; - MklConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); + MklConvBackwardXOp(Var* w, Var* y, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); const char* name() const override { return "mkl_conv_backward_x"; } void infer_shape() override; diff --git a/extern/mkl/ops/mkl_conv_op.cc b/extern/mkl/ops/mkl_conv_op.cc index 951cabe0..fb9c87e3 100644 --- a/extern/mkl/ops/mkl_conv_op.cc +++ b/extern/mkl/ops/mkl_conv_op.cc @@ -45,8 +45,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a, shape[0], shape[1], shape[2], shape[3])); } -MklConvOp::MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat) - : x(x), w(w), stride(stride), padding(padding), dilation(dilation), groups(groups), +MklConvOp::MklConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat) + : x(x), w(w), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups), xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { y = create_output(nullptr, dtype_infer(x->ns, w->ns)); if (!this->yformat.size()) @@ -61,8 +61,8 @@ void MklConvOp::infer_shape() { get_shape(w, "oihw", wformat, wco, wci, wh, ww); ASSERTop(wci * groups,==,xc); yn = xn, yc = wco; - yh = (xh+padding*2-wh*dilation+dilation-1)/stride+1; - yw = (xw+padding*2-ww*dilation+dilation-1)/stride+1; + yh = (xh+paddingh*2-wh*dilationh+dilationh-1)/strideh+1; + yw = (xw+paddingw*2-ww*dilationw+dilationw-1)/stridew+1; set_shape(y, "abcd", yformat, yn, yc, yh, yw); } @@ -104,7 +104,7 @@ void MklConvOp::jit_run() { using dt = memory::data_type; if (tag::@XFORMAT==tag::nhwc && tag::@YFORMAT==tag::nhwc && tag::@WFORMAT==tag::hwio - && stride==1 && padding==0 && dilation==1 && ws[0]==1 && ws[1]==1 + && strideh==1 && stridew==1 && paddingh==0 && paddingw==0 && dilationh==1 && dilationw==1 && ws[0]==1 && ws[1]==1 && dt::@Tx==dt::f32 && dt::@Ty==dt::f32 && dt::@Tw==dt::f32) { auto m = xs[0]*xs[1]*xs[2]; auto n = ws[3]; @@ -133,9 +133,9 @@ void MklConvOp::jit_run() { ? memory::dims{groups, wco/groups, wci, wh, ww} : memory::dims{wco, wci, wh, ww}; memory::dims conv1_dst_tz = {yn, yc, yh, yw}; - memory::dims conv1_strides = { stride, stride }; - memory::dims conv1_padding = { padding, padding }; - memory::dims conv1_dilation = { dilation-1, dilation-1 }; + memory::dims conv1_strides = { strideh, stridew }; + memory::dims conv1_padding = { paddingh, paddingw }; + memory::dims conv1_dilation = { dilationh-1, dilationw-1 }; if (groups>1) ASSERT(tag::@WFORMAT == tag::oihw); diff --git a/extern/mkl/ops/mkl_conv_op.h b/extern/mkl/ops/mkl_conv_op.h index 4bb29183..8ae08f07 100644 --- a/extern/mkl/ops/mkl_conv_op.h +++ b/extern/mkl/ops/mkl_conv_op.h @@ -14,10 +14,10 @@ namespace jittor { struct MklConvOp : Op { Var* x, * w, * y; - int stride, padding, dilation, groups; + int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; string xformat, wformat, yformat; /* MklConvOp: xformat abcd represents nchw */ - MklConvOp(Var* x, Var* w, int stride, int padding, int dilation=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat=""); + MklConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh=1, int dilationw=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat=""); const char* name() const override { return "mkl_conv"; } void infer_shape() override; diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index eda495e9..c3bdaaaa 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -8,7 +8,7 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.2.31' +__version__ = '1.2.2.32' from . import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/depthwise_conv.py b/python/jittor/depthwise_conv.py index e95d78e2..98a43157 100644 --- a/python/jittor/depthwise_conv.py +++ b/python/jittor/depthwise_conv.py @@ -28,6 +28,7 @@ class DepthwiseConv(Function): ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1 filter_height, filter_width = Kh, Kw self.Khw = Kh, Kw + assert oh>0 and ow>0 output = jt.code( [N, C, oh, ow], x.dtype, diff --git a/python/jittor/nn.py b/python/jittor/nn.py index a5a2b71b..988ec12f 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -554,6 +554,7 @@ class Conv(Module): assert C==self.in_channels oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1 ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1 + assert oh>0 and ow>0 xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [ 'i0', # Nid 'i2', # Cid @@ -576,6 +577,7 @@ class Conv(Module): oc = self.out_channels oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1 ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1 + assert oh>0 and ow>0 xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [ 'i0', # Nid f'i1*{CpG}+i3', # Gid diff --git a/python/jittor/test/test_inception.py b/python/jittor/test/test_inception.py new file mode 100644 index 00000000..2b6df3fe --- /dev/null +++ b/python/jittor/test/test_inception.py @@ -0,0 +1,127 @@ +# *************************************************************** +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Meng-Hao Guo +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import nn, Module +from jittor.models import inception +import numpy as np +import sys, os +import random +import math +import unittest +from jittor.test.test_reorder_tuner import simple_parser +from jittor.test.test_log import find_log_with_re +from jittor.dataset.mnist import MNIST +import jittor.transform as trans +import time + +skip_this_test = False + +class MnistNet(Module): + def __init__(self): + self.model = inception.inception_v3() + self.layer = nn.Linear(1000,10) + def execute(self, x): + x = self.model(x) + x = self.layer(x) + return x + +@unittest.skipIf(skip_this_test, "skip_this_test") +class TestInception(unittest.TestCase): + @classmethod + def setUpClass(self): + # hyper-parameters + self.batch_size = 32 + self.weight_decay = 0.0001 + self.momentum = 0.9 + self.learning_rate = 0.1 + # mnist dataset + self.train_loader = MNIST(train=True, transform=trans.Resize(300)) \ + .set_attrs(batch_size=self.batch_size, shuffle=True) + self.train_loader.num_workers = 4 + self.train_loader.total_len = self.batch_size * 300 + + # setup random seed + def setup_seed(self, seed): + np.random.seed(seed) + random.seed(seed) + jt.seed(seed) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1, use_stat_allocator=1) + def test_inception(self): + self.setup_seed(1) + loss_list=[] + acc_list=[] + mnist_net = MnistNet() + global prev + prev = time.time() + SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay) + + for batch_idx, (data, target) in enumerate(self.train_loader): + + # train step + with jt.log_capture_scope( + log_silent=1, + log_v=1, log_vprefix="op.cc=100,exe=10", + ) as logs: + # breakpoint() + output = mnist_net(data) + loss = nn.cross_entropy_loss(output, target) + SGD.step(loss) + def callback(batch_idx, loss, output, target): + # print train info + global prev + pred = np.argmax(output, axis=1) + acc = np.mean(target==pred) + loss_list.append(loss[0]) + acc_list.append(acc) + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}' + .format(0, batch_idx, 300,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev)) + # prev = time.time() + jt.fetch(batch_idx, loss, output, target, callback) + + log_conv = find_log_with_re(logs, + "Jit op key (not )?found: ((mkl)|(cudnn))_conv.*") + log_matmul = find_log_with_re(logs, + "Jit op key (not )?found: ((mkl)|(cublas))_matmul.*") + if batch_idx > 2: + assert len(log_conv)==283 and len(log_matmul)==6, (len(log_conv), len(log_matmul)) + + mem_used = jt.flags.stat_allocator_total_alloc_byte \ + -jt.flags.stat_allocator_total_free_byte + # assert mem_used < 4e9, mem_used + # TODO: why bigger? + assert mem_used < 15.6e9, mem_used + # example log: + # Train Epoch: 0 [0/100 (0%)] Loss: 2.352903 Acc: 0.110000 + # Train Epoch: 0 [1/100 (1%)] Loss: 2.840830 Acc: 0.080000 + # Train Epoch: 0 [2/100 (2%)] Loss: 3.473594 Acc: 0.100000 + # Train Epoch: 0 [3/100 (3%)] Loss: 3.131615 Acc: 0.200000 + # Train Epoch: 0 [4/100 (4%)] Loss: 2.524094 Acc: 0.230000 + # Train Epoch: 0 [5/100 (5%)] Loss: 7.780025 Acc: 0.080000 + # Train Epoch: 0 [6/100 (6%)] Loss: 3.890721 Acc: 0.160000 + # Train Epoch: 0 [7/100 (7%)] Loss: 6.370137 Acc: 0.140000 + # Train Epoch: 0 [8/100 (8%)] Loss: 11.390827 Acc: 0.150000 + # Train Epoch: 0 [9/100 (9%)] Loss: 21.598564 Acc: 0.080000 + # Train Epoch: 0 [10/100 (10%)] Loss: 23.369165 Acc: 0.130000 + # Train Epoch: 0 [20/100 (20%)] Loss: 4.804510 Acc: 0.100000 + # Train Epoch: 0 [30/100 (30%)] Loss: 3.393924 Acc: 0.110000 + # Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000 + # Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000 + + assert jt.core.number_of_lived_vars() < 50000, jt.core.number_of_lived_vars() + + jt.sync_all(True) + assert np.mean(loss_list[-20:])<1 + assert np.mean(acc_list[-20:])>0.5 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_misc_issue.py b/python/jittor/test/test_misc_issue.py index 01b29940..53a023ca 100644 --- a/python/jittor/test/test_misc_issue.py +++ b/python/jittor/test/test_misc_issue.py @@ -57,7 +57,7 @@ oihw = [4, 3, 5, 5] import jittor as jt x = jt.random(nchw) w = jt.random(oihw) -jt.mkl_ops.mkl_conv(x, w, 1, 2).sync() +jt.mkl_ops.mkl_conv(x, w, 1, 1, 2, 2).sync() jt.dirty_fix_pytorch_runtime_error() @@ -88,7 +88,7 @@ m(torch.rand(*nchw)) import jittor as jt x = jt.random(nchw) w = jt.random(oihw) -jt.mkl_ops.mkl_conv(x, w, 1, 2).sync() +jt.mkl_ops.mkl_conv(x, w, 1, 1, 2, 2).sync() """ diff --git a/python/jittor/test/test_mkl_conv_op.py b/python/jittor/test/test_mkl_conv_op.py index 4d187f61..b23f67ac 100644 --- a/python/jittor/test/test_mkl_conv_op.py +++ b/python/jittor/test/test_mkl_conv_op.py @@ -56,7 +56,7 @@ class TestMklConvOp(unittest.TestCase): def test_forward(self): a = np.random.rand(1,3,224,224).astype(np.float32) b = np.random.rand(64,3,7,7).astype(np.float32) - c = jt.mkl_ops.mkl_conv(a,b,2,3).data + c = jt.mkl_ops.mkl_conv(a,b,2,2,3,3).data a_jt = jt.array(a) b_jt = jt.array(b) @@ -81,7 +81,7 @@ class TestMklConvOp(unittest.TestCase): def check(xshape, wshape, stride, pad): a = np.random.rand(*xshape).astype(np.float32) b = np.random.rand(*wshape).astype(np.float32) - c = jt.mkl_ops.mkl_conv(a,b,stride,pad,1,xformat="acdb",wformat="hwio").data + c = jt.mkl_ops.mkl_conv(a,b,stride,stride,pad,pad,1,1,xformat="acdb",wformat="hwio").data a_jt = jt.array(a) b_jt = jt.array(b) @@ -114,8 +114,8 @@ class TestMklConvOp(unittest.TestCase): a = np.random.rand(n,c,H,W).astype(np.float32) b = np.random.rand(o,i,h,w).astype(np.float32) da = np.random.rand(n,o,H,W).astype(np.float32) - dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1).data - dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1).data + dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,1,1,1).data + dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1,1,1,1).data a_jt = jt.array(a) b_jt = jt.array(b) @@ -160,8 +160,8 @@ class TestMklConvOp(unittest.TestCase): a = np.random.rand(n,H,W,c).astype(np.float32) b = np.random.rand(h,w,i,o).astype(np.float32) da = np.random.rand(n,H,W,o).astype(np.float32) - dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data - dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data + dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data + dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data a_jt = jt.array(a) b_jt = jt.array(b) diff --git a/python/jittor/test/test_setitem.py b/python/jittor/test/test_setitem.py index 102030e5..83262252 100644 --- a/python/jittor/test/test_setitem.py +++ b/python/jittor/test/test_setitem.py @@ -148,6 +148,27 @@ class TestSetitem(unittest.TestCase): a[1,1] = -2 assert (a[0].numpy() == [-1,2]).all(), a[0].numpy() assert (a[1].numpy() == [3,-2]).all(), a[1].numpy() + + # def test_scatter(self): + # src = jt.arange(1, 11).reshape((2, 5)) + # index = jt.array([[0, 1, 2, 0]]) + # print(index.shape, src.shape) + # x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src) + # print(x) + +# def scatter(x, dim, index, src, reduce='void'): +# shape = index.shape +# indexes = [ jt.index(shape, i) for i in range(dim) ] +# indexes.append(index) +# print(indexes) +# return x.setitem(tuple(indexes), src, reduce) + +# def scatter_(x, dim, index, src, reduce='void'): +# return x.assign(x.scatter(dim, index, src, reduce)) + +# jt.Var.scatter = scatter +# jt.Var.scatter_ = scatter_ + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/src/jit_key.cc b/src/jit_key.cc index 365ce780..83fa845b 100644 --- a/src/jit_key.cc +++ b/src/jit_key.cc @@ -37,8 +37,7 @@ JitKey::JitKey() { JitKey::~JitKey() { auto buffer_end_page = get_buffer_end_page((size_t)&buffer[buffer_size-1]); LOGvv << "un-protect page" << (void*)buffer_end_page; - ASSERT(0== - mprotect((void*)buffer_end_page, page_size, PROT_READ|PROT_WRITE|PROT_EXEC)); + mprotect((void*)buffer_end_page, page_size, PROT_READ|PROT_WRITE|PROT_EXEC); protected_page = 0; } diff --git a/src/opt/tuner/conv_tuner.cc b/src/opt/tuner/conv_tuner.cc index eedfc739..874d0b2b 100644 --- a/src/opt/tuner/conv_tuner.cc +++ b/src/opt/tuner/conv_tuner.cc @@ -331,20 +331,11 @@ void ConvTuner::forwardTune(FusedOp* fop) { int padding_w = -rw[1]->as_int(); int dilation_w = rw[2]->as_int(); if (dilation_h < 1 || dilation_w < 1) continue; - if (stride_h!=stride_w || padding_h!=padding_w || dilation_h!=dilation_w) { - LOGw << "cannot relay different stride and padding between h and w" - << stride_h << padding_h << dilation_h << stride_w << padding_w << dilation_w - << "This may cause low performance. Please send us issue if you need it."; - continue; - } LOGvvvv << "get stride padding and dilation" << stride_h << padding_h << dilation_h; if (xformat == "bacd") { LOGvvvv << "mkl not support bacd, continue"; continue; } - int stride = stride_h; - int padding = padding_h; - int dilation = dilation_h; Var* x = x_id == 0 ? xoi.op->output(0) : xoi.op->input(0); Var* w = w_id == 0 ? woi.op->output(0) : woi.op->input(0); Var* y = y_id == 0 ? yoi.op->output(0) : yoi.op->input(0); @@ -371,9 +362,9 @@ void ConvTuner::forwardTune(FusedOp* fop) { if (!has_op(relay_conv_name)) continue; auto make_conv = get_op_info(relay_conv_name) - .get_constructor(); - LOGvvvv << x << w << stride << padding << dilation << groups << xformat << wformat << yformat; - rvar = make_conv(x, w, stride, padding, dilation, groups, xformat, wformat, yformat); + .get_constructor(); + LOGvvvv << x << w << stride_h << stride_w << padding_h << padding_w << dilation_h << dilation_w << groups << xformat << wformat << yformat; + rvar = make_conv(x, w, stride_h, stride_w, padding_h, padding_w, dilation_h, dilation_w, groups, xformat, wformat, yformat); } else if (x_id == 0) { relay_conv_name = fop->flags.get(NodeFlags::_cpu) ? @@ -383,9 +374,9 @@ void ConvTuner::forwardTune(FusedOp* fop) { auto height = x->shape[xformat.find("c")]; auto width = x->shape[xformat.find("d")]; auto make_conv_x = get_op_info(relay_conv_name) - .get_constructor(); - LOGvvvv << w << y << height << width << stride << padding << dilation << groups << xformat << wformat << yformat; - rvar = make_conv_x(w, y, height, width, stride, padding, dilation, groups, xformat, wformat, yformat); + .get_constructor(); + LOGvvvv << w << y << height << width << stride_h << stride_w << padding_h << padding_w << dilation_h << dilation_w << groups << xformat << wformat << yformat; + rvar = make_conv_x(w, y, height, width, stride_h, stride_w, padding_h, padding_w, dilation_h, dilation_w, groups, xformat, wformat, yformat); } else { relay_conv_name = fop->flags.get(NodeFlags::_cpu) ? "mkl_conv_backward_w" : "cudnn_conv_backward_w"; @@ -393,10 +384,10 @@ void ConvTuner::forwardTune(FusedOp* fop) { continue; auto kh = w->shape[wformat.find("h")]; auto kw = w->shape[wformat.find("w")]; - LOGvvvv << x << y << kh << stride << padding << dilation << groups << xformat << wformat << yformat; + LOGvvvv << x << y << kh << stride_h << stride_w << padding_h << padding_w << dilation_h << dilation_w << groups << xformat << wformat << yformat; auto make_conv_w = get_op_info(relay_conv_name) - .get_constructor(); - rvar = make_conv_w(x, y, kh, kw, stride, padding, dilation, groups, xformat, wformat, yformat); + .get_constructor(); + rvar = make_conv_w(x, y, kh, kw, stride_h, stride_w, padding_h, padding_w, dilation_h, dilation_w, groups, xformat, wformat, yformat); } LOGvvvv << relay_conv_name << "output:" << rvar; diff --git a/src/test/test_sfrl_allocator.cc b/src/test/test_sfrl_allocator.cc index f74b9e03..420ccf64 100644 --- a/src/test/test_sfrl_allocator.cc +++ b/src/test/test_sfrl_allocator.cc @@ -23,7 +23,6 @@ JIT_TEST(sfrl_allocator_time) { Allocator* allocator = get_allocator(); int max_allc_num = 10000; size_t id[max_allc_num]; - void* addr[max_allc_num]; size_t temp[max_allc_num]; std::vector tasks; tasks.push_back(TestTask(20000000, 1000, 1000, 400.0)); @@ -35,12 +34,12 @@ JIT_TEST(sfrl_allocator_time) { for (size_t k = 0; k < tasks[i].times1; ++k) { for (size_t j = 0; j < tasks[i].times2; ++j) { temp[j] = j; - addr[j] = allocator->alloc(tasks[i].size, id[j]); + allocator->alloc(tasks[i].size, id[j]); if (j > 0) std::swap(temp[j], temp[rand() % j]); } for (size_t j = 0; j < tasks[i].times2; ++j) { - allocator->free(addr[temp[j]], tasks[i].size, id[temp[j]]); + allocator->free(0, tasks[i].size, id[temp[j]]); } } auto end = std::chrono::duration_cast( @@ -55,7 +54,6 @@ JIT_TEST(sfrl_allocator_share) { Allocator* allocator = get_allocator(); int max_allc_num = 10000; size_t id[max_allc_num]; - void* addr[max_allc_num]; size_t temp[max_allc_num]; std::vector tasks; tasks.push_back(TestTask(20000000, 1000, 1000, 400.0)); @@ -72,13 +70,12 @@ JIT_TEST(sfrl_allocator_share) { if (rand() % 10 != 0 && j > 0) { id[j] = id[rand() % j]; allocator->share_with(tasks[i].size, id[j]); - addr[j] = addr[id[j]]; } else { - addr[j] = allocator->alloc(tasks[i].size, id[j]); + allocator->alloc(tasks[i].size, id[j]); } } for (size_t j = 0; j < tasks[i].times2; ++j) { - allocator->free(addr[temp[j]], tasks[i].size, id[temp[j]]); + allocator->free(0, tasks[i].size, id[temp[j]]); } } auto end = std::chrono::duration_cast( @@ -93,7 +90,6 @@ JIT_TEST(sfrl_allocator_share_without_size_and_ptr) { Allocator* allocator = get_allocator(); int max_allc_num = 1000; size_t id[max_allc_num]; - void* addr[max_allc_num]; size_t temp[max_allc_num]; std::vector tasks; tasks.push_back(TestTask(20000000, 100, 100, 400.0)); @@ -108,9 +104,8 @@ JIT_TEST(sfrl_allocator_share_without_size_and_ptr) { if (rand() % 10 != 0 && j > 0) { id[j] = id[rand() % j]; allocator->share_with(0, id[j]); - addr[j] = addr[id[j]]; } else { - addr[j] = allocator->alloc(tasks[i].size, id[j]); + allocator->alloc(tasks[i].size, id[j]); } } for (size_t j = 0; j < tasks[i].times2; ++j) { diff --git a/src/utils/log.cc b/src/utils/log.cc index 64ea0369..27e4a909 100644 --- a/src/utils/log.cc +++ b/src/utils/log.cc @@ -189,8 +189,9 @@ static int _pid = getpid(); void segfault_sigaction(int signal, siginfo_t *si, void *arg) { if (signal == SIGINT) { - if (_pid == getpid()) + if (_pid == getpid()) { LOGe << "Caught SIGINT, quick exit"; + } exited = true; std::quick_exit(1); }