TensorFlow tf.nn.rnn function … how to use the results of your training to do a single forward-pass through the RNN

荒凉一梦 提交于 2019-12-12 01:42:40

问题


I'm having a tough time using the 'initial state' argument in the tf.nn.rnn function.

val, _ = tf.nn.rnn(cell1, newBatch, initial_state=stateP, dtype=tf.float32)

newBatch.shape => (1, 1, 11)
stateP.shape => (2, 2, 1, 11)

In general, I've gone through the training for my LSTM neural net and now I want to use the values of it. How do I do this? I know that the tf.nn.rnn() function will return state... but I don't know how to plug it in.

fyi stateP.shape => (2, 2, 1, 11) ..... maybe because I used stacked LSTM cells?

I've also tried:

val, _ = tf.nn.dynamic_rnn(stacked_lstm, newBatch, initial_state=stateP, dtype=tf.float32)

but I get the error "AttributeError: 'NoneType' object has no attribute 'op'".
I'm pretty sure that the 'NoneType' object being talked about is the stateP tuple I gave, but I'm not sure what to do here.

EDIT: I finally got this running by using:

init_state = cell.zero_state(batch_size, tf.float32)

To determine the exact shape I need to pass into the 'initial_state' argument. In my case, it was a TUPLE of 4 tensors, each with the shape of (1, 11). I made it like this:

    stateP0 = tf.convert_to_tensor(stateP[0][0])
    stateP1 = tf.convert_to_tensor(stateP[0][1])
    stateP2 = tf.convert_to_tensor(stateP[1][0])
    stateP3 = tf.convert_to_tensor(stateP[1][1])
    newStateP = stateP0, stateP1, stateP2, stateP3

Alright! Now the tf.dynamic_rnn() function is working, but it's giving me different results every time I run it.... so what's the point of passing in the initial state? I want to use the state I trained to find... and I don't want it to change. I want to actually use the results of my training!


回答1:


You are probably using the deprecated (or soon to be) behavior. stateP in your case represents the concatenation of c (cell state) and h (output of lstm from the final step of unrolling). So you need to slice the state along dimension 1 to get the actual state.

Or, you can initialize your LSTM cell with state_is_tuple=True, which I would recommend, so that you could easily get the final state (if you want to tinker with it) by indexing the state stateP[0]. Or you could just pass the state tuple directly to rnn (or dynamic_rnn).

I cant say anything beyond that because you have not provided your initialization code. So I would be guessing.

You can edit your question to provide more details if you still face problems and I would edit the answer.



来源:https://stackoverflow.com/questions/38966346/tensorflow-tf-nn-rnn-function-how-to-use-the-results-of-your-training-to-do

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!