Implementing Attention in Keras

梦想的初衷 提交于 2021-02-11 07:24:18

问题


I am trying to implement attention in keras over a simple lstm:

model_2_input = Input(shape=(500,))
#model_2 = Conv1D(100, 10, activation='relu')(model_2_input)
model_2 = Dense(64, activation='sigmoid')(model_2_input)
model_2 = Dense(64, activation='sigmoid')(model_2)

model_1_input = Input(shape=(None, 2048))
model_1 = LSTM(64, dropout_U = 0.2, dropout_W = 0.2, return_sequences=True)(model_1_input)
model_1, state_h, state_c = LSTM(16, dropout_U = 0.2, dropout_W = 0.2, return_sequences=True, return_state=True)(model_1) # dropout_U = 0.2, dropout_W = 0.2,


#print(state_c.shape)
match = dot([model_1, state_h], axes=(0, 0))
match = Activation('softmax')(match)
match = dot([match, state_h], axes=(0, 0))
print(match.shape)

merged = concatenate([model_2, match], axis=1)
print(merged.shape)
merged = Dense(4, activation='softmax')(merged)
print(merged.shape)
model = Model(inputs=[model_2_input , model_1_input], outputs=merged)
adam = Adam()
model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])

I am getting the error in line:

merged = concatenate([model_2, match], axis=1)

'Got inputs shapes: %s' % (input_shape)) ValueError: A Concatenate layer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [(None, 64), (16, 1)]

The implementation is very simple, just take dot product of lstm output and with the hidden states and use it as weighing function to compute the hidden state itself.

How to resolve the error? Especially how to get the attention concept working?


回答1:


You can add a Reshape layer before concatenating to ensure compatibility. see keras documentation here. Probably best to reshape the model_2 output (None, 64)

EDIT:

Essentially you need to add a Reshape layer with the target shape before concatenating:

model_2 = Reshape(new_shape)(model_2)

This will return (batch_size, (new_shape)) You can of course Reshape either branch of your network, just using model_2 output as it is a simpler example

Having said that, maybe it's worth rethinking your network structure. In particular, this problem stems from the second dot layer (which gives you 16 scalars only). As such it's hard to reshape so that the two branches match.

Without knowing what the model is trying to predict or what the training data looks like, it's hard to comment on whether two dots are necessary or not, but potentially re-structuring will solve this issue.



来源:https://stackoverflow.com/questions/55165008/implementing-attention-in-keras

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!