How to add an attention mechanism in keras?

后端 未结 4 701
南方客
南方客 2020-12-02 08:49

I\'m currently using this code that i get from one discussion on github Here\'s the code of the attention mechanism:

_input = Input(shape=[max_length], dtype         


        
相关标签:
4条回答
  • 2020-12-02 08:59

    Attention mechanism pays attention to different part of the sentence:

    activations = LSTM(units, return_sequences=True)(embedded)

    And it determines the contribution of each hidden state of that sentence by

    1. Computing the aggregation of each hidden state attention = Dense(1, activation='tanh')(activations)
    2. Assigning weights to different state attention = Activation('softmax')(attention)

    And finally pay attention to different states:

    sent_representation = merge([activations, attention], mode='mul')

    I don't quite understand this part: sent_representation = Lambda(lambda xin: K.sum(xin, axis=-2), output_shape=(units,))(sent_representation)

    To understand more, you can refer to this and this, and also this one gives a good implementation, see if you can understand more on your own.

    0 讨论(0)
  • 2020-12-02 09:00

    I think you can try the following code to add keras self-attention mechanism with LSTM network

        from keras_self_attention import SeqSelfAttention
    
        inputs = Input(shape=(length,))
        embedding = Embedding(vocab_size, EMBEDDING_DIM, weights=[embedding_matrix], input_length=MAX_SEQUENCE_LENGTH, trainable=False)(inputs)
        lstm = LSTM(num_lstm, input_shape=(X[train].shape[0], X[train].shape[1]), return_sequences=True)(embedding)
        attn = SeqSelfAttention(attention_activation='sigmoid')(lstm)
        Flat = Flatten()(attn)
        dense = Dense(32, activation='relu')(Flat)
        outputs = Dense(3, activation='sigmoid')(dense)
        model = Model(inputs=[inputs], outputs=outputs)
        model.compile(loss='binary_crossentropy', optimizer=Adam(0.001), metrics=['accuracy'])
        model.fit(X_train, y_train, epochs=10, batch_size=32,  validation_data=(X_val,y_val), shuffle=True)
    
    0 讨论(0)
  • 2020-12-02 09:08

    If you want to have an attention along the time dimension, then this part of your code seems correct to me:

    activations = LSTM(units, return_sequences=True)(embedded)
    
    # compute importance for each step
    attention = Dense(1, activation='tanh')(activations)
    attention = Flatten()(attention)
    attention = Activation('softmax')(attention)
    attention = RepeatVector(units)(attention)
    attention = Permute([2, 1])(attention)
    
    sent_representation = merge([activations, attention], mode='mul')
    

    You've worked out the attention vector of shape (batch_size, max_length):

    attention = Activation('softmax')(attention)
    

    I've never seen this code before, so I can't say if this one is actually correct or not:

    K.sum(xin, axis=-2)
    

    Further reading (you might have a look):

    • https://github.com/philipperemy/keras-visualize-activations

    • https://github.com/philipperemy/keras-attention-mechanism

    0 讨论(0)
  • 2020-12-02 09:09

    Recently I was working with applying attention mechanism on a dense layer and here is one sample implementation:

    def build_model():
      input_dims = train_data_X.shape[1]
      inputs = Input(shape=(input_dims,))
      dense1800 = Dense(1800, activation='relu', kernel_regularizer=regularizers.l2(0.01))(inputs)
      attention_probs = Dense( 1800, activation='sigmoid', name='attention_probs')(dense1800)
      attention_mul = multiply([ dense1800, attention_probs], name='attention_mul')
      dense7 = Dense(7, kernel_regularizer=regularizers.l2(0.01), activation='softmax')(attention_mul)   
      model = Model(input=[inputs], output=dense7)
      model.compile(optimizer='adam',
                    loss='categorical_crossentropy',
                    metrics=['accuracy'])
      return model
    
    print (model.summary)
    
    model.fit( train_data_X, train_data_Y_, epochs=20, validation_split=0.2, batch_size=600, shuffle=True, verbose=1)
    

    0 讨论(0)
提交回复
热议问题