polish conv transpose group

This commit is contained in:
Dun Liang 2021-12-17 17:54:41 +08:00
parent 922c0d8246
commit ac66897047
10 changed files with 139 additions and 35 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.1.26'
__version__ = '1.3.1.27'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -12,6 +12,7 @@
#include "cudnn_conv_backward_w_op.h"
#include "cudnn_wrapper.h"
#include "executor.h"
#include "ops/op_register.h"
using namespace std;
@ -72,6 +73,30 @@ void CudnnConvBackwardWOp::jit_prepare(JK& jk) {
jk << _CS("][YFORMAT:") << yformat;
jk << ']';
}
static auto make_conv = get_op_info("cudnn_conv")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, string, string, string>();
static auto make_backwardx = get_op_info("cudnn_conv_backward_x")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
VarPtr CudnnConvBackwardWOp::grad(Var* out, Var* dout, Var* v, int v_index) {
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
if (xformat == "nchw") {
x->shape.unpack(xn, xc, xh, xw);
dy->shape.unpack(yn, yc, yh, yw);
} else {
x->shape.unpack(xn, xh, xw, xc);
dy->shape.unpack(yn, yh, yw, yc);
}
if (v_index == 0) {
return make_backwardx(dout, dy, xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
} else {
return make_conv(x, dout, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
}
}
unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
#else // JIT

View File

@ -20,6 +20,7 @@ struct CudnnConvBackwardWOp : Op {
CudnnConvBackwardWOp(Var* x, Var* y, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "cudnn_conv_backward_w"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
DECLARE_jit_run;
};

View File

@ -12,6 +12,7 @@
#include "cudnn_conv_backward_x_op.h"
#include "cudnn_wrapper.h"
#include "executor.h"
#include "ops/op_register.h"
using namespace std;
@ -72,6 +73,22 @@ void CudnnConvBackwardXOp::jit_prepare(JK& jk) {
jk << _CS("][YFORMAT:") << yformat;
jk << ']';
}
static auto make_conv = get_op_info("cudnn_conv")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, string, string, string>();
static auto make_backwardw = get_op_info("cudnn_conv_backward_w")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
VarPtr CudnnConvBackwardXOp::grad(Var* out, Var* dout, Var* v, int v_index) {
int xn, xc, wh, ww, wci, wco, yn, yc, yd, yh, yw;
w->shape.unpack(wco, wci, wh, ww);
if (v_index == 0) {
return make_backwardw(dout, dy, wh, ww, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
} else {
return make_conv(dout, w, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
}
}
unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
#else // JIT

View File

@ -20,6 +20,7 @@ struct CudnnConvBackwardXOp : Op {
CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "cudnn_conv_backward_x"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
DECLARE_jit_run;
};

View File

@ -9,6 +9,7 @@
#include "cudnn_conv_op.h"
#include "cudnn_wrapper.h"
#include "executor.h"
#include "ops/op_register.h"
using namespace std;
@ -74,6 +75,24 @@ void CudnnConvOp::jit_prepare(JK& jk) {
jk << _CS("][YFORMAT:") << yformat;
jk << ']';
}
static auto make_backwardx = get_op_info("cudnn_conv_backward_x")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
static auto make_backwardw = get_op_info("cudnn_conv_backward_w")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
VarPtr CudnnConvOp::grad(Var* out, Var* dout, Var* v, int v_index) {
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
if (xformat == "ncdhw")
x->shape.unpack(xn, xc, xh, xw);
else
x->shape.unpack(xn, xh, xw, xc);
w->shape.unpack(wco, wci, wh, ww);
if (v_index == 0) {
return make_backwardx(w, dout, xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
} else {
return make_backwardw(x, dout, wh, ww, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
}
}
unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
#else // JIT

View File

@ -17,6 +17,7 @@ struct CudnnConvOp : Op {
CudnnConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh=1, int dilationw=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="");
const char* name() const override { return "cudnn_conv"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
DECLARE_jit_run;
};

View File

@ -803,9 +803,6 @@ class Conv(Module):
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
Kh, Kw = self.kernel_size
self.groups = groups
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
# self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
self.weight = init.invariant_uniform([out_channels, in_channels//groups, Kh, Kw], dtype="float")
@ -1195,8 +1192,7 @@ class ConvTranspose(Module):
# added
self.dilation = dilation
self.group = groups
assert groups==1, "Group conv not supported yet."
self.groups = groups
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
@ -1209,8 +1205,10 @@ class ConvTranspose(Module):
assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \
self.output_padding[1] < max(self.stride[1], self.dilation[1]), \
"output padding must be smaller than max(stride, dilation)"
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float")
self.weight = init.invariant_uniform((in_channels, out_channels//groups) + self.kernel_size, dtype="float")
if bias:
fan=1
for i in self.weight.shape[1:]:
@ -1221,29 +1219,70 @@ class ConvTranspose(Module):
self.bias = None
def execute(self, x):
N,C,H,W = x.shape
i,o,h,w = self.weight.shape
assert C==i
stride_h, stride_w = self.stride
padding_h, padding_w = self.padding
dilation_h, dilation_w = self.dilation
if self.groups == 1:
N,C,H,W = x.shape
i,o,h,w = self.weight.shape
assert C==i
stride_h, stride_w = self.stride
padding_h, padding_w = self.padding
dilation_h, dilation_w = self.dilation
h_out = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
w_out = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
out_shape = (N, o, h_out, w_out)
shape = (N, i, o, H, W, h, w)
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
ww = self.weight.broadcast(shape, (0, 3, 4)) # N,H,W
y = (ww*xx).reindex_reduce("add", out_shape, [
'i0', # N
'i2', # o
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
])
if self.bias is not None:
b = self.bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
else:
N,C,H,W = x.shape
Kh, Kw = self.kernel_size
i,o,h,w = self.weight.shape
oc = self.out_channels
G = self.groups
CpG = C // G # channels per group
assert C==self.in_channels
stride_h, stride_w = self.stride
padding_h, padding_w = self.padding
dilation_h, dilation_w = self.dilation
oh = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
ow = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
out_shape = (N, oc, oh, ow)
shape = [N,G,oc//G,CpG,oh,ow,Kh,Kw]
xx = x.reindex(shape, [
'i0',
f'i1*{oc//G}+i2',
'i4',
'i5'
])
ww = self.weight.reindex(shape, [
f'i1*{oc//G}+i2',
'i3',
'i6',
'i7'
])
ww.compile_options = xx.compile_options = {"G":G,"C":C}
y = (ww*xx).reindex_reduce("add", out_shape, [
'i0', # Nid
f'i1*{CpG}+i3', # Gid
f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid
f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid
])
if self.bias is not None:
b = self.bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
h_out = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
w_out = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
out_shape = (N, o, h_out, w_out)
shape = (N, i, o, H, W, h, w)
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
ww = self.weight.broadcast(shape, (0, 3, 4)) # N,H,W
y = (ww*xx).reindex_reduce("add", out_shape, [
'i0', # N
'i2', # o
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
])
if self.bias is not None:
b = self.bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
class ConvTranspose3d(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \

View File

@ -240,12 +240,12 @@ struct VarHolder {
// @pyjt(__set__data)
inline void set_data(ArrayArgs&& array) {
sync(true);
ASSERT(array.dtype.dsize() == var->dtype().dsize()
CHECK(array.dtype.dsize() == var->dtype().dsize()
&& array.dtype.is_int() == var->dtype().is_int());
int64 size = array.dtype.dsize();
for (int i=0; i<array.shape.size(); i++)
size *= array.shape[i];
ASSERT(size==var->size);
CHECK(size==var->size);
#ifdef HAS_CUDA
migrate_to_cpu(var, exe.allocator);
#endif

View File

@ -29,13 +29,13 @@ class TestConvTranspose(unittest.TestCase):
self.test()
def test(self):
def check(data_shape, weights_shape, stride=1, dilation=1):
def check(data_shape, weights_shape, stride=1, dilation=1, groups=1):
N,C,H,W = data_shape
i,o,h,w = weights_shape
img = np.random.rand(N,C,H,W).astype("float32")
weights = np.random.rand(i,o,h,w).astype("float32")
m1 = jt.nn.ConvTranspose(i,o,h, stride=stride, dilation=dilation, bias=False)
m2 = torch.nn.ConvTranspose2d(i,o,h, stride=stride, dilation=dilation, bias=False)
weights = np.random.rand(i,o//groups,h,w).astype("float32")
m1 = jt.nn.ConvTranspose(i,o,h, stride=stride, dilation=dilation, bias=False, groups=groups)
m2 = torch.nn.ConvTranspose2d(i,o,h, stride=stride, dilation=dilation, bias=False, groups=groups)
m1.weight.data = weights
m2.weight.data = torch.Tensor(weights)
x = jt.array(img)
@ -61,6 +61,7 @@ class TestConvTranspose(unittest.TestCase):
check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2)
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2)
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3)
check((4, 6, 100, 100), (6, 6, 5, 5), 2, 3, 2)
def test_function(self):
def check(data_shape, weights_shape, stride=1, dilation=1):