How to extract cell state of LSTM model through model.fit()?

徘徊边缘 提交于 2021-01-29 05:47:37

问题


My LSTM model is like this, and I would like to get state_c

def _get_model(input_shape, latent_dim, num_classes):

  inputs = Input(shape=input_shape)
  lstm_lyr,state_h,state_c = LSTM(latent_dim,dropout=0.1,return_state = True)(inputs)
  fc_lyr = Dense(num_classes)(lstm_lyr)
  soft_lyr = Activation('relu')(fc_lyr)
  model = Model(inputs, [soft_lyr,state_c])
  model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
return model
model =_get_model((n_steps_in, n_features),latent_dim ,n_steps_out)
history = model.fit(X_train,Y_train)

But I canot extract the state_c from the history. How to return that?


回答1:


I am unsure of what you mean by "How to get state_c", because your LSTM layer is already returning the state_c with the flag return_state=True. I assume you are trying to train the multi-output model in this case. Currently, you only have a single output but your model is compiled with multiple outputs.

Here is how you work with multi-output models.

from tensorflow.keras import layers, Model, utils

def _get_model(input_shape, latent_dim, num_classes):
    inputs = layers.Input(shape=input_shape)
    lstm_lyr,state_h,state_c = layers.LSTM(latent_dim,dropout=0.1,return_state = True)(inputs)
    fc_lyr = layers.Dense(num_classes)(lstm_lyr)
    soft_lyr = layers.Activation('relu')(fc_lyr)
    model = Model(inputs, [soft_lyr,state_c])   #<------- One input, 2 outputs
    model.compile(optimizer='adam', loss='mse')
    return model


#Dummy data
X = np.random.random((100,15,5))
y1 = np.random.random((100,4))
y2 = np.random.random((100,7))

model =_get_model((15, 5), 7 , 4)
model.fit(X, [y1,y2], epochs=4) #<--------- #One input, 2 outputs
Epoch 1/4
4/4 [==============================] - 2s 6ms/step - loss: 0.6978 - activation_9_loss: 0.2388 - lstm_9_loss: 0.4591
Epoch 2/4
4/4 [==============================] - 0s 6ms/step - loss: 0.6615 - activation_9_loss: 0.2367 - lstm_9_loss: 0.4248
Epoch 3/4
4/4 [==============================] - 0s 7ms/step - loss: 0.6349 - activation_9_loss: 0.2392 - lstm_9_loss: 0.3957
Epoch 4/4
4/4 [==============================] - 0s 8ms/step - loss: 0.6053 - activation_9_loss: 0.2392 - lstm_9_loss: 0.3661


来源:https://stackoverflow.com/questions/65654159/how-to-extract-cell-state-of-lstm-model-through-model-fit

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!