Multilayer Seq2Seq model with LSTM in Keras

后端 未结 2 1685
轻奢々
轻奢々 2021-02-02 01:04

I was making a seq2seq model in keras. I had built single layer encoder and decoder and they were working fine. But now I want to extend it to multi layer encoder and decoder.

2条回答
  •  鱼传尺愫
    2021-02-02 01:22

    I've generalized Jeremy Wortz's awesome answer to create the model from a list, 'latent_dims', which will be 'len(latent_dims)' deep, as opposed to a fixed 2-deep.

    Starting with the 'latent_dims' declaration:

    # latent_dims is an array which defines the depth of the encoder/decoder, as well as how large
    # the layers should be.   So an array of sizes [a,b,c]  would produce a depth-3 encoder and decoder
    # with layer sizes equal to [a,b,c] and [c,b,a] respectively.
    latent_dims = [1024, 512,  256]
    

    Creating the model for training:

    # Define an input sequence and process it by going through a len(latent_dims)-layer deep encoder
    encoder_inputs = Input(shape=(None, num_encoder_tokens))
    
    outputs = encoder_inputs
    encoder_states = []
    for j in range(len(latent_dims))[::-1]:
        outputs, h, c = LSTM(latent_dims[j], return_state=True, return_sequences=bool(j))(outputs)
        encoder_states += [h, c]
    
    # Set up the decoder, setting the initial state of each layer to the state of the layer in the encoder
    # which is it's mirror (so for encoder: a->b->c, you'd have decoder initial states: c->b->a).
    decoder_inputs = Input(shape=(None, num_decoder_tokens))
    
    outputs = decoder_inputs
    output_layers = []
    for j in range(len(latent_dims)):
        output_layers.append(
            LSTM(latent_dims[len(latent_dims) - j - 1], return_sequences=True, return_state=True)
        )
        outputs, dh, dc = output_layers[-1](outputs, initial_state=encoder_states[2*j:2*(j+1)])
    
    
    decoder_dense = Dense(num_decoder_tokens, activation='softmax')
    decoder_outputs = decoder_dense(outputs)
    
    # Define the model that will turn
    # `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
    model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
    

    For inference it's as follows:

    # Define sampling models (modified for n-layer deep network)
    encoder_model = Model(encoder_inputs, encoder_states)
    
    
    d_outputs = decoder_inputs
    decoder_states_inputs = []
    decoder_states = []
    for j in range(len(latent_dims))[::-1]:
        current_state_inputs = [Input(shape=(latent_dims[j],)) for _ in range(2)]
    
        temp = output_layers[len(latent_dims)-j-1](d_outputs, initial_state=current_state_inputs)
    
        d_outputs, cur_states = temp[0], temp[1:]
    
        decoder_states += cur_states
        decoder_states_inputs += current_state_inputs
    
    decoder_outputs = decoder_dense(d_outputs)
    decoder_model = Model(
        [decoder_inputs] + decoder_states_inputs,
        [decoder_outputs] + decoder_states)
    

    And finally a few modifications to Jeremy Wortz's 'decode_sequence' function are implemented to get the following:

    def decode_sequence(input_seq, encoder_model, decoder_model):
        # Encode the input as state vectors.
        states_value = encoder_model.predict(input_seq)
    
        # Generate empty target sequence of length 1.
        target_seq = np.zeros((1, 1, num_decoder_tokens))
        # Populate the first character of target sequence with the start character.
        target_seq[0, 0, target_token_index['\t']] = 1.
    
        # Sampling loop for a batch of sequences
        # (to simplify, here we assume a batch of size 1).
        stop_condition = False
        decoded_sentence = []  #Creating a list then using "".join() is usually much faster for string creation
        while not stop_condition:
            to_split = decoder_model.predict([target_seq] + states_value)
    
            output_tokens, states_value = to_split[0], to_split[1:]
    
            # Sample a token
            sampled_token_index = np.argmax(output_tokens[0, 0])
            sampled_char = reverse_target_char_index[sampled_token_index]
            decoded_sentence.append(sampled_char)
    
            # Exit condition: either hit max length
            # or find stop character.
            if sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length:
                stop_condition = True
    
            # Update the target sequence (of length 1).
            target_seq = np.zeros((1, 1, num_decoder_tokens))
            target_seq[0, 0, sampled_token_index] = 1.
    
        return "".join(decoded_sentence)
    
    

提交回复
热议问题