Tensorflow - LSTM state reuse within batch

孤街浪徒 提交于 2019-12-25 08:24:36

问题


I am working on a Tensorflow NN which uses an LSTM to track a parameter (time series data regression problem). A batch of training data contains a batch_size of consecutive observations. I would like to use the LSTM state as input to the next sample. So, if I have a batch of data observations, I would like to feed the state of the first observation as input to the second observation and so on. Below I define the lstm state as a tensor of size = batch_size. I would like to reuse the state within a batch:

state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
cell = tf.nn.rnn_cell.BasicLSTMCell(100)
output, curr_state = tf.nn.rnn(cell, data, initial_state=state) 

In the API there is a tf.nn.state_saving_rnn but the documentation is kinda vague. My question: How to reuse curr_state within a training batch.


回答1:


You are basically there, just need to update state with curr_state:

state_update = tf.assign(state, curr_state)

Then, make sure you either call run on state_update itself or an operation that has state_update as a dependency, or the assignment will not actually happen. For example:

with tf.control_dependencies([state_update]):
    model_output = ...

As suggested in the comments, the typical case for RNNs is that you have a batch where the first dimension (0) is the number of sequences and the second dimension (1) is the maximum length of each sequence (if you pass time_major=True when you build the RNN these two are swapped). Ideally, in order to get good performance, you stack multiple sequences into one batch, and then split that batch time-wise. But that's all a different topic really.



来源:https://stackoverflow.com/questions/42133661/tensorflow-lstm-state-reuse-within-batch

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!