Is RNN initial state reset for subsequent mini-batches?

前端 未结 2 1764
时光取名叫无心
时光取名叫无心 2020-12-04 12:58

Could someone please clarify whether the initial state of the RNN in TF is reset for subsequent mini-batches, or the last state of the previous mini-batch is used as mention

相关标签:
2条回答
  • 2020-12-04 13:31

    The tf.nn.dynamic_rnn() or tf.nn.rnn() operations allow to specify the initial state of the RNN using the initial_state parameter. If you don't specify this parameter, the hidden states will be initialized to zero vectors at the beginning of each training batch.

    In TensorFlow, you can wrap tensors in tf.Variable() to keep their values in the graph between multiple session runs. Just make sure to mark them as non-trainable because the optimizers tune all trainable variables by default.

    data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
    
    cell = tf.nn.rnn_cell.GRUCell(256)
    state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
    output, new_state = tf.nn.dynamic_rnn(cell, data, initial_state=state)
    
    with tf.control_dependencies([state.assign(new_state)]):
        output = tf.identity(output)
    
    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    sess.run(output, {data: ...})
    

    I haven't tested this code but it should give you a hint in the right direction. There is also a tf.nn.state_saving_rnn() to which you can provide a state saver object, but I didn't use it yet.

    0 讨论(0)
  • 2020-12-04 13:35

    In addition to danijar's answer, here is the code for a LSTM, whose state is a tuple (state_is_tuple=True). It also supports multiple layers.

    We define two functions - one for getting the state variables with an initial zero state and one function for returning an operation, which we can pass to session.run in order to update the state variables with the LSTM's last hidden state.

    def get_state_variables(batch_size, cell):
        # For each layer, get the initial state and make a variable out of it
        # to enable updating its value.
        state_variables = []
        for state_c, state_h in cell.zero_state(batch_size, tf.float32):
            state_variables.append(tf.contrib.rnn.LSTMStateTuple(
                tf.Variable(state_c, trainable=False),
                tf.Variable(state_h, trainable=False)))
        # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
        return tuple(state_variables)
    
    
    def get_state_update_op(state_variables, new_states):
        # Add an operation to update the train states with the last state tensors
        update_ops = []
        for state_variable, new_state in zip(state_variables, new_states):
            # Assign the new state to the state variables on this layer
            update_ops.extend([state_variable[0].assign(new_state[0]),
                               state_variable[1].assign(new_state[1])])
        # Return a tuple in order to combine all update_ops into a single operation.
        # The tuple's actual value should not be used.
        return tf.tuple(update_ops)
    

    Similar to danijar's answer, we can use that to update the LSTM's state after each batch:

    data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
    cells = [tf.contrib.rnn.GRUCell(256) for _ in range(num_layers)]
    cell = tf.contrib.rnn.MultiRNNCell(cells)
    
    # For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
    states = get_state_variables(batch_size, cell)
    
    # Unroll the LSTM
    outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)
    
    # Add an operation to update the train states with the last state tensors.
    update_op = get_state_update_op(states, new_states)
    
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    sess.run([outputs, update_op], {data: ...})
    

    The main difference is that state_is_tuple=True makes the LSTM's state a LSTMStateTuple containing two variables (cell state and hidden state) instead of just a single variable. Using multiple layers then makes the LSTM's state a tuple of LSTMStateTuples - one per layer.

    0 讨论(0)
提交回复
热议问题