Multilayer Seq2Seq model with LSTM in Keras

后端 未结 2 1682
轻奢々
轻奢々 2021-02-02 01:04

I was making a seq2seq model in keras. I had built single layer encoder and decoder and they were working fine. But now I want to extend it to multi layer encoder and decoder.

2条回答
  •  小鲜肉
    小鲜肉 (楼主)
    2021-02-02 01:19

    EDIT - Updated to use the functional API model in Keras vs. the RNN

    from keras.models import Model
    from keras.layers import Input, LSTM, Dense, RNN
    layers = [256,128] # we loop LSTMCells then wrap them in an RNN layer
    
    encoder_inputs = Input(shape=(None, num_encoder_tokens))
    
    e_outputs, h1, c1 = LSTM(latent_dim, return_state=True, return_sequences=True)(encoder_inputs) 
    _, h2, c2 = LSTM(latent_dim, return_state=True)(e_outputs) 
    encoder_states = [h1, c1, h2, c2]
    
    decoder_inputs = Input(shape=(None, num_decoder_tokens))
    
    out_layer1 = LSTM(latent_dim, return_sequences=True, return_state=True)
    d_outputs, dh1, dc1 = out_layer1(decoder_inputs,initial_state= [h1, c1])
    out_layer2 = LSTM(latent_dim, return_sequences=True, return_state=True)
    final, dh2, dc2 = out_layer2(d_outputs, initial_state= [h2, c2])
    decoder_dense = Dense(num_decoder_tokens, activation='softmax')
    decoder_outputs = decoder_dense(final)
    
    
    model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
    
    model.summary()
    

    And here is the inference setup:

    encoder_model = Model(encoder_inputs, encoder_states)
    
    decoder_state_input_h = Input(shape=(latent_dim,))
    decoder_state_input_c = Input(shape=(latent_dim,))
    decoder_state_input_h1 = Input(shape=(latent_dim,))
    decoder_state_input_c1 = Input(shape=(latent_dim,))
    decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c, 
                             decoder_state_input_h1, decoder_state_input_c1]
    d_o, state_h, state_c = out_layer1(
        decoder_inputs, initial_state=decoder_states_inputs[:2])
    d_o, state_h1, state_c1 = out_layer2(
        d_o, initial_state=decoder_states_inputs[-2:])
    decoder_states = [state_h, state_c, state_h1, state_c1]
    decoder_outputs = decoder_dense(d_o)
    decoder_model = Model(
        [decoder_inputs] + decoder_states_inputs,
        [decoder_outputs] + decoder_states)
    
    decoder_model.summary()
    

    Lastly, if you are following the Keras seq2seq example, you will have to change the prediction script as there are multiple hidden states that need to be managed vs. just two of them in the single-layer example. There will be 2x the number of layer hidden states

    # Reverse-lookup token index to decode sequences back to
    # something readable.
    reverse_input_char_index = dict(
        (i, char) for char, i in input_token_index.items())
    reverse_target_char_index = dict(
        (i, char) for char, i in target_token_index.items())
    
    def decode_sequence(input_seq):
        # Encode the input as state vectors.
        states_value = encoder_model.predict(input_seq)
    
        # Generate empty target sequence of length 1.
        target_seq = np.zeros((1, 1, num_decoder_tokens))
        # Populate the first character of target sequence with the start character.
        target_seq[0, 0, target_token_index['\t']] = 1.
    
        # Sampling loop for a batch of sequences
        # (to simplify, here we assume a batch of size 1).
        stop_condition = False
        decoded_sentence = ''
        while not stop_condition:
            output_tokens, h, c, h1, c1 = decoder_model.predict(
                [target_seq] + states_value) #######NOTICE THE ADDITIONAL HIDDEN STATES
    
            # Sample a token
            sampled_token_index = np.argmax(output_tokens[0, -1, :])
            sampled_char = reverse_target_char_index[sampled_token_index]
            decoded_sentence += sampled_char
    
            # Exit condition: either hit max length
            # or find stop character.
            if (sampled_char == '\n' or
               len(decoded_sentence) > max_decoder_seq_length):
                stop_condition = True
    
            # Update the target sequence (of length 1).
            target_seq = np.zeros((1, 1, num_decoder_tokens))
            target_seq[0, 0, sampled_token_index] = 1.
    
            # Update states
            states_value = [h, c, h1, c1]#######NOTICE THE ADDITIONAL HIDDEN STATES
    
        return decoded_sentence
    
    
    for seq_index in range(100):
        # Take one sequence (part of the training set)
        # for trying out decoding.
        input_seq = encoder_input_data[seq_index: seq_index + 1]
        decoded_sentence = decode_sequence(input_seq)
        print('-')
        print('Input sentence:', input_texts[seq_index])
        print('Target sentence:', target_texts[seq_index])
        print('Decoded sentence:', decoded_sentence)
    

提交回复
热议问题