add conv transpose function

This commit is contained in:
Dun Liang 2020-12-13 19:03:43 +08:00
parent fa64a66dad
commit c8252ac7fb
3 changed files with 75 additions and 1 deletions

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.2'
__version__ = '1.2.2.3'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -707,6 +707,45 @@ class ConvTranspose(Module):
y = y + b
return y
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
x = input
N,C,H,W = x.shape
i,o,h,w = weight.shape
assert C==i
assert groups==1, "Group conv not supported yet."
stride = stride if isinstance(stride, tuple) else (stride, stride)
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
# added
padding = padding if isinstance(padding, tuple) else (padding, padding)
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
assert output_padding[0] < max(stride[0], dilation[0]) and \
output_padding[1] < max(stride[1], dilation[1]), \
"output padding must be smaller than max(stride, dilation)"
stride_h, stride_w = stride
padding_h, padding_w = padding
dilation_h, dilation_w = dilation
h_out = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
w_out = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
out_shape = (N, o, h_out, w_out)
shape = (N, i, o, H, W, h, w)
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W
y = (ww*xx).reindex_reduce("add", out_shape, [
'i0', # N
'i2', # o
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
])
if isinstance(bias, jt.Var):
b = bias.broadcast(y.shape, [0,2,3])
y = y + b
else:
assert not bias, "Bias should be none or jittor var"
return y
conv_transpose2d = conv_transpose
def pad(x,padding, mode='constant', value=0):
assert mode in ['constant','replicate','reflect','circular'],'only support constant,replicate,reflect,circular pad'

View File

@ -60,6 +60,41 @@ class TestConvTranspose(unittest.TestCase):
check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2)
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2)
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3)
def test_function(self):
def check(data_shape, weights_shape, stride=1, dilation=1):
N,C,H,W = data_shape
i,o,h,w = weights_shape
img = np.random.rand(N,C,H,W).astype("float32")
weights = np.random.rand(i,o,h,w).astype("float32")
m1 = jt.nn.ConvTranspose(i,o,h, stride=stride, dilation=dilation, bias=False)
m2 = torch.nn.ConvTranspose2d(i,o,h, stride=stride, dilation=dilation, bias=False)
m1.weight.data = weights
m2.weight.data = torch.Tensor(weights)
x = jt.array(img)
# out1 = m1(x)
out1 = jt.nn.conv_transpose2d(x, m1.weight, stride=stride, dilation=dilation, bias=False)
mask = jt.random(out1.shape)
out1 = out1*mask
tx = torch.Tensor(img)
tx.requires_grad = True
out2 = m2(tx) * torch.Tensor(mask.data)
with jt.log_capture_scope(log_silent=1,
log_vprefix="var_re=0,conv=0,op.cc=100") as logs:
assert np.allclose(out1.data, out2.data)
dx, dw = jt.grad(out1, [x, m1.weight])
jt.sync([dx, dw])
out2.sum().backward()
assert np.allclose(dw.data, m2.weight.grad.numpy(), 1e-3)
assert np.allclose(dx.data, tx.grad.numpy())
assert len(find_log_with_re(logs, "conv")) == 3
check((4, 5, 10, 10), (5, 6, 3, 3))
check((4, 5, 10, 10), (5, 6, 3, 3), 2)
check((4, 5, 100, 100), (5, 6, 4, 4), 2)
check((4, 5, 100, 100), (5, 6, 4, 4), 3)
check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2)
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2)
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3)
if __name__ == "__main__":
unittest.main()