问题
I'm having a tough time using the 'initial state' argument in the tf.nn.rnn function.
val, _ = tf.nn.rnn(cell1, newBatch, initial_state=stateP, dtype=tf.float32)
newBatch.shape => (1, 1, 11)
stateP.shape => (2, 2, 1, 11)
In general, I've gone through the training for my LSTM neural net and now I want to use the values of it. How do I do this? I know that the tf.nn.rnn() function will return state... but I don't know how to plug it in.
fyi stateP.shape => (2, 2, 1, 11) ..... maybe because I used stacked LSTM cells?
I've also tried:
val, _ = tf.nn.dynamic_rnn(stacked_lstm, newBatch, initial_state=stateP, dtype=tf.float32)
but I get the error "AttributeError: 'NoneType' object has no attribute 'op'".
I'm pretty sure that the 'NoneType' object being talked about is the stateP tuple I gave, but I'm not sure what to do here.
EDIT: I finally got this running by using:
init_state = cell.zero_state(batch_size, tf.float32)
To determine the exact shape I need to pass into the 'initial_state' argument. In my case, it was a TUPLE of 4 tensors, each with the shape of (1, 11). I made it like this:
stateP0 = tf.convert_to_tensor(stateP[0][0])
stateP1 = tf.convert_to_tensor(stateP[0][1])
stateP2 = tf.convert_to_tensor(stateP[1][0])
stateP3 = tf.convert_to_tensor(stateP[1][1])
newStateP = stateP0, stateP1, stateP2, stateP3
Alright! Now the tf.dynamic_rnn() function is working, but it's giving me different results every time I run it.... so what's the point of passing in the initial state? I want to use the state I trained to find... and I don't want it to change. I want to actually use the results of my training!
回答1:
You are probably using the deprecated (or soon to be) behavior. stateP
in your case represents the concatenation of c
(cell state) and h
(output of lstm from the final step of unrolling). So you need to slice the state along dimension 1 to get the actual state.
Or, you can initialize your LSTM cell with state_is_tuple=True
, which I would recommend, so that you could easily get the final state (if you want to tinker with it) by indexing the state stateP[0]
. Or you could just pass the state tuple directly to rnn (or dynamic_rnn).
I cant say anything beyond that because you have not provided your initialization code. So I would be guessing.
You can edit your question to provide more details if you still face problems and I would edit the answer.
来源:https://stackoverflow.com/questions/38966346/tensorflow-tf-nn-rnn-function-how-to-use-the-results-of-your-training-to-do