I trained a resnet with tf.estimator, the model was saved during the training process. The saved files consist of .data
, .index
and .meta
If you have model pb or pb.txt then inference is easy. Using predictor module, we can do an inference. Check out here for more information. For image data it will be something to similar to below example. Hope this helps !!
Example code:
import numpy as np
import matplotlib.pyplot as plt
def extract_data(index=0, filepath='data/cifar-10-batches-bin/data_batch_5.bin'):
bytestream = open(filepath, mode='rb')
label_bytes_length = 1
image_bytes_length = (32 ** 2) * 3
record_bytes_length = label_bytes_length + image_bytes_length
bytestream.seek(record_bytes_length * index, 0)
label_bytes = bytestream.read(label_bytes_length)
image_bytes = bytestream.read(image_bytes_length)
label = np.frombuffer(label_bytes, dtype=np.uint8)
image = np.frombuffer(image_bytes, dtype=np.uint8)
image = np.reshape(image, [3, 32, 32])
image = np.transpose(image, [1, 2, 0])
image = image.astype(np.float32)
result = {
'image': image,
'label': label,
}
bytestream.close()
return result
predictor_fn = tf.contrib.predictor.from_saved_model(
export_dir = saved_model_dir, signature_def_key='predictions')
N = 1000
labels = []
images = []
for i in range(N):
result = extract_data(i)
images.append(result['image'])
labels.append(result['label'][0])
output = predictor_fn(
{
'images': images,
}
)
Note: This answer will evolve as soon as more information comes available. I'm not sure this is the most appropriate way to do it, but it feels better than using just comments. Feel free to drop a comment to the answer if this is inapproriate.
I don't have much experience with the import_meta_graph
method, but if sess.run(logits)
runs without complaining, I think the meta graph contains also your input pipeline.
A quick test I just made confirms that the pipeline is indeed restored too when you load the metagraph. This means, you're not actually passing in anything via feed_dict
, because the input is taken from the Dataset
-based pipeline that was used when the checkpoint was taken. From my research, I can't find a way to provide a different input function to the graph.
You code looks right to me, so my suspicion is that the checkpoint file that gets loaded is somehow wrong. I asked some clarifications in a comment, I'll update this part as soon as that info is available