Tensorflow, best way to save state in RNNs?

后端 未结 3 1632
慢半拍i
慢半拍i 2020-11-27 03:39

I currently have the following code for a series of chained together RNNs in tensorflow. I am not using MultiRNN since I was to do something later on with the output of eac

相关标签:
3条回答
  • 2020-11-27 04:10

    Here is the code to update the LSTM's initial state, when state_is_tuple=True by defining state variables. It also supports multiple layers.

    We define two functions - one for getting the state variables with an initial zero state and one function for returning an operation, which we can pass to session.run in order to update the state variables with the LSTM's last hidden state.

    def get_state_variables(batch_size, cell):
        # For each layer, get the initial state and make a variable out of it
        # to enable updating its value.
        state_variables = []
        for state_c, state_h in cell.zero_state(batch_size, tf.float32):
            state_variables.append(tf.contrib.rnn.LSTMStateTuple(
                tf.Variable(state_c, trainable=False),
                tf.Variable(state_h, trainable=False)))
        # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
        return tuple(state_variables)
    
    
    def get_state_update_op(state_variables, new_states):
        # Add an operation to update the train states with the last state tensors
        update_ops = []
        for state_variable, new_state in zip(state_variables, new_states):
            # Assign the new state to the state variables on this layer
            update_ops.extend([state_variable[0].assign(new_state[0]),
                               state_variable[1].assign(new_state[1])])
        # Return a tuple in order to combine all update_ops into a single operation.
        # The tuple's actual value should not be used.
        return tf.tuple(update_ops)
    

    We can use that to update the LSTM's state after each batch. Note that I use tf.nn.dynamic_rnn for unrolling:

    data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
    cell_layer = tf.contrib.rnn.GRUCell(256)
    cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers)
    
    # For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
    states = get_state_variables(batch_size, cell)
    
    # Unroll the LSTM
    outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)
    
    # Add an operation to update the train states with the last state tensors.
    update_op = get_state_update_op(states, new_states)
    
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    sess.run([outputs, update_op], {data: ...})
    

    The main difference to this answer is that state_is_tuple=True makes the LSTM's state a LSTMStateTuple containing two variables (cell state and hidden state) instead of just a single variable. Using multiple layers then makes the LSTM's state a tuple of LSTMStateTuples - one per layer.

    Resetting to zero

    When using a trained model for prediction / decoding, you might want to reset the state to zero. Then, you can make use of this function:

    def get_state_reset_op(state_variables, cell, batch_size):
        # Return an operation to set each variable in a list of LSTMStateTuples to zero
        zero_states = cell.zero_state(batch_size, tf.float32)
        return get_state_update_op(state_variables, zero_states)
    

    For example like above:

    reset_state_op = get_state_reset_op(state, cell, max_batch_size)
    # Reset the state to zero before feeding input
    sess.run([reset_state_op])
    sess.run([outputs, update_op], {data: ...})
    
    0 讨论(0)
  • 2020-11-27 04:20

    I am now saving the RNN states using the tf.control_dependencies. Here is an example.

     saved_states = [tf.get_variable('saved_state_%d' % i, shape = (BATCH_SIZE, sz), trainable = False, initializer = tf.constant_initializer()) for i, sz in enumerate(rnn.state_size)]
            W = tf.get_variable('W', shape = (2 * RNN_SIZE, RNN_SIZE), initializer = tf.truncated_normal_initializer(0.0, 1 / np.sqrt(2 * RNN_SIZE)))
            b = tf.get_variable('b', shape = (RNN_SIZE,), initializer = tf.constant_initializer())
    
            rnn_output, states = rnn(last_output, saved_states)
            with tf.control_dependencies([tf.assign(a, b) for a, b in zip(saved_states, states)]):
                dense_input = tf.concat(1, (last_output, rnn_output))
    
            dense_output = tf.tanh(tf.matmul(dense_input, W) + b)
            last_output = dense_output + last_output
    

    I just make sure that part of my graph is dependent on saving the state.

    0 讨论(0)
  • 2020-11-27 04:20

    These two links are also related and useful for this question:

    https://github.com/tensorflow/tensorflow/issues/2695 https://github.com/tensorflow/tensorflow/issues/2838

    0 讨论(0)
提交回复
热议问题