polish conv1d bias

This commit is contained in:
Dun Liang 2021-03-10 15:25:55 +08:00
parent 40ce9e605a
commit 8616c01a74
3 changed files with 3 additions and 2 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.41'
__version__ = '1.2.2.42'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -621,6 +621,7 @@ class Conv1d(Module):
# using list to escape module dfs
self._conv = [Conv(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias)]
self.weight = self._conv[0].weight.squeeze(-1)
self.bias = self._conv[0].bias
def execute(self, x):
N,C,D = x.shape

View File

@ -105,7 +105,7 @@ class TestConvTranspose(unittest.TestCase):
assert b.shape == [3,20,11]
b = jt.nn.Conv1d(10,20,5, padding=2)(a)
assert b.shape == [3,20,15]
assert list(conv1d.state_dict().keys()) == ['weight']
assert sorted(list(conv1d.state_dict().keys())) == ['bias', 'weight'], conv1d.state_dict().keys()
if __name__ == "__main__":
unittest.main()