mirror of https://github.com/Jittor/Jittor
group convolution support
This commit is contained in:
parent
2153b65856
commit
d792ba55b3
|
@ -15,6 +15,7 @@ test.py
|
|||
extern/mkl/mkldnn_lnx*/*
|
||||
data/
|
||||
build/
|
||||
venv/
|
||||
*.md
|
||||
!*.src.md
|
||||
!README.md
|
||||
|
|
|
@ -41,8 +41,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
|
|||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation),
|
||||
CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation), groups(groups),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
|
@ -55,7 +55,7 @@ void CudnnConvBackwardWOp::infer_shape() {
|
|||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
|
||||
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
|
||||
wco = yc, wci = xc;
|
||||
wco = yc, wci = xc / groups;
|
||||
wh = kernel_size;
|
||||
ww = kernel_size;
|
||||
set_shape(dw, "oihw", wformat, wco, wci, wh, ww);
|
||||
|
@ -96,6 +96,7 @@ void CudnnConvBackwardWOp::jit_run() {
|
|||
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
|
||||
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
|
||||
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
|
||||
|
||||
int dimX[] = {
|
||||
(int)x->shape[findc("@XFORMAT", 'a')], // n
|
||||
|
|
|
@ -13,10 +13,10 @@ namespace jittor {
|
|||
|
||||
struct CudnnConvBackwardWOp : Op {
|
||||
Var* x, * dy, * dw;
|
||||
int kernel_size, stride, padding, dilation;
|
||||
int kernel_size, stride, padding, dilation, groups;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
CudnnConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
CudnnConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "cudnn_conv_backward_w"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -43,8 +43,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
|
|||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
CudnnConvBackwardXOp::CudnnConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
: w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation),
|
||||
CudnnConvBackwardXOp::CudnnConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation), groups(groups),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
|
@ -57,7 +57,7 @@ void CudnnConvBackwardXOp::infer_shape() {
|
|||
int xn, xc, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
get_shape(w, "oihw", wformat, wco, wci, wh, ww);
|
||||
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
|
||||
xn = yn, xc = wci;
|
||||
xn = yn, xc = wci * groups;
|
||||
set_shape(dx, "abcd", xformat, xn, xc, xh, xw);
|
||||
}
|
||||
|
||||
|
@ -96,6 +96,7 @@ void CudnnConvBackwardXOp::jit_run() {
|
|||
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
|
||||
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
|
||||
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
|
||||
|
||||
int dimX[] = {
|
||||
(int)x->shape[findc("@XFORMAT", 'a')], // n
|
||||
|
|
|
@ -13,10 +13,10 @@ namespace jittor {
|
|||
|
||||
struct CudnnConvBackwardXOp : Op {
|
||||
Var* w, * dy, * dx;
|
||||
int xh, xw, stride, padding, dilation;
|
||||
int xh, xw, stride, padding, dilation, groups;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "cudnn_conv_backward_x"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -41,8 +41,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
|
|||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
CudnnConvOp::CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
: x(x), w(w), stride(stride), padding(padding), dilation(dilation),
|
||||
CudnnConvOp::CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: x(x), w(w), stride(stride), padding(padding), dilation(dilation), groups(groups),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
|
@ -57,7 +57,7 @@ void CudnnConvOp::infer_shape() {
|
|||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
|
||||
get_shape(w, "oihw", wformat, wco, wci, wh, ww);
|
||||
ASSERTop(wci,==,xc);
|
||||
ASSERTop(wci * groups,==,xc);
|
||||
yn = xn, yc = wco;
|
||||
yh = (xh+padding*2-wh*dilation+dilation-1)/stride+1;
|
||||
yw = (xw+padding*2-ww*dilation+dilation-1)/stride+1;
|
||||
|
@ -97,6 +97,8 @@ void CudnnConvOp::jit_run() {
|
|||
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
|
||||
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
|
||||
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
|
||||
|
||||
|
||||
int dimX[] = {
|
||||
(int)x->shape[findc("@XFORMAT", 'a')], // n
|
||||
|
@ -240,7 +242,7 @@ void CudnnConvOp::jit_run() {
|
|||
LOGw << "forward_ algorithm cache is full";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// TODO: warp work space
|
||||
void *workSpace = 0;
|
||||
size_t workSpaceSize;
|
||||
|
|
|
@ -10,10 +10,10 @@ namespace jittor {
|
|||
|
||||
struct CudnnConvOp : Op {
|
||||
Var* x, * w, * y;
|
||||
int stride, padding, dilation;
|
||||
int stride, padding, dilation, groups;
|
||||
string xformat, wformat, yformat;
|
||||
/* CudnnConvOp: xformat abcd represents nchw */
|
||||
CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
|
||||
const char* name() const override { return "cudnn_conv"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -275,8 +275,6 @@ Softmax = jt.make_module(softmax, 2)
|
|||
|
||||
class Conv(Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
assert groups == 1
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
||||
|
@ -284,32 +282,70 @@ class Conv(Module):
|
|||
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
|
||||
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
||||
Kh, Kw = self.kernel_size
|
||||
assert groups==1, "Group conv not supported yet."
|
||||
self.weight = init.relu_invariant_gauss([out_channels, in_channels, Kh, Kw], dtype="float", mode="fan_out")
|
||||
self.groups = groups
|
||||
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
||||
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
||||
|
||||
self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
|
||||
if bias:
|
||||
self.bias = init.uniform([out_channels], dtype="float", low=-1, high=1)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def execute(self, x):
|
||||
N,C,H,W = x.shape
|
||||
Kh, Kw = self.kernel_size
|
||||
assert C==self.in_channels
|
||||
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
||||
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
||||
xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [
|
||||
'i0', # Nid
|
||||
'i2', # Cid
|
||||
f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid
|
||||
f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid
|
||||
])
|
||||
ww = self.weight.broadcast(xx.shape, [0,3,4])
|
||||
yy = xx*ww
|
||||
y = yy.sum([2,5,6]) # Kc, Kh, Kw
|
||||
if self.bias is not None:
|
||||
b = self.bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
if self.groups == 1:
|
||||
N,C,H,W = x.shape
|
||||
Kh, Kw = self.kernel_size
|
||||
assert C==self.in_channels
|
||||
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
||||
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
||||
xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [
|
||||
'i0', # Nid
|
||||
'i2', # Cid
|
||||
f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid
|
||||
f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid
|
||||
])
|
||||
ww = self.weight.broadcast(xx.shape, [0,3,4])
|
||||
yy = xx*ww
|
||||
y = yy.sum([2,5,6]) # Kc, Kh, Kw
|
||||
if self.bias is not None:
|
||||
b = self.bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
else:
|
||||
N,C,H,W = x.shape
|
||||
Kh, Kw = self.kernel_size
|
||||
G = self.groups
|
||||
CpG = C // G # channels per group
|
||||
assert C==self.in_channels
|
||||
oc = self.out_channels
|
||||
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
||||
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
||||
xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [
|
||||
'i0', # Nid
|
||||
f'i1*{CpG}+i3', # Gid
|
||||
f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid
|
||||
f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid
|
||||
])
|
||||
# w: [oc, CpG, Kh, Kw]
|
||||
ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [
|
||||
f'i1*{oc//G}+i2',
|
||||
'i3',
|
||||
'i6',
|
||||
'i7'
|
||||
])
|
||||
yy = xx*ww
|
||||
y = yy.reindex_reduce('add', [N, oc, oh, ow], [
|
||||
'i0',
|
||||
f'i1*{oc//G}+i2',
|
||||
'i4',
|
||||
'i5'
|
||||
])
|
||||
if self.bias is not None:
|
||||
b = self.bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
|
||||
|
||||
class ConvTranspose(Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import os
|
||||
import numpy as np
|
||||
from jittor import compile_extern
|
||||
# TODO: compare with pytorch
|
||||
|
||||
from jittor.test.test_log import find_log_with_re
|
||||
if compile_extern.has_cuda:
|
||||
from jittor.compile_extern import cublas_ops, cudnn_ops
|
||||
else:
|
||||
cublas_ops = cudnn_ops = None
|
||||
|
||||
|
||||
def conv_nchw(x, in_planes, out_planes, kernel_size, padding, stride=1, dilation=1, groups=1, init_method=None, w_=None):
|
||||
N,C,H,W = x.shape
|
||||
Kh, Kw = kernel_size, kernel_size
|
||||
G = groups
|
||||
CpG = C // G # channels per group
|
||||
padding = (padding, padding)
|
||||
dilation = (dilation, dilation)
|
||||
stride = (stride, stride)
|
||||
assert C==in_planes
|
||||
oc = out_planes
|
||||
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
|
||||
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
|
||||
|
||||
if w_ is None:
|
||||
if init_method==None:
|
||||
w = jt.make_var([oc, C // G, Kh, Kw], init=lambda *a: init.relu_invariant_gauss(*a, mode="fan_out"))
|
||||
else:
|
||||
w = jt.make_var([oc, C // G, Kh, Kw], init=init_method)
|
||||
else:
|
||||
w = w_
|
||||
|
||||
xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [
|
||||
'i0', # Nid
|
||||
f'i1*{CpG}+i3', # Gid
|
||||
f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid
|
||||
f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid
|
||||
])
|
||||
# w: [oc, CpG, Kh, Kw]
|
||||
ww = w.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [
|
||||
f'i1*{oc//G}+i2',
|
||||
'i3',
|
||||
'i6',
|
||||
'i7'
|
||||
])
|
||||
|
||||
yy = xx*ww
|
||||
y = yy.reindex_reduce('add', [N, oc, oh, ow], [
|
||||
'i0',
|
||||
f'i1*{oc//G}+i2',
|
||||
'i4',
|
||||
'i5'
|
||||
])
|
||||
return y
|
||||
|
||||
|
||||
def test_nchw(x, w, stride, padding, dilation, groups):
|
||||
_, in_planes, _, _ = x.shape
|
||||
out_planes, _, kernel_size, _ = w.shape
|
||||
return conv_nchw(x, in_planes, out_planes, kernel_size, padding, stride=stride, dilation=dilation, groups=groups, w_=w)
|
||||
|
||||
|
||||
def check_forward(xshape, wshape, stride, padding, dilation, groups, use_cuda, nhwc):
|
||||
assert nhwc == 0
|
||||
test_func = test_nchw
|
||||
|
||||
# only check cudnn
|
||||
with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1,
|
||||
log_v=10, log_vprefix="conv_tuner.cc=1000"
|
||||
) as raw_log:
|
||||
x = jt.random(xshape)
|
||||
w = jt.random(wshape)
|
||||
y = test_func(x, w, stride, padding, dilation, groups)
|
||||
y.sync()
|
||||
with jt.flag_scope(use_cuda=0, enable_tuner=0):
|
||||
cy = test_func(x, w, stride, padding, dilation, groups)
|
||||
cy.sync()
|
||||
|
||||
assert np.allclose(y.data, cy.data)
|
||||
|
||||
|
||||
def check_backward(xshape, wshape, stride, padding, dilation, groups, use_cuda, nhwc):
|
||||
assert nhwc == 0
|
||||
test_func = test_nchw
|
||||
|
||||
# only check cudnn
|
||||
with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1,
|
||||
log_v=10, log_vprefix="conv_tuner.cc=1000"
|
||||
) as raw_log:
|
||||
x = jt.random(xshape)
|
||||
w = jt.random(wshape)
|
||||
y = test_func(x, w, stride, padding, dilation, groups)
|
||||
dx, dw = jt.grad(y, [x, w])
|
||||
jt.sync([y, dx, dw])
|
||||
with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test":233}):
|
||||
cy = test_func(x, w, stride, padding, dilation, groups)
|
||||
cdx, cdw = jt.grad(cy, [x, w])
|
||||
jt.sync([cy, cdx, cdw])
|
||||
|
||||
assert np.allclose(y.data, cy.data)
|
||||
assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data, np.abs(dw.data - cdw.data).max())
|
||||
assert np.allclose(dx.data, cdx.data, 1e-3), (dx.data, cdx.data, np.abs(dx.data - cdx.data).max())
|
||||
|
||||
|
||||
class TestGroupConvTuner(unittest.TestCase):
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
def test_forward_cuda(self):
|
||||
for groups in [2, 4, 8]:
|
||||
check_forward([10,8,100,100], [8,8//groups,3,3], 1, 0, 1, groups, 1, False)
|
||||
check_forward([10,8,40,50], [16,8//groups,5,5], 1, 1, 2, groups, 1, False)
|
||||
check_forward([10,8,40,50], [16,8//groups,4,4], 3, 1, 3, groups, 1, False)
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
def test_backward_cuda(self):
|
||||
for groups in [2, 4, 8]:
|
||||
check_backward([10,8,100,100], [8,8//groups,3,3], 1, 0, 1, groups, 1, False)
|
||||
check_backward([10,8,40,50], [16,8//groups,5,5], 1, 1, 2, groups, 1, False)
|
||||
check_backward([10,8,40,50], [16,8//groups,4,4], 3, 1, 3, groups, 1, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -18,6 +18,8 @@
|
|||
#include "opt/expr.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace jittor {
|
||||
|
||||
using namespace expr;
|
||||
|
@ -124,6 +126,31 @@ struct OpInspector {
|
|||
}
|
||||
}
|
||||
|
||||
// get last three index of binary mask
|
||||
void get_id(uint64 m, int& i, int& j, int& k) {
|
||||
if (m==0) failed=1;
|
||||
else {
|
||||
i=j=0;
|
||||
while (!(m&1)) i++,m>>=1;
|
||||
if (m<=1) {
|
||||
failed=1;
|
||||
return;
|
||||
}
|
||||
j=i+1,m>>=1;
|
||||
while (!(m&1)) j++,m>>=1;
|
||||
if (m<=1) {
|
||||
failed=1;
|
||||
return;
|
||||
}
|
||||
k=j+1, m>>=1;
|
||||
while (!(m&1)) k++,m>>=1;
|
||||
if (m!=1) {
|
||||
failed=1;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool check_overlap(const vector<int>& v) {
|
||||
uint64 sum=0;
|
||||
for (auto a : v) {
|
||||
|
@ -141,12 +168,17 @@ struct OpInspector {
|
|||
}
|
||||
if (check_overlap(order))
|
||||
return "";
|
||||
for (uint i=0; i<order.size(); i++) {
|
||||
if (order[i]>=(int)new_fmt.size()) {
|
||||
vector<pair<int, int>> order_;
|
||||
for (uint i = 0; i < order.size(); i++) {
|
||||
order_.push_back(pair<int, int>(order[i], i));
|
||||
}
|
||||
sort(order_.begin(), order_.end());
|
||||
for (uint i=0; i<order_.size(); i++) {
|
||||
if (order_[i].second>=(int)new_fmt.size()) {
|
||||
failed = 1;
|
||||
return "";
|
||||
}
|
||||
new_fmt[order[i]] = fmt[i];
|
||||
new_fmt[order_[i].second] = fmt[i];
|
||||
}
|
||||
return new_fmt;
|
||||
}
|
||||
|
@ -269,8 +301,8 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
if (!has_op(relay_conv_name))
|
||||
continue;
|
||||
auto make_conv = get_op_info(relay_conv_name)
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, string, string, string>();
|
||||
auto rvar = make_conv(x, w, stride, padding, dilation, xformat, wformat, yformat);
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, string, string, string>();
|
||||
auto rvar = make_conv(x, w, stride, padding, dilation, 1, xformat, wformat, yformat);
|
||||
auto rid = fop->context->vrm.add_relay_group({{rvar, rop->y}});
|
||||
if (rid>=0) {
|
||||
auto srid = "relay"+S(rid);
|
||||
|
@ -448,8 +480,8 @@ void ConvTuner::backwardTune(FusedOp* fop) {
|
|||
auto make_conv_w = get_op_info(
|
||||
fop->flags.get(NodeFlags::_cpu) ?
|
||||
"mkl_conv_backward_w" : "cudnn_conv_backward_w"
|
||||
).get_constructor<VarPtr, Var*, Var*, int, int, int, int, string, string, string>();
|
||||
auto rvar_w = make_conv_w(x, y, kernel_size, stride, padding, dilation, xformat, wformat, yformat);
|
||||
).get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, string, string, string>();
|
||||
auto rvar_w = make_conv_w(x, y, kernel_size, stride, padding, dilation, 1, xformat, wformat, yformat);
|
||||
auto rid = fop->context->vrm.add_relay_group({{rvar_w, dw}});
|
||||
if (rid>=0) {
|
||||
auto srid = "relay"+S(rid);
|
||||
|
@ -462,8 +494,8 @@ void ConvTuner::backwardTune(FusedOp* fop) {
|
|||
auto make_conv_x = get_op_info(
|
||||
fop->flags.get(NodeFlags::_cpu) ?
|
||||
"mkl_conv_backward_x" : "cudnn_conv_backward_x"
|
||||
).get_constructor<VarPtr, Var*, Var*, int , int, int, int, int, string, string, string>();
|
||||
auto rvar_x = make_conv_x(w, y, height, width, stride, padding, dilation, xformat, wformat, yformat);
|
||||
).get_constructor<VarPtr, Var*, Var*, int , int, int, int, int, int, string, string, string>();
|
||||
auto rvar_x = make_conv_x(w, y, height, width, stride, padding, dilation, 1, xformat, wformat, yformat);
|
||||
auto rid = fop->context->vrm.add_relay_group({{rvar_x, dx}});
|
||||
if (rid>=0) {
|
||||
auto srid = "relay"+S(rid);
|
||||
|
@ -482,4 +514,330 @@ void ConvTuner::run(PassManager* pm, TunerManager* tm) {
|
|||
backwardTune(fop);
|
||||
}
|
||||
|
||||
void GroupConvTuner::forwardTune(FusedOp* fop) {
|
||||
LOGvvvv << "tune group conv";
|
||||
for (Op* op : fop->ops) {
|
||||
if (op->name_ex()=="reindex_reduce.add") {
|
||||
auto rop = (ReindexReduceOp*)op;
|
||||
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply" && rop->y->input()->tflag==op->tflag))
|
||||
continue;
|
||||
auto bop = (BinaryOp*)(rop->y->input());
|
||||
|
||||
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
|
||||
if (!(bop->x->input()->name_ex()=="reindex" && bop->y->input()->name_ex()=="reindex")) return;
|
||||
auto riop1 = (ReindexOp*)(bop->x->input());
|
||||
auto riop2 = (ReindexOp*)(bop->y->input());
|
||||
LOGvvvv << "conv like op" << fop << fop->get_jit_key();
|
||||
OpInspector xoi(riop1);
|
||||
OpInspector woi(riop2);
|
||||
// determine which is which (since both are ReindexOp)
|
||||
if (xoi.mm[0] == -1 && woi.mm[0] == 0) {
|
||||
std::swap(xoi, woi);
|
||||
}
|
||||
OpInspector yoi(rop);
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
int zn, zg, zci, zco, zh, zw, zwh, zww;
|
||||
zn = zg = zci = zco = zh = zw = zwh = zww = 0;
|
||||
xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn);
|
||||
xoi.get_id(xoi.m3 & woi.m3 & yoi.m3, zg);
|
||||
xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw);
|
||||
xoi.get_id(xoi.m2 & woi.m3 & yoi.m3, zco);
|
||||
xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zci, zwh, zww);
|
||||
LOGvvvv << "zn,zg,zci,zco,zh,zw,zwh,zww =" << vector<int>{zn,zg,zci,zco,zh,zw,zwh,zww};
|
||||
xoi.check_overlap({zn,zg,zci,zco,zh,zw,zwh,zww});
|
||||
if (xoi.failed) continue;
|
||||
xn = xoi.mm[zn];
|
||||
xc = xoi.mm[zci];
|
||||
xh = xoi.mm[zh];
|
||||
xw = xoi.mm[zw];
|
||||
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
|
||||
auto xformat = xoi.format("abcd", {xn, xc, xh, xw});
|
||||
LOGvvvv << "xformat =" << xformat;
|
||||
wci = woi.mm[zci];
|
||||
wco = woi.mm[zco];
|
||||
wh = woi.mm[zwh];
|
||||
ww = woi.mm[zww];
|
||||
auto wformat = xoi.format("iohw", {wci, wco, wh, ww});
|
||||
LOGvvvv << "wformat =" << wformat;
|
||||
yn = yoi.mm[zn];
|
||||
yc = yoi.mm[zco];
|
||||
yh = yoi.mm[zh];
|
||||
yw = yoi.mm[zw];
|
||||
auto yformat = xoi.format("abcd", {yn, yc, yh, yw});
|
||||
LOGvvvv << "yformat =" << yformat;
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
// cuda doesn't support "iohw" format
|
||||
if (fop->flags.get(NodeFlags::_cuda) && wformat == "iohw") continue;
|
||||
if (xoi.failed) continue;
|
||||
std::stringstream ss;
|
||||
// i@zh*stride+i@zwh+padding
|
||||
ss << "i" << zh << "*stride+i" << zwh << "*dilation+padding";
|
||||
auto expr_h = expr::make(ss.str());
|
||||
ss.str("");
|
||||
ss << "i" << zw << "*stride+i" << zww << "*dilation+padding";
|
||||
auto expr_w = expr::make(ss.str());
|
||||
|
||||
vector<unique_ptr<Expr>> rh, rw;
|
||||
auto src_h = expr::make(riop1->indexes[xh]);
|
||||
if (!expr::match(src_h.get(), expr_h.get(), {"stride", "padding", "dilation"}, {"i"+S(zh), "i"+S(zwh)}, rh)) {
|
||||
LOGvvvv << "Expr not match" << src_h << expr_h;
|
||||
continue;
|
||||
}
|
||||
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number) || !rh[2]->is(expr::_number)) return;
|
||||
auto src_w = expr::make(riop1->indexes[xw]);
|
||||
if (!expr::match(src_w.get(), expr_w.get(), {"stride", "padding", "dilation"}, {"i"+S(zw), "i"+S(zww)}, rw))
|
||||
return;
|
||||
if (!rw[0]->is(expr::_number) || !rw[1]->is(expr::_number) || !rw[2]->is(expr::_number)) return;
|
||||
int stride_h = rh[0]->as_int();
|
||||
int padding_h = -rh[1]->as_int();
|
||||
int dilation_h = rh[2]->as_int();
|
||||
int stride_w = rw[0]->as_int();
|
||||
int padding_w = -rw[1]->as_int();
|
||||
int dilation_w = rw[2]->as_int();
|
||||
if (dilation_h < 1 || dilation_w < 1) continue;
|
||||
if (stride_h!=stride_w || padding_h!=padding_w || dilation_h!=dilation_w) {
|
||||
LOGvvvv << "cannot relay different stride and padding between h and w"
|
||||
<< stride_h << padding_h << dilation_h << stride_w << padding_w << dilation_w;
|
||||
continue;
|
||||
}
|
||||
LOGvvvv << "get stride padding and dilation" << stride_h << padding_h << dilation_h;
|
||||
|
||||
int stride = stride_h;
|
||||
int padding = padding_h;
|
||||
int dilation = dilation_h;
|
||||
Var* x = riop1->x;
|
||||
Var* w = riop2->x;
|
||||
|
||||
int oh = (x->shape[xh]-w->shape[wh]*dilation_h+dilation_h-1+padding_h*2)/stride_h+1;
|
||||
int ow = (x->shape[xw]-w->shape[ww]*dilation_w+dilation_w-1+padding_w*2)/stride_w+1;
|
||||
if (oh != rop->x->shape[yh] || ow != rop->x->shape[yw]) continue;
|
||||
|
||||
int groups = x->shape[xc] / w->shape[wci];
|
||||
LOGvvvv << "groups: " << groups;
|
||||
if (fop->flags.get(NodeFlags::_cpu) && groups > 1) {
|
||||
LOGi << "group conv does not support mkl";
|
||||
continue;
|
||||
}
|
||||
|
||||
string relay_conv_name = "cudnn_conv";
|
||||
if (!has_op(relay_conv_name))
|
||||
continue;
|
||||
auto make_conv = get_op_info(relay_conv_name)
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, string, string, string>();
|
||||
auto rvar = make_conv(x, w, stride, padding, dilation, groups, xformat, wformat, yformat);
|
||||
auto rid = fop->context->vrm.add_relay_group({{rvar, rop->x}});
|
||||
if (rid>=0) {
|
||||
auto srid = "relay"+S(rid);
|
||||
add_candidate(srid, 1);
|
||||
add_candidate(srid, 0);
|
||||
confidence = 20;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GroupConvTuner::backwardTune(FusedOp* fop) {
|
||||
for (Op* op : fop->ops) {
|
||||
int bo=0;
|
||||
Var *x=NULL, *y=NULL, *w=NULL;
|
||||
Var *dw=NULL, *dx=NULL;
|
||||
int height=0,width=0,kernel_size=0,stride=0, padding=0, dilation=1, groups=1;
|
||||
string xformat, yformat, wformat;
|
||||
if (op->name_ex() == "reindex_reduce.add") {
|
||||
auto rop = (ReindexReduceOp*)op;
|
||||
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply" && rop->y->input()->tflag==op->tflag))
|
||||
continue;
|
||||
auto bop = (BinaryOp*)(rop->y->input());
|
||||
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
|
||||
if (!(bop->x->input()->name_ex()=="reindex" && bop->y->input()->name_ex()=="reindex")) return;
|
||||
auto riop1 = (ReindexOp*)(bop->x->input());
|
||||
auto riop2 = (ReindexOp*)(bop->y->input());
|
||||
LOGvvvv << "conv like op" << fop << fop->get_jit_key();
|
||||
|
||||
OpInspector oi1(riop1);
|
||||
OpInspector oi2(riop2);
|
||||
|
||||
|
||||
if (oi1.mm[0] == 0 && oi2.mm[0] == 0) {
|
||||
// dw
|
||||
// x.mm [0,1,-1,1,2,3,2,3] y.mm [0,1,1,-1,2,3,-1,-1] w.mm [-1,0,0,1,-1,-1,2,3]
|
||||
OpInspector xoi(oi1.mm[2] == -1 ? riop1 : riop2);
|
||||
OpInspector yoi(oi1.mm[2] == -1 ? riop2 : riop1);
|
||||
OpInspector woi(rop);
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
int zn, zg, zci, zco, zh, zw, zwh, zww;
|
||||
zn = zg = zci = zco = zh = zw = zwh = zww = 0;
|
||||
xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn);
|
||||
xoi.get_id(xoi.m3 & woi.m3 & yoi.m3, zg);
|
||||
xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw);
|
||||
xoi.get_id(xoi.m2 & woi.m3 & yoi.m3, zco);
|
||||
xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zci, zwh, zww);
|
||||
LOGvvvv << "group conv backward dw zn,zg,zci,zco,zh,zw,zwh,zww =" << vector<int>{zn,zg,zci,zco,zh,zw,zwh,zww};
|
||||
xoi.check_overlap({zn,zg,zci,zco,zh,zw,zwh,zww});
|
||||
if (xoi.failed) continue;
|
||||
xn = xoi.mm[zn];
|
||||
xc = xoi.mm[zci];
|
||||
xh = xoi.mm[zh];
|
||||
xw = xoi.mm[zw];
|
||||
xformat = xoi.format("abcd", {xn, xc, xh, xw});
|
||||
wci = woi.mm[zci];
|
||||
wco = woi.mm[zco];
|
||||
wh = woi.mm[zwh];
|
||||
ww = woi.mm[zww];
|
||||
wformat = xoi.format("iohw", {wci, wco, wh, ww});
|
||||
yn = yoi.mm[zn];
|
||||
yc = yoi.mm[zco];
|
||||
yh = yoi.mm[zh];
|
||||
yw = yoi.mm[zw];
|
||||
yformat = xoi.format("abcd", {yn, yc, yh, yw});
|
||||
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
// cuda doesn't support "iohw" format
|
||||
if (fop->flags.get(NodeFlags::_cuda) && wformat == "iohw") continue;
|
||||
if (xoi.failed) continue;
|
||||
|
||||
std::stringstream ss;
|
||||
// i@zh*stride+i@zwh+padding
|
||||
ss << "i" << zh << "*stride+i" << zwh << "*dilation+padding";
|
||||
auto expr_h = expr::make(ss.str());
|
||||
|
||||
vector<unique_ptr<Expr>> rh;
|
||||
auto src_h = expr::make(riop1->indexes[xh]);
|
||||
if (!expr::match(src_h.get(), expr_h.get(), {"stride", "padding", "dilation"}, {"i"+S(zh), "i"+S(zwh)}, rh)) {
|
||||
LOGvvvv << "Expr not match" << src_h << expr_h;
|
||||
continue;
|
||||
}
|
||||
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number)) continue;
|
||||
|
||||
dw = rop->x;
|
||||
stride = rh[0]->as_int();
|
||||
padding = -rh[1]->as_int();
|
||||
dilation = rh[2]->as_int();
|
||||
kernel_size = dw->shape[wformat.find("h")];
|
||||
groups = (oi1.mm[2] == -1 ? riop1 : riop2)->x->shape[xc] / dw->shape[wci];
|
||||
|
||||
if (fop->flags.get(NodeFlags::_cpu) && groups > 1) {
|
||||
LOGi << "group conv does not support mkl";
|
||||
continue;
|
||||
}
|
||||
|
||||
LOGvvvv << stride << padding << dilation << kernel_size << groups;
|
||||
|
||||
x = (oi1.mm[2] == -1 ? riop1 : riop2)->x;
|
||||
y = (oi1.mm[2] == -1 ? riop2 : riop1)->x;
|
||||
bo++;
|
||||
} else {
|
||||
// dx
|
||||
OpInspector woi(oi1.mm[0] == -1 ? riop1 : riop2);
|
||||
OpInspector yoi(oi1.mm[0] == -1 ? riop2 : riop1);
|
||||
OpInspector xoi(rop);
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
int zn, zg, zci, zco, zh, zw, zwh, zww;
|
||||
zn = zg = zci = zco = zh = zw = zwh = zww = 0;
|
||||
xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn);
|
||||
xoi.get_id(xoi.m3 & woi.m3 & yoi.m3, zg);
|
||||
xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw);
|
||||
xoi.get_id(xoi.m2 & woi.m3 & yoi.m3, zco);
|
||||
xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zci, zwh, zww);
|
||||
LOGvvvv << "group conv backward dx zn,zg,zci,zco,zh,zw,zwh,zww =" << vector<int>{zn,zg,zci,zco,zh,zw,zwh,zww};
|
||||
xoi.check_overlap({zn,zg,zci,zco,zh,zw,zwh,zww});
|
||||
if (xoi.failed) continue;
|
||||
xn = xoi.mm[zn];
|
||||
xc = xoi.mm[zci];
|
||||
xh = xoi.mm[zh];
|
||||
xw = xoi.mm[zw];
|
||||
xformat = xoi.format("abcd", {xn, xc, xh, xw});
|
||||
wci = woi.mm[zci];
|
||||
wco = woi.mm[zco];
|
||||
wh = woi.mm[zwh];
|
||||
ww = woi.mm[zww];
|
||||
wformat = xoi.format("iohw", {wci, wco, wh, ww});
|
||||
yn = yoi.mm[zn];
|
||||
yc = yoi.mm[zco];
|
||||
yh = yoi.mm[zh];
|
||||
yw = yoi.mm[zw];
|
||||
yformat = xoi.format("abcd", {yn, yc, yh, yw});
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
// cuda doesn't support "iohw" format
|
||||
if (fop->flags.get(NodeFlags::_cuda) && wformat == "iohw") continue;
|
||||
if (xoi.failed) continue;
|
||||
|
||||
std::stringstream ss;
|
||||
// i@zh*stride+i@zwh+padding
|
||||
ss << "i" << zh << "*stride+i" << zwh << "*dilation+padding";
|
||||
auto expr_h = expr::make(ss.str());
|
||||
|
||||
vector<unique_ptr<Expr>> rh;
|
||||
auto src_h = expr::make(rop->indexes[xh]);
|
||||
if (!expr::match(src_h.get(), expr_h.get(), {"stride", "padding", "dilation"}, {"i"+S(zh), "i"+S(zwh)}, rh)) {
|
||||
LOGvvvv << "Expr not match" << src_h << expr_h;
|
||||
continue;
|
||||
}
|
||||
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number)) continue;
|
||||
|
||||
dx = rop->x;
|
||||
stride = rh[0]->as_int();
|
||||
padding = -rh[1]->as_int();
|
||||
dilation = rh[2]->as_int();
|
||||
height = dx->shape[xformat.find("c")];
|
||||
width = dx->shape[xformat.find("d")];
|
||||
groups = dx->shape[xc] / (oi1.mm[0] == -1 ? riop1 : riop2)->x->shape[wci];
|
||||
|
||||
if (fop->flags.get(NodeFlags::_cpu) && groups > 1) {
|
||||
LOGi << "group conv does not support mkl";
|
||||
continue;
|
||||
}
|
||||
|
||||
LOGvvvv << stride << padding << dilation << height << width << groups;
|
||||
|
||||
w = (oi1.mm[0] == -1 ? riop1 : riop2)->x;
|
||||
y = (oi1.mm[0] == -1 ? riop2 : riop1)->x;
|
||||
bo+=2;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TODO: CUDA only support nchw(abcd)
|
||||
if (fop->flags.get(NodeFlags::_cuda) && (xformat != "abcd" || yformat != "abcd"))
|
||||
continue;
|
||||
|
||||
if (bo&1) {
|
||||
auto make_conv_w = get_op_info(
|
||||
"cudnn_conv_backward_w"
|
||||
).get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, string, string, string>();
|
||||
auto rvar_w = make_conv_w(x, y, kernel_size, stride, padding, dilation, groups, xformat, wformat, yformat);
|
||||
auto rid = fop->context->vrm.add_relay_group({{rvar_w, dw}});
|
||||
if (rid>=0) {
|
||||
auto srid = "relay"+S(rid);
|
||||
add_candidate(srid, 1);
|
||||
add_candidate(srid, 0);
|
||||
confidence = 20;
|
||||
}
|
||||
}
|
||||
if (bo&2) {
|
||||
auto make_conv_x = get_op_info(
|
||||
"cudnn_conv_backward_x"
|
||||
).get_constructor<VarPtr, Var*, Var*, int , int, int, int, int, int, string, string, string>();
|
||||
auto rvar_x = make_conv_x(w, y, height, width, stride, padding, dilation, groups, xformat, wformat, yformat);
|
||||
auto rid = fop->context->vrm.add_relay_group({{rvar_x, dx}});
|
||||
if (rid>=0) {
|
||||
auto srid = "relay"+S(rid);
|
||||
add_candidate(srid, 1);
|
||||
add_candidate(srid, 0);
|
||||
confidence = 20;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GroupConvTuner::run(PassManager* pm, TunerManager* tm) {
|
||||
FusedOp* fop=tm->oc->op;
|
||||
|
||||
forwardTune(fop);
|
||||
backwardTune(fop);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -20,4 +20,11 @@ struct ConvTuner : Tuner {
|
|||
void run(PassManager* pm, TunerManager* tm);
|
||||
};
|
||||
|
||||
struct GroupConvTuner : Tuner {
|
||||
GroupConvTuner() : Tuner("group_conv") {}
|
||||
void forwardTune(FusedOp* fop);
|
||||
void backwardTune(FusedOp* fop);
|
||||
void run(PassManager* pm, TunerManager* tm);
|
||||
};
|
||||
|
||||
}
|
|
@ -43,6 +43,7 @@ string TunerManager::tune() {
|
|||
run_tuner<ReduceTuner>(&pm);
|
||||
run_tuner<MatmulTuner>(&pm);
|
||||
run_tuner<ConvTuner>(&pm);
|
||||
run_tuner<GroupConvTuner>(&pm);
|
||||
|
||||
// use the best tuner if it is confidence enough
|
||||
if (best_tuner && best_tuner->confidence) {
|
||||
|
|
Loading…
Reference in New Issue