TensorFlow: Remember LSTM state for next batch (stateful LSTM)

后端 未结 2 1700
遇见更好的自我
遇见更好的自我 2020-11-28 04:48

Given a trained LSTM model I want to perform inference for single timesteps, i.e. seq_length = 1 in the example below. After each timestep the internal LSTM (me

相关标签:
2条回答
  • 2020-11-28 05:27

    Tensorflow, best way to save state in RNNs? was actually my original question. The code bellow is how I use the state tuples.

    with tf.variable_scope('decoder') as scope:
        rnn_cell = tf.nn.rnn_cell.MultiRNNCell \
        ([
            tf.nn.rnn_cell.LSTMCell(512, num_proj = 256, state_is_tuple = True),
            tf.nn.rnn_cell.LSTMCell(512, num_proj = WORD_VEC_SIZE, state_is_tuple = True)
        ], state_is_tuple = True)
    
        state = [[tf.zeros((BATCH_SIZE, sz)) for sz in sz_outer] for sz_outer in rnn_cell.state_size]
    
        for t in range(TIME_STEPS):
            if t:
                last = y_[t - 1] if TRAINING else y[t - 1]
            else:
                last = tf.zeros((BATCH_SIZE, WORD_VEC_SIZE))
    
            y[t] = tf.concat(1, (y[t], last))
            y[t], state = rnn_cell(y[t], state)
    
            scope.reuse_variables()
    

    Rather than using tf.nn.rnn_cell.LSTMStateTuple I just create a lists of lists which works fine. In this example I am not saving the state. However you could easily have made state out of variables and just used assign to save the values.

    0 讨论(0)
  • 2020-11-28 05:51

    I found out it was easiest to save the whole state for all layers in a placeholder.

    init_state = np.zeros((num_layers, 2, batch_size, state_size))
    
    ...
    
    state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
    

    Then unpack it and create a tuple of LSTMStateTuples before using the native tensorflow RNN Api.

    l = tf.unpack(state_placeholder, axis=0)
    rnn_tuple_state = tuple(
    [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1])
     for idx in range(num_layers)]
    )
    

    RNN passes in the API:

    cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
    cell = tf.nn.rnn_cell.MultiRNNCell([cell]*num_layers, state_is_tuple=True)
    outputs, state = tf.nn.dynamic_rnn(cell, x_input_batch, initial_state=rnn_tuple_state)
    

    The state - variable will then be feeded to the next batch as a placeholder.

    0 讨论(0)
提交回复
热议问题