onnxruntime加载pytorch图像分类模型

自作多情 提交于 2020-10-24 07:31:53
  • 从pytorch模型导出onnx模型,可以参考笔者的前一篇博文https://blog.csdn.net/ouening/article/details/109245243
  • 使用netron查看onnx模型结构,如下图:
    在这里插入图片描述
    注意输入输出的名称name以及数据类型和维度type

  • 程序
import numpy as np    # we're going to use numpy to process input and output data
import onnxruntime    # to inference ONNX models, we use the ONNX Runtime
import onnx
from onnx import numpy_helper
import urllib.request
import json
import time
from imageio import imread
import warnings
warnings.filterwarnings('ignore')
# display images in notebook
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont

onnx_model = r"D:\Files\python\opencv\调用pytorch-onnx模型\exported.onnx"

# Run the model on the backend
session = onnxruntime.InferenceSession(onnx_model, None)

# get the name of the first input of the model
input_name = session.get_inputs()[0].name  
output_name = session.get_outputs()[0].name  
# print(len(session.get_outputs()))
print('Input Name:', input_name)
print('Output Name:', output_name)

img_file = r"C:\Users\LX\Pictures\elephant.jpg"
def load_labels():
    classes = None
    class_file = r"E:\ScientificComputing\opencv\sources\samples\data\dnn\classification_classes_ILSVRC2012.txt"
    with open(class_file, 'rt') as f:
        classes = f.read().rstrip('\n').split('\n')
    return classes

def preprocess(input_data):
    # convert the input data into the float32 input
    img_data = input_data.astype('float32')

    #normalize
    mean_vec = np.array([0.485, 0.456, 0.406])
    stddev_vec = np.array([0.229, 0.224, 0.225])
    norm_img_data = np.zeros(img_data.shape).astype('float32')
    for i in range(img_data.shape[0]):
        norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
        
    #add batch channel
    norm_img_data = norm_img_data.reshape(1, 3, 224, 224).astype('float32')
    return norm_img_data

def softmax(x):
    x = x.reshape(-1)
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def postprocess(result):
    return softmax(np.array(result)).tolist()

image = Image.open(img_file).resize((224,224))
# image = Image.open('images/plane.jpg')

print("Image size: ", image.size)
plt.axis('off')
display_image = plt.imshow(image)
image_data = np.array(image).transpose(2, 0, 1)
input_data = preprocess(image_data)

#%%
start = time.time()
raw_result = session.run([], {
   
   input_name: input_data})
end = time.time()
res = postprocess(raw_result)

inference_time = np.round((end - start) * 1000, 2)
idx = np.argmax(res)
labels = load_labels()
print('========================================')
print('Final top prediction is: ' + labels[idx])
print('========================================')

print('========================================')
print('Inference time: ' + str(inference_time) + " ms")
print('========================================')

sort_idx = np.flip(np.squeeze(np.argsort(res)))
print('============ Top 5 labels are: ============================')
# print(labels[sort_idx[:5]])
for k in sort_idx[:5]:
    print(labels[k])
print('===========================================================')

plt.axis('off')
display_image = plt.imshow(image)

参考链接:https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/inference_demos/resnet50_modelzoo_onnxruntime_inference.ipynb

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!