Tensorflow RNN-LSTM - reset hidden state

前端 未结 2 1350
佛祖请我去吃肉
佛祖请我去吃肉 2021-01-14 22:09

I\'m building a statefull LSTM used for language recognition. Being statefull I can train the network with smaller files and a new batch will be like a next sentence in a di

相关标签:
2条回答
  • 2021-01-14 22:36

    Simplified version of AMairesse post for one LSTM layer:

    zero_state = tf.zeros(shape=[1, units[-1]])
    self.c_state = tf.Variable(zero_state, trainable=False)
    self.h_state = tf.Variable(zero_state, trainable=False)
    self.init_encoder = tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state)
    
    self.output_encoder, self.state_encoder = tf.nn.dynamic_rnn(cell_encoder, layer, initial_state=self.init_encoder)
    
    # save or reset states
    self.update_ops += [self.c_state.assign(self.state_encoder.c, use_locking=True)]
    self.update_ops += [self.h_state.assign(self.state_encoder.h, use_locking=True)]
    

    or you can use replacement for init_encoder to reset states at step == 0 (you need to pass self.step_tf into session.run() as placeholder):

    self.step_tf = tf.placeholder_with_default(tf.constant(-1, dtype=tf.int64), shape=[], name="step")
    
    self.init_encoder = tf.cond(tf.equal(self.step_tf, 0),
      true_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(zero_state, zero_state),
      false_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state))
    
    0 讨论(0)
  • 2021-01-14 22:45

    Thanks to this answer to another question I was able to find a way to have complete control on whether or not (and when) the internal state of the RNN should be reset to 0.

    First you need to define some variables to store the state of the RNN, this way you will have control over it :

    with tf.variable_scope('Hidden_state'):
        state_variables = []
        for state_c, state_h in cell.zero_state(self.batch_size, tf.float32):
            state_variables.append(tf.nn.rnn_cell.LSTMStateTuple(
                tf.Variable(state_c, trainable=False),
                tf.Variable(state_h, trainable=False)))
        # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
        rnn_tuple_state = tuple(state_variables)
    

    Note that this version define directly the variables used by the LSTM, this is much better than the version in my question because you don't have to unstack and build the tuple, which add some ops to the graph that you cannot run explicitly.

    Secondly build the RNN and retrieve the final state :

    # Build the RNN
    with tf.name_scope('LSTM'):
        rnn_output, new_states = tf.nn.dynamic_rnn(cell, rnn_inputs,
                                                   sequence_length=input_seq_lengths,
                                                   initial_state=rnn_tuple_state,
                                                   time_major=True)
    

    So now you have the new internal state of the RNN. You can define two ops to manage it.

    The first one will update the variables for the next batch. So in the next batch the "initial_state" of the RNN will be fed with the final state of the previous batch :

    # Define an op to keep the hidden state between batches
    update_ops = []
    for state_variable, new_state in zip(rnn_tuple_state, new_states):
        # Assign the new state to the state variables on this layer
        update_ops.extend([state_variable[0].assign(new_state[0]),
                           state_variable[1].assign(new_state[1])])
    # Return a tuple in order to combine all update_ops into a single operation.
    # The tuple's actual value should not be used.
    rnn_keep_state_op = tf.tuple(update_ops)
    

    You should add this op to your session anytime you want to run a batch and keep the internal state.

    Beware : if you run batch 1 with this op called then batch 2 will start with the batch 1 final state, but if you don't call it again when running batch 2 then batch 3 will start with batch 1 final state also. My advice is to add this op every time you run the RNN.

    The second op will be used to reset the internal state of the RNN to zeros:

    # Define an op to reset the hidden state to zeros
    update_ops = []
    for state_variable in rnn_tuple_state:
        # Assign the new state to the state variables on this layer
        update_ops.extend([state_variable[0].assign(tf.zeros_like(state_variable[0])),
                           state_variable[1].assign(tf.zeros_like(state_variable[1]))])
    # Return a tuple in order to combine all update_ops into a single operation.
    # The tuple's actual value should not be used.
    rnn_state_zero_op = tf.tuple(update_ops)
    

    You can call this op whenever you want to reset the internal state.

    0 讨论(0)
提交回复
热议问题