语义分割之车道线检测(tensorflow版)

痴心易碎 提交于 2019-12-03 20:22:47

一、背景:

      由于项目需要,参考了多篇相关车道线检测论文与源码,设计了一套Tensorflow版车道线检测功能。

二、基本结构:

      该模型主要由以下部分组成:

1、数据源:包括所有原始数据,分组后的数据;

2、数据预处理:包括数据的准备,数据的导入,数据的提取,数据的分组(训练与测试);

3、配置文件:包括各种参数与超参数,如:训练周期,训练步长,批量数据,学习率,卷积核大小,全连接大小,训练模型存放路径(checkpoint),摘要存放路径(summary)等;

4、基础网络:包括基本的网络组件,基础网,

5、训练主文件:主入口,用于搭建生成图(graph),会话(sess),数据导入模型训练,GPU配置,训练过程打印等

 

三、代码结构

             以下为原始文件夹:

    ./data

               -- ./InstanceSegmentationClass

               -- ./JPEGImages

               -- ./SegmentationClass

               -- datasets_gen_culane.py 用于从上面三个图片目录生成list.txt,train.txt,test.txt

# coding=utf-8
#create date:12/5/2018
#modified date:2/12/2019
#author:jim.chen
import os
import glob
import random
import math
import cv2
import numpy as np


def gen_list_txt(rela_dir,img_dir,img_seg_dir,img_inst_dir):
    #cwd = os.getcwd()
    #print("gen_list_txt cwd:",cwd)
    list_txt = "list.txt"
    png_glob = img_seg_dir+'/*.png'
    png_list_path = glob.glob(png_glob)
    png_list=[]
    print("gen_list_txt png_list_path:",png_list_path)
    with open(list_txt,"w") as w_f:
        for png in png_list_path:
            path,name = os.path.splitext(os.path.basename(png))
            print("path:",path)
            w_f.write(rela_dir+img_dir+'/'+path+'.jpg'+' '+rela_dir+png+' '+rela_dir+img_inst_dir+'/'+path+'.png'+'\n')

    w_f.close()

    with open(list_txt,"r") as r_f:
        for each_line in r_f:
            png_list.append(each_line)

    png_list.sort()
    print("gen_list_txt len(png_list):",len(png_list))
    train=random.sample(png_list,int(math.floor(len(png_list)*9/10)))
    train.sort()
    print("gen_list_txt train:",train)
    val=list(set(png_list).difference(set(train)))
    print("gen_list_txt val:",val)
    enum_train_val=['train','val']
    for item in enum_train_val:
        with open(item+'.txt','w') as w1_f:
            for num_item in eval(item):
                print("gen_list_txt num_item:",num_item)
                w1_f.write(num_item)
   
    w1_f.close()
    
def sync_gt_2_img(img_dir,img_seg_dir,img_inst_dir):
    cwd = os.getcwd()
    print("sync_gt_2_img img_dir:",img_dir," img_seg_dir:",img_seg_dir," img_inst_dir:",img_inst_dir)
    img_full_dir = cwd + '/' +img_dir
    img_seg_full_dir = cwd + '/' +img_seg_dir
    img_inst_full_dir = cwd + '/' +img_inst_dir    
    img_list = os.listdir(img_full_dir)
    for img in img_list:
        img_basename = os.path.splitext(img)[0]
        print("sync_gt_2_img img_basename:",img_basename)
        img_full_path =  img_full_dir + '/'+ img   
        img_seg_full_path =  img_seg_full_dir + '/'+img_basename +'.png'
        #print("sync_gt_2_img img_seg_full_path:",img_seg_full_path)
        img_inst_full_path =  img_inst_full_dir + '/'+img_basename +'.png'  
        #print("sync_gt_2_img img_inst_full_path:",img_inst_full_path)
        if not os.path.exists(img_inst_full_path):
            print("sync_gt_2_img not os.path.exists(img_seg_full_path)")
            if os.path.exists(img_full_path):
                os.remove(img_full_path)
            if os.path.exists(img_full_path):
                os.remove(img_full_path)           
    
def sync_seg_2_inst(img_seg_dir,img_inst_dir):
    cwd = os.getcwd()
    print("sync_seg_2_inst img_seg_dir:",img_seg_dir," img_inst_dir:",img_inst_dir)
    img_seg_full_dir = cwd + '/' +img_seg_dir
    img_inst_full_dir = cwd + '/' +img_inst_dir
    img_list = os.listdir(img_seg_dir)
    for img in img_list:
        img_basename = os.path.splitext(img)[0]
        print("sync_seg_2_inst img_basename:",img_basename)
        img_seg_full_path =  img_seg_full_dir + '/'+img_basename +'.jpg'
        img_inst_full_path =  img_inst_full_dir + '/'+img_basename +'.png'
        if not os.path.exists(img_inst_full_path):
            if os.path.exists(img_seg_full_path):
                print("sync_seg_2_inst os.remove(img_seg_full_path):",img_seg_full_path)
                os.remove(img_seg_full_path)
    
def gen_seg_color(img_inst_dir,img_seg_dir):
    cwd = os.getcwd()
    inPath = os.path.join(cwd,img_inst_dir)
    print(inPath)
    outPath=os.path.join(cwd,img_seg_dir)
    inPathDir = os.listdir(inPath)
    if not os.path.exists(outPath):
          os.makedirs(outPath)

    for l,file_name in enumerate(inPathDir):
        img_instance = cv2.imread(os.path.join(inPath,file_name))
        h,w,c = img_instance.shape
        print("l:",l," img_instance.shape:",img_instance.shape)  
        img_instance_new = np.zeros((h, w, c), dtype=np.uint8)
        for i in range(0,h):
              for j in range(0,w):
                  #print(img_instance[i][j])
                  if img_instance[i][j][0] != 0:
                      img_instance_new[i][j] = [255,255,255]
        img_instance_gray = cv2.cvtColor(img_instance_new, cv2.COLOR_BGR2GRAY)
        cv2.imwrite(os.path.join(outPath,file_name), img_instance_gray)
    print("generate segment finished!")
    
def gen_inst_color(img_inst_dir):
    cwd = os.getcwd()
    inPath = os.path.join(cwd,img_inst_dir)
    print(inPath)
    outPath=os.path.join(cwd,"img_inst_new")
    inPathDir = os.listdir(inPath)
    if not os.path.exists(outPath):
          os.makedirs(outPath)

    for l,file_name in enumerate(inPathDir):
        img_instance = cv2.imread(os.path.join(inPath,file_name))
        h,w,c = img_instance.shape
        print("l:",l," img_instance.shape:",img_instance.shape)
        img_instance_new = np.zeros((h, w, c), dtype=np.uint8)
        for i in range(0,h):
              for j in range(0,w):
                  #print(img_instance[i][j])
                  if img_instance[i][j][0] == 2:
                      img_instance_new[i][j] = [20,20,20]
                  elif img_instance[i][j][0] == 3:
                      img_instance_new[i][j] = [70,70,70]              
                  elif img_instance[i][j][0] == 4:
                      img_instance_new[i][j] = [120,120,120]     
                  elif img_instance[i][j][0] == 5:
                      img_instance_new[i][j] = [170,170,170]              
                  elif img_instance[i][j][0] == 6:
                      img_instance_new[i][j] = [220,220,220]
        img_instance_gray = cv2.cvtColor(img_instance_new, cv2.COLOR_BGR2GRAY)
        cv2.imwrite(os.path.join(outPath,file_name), img_instance_gray)
    print("generate instance finished!")
    

def detect_invalid_img(img_path):    
    img_instance = cv2.imread(img_path)
    h,w,c = img_instance.shape
    print("detect_invalid_img img_instance.shape:",img_instance.shape)
    for i in range(0,h):
          for j in range(0,w):
              if img_instance[i][j][0] != 0:
                  return False
    return True
                     
    
def filter_invalid_img(img_test_dir,img_seg_dir,img_inst_dir):
    cwd = os.getcwd()
    inPath = os.path.join(cwd,img_test_dir)
    inPathDir = os.listdir(inPath)
    print("filter_invalid_img inPathDir:",inPathDir)
    for l,file_name in enumerate(inPathDir):
        img_path = os.path.join(inPath,file_name)
        isdel = detect_invalid_img(img_path)
        if isdel:
            print("filter_invalid_img isdel:",isdel)
            os.remove(os.path.join(inPath,file_name))

def main():
    print("main begin")
    rela_dir = "data/datasets_culane_all/"
    img_dir = "image"
    img_seg_dir = "gt_image_binary"
    img_inst_dir = "gt_image_instance"
    gen_list_txt(rela_dir,img_dir,img_seg_dir,img_inst_dir)
    #sync_gt_2_img(img_dir,img_seg_dir,img_inst_dir)
    #gen_seg_color(img_inst_dir,img_seg_dir)
    #gen_inst_color(img_inst_dir)
    #filter_invalid_img(img_dir,img_seg_dir,img_inst_dir)
    #sync_seg_2_inst(img_dir,img_inst_dir)
    print("main end")
    
if __name__ == '__main__':
    main()
    
View Code

               --list.txt

data/datasets_culane/JPEGImages/0000.jpg data/datasets_culane/SegmentationClass\0000.png data/datasets_culane/InstanceSegmentationClass/0000.png
data/datasets_culane/JPEGImages/0001.jpg data/datasets_culane/SegmentationClass\0001.png data/datasets_culane/InstanceSegmentationClass/0001.png
data/datasets_culane/JPEGImages/0002.jpg data/datasets_culane/SegmentationClass\0002.png data/datasets_culane/InstanceSegmentationClass/0002.png
data/datasets_culane/JPEGImages/0003.jpg data/datasets_culane/SegmentationClass\0003.png data/datasets_culane/InstanceSegmentationClass/0003.png
data/datasets_culane/JPEGImages/0004.jpg data/datasets_culane/SegmentationClass\0004.png data/datasets_culane/InstanceSegmentationClass/0004.png
data/datasets_culane/JPEGImages/0005.jpg data/datasets_culane/SegmentationClass\0005.png data/datasets_culane/InstanceSegmentationClass/0005.png
View Code

     --train.txt

data/datasets_culane/JPEGImages/0000.jpg data/datasets_culane/SegmentationClass\0000.png data/datasets_culane/InstanceSegmentationClass/0000.png
data/datasets_culane/JPEGImages/0001.jpg data/datasets_culane/SegmentationClass\0001.png data/datasets_culane/InstanceSegmentationClass/0001.png
data/datasets_culane/JPEGImages/0002.jpg data/datasets_culane/SegmentationClass\0002.png data/datasets_culane/InstanceSegmentationClass/0002.png
data/datasets_culane/JPEGImages/0004.jpg data/datasets_culane/SegmentationClass\0004.png data/datasets_culane/InstanceSegmentationClass/0004.png
View Code

     --val.txt

data/datasets_culane/JPEGImages/0005.jpg data/datasets_culane/SegmentationClass\0005.png data/datasets_culane/InstanceSegmentationClass/0005.png
data/datasets_culane/JPEGImages/0003.jpg data/datasets_culane/SegmentationClass\0003.png data/datasets_culane/InstanceSegmentationClass/0003.png
View Code

            ./data_provider

               --data_processor.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os.path as ops

import cv2
import numpy as np

try:
    from cv2 import cv2
except ImportError:
    pass


class DataSet(object):
    def __init__(self, dataset_info_file):
        self._gt_img_list, self._gt_label_binary_list, \
        self._gt_label_instance_list = self._init_dataset(dataset_info_file)
        self._random_dataset()
        self._next_batch_loop_count = 0

    def _init_dataset(self, dataset_info_file):
        gt_img_list = []
        gt_label_binary_list = []
        gt_label_instance_list = []

        assert ops.exists(dataset_info_file), '{:s} not exist'.format(dataset_info_file)

        with open(dataset_info_file, 'r') as file:
            for _info in file:
                info_tmp = _info.strip(' ').split()

                gt_img_list.append(info_tmp[0])
                gt_label_binary_list.append(info_tmp[1])
                gt_label_instance_list.append(info_tmp[2])

        return gt_img_list, gt_label_binary_list, gt_label_instance_list

    def _random_dataset(self):
        assert len(self._gt_img_list) == len(self._gt_label_binary_list) == len(self._gt_label_instance_list)

        random_idx = np.random.permutation(len(self._gt_img_list))
        new_gt_img_list = []
        new_gt_label_binary_list = []
        new_gt_label_instance_list = []

        for index in random_idx:
            new_gt_img_list.append(self._gt_img_list[index])
            new_gt_label_binary_list.append(self._gt_label_binary_list[index])
            new_gt_label_instance_list.append(self._gt_label_instance_list[index])

        self._gt_img_list = new_gt_img_list
        self._gt_label_binary_list = new_gt_label_binary_list
        self._gt_label_instance_list = new_gt_label_instance_list

    def next_batch(self, batch_size):
        """

        :param batch_size:
        :return:
        """
        assert len(self._gt_label_binary_list) == len(self._gt_label_instance_list) \
               == len(self._gt_img_list)

        idx_start = batch_size * self._next_batch_loop_count
        idx_end = batch_size * self._next_batch_loop_count + batch_size

        if idx_start == 0 and idx_end > len(self._gt_label_binary_list):
            raise ValueError('Batch size cant be more than total numbers')

        if idx_end > len(self._gt_label_binary_list):
            self._random_dataset()
            self._next_batch_loop_count = 0
            return self.next_batch(batch_size)
        else:
            gt_img_list = self._gt_img_list[idx_start:idx_end]
            gt_label_binary_list = self._gt_label_binary_list[idx_start:idx_end]
            gt_label_instance_list = self._gt_label_instance_list[idx_start:idx_end]

            gt_imgs = []
            gt_labels_binary = []
            gt_labels_instance = []

            for gt_img_path in gt_img_list:
                gt_imgs.append(cv2.imread(gt_img_path, cv2.IMREAD_COLOR))

            for gt_label_path in gt_label_binary_list:
                label_img = cv2.imread(gt_label_path, cv2.IMREAD_COLOR)
                label_binary = np.zeros([label_img.shape[0], label_img.shape[1]], dtype=np.uint8)
                idx = np.where((label_img[:, :, :] != [0, 0, 0]).all(axis=2))
                label_binary[idx] = 1
                gt_labels_binary.append(label_binary)

            for gt_label_path in gt_label_instance_list:
                label_img = cv2.imread(gt_label_path, cv2.IMREAD_UNCHANGED)
                gt_labels_instance.append(label_img)

            self._next_batch_loop_count += 1
            return gt_imgs, gt_labels_binary, gt_labels_instance


if __name__ == '__main__':
    val = DataSet('/media/baidu/Data/Semantic_Segmentation/TUSimple_Lane_Detection/training/val.txt')
    b1, b2, b3 = val.next_batch(50)
    c1, c2, c3 = val.next_batch(50)
    dd, d2, d3 = val.next_batch(50)
View Code

               ./config   

                --global_config.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from easydict import EasyDict as edict

__C = edict()
# Consumers can get config by: from config import cfg

cfg = __C

# Train options
__C.TRAIN = edict()

# Set the shadownet training epochs
__C.TRAIN.EPOCHS = 200010
# Set the display step
__C.TRAIN.DISPLAY_STEP = 1
# Set the test display step during training process
__C.TRAIN.TEST_DISPLAY_STEP = 1000
# Set the momentum parameter of the optimizer
__C.TRAIN.MOMENTUM = 0.9
# Set the initial learning rate
__C.TRAIN.LEARNING_RATE = 0.0005
# Set the GPU resource used during training process
__C.TRAIN.GPU_MEMORY_FRACTION = 0.85
# Set the GPU allow growth parameter during tensorflow training process
__C.TRAIN.TF_ALLOW_GROWTH = True
# Set the shadownet training batch size
__C.TRAIN.BATCH_SIZE = 1

# Set the shadownet validation batch size
__C.TRAIN.VAL_BATCH_SIZE = 1
# Set the learning rate decay steps
__C.TRAIN.LR_DECAY_STEPS = 410000
# Set the learning rate decay rate
__C.TRAIN.LR_DECAY_RATE = 0.1
# Set the class numbers
__C.TRAIN.CLASSES_NUMS = 2
# Set the image height
__C.TRAIN.IMG_HEIGHT = 256
# Set the image width
__C.TRAIN.IMG_WIDTH = 512

# Test options
__C.TEST = edict()

# Set the GPU resource used during testing process
__C.TEST.GPU_MEMORY_FRACTION = 0.8
# Set the GPU allow growth parameter during tensorflow testing process
__C.TEST.TF_ALLOW_GROWTH = True
# Set the test batch size
__C.TEST.BATCH_SIZE = 1
View Code

    ./encoder_decoder_model

    --cnn_basenet.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
The base convolution neural networks mainly implement some useful cnn functions
"""
import tensorflow as tf
import numpy as np


class CNNBaseModel(object):
    """
    Base model for other specific cnn ctpn_models
    """

    def __init__(self):
        pass

    @staticmethod
    def conv2d(inputdata, out_channel, kernel_size, padding='SAME',
               stride=1, w_init=None, b_init=None,
               split=1, use_bias=True, data_format='NHWC', name=None):
        with tf.variable_scope(name):
            in_shape = inputdata.get_shape().as_list()
            channel_axis = 3 if data_format == 'NHWC' else 1
            in_channel = in_shape[channel_axis]
            assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"
            assert in_channel % split == 0
            assert out_channel % split == 0

            padding = padding.upper()

            if isinstance(kernel_size, list):
                filter_shape = [kernel_size[0], kernel_size[1]] + [in_channel / split, out_channel]
            else:
                filter_shape = [kernel_size, kernel_size] + [in_channel / split, out_channel]

            if isinstance(stride, list):
                strides = [1, stride[0], stride[1], 1] if data_format == 'NHWC' \
                    else [1, 1, stride[0], stride[1]]
            else:
                strides = [1, stride, stride, 1] if data_format == 'NHWC' \
                    else [1, 1, stride, stride]

            if w_init is None:
                w_init = tf.contrib.layers.variance_scaling_initializer()
            if b_init is None:
                b_init = tf.constant_initializer()

            w = tf.get_variable('W', filter_shape, initializer=w_init)
            b = None

            if use_bias:
                b = tf.get_variable('b', [out_channel], initializer=b_init)

            if split == 1:
                conv = tf.nn.conv2d(inputdata, w, strides, padding, data_format=data_format)
            else:
                inputs = tf.split(inputdata, split, channel_axis)
                kernels = tf.split(w, split, 3)
                outputs = [tf.nn.conv2d(i, k, strides, padding, data_format=data_format)
                           for i, k in zip(inputs, kernels)]
                conv = tf.concat(outputs, channel_axis)

            ret = tf.identity(tf.nn.bias_add(conv, b, data_format=data_format)
                              if use_bias else conv, name=name)

        return ret

    @staticmethod
    def relu(inputdata, name=None):
        return tf.nn.relu(features=inputdata, name=name)

    @staticmethod
    def sigmoid(inputdata, name=None):
        return tf.nn.sigmoid(x=inputdata, name=name)

    @staticmethod
    def maxpooling(inputdata, kernel_size, stride=None, padding='VALID',
                   data_format='NHWC', name=None):
        padding = padding.upper()

        if stride is None:
            stride = kernel_size

        if isinstance(kernel_size, list):
            kernel = [1, kernel_size[0], kernel_size[1], 1] if data_format == 'NHWC' else \
                [1, 1, kernel_size[0], kernel_size[1]]
        else:
            kernel = [1, kernel_size, kernel_size, 1] if data_format == 'NHWC' \
                else [1, 1, kernel_size, kernel_size]

        if isinstance(stride, list):
            strides = [1, stride[0], stride[1], 1] if data_format == 'NHWC' \
                else [1, 1, stride[0], stride[1]]
        else:
            strides = [1, stride, stride, 1] if data_format == 'NHWC' \
                else [1, 1, stride, stride]

        return tf.nn.max_pool(value=inputdata, ksize=kernel, strides=strides, padding=padding,
                              data_format=data_format, name=name)

    @staticmethod
    def avgpooling(inputdata, kernel_size, stride=None, padding='VALID',
                   data_format='NHWC', name=None):
        if stride is None:
            stride = kernel_size

        kernel = [1, kernel_size, kernel_size, 1] if data_format == 'NHWC' \
            else [1, 1, kernel_size, kernel_size]

        strides = [1, stride, stride, 1] if data_format == 'NHWC' else [1, 1, stride, stride]

        return tf.nn.avg_pool(value=inputdata, ksize=kernel, strides=strides, padding=padding,
                              data_format=data_format, name=name)

    @staticmethod
    def globalavgpooling(inputdata, data_format='NHWC', name=None):
        assert inputdata.shape.ndims == 4
        assert data_format in ['NHWC', 'NCHW']

        axis = [1, 2] if data_format == 'NHWC' else [2, 3]

        return tf.reduce_mean(input_tensor=inputdata, axis=axis, name=name)

    @staticmethod
    def layernorm(inputdata, epsilon=1e-5, use_bias=True, use_scale=True,
                  data_format='NHWC', name=None):
        shape = inputdata.get_shape().as_list()
        ndims = len(shape)
        assert ndims in [2, 4]

        mean, var = tf.nn.moments(inputdata, list(range(1, len(shape))), keep_dims=True)

        if data_format == 'NCHW':
            channnel = shape[1]
            new_shape = [1, channnel, 1, 1]
        else:
            channnel = shape[-1]
            new_shape = [1, 1, 1, channnel]
        if ndims == 2:
            new_shape = [1, channnel]

        if use_bias:
            beta = tf.get_variable('beta', [channnel], initializer=tf.constant_initializer())
            beta = tf.reshape(beta, new_shape)
        else:
            beta = tf.zeros([1] * ndims, name='beta')
        if use_scale:
            gamma = tf.get_variable('gamma', [channnel], initializer=tf.constant_initializer(1.0))
            gamma = tf.reshape(gamma, new_shape)
        else:
            gamma = tf.ones([1] * ndims, name='gamma')

        return tf.nn.batch_normalization(inputdata, mean, var, beta, gamma, epsilon, name=name)

    @staticmethod
    def instancenorm(inputdata, epsilon=1e-5, data_format='NHWC', use_affine=True, name=None):
        shape = inputdata.get_shape().as_list()
        if len(shape) != 4:
            raise ValueError("Input data of instancebn layer has to be 4D tensor")

        if data_format == 'NHWC':
            axis = [1, 2]
            ch = shape[3]
            new_shape = [1, 1, 1, ch]
        else:
            axis = [2, 3]
            ch = shape[1]
            new_shape = [1, ch, 1, 1]
        if ch is None:
            raise ValueError("Input of instancebn require known channel!")

        mean, var = tf.nn.moments(inputdata, axis, keep_dims=True)

        if not use_affine:
            return tf.divide(inputdata - mean, tf.sqrt(var + epsilon), name='output')

        beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
        beta = tf.reshape(beta, new_shape)
        gamma = tf.get_variable('gamma', [ch], initializer=tf.constant_initializer(1.0))
        gamma = tf.reshape(gamma, new_shape)
        return tf.nn.batch_normalization(inputdata, mean, var, beta, gamma, epsilon, name=name)

    @staticmethod
    def dropout(inputdata, keep_prob, noise_shape=None, name=None):
        return tf.nn.dropout(inputdata, keep_prob=keep_prob, noise_shape=noise_shape, name=name)

    @staticmethod
    def fullyconnect(inputdata, out_dim, w_init=None, b_init=None,
                     use_bias=True, name=None):
        shape = inputdata.get_shape().as_list()[1:]
        if None not in shape:
            inputdata = tf.reshape(inputdata, [-1, int(np.prod(shape))])
        else:
            inputdata = tf.reshape(inputdata, tf.stack([tf.shape(inputdata)[0], -1]))

        if w_init is None:
            w_init = tf.contrib.layers.variance_scaling_initializer()
        if b_init is None:
            b_init = tf.constant_initializer()

        ret = tf.layers.dense(inputs=inputdata, activation=lambda x: tf.identity(x, name='output'),
                              use_bias=use_bias, name=name,
                              kernel_initializer=w_init, bias_initializer=b_init,
                              trainable=True, units=out_dim)
        return ret

    @staticmethod
    def layerbn(inputdata, is_training, name):
        return tf.layers.batch_normalization(inputs=inputdata, training=is_training, name=name)

    @staticmethod
    def squeeze(inputdata, axis=None, name=None):
        return tf.squeeze(input=inputdata, axis=axis, name=name)

    @staticmethod
    def deconv2d(inputdata, out_channel, kernel_size, padding='SAME',
                 stride=1, w_init=None, b_init=None,
                 use_bias=True, activation=None, data_format='channels_last',
                 trainable=True, name=None):
        with tf.variable_scope(name):
            in_shape = inputdata.get_shape().as_list()
            channel_axis = 3 if data_format == 'channels_last' else 1
            in_channel = in_shape[channel_axis]
            assert in_channel is not None, "[Deconv2D] Input cannot have unknown channel!"

            padding = padding.upper()

            if w_init is None:
                w_init = tf.contrib.layers.variance_scaling_initializer()
            if b_init is None:
                b_init = tf.constant_initializer()

            ret = tf.layers.conv2d_transpose(inputs=inputdata, filters=out_channel,
                                             kernel_size=kernel_size,
                                             strides=stride, padding=padding,
                                             data_format=data_format,
                                             activation=activation, use_bias=use_bias,
                                             kernel_initializer=w_init,
                                             bias_initializer=b_init, trainable=trainable,
                                             name=name)
        return ret

    @staticmethod
    def dilation_conv(input_tensor, k_size, out_dims, rate, padding='SAME',
                      w_init=None, b_init=None, use_bias=False, name=None):
        with tf.variable_scope(name):
            in_shape = input_tensor.get_shape().as_list()
            in_channel = in_shape[3]
            assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"

            padding = padding.upper()

            if isinstance(k_size, list):
                filter_shape = [k_size[0], k_size[1]] + [in_channel, out_dims]
            else:
                filter_shape = [k_size, k_size] + [in_channel, out_dims]

            if w_init is None:
                w_init = tf.contrib.layers.variance_scaling_initializer()
            if b_init is None:
                b_init = tf.constant_initializer()

            w = tf.get_variable('W', filter_shape, initializer=w_init)
            b = None

            if use_bias:
                b = tf.get_variable('b', [out_dims], initializer=b_init)

            conv = tf.nn.atrous_conv2d(value=input_tensor, filters=w, rate=rate,
                                       padding=padding, name='dilation_conv')

            if use_bias:
                ret = tf.add(conv, b)
            else:
                ret = conv

        return ret

    @staticmethod
    def spatial_dropout(input_tensor, keep_prob, is_training, name, seed=1234):
        tf.set_random_seed(seed=seed)

        def f1():
            with tf.variable_scope(name):
                return input_tensor

        def f2():
            with tf.variable_scope(name):
                num_feature_maps = [tf.shape(input_tensor)[0], tf.shape(input_tensor)[3]]

                random_tensor = keep_prob
                random_tensor += tf.random_uniform(num_feature_maps,
                                                   seed=seed,
                                                   dtype=input_tensor.dtype)

                binary_tensor = tf.floor(random_tensor)

                binary_tensor = tf.reshape(binary_tensor,
                                           [-1, 1, 1, tf.shape(input_tensor)[3]])
                ret = input_tensor * binary_tensor
                return ret

        output = tf.cond(is_training, f2, f1)
        return output

    @staticmethod
    def lrelu(inputdata, name, alpha=0.2):
        with tf.variable_scope(name):
            return tf.nn.relu(inputdata) - alpha * tf.nn.relu(-inputdata)
View Code

    --vgg_scnn_encoder.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from collections import OrderedDict

import tensorflow as tf
import glog as log
import math
import sys

sys.path.append('encoder_decoder_model')
import cnn_basenet


class VGG16Encoder(cnn_basenet.CNNBaseModel):
    def __init__(self, phase):
        super(VGG16Encoder, self).__init__()
        self._train_phase = tf.constant('train', dtype=tf.string)
        self._test_phase = tf.constant('test', dtype=tf.string)
        self._phase = phase
        self._is_training = self._init_phase()

    def _init_phase(self):
        return tf.equal(self._phase, self._train_phase)

    def _conv_stage(self, input_tensor, k_size, out_dims, name, stride=1, pad='SAME'):
        with tf.variable_scope(name):
            conv = self.conv2d(inputdata=input_tensor, out_channel=out_dims,
                               kernel_size=k_size, stride=stride,
                               use_bias=False, padding=pad, name='conv')
            bn = self.layerbn(inputdata=conv, is_training=self._is_training, name='bn')
            relu = self.relu(inputdata=bn, name='relu')
            return relu

    def _fc_stage(self, input_tensor, out_dims, name, use_bias=False):
        with tf.variable_scope(name):
            fc = self.fullyconnect(inputdata=input_tensor, out_dim=out_dims, use_bias=use_bias, name='fc')
            bn = self.layerbn(inputdata=fc, is_training=self._is_training, name='bn')
            relu = self.relu(inputdata=bn, name='relu')
        return relu

    def scnn_u2d_d2u(self,input_tensor):
        output_list_old = []
        output_list_new = []
        shape_list = input_tensor.get_shape().as_list()
        log.info("scnn_u2d_d2u shape_list:{:}".format(shape_list))
        h_size = input_tensor.get_shape().as_list()[1]
        log.info("scnn_u2d_d2u h_size:{:}".format(h_size))
        channel_size = input_tensor.get_shape().as_list()[3]
        #up2down conv
        for i in range(h_size):
            output_list_old.append(tf.expand_dims(input_tensor[:,i,:,:],axis=1))
        output_list_new.append(tf.expand_dims(input_tensor[:,0,:,:],axis=1))
        
        w_ud = tf.get_variable('w_ud',[1,9,channel_size,channel_size],initializer=tf.random_normal_initializer(0,math.sqrt(2.0/(9*channel_size*channel_size*2))))
        with tf.variable_scope("scnn_u2d"):
            scnn_u2d = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_old[0],w_ud,[1,1,1,1],'SAME')),output_list_old[1])
            output_list_new.append(scnn_u2d)
        
        for i in range(2,h_size):
            with tf.variable_scope("scnn_u2d",reuse=True):
                scnn_u2d = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_new[i-1],w_ud,[1,1,1,1],'SAME')),output_list_old[i])
                output_list_new.append(scnn_u2d)
        
        #down2up conv
        output_list_old = output_list_new
        output_list_new = []
        length = h_size-1
        output_list_new.append(output_list_old[length])
        w_du = tf.get_variable('w_du',[1,9,channel_size,channel_size],initializer=tf.random_normal_initializer(0,math.sqrt(2.0/(9*channel_size*channel_size*2))))
        with tf.variable_scope('scnn_d2u'):
            scnn_d2u = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_old[length],w_du,[1,1,1,1],'SAME')),output_list_old[length-1])
            output_list_new.append(scnn_d2u)
            
        for i in range(2,h_size):
            with tf.variable_scope("scnn_d2u",reuse=True):
                scnn_d2u = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_new[i-1],w_du,[1,1,1,1],'SAME')),output_list_old[length-i])
                output_list_new.append(scnn_d2u)
                
        output_list_new.reverse()
        #log.info("scnn_u2d_d2u output_list_new:{:}".format(output_list_new))
        out_tensor = tf.stack(output_list_new,axis = 1)
        out_tensor = tf.squeeze(out_tensor,axis=2)
        return out_tensor
    
    def scnn_l2r_r2l(self,input_tensor):
        output_list_old = []
        output_list_new = []
        shape_list = input_tensor.get_shape().as_list()
        log.info("scnn_l2r_r2l shape_list:{:}".format(shape_list))
        w_size = input_tensor.get_shape().as_list()[2]
        log.info("scnn_l2r_r2l w_size:{:}".format(w_size))
        channel_size = input_tensor.get_shape().as_list()[3]
        
        #left2right conv
        for i in range(w_size):
            output_list_old.append(tf.expand_dims(input_tensor[:,:,i,:],axis=2))
        output_list_new.append(tf.expand_dims(input_tensor[:,:,0,:],axis=2))
        
        w_lr = tf.get_variable('w_lr',[9,1,channel_size,channel_size],initializer=tf.random_normal_initializer(0,math.sqrt(2.0/(9*channel_size*channel_size*5))))
        with tf.variable_scope("scnn_l2r"):
            scnn_l2r = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_old[0],w_lr,[1,1,1,1],'SAME')),output_list_old[1])
            output_list_new.append(scnn_l2r)
        
        for i in range(2,w_size):
            with tf.variable_scope("scnn_l2r",reuse=True):
                scnn_l2r = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_new[i-1],w_lr,[1,1,1,1],'SAME')),output_list_old[i])
                output_list_new.append(scnn_l2r)
        #log.info("output_list_new:{:}".format(output_list_new))
        
        #right2left conv
        output_list_old = output_list_new
        output_list_new = []
        length = w_size-1
        output_list_new.append(output_list_old[length])
        w_rl = tf.get_variable('w_rl',[9,1,channel_size,channel_size],initializer=tf.random_normal_initializer(0,math.sqrt(2.0/(9*channel_size*channel_size*5))))
        with tf.variable_scope('scnn_r2l'):
            scnn_r2l = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_old[length],w_rl,[1,1,1,1],'SAME')),output_list_old[length-1])
            output_list_new.append(scnn_r2l)
            
        for i in range(2,w_size):
            with tf.variable_scope("scnn_r2l",reuse=True):
                scnn_r2l = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_new[i-1],w_rl,[1,1,1,1],'SAME')),output_list_old[length-i])
                output_list_new.append(scnn_r2l)
                
        output_list_new.reverse()
        out_tensor = tf.stack(output_list_new,axis = 2)
        out_tensor = tf.squeeze(out_tensor,axis=3)
        return out_tensor
     
    
    def encode(self, input_tensor, name):
        ret = OrderedDict()

        with tf.variable_scope(name):
            # conv stage 1_1
            conv_1_1 = self._conv_stage(input_tensor=input_tensor, k_size=3, out_dims=64, name='conv1_1')
            log.info("encode conv_1_1:{:}".format(conv_1_1.get_shape().as_list()))
            
            # conv stage 1_2
            conv_1_2 = self._conv_stage(input_tensor=conv_1_1, k_size=3, out_dims=64, name='conv1_2')
            log.info("encode conv_1_2:{:}".format(conv_1_2.get_shape().as_list()))
            
            # pool stage 1
            pool1 = self.maxpooling(inputdata=conv_1_2, kernel_size=2, stride=2, name='pool1')
            log.info("encode pool1:{:}".format(pool1.get_shape().as_list()))
            
            # conv stage 2_1
            conv_2_1 = self._conv_stage(input_tensor=pool1, k_size=3,  out_dims=128, name='conv2_1')
            log.info("encode conv_2_1:{:}".format(conv_2_1.get_shape().as_list()))

            # conv stage 2_2
            conv_2_2 = self._conv_stage(input_tensor=conv_2_1, k_size=3, out_dims=128, name='conv2_2')
            log.info("encode conv_2_2:{:}".format(conv_2_2.get_shape().as_list()))
            
            # pool stage 2
            pool2 = self.maxpooling(inputdata=conv_2_2, kernel_size=2, stride=2, name='pool2')
            log.info("encode pool2:{:}".format(pool2.get_shape().as_list()))                        

            # conv stage 3_1
            conv_3_1 = self._conv_stage(input_tensor=pool2, k_size=3, out_dims=256, name='conv3_1')
            log.info("encode conv_3_1:{:}".format(conv_3_1.get_shape().as_list()))                            

            # conv_stage 3_2
            conv_3_2 = self._conv_stage(input_tensor=conv_3_1, k_size=3, out_dims=256, name='conv3_2')
            log.info("encode conv_3_2:{:}".format(conv_3_2.get_shape().as_list()))                              

            # conv stage 3_3
            conv_3_3 = self._conv_stage(input_tensor=conv_3_2, k_size=3, out_dims=256, name='conv3_3')
            log.info("encode conv_3_3:{:}".format(conv_3_3.get_shape().as_list()))                              

            ret['conv_3_3'] = dict()
            ret['conv_3_3']['data'] = conv_3_3
            ret['conv_3_3']['shape'] = conv_3_3.get_shape().as_list()

            # pool stage 3
            pool3 = self.maxpooling(inputdata=conv_3_3, kernel_size=2, stride=2, name='pool3')
            log.info("encode pool3:{:}".format(pool3.get_shape().as_list()))
                                   
            ret['pool3'] = dict()
            ret['pool3']['data'] = pool3
            ret['pool3']['shape'] = pool3.get_shape().as_list()

            # conv stage 4_1
            conv_4_1 = self._conv_stage(input_tensor=pool3, k_size=3, out_dims=512, name='conv4_1')
            log.info("encode conv_4_1:{:}".format(conv_4_1.get_shape().as_list()))                            

            # conv stage 4_2
            conv_4_2 = self._conv_stage(input_tensor=conv_4_1, k_size=3, out_dims=512, name='conv4_2')
            log.info("encode conv_4_2:{:}".format(conv_4_2.get_shape().as_list()))

            # conv stage 4_3
            conv_4_3 = self._conv_stage(input_tensor=conv_4_2, k_size=3,  out_dims=512, name='conv4_3')
            log.info("encode conv_4_3:{:}".format(conv_4_3.get_shape().as_list()))                               

            # pool stage 4
            pool4 = self.maxpooling(inputdata=conv_4_3, kernel_size=2, stride=2, name='pool4')
            log.info("encode pool4:{:}".format(pool4.get_shape().as_list()))
                                    
            ret['pool4'] = dict()
            ret['pool4']['data'] = pool4
            ret['pool4']['shape'] = pool4.get_shape().as_list()

            # conv stage 5_1
            conv_5_1 = self._conv_stage(input_tensor=pool4, k_size=3,
                                        out_dims=512, name='conv5_1')
            log.info("encode conv_5_1:{:}".format(conv_5_1.get_shape().as_list()))                            

            # conv stage 5_2
            conv_5_2 = self._conv_stage(input_tensor=conv_5_1, k_size=3,
                                        out_dims=512, name='conv5_2')
            log.info("encode conv_5_2:{:}".format(conv_5_2.get_shape().as_list()))

            # conv stage 5_3
            conv_5_3 = self._conv_stage(input_tensor=conv_5_2, k_size=3,
                                        out_dims=512, name='conv5_3')
            log.info("encode conv_5_3:{:}".format(conv_5_3.get_shape().as_list()))
            
            # conv stage 6_1
            conv_6_1 = self._conv_stage(input_tensor=conv_5_3, k_size=3,
                          out_dims=128, name='conv6_1')
            log.info("encode conv_6_1:{:}".format(conv_6_1.get_shape().as_list()))

            scnn_ud = self.scnn_u2d_d2u(conv_6_1)
            log.info("encode scnn_ud:{:}".format(scnn_ud.get_shape().as_list()))
            
            scnn_lr = self.scnn_l2r_r2l(scnn_ud)
            log.info("encode scnn_lr:{:}".format(scnn_lr.get_shape().as_list()))
            
            # pool stage 5
            pool5 = self.maxpooling(inputdata=scnn_lr, kernel_size=2,
                                    stride=2, name='pool5')
            log.info("encode pool5:{:}".format(pool5.get_shape().as_list()))

            ret['pool5'] = dict()
            ret['pool5']['data'] = pool5
            ret['pool5']['shape'] = pool5.get_shape().as_list()

            # fc stage 1
            # fc6 = self._fc_stage(input_tensor=pool5, out_dims=4096, name='fc6',
            #                      use_bias=False, flags=flags)

            # fc stage 2
            # fc7 = self._fc_stage(input_tensor=fc6, out_dims=4096, name='fc7',
            #                      use_bias=False, flags=flags)

        return ret

if __name__ == '__main__':
    a = tf.placeholder(dtype=tf.float32, shape=[1, 2048, 2048, 3], name='input')
    encoder = VGG16Encoder(phase=tf.constant('train', dtype=tf.string))
    ret = encoder.encode(a, name='encode')
    for layer_name, layer_info in ret.items():
        print('layer name: {:s} shape: {}'.format(layer_name, layer_info['shape']))
View Code

    --dense_encoder.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import tensorflow as tf
from collections import OrderedDict

#from encoder_decoder_model import cnn_basenet
import cnn_basenet

class DenseEncoder(cnn_basenet.CNNBaseModel):
    """
    基于DenseNet的编码器
    """
    def __init__(self, l, n, growthrate, phase, with_bc=False, bc_theta=0.5):
        super(DenseEncoder, self).__init__()
        self._L = l
        self._block_depth = int((l - n - 1) / n)
        self._N = n
        self._growthrate = growthrate
        self._with_bc = with_bc
        self._phase = phase
        self._train_phase = tf.constant('train', dtype=tf.string)
        self._test_phase = tf.constant('test', dtype=tf.string)
        self._is_training = self._init_phase()
        self._bc_theta = bc_theta
        return

    def _init_phase(self):
        return tf.equal(self._phase, self._train_phase)

    def __str__(self):
        encoder_info = 'A densenet with net depth: {:d} block nums: ' \
                       '{:d} growth rate: {:d} block depth: {:d}'.\
            format(self._L, self._N, self._growthrate, self._block_depth)
        return encoder_info

    def _composite_conv(self, inputdata, out_channel, name):
        with tf.variable_scope(name):
            bn_1 = self.layerbn(inputdata=inputdata, is_training=self._is_training, name='bn_1')
            relu_1 = self.relu(bn_1, name='relu_1')
            if self._with_bc:
                conv_1 = self.conv2d(inputdata=relu_1, out_channel=out_channel,
                                     kernel_size=1,
                                     padding='SAME', stride=1, use_bias=False,
                                     name='conv_1')

                bn_2 = self.layerbn(inputdata=conv_1, is_training=self._is_training, name='bn_2')

                relu_2 = self.relu(inputdata=bn_2, name='relu_2')
                conv_2 = self.conv2d(inputdata=relu_2, out_channel=out_channel,
                                     kernel_size=3,
                                     stride=1, padding='SAME', use_bias=False,
                                     name='conv_2')
                return conv_2
            else:
                conv_2 = self.conv2d(inputdata=relu_1, out_channel=out_channel,
                                     kernel_size=3,
                                     stride=1, padding='SAME', use_bias=False,
                                     name='conv_2')
                return conv_2

    def _denseconnect_layers(self, inputdata, name):
        with tf.variable_scope(name):
            conv_out = self._composite_conv(inputdata=inputdata, name='composite_conv',  out_channel=self._growthrate)
            concate_cout = tf.concat(values=[conv_out, inputdata], axis=3, name='concatenate')

        return concate_cout

    def _transition_layers(self, inputdata, name):
        """
        Mainly implement the Pooling layer mentioned in DenseNet paper
        :param inputdata:
        :param name:
        :return:
        """
        input_channels = inputdata.get_shape().as_list()[3]

        with tf.variable_scope(name):
            # First batch norm
            bn = self.layerbn(inputdata=inputdata, is_training=self._is_training, name='bn')

            # Second 1*1 conv
            if self._with_bc:
                out_channels = int(input_channels * self._bc_theta)
                conv = self.conv2d(inputdata=bn, out_channel=out_channels,
                                   kernel_size=1, stride=1, use_bias=False,
                                   name='conv')
                # Third average pooling
                avgpool_out = self.avgpooling(inputdata=conv, kernel_size=2,
                                              stride=2, name='avgpool')
                return avgpool_out
            else:
                conv = self.conv2d(inputdata=bn, out_channel=input_channels,
                                   kernel_size=1, stride=1, use_bias=False,
                                   name='conv')
                # Third average pooling
                avgpool_out = self.avgpooling(inputdata=conv, kernel_size=2,
                                              stride=2, name='avgpool')
                return avgpool_out

    def _dense_block(self, inputdata, name):
        """
        Mainly implement the dense block mentioned in DenseNet figure 1
        :param inputdata:
        :param name:
        :return:
        """
        block_input = inputdata
        with tf.variable_scope(name):
            for i in range(self._block_depth):
                block_layer_name = '{:s}_layer_{:d}'.format(name, i + 1)
                block_input = self._denseconnect_layers(inputdata=block_input,
                                                        name=block_layer_name)
        return block_input

    def encode(self, input_tensor, name):
        """
        DenseNet编码
        :param input_tensor:
        :param name:
        :return:
        """
        encode_ret = OrderedDict()

        # First apply a 3*3 16 out channels conv layer
        # mentioned in DenseNet paper Implementation Details part
        with tf.variable_scope(name):
            conv1 = self.conv2d(inputdata=input_tensor, out_channel=16,
                                kernel_size=3, use_bias=False, name='conv1')
            dense_block_input = conv1

            # Second apply dense block stage
            for dense_block_nums in range(self._N):
                dense_block_name = 'Dense_Block_{:d}'.format(dense_block_nums + 1)

                # dense connectivity
                dense_block_out = self._dense_block(inputdata=dense_block_input,
                                                    name=dense_block_name)
                # apply the trainsition part
                dense_block_out = self._transition_layers(inputdata=dense_block_out,
                                                          name=dense_block_name)
                dense_block_input = dense_block_out
                encode_ret[dense_block_name] = dict()
                encode_ret[dense_block_name]['data'] = dense_block_out
                encode_ret[dense_block_name]['shape'] = dense_block_out.get_shape().as_list()

        return encode_ret


if __name__ == '__main__':
    input_tensor = tf.placeholder(dtype=tf.float32, shape=[None, 384, 1248, 3], name='input_tensor')
    encoder = DenseEncoder(l=100, growthrate=16, with_bc=True, phase=tf.constant('train'), n=5)
    ret = encoder.encode(input_tensor=input_tensor, name='Dense_Encode')
    for layer_name, layer_info in ret.items():
        print('layer_name: {:s} shape: {}'.format(layer_name, layer_info['shape']))
View Code

    --fcn_decoder.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import tensorflow as tf

#from encoder_decoder_model import cnn_basenet
#from encoder_decoder_model import vgg_encoder
#from encoder_decoder_model import dense_encoder
import cnn_basenet
import vgg_encoder
import dense_encoder

class FCNDecoder(cnn_basenet.CNNBaseModel):

    def __init__(self, phase):
        """

        """
        super(FCNDecoder, self).__init__()
        self._train_phase = tf.constant('train', dtype=tf.string)
        self._phase = phase
        self._is_training = self._init_phase()

    def _init_phase(self):
        """

        :return:
        """
        return tf.equal(self._phase, self._train_phase)

    def decode(self, input_tensor_dict, decode_layer_list, name):
        """
        解码特征信息反卷积还原
        :param input_tensor_dict:
        :param decode_layer_list: 需要解码的层名称需要由深到浅顺序写
                                  eg. ['pool5', 'pool4', 'pool3']
        :param name:
        :return:
        """
        ret = dict()

        with tf.variable_scope(name):
            # score stage 1
            input_tensor = input_tensor_dict[decode_layer_list[0]]['data']

            score = self.conv2d(inputdata=input_tensor, out_channel=64,
                                kernel_size=1, use_bias=False, name='score_origin')
            ret['score'] = dict()                    
            ret['score']['data'] = score
            ret['score']['shape'] = score.get_shape().as_list()      
                          
            decode_layer_list = decode_layer_list[1:]
            print("len(decode_layer_list):",len(decode_layer_list))
            for i in range(len(decode_layer_list)):
                deconv = self.deconv2d(inputdata=score, out_channel=64, kernel_size=4,
                                       stride=2, use_bias=False, name='deconv_{:d}'.format(i + 1))
                input_tensor = input_tensor_dict[decode_layer_list[i]]['data']
                score = self.conv2d(inputdata=input_tensor, out_channel=64,
                                    kernel_size=1, use_bias=False, name='score_{:d}'.format(i + 1))
                fused = tf.add(deconv, score, name='fuse_{:d}'.format(i + 1))
                score = fused
                ret['fuse_{:d}'.format(i + 1)] = dict()
                ret['fuse_{:d}'.format(i + 1)]['data'] = fused
                ret['fuse_{:d}'.format(i + 1)]['shape'] = fused.get_shape().as_list()
               
            deconv_final = self.deconv2d(inputdata=score, out_channel=64, kernel_size=16,
                                         stride=8, use_bias=False, name='deconv_final')

            score_final = self.conv2d(inputdata=deconv_final, out_channel=2,
                                      kernel_size=1, use_bias=False, name='score_final')
              
            ret['logits'] = score_final
            ret['deconv'] = deconv_final
            
            ret['logits'] = dict()
            ret['logits']['data'] = score_final
            ret['logits']['shape'] = score_final.get_shape().as_list() 
            
            ret['deconv'] = dict()
            ret['deconv']['data'] = deconv_final
            ret['deconv']['shape'] = deconv_final.get_shape().as_list() 
        return ret


if __name__ == '__main__':

    vgg_encoder = vgg_encoder.VGG16Encoder(phase=tf.constant('train', tf.string))
    dense_encoder = dense_encoder.DenseEncoder(l=40, growthrate=12,
                                               with_bc=True, phase='train', n=5)
    decoder = FCNDecoder(phase='train')

    in_tensor = tf.placeholder(dtype=tf.float32, shape=[None, 256, 512, 3],
                               name='input')

    vgg_encode_ret = vgg_encoder.encode(in_tensor, name='vgg_encoder')
    dense_encode_ret = dense_encoder.encode(in_tensor, name='dense_encoder')
    decode_ret = decoder.decode(vgg_encode_ret, name='decoder',
                                decode_layer_list=['pool5',
                                                   'pool4',
                                                   'pool3'])
                                                   
    for layer_name, layer_info in decode_ret.items():
        print('layer name: {:s} shape: {}'.format(layer_name, layer_info['shape']))                                               
View Code

     ./merge_model

     --merge_model.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import tensorflow as tf

from encoder_decoder_model import vgg_encoder
from encoder_decoder_model import fcn_decoder
from encoder_decoder_model import dense_encoder
from encoder_decoder_model import cnn_basenet
from lanenet_model import lanenet_discriminative_loss
from encoder_decoder_model import vgg_scnn_encoder
import glog

class LaneNet(cnn_basenet.CNNBaseModel):
    """
    实现语义分割模型
    """
    def __init__(self, phase, net_flag='vgg'):
        """

        """
        super(LaneNet, self).__init__()
        self._net_flag = net_flag
        self._phase = phase
        if self._net_flag == 'vgg':
            self._encoder = vgg_encoder.VGG16Encoder(phase=phase)
        elif self._net_flag == 'vgg_scnn':
            self._encoder = vgg_scnn_encoder.VGG16Encoder(phase=phase)
        elif self._net_flag == 'dense':
            self._encoder = dense_encoder.DenseEncoder(l=20, growthrate=8,
                                                       with_bc=True,
                                                       phase=phase,
                                                       n=5)
        self._decoder = fcn_decoder.FCNDecoder(phase=phase)
        return

    def __str__(self):
        """

        :return:
        """
        info = 'Semantic Segmentation use {:s} as basenet to encode'.format(self._net_flag)
        return info

    def _build_model(self, input_tensor, name):
        """
        前向传播过程
        :param input_tensor:
        :param name:
        :return:
        """
        with tf.variable_scope(name):
            # first encode
            encode_ret = self._encoder.encode(input_tensor=input_tensor,
                                              name='encode')

            # second decode
            if self._net_flag.lower() == 'vgg':
                decode_ret = self._decoder.decode(input_tensor_dict=encode_ret,
                                                  name='decode',
                                                  decode_layer_list=['pool5',
                                                                     'pool4',
                                                                     'pool3'])
                return decode_ret
            if self._net_flag.lower() == 'vgg_scnn':
                decode_ret = self._decoder.decode(input_tensor_dict=encode_ret,
                                                  name='decode',
                                                  decode_layer_list=['pool5',
                                                                     'pool4',
                                                                     'pool3'])
                return decode_ret
            elif self._net_flag.lower() == 'dense':
                decode_ret = self._decoder.decode(input_tensor_dict=encode_ret,
                                                  name='decode',
                                                  decode_layer_list=['Dense_Block_5',
                                                                     'Dense_Block_4',
                                                                     'Dense_Block_3'])
                return decode_ret

    def compute_loss(self, input_tensor, binary_label, instance_label, name):
        """
        计算LaneNet模型损失函数
        :param input_tensor:
        :param binary_label:
        :param instance_label:
        :param name:
        :return:
        """
        with tf.variable_scope(name):
            # 前向传播获取logits
            inference_ret = self._build_model(input_tensor=input_tensor, name='inference')
            glog.info('compute_loss inference_ret:{:}'.format(inference_ret)) 
            # 计算二值分割损失函数
            decode_logits = inference_ret['logits']
            binary_label_plain = tf.reshape(
                binary_label,
                shape=[binary_label.get_shape().as_list()[0] *
                       binary_label.get_shape().as_list()[1] *
                       binary_label.get_shape().as_list()[2]])
            glog.info('compute_loss binary_label_plain:{:}'.format(binary_label_plain))            
            # 加入class weights
            unique_labels, unique_id, counts = tf.unique_with_counts(binary_label_plain)
            counts = tf.cast(counts, tf.float32)
            glog.info('compute_loss counts:{:}'.format(counts)) 
            inverse_weights = tf.divide(1.0,
                                        tf.log(tf.add(tf.divide(tf.constant(1.0), counts),
                                                      tf.constant(1.02))))
            glog.info('compute_loss inverse_weights:{:}'.format(inverse_weights))                                           
            inverse_weights = tf.gather(inverse_weights, binary_label)
            glog.info('compute_loss gather inverse_weights:{:}'.format(inverse_weights))      
            binary_segmenatation_loss = tf.losses.sparse_softmax_cross_entropy(
                labels=binary_label, logits=decode_logits, weights=inverse_weights)
            glog.info('compute_loss binary_segmenatation_loss:{:}'.format(binary_segmenatation_loss))    
            binary_segmenatation_loss = tf.reduce_mean(binary_segmenatation_loss)
            glog.info('compute_loss reduce_mean binary_segmenatation_loss:{:}'.format(binary_segmenatation_loss))    
            # 计算discriminative loss损失函数
            decode_deconv = inference_ret['deconv']
            # 像素嵌入
            pix_embedding = self.conv2d(inputdata=decode_deconv, out_channel=4, kernel_size=1,
                                        use_bias=False, name='pix_embedding_conv')
            pix_embedding = self.relu(inputdata=pix_embedding, name='pix_embedding_relu')
            # 计算discriminative loss
            image_shape = (pix_embedding.get_shape().as_list()[1], pix_embedding.get_shape().as_list()[2])
            glog.info('compute_loss image_shape:{:}'.format(image_shape)) 
            disc_loss, l_var, l_dist, l_reg = \
                lanenet_discriminative_loss.discriminative_loss(
                    pix_embedding, instance_label, 4, image_shape, 0.5, 3.0, 1.0, 1.0, 0.001)
            glog.info('compute_loss disc_loss:{:}'.format(disc_loss))
            # 合并损失
            l2_reg_loss = tf.constant(0.0, tf.float32)
            for vv in tf.trainable_variables():
                if 'bn' in vv.name:
                    continue
                else:
                    l2_reg_loss = tf.add(l2_reg_loss, tf.nn.l2_loss(vv))
            l2_reg_loss *= 0.001
            total_loss = 0.5 * binary_segmenatation_loss + 0.5 * disc_loss + l2_reg_loss

            ret = {
                'total_loss': total_loss,
                'binary_seg_logits': decode_logits,
                'instance_seg_logits': pix_embedding,
                'binary_seg_loss': binary_segmenatation_loss,
                'discriminative_loss': disc_loss
            }

            return ret

    def inference(self, input_tensor, name):
        """

        :param input_tensor:
        :param name:
        :return:
        """
        with tf.variable_scope(name):
            # 前向传播获取logits
            inference_ret = self._build_model(input_tensor=input_tensor, name='inference')
            # 计算二值分割损失函数
            decode_logits = inference_ret['logits']
            binary_seg_ret = tf.nn.softmax(logits=decode_logits)
            binary_seg_ret = tf.argmax(binary_seg_ret, axis=-1)
            # 计算像素嵌入
            decode_deconv = inference_ret['deconv']
            # 像素嵌入
            pix_embedding = self.conv2d(inputdata=decode_deconv, out_channel=4, kernel_size=1,
                                        use_bias=False, name='pix_embedding_conv')
            pix_embedding = self.relu(inputdata=pix_embedding, name='pix_embedding_relu')

            return binary_seg_ret, pix_embedding


if __name__ == '__main__':
    model = LaneNet(tf.constant('train', dtype=tf.string))
    input_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input')
    binary_label = tf.placeholder(dtype=tf.int64, shape=[1, 256, 512, 1], name='label')
    instance_label = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 1], name='label')
    ret = model.compute_loss(input_tensor=input_tensor, binary_label=binary_label,
                             instance_label=instance_label, name='loss')
    for vv in tf.trainable_variables():
        if 'bn' in vv.name:
            continue
        print(vv.name)
View Code

     --dirscriminative_loss.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import tensorflow as tf
import glog

def discriminative_loss_single(
        prediction,
        correct_label,
        feature_dim,
        label_shape,
        delta_v,
        delta_d,
        param_var,
        param_dist,
        param_reg):
    """
    论文equ(1)提到的实例分割损失函数
    :param prediction: inference of network
    :param correct_label: instance label
    :param feature_dim: feature dimension of prediction
    :param label_shape: shape of label
    :param delta_v: cut off variance distance
    :param delta_d: cut off cluster distance
    :param param_var: weight for intra cluster variance
    :param param_dist: weight for inter cluster distances
    :param param_reg: weight regularization
    """

    # 像素对齐为一行
    correct_label = tf.reshape(
        correct_label, [
            label_shape[1] * label_shape[0]])
    reshaped_pred = tf.reshape(
        prediction, [
            label_shape[1] * label_shape[0], feature_dim])

    # 统计实例个数
    unique_labels, unique_id, counts = tf.unique_with_counts(correct_label)
    counts = tf.cast(counts, tf.float32)
    num_instances = tf.size(unique_labels)
    glog.info('discriminative_loss_single counts:{:} num_instances:{:}'.format(counts,num_instances))
    # 计算pixel embedding均值向量
    segmented_sum = tf.unsorted_segment_sum(
        reshaped_pred, unique_id, num_instances)  
    mu = tf.div(segmented_sum, tf.reshape(counts, (-1, 1)))
    mu_expand = tf.gather(mu, unique_id)

    # 计算公式的loss(var)
    distance = tf.norm(tf.subtract(mu_expand, reshaped_pred), axis=1)
    distance = tf.subtract(distance, delta_v)
    distance = tf.clip_by_value(distance, 0., distance)
    distance = tf.square(distance)

    l_var = tf.unsorted_segment_sum(distance, unique_id, num_instances)
    l_var = tf.div(l_var, counts)
    l_var = tf.reduce_sum(l_var)
    l_var = tf.divide(l_var, tf.cast(num_instances, tf.float32))

    # 计算公式的loss(dist)
    mu_interleaved_rep = tf.tile(mu, [num_instances, 1])
    mu_band_rep = tf.tile(mu, [1, num_instances])
    mu_band_rep = tf.reshape(
        mu_band_rep,
        (num_instances *
         num_instances,
         feature_dim))

    mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep)

    # 去除掩模上的零点
    intermediate_tensor = tf.reduce_sum(tf.abs(mu_diff), axis=1)
    zero_vector = tf.zeros(1, dtype=tf.float32)
    bool_mask = tf.not_equal(intermediate_tensor, zero_vector)
    mu_diff_bool = tf.boolean_mask(mu_diff, bool_mask)

    mu_norm = tf.norm(mu_diff_bool, axis=1)
    mu_norm = tf.subtract(2. * delta_d, mu_norm)
    mu_norm = tf.clip_by_value(mu_norm, 0., mu_norm)
    mu_norm = tf.square(mu_norm)

    l_dist = tf.reduce_mean(mu_norm)

    # 计算原始Discriminative Loss论文中提到的正则项损失
    l_reg = tf.reduce_mean(tf.norm(mu, axis=1))

    # 合并损失按照原始Discriminative Loss论文中提到的参数合并
    param_scale = 1.
    l_var = param_var * l_var
    l_dist = param_dist * l_dist
    l_reg = param_reg * l_reg

    loss = param_scale * (l_var + l_dist + l_reg)

    return loss, l_var, l_dist, l_reg


def discriminative_loss(prediction, correct_label, feature_dim, image_shape,
                        delta_v, delta_d, param_var, param_dist, param_reg):
    """
    按照论文的思想迭代计算loss损失
    :return: discriminative loss and its three components
    """

    def cond(label, batch, out_loss, out_var, out_dist, out_reg, i):
        return tf.less(i, tf.shape(batch)[0])

    def body(label, batch, out_loss, out_var, out_dist, out_reg, i):
        disc_loss, l_var, l_dist, l_reg = discriminative_loss_single(
            prediction[i], correct_label[i], feature_dim, image_shape, delta_v, delta_d, param_var, param_dist, param_reg)

        out_loss = out_loss.write(i, disc_loss)
        out_var = out_var.write(i, l_var)
        out_dist = out_dist.write(i, l_dist)
        out_reg = out_reg.write(i, l_reg)

        return label, batch, out_loss, out_var, out_dist, out_reg, i + 1

    # TensorArray is a data structure that support dynamic writing
    output_ta_loss = tf.TensorArray(dtype=tf.float32,
                                    size=0,
                                    dynamic_size=True)
    output_ta_var = tf.TensorArray(dtype=tf.float32,
                                   size=0,
                                   dynamic_size=True)
    output_ta_dist = tf.TensorArray(dtype=tf.float32,
                                    size=0,
                                    dynamic_size=True)
    output_ta_reg = tf.TensorArray(dtype=tf.float32,
                                   size=0,
                                   dynamic_size=True)

    _, _, out_loss_op, out_var_op, out_dist_op, out_reg_op, _ = tf.while_loop(
        cond, body, [
            correct_label, prediction, output_ta_loss, output_ta_var, output_ta_dist, output_ta_reg, 0])
    out_loss_op = out_loss_op.stack()
    out_var_op = out_var_op.stack()
    out_dist_op = out_dist_op.stack()
    out_reg_op = out_reg_op.stack()

    disc_loss = tf.reduce_mean(out_loss_op)
    l_var = tf.reduce_mean(out_var_op)
    l_dist = tf.reduce_mean(out_dist_op)
    l_reg = tf.reduce_mean(out_reg_op)

    return disc_loss, l_var, l_dist, l_reg
View Code

    --postpostprecess.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LaneNet模型后处理
"""
import numpy as np
import matplotlib.pyplot as plt
import cv2
import glog

try:
    from cv2 import cv2
except ImportError:
    pass


class LaneNetPoseProcessor(object):
    """

    """
    def __init__(self):
        """

        """
        pass

    @staticmethod
    def _morphological_process(image, kernel_size=5):
        """

        :param image:
        :param kernel_size:
        :return:
        """
        if image.dtype is not np.uint8:
            image = np.array(image, np.uint8)
        glog.info("_morphological_process image shape len:{:d}".format(len(image.shape)))
        if len(image.shape) == 3:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        glog.info("_morphological_process image shape len:{:d}".format(len(image.shape)))
        kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size))
        
        # close operation fille hole
        closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1)

        return closing

    @staticmethod
    def _connect_components_analysis(image):
        """

        :param image:
        :return:
        """
        glog.info("_connect_components_analysis image shape len:{:d}".format(len(image.shape)))
        if len(image.shape) == 3:
            gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray_image = image

        return cv2.connectedComponentsWithStats(gray_image, connectivity=8, ltype=cv2.CV_32S)

    def postprocess(self, image, minarea_threshold=15):
        """

        :param image:
        :param minarea_threshold: 连通域分析阈值
        :return:
        """
        # 首先进行图像形态学运算
        morphological_ret = self._morphological_process(image, kernel_size=5)
        glog.info("postprocess image shape len:{:d}".format(len(image.shape)))

        # 进行连通域分析
        connect_components_analysis_ret = self._connect_components_analysis(image=morphological_ret)
        glog.info("postprocess connect_components_analysis_ret:{:}".format(connect_components_analysis_ret))
        # 排序连通域并删除过小的连通域
        labels = connect_components_analysis_ret[1]
        stats = connect_components_analysis_ret[2]
        glog.info("postprocess labels:{:}".format(labels))
        glog.info("postprocess stats:{:}".format(stats))
        for index, stat in enumerate(stats):
            if stat[4] <= minarea_threshold:
                idx = np.where(labels == index)
                morphological_ret[idx] = 0

        return morphological_ret


if __name__ == '__main__':
    processor = LaneNetPoseProcessor()

    image = cv2.imread('D:/Code/github/tf_lanenet/data/training_data_example/gt_image_binary/0000.png', cv2.IMREAD_UNCHANGED) #IMREAD_GRAYSCALE

    postprocess_ret = processor.postprocess(image)

    plt.figure('src')
    plt.imshow(image)
    plt.figure('post')
    plt.imshow(postprocess_ret)
    plt.show()
View Code

         --cluster.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
实现LaneNet中实例分割的聚类部分
"""
import numpy as np
import glog as log
import matplotlib.pyplot as plt
from sklearn.cluster import MeanShift
from sklearn.cluster import DBSCAN
import time
import warnings
import cv2
import glog

try:
    from cv2 import cv2
except ImportError:
    pass


class LaneNetCluster(object):
    """
    实例分割聚类器
    """

    def __init__(self):
        """

        """
        self._color_map = [np.array([255, 0, 0]),
                           np.array([0, 255, 0]),
                           np.array([0, 0, 255]),
                           np.array([125, 125, 0]),
                           np.array([0, 125, 125]),
                           np.array([125, 0, 125]),
                           np.array([50, 100, 50]),
                           np.array([100, 50, 100])]
        pass

    @staticmethod
    def _cluster(prediction, bandwidth):
        """
        实现论文SectionⅡ的cluster部分
        :param prediction:
        :param bandwidth:
        :return:
        """
        ms = MeanShift(bandwidth, bin_seeding=True)
        # log.info('开始Mean shift聚类 ...')
        tic = time.time()
        try:
            ms.fit(prediction)
        except ValueError as err:
            log.error(err)
            return 0, [], []
        # log.info('Mean Shift耗时: {:.5f}s'.format(time.time() - tic))
        labels = ms.labels_
        cluster_centers = ms.cluster_centers_

        num_clusters = cluster_centers.shape[0]

        # log.info('聚类簇个数为: {:d}'.format(num_clusters))

        return num_clusters, labels, cluster_centers

    @staticmethod
    def _cluster_v2(prediction):
        """
        dbscan cluster
        :param prediction:
        :return:
        """
        db = DBSCAN(eps=0.7, min_samples=200).fit(prediction)
        db_labels = db.labels_
        unique_labels = np.unique(db_labels)
        unique_labels = [tmp for tmp in unique_labels if tmp != -1]
        log.info('聚类簇个数为: {:d}'.format(len(unique_labels)))

        num_clusters = len(unique_labels)
        cluster_centers = db.components_

        return num_clusters, db_labels, cluster_centers

    @staticmethod
    def _get_lane_area(binary_seg_ret, instance_seg_ret):
        """
        通过二值分割掩码图在实例分割图上获取所有车道线的特征向量
        :param binary_seg_ret:
        :param instance_seg_ret:
        :return:
        """
        idx = np.where(binary_seg_ret == 1)
        
        print("_get_lane_area idx:",idx)
        print("_get_lane_area idx len:",len(idx))
        print("_get_lane_area idx len[0]:",len(idx[0]))
        print("_get_lane_area idx len[1]:",len(idx[1]))
        lane_embedding_feats = []
        lane_coordinate = []
        for i in range(len(idx[0])):
            lane_embedding_feats.append(instance_seg_ret[idx[0][i], idx[1][i]])
            #print("_get_lane_area instance_seg_ret[idx[0][i], idx[1][i]]:",instance_seg_ret[idx[0][i], idx[1][i]])
            lane_coordinate.append([idx[0][i], idx[1][i]])
            #print("_get_lane_area idx[0][i]:",idx[0][i]," , idx[1][i]:", idx[1][i])

        return np.array(lane_embedding_feats, np.float32), np.array(lane_coordinate, np.int64)

    @staticmethod
    def _thresh_coord(coord):
        """
        过滤实例车道线位置坐标点,假设车道线是连续的, 因此车道线点的坐标变换应该是平滑变化的不应该出现跳变
        :param coord: [(x, y)]
        :return:
        """
        pts_x = coord[:, 0]
        mean_x = np.mean(pts_x)

        idx = np.where(np.abs(pts_x - mean_x) < mean_x)

        return coord[idx[0]]

    @staticmethod
    def _lane_fit(lane_pts):
        """
        车道线多项式拟合
        :param lane_pts:
        :return:
        """
        if not isinstance(lane_pts, np.ndarray):
            lane_pts = np.array(lane_pts, np.float32)

        x = lane_pts[:, 0]
        y = lane_pts[:, 1]
        x_fit = []
        y_fit = []
        with warnings.catch_warnings():
            warnings.filterwarnings('error')
            try:
                f1 = np.polyfit(x, y, 3)
                p1 = np.poly1d(f1)
                x_min = int(np.min(x))
                x_max = int(np.max(x))
                x_fit = []
                for i in range(x_min, x_max + 1):
                    x_fit.append(i)
                y_fit = p1(x_fit)
            except Warning as e:
                x_fit = x
                y_fit = y
            finally:
                return zip(x_fit, y_fit)

    def get_lane_mask(self, binary_seg_ret, instance_seg_ret):
        """

        :param binary_seg_ret:
        :param instance_seg_ret:
        :return:
        """
        lane_embedding_feats, lane_coordinate = self._get_lane_area(binary_seg_ret, instance_seg_ret)
        
        num_clusters, labels, cluster_centers = self._cluster(lane_embedding_feats, bandwidth=1.5)

        # 聚类簇超过八个则选择其中类内样本最多的八个聚类簇保留下来
        if num_clusters > 8:
            cluster_sample_nums = []
            for i in range(num_clusters):
                cluster_sample_nums.append(len(np.where(labels == i)[0]))
            sort_idx = np.argsort(-np.array(cluster_sample_nums, np.int64))
            cluster_index = np.array(range(num_clusters))[sort_idx[0:4]]
        else:
            cluster_index = range(num_clusters)

        mask_image = np.zeros(shape=[binary_seg_ret.shape[0], binary_seg_ret.shape[1], 3], dtype=np.uint8)

        for index, i in enumerate(cluster_index):
            idx = np.where(labels == i)
            coord = lane_coordinate[idx]
            # coord = self._thresh_coord(coord)
            coord = np.flip(coord, axis=1)
            # coord = (coord[:, 0], coord[:, 1])
            color = (int(self._color_map[index][0]),
                     int(self._color_map[index][1]),
                     int(self._color_map[index][2]))
            coord = np.array([coord])
            cv2.polylines(img=mask_image, pts=coord, isClosed=False, color=color, thickness=2)
            # mask_image[coord] = color

        return mask_image


if __name__ == '__main__':
    binary_seg_image = cv2.imread('D:/Code/github/tf_lanenet/data/training_data_example/gt_image_binary/0000.png', cv2.IMREAD_GRAYSCALE)
    print("binary_seg_image shape:",binary_seg_image.shape)
    binary_seg_image[np.where(binary_seg_image == 255)] = 1
    print("binary_seg_image np.where(binary_seg_image == 255):",np.where(binary_seg_image == 255))
    instance_seg_image = cv2.imread('D:/Code/github/tf_lanenet/data/training_data_example/gt_image_instance/0000.png', cv2.IMREAD_UNCHANGED)
    glog.info("__name__ instance_seg_image shape len:{:d}".format(len(instance_seg_image.shape)))
    instance_seg_image = cv2.cvtColor(instance_seg_image, cv2.COLOR_GRAY2BGR)
    glog.info("__name__ instance_seg_image shape len:{:d}".format(len(instance_seg_image.shape)))
    #print("instance_seg_image shape:",instance_seg_image.shape)
    ele_mex = np.max(instance_seg_image, axis=(0,1))
    print("ele_mex:",ele_mex)
    for i in range(3):
        if ele_mex[i] == 0:
            scale = 1
        else:
            scale = 255 / ele_mex[i]
        instance_seg_image[:, :, i] *= int(scale)
    embedding_image = np.array(instance_seg_image, np.uint8)
    cluster = LaneNetCluster()
    mask_image = cluster.get_lane_mask(binary_seg_ret=binary_seg_image,instance_seg_ret=instance_seg_image)
    det_img = embedding_image+mask_image
    plt.figure('det_img')
    plt.imshow(det_img[:, :, (2, 1, 0)])         
    #plt.figure('embedding')
    #plt.imshow(embedding_image[:, :, (2, 1, 0)])
    #plt.figure('mask_image')
    #plt.imshow(mask_image[:, :, (2, 1, 0)])
    plt.show()
View Code

             --train_lane_scnn.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import argparse
import math
import os
import os.path as ops
import time

import cv2
import glog as log
import numpy as np
import tensorflow as tf

from config import global_config
from lanenet_model import lanenet_merge_model
from data_provider import lanenet_data_processor

CFG = global_config.cfg
VGG_MEAN = [103.939, 116.779, 123.68]


def init_args():
    """

    :return:
    """
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset_dir', type=str,default='data/datasets_culane', help='The training dataset dir path')
    parser.add_argument('--net', type=str, default='vgg',  help='Which base net work to use')
    parser.add_argument('--weights_path', type=str,default='model/lanenet_culane_vgg_2019-02-02-14-05-16.ckpt-200000',help='The pretrained weights path')

    return parser.parse_args()


def minmax_scale(input_arr):
    """

    :param input_arr:
    :return:
    """
    min_val = np.min(input_arr)
    max_val = np.max(input_arr)

    output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)

    return output_arr


def train_net(dataset_dir, weights_path=None, net_flag='vgg'):
    """

    :param dataset_dir:
    :param net_flag: choose which base network to use
    :param weights_path:
    :return:
    """
    train_dataset_file = ops.join(dataset_dir, 'train.txt')
    val_dataset_file = ops.join(dataset_dir, 'val.txt')
    print('train_dataset_file:',train_dataset_file)
    print('val_dataset_file:',val_dataset_file)

    assert ops.exists(train_dataset_file)

    train_dataset = lanenet_data_processor.DataSet(train_dataset_file)
    val_dataset = lanenet_data_processor.DataSet(val_dataset_file)

    with tf.device('/gpu:1'):
        input_tensor = tf.placeholder(dtype=tf.float32,
                                      shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT,
                                             CFG.TRAIN.IMG_WIDTH, 3],
                                      name='input_tensor')
        binary_label_tensor = tf.placeholder(dtype=tf.int64,
                                             shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT,
                                                    CFG.TRAIN.IMG_WIDTH, 1],
                                             name='binary_input_label')
        instance_label_tensor = tf.placeholder(dtype=tf.float32,
                                               shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT,
                                                      CFG.TRAIN.IMG_WIDTH],
                                               name='instance_input_label')
        phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase')

        net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase)

        # calculate the loss
        compute_ret = net.compute_loss(input_tensor=input_tensor, binary_label=binary_label_tensor,
                                       instance_label=instance_label_tensor, name='lanenet_model')
        total_loss = compute_ret['total_loss']
        binary_seg_loss = compute_ret['binary_seg_loss']
        disc_loss = compute_ret['discriminative_loss']
        pix_embedding = compute_ret['instance_seg_logits']

        # calculate the accuracy
        out_logits = compute_ret['binary_seg_logits']
        out_logits = tf.nn.softmax(logits=out_logits)
        out_logits_out = tf.argmax(out_logits, axis=-1)
        #out = tf.argmax(out_logits, axis=-1)
        #out = tf.expand_dims(out, axis=-1)        
        out = tf.expand_dims(out_logits_out,axis=-1)


        idx = tf.where(tf.equal(binary_label_tensor, 1))
        pix_cls_ret = tf.gather_nd(out, idx)
        accuracy = tf.count_nonzero(pix_cls_ret)
        accuracy = tf.divide(accuracy, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64))

        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE, global_step,
                                                   100000, 0.1, staircase=True)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=learning_rate, momentum=0.9).minimize(loss=total_loss,
                                                                    var_list=tf.trainable_variables(),
                                                                    global_step=global_step)

    # Set tf saver
    saver = tf.train.Saver()
    model_save_dir = 'model/lanenet_culane'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    model_name = 'lanenet_culane_{:s}_{:s}.ckpt'.format(net_flag, str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # Set tf summary
    tboard_save_path = 'tboard/lanenet_culane/{:s}'.format(net_flag)
    if not ops.exists(tboard_save_path):
        os.makedirs(tboard_save_path)
    train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss)
    val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=total_loss)
    train_accuracy_scalar = tf.summary.scalar(name='train_accuracy', tensor=accuracy)
    val_accuracy_scalar = tf.summary.scalar(name='val_accuracy', tensor=accuracy)
    train_binary_seg_loss_scalar = tf.summary.scalar(name='train_binary_seg_loss', tensor=binary_seg_loss)
    val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss', tensor=binary_seg_loss)
    train_instance_seg_loss_scalar = tf.summary.scalar(name='train_instance_seg_loss', tensor=disc_loss)
    val_instance_seg_loss_scalar = tf.summary.scalar(name='val_instance_seg_loss', tensor=disc_loss)
    learning_rate_scalar = tf.summary.scalar(name='learning_rate', tensor=learning_rate)
    train_merge_summary_op = tf.summary.merge([train_accuracy_scalar, train_cost_scalar,
                                               learning_rate_scalar, train_binary_seg_loss_scalar,
                                               train_instance_seg_loss_scalar])
    val_merge_summary_op = tf.summary.merge([val_accuracy_scalar, val_cost_scalar,
                                             val_binary_seg_loss_scalar, val_instance_seg_loss_scalar])

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    log.info('Global configuration is as follows:')
    log.info(CFG)

    with sess.as_default():

        tf.train.write_graph(graph_or_graph_def=sess.graph, logdir='',
                             name='{:s}/lanenet_model.pb'.format(model_save_dir))

        if weights_path is None:
            log.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            log.info('Restore model from last model checkpoint {:s}'.format(weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        # 加载预训练参数
        log.info('jim.chen train_net net_flag:',net_flag)
        if net_flag == 'vgg' and weights_path is None:
            pretrained_weights = np.load(
                './data/vgg16.npy',
                encoding='latin1').item()
            log.info('jim.chen train_net net_flag is 1vgg')
            for vv in tf.trainable_variables():
                weights_key = vv.name.split('/')[-3]
                try:
                    weights = pretrained_weights[weights_key][0]
                    _op = tf.assign(vv, weights)
                    sess.run(_op)
                except Exception as e:
                    continue

        train_cost_time_mean = []
        val_cost_time_mean = []
        for epoch in range(train_epochs):
            # training part
            t_start = time.time()

            with tf.device('/cpu:0'):
                gt_imgs,  binary_gt_labels,instance_gt_labels = train_dataset.next_batch(CFG.TRAIN.BATCH_SIZE)
                gt_imgs = [cv2.resize(tmp,
                                      dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                                      dst=tmp,
                                      interpolation=cv2.INTER_LINEAR)
                           for tmp in gt_imgs]

                gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs]
                binary_gt_labels = [cv2.resize(tmp,
                                               dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                                               dst=tmp,
                                               interpolation=cv2.INTER_NEAREST)
                                    for tmp in binary_gt_labels]
                binary_gt_labels = [np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels]
                instance_gt_labels = [cv2.resize(tmp,
                                                 dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                                                 dst=tmp,
                                                 interpolation=cv2.INTER_NEAREST)
                                      for tmp in instance_gt_labels]
            phase_train = 'train'

            _, c, train_accuracy, train_summary, binary_loss, instance_loss, embedding, binary_seg_img = \
                sess.run([optimizer, total_loss,
                          accuracy,
                          train_merge_summary_op,
                          binary_seg_loss,
                          disc_loss,
                          pix_embedding,
                          out_logits_out],
                         feed_dict={input_tensor: gt_imgs,
                                    binary_label_tensor: binary_gt_labels,
                                    instance_label_tensor: instance_gt_labels,
                                    phase: phase_train})

            if math.isnan(c) or math.isnan(instance_loss) or math.isnan(binary_loss):
                log.error('cost is: {:.5f}'.format(c))
                log.error('binary cost is: {:.5f}'.format(binary_loss))
                log.error('instance cost is: {:.5f}'.format(instance_loss))
                cv2.imwrite('nan_image.png', gt_imgs[0] + VGG_MEAN)
                cv2.imwrite('nan_instance_label.png', instance_gt_labels[0])
                cv2.imwrite('nan_binary_label.png', binary_gt_labels[0] * 255)
                return

            if epoch % 100 == 0:
                cv2.imwrite('image.png', gt_imgs[0] + VGG_MEAN)
                cv2.imwrite('binary_label.png', binary_gt_labels[0] * 255)
                cv2.imwrite('instance_label.png', instance_gt_labels[0])
                cv2.imwrite('binary_seg_img.png', binary_seg_img[0] * 255)

                for i in range(4):
                    embedding[0][:, :, i] = minmax_scale(embedding[0][:, :, i])
                embedding_image = np.array(embedding[0], np.uint8)
                cv2.imwrite('embedding.png', embedding_image)

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)
            summary_writer.add_summary(summary=train_summary, global_step=epoch)

            # validation part
            with tf.device('/cpu:0'):
                gt_imgs_val, binary_gt_labels_val, instance_gt_labels_val \
                    = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE)
                gt_imgs_val = [cv2.resize(tmp,
                                          dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                                          dst=tmp,
                                          interpolation=cv2.INTER_LINEAR)
                               for tmp in gt_imgs_val]
                gt_imgs_val = [tmp - VGG_MEAN for tmp in gt_imgs_val]
                binary_gt_labels_val = [cv2.resize(tmp,
                                                   dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                                                   dst=tmp)
                                        for tmp in binary_gt_labels_val]
                binary_gt_labels_val = [np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels_val]
                instance_gt_labels_val = [cv2.resize(tmp,
                                                     dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                                                     dst=tmp,
                                                     interpolation=cv2.INTER_NEAREST)
                                          for tmp in instance_gt_labels_val]
            phase_val = 'test'

            t_start_val = time.time()
            c_val, val_summary, val_accuracy, val_binary_seg_loss, val_instance_seg_loss = \
                sess.run([total_loss, val_merge_summary_op, accuracy, binary_seg_loss, disc_loss],
                         feed_dict={input_tensor: gt_imgs_val,
                                    binary_label_tensor: binary_gt_labels_val,
                                    instance_label_tensor: instance_gt_labels_val,
                                    phase: phase_val})

            if epoch % 100 == 0:
                cv2.imwrite('test_image.png', gt_imgs_val[0] + VGG_MEAN)

            summary_writer.add_summary(val_summary, global_step=epoch)

            cost_time_val = time.time() - t_start_val
            val_cost_time_mean.append(cost_time_val)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                log.info('Epoch: {:d} total_loss= {:6f} binary_seg_loss= {:6f} instance_seg_loss= {:6f} accuracy= {:6f}'
                         ' mean_cost_time= {:5f}s '.
                         format(epoch + 1, c, binary_loss, instance_loss, train_accuracy,
                                np.mean(train_cost_time_mean)))
                train_cost_time_mean.clear()

            if epoch % CFG.TRAIN.TEST_DISPLAY_STEP == 0:
                log.info('Epoch_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} '
                         'instance_seg_loss= {:6f} accuracy= {:6f} '
                         'mean_cost_time= {:5f}s '.
                         format(epoch + 1, c_val, val_binary_seg_loss, val_instance_seg_loss, val_accuracy,
                                np.mean(val_cost_time_mean)))
                val_cost_time_mean.clear()

            if epoch % 2000 == 0:
                saver.save(sess=sess, save_path=model_save_path, global_step=epoch)
    sess.close()

    return


if __name__ == '__main__':
    # init args
    args = init_args()

    # train lanenet
    train_net(args.dataset_dir, args.weights_path, net_flag=args.net)
View Code

            以下是模型训练过程中生成的文件夹:

    ./summary

              ./figure

              ./checkpoint

     

       在主目录下,执行python train_lanenet_scnn.py,没有问题的话,可以开始训练了...

 

下一篇:

1、数据源:包括所有原始数据,分组后的数据;

2、数据预处理:包括数据的准备,数据的导入,数据的提取,数据的分组(训练与测试);

3、配置文件:包括各种参数与超参数,如:训练周期,训练步长,批量数据,学习率,卷积核大小,全连接大小,训练模型存放路径(checkpoint),摘要存放路径(summary)等;

4、基础网络:包括基本的网络组件,基础网,

5、训练主文件:主入口,用于搭建生成图(graph),会话(sess),数据导入模型训练,GPU配置,训练过程打印等

 

三、代码结构

             以下为原始文件夹:

    ./data

               -- ./InstanceSegmentationClass

               -- ./JPEGImages

               -- ./SegmentationClass

               -- datasets_gen_culane.py 用于从上面三个图片目录生成list.txt,train.txt,test.txt

# coding=utf-8
#create date:12/5/2018
#modified date:2/12/2019
#author:jim.chen
import os
import glob
import random
import math
import cv2
import numpy as np


def gen_list_txt(rela_dir,img_dir,img_seg_dir,img_inst_dir):
    #cwd = os.getcwd()
    #print("gen_list_txt cwd:",cwd)
    list_txt = "list.txt"
    png_glob = img_seg_dir+'/*.png'
    png_list_path = glob.glob(png_glob)
    png_list=[]
    print("gen_list_txt png_list_path:",png_list_path)
    with open(list_txt,"w") as w_f:
        for png in png_list_path:
            path,name = os.path.splitext(os.path.basename(png))
            print("path:",path)
            w_f.write(rela_dir+img_dir+'/'+path+'.jpg'+' '+rela_dir+png+' '+rela_dir+img_inst_dir+'/'+path+'.png'+'\n')

    w_f.close()

    with open(list_txt,"r") as r_f:
        for each_line in r_f:
            png_list.append(each_line)

    png_list.sort()
    print("gen_list_txt len(png_list):",len(png_list))
    train=random.sample(png_list,int(math.floor(len(png_list)*9/10)))
    train.sort()
    print("gen_list_txt train:",train)
    val=list(set(png_list).difference(set(train)))
    print("gen_list_txt val:",val)
    enum_train_val=['train','val']
    for item in enum_train_val:
        with open(item+'.txt','w') as w1_f:
            for num_item in eval(item):
                print("gen_list_txt num_item:",num_item)
                w1_f.write(num_item)
   
    w1_f.close()
    
def sync_gt_2_img(img_dir,img_seg_dir,img_inst_dir):
    cwd = os.getcwd()
    print("sync_gt_2_img img_dir:",img_dir," img_seg_dir:",img_seg_dir," img_inst_dir:",img_inst_dir)
    img_full_dir = cwd + '/' +img_dir
    img_seg_full_dir = cwd + '/' +img_seg_dir
    img_inst_full_dir = cwd + '/' +img_inst_dir    
    img_list = os.listdir(img_full_dir)
    for img in img_list:
        img_basename = os.path.splitext(img)[0]
        print("sync_gt_2_img img_basename:",img_basename)
        img_full_path =  img_full_dir + '/'+ img   
        img_seg_full_path =  img_seg_full_dir + '/'+img_basename +'.png'
        #print("sync_gt_2_img img_seg_full_path:",img_seg_full_path)
        img_inst_full_path =  img_inst_full_dir + '/'+img_basename +'.png'  
        #print("sync_gt_2_img img_inst_full_path:",img_inst_full_path)
        if not os.path.exists(img_inst_full_path):
            print("sync_gt_2_img not os.path.exists(img_seg_full_path)")
            if os.path.exists(img_full_path):
                os.remove(img_full_path)
            if os.path.exists(img_full_path):
                os.remove(img_full_path)           
    
def sync_seg_2_inst(img_seg_dir,img_inst_dir):
    cwd = os.getcwd()
    print("sync_seg_2_inst img_seg_dir:",img_seg_dir," img_inst_dir:",img_inst_dir)
    img_seg_full_dir = cwd + '/' +img_seg_dir
    img_inst_full_dir = cwd + '/' +img_inst_dir
    img_list = os.listdir(img_seg_dir)
    for img in img_list:
        img_basename = os.path.splitext(img)[0]
        print("sync_seg_2_inst img_basename:",img_basename)
        img_seg_full_path =  img_seg_full_dir + '/'+img_basename +'.jpg'
        img_inst_full_path =  img_inst_full_dir + '/'+img_basename +'.png'
        if not os.path.exists(img_inst_full_path):
            if os.path.exists(img_seg_full_path):
                print("sync_seg_2_inst os.remove(img_seg_full_path):",img_seg_full_path)
                os.remove(img_seg_full_path)
    
def gen_seg_color(img_inst_dir,img_seg_dir):
    cwd = os.getcwd()
    inPath = os.path.join(cwd,img_inst_dir)
    print(inPath)
    outPath=os.path.join(cwd,img_seg_dir)
    inPathDir = os.listdir(inPath)
    if not os.path.exists(outPath):
          os.makedirs(outPath)

    for l,file_name in enumerate(inPathDir):
        img_instance = cv2.imread(os.path.join(inPath,file_name))
        h,w,c = img_instance.shape
        print("l:",l," img_instance.shape:",img_instance.shape)  
        img_instance_new = np.zeros((h, w, c), dtype=np.uint8)
        for i in range(0,h):
              for j in range(0,w):
                  #print(img_instance[i][j])
                  if img_instance[i][j][0] != 0:
                      img_instance_new[i][j] = [255,255,255]
        img_instance_gray = cv2.cvtColor(img_instance_new, cv2.COLOR_BGR2GRAY)
        cv2.imwrite(os.path.join(outPath,file_name), img_instance_gray)
    print("generate segment finished!")
    
def gen_inst_color(img_inst_dir):
    cwd = os.getcwd()
    inPath = os.path.join(cwd,img_inst_dir)
    print(inPath)
    outPath=os.path.join(cwd,"img_inst_new")
    inPathDir = os.listdir(inPath)
    if not os.path.exists(outPath):
          os.makedirs(outPath)

    for l,file_name in enumerate(inPathDir):
        img_instance = cv2.imread(os.path.join(inPath,file_name))
        h,w,c = img_instance.shape
        print("l:",l," img_instance.shape:",img_instance.shape)
        img_instance_new = np.zeros((h, w, c), dtype=np.uint8)
        for i in range(0,h):
              for j in range(0,w):
                  #print(img_instance[i][j])
                  if img_instance[i][j][0] == 2:
                      img_instance_new[i][j] = [20,20,20]
                  elif img_instance[i][j][0] == 3:
                      img_instance_new[i][j] = [70,70,70]              
                  elif img_instance[i][j][0] == 4:
                      img_instance_new[i][j] = [120,120,120]     
                  elif img_instance[i][j][0] == 5:
                      img_instance_new[i][j] = [170,170,170]              
                  elif img_instance[i][j][0] == 6:
                      img_instance_new[i][j] = [220,220,220]
        img_instance_gray = cv2.cvtColor(img_instance_new, cv2.COLOR_BGR2GRAY)
        cv2.imwrite(os.path.join(outPath,file_name), img_instance_gray)
    print("generate instance finished!")
    

def detect_invalid_img(img_path):    
    img_instance = cv2.imread(img_path)
    h,w,c = img_instance.shape
    print("detect_invalid_img img_instance.shape:",img_instance.shape)
    for i in range(0,h):
          for j in range(0,w):
              if img_instance[i][j][0] != 0:
                  return False
    return True
                     
    
def filter_invalid_img(img_test_dir,img_seg_dir,img_inst_dir):
    cwd = os.getcwd()
    inPath = os.path.join(cwd,img_test_dir)
    inPathDir = os.listdir(inPath)
    print("filter_invalid_img inPathDir:",inPathDir)
    for l,file_name in enumerate(inPathDir):
        img_path = os.path.join(inPath,file_name)
        isdel = detect_invalid_img(img_path)
        if isdel:
            print("filter_invalid_img isdel:",isdel)
            os.remove(os.path.join(inPath,file_name))

def main():
    print("main begin")
    rela_dir = "data/datasets_culane_all/"
    img_dir = "image"
    img_seg_dir = "gt_image_binary"
    img_inst_dir = "gt_image_instance"
    gen_list_txt(rela_dir,img_dir,img_seg_dir,img_inst_dir)
    #sync_gt_2_img(img_dir,img_seg_dir,img_inst_dir)
    #gen_seg_color(img_inst_dir,img_seg_dir)
    #gen_inst_color(img_inst_dir)
    #filter_invalid_img(img_dir,img_seg_dir,img_inst_dir)
    #sync_seg_2_inst(img_dir,img_inst_dir)
    print("main end")
    
if __name__ == '__main__':
    main()
    
View Code

               --list.txt

data/datasets_culane/JPEGImages/0000.jpg data/datasets_culane/SegmentationClass\0000.png data/datasets_culane/InstanceSegmentationClass/0000.png
data/datasets_culane/JPEGImages/0001.jpg data/datasets_culane/SegmentationClass\0001.png data/datasets_culane/InstanceSegmentationClass/0001.png
data/datasets_culane/JPEGImages/0002.jpg data/datasets_culane/SegmentationClass\0002.png data/datasets_culane/InstanceSegmentationClass/0002.png
data/datasets_culane/JPEGImages/0003.jpg data/datasets_culane/SegmentationClass\0003.png data/datasets_culane/InstanceSegmentationClass/0003.png
data/datasets_culane/JPEGImages/0004.jpg data/datasets_culane/SegmentationClass\0004.png data/datasets_culane/InstanceSegmentationClass/0004.png
data/datasets_culane/JPEGImages/0005.jpg data/datasets_culane/SegmentationClass\0005.png data/datasets_culane/InstanceSegmentationClass/0005.png
View Code

     --train.txt

data/datasets_culane/JPEGImages/0000.jpg data/datasets_culane/SegmentationClass\0000.png data/datasets_culane/InstanceSegmentationClass/0000.png
data/datasets_culane/JPEGImages/0001.jpg data/datasets_culane/SegmentationClass\0001.png data/datasets_culane/InstanceSegmentationClass/0001.png
data/datasets_culane/JPEGImages/0002.jpg data/datasets_culane/SegmentationClass\0002.png data/datasets_culane/InstanceSegmentationClass/0002.png
data/datasets_culane/JPEGImages/0004.jpg data/datasets_culane/SegmentationClass\0004.png data/datasets_culane/InstanceSegmentationClass/0004.png
View Code

     --val.txt

data/datasets_culane/JPEGImages/0005.jpg data/datasets_culane/SegmentationClass\0005.png data/datasets_culane/InstanceSegmentationClass/0005.png
data/datasets_culane/JPEGImages/0003.jpg data/datasets_culane/SegmentationClass\0003.png data/datasets_culane/InstanceSegmentationClass/0003.png
View Code

            ./data_provider

               --data_processor.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os.path as ops

import cv2
import numpy as np

try:
    from cv2 import cv2
except ImportError:
    pass


class DataSet(object):
    def __init__(self, dataset_info_file):
        self._gt_img_list, self._gt_label_binary_list, \
        self._gt_label_instance_list = self._init_dataset(dataset_info_file)
        self._random_dataset()
        self._next_batch_loop_count = 0

    def _init_dataset(self, dataset_info_file):
        gt_img_list = []
        gt_label_binary_list = []
        gt_label_instance_list = []

        assert ops.exists(dataset_info_file), '{:s} not exist'.format(dataset_info_file)

        with open(dataset_info_file, 'r') as file:
            for _info in file:
                info_tmp = _info.strip(' ').split()

                gt_img_list.append(info_tmp[0])
                gt_label_binary_list.append(info_tmp[1])
                gt_label_instance_list.append(info_tmp[2])

        return gt_img_list, gt_label_binary_list, gt_label_instance_list

    def _random_dataset(self):
        assert len(self._gt_img_list) == len(self._gt_label_binary_list) == len(self._gt_label_instance_list)

        random_idx = np.random.permutation(len(self._gt_img_list))
        new_gt_img_list = []
        new_gt_label_binary_list = []
        new_gt_label_instance_list = []

        for index in random_idx:
            new_gt_img_list.append(self._gt_img_list[index])
            new_gt_label_binary_list.append(self._gt_label_binary_list[index])
            new_gt_label_instance_list.append(self._gt_label_instance_list[index])

        self._gt_img_list = new_gt_img_list
        self._gt_label_binary_list = new_gt_label_binary_list
        self._gt_label_instance_list = new_gt_label_instance_list

    def next_batch(self, batch_size):
        """

        :param batch_size:
        :return:
        """
        assert len(self._gt_label_binary_list) == len(self._gt_label_instance_list) \
               == len(self._gt_img_list)

        idx_start = batch_size * self._next_batch_loop_count
        idx_end = batch_size * self._next_batch_loop_count + batch_size

        if idx_start == 0 and idx_end > len(self._gt_label_binary_list):
            raise ValueError('Batch size cant be more than total numbers')

        if idx_end > len(self._gt_label_binary_list):
            self._random_dataset()
            self._next_batch_loop_count = 0
            return self.next_batch(batch_size)
        else:
            gt_img_list = self._gt_img_list[idx_start:idx_end]
            gt_label_binary_list = self._gt_label_binary_list[idx_start:idx_end]
            gt_label_instance_list = self._gt_label_instance_list[idx_start:idx_end]

            gt_imgs = []
            gt_labels_binary = []
            gt_labels_instance = []

            for gt_img_path in gt_img_list:
                gt_imgs.append(cv2.imread(gt_img_path, cv2.IMREAD_COLOR))

            for gt_label_path in gt_label_binary_list:
                label_img = cv2.imread(gt_label_path, cv2.IMREAD_COLOR)
                label_binary = np.zeros([label_img.shape[0], label_img.shape[1]], dtype=np.uint8)
                idx = np.where((label_img[:, :, :] != [0, 0, 0]).all(axis=2))
                label_binary[idx] = 1
                gt_labels_binary.append(label_binary)

            for gt_label_path in gt_label_instance_list:
                label_img = cv2.imread(gt_label_path, cv2.IMREAD_UNCHANGED)
                gt_labels_instance.append(label_img)

            self._next_batch_loop_count += 1
            return gt_imgs, gt_labels_binary, gt_labels_instance


if __name__ == '__main__':
    val = DataSet('/media/baidu/Data/Semantic_Segmentation/TUSimple_Lane_Detection/training/val.txt')
    b1, b2, b3 = val.next_batch(50)
    c1, c2, c3 = val.next_batch(50)
    dd, d2, d3 = val.next_batch(50)
View Code

               ./config   

                --global_config.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from easydict import EasyDict as edict

__C = edict()
# Consumers can get config by: from config import cfg

cfg = __C

# Train options
__C.TRAIN = edict()

# Set the shadownet training epochs
__C.TRAIN.EPOCHS = 200010
# Set the display step
__C.TRAIN.DISPLAY_STEP = 1
# Set the test display step during training process
__C.TRAIN.TEST_DISPLAY_STEP = 1000
# Set the momentum parameter of the optimizer
__C.TRAIN.MOMENTUM = 0.9
# Set the initial learning rate
__C.TRAIN.LEARNING_RATE = 0.0005
# Set the GPU resource used during training process
__C.TRAIN.GPU_MEMORY_FRACTION = 0.85
# Set the GPU allow growth parameter during tensorflow training process
__C.TRAIN.TF_ALLOW_GROWTH = True
# Set the shadownet training batch size
__C.TRAIN.BATCH_SIZE = 1

# Set the shadownet validation batch size
__C.TRAIN.VAL_BATCH_SIZE = 1
# Set the learning rate decay steps
__C.TRAIN.LR_DECAY_STEPS = 410000
# Set the learning rate decay rate
__C.TRAIN.LR_DECAY_RATE = 0.1
# Set the class numbers
__C.TRAIN.CLASSES_NUMS = 2
# Set the image height
__C.TRAIN.IMG_HEIGHT = 256
# Set the image width
__C.TRAIN.IMG_WIDTH = 512

# Test options
__C.TEST = edict()

# Set the GPU resource used during testing process
__C.TEST.GPU_MEMORY_FRACTION = 0.8
# Set the GPU allow growth parameter during tensorflow testing process
__C.TEST.TF_ALLOW_GROWTH = True
# Set the test batch size
__C.TEST.BATCH_SIZE = 1
View Code

    ./encoder_decoder_model

    --cnn_basenet.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
The base convolution neural networks mainly implement some useful cnn functions
"""
import tensorflow as tf
import numpy as np


class CNNBaseModel(object):
    """
    Base model for other specific cnn ctpn_models
    """

    def __init__(self):
        pass

    @staticmethod
    def conv2d(inputdata, out_channel, kernel_size, padding='SAME',
               stride=1, w_init=None, b_init=None,
               split=1, use_bias=True, data_format='NHWC', name=None):
        with tf.variable_scope(name):
            in_shape = inputdata.get_shape().as_list()
            channel_axis = 3 if data_format == 'NHWC' else 1
            in_channel = in_shape[channel_axis]
            assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"
            assert in_channel % split == 0
            assert out_channel % split == 0

            padding = padding.upper()

            if isinstance(kernel_size, list):
                filter_shape = [kernel_size[0], kernel_size[1]] + [in_channel / split, out_channel]
            else:
                filter_shape = [kernel_size, kernel_size] + [in_channel / split, out_channel]

            if isinstance(stride, list):
                strides = [1, stride[0], stride[1], 1] if data_format == 'NHWC' \
                    else [1, 1, stride[0], stride[1]]
            else:
                strides = [1, stride, stride, 1] if data_format == 'NHWC' \
                    else [1, 1, stride, stride]

            if w_init is None:
                w_init = tf.contrib.layers.variance_scaling_initializer()
            if b_init is None:
                b_init = tf.constant_initializer()

            w = tf.get_variable('W', filter_shape, initializer=w_init)
            b = None

            if use_bias:
                b = tf.get_variable('b', [out_channel], initializer=b_init)

            if split == 1:
                conv = tf.nn.conv2d(inputdata, w, strides, padding, data_format=data_format)
            else:
                inputs = tf.split(inputdata, split, channel_axis)
                kernels = tf.split(w, split, 3)
                outputs = [tf.nn.conv2d(i, k, strides, padding, data_format=data_format)
                           for i, k in zip(inputs, kernels)]
                conv = tf.concat(outputs, channel_axis)

            ret = tf.identity(tf.nn.bias_add(conv, b, data_format=data_format)
                              if use_bias else conv, name=name)

        return ret

    @staticmethod
    def relu(inputdata, name=None):
        return tf.nn.relu(features=inputdata, name=name)

    @staticmethod
    def sigmoid(inputdata, name=None):
        return tf.nn.sigmoid(x=inputdata, name=name)

    @staticmethod
    def maxpooling(inputdata, kernel_size, stride=None, padding='VALID',
                   data_format='NHWC', name=None):
        padding = padding.upper()

        if stride is None:
            stride = kernel_size

        if isinstance(kernel_size, list):
            kernel = [1, kernel_size[0], kernel_size[1], 1] if data_format == 'NHWC' else \
                [1, 1, kernel_size[0], kernel_size[1]]
        else:
            kernel = [1, kernel_size, kernel_size, 1] if data_format == 'NHWC' \
                else [1, 1, kernel_size, kernel_size]

        if isinstance(stride, list):
            strides = [1, stride[0], stride[1], 1] if data_format == 'NHWC' \
                else [1, 1, stride[0], stride[1]]
        else:
            strides = [1, stride, stride, 1] if data_format == 'NHWC' \
                else [1, 1, stride, stride]

        return tf.nn.max_pool(value=inputdata, ksize=kernel, strides=strides, padding=padding,
                              data_format=data_format, name=name)

    @staticmethod
    def avgpooling(inputdata, kernel_size, stride=None, padding='VALID',
                   data_format='NHWC', name=None):
        if stride is None:
            stride = kernel_size

        kernel = [1, kernel_size, kernel_size, 1] if data_format == 'NHWC' \
            else [1, 1, kernel_size, kernel_size]

        strides = [1, stride, stride, 1] if data_format == 'NHWC' else [1, 1, stride, stride]

        return tf.nn.avg_pool(value=inputdata, ksize=kernel, strides=strides, padding=padding,
                              data_format=data_format, name=name)

    @staticmethod
    def globalavgpooling(inputdata, data_format='NHWC', name=None):
        assert inputdata.shape.ndims == 4
        assert data_format in ['NHWC', 'NCHW']

        axis = [1, 2] if data_format == 'NHWC' else [2, 3]

        return tf.reduce_mean(input_tensor=inputdata, axis=axis, name=name)

    @staticmethod
    def layernorm(inputdata, epsilon=1e-5, use_bias=True, use_scale=True,
                  data_format='NHWC', name=None):
        shape = inputdata.get_shape().as_list()
        ndims = len(shape)
        assert ndims in [2, 4]

        mean, var = tf.nn.moments(inputdata, list(range(1, len(shape))), keep_dims=True)

        if data_format == 'NCHW':
            channnel = shape[1]
            new_shape = [1, channnel, 1, 1]
        else:
            channnel = shape[-1]
            new_shape = [1, 1, 1, channnel]
        if ndims == 2:
            new_shape = [1, channnel]

        if use_bias:
            beta = tf.get_variable('beta', [channnel], initializer=tf.constant_initializer())
            beta = tf.reshape(beta, new_shape)
        else:
            beta = tf.zeros([1] * ndims, name='beta')
        if use_scale:
            gamma = tf.get_variable('gamma', [channnel], initializer=tf.constant_initializer(1.0))
            gamma = tf.reshape(gamma, new_shape)
        else:
            gamma = tf.ones([1] * ndims, name='gamma')

        return tf.nn.batch_normalization(inputdata, mean, var, beta, gamma, epsilon, name=name)

    @staticmethod
    def instancenorm(inputdata, epsilon=1e-5, data_format='NHWC', use_affine=True, name=None):
        shape = inputdata.get_shape().as_list()
        if len(shape) != 4:
            raise ValueError("Input data of instancebn layer has to be 4D tensor")

        if data_format == 'NHWC':
            axis = [1, 2]
            ch = shape[3]
            new_shape = [1, 1, 1, ch]
        else:
            axis = [2, 3]
            ch = shape[1]
            new_shape = [1, ch, 1, 1]
        if ch is None:
            raise ValueError("Input of instancebn require known channel!")

        mean, var = tf.nn.moments(inputdata, axis, keep_dims=True)

        if not use_affine:
            return tf.divide(inputdata - mean, tf.sqrt(var + epsilon), name='output')

        beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
        beta = tf.reshape(beta, new_shape)
        gamma = tf.get_variable('gamma', [ch], initializer=tf.constant_initializer(1.0))
        gamma = tf.reshape(gamma, new_shape)
        return tf.nn.batch_normalization(inputdata, mean, var, beta, gamma, epsilon, name=name)

    @staticmethod
    def dropout(inputdata, keep_prob, noise_shape=None, name=None):
        return tf.nn.dropout(inputdata, keep_prob=keep_prob, noise_shape=noise_shape, name=name)

    @staticmethod
    def fullyconnect(inputdata, out_dim, w_init=None, b_init=None,
                     use_bias=True, name=None):
        shape = inputdata.get_shape().as_list()[1:]
        if None not in shape:
            inputdata = tf.reshape(inputdata, [-1, int(np.prod(shape))])
        else:
            inputdata = tf.reshape(inputdata, tf.stack([tf.shape(inputdata)[0], -1]))

        if w_init is None:
            w_init = tf.contrib.layers.variance_scaling_initializer()
        if b_init is None:
            b_init = tf.constant_initializer()

        ret = tf.layers.dense(inputs=inputdata, activation=lambda x: tf.identity(x, name='output'),
                              use_bias=use_bias, name=name,
                              kernel_initializer=w_init, bias_initializer=b_init,
                              trainable=True, units=out_dim)
        return ret

    @staticmethod
    def layerbn(inputdata, is_training, name):
        return tf.layers.batch_normalization(inputs=inputdata, training=is_training, name=name)

    @staticmethod
    def squeeze(inputdata, axis=None, name=None):
        return tf.squeeze(input=inputdata, axis=axis, name=name)

    @staticmethod
    def deconv2d(inputdata, out_channel, kernel_size, padding='SAME',
                 stride=1, w_init=None, b_init=None,
                 use_bias=True, activation=None, data_format='channels_last',
                 trainable=True, name=None):
        with tf.variable_scope(name):
            in_shape = inputdata.get_shape().as_list()
            channel_axis = 3 if data_format == 'channels_last' else 1
            in_channel = in_shape[channel_axis]
            assert in_channel is not None, "[Deconv2D] Input cannot have unknown channel!"

            padding = padding.upper()

            if w_init is None:
                w_init = tf.contrib.layers.variance_scaling_initializer()
            if b_init is None:
                b_init = tf.constant_initializer()

            ret = tf.layers.conv2d_transpose(inputs=inputdata, filters=out_channel,
                                             kernel_size=kernel_size,
                                             strides=stride, padding=padding,
                                             data_format=data_format,
                                             activation=activation, use_bias=use_bias,
                                             kernel_initializer=w_init,
                                             bias_initializer=b_init, trainable=trainable,
                                             name=name)
        return ret

    @staticmethod
    def dilation_conv(input_tensor, k_size, out_dims, rate, padding='SAME',
                      w_init=None, b_init=None, use_bias=False, name=None):
        with tf.variable_scope(name):
            in_shape = input_tensor.get_shape().as_list()
            in_channel = in_shape[3]
            assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"

            padding = padding.upper()

            if isinstance(k_size, list):
                filter_shape = [k_size[0], k_size[1]] + [in_channel, out_dims]
            else:
                filter_shape = [k_size, k_size] + [in_channel, out_dims]

            if w_init is None:
                w_init = tf.contrib.layers.variance_scaling_initializer()
            if b_init is None:
                b_init = tf.constant_initializer()

            w = tf.get_variable('W', filter_shape, initializer=w_init)
            b = None

            if use_bias:
                b = tf.get_variable('b', [out_dims], initializer=b_init)

            conv = tf.nn.atrous_conv2d(value=input_tensor, filters=w, rate=rate,
                                       padding=padding, name='dilation_conv')

            if use_bias:
                ret = tf.add(conv, b)
            else:
                ret = conv

        return ret

    @staticmethod
    def spatial_dropout(input_tensor, keep_prob, is_training, name, seed=1234):
        tf.set_random_seed(seed=seed)

        def f1():
            with tf.variable_scope(name):
                return input_tensor

        def f2():
            with tf.variable_scope(name):
                num_feature_maps = [tf.shape(input_tensor)[0], tf.shape(input_tensor)[3]]

                random_tensor = keep_prob
                random_tensor += tf.random_uniform(num_feature_maps,
                                                   seed=seed,
                                                   dtype=input_tensor.dtype)

                binary_tensor = tf.floor(random_tensor)

                binary_tensor = tf.reshape(binary_tensor,
                                           [-1, 1, 1, tf.shape(input_tensor)[3]])
                ret = input_tensor * binary_tensor
                return ret

        output = tf.cond(is_training, f2, f1)
        return output

    @staticmethod
    def lrelu(inputdata, name, alpha=0.2):
        with tf.variable_scope(name):
            return tf.nn.relu(inputdata) - alpha * tf.nn.relu(-inputdata)
View Code

    --vgg_scnn_encoder.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from collections import OrderedDict

import tensorflow as tf
import glog as log
import math
import sys

sys.path.append('encoder_decoder_model')
import cnn_basenet


class VGG16Encoder(cnn_basenet.CNNBaseModel):
    def __init__(self, phase):
        super(VGG16Encoder, self).__init__()
        self._train_phase = tf.constant('train', dtype=tf.string)
        self._test_phase = tf.constant('test', dtype=tf.string)
        self._phase = phase
        self._is_training = self._init_phase()

    def _init_phase(self):
        return tf.equal(self._phase, self._train_phase)

    def _conv_stage(self, input_tensor, k_size, out_dims, name, stride=1, pad='SAME'):
        with tf.variable_scope(name):
            conv = self.conv2d(inputdata=input_tensor, out_channel=out_dims,
                               kernel_size=k_size, stride=stride,
                               use_bias=False, padding=pad, name='conv')
            bn = self.layerbn(inputdata=conv, is_training=self._is_training, name='bn')
            relu = self.relu(inputdata=bn, name='relu')
            return relu

    def _fc_stage(self, input_tensor, out_dims, name, use_bias=False):
        with tf.variable_scope(name):
            fc = self.fullyconnect(inputdata=input_tensor, out_dim=out_dims, use_bias=use_bias, name='fc')
            bn = self.layerbn(inputdata=fc, is_training=self._is_training, name='bn')
            relu = self.relu(inputdata=bn, name='relu')
        return relu

    def scnn_u2d_d2u(self,input_tensor):
        output_list_old = []
        output_list_new = []
        shape_list = input_tensor.get_shape().as_list()
        log.info("scnn_u2d_d2u shape_list:{:}".format(shape_list))
        h_size = input_tensor.get_shape().as_list()[1]
        log.info("scnn_u2d_d2u h_size:{:}".format(h_size))
        channel_size = input_tensor.get_shape().as_list()[3]
        #up2down conv
        for i in range(h_size):
            output_list_old.append(tf.expand_dims(input_tensor[:,i,:,:],axis=1))
        output_list_new.append(tf.expand_dims(input_tensor[:,0,:,:],axis=1))
        
        w_ud = tf.get_variable('w_ud',[1,9,channel_size,channel_size],initializer=tf.random_normal_initializer(0,math.sqrt(2.0/(9*channel_size*channel_size*2))))
        with tf.variable_scope("scnn_u2d"):
            scnn_u2d = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_old[0],w_ud,[1,1,1,1],'SAME')),output_list_old[1])
            output_list_new.append(scnn_u2d)
        
        for i in range(2,h_size):
            with tf.variable_scope("scnn_u2d",reuse=True):
                scnn_u2d = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_new[i-1],w_ud,[1,1,1,1],'SAME')),output_list_old[i])
                output_list_new.append(scnn_u2d)
        
        #down2up conv
        output_list_old = output_list_new
        output_list_new = []
        length = h_size-1
        output_list_new.append(output_list_old[length])
        w_du = tf.get_variable('w_du',[1,9,channel_size,channel_size],initializer=tf.random_normal_initializer(0,math.sqrt(2.0/(9*channel_size*channel_size*2))))
        with tf.variable_scope('scnn_d2u'):
            scnn_d2u = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_old[length],w_du,[1,1,1,1],'SAME')),output_list_old[length-1])
            output_list_new.append(scnn_d2u)
            
        for i in range(2,h_size):
            with tf.variable_scope("scnn_d2u",reuse=True):
                scnn_d2u = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_new[i-1],w_du,[1,1,1,1],'SAME')),output_list_old[length-i])
                output_list_new.append(scnn_d2u)
                
        output_list_new.reverse()
        #log.info("scnn_u2d_d2u output_list_new:{:}".format(output_list_new))
        out_tensor = tf.stack(output_list_new,axis = 1)
        out_tensor = tf.squeeze(out_tensor,axis=2)
        return out_tensor
    
    def scnn_l2r_r2l(self,input_tensor):
        output_list_old = []
        output_list_new = []
        shape_list = input_tensor.get_shape().as_list()
        log.info("scnn_l2r_r2l shape_list:{:}".format(shape_list))
        w_size = input_tensor.get_shape().as_list()[2]
        log.info("scnn_l2r_r2l w_size:{:}".format(w_size))
        channel_size = input_tensor.get_shape().as_list()[3]
        
        #left2right conv
        for i in range(w_size):
            output_list_old.append(tf.expand_dims(input_tensor[:,:,i,:],axis=2))
        output_list_new.append(tf.expand_dims(input_tensor[:,:,0,:],axis=2))
        
        w_lr = tf.get_variable('w_lr',[9,1,channel_size,channel_size],initializer=tf.random_normal_initializer(0,math.sqrt(2.0/(9*channel_size*channel_size*5))))
        with tf.variable_scope("scnn_l2r"):
            scnn_l2r = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_old[0],w_lr,[1,1,1,1],'SAME')),output_list_old[1])
            output_list_new.append(scnn_l2r)
        
        for i in range(2,w_size):
            with tf.variable_scope("scnn_l2r",reuse=True):
                scnn_l2r = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_new[i-1],w_lr,[1,1,1,1],'SAME')),output_list_old[i])
                output_list_new.append(scnn_l2r)
        #log.info("output_list_new:{:}".format(output_list_new))
        
        #right2left conv
        output_list_old = output_list_new
        output_list_new = []
        length = w_size-1
        output_list_new.append(output_list_old[length])
        w_rl = tf.get_variable('w_rl',[9,1,channel_size,channel_size],initializer=tf.random_normal_initializer(0,math.sqrt(2.0/(9*channel_size*channel_size*5))))
        with tf.variable_scope('scnn_r2l'):
            scnn_r2l = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_old[length],w_rl,[1,1,1,1],'SAME')),output_list_old[length-1])
            output_list_new.append(scnn_r2l)
            
        for i in range(2,w_size):
            with tf.variable_scope("scnn_r2l",reuse=True):
                scnn_r2l = tf.add(tf.nn.relu(tf.nn.conv2d(output_list_new[i-1],w_rl,[1,1,1,1],'SAME')),output_list_old[length-i])
                output_list_new.append(scnn_r2l)
                
        output_list_new.reverse()
        out_tensor = tf.stack(output_list_new,axis = 2)
        out_tensor = tf.squeeze(out_tensor,axis=3)
        return out_tensor
     
    
    def encode(self, input_tensor, name):
        ret = OrderedDict()

        with tf.variable_scope(name):
            # conv stage 1_1
            conv_1_1 = self._conv_stage(input_tensor=input_tensor, k_size=3, out_dims=64, name='conv1_1')
            log.info("encode conv_1_1:{:}".format(conv_1_1.get_shape().as_list()))
            
            # conv stage 1_2
            conv_1_2 = self._conv_stage(input_tensor=conv_1_1, k_size=3, out_dims=64, name='conv1_2')
            log.info("encode conv_1_2:{:}".format(conv_1_2.get_shape().as_list()))
            
            # pool stage 1
            pool1 = self.maxpooling(inputdata=conv_1_2, kernel_size=2, stride=2, name='pool1')
            log.info("encode pool1:{:}".format(pool1.get_shape().as_list()))
            
            # conv stage 2_1
            conv_2_1 = self._conv_stage(input_tensor=pool1, k_size=3,  out_dims=128, name='conv2_1')
            log.info("encode conv_2_1:{:}".format(conv_2_1.get_shape().as_list()))

            # conv stage 2_2
            conv_2_2 = self._conv_stage(input_tensor=conv_2_1, k_size=3, out_dims=128, name='conv2_2')
            log.info("encode conv_2_2:{:}".format(conv_2_2.get_shape().as_list()))
            
            # pool stage 2
            pool2 = self.maxpooling(inputdata=conv_2_2, kernel_size=2, stride=2, name='pool2')
            log.info("encode pool2:{:}".format(pool2.get_shape().as_list()))                        

            # conv stage 3_1
            conv_3_1 = self._conv_stage(input_tensor=pool2, k_size=3, out_dims=256, name='conv3_1')
            log.info("encode conv_3_1:{:}".format(conv_3_1.get_shape().as_list()))                            

            # conv_stage 3_2
            conv_3_2 = self._conv_stage(input_tensor=conv_3_1, k_size=3, out_dims=256, name='conv3_2')
            log.info("encode conv_3_2:{:}".format(conv_3_2.get_shape().as_list()))                              

            # conv stage 3_3
            conv_3_3 = self._conv_stage(input_tensor=conv_3_2, k_size=3, out_dims=256, name='conv3_3')
            log.info("encode conv_3_3:{:}".format(conv_3_3.get_shape().as_list()))                              

            ret['conv_3_3'] = dict()
            ret['conv_3_3']['data'] = conv_3_3
            ret['conv_3_3']['shape'] = conv_3_3.get_shape().as_list()

            # pool stage 3
            pool3 = self.maxpooling(inputdata=conv_3_3, kernel_size=2, stride=2, name='pool3')
            log.info("encode pool3:{:}".format(pool3.get_shape().as_list()))
                                   
            ret['pool3'] = dict()
            ret['pool3']['data'] = pool3
            ret['pool3']['shape'] = pool3.get_shape().as_list()

            # conv stage 4_1
            conv_4_1 = self._conv_stage(input_tensor=pool3, k_size=3, out_dims=512, name='conv4_1')
            log.info("encode conv_4_1:{:}".format(conv_4_1.get_shape().as_list()))                            

            # conv stage 4_2
            conv_4_2 = self._conv_stage(input_tensor=conv_4_1, k_size=3, out_dims=512, name='conv4_2')
            log.info("encode conv_4_2:{:}".format(conv_4_2.get_shape().as_list()))

            # conv stage 4_3
            conv_4_3 = self._conv_stage(input_tensor=conv_4_2, k_size=3,  out_dims=512, name='conv4_3')
            log.info("encode conv_4_3:{:}".format(conv_4_3.get_shape().as_list()))                               

            # pool stage 4
            pool4 = self.maxpooling(inputdata=conv_4_3, kernel_size=2, stride=2, name='pool4')
            log.info("encode pool4:{:}".format(pool4.get_shape().as_list()))
                                    
            ret['pool4'] = dict()
            ret['pool4']['data'] = pool4
            ret['pool4']['shape'] = pool4.get_shape().as_list()

            # conv stage 5_1
            conv_5_1 = self._conv_stage(input_tensor=pool4, k_size=3,
                                        out_dims=512, name='conv5_1')
            log.info("encode conv_5_1:{:}".format(conv_5_1.get_shape().as_list()))                            

            # conv stage 5_2
            conv_5_2 = self._conv_stage(input_tensor=conv_5_1, k_size=3,
                                        out_dims=512, name='conv5_2')
            log.info("encode conv_5_2:{:}".format(conv_5_2.get_shape().as_list()))

            # conv stage 5_3
            conv_5_3 = self._conv_stage(input_tensor=conv_5_2, k_size=3,
                                        out_dims=512, name='conv5_3')
            log.info("encode conv_5_3:{:}".format(conv_5_3.get_shape().as_list()))
            
            # conv stage 6_1
            conv_6_1 = self._conv_stage(input_tensor=conv_5_3, k_size=3,
                          out_dims=128, name='conv6_1')
            log.info("encode conv_6_1:{:}".format(conv_6_1.get_shape().as_list()))

            scnn_ud = self.scnn_u2d_d2u(conv_6_1)
            log.info("encode scnn_ud:{:}".format(scnn_ud.get_shape().as_list()))
            
            scnn_lr = self.scnn_l2r_r2l(scnn_ud)
            log.info("encode scnn_lr:{:}".format(scnn_lr.get_shape().as_list()))
            
            # pool stage 5
            pool5 = self.maxpooling(inputdata=scnn_lr, kernel_size=2,
                                    stride=2, name='pool5')
            log.info("encode pool5:{:}".format(pool5.get_shape().as_list()))

            ret['pool5'] = dict()
            ret['pool5']['data'] = pool5
            ret['pool5']['shape'] = pool5.get_shape().as_list()

            # fc stage 1
            # fc6 = self._fc_stage(input_tensor=pool5, out_dims=4096, name='fc6',
            #                      use_bias=False, flags=flags)

            # fc stage 2
            # fc7 = self._fc_stage(input_tensor=fc6, out_dims=4096, name='fc7',
            #                      use_bias=False, flags=flags)

        return ret

if __name__ == '__main__':
    a = tf.placeholder(dtype=tf.float32, shape=[1, 2048, 2048, 3], name='input')
    encoder = VGG16Encoder(phase=tf.constant('train', dtype=tf.string))
    ret = encoder.encode(a, name='encode')
    for layer_name, layer_info in ret.items():
        print('layer name: {:s} shape: {}'.format(layer_name, layer_info['shape']))
View Code

    --dense_encoder.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import tensorflow as tf
from collections import OrderedDict

#from encoder_decoder_model import cnn_basenet
import cnn_basenet

class DenseEncoder(cnn_basenet.CNNBaseModel):
    """
    基于DenseNet的编码器
    """
    def __init__(self, l, n, growthrate, phase, with_bc=False, bc_theta=0.5):
        super(DenseEncoder, self).__init__()
        self._L = l
        self._block_depth = int((l - n - 1) / n)
        self._N = n
        self._growthrate = growthrate
        self._with_bc = with_bc
        self._phase = phase
        self._train_phase = tf.constant('train', dtype=tf.string)
        self._test_phase = tf.constant('test', dtype=tf.string)
        self._is_training = self._init_phase()
        self._bc_theta = bc_theta
        return

    def _init_phase(self):
        return tf.equal(self._phase, self._train_phase)

    def __str__(self):
        encoder_info = 'A densenet with net depth: {:d} block nums: ' \
                       '{:d} growth rate: {:d} block depth: {:d}'.\
            format(self._L, self._N, self._growthrate, self._block_depth)
        return encoder_info

    def _composite_conv(self, inputdata, out_channel, name):
        with tf.variable_scope(name):
            bn_1 = self.layerbn(inputdata=inputdata, is_training=self._is_training, name='bn_1')
            relu_1 = self.relu(bn_1, name='relu_1')
            if self._with_bc:
                conv_1 = self.conv2d(inputdata=relu_1, out_channel=out_channel,
                                     kernel_size=1,
                                     padding='SAME', stride=1, use_bias=False,
                                     name='conv_1')

                bn_2 = self.layerbn(inputdata=conv_1, is_training=self._is_training, name='bn_2')

                relu_2 = self.relu(inputdata=bn_2, name='relu_2')
                conv_2 = self.conv2d(inputdata=relu_2, out_channel=out_channel,
                                     kernel_size=3,
                                     stride=1, padding='SAME', use_bias=False,
                                     name='conv_2')
                return conv_2
            else:
                conv_2 = self.conv2d(inputdata=relu_1, out_channel=out_channel,
                                     kernel_size=3,
                                     stride=1, padding='SAME', use_bias=False,
                                     name='conv_2')
                return conv_2

    def _denseconnect_layers(self, inputdata, name):
        with tf.variable_scope(name):
            conv_out = self._composite_conv(inputdata=inputdata, name='composite_conv',  out_channel=self._growthrate)
            concate_cout = tf.concat(values=[conv_out, inputdata], axis=3, name='concatenate')

        return concate_cout

    def _transition_layers(self, inputdata, name):
        """
        Mainly implement the Pooling layer mentioned in DenseNet paper
        :param inputdata:
        :param name:
        :return:
        """
        input_channels = inputdata.get_shape().as_list()[3]

        with tf.variable_scope(name):
            # First batch norm
            bn = self.layerbn(inputdata=inputdata, is_training=self._is_training, name='bn')

            # Second 1*1 conv
            if self._with_bc:
                out_channels = int(input_channels * self._bc_theta)
                conv = self.conv2d(inputdata=bn, out_channel=out_channels,
                                   kernel_size=1, stride=1, use_bias=False,
                                   name='conv')
                # Third average pooling
                avgpool_out = self.avgpooling(inputdata=conv, kernel_size=2,
                                              stride=2, name='avgpool')
                return avgpool_out
            else:
                conv = self.conv2d(inputdata=bn, out_channel=input_channels,
                                   kernel_size=1, stride=1, use_bias=False,
                                   name='conv')
                # Third average pooling
                avgpool_out = self.avgpooling(inputdata=conv, kernel_size=2,
                                              stride=2, name='avgpool')
                return avgpool_out

    def _dense_block(self, inputdata, name):
        """
        Mainly implement the dense block mentioned in DenseNet figure 1
        :param inputdata:
        :param name:
        :return:
        """
        block_input = inputdata
        with tf.variable_scope(name):
            for i in range(self._block_depth):
                block_layer_name = '{:s}_layer_{:d}'.format(name, i + 1)
                block_input = self._denseconnect_layers(inputdata=block_input,
                                                        name=block_layer_name)
        return block_input

    def encode(self, input_tensor, name):
        """
        DenseNet编码
        :param input_tensor:
        :param name:
        :return:
        """
        encode_ret = OrderedDict()

        # First apply a 3*3 16 out channels conv layer
        # mentioned in DenseNet paper Implementation Details part
        with tf.variable_scope(name):
            conv1 = self.conv2d(inputdata=input_tensor, out_channel=16,
                                kernel_size=3, use_bias=False, name='conv1')
            dense_block_input = conv1

            # Second apply dense block stage
            for dense_block_nums in range(self._N):
                dense_block_name = 'Dense_Block_{:d}'.format(dense_block_nums + 1)

                # dense connectivity
                dense_block_out = self._dense_block(inputdata=dense_block_input,
                                                    name=dense_block_name)
                # apply the trainsition part
                dense_block_out = self._transition_layers(inputdata=dense_block_out,
                                                          name=dense_block_name)
                dense_block_input = dense_block_out
                encode_ret[dense_block_name] = dict()
                encode_ret[dense_block_name]['data'] = dense_block_out
                encode_ret[dense_block_name]['shape'] = dense_block_out.get_shape().as_list()

        return encode_ret


if __name__ == '__main__':
    input_tensor = tf.placeholder(dtype=tf.float32, shape=[None, 384, 1248, 3], name='input_tensor')
    encoder = DenseEncoder(l=100, growthrate=16, with_bc=True, phase=tf.constant('train'), n=5)
    ret = encoder.encode(input_tensor=input_tensor, name='Dense_Encode')
    for layer_name, layer_info in ret.items():
        print('layer_name: {:s} shape: {}'.format(layer_name, layer_info['shape']))
View Code

    --fcn_decoder.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import tensorflow as tf

#from encoder_decoder_model import cnn_basenet
#from encoder_decoder_model import vgg_encoder
#from encoder_decoder_model import dense_encoder
import cnn_basenet
import vgg_encoder
import dense_encoder

class FCNDecoder(cnn_basenet.CNNBaseModel):

    def __init__(self, phase):
        """

        """
        super(FCNDecoder, self).__init__()
        self._train_phase = tf.constant('train', dtype=tf.string)
        self._phase = phase
        self._is_training = self._init_phase()

    def _init_phase(self):
        """

        :return:
        """
        return tf.equal(self._phase, self._train_phase)

    def decode(self, input_tensor_dict, decode_layer_list, name):
        """
        解码特征信息反卷积还原
        :param input_tensor_dict:
        :param decode_layer_list: 需要解码的层名称需要由深到浅顺序写
                                  eg. ['pool5', 'pool4', 'pool3']
        :param name:
        :return:
        """
        ret = dict()

        with tf.variable_scope(name):
            # score stage 1
            input_tensor = input_tensor_dict[decode_layer_list[0]]['data']

            score = self.conv2d(inputdata=input_tensor, out_channel=64,
                                kernel_size=1, use_bias=False, name='score_origin')
            ret['score'] = dict()                    
            ret['score']['data'] = score
            ret['score']['shape'] = score.get_shape().as_list()      
                          
            decode_layer_list = decode_layer_list[1:]
            print("len(decode_layer_list):",len(decode_layer_list))
            for i in range(len(decode_layer_list)):
                deconv = self.deconv2d(inputdata=score, out_channel=64, kernel_size=4,
                                       stride=2, use_bias=False, name='deconv_{:d}'.format(i + 1))
                input_tensor = input_tensor_dict[decode_layer_list[i]]['data']
                score = self.conv2d(inputdata=input_tensor, out_channel=64,
                                    kernel_size=1, use_bias=False, name='score_{:d}'.format(i + 1))
                fused = tf.add(deconv, score, name='fuse_{:d}'.format(i + 1))
                score = fused
                ret['fuse_{:d}'.format(i + 1)] = dict()
                ret['fuse_{:d}'.format(i + 1)]['data'] = fused
                ret['fuse_{:d}'.format(i + 1)]['shape'] = fused.get_shape().as_list()
               
            deconv_final = self.deconv2d(inputdata=score, out_channel=64, kernel_size=16,
                                         stride=8, use_bias=False, name='deconv_final')

            score_final = self.conv2d(inputdata=deconv_final, out_channel=2,
                                      kernel_size=1, use_bias=False, name='score_final')
              
            ret['logits'] = score_final
            ret['deconv'] = deconv_final
            
            ret['logits'] = dict()
            ret['logits']['data'] = score_final
            ret['logits']['shape'] = score_final.get_shape().as_list() 
            
            ret['deconv'] = dict()
            ret['deconv']['data'] = deconv_final
            ret['deconv']['shape'] = deconv_final.get_shape().as_list() 
        return ret


if __name__ == '__main__':

    vgg_encoder = vgg_encoder.VGG16Encoder(phase=tf.constant('train', tf.string))
    dense_encoder = dense_encoder.DenseEncoder(l=40, growthrate=12,
                                               with_bc=True, phase='train', n=5)
    decoder = FCNDecoder(phase='train')

    in_tensor = tf.placeholder(dtype=tf.float32, shape=[None, 256, 512, 3],
                               name='input')

    vgg_encode_ret = vgg_encoder.encode(in_tensor, name='vgg_encoder')
    dense_encode_ret = dense_encoder.encode(in_tensor, name='dense_encoder')
    decode_ret = decoder.decode(vgg_encode_ret, name='decoder',
                                decode_layer_list=['pool5',
                                                   'pool4',
                                                   'pool3'])
                                                   
    for layer_name, layer_info in decode_ret.items():
        print('layer name: {:s} shape: {}'.format(layer_name, layer_info['shape']))                                               
View Code

     ./merge_model

     --merge_model.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import tensorflow as tf

from encoder_decoder_model import vgg_encoder
from encoder_decoder_model import fcn_decoder
from encoder_decoder_model import dense_encoder
from encoder_decoder_model import cnn_basenet
from lanenet_model import lanenet_discriminative_loss
from encoder_decoder_model import vgg_scnn_encoder
import glog

class LaneNet(cnn_basenet.CNNBaseModel):
    """
    实现语义分割模型
    """
    def __init__(self, phase, net_flag='vgg'):
        """

        """
        super(LaneNet, self).__init__()
        self._net_flag = net_flag
        self._phase = phase
        if self._net_flag == 'vgg':
            self._encoder = vgg_encoder.VGG16Encoder(phase=phase)
        elif self._net_flag == 'vgg_scnn':
            self._encoder = vgg_scnn_encoder.VGG16Encoder(phase=phase)
        elif self._net_flag == 'dense':
            self._encoder = dense_encoder.DenseEncoder(l=20, growthrate=8,
                                                       with_bc=True,
                                                       phase=phase,
                                                       n=5)
        self._decoder = fcn_decoder.FCNDecoder(phase=phase)
        return

    def __str__(self):
        """

        :return:
        """
        info = 'Semantic Segmentation use {:s} as basenet to encode'.format(self._net_flag)
        return info

    def _build_model(self, input_tensor, name):
        """
        前向传播过程
        :param input_tensor:
        :param name:
        :return:
        """
        with tf.variable_scope(name):
            # first encode
            encode_ret = self._encoder.encode(input_tensor=input_tensor,
                                              name='encode')

            # second decode
            if self._net_flag.lower() == 'vgg':
                decode_ret = self._decoder.decode(input_tensor_dict=encode_ret,
                                                  name='decode',
                                                  decode_layer_list=['pool5',
                                                                     'pool4',
                                                                     'pool3'])
                return decode_ret
            if self._net_flag.lower() == 'vgg_scnn':
                decode_ret = self._decoder.decode(input_tensor_dict=encode_ret,
                                                  name='decode',
                                                  decode_layer_list=['pool5',
                                                                     'pool4',
                                                                     'pool3'])
                return decode_ret
            elif self._net_flag.lower() == 'dense':
                decode_ret = self._decoder.decode(input_tensor_dict=encode_ret,
                                                  name='decode',
                                                  decode_layer_list=['Dense_Block_5',
                                                                     'Dense_Block_4',
                                                                     'Dense_Block_3'])
                return decode_ret

    def compute_loss(self, input_tensor, binary_label, instance_label, name):
        """
        计算LaneNet模型损失函数
        :param input_tensor:
        :param binary_label:
        :param instance_label:
        :param name:
        :return:
        """
        with tf.variable_scope(name):
            # 前向传播获取logits
            inference_ret = self._build_model(input_tensor=input_tensor, name='inference')
            glog.info('compute_loss inference_ret:{:}'.format(inference_ret)) 
            # 计算二值分割损失函数
            decode_logits = inference_ret['logits']
            binary_label_plain = tf.reshape(
                binary_label,
                shape=[binary_label.get_shape().as_list()[0] *
                       binary_label.get_shape().as_list()[1] *
                       binary_label.get_shape().as_list()[2]])
            glog.info('compute_loss binary_label_plain:{:}'.format(binary_label_plain))            
            # 加入class weights
            unique_labels, unique_id, counts = tf.unique_with_counts(binary_label_plain)
            counts = tf.cast(counts, tf.float32)
            glog.info('compute_loss counts:{:}'.format(counts)) 
            inverse_weights = tf.divide(1.0,
                                        tf.log(tf.add(tf.divide(tf.constant(1.0), counts),
                                                      tf.constant(1.02))))
            glog.info('compute_loss inverse_weights:{:}'.format(inverse_weights))                                           
            inverse_weights = tf.gather(inverse_weights, binary_label)
            glog.info('compute_loss gather inverse_weights:{:}'.format(inverse_weights))      
            binary_segmenatation_loss = tf.losses.sparse_softmax_cross_entropy(
                labels=binary_label, logits=decode_logits, weights=inverse_weights)
            glog.info('compute_loss binary_segmenatation_loss:{:}'.format(binary_segmenatation_loss))    
            binary_segmenatation_loss = tf.reduce_mean(binary_segmenatation_loss)
            glog.info('compute_loss reduce_mean binary_segmenatation_loss:{:}'.format(binary_segmenatation_loss))    
            # 计算discriminative loss损失函数
            decode_deconv = inference_ret['deconv']
            # 像素嵌入
            pix_embedding = self.conv2d(inputdata=decode_deconv, out_channel=4, kernel_size=1,
                                        use_bias=False, name='pix_embedding_conv')
            pix_embedding = self.relu(inputdata=pix_embedding, name='pix_embedding_relu')
            # 计算discriminative loss
            image_shape = (pix_embedding.get_shape().as_list()[1], pix_embedding.get_shape().as_list()[2])
            glog.info('compute_loss image_shape:{:}'.format(image_shape)) 
            disc_loss, l_var, l_dist, l_reg = \
                lanenet_discriminative_loss.discriminative_loss(
                    pix_embedding, instance_label, 4, image_shape, 0.5, 3.0, 1.0, 1.0, 0.001)
            glog.info('compute_loss disc_loss:{:}'.format(disc_loss))
            # 合并损失
            l2_reg_loss = tf.constant(0.0, tf.float32)
            for vv in tf.trainable_variables():
                if 'bn' in vv.name:
                    continue
                else:
                    l2_reg_loss = tf.add(l2_reg_loss, tf.nn.l2_loss(vv))
            l2_reg_loss *= 0.001
            total_loss = 0.5 * binary_segmenatation_loss + 0.5 * disc_loss + l2_reg_loss

            ret = {
                'total_loss': total_loss,
                'binary_seg_logits': decode_logits,
                'instance_seg_logits': pix_embedding,
                'binary_seg_loss': binary_segmenatation_loss,
                'discriminative_loss': disc_loss
            }

            return ret

    def inference(self, input_tensor, name):
        """

        :param input_tensor:
        :param name:
        :return:
        """
        with tf.variable_scope(name):
            # 前向传播获取logits
            inference_ret = self._build_model(input_tensor=input_tensor, name='inference')
            # 计算二值分割损失函数
            decode_logits = inference_ret['logits']
            binary_seg_ret = tf.nn.softmax(logits=decode_logits)
            binary_seg_ret = tf.argmax(binary_seg_ret, axis=-1)
            # 计算像素嵌入
            decode_deconv = inference_ret['deconv']
            # 像素嵌入
            pix_embedding = self.conv2d(inputdata=decode_deconv, out_channel=4, kernel_size=1,
                                        use_bias=False, name='pix_embedding_conv')
            pix_embedding = self.relu(inputdata=pix_embedding, name='pix_embedding_relu')

            return binary_seg_ret, pix_embedding


if __name__ == '__main__':
    model = LaneNet(tf.constant('train', dtype=tf.string))
    input_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input')
    binary_label = tf.placeholder(dtype=tf.int64, shape=[1, 256, 512, 1], name='label')
    instance_label = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 1], name='label')
    ret = model.compute_loss(input_tensor=input_tensor, binary_label=binary_label,
                             instance_label=instance_label, name='loss')
    for vv in tf.trainable_variables():
        if 'bn' in vv.name:
            continue
        print(vv.name)
View Code

     --dirscriminative_loss.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import tensorflow as tf
import glog

def discriminative_loss_single(
        prediction,
        correct_label,
        feature_dim,
        label_shape,
        delta_v,
        delta_d,
        param_var,
        param_dist,
        param_reg):
    """
    论文equ(1)提到的实例分割损失函数
    :param prediction: inference of network
    :param correct_label: instance label
    :param feature_dim: feature dimension of prediction
    :param label_shape: shape of label
    :param delta_v: cut off variance distance
    :param delta_d: cut off cluster distance
    :param param_var: weight for intra cluster variance
    :param param_dist: weight for inter cluster distances
    :param param_reg: weight regularization
    """

    # 像素对齐为一行
    correct_label = tf.reshape(
        correct_label, [
            label_shape[1] * label_shape[0]])
    reshaped_pred = tf.reshape(
        prediction, [
            label_shape[1] * label_shape[0], feature_dim])

    # 统计实例个数
    unique_labels, unique_id, counts = tf.unique_with_counts(correct_label)
    counts = tf.cast(counts, tf.float32)
    num_instances = tf.size(unique_labels)
    glog.info('discriminative_loss_single counts:{:} num_instances:{:}'.format(counts,num_instances))
    # 计算pixel embedding均值向量
    segmented_sum = tf.unsorted_segment_sum(
        reshaped_pred, unique_id, num_instances)  
    mu = tf.div(segmented_sum, tf.reshape(counts, (-1, 1)))
    mu_expand = tf.gather(mu, unique_id)

    # 计算公式的loss(var)
    distance = tf.norm(tf.subtract(mu_expand, reshaped_pred), axis=1)
    distance = tf.subtract(distance, delta_v)
    distance = tf.clip_by_value(distance, 0., distance)
    distance = tf.square(distance)

    l_var = tf.unsorted_segment_sum(distance, unique_id, num_instances)
    l_var = tf.div(l_var, counts)
    l_var = tf.reduce_sum(l_var)
    l_var = tf.divide(l_var, tf.cast(num_instances, tf.float32))

    # 计算公式的loss(dist)
    mu_interleaved_rep = tf.tile(mu, [num_instances, 1])
    mu_band_rep = tf.tile(mu, [1, num_instances])
    mu_band_rep = tf.reshape(
        mu_band_rep,
        (num_instances *
         num_instances,
         feature_dim))

    mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep)

    # 去除掩模上的零点
    intermediate_tensor = tf.reduce_sum(tf.abs(mu_diff), axis=1)
    zero_vector = tf.zeros(1, dtype=tf.float32)
    bool_mask = tf.not_equal(intermediate_tensor, zero_vector)
    mu_diff_bool = tf.boolean_mask(mu_diff, bool_mask)

    mu_norm = tf.norm(mu_diff_bool, axis=1)
    mu_norm = tf.subtract(2. * delta_d, mu_norm)
    mu_norm = tf.clip_by_value(mu_norm, 0., mu_norm)
    mu_norm = tf.square(mu_norm)

    l_dist = tf.reduce_mean(mu_norm)

    # 计算原始Discriminative Loss论文中提到的正则项损失
    l_reg = tf.reduce_mean(tf.norm(mu, axis=1))

    # 合并损失按照原始Discriminative Loss论文中提到的参数合并
    param_scale = 1.
    l_var = param_var * l_var
    l_dist = param_dist * l_dist
    l_reg = param_reg * l_reg

    loss = param_scale * (l_var + l_dist + l_reg)

    return loss, l_var, l_dist, l_reg


def discriminative_loss(prediction, correct_label, feature_dim, image_shape,
                        delta_v, delta_d, param_var, param_dist, param_reg):
    """
    按照论文的思想迭代计算loss损失
    :return: discriminative loss and its three components
    """

    def cond(label, batch, out_loss, out_var, out_dist, out_reg, i):
        return tf.less(i, tf.shape(batch)[0])

    def body(label, batch, out_loss, out_var, out_dist, out_reg, i):
        disc_loss, l_var, l_dist, l_reg = discriminative_loss_single(
            prediction[i], correct_label[i], feature_dim, image_shape, delta_v, delta_d, param_var, param_dist, param_reg)

        out_loss = out_loss.write(i, disc_loss)
        out_var = out_var.write(i, l_var)
        out_dist = out_dist.write(i, l_dist)
        out_reg = out_reg.write(i, l_reg)

        return label, batch, out_loss, out_var, out_dist, out_reg, i + 1

    # TensorArray is a data structure that support dynamic writing
    output_ta_loss = tf.TensorArray(dtype=tf.float32,
                                    size=0,
                                    dynamic_size=True)
    output_ta_var = tf.TensorArray(dtype=tf.float32,
                                   size=0,
                                   dynamic_size=True)
    output_ta_dist = tf.TensorArray(dtype=tf.float32,
                                    size=0,
                                    dynamic_size=True)
    output_ta_reg = tf.TensorArray(dtype=tf.float32,
                                   size=0,
                                   dynamic_size=True)

    _, _, out_loss_op, out_var_op, out_dist_op, out_reg_op, _ = tf.while_loop(
        cond, body, [
            correct_label, prediction, output_ta_loss, output_ta_var, output_ta_dist, output_ta_reg, 0])
    out_loss_op = out_loss_op.stack()
    out_var_op = out_var_op.stack()
    out_dist_op = out_dist_op.stack()
    out_reg_op = out_reg_op.stack()

    disc_loss = tf.reduce_mean(out_loss_op)
    l_var = tf.reduce_mean(out_var_op)
    l_dist = tf.reduce_mean(out_dist_op)
    l_reg = tf.reduce_mean(out_reg_op)

    return disc_loss, l_var, l_dist, l_reg
View Code

    --postpostprecess.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LaneNet模型后处理
"""
import numpy as np
import matplotlib.pyplot as plt
import cv2
import glog

try:
    from cv2 import cv2
except ImportError:
    pass


class LaneNetPoseProcessor(object):
    """

    """
    def __init__(self):
        """

        """
        pass

    @staticmethod
    def _morphological_process(image, kernel_size=5):
        """

        :param image:
        :param kernel_size:
        :return:
        """
        if image.dtype is not np.uint8:
            image = np.array(image, np.uint8)
        glog.info("_morphological_process image shape len:{:d}".format(len(image.shape)))
        if len(image.shape) == 3:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        glog.info("_morphological_process image shape len:{:d}".format(len(image.shape)))
        kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size))
        
        # close operation fille hole
        closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1)

        return closing

    @staticmethod
    def _connect_components_analysis(image):
        """

        :param image:
        :return:
        """
        glog.info("_connect_components_analysis image shape len:{:d}".format(len(image.shape)))
        if len(image.shape) == 3:
            gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray_image = image

        return cv2.connectedComponentsWithStats(gray_image, connectivity=8, ltype=cv2.CV_32S)

    def postprocess(self, image, minarea_threshold=15):
        """

        :param image:
        :param minarea_threshold: 连通域分析阈值
        :return:
        """
        # 首先进行图像形态学运算
        morphological_ret = self._morphological_process(image, kernel_size=5)
        glog.info("postprocess image shape len:{:d}".format(len(image.shape)))

        # 进行连通域分析
        connect_components_analysis_ret = self._connect_components_analysis(image=morphological_ret)
        glog.info("postprocess connect_components_analysis_ret:{:}".format(connect_components_analysis_ret))
        # 排序连通域并删除过小的连通域
        labels = connect_components_analysis_ret[1]
        stats = connect_components_analysis_ret[2]
        glog.info("postprocess labels:{:}".format(labels))
        glog.info("postprocess stats:{:}".format(stats))
        for index, stat in enumerate(stats):
            if stat[4] <= minarea_threshold:
                idx = np.where(labels == index)
                morphological_ret[idx] = 0

        return morphological_ret


if __name__ == '__main__':
    processor = LaneNetPoseProcessor()

    image = cv2.imread('D:/Code/github/tf_lanenet/data/training_data_example/gt_image_binary/0000.png', cv2.IMREAD_UNCHANGED) #IMREAD_GRAYSCALE

    postprocess_ret = processor.postprocess(image)

    plt.figure('src')
    plt.imshow(image)
    plt.figure('post')
    plt.imshow(postprocess_ret)
    plt.show()
View Code

         --cluster.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
实现LaneNet中实例分割的聚类部分
"""
import numpy as np
import glog as log
import matplotlib.pyplot as plt
from sklearn.cluster import MeanShift
from sklearn.cluster import DBSCAN
import time
import warnings
import cv2
import glog

try:
    from cv2 import cv2
except ImportError:
    pass


class LaneNetCluster(object):
    """
    实例分割聚类器
    """

    def __init__(self):
        """

        """
        self._color_map = [np.array([255, 0, 0]),
                           np.array([0, 255, 0]),
                           np.array([0, 0, 255]),
                           np.array([125, 125, 0]),
                           np.array([0, 125, 125]),
                           np.array([125, 0, 125]),
                           np.array([50, 100, 50]),
                           np.array([100, 50, 100])]
        pass

    @staticmethod
    def _cluster(prediction, bandwidth):
        """
        实现论文SectionⅡ的cluster部分
        :param prediction:
        :param bandwidth:
        :return:
        """
        ms = MeanShift(bandwidth, bin_seeding=True)
        # log.info('开始Mean shift聚类 ...')
        tic = time.time()
        try:
            ms.fit(prediction)
        except ValueError as err:
            log.error(err)
            return 0, [], []
        # log.info('Mean Shift耗时: {:.5f}s'.format(time.time() - tic))
        labels = ms.labels_
        cluster_centers = ms.cluster_centers_

        num_clusters = cluster_centers.shape[0]

        # log.info('聚类簇个数为: {:d}'.format(num_clusters))

        return num_clusters, labels, cluster_centers

    @staticmethod
    def _cluster_v2(prediction):
        """
        dbscan cluster
        :param prediction:
        :return:
        """
        db = DBSCAN(eps=0.7, min_samples=200).fit(prediction)
        db_labels = db.labels_
        unique_labels = np.unique(db_labels)
        unique_labels = [tmp for tmp in unique_labels if tmp != -1]
        log.info('聚类簇个数为: {:d}'.format(len(unique_labels)))

        num_clusters = len(unique_labels)
        cluster_centers = db.components_

        return num_clusters, db_labels, cluster_centers

    @staticmethod
    def _get_lane_area(binary_seg_ret, instance_seg_ret):
        """
        通过二值分割掩码图在实例分割图上获取所有车道线的特征向量
        :param binary_seg_ret:
        :param instance_seg_ret:
        :return:
        """
        idx = np.where(binary_seg_ret == 1)
        
        print("_get_lane_area idx:",idx)
        print("_get_lane_area idx len:",len(idx))
        print("_get_lane_area idx len[0]:",len(idx[0]))
        print("_get_lane_area idx len[1]:",len(idx[1]))
        lane_embedding_feats = []
        lane_coordinate = []
        for i in range(len(idx[0])):
            lane_embedding_feats.append(instance_seg_ret[idx[0][i], idx[1][i]])
            #print("_get_lane_area instance_seg_ret[idx[0][i], idx[1][i]]:",instance_seg_ret[idx[0][i], idx[1][i]])
            lane_coordinate.append([idx[0][i], idx[1][i]])
            #print("_get_lane_area idx[0][i]:",idx[0][i]," , idx[1][i]:", idx[1][i])

        return np.array(lane_embedding_feats, np.float32), np.array(lane_coordinate, np.int64)

    @staticmethod
    def _thresh_coord(coord):
        """
        过滤实例车道线位置坐标点,假设车道线是连续的, 因此车道线点的坐标变换应该是平滑变化的不应该出现跳变
        :param coord: [(x, y)]
        :return:
        """
        pts_x = coord[:, 0]
        mean_x = np.mean(pts_x)

        idx = np.where(np.abs(pts_x - mean_x) < mean_x)

        return coord[idx[0]]

    @staticmethod
    def _lane_fit(lane_pts):
        """
        车道线多项式拟合
        :param lane_pts:
        :return:
        """
        if not isinstance(lane_pts, np.ndarray):
            lane_pts = np.array(lane_pts, np.float32)

        x = lane_pts[:, 0]
        y = lane_pts[:, 1]
        x_fit = []
        y_fit = []
        with warnings.catch_warnings():
            warnings.filterwarnings('error')
            try:
                f1 = np.polyfit(x, y, 3)
                p1 = np.poly1d(f1)
                x_min = int(np.min(x))
                x_max = int(np.max(x))
                x_fit = []
                for i in range(x_min, x_max + 1):
                    x_fit.append(i)
                y_fit = p1(x_fit)
            except Warning as e:
                x_fit = x
                y_fit = y
            finally:
                return zip(x_fit, y_fit)

    def get_lane_mask(self, binary_seg_ret, instance_seg_ret):
        """

        :param binary_seg_ret:
        :param instance_seg_ret:
        :return:
        """
        lane_embedding_feats, lane_coordinate = self._get_lane_area(binary_seg_ret, instance_seg_ret)
        
        num_clusters, labels, cluster_centers = self._cluster(lane_embedding_feats, bandwidth=1.5)

        # 聚类簇超过八个则选择其中类内样本最多的八个聚类簇保留下来
        if num_clusters > 8:
            cluster_sample_nums = []
            for i in range(num_clusters):
                cluster_sample_nums.append(len(np.where(labels == i)[0]))
            sort_idx = np.argsort(-np.array(cluster_sample_nums, np.int64))
            cluster_index = np.array(range(num_clusters))[sort_idx[0:4]]
        else:
            cluster_index = range(num_clusters)

        mask_image = np.zeros(shape=[binary_seg_ret.shape[0], binary_seg_ret.shape[1], 3], dtype=np.uint8)

        for index, i in enumerate(cluster_index):
            idx = np.where(labels == i)
            coord = lane_coordinate[idx]
            # coord = self._thresh_coord(coord)
            coord = np.flip(coord, axis=1)
            # coord = (coord[:, 0], coord[:, 1])
            color = (int(self._color_map[index][0]),
                     int(self._color_map[index][1]),
                     int(self._color_map[index][2]))
            coord = np.array([coord])
            cv2.polylines(img=mask_image, pts=coord, isClosed=False, color=color, thickness=2)
            # mask_image[coord] = color

        return mask_image


if __name__ == '__main__':
    binary_seg_image = cv2.imread('D:/Code/github/tf_lanenet/data/training_data_example/gt_image_binary/0000.png', cv2.IMREAD_GRAYSCALE)
    print("binary_seg_image shape:",binary_seg_image.shape)
    binary_seg_image[np.where(binary_seg_image == 255)] = 1
    print("binary_seg_image np.where(binary_seg_image == 255):",np.where(binary_seg_image == 255))
    instance_seg_image = cv2.imread('D:/Code/github/tf_lanenet/data/training_data_example/gt_image_instance/0000.png', cv2.IMREAD_UNCHANGED)
    glog.info("__name__ instance_seg_image shape len:{:d}".format(len(instance_seg_image.shape)))
    instance_seg_image = cv2.cvtColor(instance_seg_image, cv2.COLOR_GRAY2BGR)
    glog.info("__name__ instance_seg_image shape len:{:d}".format(len(instance_seg_image.shape)))
    #print("instance_seg_image shape:",instance_seg_image.shape)
    ele_mex = np.max(instance_seg_image, axis=(0,1))
    print("ele_mex:",ele_mex)
    for i in range(3):
        if ele_mex[i] == 0:
            scale = 1
        else:
            scale = 255 / ele_mex[i]
        instance_seg_image[:, :, i] *= int(scale)
    embedding_image = np.array(instance_seg_image, np.uint8)
    cluster = LaneNetCluster()
    mask_image = cluster.get_lane_mask(binary_seg_ret=binary_seg_image,instance_seg_ret=instance_seg_image)
    det_img = embedding_image+mask_image
    plt.figure('det_img')
    plt.imshow(det_img[:, :, (2, 1, 0)])         
    #plt.figure('embedding')
    #plt.imshow(embedding_image[:, :, (2, 1, 0)])
    #plt.figure('mask_image')
    #plt.imshow(mask_image[:, :, (2, 1, 0)])
    plt.show()
View Code

             --train_lane_scnn.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import argparse
import math
import os
import os.path as ops
import time

import cv2
import glog as log
import numpy as np
import tensorflow as tf

from config import global_config
from lanenet_model import lanenet_merge_model
from data_provider import lanenet_data_processor

CFG = global_config.cfg
VGG_MEAN = [103.939, 116.779, 123.68]


def init_args():
    """

    :return:
    """
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset_dir', type=str,default='data/datasets_culane', help='The training dataset dir path')
    parser.add_argument('--net', type=str, default='vgg',  help='Which base net work to use')
    parser.add_argument('--weights_path', type=str,default='model/lanenet_culane_vgg_2019-02-02-14-05-16.ckpt-200000',help='The pretrained weights path')

    return parser.parse_args()


def minmax_scale(input_arr):
    """

    :param input_arr:
    :return:
    """
    min_val = np.min(input_arr)
    max_val = np.max(input_arr)

    output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)

    return output_arr


def train_net(dataset_dir, weights_path=None, net_flag='vgg'):
    """

    :param dataset_dir:
    :param net_flag: choose which base network to use
    :param weights_path:
    :return:
    """
    train_dataset_file = ops.join(dataset_dir, 'train.txt')
    val_dataset_file = ops.join(dataset_dir, 'val.txt')
    print('train_dataset_file:',train_dataset_file)
    print('val_dataset_file:',val_dataset_file)

    assert ops.exists(train_dataset_file)

    train_dataset = lanenet_data_processor.DataSet(train_dataset_file)
    val_dataset = lanenet_data_processor.DataSet(val_dataset_file)

    with tf.device('/gpu:1'):
        input_tensor = tf.placeholder(dtype=tf.float32,
                                      shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT,
                                             CFG.TRAIN.IMG_WIDTH, 3],
                                      name='input_tensor')
        binary_label_tensor = tf.placeholder(dtype=tf.int64,
                                             shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT,
                                                    CFG.TRAIN.IMG_WIDTH, 1],
                                             name='binary_input_label')
        instance_label_tensor = tf.placeholder(dtype=tf.float32,
                                               shape=[CFG.TRAIN.BATCH_SIZE, CFG.TRAIN.IMG_HEIGHT,
                                                      CFG.TRAIN.IMG_WIDTH],
                                               name='instance_input_label')
        phase = tf.placeholder(dtype=tf.string, shape=None, name='net_phase')

        net = lanenet_merge_model.LaneNet(net_flag=net_flag, phase=phase)

        # calculate the loss
        compute_ret = net.compute_loss(input_tensor=input_tensor, binary_label=binary_label_tensor,
                                       instance_label=instance_label_tensor, name='lanenet_model')
        total_loss = compute_ret['total_loss']
        binary_seg_loss = compute_ret['binary_seg_loss']
        disc_loss = compute_ret['discriminative_loss']
        pix_embedding = compute_ret['instance_seg_logits']

        # calculate the accuracy
        out_logits = compute_ret['binary_seg_logits']
        out_logits = tf.nn.softmax(logits=out_logits)
        out_logits_out = tf.argmax(out_logits, axis=-1)
        #out = tf.argmax(out_logits, axis=-1)
        #out = tf.expand_dims(out, axis=-1)        
        out = tf.expand_dims(out_logits_out,axis=-1)


        idx = tf.where(tf.equal(binary_label_tensor, 1))
        pix_cls_ret = tf.gather_nd(out, idx)
        accuracy = tf.count_nonzero(pix_cls_ret)
        accuracy = tf.divide(accuracy, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64))

        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(CFG.TRAIN.LEARNING_RATE, global_step,
                                                   100000, 0.1, staircase=True)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=learning_rate, momentum=0.9).minimize(loss=total_loss,
                                                                    var_list=tf.trainable_variables(),
                                                                    global_step=global_step)

    # Set tf saver
    saver = tf.train.Saver()
    model_save_dir = 'model/lanenet_culane'
    if not ops.exists(model_save_dir):
        os.makedirs(model_save_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    model_name = 'lanenet_culane_{:s}_{:s}.ckpt'.format(net_flag, str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # Set tf summary
    tboard_save_path = 'tboard/lanenet_culane/{:s}'.format(net_flag)
    if not ops.exists(tboard_save_path):
        os.makedirs(tboard_save_path)
    train_cost_scalar = tf.summary.scalar(name='train_cost', tensor=total_loss)
    val_cost_scalar = tf.summary.scalar(name='val_cost', tensor=total_loss)
    train_accuracy_scalar = tf.summary.scalar(name='train_accuracy', tensor=accuracy)
    val_accuracy_scalar = tf.summary.scalar(name='val_accuracy', tensor=accuracy)
    train_binary_seg_loss_scalar = tf.summary.scalar(name='train_binary_seg_loss', tensor=binary_seg_loss)
    val_binary_seg_loss_scalar = tf.summary.scalar(name='val_binary_seg_loss', tensor=binary_seg_loss)
    train_instance_seg_loss_scalar = tf.summary.scalar(name='train_instance_seg_loss', tensor=disc_loss)
    val_instance_seg_loss_scalar = tf.summary.scalar(name='val_instance_seg_loss', tensor=disc_loss)
    learning_rate_scalar = tf.summary.scalar(name='learning_rate', tensor=learning_rate)
    train_merge_summary_op = tf.summary.merge([train_accuracy_scalar, train_cost_scalar,
                                               learning_rate_scalar, train_binary_seg_loss_scalar,
                                               train_instance_seg_loss_scalar])
    val_merge_summary_op = tf.summary.merge([val_accuracy_scalar, val_cost_scalar,
                                             val_binary_seg_loss_scalar, val_instance_seg_loss_scalar])

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_path)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    log.info('Global configuration is as follows:')
    log.info(CFG)

    with sess.as_default():

        tf.train.write_graph(graph_or_graph_def=sess.graph, logdir='',
                             name='{:s}/lanenet_model.pb'.format(model_save_dir))

        if weights_path is None:
            log.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            log.info('Restore model from last model checkpoint {:s}'.format(weights_path))
            saver.restore(sess=sess, save_path=weights_path)

        # 加载预训练参数
        log.info('jim.chen train_net net_flag:',net_flag)
        if net_flag == 'vgg' and weights_path is None:
            pretrained_weights = np.load(
                './data/vgg16.npy',
                encoding='latin1').item()
            log.info('jim.chen train_net net_flag is 1vgg')
            for vv in tf.trainable_variables():
                weights_key = vv.name.split('/')[-3]
                try:
                    weights = pretrained_weights[weights_key][0]
                    _op = tf.assign(vv, weights)
                    sess.run(_op)
                except Exception as e:
                    continue

        train_cost_time_mean = []
        val_cost_time_mean = []
        for epoch in range(train_epochs):
            # training part
            t_start = time.time()

            with tf.device('/cpu:0'):
                gt_imgs,  binary_gt_labels,instance_gt_labels = train_dataset.next_batch(CFG.TRAIN.BATCH_SIZE)
                gt_imgs = [cv2.resize(tmp,
                                      dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                                      dst=tmp,
                                      interpolation=cv2.INTER_LINEAR)
                           for tmp in gt_imgs]

                gt_imgs = [tmp - VGG_MEAN for tmp in gt_imgs]
                binary_gt_labels = [cv2.resize(tmp,
                                               dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                                               dst=tmp,
                                               interpolation=cv2.INTER_NEAREST)
                                    for tmp in binary_gt_labels]
                binary_gt_labels = [np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels]
                instance_gt_labels = [cv2.resize(tmp,
                                                 dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                                                 dst=tmp,
                                                 interpolation=cv2.INTER_NEAREST)
                                      for tmp in instance_gt_labels]
            phase_train = 'train'

            _, c, train_accuracy, train_summary, binary_loss, instance_loss, embedding, binary_seg_img = \
                sess.run([optimizer, total_loss,
                          accuracy,
                          train_merge_summary_op,
                          binary_seg_loss,
                          disc_loss,
                          pix_embedding,
                          out_logits_out],
                         feed_dict={input_tensor: gt_imgs,
                                    binary_label_tensor: binary_gt_labels,
                                    instance_label_tensor: instance_gt_labels,
                                    phase: phase_train})

            if math.isnan(c) or math.isnan(instance_loss) or math.isnan(binary_loss):
                log.error('cost is: {:.5f}'.format(c))
                log.error('binary cost is: {:.5f}'.format(binary_loss))
                log.error('instance cost is: {:.5f}'.format(instance_loss))
                cv2.imwrite('nan_image.png', gt_imgs[0] + VGG_MEAN)
                cv2.imwrite('nan_instance_label.png', instance_gt_labels[0])
                cv2.imwrite('nan_binary_label.png', binary_gt_labels[0] * 255)
                return

            if epoch % 100 == 0:
                cv2.imwrite('image.png', gt_imgs[0] + VGG_MEAN)
                cv2.imwrite('binary_label.png', binary_gt_labels[0] * 255)
                cv2.imwrite('instance_label.png', instance_gt_labels[0])
                cv2.imwrite('binary_seg_img.png', binary_seg_img[0] * 255)

                for i in range(4):
                    embedding[0][:, :, i] = minmax_scale(embedding[0][:, :, i])
                embedding_image = np.array(embedding[0], np.uint8)
                cv2.imwrite('embedding.png', embedding_image)

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)
            summary_writer.add_summary(summary=train_summary, global_step=epoch)

            # validation part
            with tf.device('/cpu:0'):
                gt_imgs_val, binary_gt_labels_val, instance_gt_labels_val \
                    = val_dataset.next_batch(CFG.TRAIN.VAL_BATCH_SIZE)
                gt_imgs_val = [cv2.resize(tmp,
                                          dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                                          dst=tmp,
                                          interpolation=cv2.INTER_LINEAR)
                               for tmp in gt_imgs_val]
                gt_imgs_val = [tmp - VGG_MEAN for tmp in gt_imgs_val]
                binary_gt_labels_val = [cv2.resize(tmp,
                                                   dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                                                   dst=tmp)
                                        for tmp in binary_gt_labels_val]
                binary_gt_labels_val = [np.expand_dims(tmp, axis=-1) for tmp in binary_gt_labels_val]
                instance_gt_labels_val = [cv2.resize(tmp,
                                                     dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
                                                     dst=tmp,
                                                     interpolation=cv2.INTER_NEAREST)
                                          for tmp in instance_gt_labels_val]
            phase_val = 'test'

            t_start_val = time.time()
            c_val, val_summary, val_accuracy, val_binary_seg_loss, val_instance_seg_loss = \
                sess.run([total_loss, val_merge_summary_op, accuracy, binary_seg_loss, disc_loss],
                         feed_dict={input_tensor: gt_imgs_val,
                                    binary_label_tensor: binary_gt_labels_val,
                                    instance_label_tensor: instance_gt_labels_val,
                                    phase: phase_val})

            if epoch % 100 == 0:
                cv2.imwrite('test_image.png', gt_imgs_val[0] + VGG_MEAN)

            summary_writer.add_summary(val_summary, global_step=epoch)

            cost_time_val = time.time() - t_start_val
            val_cost_time_mean.append(cost_time_val)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                log.info('Epoch: {:d} total_loss= {:6f} binary_seg_loss= {:6f} instance_seg_loss= {:6f} accuracy= {:6f}'
                         ' mean_cost_time= {:5f}s '.
                         format(epoch + 1, c, binary_loss, instance_loss, train_accuracy,
                                np.mean(train_cost_time_mean)))
                train_cost_time_mean.clear()

            if epoch % CFG.TRAIN.TEST_DISPLAY_STEP == 0:
                log.info('Epoch_Val: {:d} total_loss= {:6f} binary_seg_loss= {:6f} '
                         'instance_seg_loss= {:6f} accuracy= {:6f} '
                         'mean_cost_time= {:5f}s '.
                         format(epoch + 1, c_val, val_binary_seg_loss, val_instance_seg_loss, val_accuracy,
                                np.mean(val_cost_time_mean)))
                val_cost_time_mean.clear()

            if epoch % 2000 == 0:
                saver.save(sess=sess, save_path=model_save_path, global_step=epoch)
    sess.close()

    return


if __name__ == '__main__':
    # init args
    args = init_args()

    # train lanenet
    train_net(args.dataset_dir, args.weights_path, net_flag=args.net)
View Code

            以下是模型训练过程中生成的文件夹:

    ./summary

              ./figure

              ./checkpoint

     

       在主目录下,执行python train_lanenet_scnn.py,没有问题的话,可以开始训练了...

 

下一篇:

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!