I\'m struggling to understand the cryptic RNN docs. Any help with the following will be greatly appreciated.
tf.nn.dynamic_rnn(cell, inputs, sequence_length=Non
tf.nn.dynamic_rnn
takes in a batch (with the minibatch meaning) of unrelated sequences.
cell
is the actual cell that you want to use (LSTM, GRU,...)inputs
has a shape of batch_size x max_time x input_size
in which max_time is the number of steps in the longest sequence (but all sequences could be of the same length)sequence_length
is a vector of size batch_size
in which each element gives the length of each sequence in the batch (leave it as default if all your sequences are of the same size. This parameter is the one that defines the cell unroll size.The usual way of handling hidden state is to define an initial state tensor before the dynamic_rnn
, like this for instance :
hidden_state_in = cell.zero_state(batch_size, tf.float32)
output, hidden_state_out = tf.nn.dynamic_rnn(cell,
inputs,
initial_state=hidden_state_in,
...)
In the above snippet, both hidden_state_in
and hidden_state_out
have the same shape [batch_size, ...]
(the actual shape depends on the type of cell you use but the important thing is that the first dimension is the batch size).
This way, dynamic_rnn
has an initial hidden state for each sequence. It will pass on the hidden state from time step to time step for each sequence in the inputs
parameter on its own, and hidden_state_out
will contain the final output state for each sequence in the batch. No hidden state is passed between sequences of the same batch, but only between time steps of the same sequence.
Usually, when you're training, every batch is unrelated so you don't have to feed back the hidden state when doing a session.run(output)
.
However, if you're testing, and you need the output at each time step, (i.e. you have to do a session.run()
at every time step) you'll want to evaluate and feed back the output hidden state using something like this :
output, hidden_state = sess.run([output, hidden_state_out],
feed_dict={hidden_state_in:hidden_state})
otherwise tensorflow will just use the default cell.zero_state(batch_size, tf.float32)
at each time step which equates to reinitialising the hidden state at each time step.