How to extract the cell state and hidden state from an RNN model in tensorflow?

后端 未结 3 1242
不思量自难忘°
不思量自难忘° 2021-02-09 09:51

I am new to TensorFlow and have difficulties understanding the RNN module. I am trying to extract hidden/cell states from an LSTM. For my code, I am using the implementation fr

相关标签:
3条回答
  • 2021-02-09 10:32

    You may simply collect the values of the states in the same way accuracy is collected.

    I guess, pred, states, acc = sess.run(pred, states, accuracy, feed_dict={x: batch_x, y: batch_y}) should work perfectly fine.

    0 讨论(0)
  • 2021-02-09 10:44

    One comment about your assumption: the "states" does have only the values of "hidden state" and "memory cell" from last timestep.

    The "outputs" contain the "hidden state" from each time step you want (the size of outputs is [batch_size, seq_len, hidden_size]. So I assume that you want "outputs" variable, not "states". See the documentation.

    0 讨论(0)
  • 2021-02-09 10:47

    I have to disagree with the answer of user3480922. For the code:

    outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)
    

    to be able to extract the hidden state for each time_step in a prediction, you have to use the outputs. Because outputs have the hidden state value for each time_step. However, I am not sure is there any way we can store the values of the cell state for each time_step as well. Because states tuple provides the cell state values but only for the last time_step.

    For example, in the following sample with 5 time_steps, the outputs[4,:,:], time_step = 0,...,4 has the hidden state values for time_step=4, whereas the states tuple h only has the hidden state values for time_step=4. State tuple c has the cell value at the time_step=4 though.

      outputs = [[[ 0.0589103 -0.06925126 -0.01531546 0.06108122]
      [ 0.00861215 0.06067181 0.03790079 -0.04296958]
      [ 0.00597713 0.03916606 0.02355802 -0.0277683 ]]
    
      [[ 0.06252582 -0.07336216 -0.01607122 0.05024602]
      [ 0.05464711 0.03219429 0.06635305 0.00753127]
      [ 0.05385715 0.01259535 0.0524035 0.01696803]]
    
      [[ 0.0853352 -0.06414541 0.02524283 0.05798233]
      [ 0.10790729 -0.05008117 0.03003334 0.07391824]
      [ 0.10205664 -0.04479517 0.03844892 0.0693808 ]]
    
      [[ 0.10556188 0.0516542 0.09162509 -0.02726674]
      [ 0.11425048 -0.00211394 0.06025286 0.03575509]
      [ 0.11338984 0.02839304 0.08105748 0.01564003]]
    
      **[[ 0.10072514 0.14767936 0.12387902 -0.07391471]
      [ 0.10510238 0.06321315 0.08100517 -0.00940042]
      [ 0.10553667 0.0984127 0.10094948 -0.02546882]]**]
      states = LSTMStateTuple(c=array([[ 0.23870754, 0.24315512, 0.20842518, -0.12798975],
      [ 0.23749796, 0.10797793, 0.14181322, -0.01695861],
      [ 0.2413336 , 0.16692916, 0.17559692, -0.0453596 ]], dtype=float32), h=array(**[[ 0.10072514, 0.14767936, 0.12387902, -0.07391471],
      [ 0.10510238, 0.06321315, 0.08100517, -0.00940042],
      [ 0.10553667, 0.0984127 , 0.10094948, -0.02546882]]**, dtype=float32))
    
    0 讨论(0)
提交回复
热议问题