feat: add rnn, gru

This commit is contained in:
lzhengning 2021-04-29 14:58:24 +08:00
parent 18923fcf3d
commit 026dfb8fa2
2 changed files with 337 additions and 112 deletions

View File

@ -1480,7 +1480,183 @@ class LSTMCell(jt.Module):
return h, c
class LSTM(jt.Module):
class RNNBase(Module):
def __init__(self, mode: str, input_size: int, hidden_size: int,
num_layers: int = 1, bias: bool = True, batch_first: bool = False,
dropout: float = 0, bidirectional: bool = False,
proj_size: int = 0) -> None:
super().__init__()
self.mode = mode
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = dropout
self.bidirectional = bidirectional
self.proj_size = proj_size
if mode == 'LSTM':
gate_size = 4 * hidden_size
elif mode == 'GRU':
gate_size = 3 * hidden_size
elif mode == 'RNN':
gate_size = hidden_size
else:
raise ValueError("Unrecognized RNN mode: " + mode)
num_directions = 1 + bidirectional
k = math.sqrt(1 / hidden_size)
def build_unit(name, in_channels, out_channels=None):
if out_channels is not None:
shape = (in_channels, out_channels)
else:
shape = (in_channels,)
setattr(self, name, init.uniform(shape, 'float32', -k, k))
if self.bidirectional:
setattr(self, name + '_reverse', init.uniform(shape, 'float32', -k, k))
for layer in range(num_layers):
if layer == 0:
build_unit(f'weight_ih_l{layer}', gate_size, input_size)
else:
if proj_size > 0:
build_unit(f'weight_ih_l{layer}', gate_size, num_directions * proj_size)
else:
build_unit(f'weight_ih_l{layer}', gate_size, num_directions * hidden_size)
if proj_size > 0:
build_unit(f'weight_hh_l{layer}', gate_size, proj_size)
build_unit(f'weight_hr_l{layer}', proj_size, hidden_size)
else:
build_unit(f'weight_hh_l{layer}', gate_size, hidden_size)
if bias:
build_unit(f'bias_ih_l{layer}', gate_size)
build_unit(f'bias_hh_l{layer}', gate_size)
def call_rnn_sequence(self, input, hidden, suffix):
if 'reverse' in suffix:
input = input[::-1]
output = []
for s in range(input.shape[0]):
out, hidden = self.call_rnn_cell(input[s], hidden, suffix)
output.append(out)
if 'reverse' in suffix:
output = output[::-1]
output = jt.stack(output, dim=0)
return output, hidden
def execute(self, input, hx):
if self.batch_first:
input = input.permute(1, 0, 2)
num_directions = 2 if self.bidirectional else 1
if hx is None:
hx = self.default_init_state()
hidden_n = []
for l in range(self.num_layers):
output = []
if isinstance(hx, tuple):
hidden = [h[l * num_directions] for h in hx]
else:
hidden = hx[l * num_directions]
output, _hidden = self.call_rnn_sequence(input, hidden, f'l{l}')
hidden_n.append(_hidden)
if self.bidirectional:
if isinstance(hx, tuple):
hidden = [h[l * num_directions + 1] for h in hx]
else:
hidden = hx[l * num_directions + 1]
output_b, _hidden = self.call_rnn_sequence(input, hidden, f'l{l}_reverse')
output = jt.concat([output, output_b], dim=-1)
hidden_n.append(_hidden)
if self.dropout > 0:
input = dropout(output, p=self.dropout)
else:
input = output
if isinstance(hx, tuple):
hidden_n = tuple(jt.stack(hn, dim=0) for hn in zip(*hidden_n))
else:
hidden_n = jt.stack(hidden_n, dim=0)
return output, hidden_n
class RNN(RNNBase):
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1,
nonlinearity: str = 'tanh', bias: bool = True, batch_first: bool = False,
dropout: float = 0, bidirectional: bool = False) -> None:
''' Applies a multi-layer Elman RNN with tanh ReLU non-linearity to an input sequence.
:param input_size: The number of expected features in the input.
:type input_size: int
:param hidden_size: The number of features in the hidden state.
:type hidden_size: int
:param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1
:type num_layers: int, optinal
:param nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'
:type nonlinearity: str, optional
:param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True.
:type bias: bool, optional
:param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False
:type bias: bool, optional
:param dropout: If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0
:type dropout: float, optional
:param bidirectional: If True, becomes a bidirectional RNN. Default: False
:type bidirectional: bool, optional
Example:
>>> rnn = nn.RNN(10, 20, 2)
>>> input = jt.randn(5, 3, 10)
>>> h0 = jt.randn(2, 3, 20)
>>> output, hn = rnn(input, h0)
'''
super().__init__('RNN', input_size, hidden_size, num_layers=num_layers,
bias=bias, batch_first=batch_first, dropout=dropout,
bidirectional=bidirectional)
if not nonlinearity in ['tanh', 'relu']:
raise ValueError('Unrecognized nonlinearity: ' + nonlinearity)
self.nonlinearity = nonlinearity
def call_rnn_cell(self, input, hidden, suffix):
y = matmul_transpose(input, getattr(self, f'weight_ih_{suffix}'))
y = y + matmul_transpose(hidden, getattr(self, f'weight_hh_{suffix}'))
if self.bias:
y = y + getattr(self, f'bias_ih_{suffix}') + getattr(self, f'bias_hh_{suffix}')
if self.nonlinearity == 'tanh':
h = jt.tanh(y)
else:
h = jt.relu(y)
return h, h
class LSTM(RNNBase):
def __init__(self, input_size, hidden_size, num_layers=1, bias=True,
batch_first=False, dropout=0, bidirectional=False, proj_size=0):
''' Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.
@ -1500,10 +1676,10 @@ class LSTM(jt.Module):
:param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False
:type bias: bool, optional
:param dropout: [Not implemented] If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0
:param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0
:type dropout: float, optional
:param bidirectional: [Not implemented] If True, becomes a bidirectional LSTM. Default: False
:param bidirectional: If True, becomes a bidirectional LSTM. Default: False
:type bidirectional: bool, optional
:param proj_size: If > 0, will use LSTM with projections of corresponding size. Default: 0
@ -1516,119 +1692,80 @@ class LSTM(jt.Module):
>>> c0 = jt.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
'''
super().__init__()
super().__init__('LSTM', input_size, hidden_size, num_layers=num_layers,
bias=bias, batch_first=batch_first, dropout=dropout,
bidirectional=bidirectional, proj_size=proj_size)
self.hidden_size = hidden_size
self.bias = bias
self.num_layers = num_layers
self.batch_first = batch_first
self.bidirectional = bidirectional
self.proj_size = proj_size
self.dropout = dropout
def call_rnn_cell(self, input, hidden, suffix):
h, c = hidden
y = matmul_transpose(input, getattr(self, f'weight_ih_{suffix}'))
y = y + matmul_transpose(h, getattr(self, f'weight_hh_{suffix}'))
if self.bias:
y = y + getattr(self, f'bias_ih_{suffix}') + getattr(self, f'bias_hh_{suffix}')
num_directions = 1 + bidirectional
k = math.sqrt(1 / hidden_size)
i = y[:, :self.hidden_size].sigmoid()
f = y[:, self.hidden_size : 2 * self.hidden_size].sigmoid()
g = y[:, 2 * self.hidden_size : 3 * self.hidden_size].tanh()
o = y[:, 3 * self.hidden_size:].sigmoid()
c = f * c + i * g
h = o * c.tanh()
def build_unit(name, in_channels, out_channels=None):
if out_channels is not None:
shape = (in_channels, out_channels)
else:
shape = (in_channels,)
setattr(self, name, init.uniform(shape, 'float32', -k, k))
if self.bidirectional:
setattr(self, name + '_reverse', init.uniform(shape, 'float32', -k, k))
if self.proj_size > 0:
h = matmul_transpose(h, getattr(self, f'weight_hr_{suffix}'))
for l in range(num_layers):
if l == 0:
build_unit(f'weight_ih_l{l}', 4 * hidden_size, input_size)
else:
if proj_size > 0:
build_unit(f'weight_ih_l{l}', 4 * hidden_size, num_directions * proj_size)
else:
build_unit(f'weight_ih_l{l}', 4 * hidden_size, num_directions * hidden_size)
return h, (h, c)
if proj_size > 0:
build_unit(f'weight_hh_l{l}', 4 * hidden_size, proj_size)
build_unit(f'weight_hr_l{l}', proj_size, hidden_size)
else:
build_unit(f'weight_hh_l{l}', 4 * hidden_size, hidden_size)
if bias:
build_unit(f'bias_ih_l{l}', 4 * hidden_size)
build_unit(f'bias_hh_l{l}', 4 * hidden_size)
class GRU(RNNBase):
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1,
bias: bool = True, batch_first: bool = False, dropout: float = 0,
bidirectional: bool = False) -> None:
''' Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
def call_lstm_sequence(self, input, h, c, suffix):
if 'reverse' in suffix:
input = input[::-1]
:param input_size: The number of expected features in the input.
:type input_size: int
output = []
for s in range(input.shape[0]):
y = matmul_transpose(input[s], getattr(self, f'weight_ih_{suffix}'))
y = y + matmul_transpose(h, getattr(self, f'weight_hh_{suffix}'))
if self.bias:
y = y + getattr(self, f'bias_ih_{suffix}') + getattr(self, f'bias_hh_{suffix}')
:param hidden_size: The number of features in the hidden state.
:type hidden_size: int
i = y[:, :self.hidden_size].sigmoid()
f = y[:, self.hidden_size : 2 * self.hidden_size].sigmoid()
g = y[:, 2 * self.hidden_size : 3 * self.hidden_size].tanh()
o = y[:, 3 * self.hidden_size:].sigmoid()
c = f * c + i * g
h = o * c.tanh()
:param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two GRUs together to form a stacked GRU, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1
:type num_layers: int, optinal
if self.proj_size > 0:
h = matmul_transpose(h, getattr(self, f'weight_hr_{suffix}'))
:param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True.
:type bias: bool, optional
output.append(h)
:param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False
:type bias: bool, optional
if 'reverse' in suffix:
output = output[::-1]
output = jt.stack(output, dim=0)
:param dropout: If non-zero, introduces a Dropout layer on the outputs of each GRU layer except the last layer, with dropout probability equal to dropout. Default: 0
:type dropout: float, optional
return output, h, c
:param bidirectional: If True, becomes a bidirectional GRU. Default: False
:type bidirectional: bool, optional
def execute(self, input, hx):
if self.batch_first:
input = input.permute(1, 0, 2)
Example:
>>> rnn = nn.GRU(10, 20, 2)
>>> input = jt.randn(5, 3, 10)
>>> h0 = jt.randn(2, 3, 20)
>>> output, hn = rnn(input, h0)
'''
super().__init__('GRU', input_size, hidden_size, num_layers=num_layers,
bias=bias, batch_first=batch_first, dropout=dropout,
bidirectional=bidirectional)
num_directions = 2 if self.bidirectional else 1
def call_rnn_cell(self, input, hidden, suffix):
ih = matmul_transpose(input, getattr(self, f'weight_ih_{suffix}'))
hh = matmul_transpose(hidden, getattr(self, f'weight_hh_{suffix}'))
if self.bias:
ih = ih + getattr(self, f'bias_ih_{suffix}')
hh = hh + getattr(self, f'bias_hh_{suffix}')
if hx is None:
real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
h_zeros = jt.zeros(self.num_layers * num_directions,
input.shape[1], real_hidden_size,
dtype=input.dtype, device=input.device)
c_zeros = jt.zeros(self.num_layers * num_directions,
input.shape[1], self.hidden_size,
dtype=input.dtype, device=input.device)
h0, c0 = h_zeros, c_zeros
else:
h0, c0 = hx
hs = self.hidden_size
r = (ih[:, :hs] + hh[:, :hs]).sigmoid()
z = (ih[:, hs: 2 * hs] + hh[:, hs: 2 * hs]).sigmoid()
n = (ih[:, 2 * hs:] + r * hh[:, 2 * hs:]).tanh()
h = (1 - z) * n + z * hidden
hn = []
cn = []
for l in range(self.num_layers):
output = []
h = h0[l * num_directions: (l + 1) * num_directions]
c = c0[l * num_directions: (l + 1) * num_directions]
output, _h, _c = self.call_lstm_sequence(input, h[0], c[0], f'l{l}')
hn.append(_h)
cn.append(_c)
if self.bidirectional:
output_b, _h, _c = self.call_lstm_sequence(input, h[1], c[1], f'l{l}_reverse')
output = jt.concat([output, output_b], dim=-1)
hn.append(_h)
cn.append(_c)
if self.dropout > 0:
input = dropout(output, p=self.dropout)
else:
input = output
hn = jt.stack(hn, dim=0)
cn = jt.stack(cn, dim=0)
return output, (hn, cn)
return h, h

View File

@ -24,11 +24,23 @@ except:
skip_this_test = True
def check_equal(t_rnn, j_rnn, input, h0, c0):
def check_equal_1(t_rnn, j_rnn, input, h0):
j_rnn.load_state_dict(t_rnn.state_dict())
t_output, th = t_rnn(torch.from_numpy(input), torch.from_numpy(h0))
j_output, jh = j_rnn(jt.float32(input), jt.float32(h0))
assert np.allclose(t_output.detach().numpy(), j_output.data, rtol=1e-03, atol=1e-06)
assert np.allclose(th.detach().numpy(), jh.data, rtol=1e-03, atol=1e-06)
def check_equal_2(t_rnn, j_rnn, input, h0, c0):
j_rnn.load_state_dict(t_rnn.state_dict())
t_output, (th, tc) = t_rnn(torch.from_numpy(input),
(torch.from_numpy(h0), torch.from_numpy(c0)))
j_output, (jh, jc) = j_rnn(jt.float32(input),
(jt.float32(h0), jt.float32(c0)))
@ -38,7 +50,45 @@ def check_equal(t_rnn, j_rnn, input, h0, c0):
@unittest.skipIf(skip_this_test, "No Torch found")
class TestLSTM(unittest.TestCase):
class TestRNN(unittest.TestCase):
def test_rnn(self):
h0 = np.random.rand(1, 24, 200).astype(np.float32)
input = np.random.rand(32, 24, 100).astype(np.float32)
t_rnn = tnn.RNN(100, 200)
j_rnn = nn.RNN(100, 200)
check_equal_1(t_rnn, j_rnn, input, h0)
h0 = np.random.rand(4, 4, 200).astype(np.float32)
input = np.random.rand(5, 4, 100).astype(np.float32)
t_rnn = tnn.RNN(100, 200, num_layers=4)
j_rnn = nn.RNN(100, 200, num_layers=4)
check_equal_1(t_rnn, j_rnn, input, h0)
h0 = np.random.rand(2, 1, 200).astype(np.float32)
input = np.random.rand(5, 1, 100).astype(np.float32)
t_rnn = tnn.RNN(100, 200, bidirectional=True)
j_rnn = nn.RNN(100, 200, bidirectional=True)
check_equal_1(t_rnn, j_rnn, input, h0)
h0 = np.random.rand(4, 4, 200).astype(np.float32)
input = np.random.rand(5, 4, 100).astype(np.float32)
t_rnn = tnn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False)
j_rnn = nn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False)
check_equal_1(t_rnn, j_rnn, input, h0)
h0 = np.random.rand(2, 4, 200).astype(np.float32)
input = np.random.rand(5, 4, 100).astype(np.float32)
t_rnn = tnn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False)
j_rnn = nn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False)
t_rnn.eval()
j_rnn.eval()
check_equal_1(t_rnn, j_rnn, input, h0)
def test_lstm_cell(self):
np_h0 = torch.randn(3, 20).numpy()
np_c0 = torch.randn(3, 20).numpy()
@ -76,7 +126,7 @@ class TestLSTM(unittest.TestCase):
t_rnn = tnn.LSTM(100, 200)
j_rnn = nn.LSTM(100, 200)
check_equal(t_rnn, j_rnn, input, h0, c0)
check_equal_2(t_rnn, j_rnn, input, h0, c0)
proj_size = 13
h0 = np.random.rand(1, 24, proj_size).astype(np.float32)
@ -84,7 +134,7 @@ class TestLSTM(unittest.TestCase):
input = np.random.rand(32, 24, 100).astype(np.float32)
t_rnn = tnn.LSTM(100, 200, proj_size=proj_size)
j_rnn = nn.LSTM(100, 200, proj_size=proj_size)
check_equal(t_rnn, j_rnn, input, h0, c0)
check_equal_2(t_rnn, j_rnn, input, h0, c0)
h0 = np.random.rand(4, 4, 200).astype(np.float32)
c0 = np.random.rand(4, 4, 200).astype(np.float32)
@ -92,7 +142,7 @@ class TestLSTM(unittest.TestCase):
t_rnn = tnn.LSTM(100, 200, num_layers=4)
j_rnn = nn.LSTM(100, 200, num_layers=4)
check_equal(t_rnn, j_rnn, input, h0, c0)
check_equal_2(t_rnn, j_rnn, input, h0, c0)
h0 = np.random.rand(2, 4, proj_size).astype(np.float32)
c0 = np.random.rand(2, 4, 20).astype(np.float32)
@ -100,7 +150,7 @@ class TestLSTM(unittest.TestCase):
t_rnn = tnn.LSTM(10, 20, num_layers=2, proj_size=proj_size)
j_rnn = nn.LSTM(10, 20, num_layers=2, proj_size=proj_size)
check_equal(t_rnn, j_rnn, input, h0, c0)
check_equal_2(t_rnn, j_rnn, input, h0, c0)
h0 = np.random.rand(2, 1, 200).astype(np.float32)
c0 = np.random.rand(2, 1, 200).astype(np.float32)
@ -108,7 +158,7 @@ class TestLSTM(unittest.TestCase):
t_rnn = tnn.LSTM(100, 200, bidirectional=True)
j_rnn = nn.LSTM(100, 200, bidirectional=True)
check_equal(t_rnn, j_rnn, input, h0, c0)
check_equal_2(t_rnn, j_rnn, input, h0, c0)
proj_size = 13
h0 = np.random.rand(2, 4, proj_size).astype(np.float32)
@ -117,7 +167,7 @@ class TestLSTM(unittest.TestCase):
t_rnn = tnn.LSTM(100, 200, bidirectional=True, proj_size=proj_size)
j_rnn = nn.LSTM(100, 200, bidirectional=True, proj_size=proj_size)
check_equal(t_rnn, j_rnn, input, h0, c0)
check_equal_2(t_rnn, j_rnn, input, h0, c0)
h0 = np.random.rand(4, 4, 200).astype(np.float32)
c0 = np.random.rand(4, 4, 200).astype(np.float32)
@ -125,7 +175,7 @@ class TestLSTM(unittest.TestCase):
t_rnn = tnn.LSTM(100, 200, num_layers=2, bidirectional=True, bias=False)
j_rnn = nn.LSTM(100, 200, num_layers=2, bidirectional=True, bias=False)
check_equal(t_rnn, j_rnn, input, h0, c0)
check_equal_2(t_rnn, j_rnn, input, h0, c0)
h0 = np.random.rand(2, 4, 200).astype(np.float32)
@ -136,7 +186,45 @@ class TestLSTM(unittest.TestCase):
j_rnn = nn.LSTM(100, 200, num_layers=2, dropout=0.5, bias=False)
t_rnn.eval()
j_rnn.eval()
check_equal(t_rnn, j_rnn, input, h0, c0)
check_equal_2(t_rnn, j_rnn, input, h0, c0)
def test_gru(self):
h0 = np.random.rand(1, 24, 200).astype(np.float32)
input = np.random.rand(32, 24, 100).astype(np.float32)
t_rnn = tnn.GRU(100, 200)
j_rnn = nn.GRU(100, 200)
check_equal_1(t_rnn, j_rnn, input, h0)
h0 = np.random.rand(4, 4, 200).astype(np.float32)
input = np.random.rand(5, 4, 100).astype(np.float32)
t_rnn = tnn.GRU(100, 200, num_layers=4)
j_rnn = nn.GRU(100, 200, num_layers=4)
check_equal_1(t_rnn, j_rnn, input, h0)
h0 = np.random.rand(2, 1, 200).astype(np.float32)
input = np.random.rand(5, 1, 100).astype(np.float32)
t_rnn = tnn.GRU(100, 200, bidirectional=True)
j_rnn = nn.GRU(100, 200, bidirectional=True)
check_equal_1(t_rnn, j_rnn, input, h0)
h0 = np.random.rand(4, 4, 200).astype(np.float32)
input = np.random.rand(5, 4, 100).astype(np.float32)
t_rnn = tnn.GRU(100, 200, num_layers=2, bidirectional=True, bias=False)
j_rnn = nn.GRU(100, 200, num_layers=2, bidirectional=True, bias=False)
check_equal_1(t_rnn, j_rnn, input, h0)
h0 = np.random.rand(2, 4, 200).astype(np.float32)
input = np.random.rand(5, 4, 100).astype(np.float32)
t_rnn = tnn.GRU(100, 200, num_layers=2, dropout=0.5, bias=False)
j_rnn = nn.GRU(100, 200, num_layers=2, dropout=0.5, bias=False)
t_rnn.eval()
j_rnn.eval()
check_equal_1(t_rnn, j_rnn, input, h0)
if __name__ == "__main__":