Tensorflow版Faster RCNN源码解析(TFFRCNN) (15) VGGnet_train.py

徘徊边缘 提交于 2019-11-26 23:30:57

本blog为github上CharlesShang/TFFRCNN版源码解析系列代码笔记

---------------个人学习笔记---------------

----------------本文作者吴疆--------------

------点击此处链接至博客园原文------

 

与VGGnet_test.py相比,VGGnet_train.py需要馈入更多的变量,与train.py中train_model(...)函数定义的feed_dict相照应,此外,还增加了name为rpn-data、roi-data、drop6和drop7的网络处理层,keep_prob为dropout的比例

# train.py中train_model(...)函数定义的feed_dict
     feed_dict={
                self.net.data: blobs['data'],
                self.net.im_info: blobs['im_info'],
                self.net.keep_prob: 0.5,
                self.net.gt_boxes: blobs['gt_boxes'],
                self.net.gt_ishard: blobs['gt_ishard'],
                self.net.dontcare_areas: blobs['dontcare_areas']    
            }

VGGnet_train.py代码及注释如下:

import tensorflow as tf
from network import Network
from ..fast_rcnn.config import cfg

class VGGnet_train(Network):
    # 基类为Network,重构了__init__()
    def __init__(self, trainable=True):
        # 定义的变量比VGGnet_test.py中要多
        # 下一层输入(列表)
        self.inputs = []
        self.data = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='data')
        self.im_info = tf.placeholder(tf.float32, shape=[None, 3], name='im_info')
        # 与train.py中train_model(...)函数定义的feed_dict照应
        self.gt_boxes = tf.placeholder(tf.float32, shape=[None, 5], name='gt_boxes')
        self.gt_ishard = tf.placeholder(tf.int32, shape=[None], name='gt_ishard')
        self.dontcare_areas = tf.placeholder(tf.float32, shape=[None, 4], name='dontcare_areas')
        self.keep_prob = tf.placeholder(tf.float32)        # 定义dropout的比例!!!
        # 各层输出(字典)
        self.layers = dict({'data': self.data, 'im_info': self.im_info, 'gt_boxes': self.gt_boxes,\
                            'gt_ishard': self.gt_ishard, 'dontcare_areas': self.dontcare_areas})
        self.trainable = trainable
        self.setup()

    def setup(self):
        # n_classes = 2 #2018.1.30
        n_classes = cfg.NCLASSES
        # anchor_scales = [8, 16, 32]
        anchor_scales = cfg.ANCHOR_SCALES
        _feat_stride = [16, ]

        (self.feed('data')    #feed最后会返回self,下一层可以直接.xxx
             # conv3_1后卷积核参数才被更新,之前层权值不变
             .conv(3, 3, 64, 1, 1, name='conv1_1', trainable=False)
             .conv(3, 3, 64, 1, 1, name='conv1_2', trainable=False)
             .max_pool(2, 2, 2, 2, padding='VALID', name='pool1')
             .conv(3, 3, 128, 1, 1, name='conv2_1', trainable=False)
             .conv(3, 3, 128, 1, 1, name='conv2_2', trainable=False)
             .max_pool(2, 2, 2, 2, padding='VALID', name='pool2')   
             .conv(3, 3, 256, 1, 1, name='conv3_1')
             .conv(3, 3, 256, 1, 1, name='conv3_2')
             .conv(3, 3, 256, 1, 1, name='conv3_3')
             .max_pool(2, 2, 2, 2, padding='VALID', name='pool3')
             .conv(3, 3, 512, 1, 1, name='conv4_1')
             .conv(3, 3, 512, 1, 1, name='conv4_2')
             .conv(3, 3, 512, 1, 1, name='conv4_3')
             .max_pool(2, 2, 2, 2, padding='VALID', name='pool4')
             .conv(3, 3, 512, 1, 1, name='conv5_1')
             .conv(3, 3, 512, 1, 1, name='conv5_2')
             .conv(3, 3, 512, 1, 1, name='conv5_3'))

        #========= RPN ============ 
        (self.feed('conv5_3')
             .conv(3, 3, 512, 1, 1,name='rpn_conv/3x3'))
        # (1, H, W, A x 4)
        (self.feed('rpn_conv/3x3')
             .conv(1, 1, len(anchor_scales) * 3 * 4, 1, 1, padding='VALID', relu=False, name='rpn_bbox_pred'))
        # (1, H, W, A x 2)
        (self.feed('rpn_conv/3x3')
             .conv(1, 1, len(anchor_scales) * 3 * 2, 1, 1, padding='VALID', relu=False, name='rpn_cls_score'))

        # generating training labels on the fly 飞速写入
        # output: rpn_labels(HxWxA, 2) rpn_bbox_targets(HxWxA, 4) rpn_bbox_inside_weights rpn_bbox_outside_weights
        # 相比于VGGnet_test.py多的网络层次!!!
        # Produces anchor classification labels and bounding-box regression targets.
        (self.feed('rpn_cls_score', 'gt_boxes', 'gt_ishard', 'dontcare_areas', 'im_info')
             .anchor_target_layer(_feat_stride, anchor_scales, name='rpn-data' ))

        # 先reshape后softmax再reshape回来
        # shape is (1, H, W, Ax2) -> (1, H, WxA, 2)
        (self.feed('rpn_cls_score')
             .spatial_reshape_layer(2, name='rpn_cls_score_reshape')
             .spatial_softmax(name='rpn_cls_prob'))
        # shape is (1, H, WxA, 2) -> (1, H, W, Ax2)
        (self.feed('rpn_cls_prob')
             .spatial_reshape_layer(len(anchor_scales)*3*2, name='rpn_cls_prob_reshape'))

        # ========= RoI Proposal ============
        # add the delta(output) to anchors then
        # choose some reasonabel boxes, considering scores, ratios, size and iou
        # rpn_rois <- (1 x H x W x A, 5) e.g. [0, x1, y1, x2, y2]
        # 回归后并经过一些后处理得到的proposal,见proposal_layer_tf.py
        # 默认_feat_stride = [16, ]、anchor_scales = cfg.ANCHOR_SCALES = [8, 16, 32]、TEST模式
        (self.feed('rpn_cls_prob_reshape', 'rpn_bbox_pred', 'im_info')
             .proposal_layer(_feat_stride, anchor_scales, 'TRAIN', name='rpn_rois'))

        # 相比于VGGnet_test.py多的网络层次!!!
        # matching boxes and groundtruth and randomly sample some rois and labels for RCNN
        (self.feed('rpn_rois', 'gt_boxes', 'gt_ishard', 'dontcare_areas')
             .proposal_target_layer(n_classes, name='roi-data'))

        # ========= RCNN ============
        (self.feed('conv5_3', 'rois')
             .roi_pool(7, 7, 1.0/16, name='pool_5')
             .fc(4096, name='fc6')
             .dropout(0.5, name='drop6')  # 相比于VGGnet_test.py多的网络层次!!!
             .fc(4096, name='fc7')
             .dropout(0.5, name='drop7')  # 相比于VGGnet_test.py多的网络层次!!!
             .fc(n_classes, relu=False, name='cls_score')
             .softmax(name='cls_prob'))

        (self.feed('drop7')
            .fc(n_classes*4, relu=False, name='bbox_pred'))
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!