mirror of https://github.com/Jittor/Jittor
add rnncell and grucell
This commit is contained in:
parent
295393b556
commit
46e6a66411
|
@ -1460,7 +1460,7 @@ class LSTMCell(jt.Module):
|
||||||
|
|
||||||
def execute(self, input, hx = None):
|
def execute(self, input, hx = None):
|
||||||
if hx is None:
|
if hx is None:
|
||||||
zeros = jt.zeros(input.shape[0], self.hidden_size, dtype=input.dtype)
|
zeros = jt.zeros((input.shape[0], self.hidden_size), dtype=input.dtype)
|
||||||
h, c = zeros, zeros
|
h, c = zeros, zeros
|
||||||
else:
|
else:
|
||||||
h, c = hx
|
h, c = hx
|
||||||
|
@ -1481,6 +1481,120 @@ class LSTMCell(jt.Module):
|
||||||
return h, c
|
return h, c
|
||||||
|
|
||||||
|
|
||||||
|
class RNNCell(jt.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, bias=True, nonlinearity = "tanh"):
|
||||||
|
''' An Elman RNN cell with tanh or ReLU non-linearity.
|
||||||
|
|
||||||
|
:param input_size: The number of expected features in the input
|
||||||
|
:type input_size: int
|
||||||
|
|
||||||
|
:param hidden_size: The number of features in the hidden state
|
||||||
|
:type hidden_size: int
|
||||||
|
|
||||||
|
:param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True.
|
||||||
|
:type bias: bool, optional
|
||||||
|
|
||||||
|
:param nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'.
|
||||||
|
:type nonlinearity: str, optional
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> rnn = nn.RNNCell(10, 20)
|
||||||
|
>>> input = jt.randn((6, 3, 10))
|
||||||
|
>>> hx = jt.randn((3, 20))
|
||||||
|
>>> output = []
|
||||||
|
>>> for i in range(6):
|
||||||
|
hx = rnn(input[i], hx)
|
||||||
|
output.append(hx)
|
||||||
|
'''
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.bias = bias
|
||||||
|
self.nonlinearity = nonlinearity
|
||||||
|
|
||||||
|
k = math.sqrt(1 / hidden_size)
|
||||||
|
self.weight_ih = init.uniform((hidden_size, input_size), 'float32', -k, k)
|
||||||
|
self.weight_hh = init.uniform((hidden_size, hidden_size), 'float32', -k, k)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias_ih = init.uniform((hidden_size,), 'float32', -k, k)
|
||||||
|
self.bias_hh = init.uniform((hidden_size,), 'float32', -k, k)
|
||||||
|
|
||||||
|
def execute(self, input, hx = None):
|
||||||
|
if hx is None:
|
||||||
|
hx = jt.zeros((input.shape[0], self.hidden_size), dtype=input.dtype)
|
||||||
|
|
||||||
|
y = matmul_transpose(input, self.weight_ih)+matmul_transpose(hx, self.weight_hh)
|
||||||
|
|
||||||
|
if self.bias:
|
||||||
|
y= y + self.bias_ih + self.bias_hh
|
||||||
|
|
||||||
|
if self.nonlinearity == 'tanh':
|
||||||
|
y = y.tanh()
|
||||||
|
elif self.nonlinearity == 'relu':
|
||||||
|
y = relu(y)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unknown nonlinearity: {}".format(self.nonlinearity))
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
class GRUCell(jt.Module):
|
||||||
|
def __init__(self, input_size, hidden_size, bias=True):
|
||||||
|
''' A gated recurrent unit (GRU) cell.
|
||||||
|
|
||||||
|
:param input_size: The number of expected features in the input
|
||||||
|
:type input_size: int
|
||||||
|
|
||||||
|
:param hidden_size: The number of features in the hidden state
|
||||||
|
:type hidden_size: int
|
||||||
|
|
||||||
|
:param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True.
|
||||||
|
:type bias: bool, optional
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> rnn = nn.GRUCell(10, 20)
|
||||||
|
>>> input = jt.randn((6, 3, 10))
|
||||||
|
>>> hx = jt.randn((3, 20))
|
||||||
|
>>> output = []
|
||||||
|
>>> for i in range(6):
|
||||||
|
hx = rnn(input[i], hx)
|
||||||
|
output.append(hx)
|
||||||
|
'''
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
k = math.sqrt(1 / hidden_size)
|
||||||
|
self.weight_ih = init.uniform((3*hidden_size, input_size), 'float32', -k, k)
|
||||||
|
self.weight_hh = init.uniform((3*hidden_size, hidden_size), 'float32', -k, k)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias_ih = init.uniform((3*hidden_size,), 'float32', -k, k)
|
||||||
|
self.bias_hh = init.uniform((3*hidden_size,), 'float32', -k, k)
|
||||||
|
|
||||||
|
def execute(self, input, hx = None):
|
||||||
|
if hx is None:
|
||||||
|
hx = jt.zeros((input.shape[0], self.hidden_size), dtype=input.dtype)
|
||||||
|
|
||||||
|
gi = matmul_transpose(input, self.weight_ih)
|
||||||
|
gh = matmul_transpose(hx, self.weight_hh)
|
||||||
|
|
||||||
|
if self.bias:
|
||||||
|
gi += self.bias_ih
|
||||||
|
gh += self.bias_hh
|
||||||
|
|
||||||
|
i_r, i_i, i_n = gi.chunk(3, 1)
|
||||||
|
h_r, h_i, h_n = gh.chunk(3, 1)
|
||||||
|
|
||||||
|
resetgate = jt.sigmoid(i_r + h_r)
|
||||||
|
inputgate = jt.sigmoid(i_i + h_i)
|
||||||
|
newgate = jt.tanh(i_n + resetgate * h_n)
|
||||||
|
hy = newgate + inputgate * (hx - newgate)
|
||||||
|
return hy
|
||||||
|
|
||||||
class RNNBase(Module):
|
class RNNBase(Module):
|
||||||
def __init__(self, mode: str, input_size: int, hidden_size: int,
|
def __init__(self, mode: str, input_size: int, hidden_size: int,
|
||||||
num_layers: int = 1, bias: bool = True, batch_first: bool = False,
|
num_layers: int = 1, bias: bool = True, batch_first: bool = False,
|
||||||
|
|
|
@ -119,6 +119,60 @@ class TestRNN(unittest.TestCase):
|
||||||
j_output = j_output.data
|
j_output = j_output.data
|
||||||
assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06)
|
assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06)
|
||||||
|
|
||||||
|
def test_rnn_cell(self):
|
||||||
|
np_h0 = torch.randn(3, 20).numpy()
|
||||||
|
|
||||||
|
t_rnn = tnn.RNNCell(10, 20)
|
||||||
|
input = torch.randn(2, 3, 10)
|
||||||
|
h0 = torch.from_numpy(np_h0)
|
||||||
|
t_output = []
|
||||||
|
for i in range(input.size()[0]):
|
||||||
|
h0 = t_rnn(input[i], h0)
|
||||||
|
t_output.append(h0)
|
||||||
|
t_output = torch.stack(t_output, dim=0)
|
||||||
|
|
||||||
|
j_rnn = nn.RNNCell(10, 20)
|
||||||
|
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||||
|
|
||||||
|
input = jt.float32(input.numpy())
|
||||||
|
h0 = jt.float32(np_h0)
|
||||||
|
j_output = []
|
||||||
|
for i in range(input.size()[0]):
|
||||||
|
h0 = j_rnn(input[i], h0)
|
||||||
|
j_output.append(h0)
|
||||||
|
j_output = jt.stack(j_output, dim=0)
|
||||||
|
|
||||||
|
t_output = t_output.detach().numpy()
|
||||||
|
j_output = j_output.data
|
||||||
|
assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06)
|
||||||
|
|
||||||
|
def test_gru_cell(self):
|
||||||
|
np_h0 = torch.randn(3, 20).numpy()
|
||||||
|
|
||||||
|
t_rnn = tnn.GRUCell(10, 20)
|
||||||
|
input = torch.randn(2, 3, 10)
|
||||||
|
h0 = torch.from_numpy(np_h0)
|
||||||
|
t_output = []
|
||||||
|
for i in range(input.size()[0]):
|
||||||
|
h0 = t_rnn(input[i], h0)
|
||||||
|
t_output.append(h0)
|
||||||
|
t_output = torch.stack(t_output, dim=0)
|
||||||
|
|
||||||
|
j_rnn = nn.GRUCell(10, 20)
|
||||||
|
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||||
|
|
||||||
|
input = jt.float32(input.numpy())
|
||||||
|
h0 = jt.float32(np_h0)
|
||||||
|
j_output = []
|
||||||
|
for i in range(input.size()[0]):
|
||||||
|
h0 = j_rnn(input[i], h0)
|
||||||
|
j_output.append(h0)
|
||||||
|
j_output = jt.stack(j_output, dim=0)
|
||||||
|
|
||||||
|
t_output = t_output.detach().numpy()
|
||||||
|
j_output = j_output.data
|
||||||
|
assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06)
|
||||||
|
|
||||||
def test_lstm(self):
|
def test_lstm(self):
|
||||||
h0 = np.random.rand(1, 24, 200).astype(np.float32)
|
h0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||||
c0 = np.random.rand(1, 24, 200).astype(np.float32)
|
c0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||||
|
|
Loading…
Reference in New Issue