I\'m building a statefull LSTM used for language recognition. Being statefull I can train the network with smaller files and a new batch will be like a next sentence in a di
Simplified version of AMairesse post for one LSTM layer:
zero_state = tf.zeros(shape=[1, units[-1]])
self.c_state = tf.Variable(zero_state, trainable=False)
self.h_state = tf.Variable(zero_state, trainable=False)
self.init_encoder = tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state)
self.output_encoder, self.state_encoder = tf.nn.dynamic_rnn(cell_encoder, layer, initial_state=self.init_encoder)
# save or reset states
self.update_ops += [self.c_state.assign(self.state_encoder.c, use_locking=True)]
self.update_ops += [self.h_state.assign(self.state_encoder.h, use_locking=True)]
or you can use replacement for init_encoder to reset states at step == 0 (you need to pass self.step_tf into session.run() as placeholder):
self.step_tf = tf.placeholder_with_default(tf.constant(-1, dtype=tf.int64), shape=[], name="step")
self.init_encoder = tf.cond(tf.equal(self.step_tf, 0),
true_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(zero_state, zero_state),
false_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state))