Restoring a model trained with tf.estimator and feeding input through feed_dict

风格不统一 提交于 2019-12-01 23:42:26

Note: This answer will evolve as soon as more information comes available. I'm not sure this is the most appropriate way to do it, but it feels better than using just comments. Feel free to drop a comment to the answer if this is inapproriate.

About your second attempt:

I don't have much experience with the import_meta_graph method, but if sess.run(logits) runs without complaining, I think the meta graph contains also your input pipeline.

A quick test I just made confirms that the pipeline is indeed restored too when you load the metagraph. This means, you're not actually passing in anything via feed_dict, because the input is taken from the Dataset-based pipeline that was used when the checkpoint was taken. From my research, I can't find a way to provide a different input function to the graph.

About the first attempt:

You code looks right to me, so my suspicion is that the checkpoint file that gets loaded is somehow wrong. I asked some clarifications in a comment, I'll update this part as soon as that info is available

If you have model pb or pb.txt then inference is easy. Using predictor module, we can do an inference. Check out here for more information. For image data it will be something to similar to below example. Hope this helps !!

Example code:

import numpy as np
import matplotlib.pyplot as plt

def extract_data(index=0, filepath='data/cifar-10-batches-bin/data_batch_5.bin'):
    bytestream = open(filepath, mode='rb')
    label_bytes_length = 1
    image_bytes_length = (32 ** 2) * 3
    record_bytes_length = label_bytes_length + image_bytes_length
    bytestream.seek(record_bytes_length * index, 0)
    label_bytes = bytestream.read(label_bytes_length)
    image_bytes = bytestream.read(image_bytes_length)
    label = np.frombuffer(label_bytes, dtype=np.uint8)  
    image = np.frombuffer(image_bytes, dtype=np.uint8)
    image = np.reshape(image, [3, 32, 32])
    image = np.transpose(image, [1, 2, 0])
    image = image.astype(np.float32)
   result = {
     'image': image,
     'label': label,
   }
   bytestream.close()
   return result


    predictor_fn = tf.contrib.predictor.from_saved_model(
  export_dir = saved_model_dir, signature_def_key='predictions')
    N = 1000
    labels = []
    images = []
    for i in range(N):
       result = extract_data(i)
       images.append(result['image'])
       labels.append(result['label'][0])
    output = predictor_fn(
      {
        'images': images,
      }
    )
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!