How do I set TensorFlow RNN state when state_is_tuple=True?

后端 未结 2 808
小蘑菇
小蘑菇 2020-11-30 04:25

I have written an RNN language model using TensorFlow. The model is implemented as an RNN class. The graph structure is built in the constructor, while RN

相关标签:
2条回答
  • 2020-11-30 04:53

    One problem with a Tensorflow placeholder is that you can only feed it with a Python list or Numpy array (I think). So you can't save the state between runs in tuples of LSTMStateTuple.

    I solved this by saving the state in a tensor like this

    initial_state = np.zeros((num_layers, 2, batch_size, state_size))

    You have two components in an LSTM layer, the cell state and hidden state, thats what the "2" comes from. (this article is great: https://arxiv.org/pdf/1506.00019.pdf)

    When building the graph you unpack and create the tuple state like this:

    state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
    l = tf.unpack(state_placeholder, axis=0)
    rnn_tuple_state = tuple(
             [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
              for idx in range(num_layers)]
    )
    

    Then you get the new state the usual way

    cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
    cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
    
    outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)
    

    It shouldn't be like this... perhaps they are working on a solution.

    0 讨论(0)
  • 2020-11-30 05:05

    A simple way to feed in an RNN state is to simply feed in both components of the state tuple individually.

    # Constructing the graph
    self.state = rnn_cell.zero_state(...)
    self.output, self.next_state = tf.nn.dynamic_rnn(
        rnn_cell,
        self.input,
        initial_state=self.state)
    
    # Running with initial state
    output, state = sess.run([self.output, self.next_state], feed_dict={
        self.input: input
    })
    
    # Running with subsequent state:
    output, state = sess.run([self.output, self.next_state], feed_dict={
        self.input: input,
        self.state[0]: state[0],
        self.state[1]: state[1]
    })
    
    0 讨论(0)
提交回复
热议问题