add function version conv transpose

This commit is contained in:
Dun Liang 2022-01-04 17:29:58 +08:00
parent f20fa98e44
commit 85ba4898cf
2 changed files with 97 additions and 33 deletions

View File

@ -1091,6 +1091,7 @@ def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
b = bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
conv = conv2d
def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
''' Applies a 3D convolution over an input signal composed of several input planes.
@ -1324,42 +1325,91 @@ class ConvTranspose3d(Module):
return conv_transpose3d(x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.group, self.dilation)
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
x = input
N,C,H,W = x.shape
i,o,h,w = weight.shape
assert C==i
assert groups==1, "Group conv not supported yet."
stride = stride if isinstance(stride, tuple) else (stride, stride)
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
# added
padding = padding if isinstance(padding, tuple) else (padding, padding)
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
assert output_padding[0] < max(stride[0], dilation[0]) and \
output_padding[1] < max(stride[1], dilation[1]), \
"output padding must be smaller than max(stride, dilation)"
if groups == 1:
x = input
N,C,H,W = x.shape
i,o,h,w = weight.shape
assert C==i
stride = stride if isinstance(stride, tuple) else (stride, stride)
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
# added
padding = padding if isinstance(padding, tuple) else (padding, padding)
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
assert output_padding[0] < max(stride[0], dilation[0]) and \
output_padding[1] < max(stride[1], dilation[1]), \
"output padding must be smaller than max(stride, dilation)"
stride_h, stride_w = stride
padding_h, padding_w = padding
dilation_h, dilation_w = dilation
stride_h, stride_w = stride
padding_h, padding_w = padding
dilation_h, dilation_w = dilation
h_out = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
w_out = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
out_shape = (N, o, h_out, w_out)
shape = (N, i, o, H, W, h, w)
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W
y = (ww*xx).reindex_reduce("add", out_shape, [
'i0', # N
'i2', # o
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
])
if isinstance(bias, jt.Var):
b = bias.broadcast(y.shape, [0,2,3])
y = y + b
h_out = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
w_out = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
out_shape = (N, o, h_out, w_out)
shape = (N, i, o, H, W, h, w)
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W
y = (ww*xx).reindex_reduce("add", out_shape, [
'i0', # N
'i2', # o
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
])
if isinstance(bias, jt.Var):
b = bias.broadcast(y.shape, [0,2,3])
y = y + b
else:
assert not bias, "Bias should be none or jittor var"
return y
else:
assert not bias, "Bias should be none or jittor var"
return y
N,C,H,W = input.shape
i,o,h,w = weight.shape
G = groups
oc = o * G
CpG = C // G # channels per group
assert C % G == 0
assert C==i, (C, i)
stride = stride if isinstance(stride, tuple) else (stride, stride)
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
# added
padding = padding if isinstance(padding, tuple) else (padding, padding)
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
assert output_padding[0] < max(stride[0], dilation[0]) and \
output_padding[1] < max(stride[1], dilation[1]), \
"output padding must be smaller than max(stride, dilation)"
stride_h, stride_w = stride
padding_h, padding_w = padding
dilation_h, dilation_w = dilation
oh = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
ow = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
out_shape = (N, oc, oh, ow)
shape = [N,G,oc//G,CpG,oh,ow,h,w]
xx = input.reindex(shape, [
'i0',
f'i1*{oc//G}+i2',
'i4',
'i5'
])
ww = weight.reindex(shape, [
f'i1*{oc//G}+i2',
'i3',
'i6',
'i7'
])
ww.compile_options = xx.compile_options = {"G":G,"C":C}
y = (ww*xx).reindex_reduce("add", out_shape, [
'i0', # Nid
f'i1*{CpG}+i3', # Gid
f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid
f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid
])
if bias is not None:
b = bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
conv_transpose2d = conv_transpose
def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
x = input

View File

@ -230,6 +230,20 @@ class TestReindexOp(unittest.TestCase):
check_fused(len(x.shape))
npy = conv_transpose_naive(x.data, w.data)
assert np.allclose(npy, ny), (np.where(np.abs(npy-ny)>1e-4), npy[0,:4,:4,0], ny[0,:4,:4,0])
def test_conv_transpose_group(self):
N,C,H,W = 3,6,10,10
i,o,h,w = 6,2,3,3
g = 2
x = jt.random([N,C,H,W])
ww = jt.random([i,o,h,w])
ct = jt.nn.ConvTranspose(i,o*g,(h,w), groups=2, bias=False)
assert ct.weight.shape == ww.shape, (ct.weight.shape, ww.shape)
ct.weight = ww
y = ct(x)
y2 = jt.nn.conv_transpose(x, ww, groups=2)
np.testing.assert_allclose(y.data, y2.data)
def test_conv_transpose_grad(self):
N,H,W,C = 1,5,5,2