tensorflow随笔――VGG网络

匿名 (未验证) 提交于 2019-12-02 22:56:40

这次用slim搭个稍微大一点的网络VGG16,VGG16和VGG19实际上差不多,所以本例程的代码以VGG16来做5类花的分类任务。

VGG网络相比之前的LeNet,AlexNet引入如下几个特点:

1. 堆叠3×3的小卷积核替代了5×5,7×7的大卷积核。

虽然5×5的卷积核感受野大,但是参数多。2个3×3的卷积堆叠感受野等同于5×5,并且进行了2次非线性变换。总结一下:相比于大卷积核,小卷积核的堆叠一方面减少了参数; 另一方面进行了更多的非线性映射,增加了网络表达能力。

2.网络层数加深。我们先不谈深层网络难以训练又或者梯度弥散等缺点,在特征的抽象化或者网络的表达能力范畴上,深层网络比浅层网络更加能够拟合数据的分布。

3.VGG网络的原作还引入了数据增广,图像预处理等trick。

开始贴代码阶段,工程分为三个文件:

vgg.py: 搭建16层的VGG网络。

 import tensorflow as tf import tensorflow.contrib.slim as slim   def build_vgg(rgb, num_classes, keep_prob, train=True):     with slim.arg_scope([slim.conv2d, slim.fully_connected], activation_fn=tf.nn.relu, normalizer_fn=slim.batch_norm):         # block_1         net = slim.repeat(rgb, 2, slim.conv2d, 64, [3, 3], padding='SAME', scope='conv1')         net = slim.max_pool2d(net, [2, 2], scope='pool1')          # block_2         net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], padding='SAME', scope='conv2')         net = slim.max_pool2d(net, [2, 2], scope='pool2')          # block_3         net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], padding='SAME', scope='conv3')         net = slim.max_pool2d(net, [2, 2], scope='pool3')          # block_4         net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], padding='SAME', scope='conv4')         net = slim.max_pool2d(net, [2, 2], scope='pool4')          # block_5         net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], padding='SAME', scope='conv5')         net = slim.max_pool2d(net, [2, 2], scope='pool5')          # flatten         feature_shape = net.get_shape()         flattened_shape = feature_shape[1].value * feature_shape[2].value * feature_shape[3].value         pool5_flatten = tf.reshape(net, [-1, flattened_shape])          # fc6         net = slim.fully_connected(pool5_flatten, 4096, scope='fc6')         if train:             net = slim.dropout(net, keep_prob=keep_prob, scope='dropout6')          # fc7         net = slim.fully_connected(net, 4096, scope='fc7')         if train:             net = slim.dropout(net, keep_prob=keep_prob, scope='dropout7')          # fc8         net = slim.fully_connected(net, num_classes, activation_fn=tf.nn.softmax, scope='fc8')     return net

tfrecords.py:用于数据的编码和解码,本例程不同与之前的文章采用feed_dict向网络喂数据,而是使用tensorflow自己的TFRecord结构编码数据集。

 import tensorflow as tf import numpy as np import os import glob from PIL import Image  path_tfrecord = '/home/danny/chenwei/CSDN_blog/VGG/tfrecords/'  def convert_to_tfrecord(images, labels, filename):     print("Converting data into %s ..." % filename)     writer = tf.python_io.TFRecordWriter(path_tfrecord + filename)     for index, img in enumerate(images):         img_raw = Image.open(img)         if img_raw.mode != "RGB":             continue         img_raw = img_raw.resize((256, 256))         img_raw = img_raw.tobytes()         label = int(labels[index])         example = tf.train.Example(features=tf.train.Features(feature={             "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),             "img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),         }))         writer.write(example.SerializeToString())     writer.close()  def read_and_decode(filename, is_train=None):     filename_queue = tf.train.string_input_producer([filename], num_epochs=400)     reader = tf.TFRecordReader()     _, serialized_example = reader.read(filename_queue)     features = tf.parse_single_example(serialized_example,                                        features={                                            'label': tf.FixedLenFeature([], tf.int64),                                            'img_raw': tf.FixedLenFeature([], tf.string),                                        })      img = tf.decode_raw(features['img_raw'], tf.uint8)     img = tf.reshape(img, [256, 256, 3])     img = tf.cast(img, tf.float32) * (1. / 255) - 0.5      if is_train == True:         img = tf.random_crop(img, [224, 224, 3])         img = tf.image.random_flip_left_right(img)         img = tf.image.random_brightness(img, max_delta=63)         img = tf.image.random_contrast(img, lower=0.2, upper=1.8)         img = tf.image.per_image_standardization(img)     else:         img = tf.image.resize_image_with_crop_or_pad(img, 224, 224)         img = tf.image.per_image_standardization(img)      label = tf.cast(features['label'], tf.int32)     return img, label  def get_file(path):     cate = [path+x for x in os.listdir(path) if os.path.isdir(path+x)]     images = []     labels = []     for idx, folder in enumerate(cate):         for img in glob.glob(folder+'/*.jpg'):             print('reading the images:%s' % (img))             images.append(img)             labels.append(idx)     image_list = np.asarray(images, np.string_)     label_list = np.asarray(labels, np.int32)      # shuffle     num_example = image_list.shape[0]     arr = np.arange(num_example)     np.random.shuffle(arr)     image_list = image_list[arr]     label_list = label_list[arr]      # divide train_data and val_data     num_example = image_list.shape[0]     split = np.int(num_example * 0.8)     train_images = image_list[:split]     train_labels = label_list[:split]     val_images = image_list[split:]     val_labels = label_list[split:]     return train_images, train_labels, val_images, val_labels   if __name__ == '__main__':     train_images, train_labels, val_images, val_labels = get_file('/home/danny/chenwei/CSDN_blog/VGG/datasets/')     convert_to_tfrecord(images=train_images, labels=train_labels, filename="train.tfrecords")     convert_to_tfrecord(images=val_images, labels=val_labels, filename="test.tfrecords") 

train.py:用于训练的文件,与之间不同之处在于使用队列的方式多线程取数据进行训练。

 # -*- coding: utf-8 -*- import tensorflow as tf from utils.tfrecords import * from model.vgg import *  tf.app.flags.DEFINE_integer('num_classes', 5, 'classification number.') tf.app.flags.DEFINE_integer('crop_width', 256, 'width of input image.') tf.app.flags.DEFINE_integer('crop_height', 256, 'height of input image.') tf.app.flags.DEFINE_integer('channels', 3, 'channel number of image.') tf.app.flags.DEFINE_integer('batch_size', 2, 'num of each batch') tf.app.flags.DEFINE_integer('num_epochs', 400, 'number of epoch') tf.app.flags.DEFINE_bool('continue_training', False, 'whether is continue training') tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate') tf.app.flags.DEFINE_string('dataset_path', './datasets/', 'path of dataset') tf.app.flags.DEFINE_string('checkpoints', './checkpoints/model.ckpt', 'path of checkpoints') tf.app.flags.DEFINE_string('train_tfrecords', '/home/danny/chenwei/CSDN_blog/VGG/tfrecords/train.tfrecords', 'train tfrecord') tf.app.flags.DEFINE_string('test_tfrecords', '/home/danny/chenwei/CSDN_blog/VGG/tfrecords/test.tfrecords', 'test tfrecord')  FLAGS = tf.app.flags.FLAGS  def main(_):     # data process     train_images, train_labels = read_and_decode(FLAGS.train_tfrecords, True)     val_images, val_labels = read_and_decode(FLAGS.test_tfrecords, False)      train_labels = tf.one_hot(indices=tf.cast(train_labels, tf.int32), depth=FLAGS.num_classes)     train_images_batch, train_labels_batch = tf.train.shuffle_batch([train_images, train_labels], batch_size=FLAGS.batch_size, capacity=20000, min_after_dequeue=3000, num_threads=16)  # 这里设置线程数      val_labels = tf.one_hot(indices=tf.cast(val_labels, tf.int32), depth=FLAGS.num_classes)     val_images_batch, val_labels_batch = tf.train.shuffle_batch([val_images, val_labels], batch_size=FLAGS.batch_size, capacity=20000, min_after_dequeue=3000, num_threads=16)  # 这里设置线程数      # define network input     input = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.crop_height, FLAGS.crop_width, FLAGS.channels], name='input')     output = tf.placeholder(tf.int32, shape=[FLAGS.batch_size, FLAGS.num_classes], name='output')      # control GPU resource utilization     config = tf.ConfigProto(allow_soft_placement=True)     config.gpu_options.allow_growth = True     sess = tf.Session(config=config)      # build network     logits = build_vgg(input, FLAGS.num_classes, 0.5, True)      # loss     cross_entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=output))     regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))     loss = cross_entropy_loss + regularization_loss      # optimizer     train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(loss)      # calculate correct     correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(output, 1))     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))      with sess.as_default():          # init all paramters         saver = tf.train.Saver(max_to_keep=1000)         sess.run(tf.local_variables_initializer())         sess.run(tf.global_variables_initializer())          # restore weight         if FLAGS.continue_training:             saver.restore(sess, FLAGS.checkpoints)          # begin training         coord = tf.train.Coordinator()         threads = tf.train.start_queue_runners(sess=sess, coord=coord)          epoch = 0         try:             while not coord.should_stop():                 # begin training                 train_images, train_labels = sess.run([train_images_batch, train_labels_batch])                 _, err, acc = sess.run([train_op, loss, accuracy], feed_dict={input: train_images, output: train_labels})                 print("[Train] Step: %d, loss: %.4f, accuracy: %.4f%%" % (epoch, err, acc))                 epoch += 1                  if epoch % 10 == 0 or (epoch + 1) == FLAGS.num_epochs:                     val_images, val_labels = sess.run([val_images_batch, val_labels_batch])                     val_err, val_acc = sess.run([loss, accuracy], feed_dict={input:val_imagesh, output: val_labels})                     print("[validation] Step: %d, loss: %.4f, accuracy: %.4f%%" % (epoch, val_err, val_acc))                  if (epoch + 1) == FLAGS.num_epochs:                     checkpoint_path = FLAGS.checkpoints                     saver.save(sess, save_path=checkpoint_path, global_step=epoch)         except tf.errors.OutOfRangeError:             print('Done training -- epoch limited reached')         finally:             coord.request_stop()          coord.join(threads)         sess.close()   if __name__ == '__main__':     tf.app.run()

训练结果:大约在96%左右

 [Train] Step: 19985, loss: 1.1098, accuracy: 1.0000% [Train] Step: 19986, loss: 1.1302, accuracy: 1.0000% [Train] Step: 19987, loss: 1.1232, accuracy: 1.0000% [Train] Step: 19988, loss: 1.1299, accuracy: 1.0000% [Train] Step: 19989, loss: 1.1220, accuracy: 1.0000% [validation] Step: 19990, loss: 1.1634, accuracy: 0.9688%

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!