feat: bidirectional

This commit is contained in:
lzhengning 2021-04-29 13:32:18 +08:00
parent 5f6c170121
commit 18923fcf3d
2 changed files with 116 additions and 54 deletions

View File

@ -1524,9 +1524,7 @@ class LSTM(jt.Module):
self.batch_first = batch_first
self.bidirectional = bidirectional
self.proj_size = proj_size
assert bidirectional == False, 'bidirectional is not supported now'
assert dropout == 0, 'dropout is not supported now'
self.dropout = dropout
num_directions = 1 + bidirectional
k = math.sqrt(1 / hidden_size)
@ -1542,7 +1540,7 @@ class LSTM(jt.Module):
for l in range(num_layers):
if l == 0:
build_unit(f'weight_ih_l{l}', 4 * hidden_size, num_directions * input_size)
build_unit(f'weight_ih_l{l}', 4 * hidden_size, input_size)
else:
if proj_size > 0:
build_unit(f'weight_ih_l{l}', 4 * hidden_size, num_directions * proj_size)
@ -1559,12 +1557,43 @@ class LSTM(jt.Module):
build_unit(f'bias_ih_l{l}', 4 * hidden_size)
build_unit(f'bias_hh_l{l}', 4 * hidden_size)
def call_lstm_sequence(self, input, h, c, suffix):
if 'reverse' in suffix:
input = input[::-1]
output = []
for s in range(input.shape[0]):
y = matmul_transpose(input[s], getattr(self, f'weight_ih_{suffix}'))
y = y + matmul_transpose(h, getattr(self, f'weight_hh_{suffix}'))
if self.bias:
y = y + getattr(self, f'bias_ih_{suffix}') + getattr(self, f'bias_hh_{suffix}')
i = y[:, :self.hidden_size].sigmoid()
f = y[:, self.hidden_size : 2 * self.hidden_size].sigmoid()
g = y[:, 2 * self.hidden_size : 3 * self.hidden_size].tanh()
o = y[:, 3 * self.hidden_size:].sigmoid()
c = f * c + i * g
h = o * c.tanh()
if self.proj_size > 0:
h = matmul_transpose(h, getattr(self, f'weight_hr_{suffix}'))
output.append(h)
if 'reverse' in suffix:
output = output[::-1]
output = jt.stack(output, dim=0)
return output, h, c
def execute(self, input, hx):
if self.batch_first:
input = input.permute(1, 0, 2)
num_directions = 2 if self.bidirectional else 1
if hx is None:
num_directions = 2 if self.bidirectional else 1
real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
h_zeros = jt.zeros(self.num_layers * num_directions,
input.shape[1], real_hidden_size,
@ -1572,37 +1601,34 @@ class LSTM(jt.Module):
c_zeros = jt.zeros(self.num_layers * num_directions,
input.shape[1], self.hidden_size,
dtype=input.dtype, device=input.device)
h, c = h_zeros, c_zeros
h0, c0 = h_zeros, c_zeros
else:
h, c = hx
h0, c0 = hx
output = []
for s in range(input.shape[0]):
for l in range(self.num_layers):
if l == 0:
y = matmul_transpose(input[s], getattr(self, f'weight_ih_l{l}'))
else:
y = matmul_transpose(h[l - 1], getattr(self, f'weight_ih_l{l}'))
hn = []
cn = []
y = y + matmul_transpose(h[l], getattr(self, f'weight_hh_l{l}'))
for l in range(self.num_layers):
output = []
if self.bias:
y = y + getattr(self, f'bias_ih_l{l}') + getattr(self, f'bias_hh_l{l}')
h = h0[l * num_directions: (l + 1) * num_directions]
c = c0[l * num_directions: (l + 1) * num_directions]
i = y[:, :self.hidden_size].sigmoid()
f = y[:, self.hidden_size : 2 * self.hidden_size].sigmoid()
g = y[:, 2 * self.hidden_size : 3 * self.hidden_size].tanh()
o = y[:, 3 * self.hidden_size:].sigmoid()
output, _h, _c = self.call_lstm_sequence(input, h[0], c[0], f'l{l}')
hn.append(_h)
cn.append(_c)
c[l] = f * c[l] + i * g
rh = o * c[l].tanh()
if self.bidirectional:
output_b, _h, _c = self.call_lstm_sequence(input, h[1], c[1], f'l{l}_reverse')
output = jt.concat([output, output_b], dim=-1)
hn.append(_h)
cn.append(_c)
if self.proj_size > 0:
h[l] = matmul_transpose(rh, getattr(self, f'weight_hr_l{l}'))
else:
h[l] = rh
if self.dropout > 0:
input = dropout(output, p=self.dropout)
else:
input = output
output.append(h[-1])
output = jt.stack(output, dim=0)
return output, (h, c)
hn = jt.stack(hn, dim=0)
cn = jt.stack(cn, dim=0)
return output, (hn, cn)

View File

@ -1,8 +1,8 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Zheng-Ning Liu <lzhengning@gmail.com>
#
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Zheng-Ning Liu <lzhengning@gmail.com>
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -27,9 +27,9 @@ except:
def check_equal(t_rnn, j_rnn, input, h0, c0):
j_rnn.load_state_dict(t_rnn.state_dict())
t_output, (th, tc) = t_rnn(torch.from_numpy(input),
t_output, (th, tc) = t_rnn(torch.from_numpy(input),
(torch.from_numpy(h0), torch.from_numpy(c0)))
j_output, (jh, jc) = j_rnn(jt.float32(input),
j_output, (jh, jc) = j_rnn(jt.float32(input),
(jt.float32(h0), jt.float32(c0)))
assert np.allclose(t_output.detach().numpy(), j_output.data, rtol=1e-03, atol=1e-06)
@ -43,7 +43,7 @@ class TestLSTM(unittest.TestCase):
np_h0 = torch.randn(3, 20).numpy()
np_c0 = torch.randn(3, 20).numpy()
t_rnn = tnn.LSTMCell(10, 20)
t_rnn = tnn.LSTMCell(10, 20)
input = torch.randn(2, 3, 10)
h0 = torch.from_numpy(np_h0)
c0 = torch.from_numpy(np_c0)
@ -70,38 +70,74 @@ class TestLSTM(unittest.TestCase):
assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06)
def test_lstm(self):
h0 = np.random.rand(1, 2, 20).astype(np.float32)
c0 = np.random.rand(1, 2, 20).astype(np.float32)
input = np.random.rand(5, 2, 10).astype(np.float32)
h0 = np.random.rand(1, 24, 200).astype(np.float32)
c0 = np.random.rand(1, 24, 200).astype(np.float32)
input = np.random.rand(32, 24, 100).astype(np.float32)
t_rnn = tnn.LSTM(10, 20)
j_rnn = nn.LSTM(10, 20)
t_rnn = tnn.LSTM(100, 200)
j_rnn = nn.LSTM(100, 200)
check_equal(t_rnn, j_rnn, input, h0, c0)
proj_size = 13
h0 = np.random.rand(1, 2, proj_size).astype(np.float32)
c0 = np.random.rand(1, 2, 20).astype(np.float32)
input = np.random.rand(5, 2, 10).astype(np.float32)
t_rnn = tnn.LSTM(10, 20, proj_size=proj_size)
j_rnn = nn.LSTM(10, 20, proj_size=proj_size)
h0 = np.random.rand(1, 24, proj_size).astype(np.float32)
c0 = np.random.rand(1, 24, 200).astype(np.float32)
input = np.random.rand(32, 24, 100).astype(np.float32)
t_rnn = tnn.LSTM(100, 200, proj_size=proj_size)
j_rnn = nn.LSTM(100, 200, proj_size=proj_size)
check_equal(t_rnn, j_rnn, input, h0, c0)
h0 = np.random.rand(2, 4, 20).astype(np.float32)
c0 = np.random.rand(2, 4, 20).astype(np.float32)
input = np.random.rand(5, 4, 10).astype(np.float32)
h0 = np.random.rand(4, 4, 200).astype(np.float32)
c0 = np.random.rand(4, 4, 200).astype(np.float32)
input = np.random.rand(5, 4, 100).astype(np.float32)
t_rnn = tnn.LSTM(10, 20, num_layers=2)
j_rnn = nn.LSTM(10, 20, num_layers=2)
t_rnn = tnn.LSTM(100, 200, num_layers=4)
j_rnn = nn.LSTM(100, 200, num_layers=4)
check_equal(t_rnn, j_rnn, input, h0, c0)
h0 = np.random.rand(2, 4, proj_size).astype(np.float32)
c0 = np.random.rand(2, 4, 20).astype(np.float32)
input = np.random.rand(5, 4, 10).astype(np.float32)
t_rnn = tnn.LSTM(10, 20, num_layers=2, proj_size=proj_size)
t_rnn = tnn.LSTM(10, 20, num_layers=2, proj_size=proj_size)
j_rnn = nn.LSTM(10, 20, num_layers=2, proj_size=proj_size)
check_equal(t_rnn, j_rnn, input, h0, c0)
h0 = np.random.rand(2, 1, 200).astype(np.float32)
c0 = np.random.rand(2, 1, 200).astype(np.float32)
input = np.random.rand(5, 1, 100).astype(np.float32)
t_rnn = tnn.LSTM(100, 200, bidirectional=True)
j_rnn = nn.LSTM(100, 200, bidirectional=True)
check_equal(t_rnn, j_rnn, input, h0, c0)
proj_size = 13
h0 = np.random.rand(2, 4, proj_size).astype(np.float32)
c0 = np.random.rand(2, 4, 200).astype(np.float32)
input = np.random.rand(5, 4, 100).astype(np.float32)
t_rnn = tnn.LSTM(100, 200, bidirectional=True, proj_size=proj_size)
j_rnn = nn.LSTM(100, 200, bidirectional=True, proj_size=proj_size)
check_equal(t_rnn, j_rnn, input, h0, c0)
h0 = np.random.rand(4, 4, 200).astype(np.float32)
c0 = np.random.rand(4, 4, 200).astype(np.float32)
input = np.random.rand(5, 4, 100).astype(np.float32)
t_rnn = tnn.LSTM(100, 200, num_layers=2, bidirectional=True, bias=False)
j_rnn = nn.LSTM(100, 200, num_layers=2, bidirectional=True, bias=False)
check_equal(t_rnn, j_rnn, input, h0, c0)
h0 = np.random.rand(2, 4, 200).astype(np.float32)
c0 = np.random.rand(2, 4, 200).astype(np.float32)
input = np.random.rand(5, 4, 100).astype(np.float32)
t_rnn = tnn.LSTM(100, 200, num_layers=2, dropout=0.5, bias=False)
j_rnn = nn.LSTM(100, 200, num_layers=2, dropout=0.5, bias=False)
t_rnn.eval()
j_rnn.eval()
check_equal(t_rnn, j_rnn, input, h0, c0)
if __name__ == "__main__":
unittest.main()