Tensorflow RNN-LSTM - reset hidden state

前端 未结 2 1349
佛祖请我去吃肉
佛祖请我去吃肉 2021-01-14 22:09

I\'m building a statefull LSTM used for language recognition. Being statefull I can train the network with smaller files and a new batch will be like a next sentence in a di

2条回答
  •  北海茫月
    2021-01-14 22:36

    Simplified version of AMairesse post for one LSTM layer:

    zero_state = tf.zeros(shape=[1, units[-1]])
    self.c_state = tf.Variable(zero_state, trainable=False)
    self.h_state = tf.Variable(zero_state, trainable=False)
    self.init_encoder = tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state)
    
    self.output_encoder, self.state_encoder = tf.nn.dynamic_rnn(cell_encoder, layer, initial_state=self.init_encoder)
    
    # save or reset states
    self.update_ops += [self.c_state.assign(self.state_encoder.c, use_locking=True)]
    self.update_ops += [self.h_state.assign(self.state_encoder.h, use_locking=True)]
    

    or you can use replacement for init_encoder to reset states at step == 0 (you need to pass self.step_tf into session.run() as placeholder):

    self.step_tf = tf.placeholder_with_default(tf.constant(-1, dtype=tf.int64), shape=[], name="step")
    
    self.init_encoder = tf.cond(tf.equal(self.step_tf, 0),
      true_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(zero_state, zero_state),
      false_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state))
    

提交回复
热议问题