Tensorflow, how to access all the middle states of an RNN, not just the last state

后端 未结 2 395
小蘑菇
小蘑菇 2020-12-30 16:06

My understanding is that tf.nn.dynamic_rnn returns the output of an RNN cell (e.g. LSTM) at each time step as well as the final state. How can I access cell sta

相关标签:
2条回答
  • 2020-12-30 16:45

    I would point you to this thread (highlights from me):

    You can write a variant of the LSTMCell that returns both state tensors as part of the output, if you need both c and h state for each time step. If you just need the h state, that's the output of each time step.

    As @jasekp wrote in its comment, the output is really the h part of the state. Then the dynamic_rnn method will just stack all the h part across time (see the string doc of _dynamic_rnn_loop in this file):

    def _dynamic_rnn_loop(cell,
                          inputs,
                          initial_state,
                          parallel_iterations,
                          swap_memory,
                          sequence_length=None,
                          dtype=None):
      """Internal implementation of Dynamic RNN.
        [...]
        Returns:
        Tuple `(final_outputs, final_state)`.
        final_outputs:
          A `Tensor` of shape `[time, batch_size, cell.output_size]`.  If
          `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
          objects, then this returns a (possibly nsted) tuple of Tensors matching
          the corresponding shapes.
    
    0 讨论(0)
  • 2020-12-30 16:47

    Something like this should work.

    import tensorflow as tf
    import numpy as np
    
    
    class CustomRNN(tf.contrib.rnn.LSTMCell):
        def __init__(self, *args, **kwargs):
            kwargs['state_is_tuple'] = False # force the use of a concatenated state.
            returns = super(CustomRNN, self).__init__(*args, **kwargs) # create an lstm cell
            self._output_size = self._state_size # change the output size to the state size
            return returns
        def __call__(self, inputs, state):
            output, next_state = super(CustomRNN, self).__call__(inputs, state)
            return next_state, next_state # return two copies of the state, instead of the output and the state
    
    X = np.random.randn(2, 10, 8)
    X[1,6:] = 0
    X_lengths = [10, 10]
    
    cell = CustomRNN(num_units=64)
    
    outputs, last_states = tf.nn.dynamic_rnn(
        cell=cell,
        dtype=tf.float64,
        sequence_length=X_lengths,
        inputs=X)
    
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())                                 
    states, last_state = sess.run([outputs, last_states], feed_dict=None)
    

    This uses concatenated states, as I don't know if you can store an arbitrary number of tuple states. The states variable is of shape (batch_size, max_time_size, state_size).

    0 讨论(0)
提交回复
热议问题