diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 06e96931..3f9076ef 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -1460,7 +1460,7 @@ class LSTMCell(jt.Module): def execute(self, input, hx = None): if hx is None: - zeros = jt.zeros(input.shape[0], self.hidden_size, dtype=input.dtype) + zeros = jt.zeros((input.shape[0], self.hidden_size), dtype=input.dtype) h, c = zeros, zeros else: h, c = hx @@ -1481,6 +1481,120 @@ class LSTMCell(jt.Module): return h, c +class RNNCell(jt.Module): + def __init__(self, input_size, hidden_size, bias=True, nonlinearity = "tanh"): + ''' An Elman RNN cell with tanh or ReLU non-linearity. + + :param input_size: The number of expected features in the input + :type input_size: int + + :param hidden_size: The number of features in the hidden state + :type hidden_size: int + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + :param nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'. + :type nonlinearity: str, optional + + Example: + + >>> rnn = nn.RNNCell(10, 20) + >>> input = jt.randn((6, 3, 10)) + >>> hx = jt.randn((3, 20)) + >>> output = [] + >>> for i in range(6): + hx = rnn(input[i], hx) + output.append(hx) + ''' + super().__init__() + + self.hidden_size = hidden_size + self.bias = bias + self.nonlinearity = nonlinearity + + k = math.sqrt(1 / hidden_size) + self.weight_ih = init.uniform((hidden_size, input_size), 'float32', -k, k) + self.weight_hh = init.uniform((hidden_size, hidden_size), 'float32', -k, k) + + if bias: + self.bias_ih = init.uniform((hidden_size,), 'float32', -k, k) + self.bias_hh = init.uniform((hidden_size,), 'float32', -k, k) + + def execute(self, input, hx = None): + if hx is None: + hx = jt.zeros((input.shape[0], self.hidden_size), dtype=input.dtype) + + y = matmul_transpose(input, self.weight_ih)+matmul_transpose(hx, self.weight_hh) + + if self.bias: + y= y + self.bias_ih + self.bias_hh + + if self.nonlinearity == 'tanh': + y = y.tanh() + elif self.nonlinearity == 'relu': + y = relu(y) + else: + raise RuntimeError("Unknown nonlinearity: {}".format(self.nonlinearity)) + + return y + +class GRUCell(jt.Module): + def __init__(self, input_size, hidden_size, bias=True): + ''' A gated recurrent unit (GRU) cell. + + :param input_size: The number of expected features in the input + :type input_size: int + + :param hidden_size: The number of features in the hidden state + :type hidden_size: int + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + Example: + + >>> rnn = nn.GRUCell(10, 20) + >>> input = jt.randn((6, 3, 10)) + >>> hx = jt.randn((3, 20)) + >>> output = [] + >>> for i in range(6): + hx = rnn(input[i], hx) + output.append(hx) + ''' + super().__init__() + + self.hidden_size = hidden_size + self.bias = bias + + k = math.sqrt(1 / hidden_size) + self.weight_ih = init.uniform((3*hidden_size, input_size), 'float32', -k, k) + self.weight_hh = init.uniform((3*hidden_size, hidden_size), 'float32', -k, k) + + if bias: + self.bias_ih = init.uniform((3*hidden_size,), 'float32', -k, k) + self.bias_hh = init.uniform((3*hidden_size,), 'float32', -k, k) + + def execute(self, input, hx = None): + if hx is None: + hx = jt.zeros((input.shape[0], self.hidden_size), dtype=input.dtype) + + gi = matmul_transpose(input, self.weight_ih) + gh = matmul_transpose(hx, self.weight_hh) + + if self.bias: + gi += self.bias_ih + gh += self.bias_hh + + i_r, i_i, i_n = gi.chunk(3, 1) + h_r, h_i, h_n = gh.chunk(3, 1) + + resetgate = jt.sigmoid(i_r + h_r) + inputgate = jt.sigmoid(i_i + h_i) + newgate = jt.tanh(i_n + resetgate * h_n) + hy = newgate + inputgate * (hx - newgate) + return hy + class RNNBase(Module): def __init__(self, mode: str, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, diff --git a/python/jittor/test/test_rnn.py b/python/jittor/test/test_rnn.py index efed034e..51e03668 100644 --- a/python/jittor/test/test_rnn.py +++ b/python/jittor/test/test_rnn.py @@ -119,6 +119,60 @@ class TestRNN(unittest.TestCase): j_output = j_output.data assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06) + def test_rnn_cell(self): + np_h0 = torch.randn(3, 20).numpy() + + t_rnn = tnn.RNNCell(10, 20) + input = torch.randn(2, 3, 10) + h0 = torch.from_numpy(np_h0) + t_output = [] + for i in range(input.size()[0]): + h0 = t_rnn(input[i], h0) + t_output.append(h0) + t_output = torch.stack(t_output, dim=0) + + j_rnn = nn.RNNCell(10, 20) + j_rnn.load_state_dict(t_rnn.state_dict()) + + input = jt.float32(input.numpy()) + h0 = jt.float32(np_h0) + j_output = [] + for i in range(input.size()[0]): + h0 = j_rnn(input[i], h0) + j_output.append(h0) + j_output = jt.stack(j_output, dim=0) + + t_output = t_output.detach().numpy() + j_output = j_output.data + assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06) + + def test_gru_cell(self): + np_h0 = torch.randn(3, 20).numpy() + + t_rnn = tnn.GRUCell(10, 20) + input = torch.randn(2, 3, 10) + h0 = torch.from_numpy(np_h0) + t_output = [] + for i in range(input.size()[0]): + h0 = t_rnn(input[i], h0) + t_output.append(h0) + t_output = torch.stack(t_output, dim=0) + + j_rnn = nn.GRUCell(10, 20) + j_rnn.load_state_dict(t_rnn.state_dict()) + + input = jt.float32(input.numpy()) + h0 = jt.float32(np_h0) + j_output = [] + for i in range(input.size()[0]): + h0 = j_rnn(input[i], h0) + j_output.append(h0) + j_output = jt.stack(j_output, dim=0) + + t_output = t_output.detach().numpy() + j_output = j_output.data + assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06) + def test_lstm(self): h0 = np.random.rand(1, 24, 200).astype(np.float32) c0 = np.random.rand(1, 24, 200).astype(np.float32)