mirror of https://github.com/Jittor/Jittor
add conv transpose function
This commit is contained in:
parent
fa64a66dad
commit
c8252ac7fb
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.2'
|
||||
__version__ = '1.2.2.3'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -707,6 +707,45 @@ class ConvTranspose(Module):
|
|||
y = y + b
|
||||
return y
|
||||
|
||||
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
||||
x = input
|
||||
N,C,H,W = x.shape
|
||||
i,o,h,w = weight.shape
|
||||
assert C==i
|
||||
assert groups==1, "Group conv not supported yet."
|
||||
stride = stride if isinstance(stride, tuple) else (stride, stride)
|
||||
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
||||
# added
|
||||
padding = padding if isinstance(padding, tuple) else (padding, padding)
|
||||
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
|
||||
assert output_padding[0] < max(stride[0], dilation[0]) and \
|
||||
output_padding[1] < max(stride[1], dilation[1]), \
|
||||
"output padding must be smaller than max(stride, dilation)"
|
||||
|
||||
stride_h, stride_w = stride
|
||||
padding_h, padding_w = padding
|
||||
dilation_h, dilation_w = dilation
|
||||
|
||||
h_out = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
||||
w_out = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
||||
out_shape = (N, o, h_out, w_out)
|
||||
shape = (N, i, o, H, W, h, w)
|
||||
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
|
||||
ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W
|
||||
y = (ww*xx).reindex_reduce("add", out_shape, [
|
||||
'i0', # N
|
||||
'i2', # o
|
||||
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
|
||||
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
|
||||
])
|
||||
if isinstance(bias, jt.Var):
|
||||
b = bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
else:
|
||||
assert not bias, "Bias should be none or jittor var"
|
||||
return y
|
||||
|
||||
conv_transpose2d = conv_transpose
|
||||
|
||||
def pad(x,padding, mode='constant', value=0):
|
||||
assert mode in ['constant','replicate','reflect','circular'],'only support constant,replicate,reflect,circular pad'
|
||||
|
|
|
@ -60,6 +60,41 @@ class TestConvTranspose(unittest.TestCase):
|
|||
check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2)
|
||||
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2)
|
||||
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3)
|
||||
|
||||
def test_function(self):
|
||||
def check(data_shape, weights_shape, stride=1, dilation=1):
|
||||
N,C,H,W = data_shape
|
||||
i,o,h,w = weights_shape
|
||||
img = np.random.rand(N,C,H,W).astype("float32")
|
||||
weights = np.random.rand(i,o,h,w).astype("float32")
|
||||
m1 = jt.nn.ConvTranspose(i,o,h, stride=stride, dilation=dilation, bias=False)
|
||||
m2 = torch.nn.ConvTranspose2d(i,o,h, stride=stride, dilation=dilation, bias=False)
|
||||
m1.weight.data = weights
|
||||
m2.weight.data = torch.Tensor(weights)
|
||||
x = jt.array(img)
|
||||
# out1 = m1(x)
|
||||
out1 = jt.nn.conv_transpose2d(x, m1.weight, stride=stride, dilation=dilation, bias=False)
|
||||
mask = jt.random(out1.shape)
|
||||
out1 = out1*mask
|
||||
tx = torch.Tensor(img)
|
||||
tx.requires_grad = True
|
||||
out2 = m2(tx) * torch.Tensor(mask.data)
|
||||
with jt.log_capture_scope(log_silent=1,
|
||||
log_vprefix="var_re=0,conv=0,op.cc=100") as logs:
|
||||
assert np.allclose(out1.data, out2.data)
|
||||
dx, dw = jt.grad(out1, [x, m1.weight])
|
||||
jt.sync([dx, dw])
|
||||
out2.sum().backward()
|
||||
assert np.allclose(dw.data, m2.weight.grad.numpy(), 1e-3)
|
||||
assert np.allclose(dx.data, tx.grad.numpy())
|
||||
assert len(find_log_with_re(logs, "conv")) == 3
|
||||
check((4, 5, 10, 10), (5, 6, 3, 3))
|
||||
check((4, 5, 10, 10), (5, 6, 3, 3), 2)
|
||||
check((4, 5, 100, 100), (5, 6, 4, 4), 2)
|
||||
check((4, 5, 100, 100), (5, 6, 4, 4), 3)
|
||||
check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2)
|
||||
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2)
|
||||
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue