TensorFlow dynamic_rnn state

前端 未结 1 1397
鱼传尺愫
鱼传尺愫 2020-12-29 17:05

My question is about the TensorFlow method tf.nn.dynamic_rnn. It returns the output of every time step and the final state.

I would like to know if the

相关标签:
1条回答
  • 2020-12-29 17:39

    tf.nn.dynamic_rnn returns two tensors: outputs and states.

    The outputs holds the outputs of all cells for all sequences in a batch. So if a particular sequence is shorter and padded with zeros, the outputs for the last cells will be zero.

    The states holds the last cell state, or equivalently the last non-zero output per sequence (if you're using BasicRNNCell).

    Here's an example:

    import numpy as np
    import tensorflow as tf
    
    n_steps = 2
    n_inputs = 3
    n_neurons = 5
    
    X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
    seq_length = tf.placeholder(tf.int32, [None])
    
    basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
    outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)
    
    X_batch = np.array([
      # t = 0      t = 1
      [[0, 1, 2], [9, 8, 7]], # instance 0
      [[3, 4, 5], [0, 0, 0]], # instance 1
    ])
    seq_length_batch = np.array([2, 1])
    
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      outputs_val, states_val = sess.run([outputs, states], 
                                         feed_dict={X: X_batch, seq_length: seq_length_batch})
    
      print('outputs:')
      print(outputs_val)
      print('\nstates:')
      print(states_val)
    

    This prints something like:

    outputs:
    [[[-0.85381496 -0.19517037  0.36011398 -0.18617202  0.39162001]
      [-0.99998015 -0.99461144 -0.82241321  0.93778896  0.90737367]]
    
     [[-0.99849552 -0.88643843  0.20635395  0.157896    0.76042926]
      [ 0.          0.          0.          0.          0.        ]]]  # because len=1
    
    states:
    [[-0.99998015 -0.99461144 -0.82241321  0.93778896  0.90737367]
     [-0.99849552 -0.88643843  0.20635395  0.157896    0.76042926]]
    

    Note that the states holds the same vectors as in output, and they are the last non-zero outputs per batch instance.

    0 讨论(0)
提交回复
热议问题