mirror of https://github.com/Jittor/Jittor
feat: add lstm (#202)
This commit is contained in:
parent
4a14967cc5
commit
97e0652650
|
@ -6,7 +6,7 @@
|
||||||
# Wenyang Zhou <576825820@qq.com>
|
# Wenyang Zhou <576825820@qq.com>
|
||||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||||
# Dun Liang <randonlang@gmail.com>.
|
# Dun Liang <randonlang@gmail.com>.
|
||||||
#
|
# Zheng-Ning Liu <lzhengning@gmail.com>
|
||||||
#
|
#
|
||||||
# This file is subject to the terms and conditions defined in
|
# This file is subject to the terms and conditions defined in
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
@ -1417,3 +1417,192 @@ def fold(X,output_size,kernel_size,dilation=1,padding=0,stride=1):
|
||||||
return output[:,:,padding[0]:padding[0]+output_size[0],padding[1]:padding[1]+output_size[1]]
|
return output[:,:,padding[0]:padding[0]+output_size[0],padding[1]:padding[1]+output_size[1]]
|
||||||
|
|
||||||
ModuleList = Sequential
|
ModuleList = Sequential
|
||||||
|
|
||||||
|
|
||||||
|
class LSTMCell(jt.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, bias=True):
|
||||||
|
''' A long short-term memory (LSTM) cell.
|
||||||
|
|
||||||
|
:param input_size: The number of expected features in the input
|
||||||
|
:type input_size: int
|
||||||
|
|
||||||
|
:param hidden_size: The number of features in the hidden state
|
||||||
|
:type hidden_size: int
|
||||||
|
|
||||||
|
:param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True.
|
||||||
|
:type bias: bool, optional
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
|
||||||
|
>>> input = jt.randn(2, 3, 10) # (time_steps, batch, input_size)
|
||||||
|
>>> hx = jt.randn(3, 20) # (batch, hidden_size)
|
||||||
|
>>> cx = jt.randn(3, 20)
|
||||||
|
>>> output = []
|
||||||
|
>>> for i in range(input.shape[0]):
|
||||||
|
hx, cx = rnn(input[i], (hx, cx))
|
||||||
|
output.append(hx)
|
||||||
|
>>> output = jt.stack(output, dim=0)
|
||||||
|
'''
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
k = math.sqrt(1 / hidden_size)
|
||||||
|
self.weight_ih = init.uniform((4 * hidden_size, input_size), 'float32', -k, k)
|
||||||
|
self.weight_hh = init.uniform((4 * hidden_size, hidden_size), 'float32', -k, k)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias_ih = init.uniform((4 * hidden_size,), 'float32', -k, k)
|
||||||
|
self.bias_hh = init.uniform((4 * hidden_size,), 'float32', -k, k)
|
||||||
|
|
||||||
|
def execute(self, input, hx = None):
|
||||||
|
if hx is None:
|
||||||
|
zeros = jt.zeros(input.shape[0], self.hidden_size, dtype=input.dtype)
|
||||||
|
h, c = zeros, zeros
|
||||||
|
else:
|
||||||
|
h, c = hx
|
||||||
|
|
||||||
|
y = matmul_transpose(input, self.weight_ih) + matmul_transpose(h, self.weight_hh)
|
||||||
|
|
||||||
|
if self.bias:
|
||||||
|
y = y + self.bias_ih + self.bias_hh
|
||||||
|
|
||||||
|
i = y[:, :self.hidden_size].sigmoid()
|
||||||
|
f = y[:, self.hidden_size : 2 * self.hidden_size].sigmoid()
|
||||||
|
g = y[:, 2 * self.hidden_size : 3 * self.hidden_size].tanh()
|
||||||
|
o = y[:, 3 * self.hidden_size:].sigmoid()
|
||||||
|
|
||||||
|
c = f * c + i * g
|
||||||
|
h = o * c.tanh()
|
||||||
|
|
||||||
|
return h, c
|
||||||
|
|
||||||
|
|
||||||
|
class LSTM(jt.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, num_layers=1, bias=True,
|
||||||
|
batch_first=False, dropout=0, bidirectional=False, proj_size=0):
|
||||||
|
''' Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.
|
||||||
|
|
||||||
|
:param input_size: The number of expected features in the input.
|
||||||
|
:type input_size: int
|
||||||
|
|
||||||
|
:param hidden_size: The number of features in the hidden state.
|
||||||
|
:type hidden_size: int
|
||||||
|
|
||||||
|
:param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1
|
||||||
|
:type num_layers: int, optinal
|
||||||
|
|
||||||
|
:param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True.
|
||||||
|
:type bias: bool, optional
|
||||||
|
|
||||||
|
:param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False
|
||||||
|
:type bias: bool, optional
|
||||||
|
|
||||||
|
:param dropout: [Not implemented] If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0
|
||||||
|
:type dropout: float, optional
|
||||||
|
|
||||||
|
:param bidirectional: [Not implemented] If True, becomes a bidirectional LSTM. Default: False
|
||||||
|
:type bidirectional: bool, optional
|
||||||
|
|
||||||
|
:param proj_size: If > 0, will use LSTM with projections of corresponding size. Default: 0
|
||||||
|
:type proj_size: int, optional
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> rnn = nn.LSTM(10, 20, 2)
|
||||||
|
>>> input = jt.randn(5, 3, 10)
|
||||||
|
>>> h0 = jt.randn(2, 3, 20)
|
||||||
|
>>> c0 = jt.randn(2, 3, 20)
|
||||||
|
>>> output, (hn, cn) = rnn(input, (h0, c0))
|
||||||
|
'''
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.bias = bias
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.batch_first = batch_first
|
||||||
|
self.bidirectional = bidirectional
|
||||||
|
self.proj_size = proj_size
|
||||||
|
|
||||||
|
assert bidirectional == False, 'bidirectional is not supported now'
|
||||||
|
assert dropout == 0, 'dropout is not supported now'
|
||||||
|
|
||||||
|
num_directions = 1 + bidirectional
|
||||||
|
k = math.sqrt(1 / hidden_size)
|
||||||
|
|
||||||
|
def build_unit(name, in_channels, out_channels=None):
|
||||||
|
if out_channels is not None:
|
||||||
|
shape = (in_channels, out_channels)
|
||||||
|
else:
|
||||||
|
shape = (in_channels,)
|
||||||
|
setattr(self, name, init.uniform(shape, 'float32', -k, k))
|
||||||
|
if self.bidirectional:
|
||||||
|
setattr(self, name + '_reverse', init.uniform(shape, 'float32', -k, k))
|
||||||
|
|
||||||
|
for l in range(num_layers):
|
||||||
|
if l == 0:
|
||||||
|
build_unit(f'weight_ih_l{l}', 4 * hidden_size, num_directions * input_size)
|
||||||
|
else:
|
||||||
|
if proj_size > 0:
|
||||||
|
build_unit(f'weight_ih_l{l}', 4 * hidden_size, num_directions * proj_size)
|
||||||
|
else:
|
||||||
|
build_unit(f'weight_ih_l{l}', 4 * hidden_size, num_directions * hidden_size)
|
||||||
|
|
||||||
|
if proj_size > 0:
|
||||||
|
build_unit(f'weight_hh_l{l}', 4 * hidden_size, proj_size)
|
||||||
|
build_unit(f'weight_hr_l{l}', proj_size, hidden_size)
|
||||||
|
else:
|
||||||
|
build_unit(f'weight_hh_l{l}', 4 * hidden_size, hidden_size)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
build_unit(f'bias_ih_l{l}', 4 * hidden_size)
|
||||||
|
build_unit(f'bias_hh_l{l}', 4 * hidden_size)
|
||||||
|
|
||||||
|
def execute(self, input, hx):
|
||||||
|
if self.batch_first:
|
||||||
|
input = input.permute(1, 0, 2)
|
||||||
|
|
||||||
|
if hx is None:
|
||||||
|
num_directions = 2 if self.bidirectional else 1
|
||||||
|
real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
|
||||||
|
h_zeros = jt.zeros(self.num_layers * num_directions,
|
||||||
|
input.shape[1], real_hidden_size,
|
||||||
|
dtype=input.dtype, device=input.device)
|
||||||
|
c_zeros = jt.zeros(self.num_layers * num_directions,
|
||||||
|
input.shape[1], self.hidden_size,
|
||||||
|
dtype=input.dtype, device=input.device)
|
||||||
|
h, c = h_zeros, c_zeros
|
||||||
|
else:
|
||||||
|
h, c = hx
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for s in range(input.shape[0]):
|
||||||
|
for l in range(self.num_layers):
|
||||||
|
if l == 0:
|
||||||
|
y = matmul_transpose(input[s], getattr(self, f'weight_ih_l{l}'))
|
||||||
|
else:
|
||||||
|
y = matmul_transpose(h[l - 1], getattr(self, f'weight_ih_l{l}'))
|
||||||
|
|
||||||
|
y = y + matmul_transpose(h[l], getattr(self, f'weight_hh_l{l}'))
|
||||||
|
|
||||||
|
if self.bias:
|
||||||
|
y = y + getattr(self, f'bias_ih_l{l}') + getattr(self, f'bias_hh_l{l}')
|
||||||
|
|
||||||
|
i = y[:, :self.hidden_size].sigmoid()
|
||||||
|
f = y[:, self.hidden_size : 2 * self.hidden_size].sigmoid()
|
||||||
|
g = y[:, 2 * self.hidden_size : 3 * self.hidden_size].tanh()
|
||||||
|
o = y[:, 3 * self.hidden_size:].sigmoid()
|
||||||
|
|
||||||
|
c[l] = f * c[l] + i * g
|
||||||
|
rh = o * c[l].tanh()
|
||||||
|
|
||||||
|
if self.proj_size > 0:
|
||||||
|
h[l] = matmul_transpose(rh, getattr(self, f'weight_hr_l{l}'))
|
||||||
|
else:
|
||||||
|
h[l] = rh
|
||||||
|
|
||||||
|
output.append(h[-1])
|
||||||
|
|
||||||
|
output = jt.stack(output, dim=0)
|
||||||
|
return output, (h, c)
|
||||||
|
|
|
@ -0,0 +1,107 @@
|
||||||
|
# ***************************************************************
|
||||||
|
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||||
|
# Maintainers:
|
||||||
|
# Zheng-Ning Liu <lzhengning@gmail.com>
|
||||||
|
#
|
||||||
|
# This file is subject to the terms and conditions defined in
|
||||||
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
# ***************************************************************
|
||||||
|
import unittest
|
||||||
|
import jittor as jt
|
||||||
|
import jittor.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
skip_this_test = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
jt.dirty_fix_pytorch_runtime_error()
|
||||||
|
import torch
|
||||||
|
import torch.nn as tnn
|
||||||
|
except:
|
||||||
|
torch = None
|
||||||
|
tnn = None
|
||||||
|
skip_this_test = True
|
||||||
|
|
||||||
|
|
||||||
|
def check_equal(t_rnn, j_rnn, input, h0, c0):
|
||||||
|
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||||
|
|
||||||
|
t_output, (th, tc) = t_rnn(torch.from_numpy(input),
|
||||||
|
(torch.from_numpy(h0), torch.from_numpy(c0)))
|
||||||
|
j_output, (jh, jc) = j_rnn(jt.float32(input),
|
||||||
|
(jt.float32(h0), jt.float32(c0)))
|
||||||
|
|
||||||
|
assert np.allclose(t_output.detach().numpy(), j_output.data, rtol=1e-03, atol=1e-06)
|
||||||
|
assert np.allclose(th.detach().numpy(), jh.data, rtol=1e-03, atol=1e-06)
|
||||||
|
assert np.allclose(tc.detach().numpy(), jc.data, rtol=1e-03, atol=1e-06)
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||||
|
class TestLSTM(unittest.TestCase):
|
||||||
|
def test_lstm_cell(self):
|
||||||
|
np_h0 = torch.randn(3, 20).numpy()
|
||||||
|
np_c0 = torch.randn(3, 20).numpy()
|
||||||
|
|
||||||
|
t_rnn = tnn.LSTMCell(10, 20)
|
||||||
|
input = torch.randn(2, 3, 10)
|
||||||
|
h0 = torch.from_numpy(np_h0)
|
||||||
|
c0 = torch.from_numpy(np_c0)
|
||||||
|
t_output = []
|
||||||
|
for i in range(input.size()[0]):
|
||||||
|
h0, c0 = t_rnn(input[i], (h0, c0))
|
||||||
|
t_output.append(h0)
|
||||||
|
t_output = torch.stack(t_output, dim=0)
|
||||||
|
|
||||||
|
j_rnn = nn.LSTMCell(10, 20)
|
||||||
|
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||||
|
|
||||||
|
input = jt.float32(input.numpy())
|
||||||
|
h0 = jt.float32(np_h0)
|
||||||
|
c0 = jt.float32(np_c0)
|
||||||
|
j_output = []
|
||||||
|
for i in range(input.size()[0]):
|
||||||
|
h0, c0 = j_rnn(input[i], (h0, c0))
|
||||||
|
j_output.append(h0)
|
||||||
|
j_output = jt.stack(j_output, dim=0)
|
||||||
|
|
||||||
|
t_output = t_output.detach().numpy()
|
||||||
|
j_output = j_output.data
|
||||||
|
assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06)
|
||||||
|
|
||||||
|
def test_lstm(self):
|
||||||
|
h0 = np.random.rand(1, 2, 20).astype(np.float32)
|
||||||
|
c0 = np.random.rand(1, 2, 20).astype(np.float32)
|
||||||
|
input = np.random.rand(5, 2, 10).astype(np.float32)
|
||||||
|
|
||||||
|
t_rnn = tnn.LSTM(10, 20)
|
||||||
|
j_rnn = nn.LSTM(10, 20)
|
||||||
|
check_equal(t_rnn, j_rnn, input, h0, c0)
|
||||||
|
|
||||||
|
proj_size = 13
|
||||||
|
h0 = np.random.rand(1, 2, proj_size).astype(np.float32)
|
||||||
|
c0 = np.random.rand(1, 2, 20).astype(np.float32)
|
||||||
|
input = np.random.rand(5, 2, 10).astype(np.float32)
|
||||||
|
t_rnn = tnn.LSTM(10, 20, proj_size=proj_size)
|
||||||
|
j_rnn = nn.LSTM(10, 20, proj_size=proj_size)
|
||||||
|
check_equal(t_rnn, j_rnn, input, h0, c0)
|
||||||
|
|
||||||
|
h0 = np.random.rand(2, 4, 20).astype(np.float32)
|
||||||
|
c0 = np.random.rand(2, 4, 20).astype(np.float32)
|
||||||
|
input = np.random.rand(5, 4, 10).astype(np.float32)
|
||||||
|
|
||||||
|
t_rnn = tnn.LSTM(10, 20, num_layers=2)
|
||||||
|
j_rnn = nn.LSTM(10, 20, num_layers=2)
|
||||||
|
check_equal(t_rnn, j_rnn, input, h0, c0)
|
||||||
|
|
||||||
|
h0 = np.random.rand(2, 4, proj_size).astype(np.float32)
|
||||||
|
c0 = np.random.rand(2, 4, 20).astype(np.float32)
|
||||||
|
input = np.random.rand(5, 4, 10).astype(np.float32)
|
||||||
|
|
||||||
|
t_rnn = tnn.LSTM(10, 20, num_layers=2, proj_size=proj_size)
|
||||||
|
j_rnn = nn.LSTM(10, 20, num_layers=2, proj_size=proj_size)
|
||||||
|
check_equal(t_rnn, j_rnn, input, h0, c0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue