I trained Cifar10 example model from TensorFlow\'s repository with batch_size 128 and it worked fine. Then I froze graph and managed to run it with C++ just like they do it in t
It sounds like the problem is that TensorFlow is "baking in" the batch size to other tensors in the graph (e.g. if the graph contains tf.shape(t)
for some tensor t
whose shape depends on the batch size, the batch size might be stored in the graph as a constant). The solution is to change your program slightly so that tf.train.batch()
returns tensors with a variable batch size.
The tf.train.batch()
method accepts a tf.Tensor
for the batch_size
argument. Perhaps the simplest way to modify your program for variable-sized batches would be to define a placeholder for the batch size:
# Define a scalar tensor for the batch size, so that you can alter it at
# Session.run()-time.
batch_size_tensor = tf.placeholder(tf.int32, shape=[])
input_tensors = tf.train.batch(..., batch_size=batch_size_tensor, ...)
This would prevent the batch size from being baked into your GraphDef, so you should be able to feed values of any batch size in C++. However, this modification would require you to feed a value for the batch size on every step, which is slightly tedious.
Assuming that you always want to train with batch size 128, but retain the flexibility to change the batch size later, you could use a tf.placeholder_with_default() to specify that the batch size should be 128 when you don't feed an alternative value:
# Define a scalar tensor for the batch size, so that you can alter it at
# Session.run()-time.
batch_size_tensor = tf.placeholder_with_default(128, shape=[])
input_tensors = tf.train.batch(..., batch_size=batch_size_tensor, ...)
Is there a reason you need fixed batch size in the graph?
I think a good way is to build a graph with a variable batch size - by putting None as the first dimension. During training, you can then pass the batch size flag to your data provider, so it feeds the desired amount of data in each iteration.
After the model is trained, you can export the graph using tf.train.Saver()
, which exports the metagraph. To do inference, you can load the exported files and just evaluate with any number of examples - also just one.
Note, this is different from the frozen graph.