- 从pytorch模型导出onnx模型,可以参考笔者的前一篇博文https://blog.csdn.net/ouening/article/details/109245243
- 使用netron查看onnx模型结构,如下图:
注意输入输出的名称name以及数据类型和维度type - 程序
import numpy as np # we're going to use numpy to process input and output data
import onnxruntime # to inference ONNX models, we use the ONNX Runtime
import onnx
from onnx import numpy_helper
import urllib.request
import json
import time
from imageio import imread
import warnings
warnings.filterwarnings('ignore')
# display images in notebook
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
onnx_model = r"D:\Files\python\opencv\调用pytorch-onnx模型\exported.onnx"
# Run the model on the backend
session = onnxruntime.InferenceSession(onnx_model, None)
# get the name of the first input of the model
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# print(len(session.get_outputs()))
print('Input Name:', input_name)
print('Output Name:', output_name)
img_file = r"C:\Users\LX\Pictures\elephant.jpg"
def load_labels():
classes = None
class_file = r"E:\ScientificComputing\opencv\sources\samples\data\dnn\classification_classes_ILSVRC2012.txt"
with open(class_file, 'rt') as f:
classes = f.read().rstrip('\n').split('\n')
return classes
def preprocess(input_data):
# convert the input data into the float32 input
img_data = input_data.astype('float32')
#normalize
mean_vec = np.array([0.485, 0.456, 0.406])
stddev_vec = np.array([0.229, 0.224, 0.225])
norm_img_data = np.zeros(img_data.shape).astype('float32')
for i in range(img_data.shape[0]):
norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
#add batch channel
norm_img_data = norm_img_data.reshape(1, 3, 224, 224).astype('float32')
return norm_img_data
def softmax(x):
x = x.reshape(-1)
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0)
def postprocess(result):
return softmax(np.array(result)).tolist()
image = Image.open(img_file).resize((224,224))
# image = Image.open('images/plane.jpg')
print("Image size: ", image.size)
plt.axis('off')
display_image = plt.imshow(image)
image_data = np.array(image).transpose(2, 0, 1)
input_data = preprocess(image_data)
#%%
start = time.time()
raw_result = session.run([], {
input_name: input_data})
end = time.time()
res = postprocess(raw_result)
inference_time = np.round((end - start) * 1000, 2)
idx = np.argmax(res)
labels = load_labels()
print('========================================')
print('Final top prediction is: ' + labels[idx])
print('========================================')
print('========================================')
print('Inference time: ' + str(inference_time) + " ms")
print('========================================')
sort_idx = np.flip(np.squeeze(np.argsort(res)))
print('============ Top 5 labels are: ============================')
# print(labels[sort_idx[:5]])
for k in sort_idx[:5]:
print(labels[k])
print('===========================================================')
plt.axis('off')
display_image = plt.imshow(image)
参考链接:https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/inference_demos/resnet50_modelzoo_onnxruntime_inference.ipynb
来源:oschina
链接:https://my.oschina.net/u/4258318/blog/4686791