问题
I'm using keras 1.0.1 I'm trying to add an attention layer on top of an LSTM. This is what I have so far, but it doesn't work.
input_ = Input(shape=(input_length, input_dim))
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_)
att = TimeDistributed(Dense(1)(lstm))
att = Reshape((-1, input_length))(att)
att = Activation(activation="softmax")(att)
att = RepeatVector(self.HID_DIM)(att)
merge = Merge([att, lstm], "mul")
hid = Merge("sum")(merge)
last = Dense(self.HID_DIM, activation="relu")(hid)
The network should apply an LSTM over the input sequence. Then each hidden state of the LSTM should be input into a fully connected layer, over which a Softmax is applied. The softmax is replicated for each hidden dimension and multiplied by the LSTM hidden states elementwise. Then the resulting vector should be averaged.
EDIT: This compiles, but I'm not sure if it does what I think it should do.
input_ = Input(shape=(input_length, input_dim))
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_)
att = TimeDistributed(Dense(1))(lstm)
att = Flatten()(att)
att = Activation(activation="softmax")(att)
att = RepeatVector(self.HID_DIM)(att)
att = Permute((2,1))(att)
mer = merge([att, lstm], "mul")
hid = AveragePooling1D(pool_length=input_length)(mer)
hid = Flatten()(hid)
回答1:
Here is an implementation of Attention LSTM with Keras, and an example of its instantiation. I haven't tried it myself, though.
回答2:
The first piece of code you have shared is incorrect. The second piece of code looks correct except for one thing. Do not use TimeDistributed as the weights will be the same. Use a regular Dense layer with a non linear activation.
input_ = Input(shape=(input_length, input_dim))
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_)
att = Dense(1, activation='tanh')(lstm_out )
att = Flatten()(att)
att = Activation(activation="softmax")(att)
att = RepeatVector(self.HID_DIM)(att)
att = Permute((2,1))(att)
mer = merge([att, lstm], "mul")
Now you have the weight adjusted states. How you use it is up to you. Most versions of Attention I have seen, just add these up over the time axis and then use the output as the context.
来源:https://stackoverflow.com/questions/36812351/keras-attention-layer-over-lstm