Restoring a model trained with tf.estimator and feeding input through feed_dict

前端 未结 2 468
既然无缘
既然无缘 2021-01-20 10:17

I trained a resnet with tf.estimator, the model was saved during the training process. The saved files consist of .data, .index and .meta

相关标签:
2条回答
  • 2021-01-20 11:03

    If you have model pb or pb.txt then inference is easy. Using predictor module, we can do an inference. Check out here for more information. For image data it will be something to similar to below example. Hope this helps !!

    Example code:

    import numpy as np
    import matplotlib.pyplot as plt
    
    def extract_data(index=0, filepath='data/cifar-10-batches-bin/data_batch_5.bin'):
        bytestream = open(filepath, mode='rb')
        label_bytes_length = 1
        image_bytes_length = (32 ** 2) * 3
        record_bytes_length = label_bytes_length + image_bytes_length
        bytestream.seek(record_bytes_length * index, 0)
        label_bytes = bytestream.read(label_bytes_length)
        image_bytes = bytestream.read(image_bytes_length)
        label = np.frombuffer(label_bytes, dtype=np.uint8)  
        image = np.frombuffer(image_bytes, dtype=np.uint8)
        image = np.reshape(image, [3, 32, 32])
        image = np.transpose(image, [1, 2, 0])
        image = image.astype(np.float32)
       result = {
         'image': image,
         'label': label,
       }
       bytestream.close()
       return result
    
    
        predictor_fn = tf.contrib.predictor.from_saved_model(
      export_dir = saved_model_dir, signature_def_key='predictions')
        N = 1000
        labels = []
        images = []
        for i in range(N):
           result = extract_data(i)
           images.append(result['image'])
           labels.append(result['label'][0])
        output = predictor_fn(
          {
            'images': images,
          }
        )
    
    0 讨论(0)
  • 2021-01-20 11:13

    Note: This answer will evolve as soon as more information comes available. I'm not sure this is the most appropriate way to do it, but it feels better than using just comments. Feel free to drop a comment to the answer if this is inapproriate.

    About your second attempt:

    I don't have much experience with the import_meta_graph method, but if sess.run(logits) runs without complaining, I think the meta graph contains also your input pipeline.

    A quick test I just made confirms that the pipeline is indeed restored too when you load the metagraph. This means, you're not actually passing in anything via feed_dict, because the input is taken from the Dataset-based pipeline that was used when the checkpoint was taken. From my research, I can't find a way to provide a different input function to the graph.

    About the first attempt:

    You code looks right to me, so my suspicion is that the checkpoint file that gets loaded is somehow wrong. I asked some clarifications in a comment, I'll update this part as soon as that info is available

    0 讨论(0)
提交回复
热议问题