diff --git a/python/jittor/nn.py b/python/jittor/nn.py index da5c529c..d3fd9d52 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -6,7 +6,7 @@ # Wenyang Zhou <576825820@qq.com> # Meng-Hao Guo # Dun Liang . -# +# Zheng-Ning Liu # # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. @@ -1417,3 +1417,192 @@ def fold(X,output_size,kernel_size,dilation=1,padding=0,stride=1): return output[:,:,padding[0]:padding[0]+output_size[0],padding[1]:padding[1]+output_size[1]] ModuleList = Sequential + + +class LSTMCell(jt.Module): + def __init__(self, input_size, hidden_size, bias=True): + ''' A long short-term memory (LSTM) cell. + + :param input_size: The number of expected features in the input + :type input_size: int + + :param hidden_size: The number of features in the hidden state + :type hidden_size: int + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + Example: + + >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size) + >>> input = jt.randn(2, 3, 10) # (time_steps, batch, input_size) + >>> hx = jt.randn(3, 20) # (batch, hidden_size) + >>> cx = jt.randn(3, 20) + >>> output = [] + >>> for i in range(input.shape[0]): + hx, cx = rnn(input[i], (hx, cx)) + output.append(hx) + >>> output = jt.stack(output, dim=0) + ''' + super().__init__() + + self.hidden_size = hidden_size + self.bias = bias + + k = math.sqrt(1 / hidden_size) + self.weight_ih = init.uniform((4 * hidden_size, input_size), 'float32', -k, k) + self.weight_hh = init.uniform((4 * hidden_size, hidden_size), 'float32', -k, k) + + if bias: + self.bias_ih = init.uniform((4 * hidden_size,), 'float32', -k, k) + self.bias_hh = init.uniform((4 * hidden_size,), 'float32', -k, k) + + def execute(self, input, hx = None): + if hx is None: + zeros = jt.zeros(input.shape[0], self.hidden_size, dtype=input.dtype) + h, c = zeros, zeros + else: + h, c = hx + + y = matmul_transpose(input, self.weight_ih) + matmul_transpose(h, self.weight_hh) + + if self.bias: + y = y + self.bias_ih + self.bias_hh + + i = y[:, :self.hidden_size].sigmoid() + f = y[:, self.hidden_size : 2 * self.hidden_size].sigmoid() + g = y[:, 2 * self.hidden_size : 3 * self.hidden_size].tanh() + o = y[:, 3 * self.hidden_size:].sigmoid() + + c = f * c + i * g + h = o * c.tanh() + + return h, c + + +class LSTM(jt.Module): + def __init__(self, input_size, hidden_size, num_layers=1, bias=True, + batch_first=False, dropout=0, bidirectional=False, proj_size=0): + ''' Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence. + + :param input_size: The number of expected features in the input. + :type input_size: int + + :param hidden_size: The number of features in the hidden state. + :type hidden_size: int + + :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1 + :type num_layers: int, optinal + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False + :type bias: bool, optional + + :param dropout: [Not implemented] If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0 + :type dropout: float, optional + + :param bidirectional: [Not implemented] If True, becomes a bidirectional LSTM. Default: False + :type bidirectional: bool, optional + + :param proj_size: If > 0, will use LSTM with projections of corresponding size. Default: 0 + :type proj_size: int, optional + + Example: + >>> rnn = nn.LSTM(10, 20, 2) + >>> input = jt.randn(5, 3, 10) + >>> h0 = jt.randn(2, 3, 20) + >>> c0 = jt.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + ''' + super().__init__() + + self.hidden_size = hidden_size + self.bias = bias + self.num_layers = num_layers + self.batch_first = batch_first + self.bidirectional = bidirectional + self.proj_size = proj_size + + assert bidirectional == False, 'bidirectional is not supported now' + assert dropout == 0, 'dropout is not supported now' + + num_directions = 1 + bidirectional + k = math.sqrt(1 / hidden_size) + + def build_unit(name, in_channels, out_channels=None): + if out_channels is not None: + shape = (in_channels, out_channels) + else: + shape = (in_channels,) + setattr(self, name, init.uniform(shape, 'float32', -k, k)) + if self.bidirectional: + setattr(self, name + '_reverse', init.uniform(shape, 'float32', -k, k)) + + for l in range(num_layers): + if l == 0: + build_unit(f'weight_ih_l{l}', 4 * hidden_size, num_directions * input_size) + else: + if proj_size > 0: + build_unit(f'weight_ih_l{l}', 4 * hidden_size, num_directions * proj_size) + else: + build_unit(f'weight_ih_l{l}', 4 * hidden_size, num_directions * hidden_size) + + if proj_size > 0: + build_unit(f'weight_hh_l{l}', 4 * hidden_size, proj_size) + build_unit(f'weight_hr_l{l}', proj_size, hidden_size) + else: + build_unit(f'weight_hh_l{l}', 4 * hidden_size, hidden_size) + + if bias: + build_unit(f'bias_ih_l{l}', 4 * hidden_size) + build_unit(f'bias_hh_l{l}', 4 * hidden_size) + + def execute(self, input, hx): + if self.batch_first: + input = input.permute(1, 0, 2) + + if hx is None: + num_directions = 2 if self.bidirectional else 1 + real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size + h_zeros = jt.zeros(self.num_layers * num_directions, + input.shape[1], real_hidden_size, + dtype=input.dtype, device=input.device) + c_zeros = jt.zeros(self.num_layers * num_directions, + input.shape[1], self.hidden_size, + dtype=input.dtype, device=input.device) + h, c = h_zeros, c_zeros + else: + h, c = hx + + output = [] + for s in range(input.shape[0]): + for l in range(self.num_layers): + if l == 0: + y = matmul_transpose(input[s], getattr(self, f'weight_ih_l{l}')) + else: + y = matmul_transpose(h[l - 1], getattr(self, f'weight_ih_l{l}')) + + y = y + matmul_transpose(h[l], getattr(self, f'weight_hh_l{l}')) + + if self.bias: + y = y + getattr(self, f'bias_ih_l{l}') + getattr(self, f'bias_hh_l{l}') + + i = y[:, :self.hidden_size].sigmoid() + f = y[:, self.hidden_size : 2 * self.hidden_size].sigmoid() + g = y[:, 2 * self.hidden_size : 3 * self.hidden_size].tanh() + o = y[:, 3 * self.hidden_size:].sigmoid() + + c[l] = f * c[l] + i * g + rh = o * c[l].tanh() + + if self.proj_size > 0: + h[l] = matmul_transpose(rh, getattr(self, f'weight_hr_l{l}')) + else: + h[l] = rh + + output.append(h[-1]) + + output = jt.stack(output, dim=0) + return output, (h, c) diff --git a/python/jittor/test/test_lstm.py b/python/jittor/test/test_lstm.py new file mode 100644 index 00000000..3889ae3c --- /dev/null +++ b/python/jittor/test/test_lstm.py @@ -0,0 +1,107 @@ +# *************************************************************** +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: +# Zheng-Ning Liu +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import jittor.nn as nn +import numpy as np + + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn +except: + torch = None + tnn = None + skip_this_test = True + + +def check_equal(t_rnn, j_rnn, input, h0, c0): + j_rnn.load_state_dict(t_rnn.state_dict()) + + t_output, (th, tc) = t_rnn(torch.from_numpy(input), + (torch.from_numpy(h0), torch.from_numpy(c0))) + j_output, (jh, jc) = j_rnn(jt.float32(input), + (jt.float32(h0), jt.float32(c0))) + + assert np.allclose(t_output.detach().numpy(), j_output.data, rtol=1e-03, atol=1e-06) + assert np.allclose(th.detach().numpy(), jh.data, rtol=1e-03, atol=1e-06) + assert np.allclose(tc.detach().numpy(), jc.data, rtol=1e-03, atol=1e-06) + + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestLSTM(unittest.TestCase): + def test_lstm_cell(self): + np_h0 = torch.randn(3, 20).numpy() + np_c0 = torch.randn(3, 20).numpy() + + t_rnn = tnn.LSTMCell(10, 20) + input = torch.randn(2, 3, 10) + h0 = torch.from_numpy(np_h0) + c0 = torch.from_numpy(np_c0) + t_output = [] + for i in range(input.size()[0]): + h0, c0 = t_rnn(input[i], (h0, c0)) + t_output.append(h0) + t_output = torch.stack(t_output, dim=0) + + j_rnn = nn.LSTMCell(10, 20) + j_rnn.load_state_dict(t_rnn.state_dict()) + + input = jt.float32(input.numpy()) + h0 = jt.float32(np_h0) + c0 = jt.float32(np_c0) + j_output = [] + for i in range(input.size()[0]): + h0, c0 = j_rnn(input[i], (h0, c0)) + j_output.append(h0) + j_output = jt.stack(j_output, dim=0) + + t_output = t_output.detach().numpy() + j_output = j_output.data + assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06) + + def test_lstm(self): + h0 = np.random.rand(1, 2, 20).astype(np.float32) + c0 = np.random.rand(1, 2, 20).astype(np.float32) + input = np.random.rand(5, 2, 10).astype(np.float32) + + t_rnn = tnn.LSTM(10, 20) + j_rnn = nn.LSTM(10, 20) + check_equal(t_rnn, j_rnn, input, h0, c0) + + proj_size = 13 + h0 = np.random.rand(1, 2, proj_size).astype(np.float32) + c0 = np.random.rand(1, 2, 20).astype(np.float32) + input = np.random.rand(5, 2, 10).astype(np.float32) + t_rnn = tnn.LSTM(10, 20, proj_size=proj_size) + j_rnn = nn.LSTM(10, 20, proj_size=proj_size) + check_equal(t_rnn, j_rnn, input, h0, c0) + + h0 = np.random.rand(2, 4, 20).astype(np.float32) + c0 = np.random.rand(2, 4, 20).astype(np.float32) + input = np.random.rand(5, 4, 10).astype(np.float32) + + t_rnn = tnn.LSTM(10, 20, num_layers=2) + j_rnn = nn.LSTM(10, 20, num_layers=2) + check_equal(t_rnn, j_rnn, input, h0, c0) + + h0 = np.random.rand(2, 4, proj_size).astype(np.float32) + c0 = np.random.rand(2, 4, 20).astype(np.float32) + input = np.random.rand(5, 4, 10).astype(np.float32) + + t_rnn = tnn.LSTM(10, 20, num_layers=2, proj_size=proj_size) + j_rnn = nn.LSTM(10, 20, num_layers=2, proj_size=proj_size) + check_equal(t_rnn, j_rnn, input, h0, c0) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file