mirror of https://github.com/Jittor/Jittor
polish conv transpose group
This commit is contained in:
parent
922c0d8246
commit
ac66897047
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.1.26'
|
||||
__version__ = '1.3.1.27'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "cudnn_conv_backward_w_op.h"
|
||||
#include "cudnn_wrapper.h"
|
||||
#include "executor.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
@ -72,6 +73,30 @@ void CudnnConvBackwardWOp::jit_prepare(JK& jk) {
|
|||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
static auto make_conv = get_op_info("cudnn_conv")
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, string, string, string>();
|
||||
static auto make_backwardx = get_op_info("cudnn_conv_backward_x")
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
|
||||
|
||||
VarPtr CudnnConvBackwardWOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
|
||||
if (xformat == "nchw") {
|
||||
x->shape.unpack(xn, xc, xh, xw);
|
||||
dy->shape.unpack(yn, yc, yh, yw);
|
||||
} else {
|
||||
x->shape.unpack(xn, xh, xw, xc);
|
||||
dy->shape.unpack(yn, yh, yw, yc);
|
||||
}
|
||||
|
||||
if (v_index == 0) {
|
||||
return make_backwardx(dout, dy, xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
|
||||
} else {
|
||||
return make_conv(x, dout, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
|
||||
}
|
||||
}
|
||||
|
||||
unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -20,6 +20,7 @@ struct CudnnConvBackwardWOp : Op {
|
|||
CudnnConvBackwardWOp(Var* x, Var* y, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "cudnn_conv_backward_w"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "cudnn_conv_backward_x_op.h"
|
||||
#include "cudnn_wrapper.h"
|
||||
#include "executor.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
@ -72,6 +73,22 @@ void CudnnConvBackwardXOp::jit_prepare(JK& jk) {
|
|||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
static auto make_conv = get_op_info("cudnn_conv")
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, string, string, string>();
|
||||
static auto make_backwardw = get_op_info("cudnn_conv_backward_w")
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
|
||||
|
||||
VarPtr CudnnConvBackwardXOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
int xn, xc, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||
w->shape.unpack(wco, wci, wh, ww);
|
||||
|
||||
if (v_index == 0) {
|
||||
return make_backwardw(dout, dy, wh, ww, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
|
||||
} else {
|
||||
return make_conv(dout, w, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
|
||||
}
|
||||
}
|
||||
unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -20,6 +20,7 @@ struct CudnnConvBackwardXOp : Op {
|
|||
CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "cudnn_conv_backward_x"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "cudnn_conv_op.h"
|
||||
#include "cudnn_wrapper.h"
|
||||
#include "executor.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
@ -74,6 +75,24 @@ void CudnnConvOp::jit_prepare(JK& jk) {
|
|||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
}
|
||||
static auto make_backwardx = get_op_info("cudnn_conv_backward_x")
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
|
||||
static auto make_backwardw = get_op_info("cudnn_conv_backward_w")
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
|
||||
VarPtr CudnnConvOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
if (xformat == "ncdhw")
|
||||
x->shape.unpack(xn, xc, xh, xw);
|
||||
else
|
||||
x->shape.unpack(xn, xh, xw, xc);
|
||||
w->shape.unpack(wco, wci, wh, ww);
|
||||
if (v_index == 0) {
|
||||
return make_backwardx(w, dout, xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
|
||||
} else {
|
||||
return make_backwardw(x, dout, wh, ww, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
|
||||
}
|
||||
}
|
||||
|
||||
unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -17,6 +17,7 @@ struct CudnnConvOp : Op {
|
|||
CudnnConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh=1, int dilationw=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
|
||||
const char* name() const override { return "cudnn_conv"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
|
|
@ -803,9 +803,6 @@ class Conv(Module):
|
|||
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
||||
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
||||
Kh, Kw = self.kernel_size
|
||||
self.groups = groups
|
||||
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
||||
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
||||
|
||||
# self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
|
||||
self.weight = init.invariant_uniform([out_channels, in_channels//groups, Kh, Kw], dtype="float")
|
||||
|
@ -1195,8 +1192,7 @@ class ConvTranspose(Module):
|
|||
|
||||
# added
|
||||
self.dilation = dilation
|
||||
self.group = groups
|
||||
assert groups==1, "Group conv not supported yet."
|
||||
self.groups = groups
|
||||
|
||||
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
||||
|
@ -1209,8 +1205,10 @@ class ConvTranspose(Module):
|
|||
assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \
|
||||
self.output_padding[1] < max(self.stride[1], self.dilation[1]), \
|
||||
"output padding must be smaller than max(stride, dilation)"
|
||||
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
||||
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
||||
|
||||
self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float")
|
||||
self.weight = init.invariant_uniform((in_channels, out_channels//groups) + self.kernel_size, dtype="float")
|
||||
if bias:
|
||||
fan=1
|
||||
for i in self.weight.shape[1:]:
|
||||
|
@ -1221,29 +1219,70 @@ class ConvTranspose(Module):
|
|||
self.bias = None
|
||||
|
||||
def execute(self, x):
|
||||
N,C,H,W = x.shape
|
||||
i,o,h,w = self.weight.shape
|
||||
assert C==i
|
||||
stride_h, stride_w = self.stride
|
||||
padding_h, padding_w = self.padding
|
||||
dilation_h, dilation_w = self.dilation
|
||||
if self.groups == 1:
|
||||
N,C,H,W = x.shape
|
||||
i,o,h,w = self.weight.shape
|
||||
assert C==i
|
||||
stride_h, stride_w = self.stride
|
||||
padding_h, padding_w = self.padding
|
||||
dilation_h, dilation_w = self.dilation
|
||||
|
||||
h_out = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
||||
w_out = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
||||
out_shape = (N, o, h_out, w_out)
|
||||
shape = (N, i, o, H, W, h, w)
|
||||
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
|
||||
ww = self.weight.broadcast(shape, (0, 3, 4)) # N,H,W
|
||||
y = (ww*xx).reindex_reduce("add", out_shape, [
|
||||
'i0', # N
|
||||
'i2', # o
|
||||
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
|
||||
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
|
||||
])
|
||||
if self.bias is not None:
|
||||
b = self.bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
else:
|
||||
N,C,H,W = x.shape
|
||||
Kh, Kw = self.kernel_size
|
||||
i,o,h,w = self.weight.shape
|
||||
oc = self.out_channels
|
||||
G = self.groups
|
||||
CpG = C // G # channels per group
|
||||
assert C==self.in_channels
|
||||
stride_h, stride_w = self.stride
|
||||
padding_h, padding_w = self.padding
|
||||
dilation_h, dilation_w = self.dilation
|
||||
|
||||
oh = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
||||
ow = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
||||
out_shape = (N, oc, oh, ow)
|
||||
shape = [N,G,oc//G,CpG,oh,ow,Kh,Kw]
|
||||
xx = x.reindex(shape, [
|
||||
'i0',
|
||||
f'i1*{oc//G}+i2',
|
||||
'i4',
|
||||
'i5'
|
||||
])
|
||||
ww = self.weight.reindex(shape, [
|
||||
f'i1*{oc//G}+i2',
|
||||
'i3',
|
||||
'i6',
|
||||
'i7'
|
||||
])
|
||||
ww.compile_options = xx.compile_options = {"G":G,"C":C}
|
||||
y = (ww*xx).reindex_reduce("add", out_shape, [
|
||||
'i0', # Nid
|
||||
f'i1*{CpG}+i3', # Gid
|
||||
f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid
|
||||
f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid
|
||||
])
|
||||
if self.bias is not None:
|
||||
b = self.bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
|
||||
h_out = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
||||
w_out = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
||||
out_shape = (N, o, h_out, w_out)
|
||||
shape = (N, i, o, H, W, h, w)
|
||||
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
|
||||
ww = self.weight.broadcast(shape, (0, 3, 4)) # N,H,W
|
||||
y = (ww*xx).reindex_reduce("add", out_shape, [
|
||||
'i0', # N
|
||||
'i2', # o
|
||||
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
|
||||
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
|
||||
])
|
||||
if self.bias is not None:
|
||||
b = self.bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
|
||||
class ConvTranspose3d(Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
|
||||
|
|
|
@ -240,12 +240,12 @@ struct VarHolder {
|
|||
// @pyjt(__set__data)
|
||||
inline void set_data(ArrayArgs&& array) {
|
||||
sync(true);
|
||||
ASSERT(array.dtype.dsize() == var->dtype().dsize()
|
||||
CHECK(array.dtype.dsize() == var->dtype().dsize()
|
||||
&& array.dtype.is_int() == var->dtype().is_int());
|
||||
int64 size = array.dtype.dsize();
|
||||
for (int i=0; i<array.shape.size(); i++)
|
||||
size *= array.shape[i];
|
||||
ASSERT(size==var->size);
|
||||
CHECK(size==var->size);
|
||||
#ifdef HAS_CUDA
|
||||
migrate_to_cpu(var, exe.allocator);
|
||||
#endif
|
||||
|
|
|
@ -29,13 +29,13 @@ class TestConvTranspose(unittest.TestCase):
|
|||
self.test()
|
||||
|
||||
def test(self):
|
||||
def check(data_shape, weights_shape, stride=1, dilation=1):
|
||||
def check(data_shape, weights_shape, stride=1, dilation=1, groups=1):
|
||||
N,C,H,W = data_shape
|
||||
i,o,h,w = weights_shape
|
||||
img = np.random.rand(N,C,H,W).astype("float32")
|
||||
weights = np.random.rand(i,o,h,w).astype("float32")
|
||||
m1 = jt.nn.ConvTranspose(i,o,h, stride=stride, dilation=dilation, bias=False)
|
||||
m2 = torch.nn.ConvTranspose2d(i,o,h, stride=stride, dilation=dilation, bias=False)
|
||||
weights = np.random.rand(i,o//groups,h,w).astype("float32")
|
||||
m1 = jt.nn.ConvTranspose(i,o,h, stride=stride, dilation=dilation, bias=False, groups=groups)
|
||||
m2 = torch.nn.ConvTranspose2d(i,o,h, stride=stride, dilation=dilation, bias=False, groups=groups)
|
||||
m1.weight.data = weights
|
||||
m2.weight.data = torch.Tensor(weights)
|
||||
x = jt.array(img)
|
||||
|
@ -61,6 +61,7 @@ class TestConvTranspose(unittest.TestCase):
|
|||
check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2)
|
||||
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2)
|
||||
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3)
|
||||
check((4, 6, 100, 100), (6, 6, 5, 5), 2, 3, 2)
|
||||
|
||||
def test_function(self):
|
||||
def check(data_shape, weights_shape, stride=1, dilation=1):
|
||||
|
|
Loading…
Reference in New Issue