mirror of https://github.com/Jittor/Jittor
polish conv1d bias
This commit is contained in:
parent
40ce9e605a
commit
8616c01a74
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.41'
|
||||
__version__ = '1.2.2.42'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -621,6 +621,7 @@ class Conv1d(Module):
|
|||
# using list to escape module dfs
|
||||
self._conv = [Conv(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias)]
|
||||
self.weight = self._conv[0].weight.squeeze(-1)
|
||||
self.bias = self._conv[0].bias
|
||||
|
||||
def execute(self, x):
|
||||
N,C,D = x.shape
|
||||
|
|
|
@ -105,7 +105,7 @@ class TestConvTranspose(unittest.TestCase):
|
|||
assert b.shape == [3,20,11]
|
||||
b = jt.nn.Conv1d(10,20,5, padding=2)(a)
|
||||
assert b.shape == [3,20,15]
|
||||
assert list(conv1d.state_dict().keys()) == ['weight']
|
||||
assert sorted(list(conv1d.state_dict().keys())) == ['bias', 'weight'], conv1d.state_dict().keys()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue