需要准备三个路径:
1、一个是进行测试时所使用的那些图片,找到其路径
2、result.txt所在的路径
3、生成图像的存放路径
1 #!/usr/bin/env python 2 3 4 from __future__ import absolute_import 5 from __future__ import division 6 from __future__ import print_function 7 8 import _init_paths 9 from model.config import cfg 10 from model.test import im_detect 11 from model.nms_wrapper import nms 12 13 from utils.timer import Timer 14 import tensorflow as tf 15 import matplotlib.pyplot as plt 16 from PIL import Image 17 import numpy as np 18 import os, cv2 19 import argparse 20 21 22 from nets.vgg16 import vgg16 23 from nets.resnet_v1 import resnetv1 24 25 CLASSES = ('__background__', 'dan','duo') 26 27 NETS = {'vgg16': ('vgg16_faster_rcnn_iter_2000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)} 28 DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)} 29 30 31 def vis_detections(image_name, class_name, dets, thresh=0.7): 32 """Draw detected bounding boxes.""" 33 inds = np.where(dets[:, -1] >= thresh)[0] 34 if len(inds) == 0: 35 return 36 for i in inds: 37 bbox = dets[i, :4] 38 score = dets[i, -1] 39 if(class_name == '__background__'): 40 fw = open('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/VOCdevkit2007/VOC2007/ImageSets/Main/result.txt','a') 41 fw.write(str(image_name)+' '+class_name+' '+str(int(bbox[0]))+' '+str(int(bbox[1]))+' '+str(int(bbox[2]))+' '+str(int(bbox[3]))+'\n') 42 fw.close() 43 elif(class_name == 'dan'): 44 fw = open('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/VOCdevkit2007/VOC2007/ImageSets/Main/result.txt','a') 45 fw.write(str(image_name)+' '+class_name+' '+str(int(bbox[0]))+' '+str(int(bbox[1]))+' '+str(int(bbox[2]))+' '+str(int(bbox[3]))+'\n') 46 fw.close() 47 elif(class_name == 'duo'): 48 fw = open('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/VOCdevkit2007/VOC2007/ImageSets/Main/result.txt','a') 49 fw.write(str(image_name)+' '+class_name+' '+str(int(bbox[0]))+' '+str(int(bbox[1]))+' '+str(int(bbox[2]))+' '+str(int(bbox[3]))+'\n') 50 fw.close() 51 52 53 def demo(image_name, sess, net): 54 im_file = os.path.join('/','home','bioinfo','Documents','pathonwork','lzh','tfasterrcnn', 'data', 'VOCdevkit2007', 'VOC2007', 'JPEGImages',image_name) 55 #im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name) 56 im = cv2.imread(im_file) 57 # Detect all object classes and regress object bounds 58 timer = Timer() 59 timer.tic() 60 scores, boxes = im_detect(sess, net, im) 61 timer.toc() 62 print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0])) 63 # Visualize detections for each class 64 CONF_THRESH = 0.7 65 thresh=0.7 66 NMS_THRESH = 0.3 67 im = im[:, :, (2, 1, 0)] 68 fig, ax = plt.subplots(figsize=(12, 12)) 69 ax.imshow(im, aspect='equal', alpha=0.5) 70 for cls_ind, cls in enumerate(CLASSES[1:]): 71 cls_ind += 1 # because we skipped background 72 cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)] 73 cls_scores = scores[:, cls_ind] 74 dets = np.hstack((cls_boxes, 75 cls_scores[:, np.newaxis])).astype(np.float32) 76 keep = nms(dets, NMS_THRESH) 77 dets = dets[keep, :] 78 vis_detections(image_name, cls, dets, thresh=CONF_THRESH) 79 inds = np.where(dets[:, -1] >= thresh)[0] 80 if len(inds) == 0: 81 continue 82 for i in inds: 83 bbox = dets[i, :4] 84 score = dets[i, -1] 85 ax.add_patch( 86 plt.Rectangle((bbox[0], bbox[1]), 87 bbox[2] - bbox[0], 88 bbox[3] - bbox[1], fill=False, 89 edgecolor='red', linewidth=1.5) 90 ) 91 ax.text(bbox[0], bbox[1] - 2, 92 '{:s} {:.3f}'.format(cls, score), 93 bbox=dict(facecolor='blue', alpha=0.5), 94 fontsize=14, color='white') 95 96 plt.axis('off') 97 plt.tight_layout() 98 plt.draw() 99 image_name=image_name.replace('jpg','jpg') 100 plt.savefig('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/results/'+image_name) 101 print("save image to /home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/results/{}".format(image_name)) 102 103 def parse_args(): 104 """Parse input arguments.""" 105 parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo') 106 parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]', 107 choices=NETS.keys(), default='vgg16') 108 parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]', 109 choices=DATASETS.keys(), default='pascal_voc') 110 args = parser.parse_args() 111 112 return args 113 114 if __name__ == '__main__': 115 cfg.TEST.HAS_RPN = True # Use RPN for proposals 116 args = parse_args() 117 # model path 118 demonet = args.demo_net 119 dataset = args.dataset 120 tfmodel = ('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/output/vgg16/voc_2007_trainval/default/vgg16_faster_rcnn_iter_2000.ckpt') 121 122 if not os.path.isfile(tfmodel + '.meta'): 123 raise IOError(('{:s} not found.\nDid you download the proper networks from ' 124 'our server and place them properly?').format(tfmodel + '.meta')) 125 # set config 126 tfconfig = tf.ConfigProto(allow_soft_placement=True) 127 tfconfig.gpu_options.allow_growth=True 128 # init session 129 sess = tf.Session(config=tfconfig) 130 # load network 131 if demonet == 'vgg16': 132 net = vgg16() 133 elif demonet == 'res101': 134 net = resnetv1(num_layers=101) 135 else: 136 raise NotImplementedError 137 net.create_architecture("TEST", 3, 138 tag='default', anchor_scales=[8, 16, 32]) 139 saver = tf.train.Saver() 140 saver.restore(sess, tfmodel) 141 print('Loaded network {:s}'.format(tfmodel)) 142 143 fi=open('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt','r') 144 txt=fi.readlines() 145 im_names = [] 146 for line in txt: 147 line=line.strip('\n') 148 line=line.replace('\r','') 149 line=(line+'.jpg') 150 im_names.append(line) 151 print(im_names) 152 fi.close() 153 for im_name in im_names: 154 print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') 155 print('Demo for data/demo/{}'.format(im_name)) 156 demo(im_name, sess, net) 157 #plt.show()#
生成结果
参考文章:
1.https://blog.csdn.net/zxs0222/article/details/89605300
2.https://blog.csdn.net/gusui7202/article/details/83240212
作者:舟华520
出处:https://www.cnblogs.com/xfzh193/
本文以学习,分享,研究交流为主,欢迎转载,请标明作者出处!