mirror of https://github.com/Jittor/Jittor
feat: bidirectional
This commit is contained in:
parent
5f6c170121
commit
18923fcf3d
|
@ -1524,9 +1524,7 @@ class LSTM(jt.Module):
|
|||
self.batch_first = batch_first
|
||||
self.bidirectional = bidirectional
|
||||
self.proj_size = proj_size
|
||||
|
||||
assert bidirectional == False, 'bidirectional is not supported now'
|
||||
assert dropout == 0, 'dropout is not supported now'
|
||||
self.dropout = dropout
|
||||
|
||||
num_directions = 1 + bidirectional
|
||||
k = math.sqrt(1 / hidden_size)
|
||||
|
@ -1542,7 +1540,7 @@ class LSTM(jt.Module):
|
|||
|
||||
for l in range(num_layers):
|
||||
if l == 0:
|
||||
build_unit(f'weight_ih_l{l}', 4 * hidden_size, num_directions * input_size)
|
||||
build_unit(f'weight_ih_l{l}', 4 * hidden_size, input_size)
|
||||
else:
|
||||
if proj_size > 0:
|
||||
build_unit(f'weight_ih_l{l}', 4 * hidden_size, num_directions * proj_size)
|
||||
|
@ -1559,12 +1557,43 @@ class LSTM(jt.Module):
|
|||
build_unit(f'bias_ih_l{l}', 4 * hidden_size)
|
||||
build_unit(f'bias_hh_l{l}', 4 * hidden_size)
|
||||
|
||||
def call_lstm_sequence(self, input, h, c, suffix):
|
||||
if 'reverse' in suffix:
|
||||
input = input[::-1]
|
||||
|
||||
output = []
|
||||
for s in range(input.shape[0]):
|
||||
y = matmul_transpose(input[s], getattr(self, f'weight_ih_{suffix}'))
|
||||
y = y + matmul_transpose(h, getattr(self, f'weight_hh_{suffix}'))
|
||||
|
||||
if self.bias:
|
||||
y = y + getattr(self, f'bias_ih_{suffix}') + getattr(self, f'bias_hh_{suffix}')
|
||||
|
||||
i = y[:, :self.hidden_size].sigmoid()
|
||||
f = y[:, self.hidden_size : 2 * self.hidden_size].sigmoid()
|
||||
g = y[:, 2 * self.hidden_size : 3 * self.hidden_size].tanh()
|
||||
o = y[:, 3 * self.hidden_size:].sigmoid()
|
||||
c = f * c + i * g
|
||||
h = o * c.tanh()
|
||||
|
||||
if self.proj_size > 0:
|
||||
h = matmul_transpose(h, getattr(self, f'weight_hr_{suffix}'))
|
||||
|
||||
output.append(h)
|
||||
|
||||
if 'reverse' in suffix:
|
||||
output = output[::-1]
|
||||
output = jt.stack(output, dim=0)
|
||||
|
||||
return output, h, c
|
||||
|
||||
def execute(self, input, hx):
|
||||
if self.batch_first:
|
||||
input = input.permute(1, 0, 2)
|
||||
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
|
||||
if hx is None:
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
|
||||
h_zeros = jt.zeros(self.num_layers * num_directions,
|
||||
input.shape[1], real_hidden_size,
|
||||
|
@ -1572,37 +1601,34 @@ class LSTM(jt.Module):
|
|||
c_zeros = jt.zeros(self.num_layers * num_directions,
|
||||
input.shape[1], self.hidden_size,
|
||||
dtype=input.dtype, device=input.device)
|
||||
h, c = h_zeros, c_zeros
|
||||
h0, c0 = h_zeros, c_zeros
|
||||
else:
|
||||
h, c = hx
|
||||
h0, c0 = hx
|
||||
|
||||
output = []
|
||||
for s in range(input.shape[0]):
|
||||
for l in range(self.num_layers):
|
||||
if l == 0:
|
||||
y = matmul_transpose(input[s], getattr(self, f'weight_ih_l{l}'))
|
||||
else:
|
||||
y = matmul_transpose(h[l - 1], getattr(self, f'weight_ih_l{l}'))
|
||||
hn = []
|
||||
cn = []
|
||||
|
||||
y = y + matmul_transpose(h[l], getattr(self, f'weight_hh_l{l}'))
|
||||
for l in range(self.num_layers):
|
||||
output = []
|
||||
|
||||
if self.bias:
|
||||
y = y + getattr(self, f'bias_ih_l{l}') + getattr(self, f'bias_hh_l{l}')
|
||||
h = h0[l * num_directions: (l + 1) * num_directions]
|
||||
c = c0[l * num_directions: (l + 1) * num_directions]
|
||||
|
||||
i = y[:, :self.hidden_size].sigmoid()
|
||||
f = y[:, self.hidden_size : 2 * self.hidden_size].sigmoid()
|
||||
g = y[:, 2 * self.hidden_size : 3 * self.hidden_size].tanh()
|
||||
o = y[:, 3 * self.hidden_size:].sigmoid()
|
||||
output, _h, _c = self.call_lstm_sequence(input, h[0], c[0], f'l{l}')
|
||||
hn.append(_h)
|
||||
cn.append(_c)
|
||||
|
||||
c[l] = f * c[l] + i * g
|
||||
rh = o * c[l].tanh()
|
||||
if self.bidirectional:
|
||||
output_b, _h, _c = self.call_lstm_sequence(input, h[1], c[1], f'l{l}_reverse')
|
||||
output = jt.concat([output, output_b], dim=-1)
|
||||
hn.append(_h)
|
||||
cn.append(_c)
|
||||
|
||||
if self.proj_size > 0:
|
||||
h[l] = matmul_transpose(rh, getattr(self, f'weight_hr_l{l}'))
|
||||
else:
|
||||
h[l] = rh
|
||||
if self.dropout > 0:
|
||||
input = dropout(output, p=self.dropout)
|
||||
else:
|
||||
input = output
|
||||
|
||||
output.append(h[-1])
|
||||
|
||||
output = jt.stack(output, dim=0)
|
||||
return output, (h, c)
|
||||
hn = jt.stack(hn, dim=0)
|
||||
cn = jt.stack(cn, dim=0)
|
||||
return output, (hn, cn)
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Zheng-Ning Liu <lzhengning@gmail.com>
|
||||
#
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Zheng-Ning Liu <lzhengning@gmail.com>
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -27,9 +27,9 @@ except:
|
|||
def check_equal(t_rnn, j_rnn, input, h0, c0):
|
||||
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||
|
||||
t_output, (th, tc) = t_rnn(torch.from_numpy(input),
|
||||
t_output, (th, tc) = t_rnn(torch.from_numpy(input),
|
||||
(torch.from_numpy(h0), torch.from_numpy(c0)))
|
||||
j_output, (jh, jc) = j_rnn(jt.float32(input),
|
||||
j_output, (jh, jc) = j_rnn(jt.float32(input),
|
||||
(jt.float32(h0), jt.float32(c0)))
|
||||
|
||||
assert np.allclose(t_output.detach().numpy(), j_output.data, rtol=1e-03, atol=1e-06)
|
||||
|
@ -43,7 +43,7 @@ class TestLSTM(unittest.TestCase):
|
|||
np_h0 = torch.randn(3, 20).numpy()
|
||||
np_c0 = torch.randn(3, 20).numpy()
|
||||
|
||||
t_rnn = tnn.LSTMCell(10, 20)
|
||||
t_rnn = tnn.LSTMCell(10, 20)
|
||||
input = torch.randn(2, 3, 10)
|
||||
h0 = torch.from_numpy(np_h0)
|
||||
c0 = torch.from_numpy(np_c0)
|
||||
|
@ -70,38 +70,74 @@ class TestLSTM(unittest.TestCase):
|
|||
assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06)
|
||||
|
||||
def test_lstm(self):
|
||||
h0 = np.random.rand(1, 2, 20).astype(np.float32)
|
||||
c0 = np.random.rand(1, 2, 20).astype(np.float32)
|
||||
input = np.random.rand(5, 2, 10).astype(np.float32)
|
||||
h0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||
c0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||
input = np.random.rand(32, 24, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.LSTM(10, 20)
|
||||
j_rnn = nn.LSTM(10, 20)
|
||||
t_rnn = tnn.LSTM(100, 200)
|
||||
j_rnn = nn.LSTM(100, 200)
|
||||
check_equal(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
proj_size = 13
|
||||
h0 = np.random.rand(1, 2, proj_size).astype(np.float32)
|
||||
c0 = np.random.rand(1, 2, 20).astype(np.float32)
|
||||
input = np.random.rand(5, 2, 10).astype(np.float32)
|
||||
t_rnn = tnn.LSTM(10, 20, proj_size=proj_size)
|
||||
j_rnn = nn.LSTM(10, 20, proj_size=proj_size)
|
||||
h0 = np.random.rand(1, 24, proj_size).astype(np.float32)
|
||||
c0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||
input = np.random.rand(32, 24, 100).astype(np.float32)
|
||||
t_rnn = tnn.LSTM(100, 200, proj_size=proj_size)
|
||||
j_rnn = nn.LSTM(100, 200, proj_size=proj_size)
|
||||
check_equal(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
h0 = np.random.rand(2, 4, 20).astype(np.float32)
|
||||
c0 = np.random.rand(2, 4, 20).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 10).astype(np.float32)
|
||||
h0 = np.random.rand(4, 4, 200).astype(np.float32)
|
||||
c0 = np.random.rand(4, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.LSTM(10, 20, num_layers=2)
|
||||
j_rnn = nn.LSTM(10, 20, num_layers=2)
|
||||
t_rnn = tnn.LSTM(100, 200, num_layers=4)
|
||||
j_rnn = nn.LSTM(100, 200, num_layers=4)
|
||||
check_equal(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
h0 = np.random.rand(2, 4, proj_size).astype(np.float32)
|
||||
c0 = np.random.rand(2, 4, 20).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 10).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.LSTM(10, 20, num_layers=2, proj_size=proj_size)
|
||||
t_rnn = tnn.LSTM(10, 20, num_layers=2, proj_size=proj_size)
|
||||
j_rnn = nn.LSTM(10, 20, num_layers=2, proj_size=proj_size)
|
||||
check_equal(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
h0 = np.random.rand(2, 1, 200).astype(np.float32)
|
||||
c0 = np.random.rand(2, 1, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 1, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.LSTM(100, 200, bidirectional=True)
|
||||
j_rnn = nn.LSTM(100, 200, bidirectional=True)
|
||||
check_equal(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
proj_size = 13
|
||||
h0 = np.random.rand(2, 4, proj_size).astype(np.float32)
|
||||
c0 = np.random.rand(2, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.LSTM(100, 200, bidirectional=True, proj_size=proj_size)
|
||||
j_rnn = nn.LSTM(100, 200, bidirectional=True, proj_size=proj_size)
|
||||
check_equal(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
h0 = np.random.rand(4, 4, 200).astype(np.float32)
|
||||
c0 = np.random.rand(4, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.LSTM(100, 200, num_layers=2, bidirectional=True, bias=False)
|
||||
j_rnn = nn.LSTM(100, 200, num_layers=2, bidirectional=True, bias=False)
|
||||
check_equal(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
|
||||
h0 = np.random.rand(2, 4, 200).astype(np.float32)
|
||||
c0 = np.random.rand(2, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.LSTM(100, 200, num_layers=2, dropout=0.5, bias=False)
|
||||
j_rnn = nn.LSTM(100, 200, num_layers=2, dropout=0.5, bias=False)
|
||||
t_rnn.eval()
|
||||
j_rnn.eval()
|
||||
check_equal(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue