Faster RCNN Tensorflow在测试得到result.txt文件

こ雲淡風輕ζ 提交于 2019-12-05 16:59:22

需要准备三个路径:

1、一个是进行测试时所使用的那些图片,找到其路径

2、result.txt所在的路径

3、生成图像的存放路径

  1 #!/usr/bin/env python
  2  
  3 
  4 from __future__ import absolute_import
  5 from __future__ import division
  6 from __future__ import print_function
  7  
  8 import _init_paths
  9 from model.config import cfg
 10 from model.test import im_detect
 11 from model.nms_wrapper import nms
 12  
 13 from utils.timer import Timer
 14 import tensorflow as tf
 15 import matplotlib.pyplot as plt
 16 from PIL import Image
 17 import numpy as np
 18 import os, cv2
 19 import argparse
 20  
 21  
 22 from nets.vgg16 import vgg16
 23 from nets.resnet_v1 import resnetv1
 24  
 25 CLASSES = ('__background__', 'dan','duo')
 26  
 27 NETS = {'vgg16': ('vgg16_faster_rcnn_iter_2000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
 28 DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
 29 
 30 
 31 def vis_detections(image_name, class_name, dets, thresh=0.7):
 32     """Draw detected bounding boxes."""
 33     inds = np.where(dets[:, -1] >= thresh)[0]
 34     if len(inds) == 0:
 35         return
 36     for i in inds:
 37         bbox = dets[i, :4]
 38         score = dets[i, -1]
 39         if(class_name == '__background__'):
 40             fw = open('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/VOCdevkit2007/VOC2007/ImageSets/Main/result.txt','a')   
 41             fw.write(str(image_name)+' '+class_name+' '+str(int(bbox[0]))+' '+str(int(bbox[1]))+' '+str(int(bbox[2]))+' '+str(int(bbox[3]))+'\n')
 42             fw.close()
 43         elif(class_name == 'dan'):
 44             fw = open('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/VOCdevkit2007/VOC2007/ImageSets/Main/result.txt','a')
 45             fw.write(str(image_name)+' '+class_name+' '+str(int(bbox[0]))+' '+str(int(bbox[1]))+' '+str(int(bbox[2]))+' '+str(int(bbox[3]))+'\n')
 46             fw.close()
 47         elif(class_name == 'duo'):
 48             fw = open('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/VOCdevkit2007/VOC2007/ImageSets/Main/result.txt','a')
 49             fw.write(str(image_name)+' '+class_name+' '+str(int(bbox[0]))+' '+str(int(bbox[1]))+' '+str(int(bbox[2]))+' '+str(int(bbox[3]))+'\n')
 50             fw.close()
 51  
 52  
 53 def demo(image_name, sess, net):
 54     im_file = os.path.join('/','home','bioinfo','Documents','pathonwork','lzh','tfasterrcnn', 'data', 'VOCdevkit2007', 'VOC2007', 'JPEGImages',image_name)
 55     #im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
 56     im = cv2.imread(im_file)
 57     # Detect all object classes and regress object bounds
 58     timer = Timer()
 59     timer.tic()
 60     scores, boxes = im_detect(sess, net, im)
 61     timer.toc()
 62     print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))
 63     # Visualize detections for each class
 64     CONF_THRESH = 0.7
 65     thresh=0.7
 66     NMS_THRESH = 0.3
 67     im = im[:, :, (2, 1, 0)]
 68     fig, ax = plt.subplots(figsize=(12, 12))
 69     ax.imshow(im, aspect='equal', alpha=0.5)
 70     for cls_ind, cls in enumerate(CLASSES[1:]):
 71         cls_ind += 1 # because we skipped background
 72         cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
 73         cls_scores = scores[:, cls_ind]
 74         dets = np.hstack((cls_boxes,
 75                           cls_scores[:, np.newaxis])).astype(np.float32)
 76         keep = nms(dets, NMS_THRESH)
 77         dets = dets[keep, :]
 78         vis_detections(image_name, cls, dets, thresh=CONF_THRESH)
 79         inds = np.where(dets[:, -1] >= thresh)[0]
 80         if len(inds) == 0:
 81             continue
 82         for i in inds:
 83             bbox = dets[i, :4]
 84             score = dets[i, -1]
 85             ax.add_patch(
 86                 plt.Rectangle((bbox[0], bbox[1]),
 87                               bbox[2] - bbox[0],
 88                               bbox[3] - bbox[1], fill=False,
 89                               edgecolor='red', linewidth=1.5)
 90                 )
 91             ax.text(bbox[0], bbox[1] - 2,
 92                 '{:s} {:.3f}'.format(cls, score),
 93                 bbox=dict(facecolor='blue', alpha=0.5),
 94                 fontsize=14, color='white')
 95     
 96     plt.axis('off')
 97     plt.tight_layout()
 98     plt.draw()
 99     image_name=image_name.replace('jpg','jpg')
100     plt.savefig('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/results/'+image_name)
101     print("save image to /home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/results/{}".format(image_name))
102  
103 def parse_args():
104     """Parse input arguments."""
105     parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
106     parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
107                         choices=NETS.keys(), default='vgg16')
108     parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
109                         choices=DATASETS.keys(), default='pascal_voc')
110     args = parser.parse_args()
111     
112     return args
113  
114 if __name__ == '__main__':
115     cfg.TEST.HAS_RPN = True # Use RPN for proposals
116     args = parse_args()
117     # model path
118     demonet = args.demo_net
119     dataset = args.dataset
120     tfmodel = ('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/output/vgg16/voc_2007_trainval/default/vgg16_faster_rcnn_iter_2000.ckpt')
121     
122     if not os.path.isfile(tfmodel + '.meta'):
123         raise IOError(('{:s} not found.\nDid you download the proper networks from '
124                       'our server and place them properly?').format(tfmodel + '.meta'))
125     # set config
126     tfconfig = tf.ConfigProto(allow_soft_placement=True)
127     tfconfig.gpu_options.allow_growth=True
128     # init session
129     sess = tf.Session(config=tfconfig)
130     # load network
131     if demonet == 'vgg16':
132         net = vgg16()
133     elif demonet == 'res101':
134         net = resnetv1(num_layers=101)
135     else:
136         raise NotImplementedError
137     net.create_architecture("TEST", 3,
138                           tag='default', anchor_scales=[8, 16, 32])
139     saver = tf.train.Saver()
140     saver.restore(sess, tfmodel)
141     print('Loaded network {:s}'.format(tfmodel))
142     
143     fi=open('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt','r')
144     txt=fi.readlines()
145     im_names = []
146     for line in txt:
147         line=line.strip('\n')
148         line=line.replace('\r','')
149         line=(line+'.jpg')
150         im_names.append(line)
151     print(im_names)
152     fi.close()
153     for im_name in im_names:
154         print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
155         print('Demo for data/demo/{}'.format(im_name))
156         demo(im_name, sess, net)
157     #plt.show()#

生成结果

 

参考文章:

1.https://blog.csdn.net/zxs0222/article/details/89605300

2.https://blog.csdn.net/gusui7202/article/details/83240212

 

作者:舟华520

出处:https://www.cnblogs.com/xfzh193/

本文以学习,分享,研究交流为主,欢迎转载,请标明作者出处!

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!