I was making a seq2seq model in keras. I had built single layer encoder and decoder and they were working fine. But now I want to extend it to multi layer encoder and decoder.
EDIT - Updated to use the functional API model in Keras vs. the RNN
from keras.models import Model
from keras.layers import Input, LSTM, Dense, RNN
layers = [256,128] # we loop LSTMCells then wrap them in an RNN layer
encoder_inputs = Input(shape=(None, num_encoder_tokens))
e_outputs, h1, c1 = LSTM(latent_dim, return_state=True, return_sequences=True)(encoder_inputs)
_, h2, c2 = LSTM(latent_dim, return_state=True)(e_outputs)
encoder_states = [h1, c1, h2, c2]
decoder_inputs = Input(shape=(None, num_decoder_tokens))
out_layer1 = LSTM(latent_dim, return_sequences=True, return_state=True)
d_outputs, dh1, dc1 = out_layer1(decoder_inputs,initial_state= [h1, c1])
out_layer2 = LSTM(latent_dim, return_sequences=True, return_state=True)
final, dh2, dc2 = out_layer2(d_outputs, initial_state= [h2, c2])
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(final)
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.summary()
And here is the inference setup:
encoder_model = Model(encoder_inputs, encoder_states)
decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_state_input_h1 = Input(shape=(latent_dim,))
decoder_state_input_c1 = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c,
decoder_state_input_h1, decoder_state_input_c1]
d_o, state_h, state_c = out_layer1(
decoder_inputs, initial_state=decoder_states_inputs[:2])
d_o, state_h1, state_c1 = out_layer2(
d_o, initial_state=decoder_states_inputs[-2:])
decoder_states = [state_h, state_c, state_h1, state_c1]
decoder_outputs = decoder_dense(d_o)
decoder_model = Model(
[decoder_inputs] + decoder_states_inputs,
[decoder_outputs] + decoder_states)
decoder_model.summary()
Lastly, if you are following the Keras seq2seq example, you will have to change the prediction script as there are multiple hidden states that need to be managed vs. just two of them in the single-layer example. There will be 2x the number of layer hidden states
# Reverse-lookup token index to decode sequences back to
# something readable.
reverse_input_char_index = dict(
(i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
(i, char) for char, i in target_token_index.items())
def decode_sequence(input_seq):
# Encode the input as state vectors.
states_value = encoder_model.predict(input_seq)
# Generate empty target sequence of length 1.
target_seq = np.zeros((1, 1, num_decoder_tokens))
# Populate the first character of target sequence with the start character.
target_seq[0, 0, target_token_index['\t']] = 1.
# Sampling loop for a batch of sequences
# (to simplify, here we assume a batch of size 1).
stop_condition = False
decoded_sentence = ''
while not stop_condition:
output_tokens, h, c, h1, c1 = decoder_model.predict(
[target_seq] + states_value) #######NOTICE THE ADDITIONAL HIDDEN STATES
# Sample a token
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_char = reverse_target_char_index[sampled_token_index]
decoded_sentence += sampled_char
# Exit condition: either hit max length
# or find stop character.
if (sampled_char == '\n' or
len(decoded_sentence) > max_decoder_seq_length):
stop_condition = True
# Update the target sequence (of length 1).
target_seq = np.zeros((1, 1, num_decoder_tokens))
target_seq[0, 0, sampled_token_index] = 1.
# Update states
states_value = [h, c, h1, c1]#######NOTICE THE ADDITIONAL HIDDEN STATES
return decoded_sentence
for seq_index in range(100):
# Take one sequence (part of the training set)
# for trying out decoding.
input_seq = encoder_input_data[seq_index: seq_index + 1]
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_texts[seq_index])
print('Target sentence:', target_texts[seq_index])
print('Decoded sentence:', decoded_sentence)
I've generalized Jeremy Wortz's awesome answer to create the model from a list, 'latent_dims', which will be 'len(latent_dims)' deep, as opposed to a fixed 2-deep.
Starting with the 'latent_dims' declaration:
# latent_dims is an array which defines the depth of the encoder/decoder, as well as how large
# the layers should be. So an array of sizes [a,b,c] would produce a depth-3 encoder and decoder
# with layer sizes equal to [a,b,c] and [c,b,a] respectively.
latent_dims = [1024, 512, 256]
Creating the model for training:
# Define an input sequence and process it by going through a len(latent_dims)-layer deep encoder
encoder_inputs = Input(shape=(None, num_encoder_tokens))
outputs = encoder_inputs
encoder_states = []
for j in range(len(latent_dims))[::-1]:
outputs, h, c = LSTM(latent_dims[j], return_state=True, return_sequences=bool(j))(outputs)
encoder_states += [h, c]
# Set up the decoder, setting the initial state of each layer to the state of the layer in the encoder
# which is it's mirror (so for encoder: a->b->c, you'd have decoder initial states: c->b->a).
decoder_inputs = Input(shape=(None, num_decoder_tokens))
outputs = decoder_inputs
output_layers = []
for j in range(len(latent_dims)):
output_layers.append(
LSTM(latent_dims[len(latent_dims) - j - 1], return_sequences=True, return_state=True)
)
outputs, dh, dc = output_layers[-1](outputs, initial_state=encoder_states[2*j:2*(j+1)])
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(outputs)
# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
For inference it's as follows:
# Define sampling models (modified for n-layer deep network)
encoder_model = Model(encoder_inputs, encoder_states)
d_outputs = decoder_inputs
decoder_states_inputs = []
decoder_states = []
for j in range(len(latent_dims))[::-1]:
current_state_inputs = [Input(shape=(latent_dims[j],)) for _ in range(2)]
temp = output_layers[len(latent_dims)-j-1](d_outputs, initial_state=current_state_inputs)
d_outputs, cur_states = temp[0], temp[1:]
decoder_states += cur_states
decoder_states_inputs += current_state_inputs
decoder_outputs = decoder_dense(d_outputs)
decoder_model = Model(
[decoder_inputs] + decoder_states_inputs,
[decoder_outputs] + decoder_states)
And finally a few modifications to Jeremy Wortz's 'decode_sequence' function are implemented to get the following:
def decode_sequence(input_seq, encoder_model, decoder_model):
# Encode the input as state vectors.
states_value = encoder_model.predict(input_seq)
# Generate empty target sequence of length 1.
target_seq = np.zeros((1, 1, num_decoder_tokens))
# Populate the first character of target sequence with the start character.
target_seq[0, 0, target_token_index['\t']] = 1.
# Sampling loop for a batch of sequences
# (to simplify, here we assume a batch of size 1).
stop_condition = False
decoded_sentence = [] #Creating a list then using "".join() is usually much faster for string creation
while not stop_condition:
to_split = decoder_model.predict([target_seq] + states_value)
output_tokens, states_value = to_split[0], to_split[1:]
# Sample a token
sampled_token_index = np.argmax(output_tokens[0, 0])
sampled_char = reverse_target_char_index[sampled_token_index]
decoded_sentence.append(sampled_char)
# Exit condition: either hit max length
# or find stop character.
if sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length:
stop_condition = True
# Update the target sequence (of length 1).
target_seq = np.zeros((1, 1, num_decoder_tokens))
target_seq[0, 0, sampled_token_index] = 1.
return "".join(decoded_sentence)