问题
I have been working with the "dynamic_rnn" to create a model.
The model is based upon a 80 time period signal, and I want to zero the "initial_state" before each run so I have setup the following code fragment to accomplish this:
state = cell_L1.zero_state(self.BatchSize,Xinputs.dtype)
outputs, outState = rnn.dynamic_rnn(cell_L1,Xinputs,initial_state=state, dtype=tf.float32)
This works great for the training process. The problem is once I go to the inference, where my BatchSize = 1, I get an error back as the rnn "state" doesn't match the new Xinputs shape. So what I figured is I need to make "self.BatchSize" based upon the input batch size rather than hard code it. I tried many different approaches, and none of them have worked. I would rather not pass a bunch of zeros through the feed_dict as it is a constant based upon the batch size.
Here are some of my attempts. They all generally fail since the input size is unknown upon building the graph:
state = cell_L1.zero_state(Xinputs.get_shape()[0],Xinputs.dtype)
.....
state = tf.zeros([Xinputs.get_shape()[0], self.state_size], Xinputs.dtype, name="RnnInitializer")
Another approach, thinking the initializer might not get called until run-time, but still failed at graph build:
init = lambda shape, dtype: np.zeros(*shape)
state = tf.get_variable("state", shape=[Xinputs.get_shape()[0], self.state_size],initializer=init)
Is there a way to get this constant initial state to be created dynamically or do I need to reset it through the feed_dict with tensor-serving code? Is there a clever way to do this only once within the graph maybe with an tf.Variable.assign?
回答1:
The solution to the problem was how to obtain the "batch_size" such that the variable is not hard coded.
This was the correct approach from the given example:
Xinputs = tf.placeholder(tf.int32, (None, self.sequence_size, self.num_params), name="input")
state = cell_L1.zero_state(Xinputs.get_shape()[0],Xinputs.dtype)
The problem is the use of "get_shape()[0]", this returns the "shape" of the tensor and takes the batch_size value at [0]. The documentation doesn't seem to be that clear, but this appears to be a constant value so when you load the graph into an inference, this value is still hard coded (maybe only evaluated at graph creation?).
Using the "tf.shape()" function, seems to do the trick. This doesn't return the shape, but a tensor. So this seems to be updated more at run-time. Using this code fragment solved the problem of a training batch of 128 and then loading the graph into TensorFlow-Service inference handling a batch of just 1.
Xinputs = tf.placeholder(tf.int32, (None, self.sequence_size, self.num_params), name="input")
batch_size = tf.shape(Xinputs)[0]
state = self.cell_L1.zero_state(batch_size,Xinputs.dtype)
Here is a good link to TensorFlow FAQ which describes this approach 'How do I build a graph that works with variable batch sizes?': https://www.tensorflow.org/resources/faq
来源:https://stackoverflow.com/questions/41668786/how-do-you-create-a-dynamic-rnn-with-dynamic-zero-state-fails-with-inference