polish conv1d weight

This commit is contained in:
Dun Liang 2021-03-10 13:35:06 +08:00
parent a0f2516626
commit 40ce9e605a
3 changed files with 16 additions and 3 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.40'
__version__ = '1.2.2.41'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -618,13 +618,16 @@ class Conv1d(Module):
self.bias = bias
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
self.conv = Conv(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias)
# using list to escape module dfs
self._conv = [Conv(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias)]
self.weight = self._conv[0].weight.squeeze(-1)
def execute(self, x):
N,C,D = x.shape
assert C==self.in_channels
self._conv[0].weight = self.weight.unsqueeze(-1)
x = x.unsqueeze(-1)
x = self.conv(x)
x = self._conv[0](x)
y = x.squeeze(-1)
return y

View File

@ -96,6 +96,16 @@ class TestConvTranspose(unittest.TestCase):
check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2)
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2)
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3)
def test_conv1d(self):
conv1d = jt.nn.Conv1d(10,20,5)
a = jt.rand((3,10,15))
b = conv1d(a)
b.sync()
assert b.shape == [3,20,11]
b = jt.nn.Conv1d(10,20,5, padding=2)(a)
assert b.shape == [3,20,15]
assert list(conv1d.state_dict().keys()) == ['weight']
if __name__ == "__main__":
unittest.main()