Tensorflow dynamic_rnn parameters meaning

前端 未结 1 1831
日久生厌
日久生厌 2021-02-01 09:24

I\'m struggling to understand the cryptic RNN docs. Any help with the following will be greatly appreciated.

tf.nn.dynamic_rnn(cell, inputs, sequence_length=Non         


        
1条回答
  •  醉话见心
    2021-02-01 10:22

    tf.nn.dynamic_rnn takes in a batch (with the minibatch meaning) of unrelated sequences.

    • cell is the actual cell that you want to use (LSTM, GRU,...)
    • inputs has a shape of batch_size x max_time x input_size in which max_time is the number of steps in the longest sequence (but all sequences could be of the same length)
    • sequence_length is a vector of size batch_size in which each element gives the length of each sequence in the batch (leave it as default if all your sequences are of the same size. This parameter is the one that defines the cell unroll size.

    Hidden state handling

    The usual way of handling hidden state is to define an initial state tensor before the dynamic_rnn, like this for instance :

    hidden_state_in = cell.zero_state(batch_size, tf.float32) 
    output, hidden_state_out = tf.nn.dynamic_rnn(cell, 
                                                 inputs,
                                                 initial_state=hidden_state_in,
                                                 ...)
    

    In the above snippet, both hidden_state_in and hidden_state_out have the same shape [batch_size, ...] (the actual shape depends on the type of cell you use but the important thing is that the first dimension is the batch size).

    This way, dynamic_rnn has an initial hidden state for each sequence. It will pass on the hidden state from time step to time step for each sequence in the inputs parameter on its own, and hidden_state_out will contain the final output state for each sequence in the batch. No hidden state is passed between sequences of the same batch, but only between time steps of the same sequence.

    When do I need to feed back the hidden state manually?

    Usually, when you're training, every batch is unrelated so you don't have to feed back the hidden state when doing a session.run(output).

    However, if you're testing, and you need the output at each time step, (i.e. you have to do a session.run() at every time step) you'll want to evaluate and feed back the output hidden state using something like this :

    output, hidden_state = sess.run([output, hidden_state_out],
                                    feed_dict={hidden_state_in:hidden_state})
    

    otherwise tensorflow will just use the default cell.zero_state(batch_size, tf.float32) at each time step which equates to reinitialising the hidden state at each time step.

    0 讨论(0)
提交回复
热议问题