问题
I am trying to implement attention in keras over a simple lstm:
model_2_input = Input(shape=(500,))
#model_2 = Conv1D(100, 10, activation='relu')(model_2_input)
model_2 = Dense(64, activation='sigmoid')(model_2_input)
model_2 = Dense(64, activation='sigmoid')(model_2)
model_1_input = Input(shape=(None, 2048))
model_1 = LSTM(64, dropout_U = 0.2, dropout_W = 0.2, return_sequences=True)(model_1_input)
model_1, state_h, state_c = LSTM(16, dropout_U = 0.2, dropout_W = 0.2, return_sequences=True, return_state=True)(model_1) # dropout_U = 0.2, dropout_W = 0.2,
#print(state_c.shape)
match = dot([model_1, state_h], axes=(0, 0))
match = Activation('softmax')(match)
match = dot([match, state_h], axes=(0, 0))
print(match.shape)
merged = concatenate([model_2, match], axis=1)
print(merged.shape)
merged = Dense(4, activation='softmax')(merged)
print(merged.shape)
model = Model(inputs=[model_2_input , model_1_input], outputs=merged)
adam = Adam()
model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])
I am getting the error in line:
merged = concatenate([model_2, match], axis=1)
'Got inputs shapes: %s' % (input_shape)) ValueError: A
Concatenate
layer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [(None, 64), (16, 1)]
The implementation is very simple, just take dot product of lstm output and with the hidden states and use it as weighing function to compute the hidden state itself.
How to resolve the error? Especially how to get the attention concept working?
回答1:
You can add a Reshape layer before concatenating to ensure compatibility.
see keras documentation here.
Probably best to reshape the model_2 output (None, 64)
EDIT:
Essentially you need to add a Reshape layer with the target shape before concatenating:
model_2 = Reshape(new_shape)(model_2)
This will return (batch_size, (new_shape))
You can of course Reshape either branch of your network, just using model_2 output as it is a simpler example
Having said that, maybe it's worth rethinking your network structure. In particular, this problem stems from the second dot layer (which gives you 16 scalars only). As such it's hard to reshape so that the two branches match.
Without knowing what the model is trying to predict or what the training data looks like, it's hard to comment on whether two dots are necessary or not, but potentially re-structuring will solve this issue.
来源:https://stackoverflow.com/questions/55165008/implementing-attention-in-keras