mirror of https://github.com/Jittor/Jittor
add function version conv transpose
This commit is contained in:
parent
f20fa98e44
commit
85ba4898cf
|
@ -1091,6 +1091,7 @@ def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
|||
b = bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
conv = conv2d
|
||||
|
||||
def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
''' Applies a 3D convolution over an input signal composed of several input planes.
|
||||
|
@ -1324,42 +1325,91 @@ class ConvTranspose3d(Module):
|
|||
return conv_transpose3d(x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.group, self.dilation)
|
||||
|
||||
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
||||
x = input
|
||||
N,C,H,W = x.shape
|
||||
i,o,h,w = weight.shape
|
||||
assert C==i
|
||||
assert groups==1, "Group conv not supported yet."
|
||||
stride = stride if isinstance(stride, tuple) else (stride, stride)
|
||||
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
||||
# added
|
||||
padding = padding if isinstance(padding, tuple) else (padding, padding)
|
||||
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
|
||||
assert output_padding[0] < max(stride[0], dilation[0]) and \
|
||||
output_padding[1] < max(stride[1], dilation[1]), \
|
||||
"output padding must be smaller than max(stride, dilation)"
|
||||
if groups == 1:
|
||||
x = input
|
||||
N,C,H,W = x.shape
|
||||
i,o,h,w = weight.shape
|
||||
assert C==i
|
||||
stride = stride if isinstance(stride, tuple) else (stride, stride)
|
||||
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
||||
# added
|
||||
padding = padding if isinstance(padding, tuple) else (padding, padding)
|
||||
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
|
||||
assert output_padding[0] < max(stride[0], dilation[0]) and \
|
||||
output_padding[1] < max(stride[1], dilation[1]), \
|
||||
"output padding must be smaller than max(stride, dilation)"
|
||||
|
||||
stride_h, stride_w = stride
|
||||
padding_h, padding_w = padding
|
||||
dilation_h, dilation_w = dilation
|
||||
stride_h, stride_w = stride
|
||||
padding_h, padding_w = padding
|
||||
dilation_h, dilation_w = dilation
|
||||
|
||||
h_out = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
||||
w_out = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
||||
out_shape = (N, o, h_out, w_out)
|
||||
shape = (N, i, o, H, W, h, w)
|
||||
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
|
||||
ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W
|
||||
y = (ww*xx).reindex_reduce("add", out_shape, [
|
||||
'i0', # N
|
||||
'i2', # o
|
||||
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
|
||||
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
|
||||
])
|
||||
if isinstance(bias, jt.Var):
|
||||
b = bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
h_out = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
||||
w_out = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
||||
out_shape = (N, o, h_out, w_out)
|
||||
shape = (N, i, o, H, W, h, w)
|
||||
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
|
||||
ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W
|
||||
y = (ww*xx).reindex_reduce("add", out_shape, [
|
||||
'i0', # N
|
||||
'i2', # o
|
||||
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
|
||||
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
|
||||
])
|
||||
if isinstance(bias, jt.Var):
|
||||
b = bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
else:
|
||||
assert not bias, "Bias should be none or jittor var"
|
||||
return y
|
||||
else:
|
||||
assert not bias, "Bias should be none or jittor var"
|
||||
return y
|
||||
N,C,H,W = input.shape
|
||||
i,o,h,w = weight.shape
|
||||
G = groups
|
||||
oc = o * G
|
||||
CpG = C // G # channels per group
|
||||
assert C % G == 0
|
||||
assert C==i, (C, i)
|
||||
stride = stride if isinstance(stride, tuple) else (stride, stride)
|
||||
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
||||
# added
|
||||
padding = padding if isinstance(padding, tuple) else (padding, padding)
|
||||
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
|
||||
assert output_padding[0] < max(stride[0], dilation[0]) and \
|
||||
output_padding[1] < max(stride[1], dilation[1]), \
|
||||
"output padding must be smaller than max(stride, dilation)"
|
||||
|
||||
stride_h, stride_w = stride
|
||||
padding_h, padding_w = padding
|
||||
dilation_h, dilation_w = dilation
|
||||
|
||||
oh = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
||||
ow = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
||||
out_shape = (N, oc, oh, ow)
|
||||
shape = [N,G,oc//G,CpG,oh,ow,h,w]
|
||||
xx = input.reindex(shape, [
|
||||
'i0',
|
||||
f'i1*{oc//G}+i2',
|
||||
'i4',
|
||||
'i5'
|
||||
])
|
||||
ww = weight.reindex(shape, [
|
||||
f'i1*{oc//G}+i2',
|
||||
'i3',
|
||||
'i6',
|
||||
'i7'
|
||||
])
|
||||
ww.compile_options = xx.compile_options = {"G":G,"C":C}
|
||||
y = (ww*xx).reindex_reduce("add", out_shape, [
|
||||
'i0', # Nid
|
||||
f'i1*{CpG}+i3', # Gid
|
||||
f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid
|
||||
f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid
|
||||
])
|
||||
if bias is not None:
|
||||
b = bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
conv_transpose2d = conv_transpose
|
||||
|
||||
def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
||||
x = input
|
||||
|
|
|
@ -230,6 +230,20 @@ class TestReindexOp(unittest.TestCase):
|
|||
check_fused(len(x.shape))
|
||||
npy = conv_transpose_naive(x.data, w.data)
|
||||
assert np.allclose(npy, ny), (np.where(np.abs(npy-ny)>1e-4), npy[0,:4,:4,0], ny[0,:4,:4,0])
|
||||
|
||||
|
||||
def test_conv_transpose_group(self):
|
||||
N,C,H,W = 3,6,10,10
|
||||
i,o,h,w = 6,2,3,3
|
||||
g = 2
|
||||
x = jt.random([N,C,H,W])
|
||||
ww = jt.random([i,o,h,w])
|
||||
ct = jt.nn.ConvTranspose(i,o*g,(h,w), groups=2, bias=False)
|
||||
assert ct.weight.shape == ww.shape, (ct.weight.shape, ww.shape)
|
||||
ct.weight = ww
|
||||
y = ct(x)
|
||||
y2 = jt.nn.conv_transpose(x, ww, groups=2)
|
||||
np.testing.assert_allclose(y.data, y2.data)
|
||||
|
||||
def test_conv_transpose_grad(self):
|
||||
N,H,W,C = 1,5,5,2
|
||||
|
|
Loading…
Reference in New Issue