mirror of https://github.com/Jittor/Jittor
polish group conv tuner
This commit is contained in:
parent
1c19d5837d
commit
233bd28355
|
@ -16,7 +16,7 @@ struct CudnnConvBackwardWOp : Op {
|
|||
int kernel_size, stride, padding, dilation, groups;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
CudnnConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
CudnnConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "cudnn_conv_backward_w"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -16,7 +16,7 @@ struct CudnnConvBackwardXOp : Op {
|
|||
int xh, xw, stride, padding, dilation, groups;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "cudnn_conv_backward_x"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -13,7 +13,7 @@ struct CudnnConvOp : Op {
|
|||
int stride, padding, dilation, groups;
|
||||
string xformat, wformat, yformat;
|
||||
/* CudnnConvOp: xformat abcd represents nchw */
|
||||
CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
|
||||
const char* name() const override { return "cudnn_conv"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -46,7 +46,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
|
|||
}
|
||||
|
||||
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation),
|
||||
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation), groups(groups),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ void MklConvBackwardWOp::infer_shape() {
|
|||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
|
||||
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
|
||||
wco = yc, wci = xc;
|
||||
wco = yc, wci = xc / groups;
|
||||
wh = kernel_size;
|
||||
ww = kernel_size;
|
||||
set_shape(dw, "oihw", wformat, wco, wci, wh, ww);
|
||||
|
@ -113,12 +113,16 @@ void MklConvBackwardWOp::jit_run() {
|
|||
std::vector<std::unordered_map<int, memory>> net_bwd_args;
|
||||
|
||||
memory::dims conv_src_tz = {batch, ch_in, height, width};
|
||||
memory::dims conv_weights_tz = {ch_out, ch_in, kernel_size, kernel_size};
|
||||
memory::dims conv_weights_tz = groups>1
|
||||
? memory::dims{groups, ch_out/groups, ch_in/groups, kernel_size, kernel_size}
|
||||
: memory::dims{ch_out, ch_in, kernel_size, kernel_size};
|
||||
memory::dims conv_dst_tz = {batch, ch_out, (height+padding*2-kernel_size*dilation+dilation-1)/stride+1, (width+padding*2-kernel_size*dilation+dilation-1)/stride+1};
|
||||
memory::dims conv_strides = {stride, stride};
|
||||
memory::dims conv_padding = {padding, padding};
|
||||
memory::dims conv_dilation = {dilation-1, dilation-1};
|
||||
|
||||
if (groups>1) ASSERT(tag::@WFORMAT == tag::oihw);
|
||||
|
||||
auto conv_user_src_memory
|
||||
= memory({{conv_src_tz}, dt::@Tx, tag::@XFORMAT}, eng, net_src);
|
||||
|
||||
|
@ -144,7 +148,7 @@ void MklConvBackwardWOp::jit_run() {
|
|||
= memory({{conv_dst_tz}, dt::@Ty, tag::YFORMAT}, eng, net_diff_dst);
|
||||
|
||||
auto conv_user_diff_weights_memory
|
||||
= memory({{conv_weights_tz}, dt::@Tw, tag::WFORMAT}, eng, conv_user_diff_weights_buffer);
|
||||
= memory({{conv_weights_tz}, dt::@Tw, groups>1 ? tag::goihw : tag::@WFORMAT}, eng, conv_user_diff_weights_buffer);
|
||||
|
||||
auto conv_bwd_src_md = memory::desc({conv_src_tz}, dt::@Tx, tag::any);
|
||||
auto conv_diff_weights_md
|
||||
|
|
|
@ -13,10 +13,10 @@ namespace jittor {
|
|||
|
||||
struct MklConvBackwardWOp : Op {
|
||||
Var* x, * dy, * dw;
|
||||
int kernel_size, stride, padding, dilation;
|
||||
int kernel_size, stride, padding, dilation, groups;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
MklConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
MklConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "mkl_conv_backward_w"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -46,7 +46,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
|
|||
}
|
||||
|
||||
MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation),
|
||||
: w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation), groups(groups),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ void MklConvBackwardXOp::infer_shape() {
|
|||
int xn, xc, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
get_shape(w, "oihw", wformat, wco, wci, wh, ww);
|
||||
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
|
||||
xn = yn, xc = wci;
|
||||
xn = yn, xc = wci * groups;
|
||||
set_shape(dx, "abcd", xformat, xn, xc, xh, xw);
|
||||
}
|
||||
|
||||
|
@ -111,14 +111,18 @@ void MklConvBackwardXOp::jit_run() {
|
|||
std::vector<std::unordered_map<int, memory>> net_bwd_args;
|
||||
|
||||
memory::dims conv_src_tz = {batch, ch_in, height, width};
|
||||
memory::dims conv_weights_tz = {ch_out, ch_in, kernel_size, kernel_size};
|
||||
memory::dims conv_weights_tz = groups>1
|
||||
? memory::dims{groups, ch_out/groups, ch_in/groups, kernel_size, kernel_size}
|
||||
: memory::dims{ch_out, ch_in, kernel_size, kernel_size};
|
||||
memory::dims conv_dst_tz = {batch, ch_out, (height+padding*2-kernel_size*dilation+dilation-1)/stride+1, (width+padding*2-kernel_size*dilation+dilation-1)/stride+1};
|
||||
memory::dims conv_strides = {stride, stride};
|
||||
memory::dims conv_padding = {padding, padding};
|
||||
memory::dims conv_dilation = {dilation-1, dilation-1};
|
||||
|
||||
if (groups>1) ASSERT(tag::@WFORMAT == tag::oihw);
|
||||
|
||||
auto conv_user_weights_memory
|
||||
= memory({{conv_weights_tz}, dt::@Tw, tag::@WFORMAT}, eng, conv_weights);
|
||||
= memory({{conv_weights_tz}, dt::@Tw, groups>1 ? tag::goihw : tag::@WFORMAT}, eng, conv_weights);
|
||||
|
||||
auto conv_src_md = memory::desc({conv_src_tz}, dt::@Tx, tag::any);
|
||||
auto conv_weights_md = memory::desc({conv_weights_tz}, dt::@Tw, tag::any);
|
||||
|
|
|
@ -13,10 +13,10 @@ namespace jittor {
|
|||
|
||||
struct MklConvBackwardXOp : Op {
|
||||
Var* w, * dy, * dx;
|
||||
int xh, xw, stride, padding, dilation;
|
||||
int xh, xw, stride, padding, dilation, groups;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
MklConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
MklConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "mkl_conv_backward_x"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -45,7 +45,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
|
|||
}
|
||||
|
||||
MklConvOp::MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: x(x), w(w), stride(stride), padding(padding), dilation(dilation),
|
||||
: x(x), w(w), stride(stride), padding(padding), dilation(dilation), groups(groups),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
y = create_output(nullptr, dtype_infer(x->ns, w->ns));
|
||||
if (!this->yformat.size())
|
||||
|
@ -58,7 +58,7 @@ void MklConvOp::infer_shape() {
|
|||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
|
||||
get_shape(w, "oihw", wformat, wco, wci, wh, ww);
|
||||
ASSERTop(wci,==,xc);
|
||||
ASSERTop(wci * groups,==,xc);
|
||||
yn = xn, yc = wco;
|
||||
yh = (xh+padding*2-wh*dilation+dilation-1)/stride+1;
|
||||
yw = (xw+padding*2-ww*dilation+dilation-1)/stride+1;
|
||||
|
@ -124,18 +124,22 @@ void MklConvOp::jit_run() {
|
|||
get_shape(y, "abcd", yformat, yn, yc, yh, yw);
|
||||
|
||||
memory::dims conv1_src_tz = {xn, xc, xh, xw};
|
||||
memory::dims conv1_weights_tz = {wco, wci, wh, ww};
|
||||
memory::dims conv1_weights_tz = groups>1
|
||||
? memory::dims{groups, wco/groups, wci, wh, ww}
|
||||
: memory::dims{wco, wci, wh, ww};
|
||||
memory::dims conv1_dst_tz = {yn, yc, yh, yw};
|
||||
memory::dims conv1_strides = { stride, stride };
|
||||
memory::dims conv1_padding = { padding, padding };
|
||||
memory::dims conv1_dilation = { dilation-1, dilation-1 };
|
||||
|
||||
if (groups>1) ASSERT(tag::@WFORMAT == tag::oihw);
|
||||
|
||||
auto user_src_memory = memory(
|
||||
{ { conv1_src_tz }, dt::@Tx, tag::@XFORMAT }, eng, x->mem_ptr);
|
||||
auto user_dst_memory = memory(
|
||||
{ { conv1_dst_tz }, dt::@Ty, tag::@YFORMAT }, eng, y->mem_ptr);
|
||||
auto user_weights_memory = memory(
|
||||
{ { conv1_weights_tz }, dt::@Tw, tag::@WFORMAT }, eng, w->mem_ptr);
|
||||
{ { conv1_weights_tz }, dt::@Tw, groups>1 ? tag::goihw : tag::@WFORMAT }, eng, w->mem_ptr);
|
||||
|
||||
auto conv1_src_md = memory::desc({ conv1_src_tz }, dt::@Tx, tag::any);
|
||||
auto conv1_weights_md
|
||||
|
|
|
@ -13,10 +13,10 @@ namespace jittor {
|
|||
|
||||
struct MklConvOp : Op {
|
||||
Var* x, * w, * y;
|
||||
int stride, padding, dilation;
|
||||
int stride, padding, dilation, groups;
|
||||
string xformat, wformat, yformat;
|
||||
/* MklConvOp: xformat abcd represents nchw */
|
||||
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
|
||||
const char* name() const override { return "mkl_conv"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -93,7 +93,7 @@ def check_forward(xshape, wshape, stride, padding, dilation, use_cuda, nhwc):
|
|||
else:
|
||||
op_name = "mkl_conv"
|
||||
with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1,
|
||||
log_v=0, log_vprefix="op.cc=100", compile_options={"test":266}
|
||||
log_v=0, log_vprefix="op.cc=100,conv_tuner=1000", compile_options={"test":266}
|
||||
) as raw_log:
|
||||
x = jt.random(xshape)
|
||||
w = jt.random(wshape)
|
||||
|
@ -118,7 +118,7 @@ def check_backward(xshape, wshape, stride, padding, dilation, use_cuda, nhwc):
|
|||
op_name = "mkl_conv"
|
||||
|
||||
with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1,
|
||||
log_v=1, log_vprefix="op.cc=100,exe=1000", compile_options={"test":244}
|
||||
log_v=1, log_vprefix="op.cc=100,exe=1000,conv_t=1000", compile_options={"test":244}
|
||||
) as raw_log:
|
||||
x = jt.random(xshape)
|
||||
w = jt.random(wshape)
|
||||
|
|
|
@ -77,7 +77,7 @@ def check_forward(xshape, wshape, stride, padding, dilation, groups, use_cuda, n
|
|||
|
||||
# only check cudnn
|
||||
with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1,
|
||||
log_v=10, log_vprefix="conv_tuner.cc=1000"
|
||||
log_v=10, log_vprefix="op.cc=100,conv_tuner=1000"
|
||||
) as raw_log:
|
||||
x = jt.random(xshape)
|
||||
w = jt.random(wshape)
|
||||
|
@ -87,6 +87,8 @@ def check_forward(xshape, wshape, stride, padding, dilation, groups, use_cuda, n
|
|||
cy = test_func(x, w, stride, padding, dilation, groups)
|
||||
cy.sync()
|
||||
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: .*conv.*)")
|
||||
assert len(logs)==1
|
||||
assert np.allclose(y.data, cy.data)
|
||||
|
||||
|
||||
|
@ -96,11 +98,12 @@ def check_backward(xshape, wshape, stride, padding, dilation, groups, use_cuda,
|
|||
|
||||
# only check cudnn
|
||||
with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1,
|
||||
log_v=10, log_vprefix="conv_tuner.cc=1000"
|
||||
log_v=10, log_vprefix="op.cc=100,conv_tuner=1000"
|
||||
) as raw_log:
|
||||
x = jt.random(xshape)
|
||||
w = jt.random(wshape)
|
||||
y = test_func(x, w, stride, padding, dilation, groups)
|
||||
y.sync()
|
||||
dx, dw = jt.grad(y, [x, w])
|
||||
jt.sync([y, dx, dw])
|
||||
with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test":233}):
|
||||
|
@ -108,6 +111,8 @@ def check_backward(xshape, wshape, stride, padding, dilation, groups, use_cuda,
|
|||
cdx, cdw = jt.grad(cy, [x, w])
|
||||
jt.sync([cy, cdx, cdw])
|
||||
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: .*conv.*)")
|
||||
assert len(logs)==3
|
||||
assert np.allclose(y.data, cy.data)
|
||||
assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data, np.abs(dw.data - cdw.data).max())
|
||||
assert np.allclose(dx.data, cdx.data, 1e-3), (dx.data, cdx.data, np.abs(dx.data - cdx.data).max())
|
||||
|
@ -121,12 +126,24 @@ class TestGroupConvTuner(unittest.TestCase):
|
|||
check_forward([10,8,40,50], [16,8//groups,5,5], 1, 1, 2, groups, 1, False)
|
||||
check_forward([10,8,40,50], [16,8//groups,4,4], 3, 1, 3, groups, 1, False)
|
||||
|
||||
def test_forward(self):
|
||||
for groups in [2, 4, 8]:
|
||||
check_forward([10,8,100,100], [8,8//groups,3,3], 1, 0, 1, groups, 0, False)
|
||||
check_forward([10,8,40,50], [16,8//groups,5,5], 1, 1, 2, groups, 0, False)
|
||||
check_forward([10,8,40,50], [16,8//groups,4,4], 3, 1, 3, groups, 0, False)
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
def test_backward_cuda(self):
|
||||
for groups in [2, 4, 8]:
|
||||
check_backward([10,8,100,100], [8,8//groups,3,3], 1, 0, 1, groups, 1, False)
|
||||
check_backward([10,8,40,50], [16,8//groups,5,5], 1, 1, 2, groups, 1, False)
|
||||
check_backward([10,8,40,50], [16,8//groups,4,4], 3, 1, 3, groups, 1, False)
|
||||
|
||||
def test_backward(self):
|
||||
for groups in [2, 4, 8]:
|
||||
check_backward([10,8,100,100], [8,8//groups,3,3], 1, 0, 1, groups, 0, False)
|
||||
check_backward([10,8,40,50], [16,8//groups,5,5], 1, 1, 2, groups, 0, False)
|
||||
check_backward([10,8,40,50], [16,8//groups,4,4], 3, 1, 3, groups, 0, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "ops/op_register.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -32,8 +33,10 @@ struct OpInspector {
|
|||
uint64 m1=0, m2=0, m3=0;
|
||||
// which dimension map
|
||||
vector<int> mm;
|
||||
Op* op;
|
||||
bool failed=0;
|
||||
OpInspector(ReindexOp* op) {
|
||||
|
||||
void init(ReindexOp* op) {
|
||||
unordered_map<string, int> p;
|
||||
mm.resize(op->y->shape.size(), -1);
|
||||
for (uint i=0; i<op->y->shape.size(); i++)
|
||||
|
@ -59,7 +62,9 @@ struct OpInspector {
|
|||
m2 = ((1ll<<mm.size())-1) ^ (m1|m3);
|
||||
}
|
||||
|
||||
OpInspector(ReindexReduceOp* op) {
|
||||
OpInspector(ReindexOp* op) : op(op) { init(op); }
|
||||
|
||||
void init(ReindexReduceOp* op) {
|
||||
unordered_map<string, int> p;
|
||||
mm.resize(op->y->shape.size(), -1);
|
||||
for (uint i=0; i<op->y->shape.size(); i++)
|
||||
|
@ -85,7 +90,9 @@ struct OpInspector {
|
|||
m2 = ((1ll<<mm.size())-1) ^ (m1|m3);
|
||||
}
|
||||
|
||||
OpInspector(BroadcastToOp* op) {
|
||||
OpInspector(ReindexReduceOp* op) : op(op) { init(op); }
|
||||
|
||||
void init(BroadcastToOp* op) {
|
||||
mm.resize(op->z->shape.size(), 0);
|
||||
m2 = op->bcast_mask;
|
||||
m1 = ((1ll<<mm.size())-1) ^ (m2);
|
||||
|
@ -93,13 +100,30 @@ struct OpInspector {
|
|||
if ((m1>>i)&1) mm[i] = j++;
|
||||
}
|
||||
|
||||
OpInspector(ReduceOp* op) {
|
||||
OpInspector(BroadcastToOp* op) : op(op) { init(op); }
|
||||
|
||||
void init(ReduceOp* op) {
|
||||
mm.resize(op->x->shape.size(), 0);
|
||||
m2 = op->reduce_mask;
|
||||
m1 = ((1ll<<mm.size())-1) ^ (m2);
|
||||
for (uint i=0,j=0; i<op->x->shape.size(); i++)
|
||||
if ((m1>>i)&1) mm[i] = j++;
|
||||
}
|
||||
|
||||
OpInspector(ReduceOp* op) : op(op) { init(op); }
|
||||
|
||||
OpInspector(Op* op) : op(op) {
|
||||
if (strcmp(op->name(), "reduce") == 0)
|
||||
init((ReduceOp*)op);
|
||||
else if (strcmp(op->name(), "broadcast_to") == 0)
|
||||
init((BroadcastToOp*)op);
|
||||
else if (strcmp(op->name(), "reindex") == 0)
|
||||
init((ReindexOp*)op);
|
||||
else if (strcmp(op->name(), "reindex_reduce") == 0)
|
||||
init((ReindexReduceOp*)op);
|
||||
else
|
||||
failed = 1;
|
||||
}
|
||||
|
||||
// get last one index of binary mask
|
||||
void get_id(uint64 m, int& i) {
|
||||
|
@ -160,7 +184,7 @@ struct OpInspector {
|
|||
return 0;
|
||||
}
|
||||
|
||||
string format2(const string& fmt, const vector<int>& order) {
|
||||
string format(const string& fmt, const vector<int>& order) {
|
||||
string new_fmt = fmt;
|
||||
if (order.size() != fmt.size()) {
|
||||
failed = 1;
|
||||
|
@ -177,29 +201,6 @@ struct OpInspector {
|
|||
}
|
||||
return new_fmt;
|
||||
}
|
||||
|
||||
string format(const string& fmt, const vector<int>& order) {
|
||||
string new_fmt = fmt;
|
||||
if (order.size() != fmt.size()) {
|
||||
failed = 1;
|
||||
return "";
|
||||
}
|
||||
if (check_overlap(order))
|
||||
return "";
|
||||
vector<pair<int, int>> order_;
|
||||
for (uint i = 0; i < order.size(); i++) {
|
||||
order_.push_back(pair<int, int>(order[i], i));
|
||||
}
|
||||
sort(order_.begin(), order_.end());
|
||||
for (uint i=0; i<order_.size(); i++) {
|
||||
if (order_[i].second>=(int)new_fmt.size()) {
|
||||
failed = 1;
|
||||
return "";
|
||||
}
|
||||
new_fmt[order_[i].second] = fmt[i];
|
||||
}
|
||||
return new_fmt;
|
||||
}
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const OpInspector& oi) {
|
||||
|
@ -214,362 +215,64 @@ std::ostream& operator<<(std::ostream& os, const OpInspector& oi) {
|
|||
|
||||
void ConvTuner::forwardTune(FusedOp* fop) {
|
||||
for (Op* op : fop->ops)
|
||||
if (op->name_ex()=="reduce.add") {
|
||||
auto rop = (ReduceOp*)op;
|
||||
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
|
||||
if (op->name_ex()=="reduce.add" || op->name_ex()=="reindex_reduce.add") {
|
||||
// reduce op and reindex reduce op have the same memory layout
|
||||
// it is ok to force cast.
|
||||
auto op_iop = op->input(0)->input();
|
||||
if (!(op_iop
|
||||
&& op_iop->name_ex()=="binary.multiply"
|
||||
&& op_iop->tflag==op->tflag))
|
||||
continue;
|
||||
auto bop = (BinaryOp*)(rop->x->input());
|
||||
auto bop = (BinaryOp*)op_iop;
|
||||
|
||||
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
|
||||
if (!((bop->x->input()->name_ex()=="reindex" && bop->y->input()->name_ex()=="broadcast_to") ||
|
||||
(bop->y->input()->name_ex()=="reindex" && bop->x->input()->name_ex()=="broadcast_to"))) return;
|
||||
// riop1 reindex -> xx
|
||||
// riop2 broadcast -> ww
|
||||
auto riop1 = bop->x->input()->name_ex()=="reindex" ?
|
||||
(ReindexOp*)(bop->x->input()) : (ReindexOp*)(bop->y->input());
|
||||
auto riop2 = bop->x->input()->name_ex()=="reindex" ?
|
||||
(BroadcastToOp*)(bop->y->input()) : (BroadcastToOp*)(bop->x->input());
|
||||
if (!(bop->x->input()->type()==OpType::broadcast && bop->y->input()->type()==OpType::broadcast)) return;
|
||||
|
||||
Op* ops[3] = {op, bop->x->input(), bop->y->input()};
|
||||
int ok = 0;
|
||||
LOGvvvv << "conv like op" << fop << fop->get_jit_key();
|
||||
OpInspector xoi(riop1);
|
||||
LOGvvvv << "inspect x:" << xoi << riop1->indexes;
|
||||
OpInspector woi(riop2);
|
||||
LOGvvvv << "inspect w:" << woi;
|
||||
OpInspector yoi(rop);
|
||||
LOGvvvv << "inspect y:" << yoi;
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
int zn, zci, zco, zh, zw, zwh, zww;
|
||||
zn = zci = zco = zh = zw = zwh = zww = 0;
|
||||
xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn);
|
||||
xoi.get_id(xoi.m1 & woi.m1 & yoi.m2, zci);
|
||||
xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw);
|
||||
xoi.get_id(xoi.m2 & woi.m1 & yoi.m1, zco);
|
||||
xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zwh, zww);
|
||||
LOGvvvv << "zn,zci,zco,zh,zw,zwh,zww =" << vector<int>{zn,zci,zco,zh,zw,zwh,zww};
|
||||
xoi.check_overlap({zn,zci,zco,zh,zw,zwh,zww});
|
||||
if (xoi.failed) continue;
|
||||
xn = xoi.mm[zn];
|
||||
xc = xoi.mm[zci];
|
||||
xh = xoi.mm[zh];
|
||||
xw = xoi.mm[zw];
|
||||
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
|
||||
auto xformat = xoi.format2("abcd", {xn, xc, xh, xw});
|
||||
LOGvvvv << "xformat =" << xformat;
|
||||
wci = woi.mm[zci];
|
||||
wco = woi.mm[zco];
|
||||
wh = woi.mm[zwh];
|
||||
ww = woi.mm[zww];
|
||||
auto wformat = xoi.format2("iohw", {wci, wco, wh, ww});
|
||||
LOGvvvv << "wformat =" << wformat;
|
||||
yn = yoi.mm[zn];
|
||||
yc = yoi.mm[zco];
|
||||
yh = yoi.mm[zh];
|
||||
yw = yoi.mm[zw];
|
||||
auto yformat = xoi.format2("abcd", {yn, yc, yh, yw});
|
||||
LOGvvvv << "yformat =" << yformat;
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
// cuda doesn't support "iohw" format
|
||||
if (fop->flags.get(NodeFlags::_cuda) && wformat == "iohw") continue;
|
||||
if (xoi.failed) continue;
|
||||
std::stringstream ss;
|
||||
// i@zh*stride+i@zwh+padding
|
||||
ss << "i" << zh << "*stride+i" << zwh << "*dilation+padding";
|
||||
auto expr_h = expr::make(ss.str());
|
||||
ss.str("");
|
||||
ss << "i" << zw << "*stride+i" << zww << "*dilation+padding";
|
||||
auto expr_w = expr::make(ss.str());
|
||||
|
||||
vector<unique_ptr<Expr>> rh, rw;
|
||||
auto src_h = expr::make(riop1->indexes[xh]);
|
||||
if (!expr::match(src_h.get(), expr_h.get(), {"stride", "padding", "dilation"}, {"i"+S(zh), "i"+S(zwh)}, rh)) {
|
||||
LOGvvvv << "Expr not match" << src_h << expr_h;
|
||||
continue;
|
||||
}
|
||||
LOGvvvv << "H Expr matched" << src_h << expr_h;
|
||||
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number) || !rh[2]->is(expr::_number)) return;
|
||||
auto src_w = expr::make(riop1->indexes[xw]);
|
||||
if (!expr::match(src_w.get(), expr_w.get(), {"stride", "padding", "dilation"}, {"i"+S(zw), "i"+S(zww)}, rw))
|
||||
return;
|
||||
LOGvvvv << "W Expr matched" << src_w << expr_w;
|
||||
if (!rw[0]->is(expr::_number) || !rw[1]->is(expr::_number) || !rw[2]->is(expr::_number)) return;
|
||||
int stride_h = rh[0]->as_int();
|
||||
int padding_h = -rh[1]->as_int();
|
||||
int dilation_h = rh[2]->as_int();
|
||||
int stride_w = rw[0]->as_int();
|
||||
int padding_w = -rw[1]->as_int();
|
||||
int dilation_w = rw[2]->as_int();
|
||||
if (dilation_h < 1 || dilation_w < 1) continue;
|
||||
if (stride_h!=stride_w || padding_h!=padding_w || dilation_h!=dilation_w) {
|
||||
LOGvvvv << "cannot relay different stride and padding between h and w"
|
||||
<< stride_h << padding_h << dilation_h << stride_w << padding_w << dilation_w;
|
||||
continue;
|
||||
}
|
||||
LOGvvvv << "get stride padding and dilation" << stride_h << padding_h << dilation_h;
|
||||
if (xformat == "bacd" && dilation_h != 1) {
|
||||
LOGvvvv << "mkl not support bacd dilation, continue";
|
||||
continue;
|
||||
}
|
||||
int stride = stride_h;
|
||||
int padding = padding_h;
|
||||
int dilation = dilation_h;
|
||||
Var* x = riop1->x;
|
||||
Var* w = riop2->x;
|
||||
|
||||
int oh = (x->shape[xh]-w->shape[wh]*dilation_h+dilation_h-1+padding_h*2)/stride_h+1;
|
||||
int ow = (x->shape[xw]-w->shape[ww]*dilation_w+dilation_w-1+padding_w*2)/stride_w+1;
|
||||
if (oh != rop->y->shape[yh] || ow != rop->y->shape[yw]) continue;
|
||||
|
||||
string relay_conv_name = fop->flags.get(NodeFlags::_cpu) ?
|
||||
"mkl_conv" : "cudnn_conv";
|
||||
if (!has_op(relay_conv_name))
|
||||
continue;
|
||||
auto make_conv = get_op_info(relay_conv_name)
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, string, string, string>();
|
||||
auto rvar = make_conv(x, w, stride, padding, dilation, 1, xformat, wformat, yformat);
|
||||
auto rid = fop->context->vrm.add_relay_group({{rvar, rop->y}});
|
||||
if (rid>=0) {
|
||||
auto srid = "relay"+S(rid);
|
||||
add_candidate(srid, 1);
|
||||
add_candidate(srid, 0);
|
||||
confidence = 20;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ConvTuner::backwardTune(FusedOp* fop) {
|
||||
for (Op* op : fop->ops) {
|
||||
int bo=0;
|
||||
Var *x=NULL, *y=NULL, *w=NULL;
|
||||
Var *dw=NULL, *dx=NULL;
|
||||
int height=0,width=0,kernel_size=0,stride=0, padding=0, dilation=1;
|
||||
string xformat, yformat, wformat;
|
||||
if (op->name_ex() == "reduce.add") {
|
||||
auto rop = (ReduceOp*)op;
|
||||
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
|
||||
continue;
|
||||
auto bop = (BinaryOp*)(rop->x->input());
|
||||
|
||||
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
|
||||
if (!((bop->x->input()->name_ex()=="reindex" && bop->y->input()->name_ex()=="broadcast_to") ||
|
||||
(bop->y->input()->name_ex()=="reindex" && bop->x->input()->name_ex()=="broadcast_to"))) continue;
|
||||
auto riop1 = bop->x->input()->name_ex()=="reindex" ? (ReindexOp*)(bop->x->input()) : (ReindexOp*)(bop->y->input());
|
||||
auto riop2 = bop->x->input()->name_ex()=="reindex" ? (BroadcastToOp*)(bop->y->input()) : (BroadcastToOp*)(bop->x->input());
|
||||
|
||||
OpInspector xoi(riop1);
|
||||
LOGvvvv << "inspect x:" << xoi << riop1->indexes;
|
||||
OpInspector yoi(riop2);
|
||||
LOGvvvv << "inspect y:" << yoi;
|
||||
OpInspector woi(rop);
|
||||
LOGvvvv << "inspect w:" << woi;
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
int zn, zci, zco, zh, zw, zwh, zww;
|
||||
zn = zci = zco = zh = zw = zwh = zww = 0;
|
||||
xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn);
|
||||
xoi.get_id(xoi.m1 & woi.m1 & yoi.m2, zci);
|
||||
xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw);
|
||||
xoi.get_id(xoi.m2 & woi.m1 & yoi.m1, zco);
|
||||
xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zwh, zww);
|
||||
LOGvvvv << "zn,zci,zco,zh,zw,zwh,zww =" << vector<int>{zn,zci,zco,zh,zw,zwh,zww};
|
||||
xoi.check_overlap({zn,zci,zco,zh,zw,zwh,zww});
|
||||
if (xoi.failed) continue;
|
||||
xn = xoi.mm[zn];
|
||||
xc = xoi.mm[zci];
|
||||
xh = xoi.mm[zh];
|
||||
xw = xoi.mm[zw];
|
||||
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
|
||||
xformat = xoi.format2("abcd", {xn, xc, xh, xw});
|
||||
LOGvvvv << "xformat =" << xformat;
|
||||
wci = woi.mm[zci];
|
||||
wco = woi.mm[zco];
|
||||
wh = woi.mm[zwh];
|
||||
ww = woi.mm[zww];
|
||||
wformat = xoi.format2("iohw", {wci, wco, wh, ww});
|
||||
LOGvvvv << "wformat =" << wformat;
|
||||
yn = yoi.mm[zn];
|
||||
yc = yoi.mm[zco];
|
||||
yh = yoi.mm[zh];
|
||||
yw = yoi.mm[zw];
|
||||
yformat = xoi.format2("abcd", {yn, yc, yh, yw});
|
||||
LOGvvvv << "yformat =" << yformat;
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
// cuda doesn't support "iohw" format
|
||||
if (fop->flags.get(NodeFlags::_cuda) && wformat == "iohw") continue;
|
||||
if (xoi.failed) continue;
|
||||
|
||||
std::stringstream ss;
|
||||
// i@zh*stride+i@zwh+padding
|
||||
ss << "i" << zh << "*stride+i" << zwh << "*dilation+padding";
|
||||
auto expr_h = expr::make(ss.str());
|
||||
|
||||
vector<unique_ptr<Expr>> rh;
|
||||
auto src_h = expr::make(riop1->indexes[xh]);
|
||||
if (!expr::match(src_h.get(), expr_h.get(), {"stride", "padding", "dilation"}, {"i"+S(zh), "i"+S(zwh)}, rh)) {
|
||||
LOGvvvv << "Expr not match" << src_h << expr_h;
|
||||
continue;
|
||||
}
|
||||
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number)) continue;
|
||||
|
||||
dw = rop->y;
|
||||
stride = rh[0]->as_int();
|
||||
padding = -rh[1]->as_int();
|
||||
dilation = rh[2]->as_int();
|
||||
kernel_size = dw->shape[wformat.find("h")];
|
||||
x = riop1->x;
|
||||
y = riop2->x;
|
||||
bo++;
|
||||
LOGvvvv << "backward_w get stride padding and dilation" << stride << padding << dilation;
|
||||
} else if (op->name_ex() == "reindex_reduce.add") {
|
||||
auto rop = (ReindexReduceOp*)op;
|
||||
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
|
||||
continue;
|
||||
auto bop = (BinaryOp*)(rop->y->input());
|
||||
|
||||
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
|
||||
if (!((bop->x->input()->name_ex()=="broadcast_to" && bop->y->input()->name_ex()=="broadcast_to"))) return;
|
||||
auto riop1 = (BroadcastToOp*)(bop->x->input());
|
||||
auto riop2 = (BroadcastToOp*)(bop->y->input());
|
||||
|
||||
OpInspector woi(riop1);
|
||||
LOGvvvv << "inspect w:" << woi;
|
||||
OpInspector yoi(riop2);
|
||||
LOGvvvv << "inspect y:" << yoi;
|
||||
OpInspector xoi(rop);
|
||||
LOGvvvv << "inspect x:" << xoi;
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
int zn, zci, zco, zh, zw, zwh, zww;
|
||||
zn = zci = zco = zh = zw = zwh = zww = 0;
|
||||
xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn);
|
||||
xoi.get_id(xoi.m1 & woi.m1 & yoi.m2, zci);
|
||||
xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw);
|
||||
xoi.get_id(xoi.m2 & woi.m1 & yoi.m1, zco);
|
||||
xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zwh, zww);
|
||||
LOGvvvv << "zn,zci,zco,zh,zw,zwh,zww =" << vector<int>{zn,zci,zco,zh,zw,zwh,zww};
|
||||
xoi.check_overlap({zn,zci,zco,zh,zw,zwh,zww});
|
||||
if (xoi.failed) continue;
|
||||
xn = xoi.mm[zn];
|
||||
xc = xoi.mm[zci];
|
||||
xh = xoi.mm[zh];
|
||||
xw = xoi.mm[zw];
|
||||
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
|
||||
xformat = xoi.format2("abcd", {xn, xc, xh, xw});
|
||||
LOGvvvv << "xformat =" << xformat;
|
||||
wci = woi.mm[zci];
|
||||
wco = woi.mm[zco];
|
||||
wh = woi.mm[zwh];
|
||||
ww = woi.mm[zww];
|
||||
wformat = xoi.format2("iohw", {wci, wco, wh, ww});
|
||||
LOGvvvv << "wformat =" << wformat;
|
||||
yn = yoi.mm[zn];
|
||||
yc = yoi.mm[zco];
|
||||
yh = yoi.mm[zh];
|
||||
yw = yoi.mm[zw];
|
||||
yformat = xoi.format2("abcd", {yn, yc, yh, yw});
|
||||
LOGvvvv << "yformat =" << yformat;
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
// cuda doesn't support "iohw" format
|
||||
if (fop->flags.get(NodeFlags::_cuda) && wformat == "iohw") continue;
|
||||
if (xoi.failed) continue;
|
||||
|
||||
std::stringstream ss;
|
||||
// i@zh*stride+i@zwh+padding
|
||||
ss << "i" << zh << "*stride+i" << zwh << "*dilation+padding";
|
||||
auto expr_h = expr::make(ss.str());
|
||||
|
||||
vector<unique_ptr<Expr>> rh;
|
||||
auto src_h = expr::make(rop->indexes[xh]);
|
||||
if (!expr::match(src_h.get(), expr_h.get(), {"stride", "padding", "dilation"}, {"i"+S(zh), "i"+S(zwh)}, rh)) {
|
||||
LOGvvvv << "Expr not match" << src_h << expr_h;
|
||||
continue;
|
||||
}
|
||||
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number)) continue;
|
||||
|
||||
dx = rop->x;
|
||||
stride = rh[0]->as_int();
|
||||
padding = -rh[1]->as_int();
|
||||
dilation = rh[2]->as_int();
|
||||
height = dx->shape[xformat.find("c")];
|
||||
width = dx->shape[xformat.find("d")];
|
||||
w = riop1->x;
|
||||
y = riop2->x;
|
||||
bo+=2;
|
||||
LOGvvvv << "backward_x get stride padding and dilation" << stride << padding << dilation;
|
||||
}
|
||||
|
||||
// TODO: CUDA only support nchw(abcd)
|
||||
if (fop->flags.get(NodeFlags::_cuda) && (xformat != "abcd" || yformat != "abcd"))
|
||||
continue;
|
||||
|
||||
if (bo&1) {
|
||||
auto make_conv_w = get_op_info(
|
||||
fop->flags.get(NodeFlags::_cpu) ?
|
||||
"mkl_conv_backward_w" : "cudnn_conv_backward_w"
|
||||
).get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, string, string, string>();
|
||||
auto rvar_w = make_conv_w(x, y, kernel_size, stride, padding, dilation, 1, xformat, wformat, yformat);
|
||||
auto rid = fop->context->vrm.add_relay_group({{rvar_w, dw}});
|
||||
if (rid>=0) {
|
||||
auto srid = "relay"+S(rid);
|
||||
add_candidate(srid, 1);
|
||||
add_candidate(srid, 0);
|
||||
confidence = 20;
|
||||
}
|
||||
}
|
||||
if (bo&2) {
|
||||
auto make_conv_x = get_op_info(
|
||||
fop->flags.get(NodeFlags::_cpu) ?
|
||||
"mkl_conv_backward_x" : "cudnn_conv_backward_x"
|
||||
).get_constructor<VarPtr, Var*, Var*, int , int, int, int, int, int, string, string, string>();
|
||||
auto rvar_x = make_conv_x(w, y, height, width, stride, padding, dilation, 1, xformat, wformat, yformat);
|
||||
auto rid = fop->context->vrm.add_relay_group({{rvar_x, dx}});
|
||||
if (rid>=0) {
|
||||
auto srid = "relay"+S(rid);
|
||||
add_candidate(srid, 1);
|
||||
add_candidate(srid, 0);
|
||||
confidence = 20;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ConvTuner::run(PassManager* pm, TunerManager* tm) {
|
||||
FusedOp* fop=tm->oc->op;
|
||||
|
||||
forwardTune(fop);
|
||||
backwardTune(fop);
|
||||
}
|
||||
|
||||
void GroupConvTuner::forwardTune(FusedOp* fop) {
|
||||
LOGvvvv << "tune group conv";
|
||||
for (Op* op : fop->ops) {
|
||||
if (op->name_ex()=="reindex_reduce.add") {
|
||||
auto rop = (ReindexReduceOp*)op;
|
||||
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply" && rop->y->input()->tflag==op->tflag))
|
||||
for (int y_id=0; y_id<3; y_id++)
|
||||
for (int x_id=0; x_id<3; x_id++)
|
||||
for (int w_id=0; w_id<3; w_id++) {
|
||||
if (ok) break;
|
||||
if (x_id == y_id || x_id == w_id || y_id == w_id) continue;
|
||||
LOGvvvv << "try" << x_id << y_id << w_id;
|
||||
OpInspector xoi(ops[x_id]);
|
||||
OpInspector yoi(ops[y_id]);
|
||||
OpInspector woi(ops[w_id]);
|
||||
vector<string>* xop_indexes;
|
||||
if (strcmp(xoi.op->name(), "reindex") == 0) {
|
||||
xop_indexes = &((ReindexOp*)xoi.op)->indexes;
|
||||
} else
|
||||
if (strcmp(xoi.op->name(), "reindex_reduce") == 0) {
|
||||
xop_indexes = &((ReindexReduceOp*)xoi.op)->indexes;
|
||||
} else
|
||||
continue;
|
||||
auto bop = (BinaryOp*)(rop->y->input());
|
||||
|
||||
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
|
||||
if (!(bop->x->input()->name_ex()=="reindex" && bop->y->input()->name_ex()=="reindex")) return;
|
||||
auto riop1 = (ReindexOp*)(bop->x->input());
|
||||
auto riop2 = (ReindexOp*)(bop->y->input());
|
||||
LOGvvvv << "conv like op" << fop << fop->get_jit_key();
|
||||
OpInspector xoi(riop1);
|
||||
OpInspector woi(riop2);
|
||||
// determine which is which (since both are ReindexOp)
|
||||
if (xoi.mm[0] == -1 && woi.mm[0] == 0) {
|
||||
std::swap(xoi, woi);
|
||||
}
|
||||
OpInspector yoi(rop);
|
||||
if (xoi.failed || yoi.failed || woi.failed) continue;
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
int zn, zg, zci, zco, zh, zw, zwh, zww;
|
||||
zn = zg = zci = zco = zh = zw = zwh = zww = 0;
|
||||
xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn);
|
||||
xoi.get_id(xoi.m3 & woi.m3 & yoi.m3, zg);
|
||||
xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw);
|
||||
xoi.get_id(xoi.m2 & woi.m3 & yoi.m3, zco);
|
||||
xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zci, zwh, zww);
|
||||
LOGvvvv << "zn,zg,zci,zco,zh,zw,zwh,zww =" << vector<int>{zn,zg,zci,zco,zh,zw,zwh,zww};
|
||||
xoi.check_overlap({zn,zg,zci,zco,zh,zw,zwh,zww});
|
||||
zn = zci = zco = zh = zw = zwh = zww = 0;
|
||||
if (bop->x->shape.size() == 7) {
|
||||
xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn);
|
||||
xoi.get_id(xoi.m1 & woi.m1 & yoi.m2, zci);
|
||||
xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw);
|
||||
xoi.get_id(xoi.m2 & woi.m1 & yoi.m1, zco);
|
||||
xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zwh, zww);
|
||||
LOGvvvv << "zn,zci,zco,zh,zw,zwh,zww =" << vector<int>{zn,zci,zco,zh,zw,zwh,zww};
|
||||
xoi.check_overlap({zn,zci,zco,zh,zw,zwh,zww});
|
||||
zg = -1;
|
||||
} else {
|
||||
if (bop->x->shape.size() != 8)
|
||||
continue;
|
||||
// group conv
|
||||
xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn);
|
||||
xoi.get_id(xoi.m3 & woi.m3 & yoi.m3, zg);
|
||||
xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw);
|
||||
xoi.get_id(xoi.m2 & woi.m3 & yoi.m3, zco);
|
||||
xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zci, zwh, zww);
|
||||
LOGvvvv << "zn,zg,zci,zco,zh,zw,zwh,zww =" << vector<int>{zn,zg,zci,zco,zh,zw,zwh,zww};
|
||||
xoi.check_overlap({zn,zg,zci,zco,zh,zw,zwh,zww});
|
||||
}
|
||||
if (xoi.failed) continue;
|
||||
xn = xoi.mm[zn];
|
||||
xc = xoi.mm[zci];
|
||||
|
@ -590,6 +293,7 @@ void GroupConvTuner::forwardTune(FusedOp* fop) {
|
|||
yw = yoi.mm[zw];
|
||||
auto yformat = xoi.format("abcd", {yn, yc, yh, yw});
|
||||
LOGvvvv << "yformat =" << yformat;
|
||||
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
// cuda doesn't support "iohw" format
|
||||
|
@ -604,15 +308,17 @@ void GroupConvTuner::forwardTune(FusedOp* fop) {
|
|||
auto expr_w = expr::make(ss.str());
|
||||
|
||||
vector<unique_ptr<Expr>> rh, rw;
|
||||
auto src_h = expr::make(riop1->indexes[xh]);
|
||||
auto src_h = expr::make(xop_indexes->at(xh));
|
||||
if (!expr::match(src_h.get(), expr_h.get(), {"stride", "padding", "dilation"}, {"i"+S(zh), "i"+S(zwh)}, rh)) {
|
||||
LOGvvvv << "Expr not match" << src_h << expr_h;
|
||||
continue;
|
||||
}
|
||||
LOGvvvv << "H Expr matched" << src_h << expr_h;
|
||||
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number) || !rh[2]->is(expr::_number)) return;
|
||||
auto src_w = expr::make(riop1->indexes[xw]);
|
||||
auto src_w = expr::make(xop_indexes->at(xw));
|
||||
if (!expr::match(src_w.get(), expr_w.get(), {"stride", "padding", "dilation"}, {"i"+S(zw), "i"+S(zww)}, rw))
|
||||
return;
|
||||
continue;
|
||||
LOGvvvv << "W Expr matched" << src_w << expr_w;
|
||||
if (!rw[0]->is(expr::_number) || !rw[1]->is(expr::_number) || !rw[2]->is(expr::_number)) return;
|
||||
int stride_h = rh[0]->as_int();
|
||||
int padding_h = -rh[1]->as_int();
|
||||
|
@ -622,247 +328,95 @@ void GroupConvTuner::forwardTune(FusedOp* fop) {
|
|||
int dilation_w = rw[2]->as_int();
|
||||
if (dilation_h < 1 || dilation_w < 1) continue;
|
||||
if (stride_h!=stride_w || padding_h!=padding_w || dilation_h!=dilation_w) {
|
||||
LOGvvvv << "cannot relay different stride and padding between h and w"
|
||||
<< stride_h << padding_h << dilation_h << stride_w << padding_w << dilation_w;
|
||||
LOGw << "cannot relay different stride and padding between h and w"
|
||||
<< stride_h << padding_h << dilation_h << stride_w << padding_w << dilation_w
|
||||
<< "This may cause low performance. Please send us issue if you need it.";
|
||||
continue;
|
||||
}
|
||||
LOGvvvv << "get stride padding and dilation" << stride_h << padding_h << dilation_h;
|
||||
|
||||
if (xformat == "bacd") {
|
||||
LOGvvvv << "mkl not support bacd, continue";
|
||||
continue;
|
||||
}
|
||||
int stride = stride_h;
|
||||
int padding = padding_h;
|
||||
int dilation = dilation_h;
|
||||
Var* x = riop1->x;
|
||||
Var* w = riop2->x;
|
||||
Var* x = x_id == 0 ? xoi.op->output(0) : xoi.op->input(0);
|
||||
Var* w = w_id == 0 ? woi.op->output(0) : woi.op->input(0);
|
||||
Var* y = y_id == 0 ? yoi.op->output(0) : yoi.op->input(0);
|
||||
|
||||
int oh = (x->shape[xh]-w->shape[wh]*dilation_h+dilation_h-1+padding_h*2)/stride_h+1;
|
||||
int ow = (x->shape[xw]-w->shape[ww]*dilation_w+dilation_w-1+padding_w*2)/stride_w+1;
|
||||
if (oh != rop->x->shape[yh] || ow != rop->x->shape[yw]) continue;
|
||||
|
||||
int groups = x->shape[xc] / w->shape[wci];
|
||||
if (oh != y->shape[yh] || ow != y->shape[yw]) {
|
||||
LOGvvvv << "shape not match" << "(" >> oh >> "," >> ow >> ") !="
|
||||
<< "(" >> y->shape[yh] >> "," >> y->shape[yw] >> ")";
|
||||
continue;
|
||||
}
|
||||
int groups = zg==-1 ? 1 : x->shape[xc] / w->shape[wci];
|
||||
LOGvvvv << "groups: " << groups;
|
||||
if (fop->flags.get(NodeFlags::_cpu) && groups > 1) {
|
||||
LOGi << "group conv does not support mkl";
|
||||
if (groups>1 && wformat != "oihw")
|
||||
continue;
|
||||
}
|
||||
|
||||
string relay_conv_name = "cudnn_conv";
|
||||
if (!has_op(relay_conv_name))
|
||||
continue;
|
||||
auto make_conv = get_op_info(relay_conv_name)
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, string, string, string>();
|
||||
auto rvar = make_conv(x, w, stride, padding, dilation, groups, xformat, wformat, yformat);
|
||||
auto rid = fop->context->vrm.add_relay_group({{rvar, rop->x}});
|
||||
if (rid>=0) {
|
||||
auto srid = "relay"+S(rid);
|
||||
add_candidate(srid, 1);
|
||||
add_candidate(srid, 0);
|
||||
confidence = 20;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
VarPtr rvar;
|
||||
int rid;
|
||||
string relay_conv_name;
|
||||
|
||||
void GroupConvTuner::backwardTune(FusedOp* fop) {
|
||||
for (Op* op : fop->ops) {
|
||||
int bo=0;
|
||||
Var *x=NULL, *y=NULL, *w=NULL;
|
||||
Var *dw=NULL, *dx=NULL;
|
||||
int height=0,width=0,kernel_size=0,stride=0, padding=0, dilation=1, groups=1;
|
||||
string xformat, yformat, wformat;
|
||||
if (op->name_ex() == "reindex_reduce.add") {
|
||||
auto rop = (ReindexReduceOp*)op;
|
||||
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply" && rop->y->input()->tflag==op->tflag))
|
||||
continue;
|
||||
auto bop = (BinaryOp*)(rop->y->input());
|
||||
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
|
||||
if (!(bop->x->input()->name_ex()=="reindex" && bop->y->input()->name_ex()=="reindex")) return;
|
||||
auto riop1 = (ReindexOp*)(bop->x->input());
|
||||
auto riop2 = (ReindexOp*)(bop->y->input());
|
||||
LOGvvvv << "conv like op" << fop << fop->get_jit_key();
|
||||
|
||||
OpInspector oi1(riop1);
|
||||
OpInspector oi2(riop2);
|
||||
|
||||
|
||||
if (oi1.mm[0] == 0 && oi2.mm[0] == 0) {
|
||||
// dw
|
||||
// x.mm [0,1,-1,1,2,3,2,3] y.mm [0,1,1,-1,2,3,-1,-1] w.mm [-1,0,0,1,-1,-1,2,3]
|
||||
OpInspector xoi(oi1.mm[2] == -1 ? riop1 : riop2);
|
||||
OpInspector yoi(oi1.mm[2] == -1 ? riop2 : riop1);
|
||||
OpInspector woi(rop);
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
int zn, zg, zci, zco, zh, zw, zwh, zww;
|
||||
zn = zg = zci = zco = zh = zw = zwh = zww = 0;
|
||||
xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn);
|
||||
xoi.get_id(xoi.m3 & woi.m3 & yoi.m3, zg);
|
||||
xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw);
|
||||
xoi.get_id(xoi.m2 & woi.m3 & yoi.m3, zco);
|
||||
xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zci, zwh, zww);
|
||||
LOGvvvv << "group conv backward dw zn,zg,zci,zco,zh,zw,zwh,zww =" << vector<int>{zn,zg,zci,zco,zh,zw,zwh,zww};
|
||||
xoi.check_overlap({zn,zg,zci,zco,zh,zw,zwh,zww});
|
||||
if (xoi.failed) continue;
|
||||
xn = xoi.mm[zn];
|
||||
xc = xoi.mm[zci];
|
||||
xh = xoi.mm[zh];
|
||||
xw = xoi.mm[zw];
|
||||
xformat = xoi.format("abcd", {xn, xc, xh, xw});
|
||||
wci = woi.mm[zci];
|
||||
wco = woi.mm[zco];
|
||||
wh = woi.mm[zwh];
|
||||
ww = woi.mm[zww];
|
||||
wformat = xoi.format("iohw", {wci, wco, wh, ww});
|
||||
yn = yoi.mm[zn];
|
||||
yc = yoi.mm[zco];
|
||||
yh = yoi.mm[zh];
|
||||
yw = yoi.mm[zw];
|
||||
yformat = xoi.format("abcd", {yn, yc, yh, yw});
|
||||
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
// cuda doesn't support "iohw" format
|
||||
if (fop->flags.get(NodeFlags::_cuda) && wformat == "iohw") continue;
|
||||
if (xoi.failed) continue;
|
||||
|
||||
std::stringstream ss;
|
||||
// i@zh*stride+i@zwh+padding
|
||||
ss << "i" << zh << "*stride+i" << zwh << "*dilation+padding";
|
||||
auto expr_h = expr::make(ss.str());
|
||||
|
||||
vector<unique_ptr<Expr>> rh;
|
||||
auto src_h = expr::make(riop1->indexes[xh]);
|
||||
if (!expr::match(src_h.get(), expr_h.get(), {"stride", "padding", "dilation"}, {"i"+S(zh), "i"+S(zwh)}, rh)) {
|
||||
LOGvvvv << "Expr not match" << src_h << expr_h;
|
||||
if (y_id == 0) {
|
||||
relay_conv_name = fop->flags.get(NodeFlags::_cpu) ?
|
||||
"mkl_conv" : "cudnn_conv";
|
||||
if (!has_op(relay_conv_name))
|
||||
continue;
|
||||
}
|
||||
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number)) continue;
|
||||
|
||||
dw = rop->x;
|
||||
stride = rh[0]->as_int();
|
||||
padding = -rh[1]->as_int();
|
||||
dilation = rh[2]->as_int();
|
||||
kernel_size = dw->shape[wformat.find("h")];
|
||||
groups = (oi1.mm[2] == -1 ? riop1 : riop2)->x->shape[xc] / dw->shape[wci];
|
||||
|
||||
if (fop->flags.get(NodeFlags::_cpu) && groups > 1) {
|
||||
LOGi << "group conv does not support mkl";
|
||||
auto make_conv = get_op_info(relay_conv_name)
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, string, string, string>();
|
||||
LOGvvvv << x << w << stride << padding << dilation << groups << xformat << wformat << yformat;
|
||||
rvar = make_conv(x, w, stride, padding, dilation, groups, xformat, wformat, yformat);
|
||||
} else
|
||||
if (x_id == 0) {
|
||||
relay_conv_name = fop->flags.get(NodeFlags::_cpu) ?
|
||||
"mkl_conv_backward_x" : "cudnn_conv_backward_x";
|
||||
if (!has_op(relay_conv_name))
|
||||
continue;
|
||||
}
|
||||
|
||||
LOGvvvv << stride << padding << dilation << kernel_size << groups;
|
||||
|
||||
x = (oi1.mm[2] == -1 ? riop1 : riop2)->x;
|
||||
y = (oi1.mm[2] == -1 ? riop2 : riop1)->x;
|
||||
bo++;
|
||||
auto height = x->shape[xformat.find("c")];
|
||||
auto width = x->shape[xformat.find("d")];
|
||||
auto make_conv_x = get_op_info(relay_conv_name)
|
||||
.get_constructor<VarPtr, Var*, Var*, int , int, int, int, int, int, string, string, string>();
|
||||
LOGvvvv << w << y << height << width << stride << padding << dilation << groups << xformat << wformat << yformat;
|
||||
rvar = make_conv_x(w, y, height, width, stride, padding, dilation, groups, xformat, wformat, yformat);
|
||||
} else {
|
||||
// dx
|
||||
OpInspector woi(oi1.mm[0] == -1 ? riop1 : riop2);
|
||||
OpInspector yoi(oi1.mm[0] == -1 ? riop2 : riop1);
|
||||
OpInspector xoi(rop);
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
int zn, zg, zci, zco, zh, zw, zwh, zww;
|
||||
zn = zg = zci = zco = zh = zw = zwh = zww = 0;
|
||||
xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn);
|
||||
xoi.get_id(xoi.m3 & woi.m3 & yoi.m3, zg);
|
||||
xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw);
|
||||
xoi.get_id(xoi.m2 & woi.m3 & yoi.m3, zco);
|
||||
xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zci, zwh, zww);
|
||||
LOGvvvv << "group conv backward dx zn,zg,zci,zco,zh,zw,zwh,zww =" << vector<int>{zn,zg,zci,zco,zh,zw,zwh,zww};
|
||||
xoi.check_overlap({zn,zg,zci,zco,zh,zw,zwh,zww});
|
||||
if (xoi.failed) continue;
|
||||
xn = xoi.mm[zn];
|
||||
xc = xoi.mm[zci];
|
||||
xh = xoi.mm[zh];
|
||||
xw = xoi.mm[zw];
|
||||
xformat = xoi.format("abcd", {xn, xc, xh, xw});
|
||||
wci = woi.mm[zci];
|
||||
wco = woi.mm[zco];
|
||||
wh = woi.mm[zwh];
|
||||
ww = woi.mm[zww];
|
||||
wformat = xoi.format("iohw", {wci, wco, wh, ww});
|
||||
yn = yoi.mm[zn];
|
||||
yc = yoi.mm[zco];
|
||||
yh = yoi.mm[zh];
|
||||
yw = yoi.mm[zw];
|
||||
yformat = xoi.format("abcd", {yn, yc, yh, yw});
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
// cuda doesn't support "iohw" format
|
||||
if (fop->flags.get(NodeFlags::_cuda) && wformat == "iohw") continue;
|
||||
if (xoi.failed) continue;
|
||||
|
||||
std::stringstream ss;
|
||||
// i@zh*stride+i@zwh+padding
|
||||
ss << "i" << zh << "*stride+i" << zwh << "*dilation+padding";
|
||||
auto expr_h = expr::make(ss.str());
|
||||
|
||||
vector<unique_ptr<Expr>> rh;
|
||||
auto src_h = expr::make(rop->indexes[xh]);
|
||||
if (!expr::match(src_h.get(), expr_h.get(), {"stride", "padding", "dilation"}, {"i"+S(zh), "i"+S(zwh)}, rh)) {
|
||||
LOGvvvv << "Expr not match" << src_h << expr_h;
|
||||
relay_conv_name = fop->flags.get(NodeFlags::_cpu) ?
|
||||
"mkl_conv_backward_w" : "cudnn_conv_backward_w";
|
||||
if (!has_op(relay_conv_name))
|
||||
continue;
|
||||
auto kh = w->shape[wformat.find("h")];
|
||||
auto kw = w->shape[wformat.find("w")];
|
||||
if (kh != kw) {
|
||||
LOGvvvv << "TODO: relay conv_backward_w when kh != kw" << kh << kw;
|
||||
continue;
|
||||
}
|
||||
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number)) continue;
|
||||
|
||||
dx = rop->x;
|
||||
stride = rh[0]->as_int();
|
||||
padding = -rh[1]->as_int();
|
||||
dilation = rh[2]->as_int();
|
||||
height = dx->shape[xformat.find("c")];
|
||||
width = dx->shape[xformat.find("d")];
|
||||
groups = dx->shape[xc] / (oi1.mm[0] == -1 ? riop1 : riop2)->x->shape[wci];
|
||||
|
||||
if (fop->flags.get(NodeFlags::_cpu) && groups > 1) {
|
||||
LOGi << "group conv does not support mkl";
|
||||
continue;
|
||||
}
|
||||
|
||||
LOGvvvv << stride << padding << dilation << height << width << groups;
|
||||
|
||||
w = (oi1.mm[0] == -1 ? riop1 : riop2)->x;
|
||||
y = (oi1.mm[0] == -1 ? riop2 : riop1)->x;
|
||||
bo+=2;
|
||||
LOGvvvv << x << y << kh << stride << padding << dilation << groups << xformat << wformat << yformat;
|
||||
auto make_conv_w = get_op_info(relay_conv_name)
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, string, string, string>();
|
||||
rvar = make_conv_w(x, y, kh, stride, padding, dilation, groups, xformat, wformat, yformat);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TODO: CUDA only support nchw(abcd)
|
||||
if (fop->flags.get(NodeFlags::_cuda) && (xformat != "abcd" || yformat != "abcd"))
|
||||
continue;
|
||||
|
||||
if (bo&1) {
|
||||
auto make_conv_w = get_op_info(
|
||||
"cudnn_conv_backward_w"
|
||||
).get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, string, string, string>();
|
||||
auto rvar_w = make_conv_w(x, y, kernel_size, stride, padding, dilation, groups, xformat, wformat, yformat);
|
||||
auto rid = fop->context->vrm.add_relay_group({{rvar_w, dw}});
|
||||
LOGvvvv << relay_conv_name << "output:" << rvar;
|
||||
rid = fop->context->vrm.add_relay_group({{rvar, op->output(0)}});
|
||||
if (rid>=0) {
|
||||
auto srid = "relay"+S(rid);
|
||||
add_candidate(srid, 1);
|
||||
add_candidate(srid, 0);
|
||||
confidence = 20;
|
||||
ok = 1;
|
||||
LOGvvvv << "ok" << x_id << y_id << w_id;
|
||||
}
|
||||
}
|
||||
if (bo&2) {
|
||||
auto make_conv_x = get_op_info(
|
||||
"cudnn_conv_backward_x"
|
||||
).get_constructor<VarPtr, Var*, Var*, int , int, int, int, int, int, string, string, string>();
|
||||
auto rvar_x = make_conv_x(w, y, height, width, stride, padding, dilation, groups, xformat, wformat, yformat);
|
||||
auto rid = fop->context->vrm.add_relay_group({{rvar_x, dx}});
|
||||
if (rid>=0) {
|
||||
auto srid = "relay"+S(rid);
|
||||
add_candidate(srid, 1);
|
||||
add_candidate(srid, 0);
|
||||
confidence = 20;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GroupConvTuner::run(PassManager* pm, TunerManager* tm) {
|
||||
void ConvTuner::run(PassManager* pm, TunerManager* tm) {
|
||||
FusedOp* fop=tm->oc->op;
|
||||
|
||||
forwardTune(fop);
|
||||
backwardTune(fop);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -16,14 +16,7 @@ namespace jittor {
|
|||
struct ConvTuner : Tuner {
|
||||
ConvTuner() : Tuner("conv") {}
|
||||
void forwardTune(FusedOp* fop);
|
||||
void backwardTune(FusedOp* fop);
|
||||
void run(PassManager* pm, TunerManager* tm);
|
||||
};
|
||||
|
||||
struct GroupConvTuner : Tuner {
|
||||
GroupConvTuner() : Tuner("group_conv") {}
|
||||
void forwardTune(FusedOp* fop);
|
||||
void backwardTune(FusedOp* fop);
|
||||
// void backwardTune(FusedOp* fop);
|
||||
void run(PassManager* pm, TunerManager* tm);
|
||||
};
|
||||
|
||||
|
|
|
@ -43,7 +43,6 @@ string TunerManager::tune() {
|
|||
run_tuner<ReduceTuner>(&pm);
|
||||
run_tuner<MatmulTuner>(&pm);
|
||||
run_tuner<ConvTuner>(&pm);
|
||||
run_tuner<GroupConvTuner>(&pm);
|
||||
|
||||
// use the best tuner if it is confidence enough
|
||||
if (best_tuner && best_tuner->confidence) {
|
||||
|
|
Loading…
Reference in New Issue