TensorFlow: getting all states from a RNN

后端 未结 2 1465
说谎
说谎 2021-02-07 23:07

How do you get all the hidden states from tf.nn.rnn() or tf.nn.dynamic_rnn() in TensorFlow? The API only gives me the final state.

The first al

2条回答
  •  小鲜肉
    小鲜肉 (楼主)
    2021-02-07 23:27

    tf.nn.dynamic_rnn(also tf.nn.static_rnn) has two return values; "outputs", "state" (https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)

    As you said, "state" is the final state of RNN, but "outputs" are all hidden states of RNN(which shape is [batch_size, max_time, cell.output_size])

    You can use "outputs" as hidden states of RNN, because in most library-provided RNNCell, "output" and "state" are same. (except LSTMCell)

    • Basic https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L347
    • GRU https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L441

提交回复
热议问题