Tensorflow, how to access all the middle states of an RNN, not just the last state

后端 未结 2 396
小蘑菇
小蘑菇 2020-12-30 16:06

My understanding is that tf.nn.dynamic_rnn returns the output of an RNN cell (e.g. LSTM) at each time step as well as the final state. How can I access cell sta

2条回答
  •  生来不讨喜
    2020-12-30 16:45

    I would point you to this thread (highlights from me):

    You can write a variant of the LSTMCell that returns both state tensors as part of the output, if you need both c and h state for each time step. If you just need the h state, that's the output of each time step.

    As @jasekp wrote in its comment, the output is really the h part of the state. Then the dynamic_rnn method will just stack all the h part across time (see the string doc of _dynamic_rnn_loop in this file):

    def _dynamic_rnn_loop(cell,
                          inputs,
                          initial_state,
                          parallel_iterations,
                          swap_memory,
                          sequence_length=None,
                          dtype=None):
      """Internal implementation of Dynamic RNN.
        [...]
        Returns:
        Tuple `(final_outputs, final_state)`.
        final_outputs:
          A `Tensor` of shape `[time, batch_size, cell.output_size]`.  If
          `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
          objects, then this returns a (possibly nsted) tuple of Tensors matching
          the corresponding shapes.
    

提交回复
热议问题