group convolution support

This commit is contained in:
bennyguo 2020-04-29 02:22:22 +08:00
parent 2153b65856
commit d792ba55b3
12 changed files with 587 additions and 47 deletions

1
.gitignore vendored
View File

@ -15,6 +15,7 @@ test.py
extern/mkl/mkldnn_lnx*/*
data/
build/
venv/
*.md
!*.src.md
!README.md

View File

@ -41,8 +41,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation),
CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation), groups(groups),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
@ -55,7 +55,7 @@ void CudnnConvBackwardWOp::infer_shape() {
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
wco = yc, wci = xc;
wco = yc, wci = xc / groups;
wh = kernel_size;
ww = kernel_size;
set_shape(dw, "oihw", wformat, wco, wci, wh, ww);
@ -96,6 +96,7 @@ void CudnnConvBackwardWOp::jit_run() {
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
int dimX[] = {
(int)x->shape[findc("@XFORMAT", 'a')], // n

View File

@ -13,10 +13,10 @@ namespace jittor {
struct CudnnConvBackwardWOp : Op {
Var* x, * dy, * dw;
int kernel_size, stride, padding, dilation;
int kernel_size, stride, padding, dilation, groups;
string xformat, wformat, yformat;
CudnnConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
CudnnConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "cudnn_conv_backward_w"; }
void infer_shape() override;

View File

@ -43,8 +43,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
CudnnConvBackwardXOp::CudnnConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
: w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation),
CudnnConvBackwardXOp::CudnnConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation), groups(groups),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
@ -57,7 +57,7 @@ void CudnnConvBackwardXOp::infer_shape() {
int xn, xc, wh, ww, wci, wco, yn, yc, yh, yw;
get_shape(w, "oihw", wformat, wco, wci, wh, ww);
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
xn = yn, xc = wci;
xn = yn, xc = wci * groups;
set_shape(dx, "abcd", xformat, xn, xc, xh, xw);
}
@ -96,6 +96,7 @@ void CudnnConvBackwardXOp::jit_run() {
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
int dimX[] = {
(int)x->shape[findc("@XFORMAT", 'a')], // n

View File

@ -13,10 +13,10 @@ namespace jittor {
struct CudnnConvBackwardXOp : Op {
Var* w, * dy, * dx;
int xh, xw, stride, padding, dilation;
int xh, xw, stride, padding, dilation, groups;
string xformat, wformat, yformat;
CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "cudnn_conv_backward_x"; }
void infer_shape() override;

View File

@ -41,8 +41,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
CudnnConvOp::CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
: x(x), w(w), stride(stride), padding(padding), dilation(dilation),
CudnnConvOp::CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: x(x), w(w), stride(stride), padding(padding), dilation(dilation), groups(groups),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
@ -57,7 +57,7 @@ void CudnnConvOp::infer_shape() {
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
get_shape(w, "oihw", wformat, wco, wci, wh, ww);
ASSERTop(wci,==,xc);
ASSERTop(wci * groups,==,xc);
yn = xn, yc = wco;
yh = (xh+padding*2-wh*dilation+dilation-1)/stride+1;
yw = (xw+padding*2-ww*dilation+dilation-1)/stride+1;
@ -97,6 +97,8 @@ void CudnnConvOp::jit_run() {
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
int dimX[] = {
(int)x->shape[findc("@XFORMAT", 'a')], // n
@ -240,7 +242,7 @@ void CudnnConvOp::jit_run() {
LOGw << "forward_ algorithm cache is full";
}
}
// TODO: warp work space
void *workSpace = 0;
size_t workSpaceSize;

View File

@ -10,10 +10,10 @@ namespace jittor {
struct CudnnConvOp : Op {
Var* x, * w, * y;
int stride, padding, dilation;
int stride, padding, dilation, groups;
string xformat, wformat, yformat;
/* CudnnConvOp: xformat abcd represents nchw */
CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="");
CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="");
const char* name() const override { return "cudnn_conv"; }
void infer_shape() override;

View File

@ -275,8 +275,6 @@ Softmax = jt.make_module(softmax, 2)
class Conv(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
assert groups == 1
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
@ -284,32 +282,70 @@ class Conv(Module):
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
Kh, Kw = self.kernel_size
assert groups==1, "Group conv not supported yet."
self.weight = init.relu_invariant_gauss([out_channels, in_channels, Kh, Kw], dtype="float", mode="fan_out")
self.groups = groups
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
if bias:
self.bias = init.uniform([out_channels], dtype="float", low=-1, high=1)
else:
self.bias = None
def execute(self, x):
N,C,H,W = x.shape
Kh, Kw = self.kernel_size
assert C==self.in_channels
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [
'i0', # Nid
'i2', # Cid
f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid
f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid
])
ww = self.weight.broadcast(xx.shape, [0,3,4])
yy = xx*ww
y = yy.sum([2,5,6]) # Kc, Kh, Kw
if self.bias is not None:
b = self.bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
if self.groups == 1:
N,C,H,W = x.shape
Kh, Kw = self.kernel_size
assert C==self.in_channels
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [
'i0', # Nid
'i2', # Cid
f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid
f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid
])
ww = self.weight.broadcast(xx.shape, [0,3,4])
yy = xx*ww
y = yy.sum([2,5,6]) # Kc, Kh, Kw
if self.bias is not None:
b = self.bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
else:
N,C,H,W = x.shape
Kh, Kw = self.kernel_size
G = self.groups
CpG = C // G # channels per group
assert C==self.in_channels
oc = self.out_channels
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [
'i0', # Nid
f'i1*{CpG}+i3', # Gid
f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid
f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid
])
# w: [oc, CpG, Kh, Kw]
ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [
f'i1*{oc//G}+i2',
'i3',
'i6',
'i7'
])
yy = xx*ww
y = yy.reindex_reduce('add', [N, oc, oh, ow], [
'i0',
f'i1*{oc//G}+i2',
'i4',
'i5'
])
if self.bias is not None:
b = self.bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
class ConvTranspose(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \

View File

@ -0,0 +1,133 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guoye Yang <498731903@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import os
import numpy as np
from jittor import compile_extern
# TODO: compare with pytorch
from jittor.test.test_log import find_log_with_re
if compile_extern.has_cuda:
from jittor.compile_extern import cublas_ops, cudnn_ops
else:
cublas_ops = cudnn_ops = None
def conv_nchw(x, in_planes, out_planes, kernel_size, padding, stride=1, dilation=1, groups=1, init_method=None, w_=None):
N,C,H,W = x.shape
Kh, Kw = kernel_size, kernel_size
G = groups
CpG = C // G # channels per group
padding = (padding, padding)
dilation = (dilation, dilation)
stride = (stride, stride)
assert C==in_planes
oc = out_planes
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
if w_ is None:
if init_method==None:
w = jt.make_var([oc, C // G, Kh, Kw], init=lambda *a: init.relu_invariant_gauss(*a, mode="fan_out"))
else:
w = jt.make_var([oc, C // G, Kh, Kw], init=init_method)
else:
w = w_
xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [
'i0', # Nid
f'i1*{CpG}+i3', # Gid
f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid
f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid
])
# w: [oc, CpG, Kh, Kw]
ww = w.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [
f'i1*{oc//G}+i2',
'i3',
'i6',
'i7'
])
yy = xx*ww
y = yy.reindex_reduce('add', [N, oc, oh, ow], [
'i0',
f'i1*{oc//G}+i2',
'i4',
'i5'
])
return y
def test_nchw(x, w, stride, padding, dilation, groups):
_, in_planes, _, _ = x.shape
out_planes, _, kernel_size, _ = w.shape
return conv_nchw(x, in_planes, out_planes, kernel_size, padding, stride=stride, dilation=dilation, groups=groups, w_=w)
def check_forward(xshape, wshape, stride, padding, dilation, groups, use_cuda, nhwc):
assert nhwc == 0
test_func = test_nchw
# only check cudnn
with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1,
log_v=10, log_vprefix="conv_tuner.cc=1000"
) as raw_log:
x = jt.random(xshape)
w = jt.random(wshape)
y = test_func(x, w, stride, padding, dilation, groups)
y.sync()
with jt.flag_scope(use_cuda=0, enable_tuner=0):
cy = test_func(x, w, stride, padding, dilation, groups)
cy.sync()
assert np.allclose(y.data, cy.data)
def check_backward(xshape, wshape, stride, padding, dilation, groups, use_cuda, nhwc):
assert nhwc == 0
test_func = test_nchw
# only check cudnn
with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1,
log_v=10, log_vprefix="conv_tuner.cc=1000"
) as raw_log:
x = jt.random(xshape)
w = jt.random(wshape)
y = test_func(x, w, stride, padding, dilation, groups)
dx, dw = jt.grad(y, [x, w])
jt.sync([y, dx, dw])
with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test":233}):
cy = test_func(x, w, stride, padding, dilation, groups)
cdx, cdw = jt.grad(cy, [x, w])
jt.sync([cy, cdx, cdw])
assert np.allclose(y.data, cy.data)
assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data, np.abs(dw.data - cdw.data).max())
assert np.allclose(dx.data, cdx.data, 1e-3), (dx.data, cdx.data, np.abs(dx.data - cdx.data).max())
class TestGroupConvTuner(unittest.TestCase):
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
def test_forward_cuda(self):
for groups in [2, 4, 8]:
check_forward([10,8,100,100], [8,8//groups,3,3], 1, 0, 1, groups, 1, False)
check_forward([10,8,40,50], [16,8//groups,5,5], 1, 1, 2, groups, 1, False)
check_forward([10,8,40,50], [16,8//groups,4,4], 3, 1, 3, groups, 1, False)
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
def test_backward_cuda(self):
for groups in [2, 4, 8]:
check_backward([10,8,100,100], [8,8//groups,3,3], 1, 0, 1, groups, 1, False)
check_backward([10,8,40,50], [16,8//groups,5,5], 1, 1, 2, groups, 1, False)
check_backward([10,8,40,50], [16,8//groups,4,4], 3, 1, 3, groups, 1, False)
if __name__ == "__main__":
unittest.main()

View File

@ -18,6 +18,8 @@
#include "opt/expr.h"
#include "ops/op_register.h"
#include <algorithm>
namespace jittor {
using namespace expr;
@ -124,6 +126,31 @@ struct OpInspector {
}
}
// get last three index of binary mask
void get_id(uint64 m, int& i, int& j, int& k) {
if (m==0) failed=1;
else {
i=j=0;
while (!(m&1)) i++,m>>=1;
if (m<=1) {
failed=1;
return;
}
j=i+1,m>>=1;
while (!(m&1)) j++,m>>=1;
if (m<=1) {
failed=1;
return;
}
k=j+1, m>>=1;
while (!(m&1)) k++,m>>=1;
if (m!=1) {
failed=1;
return;
}
}
}
bool check_overlap(const vector<int>& v) {
uint64 sum=0;
for (auto a : v) {
@ -141,12 +168,17 @@ struct OpInspector {
}
if (check_overlap(order))
return "";
for (uint i=0; i<order.size(); i++) {
if (order[i]>=(int)new_fmt.size()) {
vector<pair<int, int>> order_;
for (uint i = 0; i < order.size(); i++) {
order_.push_back(pair<int, int>(order[i], i));
}
sort(order_.begin(), order_.end());
for (uint i=0; i<order_.size(); i++) {
if (order_[i].second>=(int)new_fmt.size()) {
failed = 1;
return "";
}
new_fmt[order[i]] = fmt[i];
new_fmt[order_[i].second] = fmt[i];
}
return new_fmt;
}
@ -269,8 +301,8 @@ void ConvTuner::forwardTune(FusedOp* fop) {
if (!has_op(relay_conv_name))
continue;
auto make_conv = get_op_info(relay_conv_name)
.get_constructor<VarPtr, Var*, Var*, int, int, int, string, string, string>();
auto rvar = make_conv(x, w, stride, padding, dilation, xformat, wformat, yformat);
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, string, string, string>();
auto rvar = make_conv(x, w, stride, padding, dilation, 1, xformat, wformat, yformat);
auto rid = fop->context->vrm.add_relay_group({{rvar, rop->y}});
if (rid>=0) {
auto srid = "relay"+S(rid);
@ -448,8 +480,8 @@ void ConvTuner::backwardTune(FusedOp* fop) {
auto make_conv_w = get_op_info(
fop->flags.get(NodeFlags::_cpu) ?
"mkl_conv_backward_w" : "cudnn_conv_backward_w"
).get_constructor<VarPtr, Var*, Var*, int, int, int, int, string, string, string>();
auto rvar_w = make_conv_w(x, y, kernel_size, stride, padding, dilation, xformat, wformat, yformat);
).get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, string, string, string>();
auto rvar_w = make_conv_w(x, y, kernel_size, stride, padding, dilation, 1, xformat, wformat, yformat);
auto rid = fop->context->vrm.add_relay_group({{rvar_w, dw}});
if (rid>=0) {
auto srid = "relay"+S(rid);
@ -462,8 +494,8 @@ void ConvTuner::backwardTune(FusedOp* fop) {
auto make_conv_x = get_op_info(
fop->flags.get(NodeFlags::_cpu) ?
"mkl_conv_backward_x" : "cudnn_conv_backward_x"
).get_constructor<VarPtr, Var*, Var*, int , int, int, int, int, string, string, string>();
auto rvar_x = make_conv_x(w, y, height, width, stride, padding, dilation, xformat, wformat, yformat);
).get_constructor<VarPtr, Var*, Var*, int , int, int, int, int, int, string, string, string>();
auto rvar_x = make_conv_x(w, y, height, width, stride, padding, dilation, 1, xformat, wformat, yformat);
auto rid = fop->context->vrm.add_relay_group({{rvar_x, dx}});
if (rid>=0) {
auto srid = "relay"+S(rid);
@ -482,4 +514,330 @@ void ConvTuner::run(PassManager* pm, TunerManager* tm) {
backwardTune(fop);
}
void GroupConvTuner::forwardTune(FusedOp* fop) {
LOGvvvv << "tune group conv";
for (Op* op : fop->ops) {
if (op->name_ex()=="reindex_reduce.add") {
auto rop = (ReindexReduceOp*)op;
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply" && rop->y->input()->tflag==op->tflag))
continue;
auto bop = (BinaryOp*)(rop->y->input());
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
if (!(bop->x->input()->name_ex()=="reindex" && bop->y->input()->name_ex()=="reindex")) return;
auto riop1 = (ReindexOp*)(bop->x->input());
auto riop2 = (ReindexOp*)(bop->y->input());
LOGvvvv << "conv like op" << fop << fop->get_jit_key();
OpInspector xoi(riop1);
OpInspector woi(riop2);
// determine which is which (since both are ReindexOp)
if (xoi.mm[0] == -1 && woi.mm[0] == 0) {
std::swap(xoi, woi);
}
OpInspector yoi(rop);
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
int zn, zg, zci, zco, zh, zw, zwh, zww;
zn = zg = zci = zco = zh = zw = zwh = zww = 0;
xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn);
xoi.get_id(xoi.m3 & woi.m3 & yoi.m3, zg);
xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw);
xoi.get_id(xoi.m2 & woi.m3 & yoi.m3, zco);
xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zci, zwh, zww);
LOGvvvv << "zn,zg,zci,zco,zh,zw,zwh,zww =" << vector<int>{zn,zg,zci,zco,zh,zw,zwh,zww};
xoi.check_overlap({zn,zg,zci,zco,zh,zw,zwh,zww});
if (xoi.failed) continue;
xn = xoi.mm[zn];
xc = xoi.mm[zci];
xh = xoi.mm[zh];
xw = xoi.mm[zw];
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
auto xformat = xoi.format("abcd", {xn, xc, xh, xw});
LOGvvvv << "xformat =" << xformat;
wci = woi.mm[zci];
wco = woi.mm[zco];
wh = woi.mm[zwh];
ww = woi.mm[zww];
auto wformat = xoi.format("iohw", {wci, wco, wh, ww});
LOGvvvv << "wformat =" << wformat;
yn = yoi.mm[zn];
yc = yoi.mm[zco];
yh = yoi.mm[zh];
yw = yoi.mm[zw];
auto yformat = xoi.format("abcd", {yn, yc, yh, yw});
LOGvvvv << "yformat =" << yformat;
// mkl doesn't support "cdab" format
if (yformat == "cdab") continue;
// cuda doesn't support "iohw" format
if (fop->flags.get(NodeFlags::_cuda) && wformat == "iohw") continue;
if (xoi.failed) continue;
std::stringstream ss;
// i@zh*stride+i@zwh+padding
ss << "i" << zh << "*stride+i" << zwh << "*dilation+padding";
auto expr_h = expr::make(ss.str());
ss.str("");
ss << "i" << zw << "*stride+i" << zww << "*dilation+padding";
auto expr_w = expr::make(ss.str());
vector<unique_ptr<Expr>> rh, rw;
auto src_h = expr::make(riop1->indexes[xh]);
if (!expr::match(src_h.get(), expr_h.get(), {"stride", "padding", "dilation"}, {"i"+S(zh), "i"+S(zwh)}, rh)) {
LOGvvvv << "Expr not match" << src_h << expr_h;
continue;
}
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number) || !rh[2]->is(expr::_number)) return;
auto src_w = expr::make(riop1->indexes[xw]);
if (!expr::match(src_w.get(), expr_w.get(), {"stride", "padding", "dilation"}, {"i"+S(zw), "i"+S(zww)}, rw))
return;
if (!rw[0]->is(expr::_number) || !rw[1]->is(expr::_number) || !rw[2]->is(expr::_number)) return;
int stride_h = rh[0]->as_int();
int padding_h = -rh[1]->as_int();
int dilation_h = rh[2]->as_int();
int stride_w = rw[0]->as_int();
int padding_w = -rw[1]->as_int();
int dilation_w = rw[2]->as_int();
if (dilation_h < 1 || dilation_w < 1) continue;
if (stride_h!=stride_w || padding_h!=padding_w || dilation_h!=dilation_w) {
LOGvvvv << "cannot relay different stride and padding between h and w"
<< stride_h << padding_h << dilation_h << stride_w << padding_w << dilation_w;
continue;
}
LOGvvvv << "get stride padding and dilation" << stride_h << padding_h << dilation_h;
int stride = stride_h;
int padding = padding_h;
int dilation = dilation_h;
Var* x = riop1->x;
Var* w = riop2->x;
int oh = (x->shape[xh]-w->shape[wh]*dilation_h+dilation_h-1+padding_h*2)/stride_h+1;
int ow = (x->shape[xw]-w->shape[ww]*dilation_w+dilation_w-1+padding_w*2)/stride_w+1;
if (oh != rop->x->shape[yh] || ow != rop->x->shape[yw]) continue;
int groups = x->shape[xc] / w->shape[wci];
LOGvvvv << "groups: " << groups;
if (fop->flags.get(NodeFlags::_cpu) && groups > 1) {
LOGi << "group conv does not support mkl";
continue;
}
string relay_conv_name = "cudnn_conv";
if (!has_op(relay_conv_name))
continue;
auto make_conv = get_op_info(relay_conv_name)
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, string, string, string>();
auto rvar = make_conv(x, w, stride, padding, dilation, groups, xformat, wformat, yformat);
auto rid = fop->context->vrm.add_relay_group({{rvar, rop->x}});
if (rid>=0) {
auto srid = "relay"+S(rid);
add_candidate(srid, 1);
add_candidate(srid, 0);
confidence = 20;
}
}
}
}
void GroupConvTuner::backwardTune(FusedOp* fop) {
for (Op* op : fop->ops) {
int bo=0;
Var *x=NULL, *y=NULL, *w=NULL;
Var *dw=NULL, *dx=NULL;
int height=0,width=0,kernel_size=0,stride=0, padding=0, dilation=1, groups=1;
string xformat, yformat, wformat;
if (op->name_ex() == "reindex_reduce.add") {
auto rop = (ReindexReduceOp*)op;
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply" && rop->y->input()->tflag==op->tflag))
continue;
auto bop = (BinaryOp*)(rop->y->input());
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
if (!(bop->x->input()->name_ex()=="reindex" && bop->y->input()->name_ex()=="reindex")) return;
auto riop1 = (ReindexOp*)(bop->x->input());
auto riop2 = (ReindexOp*)(bop->y->input());
LOGvvvv << "conv like op" << fop << fop->get_jit_key();
OpInspector oi1(riop1);
OpInspector oi2(riop2);
if (oi1.mm[0] == 0 && oi2.mm[0] == 0) {
// dw
// x.mm [0,1,-1,1,2,3,2,3] y.mm [0,1,1,-1,2,3,-1,-1] w.mm [-1,0,0,1,-1,-1,2,3]
OpInspector xoi(oi1.mm[2] == -1 ? riop1 : riop2);
OpInspector yoi(oi1.mm[2] == -1 ? riop2 : riop1);
OpInspector woi(rop);
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
int zn, zg, zci, zco, zh, zw, zwh, zww;
zn = zg = zci = zco = zh = zw = zwh = zww = 0;
xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn);
xoi.get_id(xoi.m3 & woi.m3 & yoi.m3, zg);
xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw);
xoi.get_id(xoi.m2 & woi.m3 & yoi.m3, zco);
xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zci, zwh, zww);
LOGvvvv << "group conv backward dw zn,zg,zci,zco,zh,zw,zwh,zww =" << vector<int>{zn,zg,zci,zco,zh,zw,zwh,zww};
xoi.check_overlap({zn,zg,zci,zco,zh,zw,zwh,zww});
if (xoi.failed) continue;
xn = xoi.mm[zn];
xc = xoi.mm[zci];
xh = xoi.mm[zh];
xw = xoi.mm[zw];
xformat = xoi.format("abcd", {xn, xc, xh, xw});
wci = woi.mm[zci];
wco = woi.mm[zco];
wh = woi.mm[zwh];
ww = woi.mm[zww];
wformat = xoi.format("iohw", {wci, wco, wh, ww});
yn = yoi.mm[zn];
yc = yoi.mm[zco];
yh = yoi.mm[zh];
yw = yoi.mm[zw];
yformat = xoi.format("abcd", {yn, yc, yh, yw});
// mkl doesn't support "cdab" format
if (yformat == "cdab") continue;
// cuda doesn't support "iohw" format
if (fop->flags.get(NodeFlags::_cuda) && wformat == "iohw") continue;
if (xoi.failed) continue;
std::stringstream ss;
// i@zh*stride+i@zwh+padding
ss << "i" << zh << "*stride+i" << zwh << "*dilation+padding";
auto expr_h = expr::make(ss.str());
vector<unique_ptr<Expr>> rh;
auto src_h = expr::make(riop1->indexes[xh]);
if (!expr::match(src_h.get(), expr_h.get(), {"stride", "padding", "dilation"}, {"i"+S(zh), "i"+S(zwh)}, rh)) {
LOGvvvv << "Expr not match" << src_h << expr_h;
continue;
}
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number)) continue;
dw = rop->x;
stride = rh[0]->as_int();
padding = -rh[1]->as_int();
dilation = rh[2]->as_int();
kernel_size = dw->shape[wformat.find("h")];
groups = (oi1.mm[2] == -1 ? riop1 : riop2)->x->shape[xc] / dw->shape[wci];
if (fop->flags.get(NodeFlags::_cpu) && groups > 1) {
LOGi << "group conv does not support mkl";
continue;
}
LOGvvvv << stride << padding << dilation << kernel_size << groups;
x = (oi1.mm[2] == -1 ? riop1 : riop2)->x;
y = (oi1.mm[2] == -1 ? riop2 : riop1)->x;
bo++;
} else {
// dx
OpInspector woi(oi1.mm[0] == -1 ? riop1 : riop2);
OpInspector yoi(oi1.mm[0] == -1 ? riop2 : riop1);
OpInspector xoi(rop);
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
int zn, zg, zci, zco, zh, zw, zwh, zww;
zn = zg = zci = zco = zh = zw = zwh = zww = 0;
xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn);
xoi.get_id(xoi.m3 & woi.m3 & yoi.m3, zg);
xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw);
xoi.get_id(xoi.m2 & woi.m3 & yoi.m3, zco);
xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zci, zwh, zww);
LOGvvvv << "group conv backward dx zn,zg,zci,zco,zh,zw,zwh,zww =" << vector<int>{zn,zg,zci,zco,zh,zw,zwh,zww};
xoi.check_overlap({zn,zg,zci,zco,zh,zw,zwh,zww});
if (xoi.failed) continue;
xn = xoi.mm[zn];
xc = xoi.mm[zci];
xh = xoi.mm[zh];
xw = xoi.mm[zw];
xformat = xoi.format("abcd", {xn, xc, xh, xw});
wci = woi.mm[zci];
wco = woi.mm[zco];
wh = woi.mm[zwh];
ww = woi.mm[zww];
wformat = xoi.format("iohw", {wci, wco, wh, ww});
yn = yoi.mm[zn];
yc = yoi.mm[zco];
yh = yoi.mm[zh];
yw = yoi.mm[zw];
yformat = xoi.format("abcd", {yn, yc, yh, yw});
// mkl doesn't support "cdab" format
if (yformat == "cdab") continue;
// cuda doesn't support "iohw" format
if (fop->flags.get(NodeFlags::_cuda) && wformat == "iohw") continue;
if (xoi.failed) continue;
std::stringstream ss;
// i@zh*stride+i@zwh+padding
ss << "i" << zh << "*stride+i" << zwh << "*dilation+padding";
auto expr_h = expr::make(ss.str());
vector<unique_ptr<Expr>> rh;
auto src_h = expr::make(rop->indexes[xh]);
if (!expr::match(src_h.get(), expr_h.get(), {"stride", "padding", "dilation"}, {"i"+S(zh), "i"+S(zwh)}, rh)) {
LOGvvvv << "Expr not match" << src_h << expr_h;
continue;
}
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number)) continue;
dx = rop->x;
stride = rh[0]->as_int();
padding = -rh[1]->as_int();
dilation = rh[2]->as_int();
height = dx->shape[xformat.find("c")];
width = dx->shape[xformat.find("d")];
groups = dx->shape[xc] / (oi1.mm[0] == -1 ? riop1 : riop2)->x->shape[wci];
if (fop->flags.get(NodeFlags::_cpu) && groups > 1) {
LOGi << "group conv does not support mkl";
continue;
}
LOGvvvv << stride << padding << dilation << height << width << groups;
w = (oi1.mm[0] == -1 ? riop1 : riop2)->x;
y = (oi1.mm[0] == -1 ? riop2 : riop1)->x;
bo+=2;
}
}
// TODO: CUDA only support nchw(abcd)
if (fop->flags.get(NodeFlags::_cuda) && (xformat != "abcd" || yformat != "abcd"))
continue;
if (bo&1) {
auto make_conv_w = get_op_info(
"cudnn_conv_backward_w"
).get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, string, string, string>();
auto rvar_w = make_conv_w(x, y, kernel_size, stride, padding, dilation, groups, xformat, wformat, yformat);
auto rid = fop->context->vrm.add_relay_group({{rvar_w, dw}});
if (rid>=0) {
auto srid = "relay"+S(rid);
add_candidate(srid, 1);
add_candidate(srid, 0);
confidence = 20;
}
}
if (bo&2) {
auto make_conv_x = get_op_info(
"cudnn_conv_backward_x"
).get_constructor<VarPtr, Var*, Var*, int , int, int, int, int, int, string, string, string>();
auto rvar_x = make_conv_x(w, y, height, width, stride, padding, dilation, groups, xformat, wformat, yformat);
auto rid = fop->context->vrm.add_relay_group({{rvar_x, dx}});
if (rid>=0) {
auto srid = "relay"+S(rid);
add_candidate(srid, 1);
add_candidate(srid, 0);
confidence = 20;
}
}
}
}
void GroupConvTuner::run(PassManager* pm, TunerManager* tm) {
FusedOp* fop=tm->oc->op;
forwardTune(fop);
backwardTune(fop);
}
}

View File

@ -20,4 +20,11 @@ struct ConvTuner : Tuner {
void run(PassManager* pm, TunerManager* tm);
};
struct GroupConvTuner : Tuner {
GroupConvTuner() : Tuner("group_conv") {}
void forwardTune(FusedOp* fop);
void backwardTune(FusedOp* fop);
void run(PassManager* pm, TunerManager* tm);
};
}

View File

@ -43,6 +43,7 @@ string TunerManager::tune() {
run_tuner<ReduceTuner>(&pm);
run_tuner<MatmulTuner>(&pm);
run_tuner<ConvTuner>(&pm);
run_tuner<GroupConvTuner>(&pm);
// use the best tuner if it is confidence enough
if (best_tuner && best_tuner->confidence) {