到此为止关于超分重建的理论部分八成已经作结,关于这个tensorflow版本的SRCNN的代码解读不知道究竟需要写到什么程度才可以完美收官。大家也都明白,这个东西若写太细,略显冗杂;若写太粗,略显不够明析。反正吧,尽可能的写清楚写明细。下面是我的GitHub代码仓库:https://github.com/XiaoYunChaos,关于这篇的代码随后完整作结后我会上传至仓库,供大家讨论学习,欢迎star哦!
SRCNN(tensorflow)详解分析
- 【1】首先,介绍一下项目结构:
main.py 定义训练和测试参数,此后由设定的参数进行训练或测试。
model.py是模型文件以类的方式实现
utils.py是用来封装项目中的函数作为函数池
psnr.py是用来做评价函数的,功能就是进行计算评价指标
checkpoint文件夹是用来保训练模型,即chekpoint的路径
sample文件夹是样本路径
Train文件夹是训练集路径
Test文件夹是测试集路径,包含Set5与Set14
在看懂代码前,一定要明白一件事就是我们每一次训练实际上是训练图片的大小和输出图片等的大小等参数的设置。项目除了一般的预处理操作,还需要将图片分割,最后的训练完还做实验的时候还需要将图片结合起来。
- 【2】main.py
功能:定义训练和测试参数,包括:batchSize、学习率、步长stride、训练、测试等。
函数运行开启:
if __name__ == '__main__':
# main()
tf.app.run()
随后tf.app运行,此时涉及相关参数:
flags = tf.app.flags
#第一个是参数名称,第二个参数是默认值,第三个是参数描述
flags.DEFINE_integer("epoch", 15000, "训练多少波Number of epoch [15000]")
#flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
flags.DEFINE_integer("batch_size", 128, "batch size")
#一开始将batch size设为128和64,不仅参数初始loss很大,而且往往一段时间后训练就发散
#batch中每个样本产生梯度竞争可能比较激烈,所以导致了收敛过慢
#后来就改回了128
flags.DEFINE_integer("image_size", 33, "图像使用的尺寸 The size of image to use [33]")
flags.DEFINE_integer("label_size", 21, "label_制作的尺寸 The size of label to produce [21]")
#学习率文中设置为 前两层1e-4 第三层1e-5
#SGD+指数学习率10-2作为初始
flags.DEFINE_float("learning_rate", 1e-4, "学习率 The learning rate of gradient descent algorithm [1e-4]")
flags.DEFINE_integer("c_dim", 1, "图像维度 Dimension of image color. [1]")
flags.DEFINE_integer("scale", 3, "sample的scale大小 The size of scale factor for preprocessing input image [3]")
#stride训练采用14,测试采用21
flags.DEFINE_integer("stride", 14, "步长为14或者21 The size of stride to apply input image [14]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "名字 Name of checkpoint directory [checkpoint]")
flags.DEFINE_string("sample_dir", "sample", "名字 Name of sample directory [sample]")
flags.DEFINE_boolean("is_train", True, "True for training, False for testing [True]")#训练
#flags.DEFINE_boolean("is_train", False, "True for training, False for testing")#测试
FLAGS = flags.FLAGS
#第一句是赋值,将前面的一系列参数赋值给FLAGS。
#第二句是创建了一个打印的类,这样就可以调用pp的函数了。
pp = pprint.PrettyPrinter()
此时需要注意这些参数:
- epoch:迭代次数
- batch_size:批处理参数
- image_size:图像大小
- label_size:高分辨率图像大小,即真实标签的大小
- learning_rate:学习率
- c_dim:图像颜色维度
- scale:缩放倍数
- stride:卷积步长
- checkpoint_dir:模型保存路径
- sample_dir:样本路径
- is_train:是否训练
- 【3】main函数
CPU版本:
def main(_): #CPU版本
pp.pprint(flags.FLAGS.__flags)
#路径检查,没有则创建
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
#tf的相关参数传入及srcnn模型训练或测试
with tf.Session() as sess:
#new出一个类对象,这个对象你可以理解为这个三层神经网络
srcnn = SRCNN(sess,
image_size=FLAGS.image_size,
label_size=FLAGS.label_size,
batch_size=FLAGS.batch_size,
c_dim=FLAGS.c_dim,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)
#训练模型
srcnn.train(FLAGS)
GPU版本:
def main(_): #GPU版本:
pp.pprint(flags.FLAGS.__flags)
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#主函数验证路径是否存在,如果不存在就创造一个
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
#sess = tf.Session()
srcnn = SRCNN(sess,
image_size=FLAGS.image_size,
label_size=FLAGS.label_size,
batch_size=FLAGS.batch_size,
c_dim=FLAGS.c_dim,
#图像维度
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)
srcnn.train(FLAGS)
print(srcnn.train(FLAGS))
GPU版本与CPU版本代码理解无多大区别,就是在项目部署上可能不一样,GPU的存在有什么好处呢,说白了就是模型训练加速器,可以更快更高效的将模型训练出来,对于GPU的相关笔记随后再做解释吧,你只要把CPU代码理解了,其他的都是锦上添花。
上述main函数可以说是已经将项目框架跑完了,随后就是一些细节上的理解和处理了。
-
【4】model.py
from utils import (
read_data,
input_setup,
imsave,
merge
)
import time
import os
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
try:
xrange
except:
xrange = range
class SRCNN(object):
def __init__(self,
sess,
image_size=33,
label_size=21,
batch_size=128,
c_dim=1,
checkpoint_dir=None,
sample_dir=None):
self.sess = sess
self.is_grayscale = (c_dim == 1)
self.image_size = image_size
self.label_size = label_size
self.batch_size = batch_size
self.c_dim = c_dim
self.checkpoint_dir = checkpoint_dir
self.sample_dir = sample_dir
self.build_model()
#搭建网络
def build_model(self): #三层网络结构
self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images')
self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')
#第一层CNN:对输入图片的特征提取。(9 x 9 x 64卷积核)
#第二层CNN:对第一层提取的特征的非线性映射(1 x 1 x 32卷积核)
#第三层CNN:对映射后的特征进行重建,生成高分辨率图像(5 x 5 x 1卷积核)
#权重
self.weights = {
#论文中为提高训练速度的设置 n1=32 n2=16
'w1': tf.Variable(tf.random_normal([9, 9, 1, 64], stddev=1e-3), name='w1'),
'w2': tf.Variable(tf.random_normal([1, 1, 64, 32], stddev=1e-3), name='w2'),
'w3': tf.Variable(tf.random_normal([5, 5, 32, 1], stddev=1e-3), name='w3')
}
self.biases = {
'b1': tf.Variable(tf.zeros([64]), name='b1'),
'b2': tf.Variable(tf.zeros([32]), name='b2'),
'b3': tf.Variable(tf.zeros([1]), name='b3')
}
self.pred = self.model()
# Loss function (MSE)以MSE为损失函数
self.loss = tf.reduce_mean(tf.square(self.labels - self.pred))
#主函数调用(训练或测试)
self.saver = tf.train.Saver()
#训练
def train(self, config):
if config.is_train:#判断是否为训练(main传入)
input_setup(self.sess, config)
else:
nx, ny = input_setup(self.sess, config)
#训练为checkpoint下train.h5
#测试为checkpoint下test.h5
if config.is_train:
data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")
else:
data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "test.h5")
#训练数据标签
train_data, train_label = read_data(data_dir)
#读取.h5文件(由测试和训练决定)
# Stochastic gradient descent with the standard backpropagation
self.train_op = tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss)
tf.global_variables_initializer().run()
counter = 0
start_time = time.time()
if self.load(self.checkpoint_dir):
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
#训练
if config.is_train:
print("Training...")
for ep in xrange(config.epoch):#迭代次数的循环
#以batch为单元
# Run by batch images
batch_idxs = len(train_data) // config.batch_size
for idx in xrange(0, batch_idxs):
batch_images = train_data[idx*config.batch_size : (idx+1)*config.batch_size]
batch_labels = train_label[idx*config.batch_size : (idx+1)*config.batch_size]
counter += 1
_, err = self.sess.run([self.train_op, self.loss], feed_dict={self.images: batch_images, self.labels: batch_labels})
if counter % 10 == 0:#10的倍数的step显示
print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
% ((ep+1), counter, time.time()-start_time, err))
if counter % 500 == 0:#500的倍数step储存
self.save(config.checkpoint_dir, counter)
#测试
else:
print("Testing...")
result = self.pred.eval({self.images: train_data, self.labels: train_label})
result = merge(result, [nx, ny])
result = result.squeeze()#除去size为1的维度
#result= exposure.adjust_gamma(result, 1.07)#调暗一些
image_path = os.path.join(os.getcwd(), config.sample_dir)
image_path = os.path.join(image_path, "test_image.png")
imsave(result, image_path)
def model(self):
#strides在官方定义中是一个一维具有四个元素的张量,其规定前后必须为1,所以我们可以改的是中间两个数,中间两个数分别代表了水平滑动和垂直滑动步长值。
conv1 = tf.nn.relu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'])
conv2 = tf.nn.relu(tf.nn.conv2d(conv1, self.weights['w2'], strides=[1,1,1,1], padding='VALID') + self.biases['b2'])
conv3 = tf.nn.conv2d(conv2, self.weights['w3'], strides=[1,1,1,1], padding='VALID') + self.biases['b3']
return conv3
def save(self, checkpoint_dir, step):
model_name = "SRCNN.model"
model_dir = "%s_%s" % ("srcnn", self.label_size)
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)#再一次确定路径为 checkpoint->srcnn_21下
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(self.sess,
os.path.join(checkpoint_dir, model_name), #文件名为SRCNN.model-迭代次数
global_step=step)
def load(self, checkpoint_dir):
print(" [*] Reading checkpoints...")
model_dir = "%s_%s" % ("srcnn", self.label_size)
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
#路径为checkpoint->srcnn_labelsize(21)
#加载路径下的模型(.meta文件保存当前图的结构;
#.index文件保存当前参数名; .data文件保存当前参数值)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
#saver.restore()函数给出model.-n路径后会自动寻找参数名-值文件进行加载
return True
else:
return False
训练方式:SGD的效果更好
- 【5】utils.py
"""
Scipy version > 0.18 is needed, due to 'mode' option from scipy.misc.imread function
"""
import os
import glob#导入glob库,作用是类似于系统的文件路径匹配查询
import h5py#h5py库,主要用于读取或创建datasets或groups
import random
import matplotlib.pyplot as plt
from PIL import Image # for loading images as YCbCr format
import scipy.misc#该库主要用于将数组保存成图像形式
import scipy.ndimage#该库用于图像处理
import numpy as np
import tensorflow as tf
try:
xrange#处理异常中断
except:
xrange = range
FLAGS = tf.app.flags.FLAGS#命令行参数传递
def read_data(path):#读取.h5文件的data和label数据,转化np.array格式
"""
Read h5 format data file
读取h5格式数据文件,用于训练或者测试
参数:
路径: 文件
data.h5 包含训练输入
label.h5 包含训练输出
Args:
path: file path of desired file
data: '.h5' file format that contains train data values
label: '.h5' file format that contains train label values
"""
with h5py.File(path, 'r') as hf:#读取h5格式数据文件(用于训练或测试)
data = np.array(hf.get('data'))
label = np.array(hf.get('label'))
return data, label
def preprocess(path, scale=3):#定义预处理函数
#(1)读取灰度图像;
#(2)modcrop;
#(3)归一化;
#(4)两次bicubic interpolation
返回input_ ,label_
make_data(sess,data,label)**
作用:将data(checkpoint下的train.h5 或test.h5)利用h5的create_dataset 写入
"""
#对路径下的image裁剪成scale整数倍,再对image缩小1/scale倍后,放大scale倍以得到低分辨率图input_,调整尺寸后的image为高分辨率图label_
#image = imread(path, is_grayscale=True)
#label_ = modcrop(image, scale)
Preprocess single image file
(1) Read original image as YCbCr format (and grayscale as default)
(2) Normalize
(3) Apply image file with bicubic interpolation
Args:
path: file path of desired file
input_: image applied bicubic interpolation (low-resolution)
label_: image with original resolution (high-resolution)
"""
image = imread(path, is_grayscale=True)
label_ = modcrop(image, scale)
# Must be normalized
image = image / 255.
label_ = label_ / 255.
input_ = scipy.ndimage.interpolation.zoom(label_, (1./scale), prefilter=False)
input_ = scipy.ndimage.interpolation.zoom(input_, (scale/1.), prefilter=False)
return input_, label_
def prepare_data(sess, dataset):#作用:返回data是训练集或测试集bmp格式的图像
#(1)参数说明:dataset是train dataset 或 test dataset
#(2)glob.glob得到所有的训练集或是测试集图像
"""
Args:
dataset: choose train dataset or test dataset
For train dataset, output data would be ['.../t1.bmp', '.../t2.bmp', ..., '.../t99.bmp']
"""
if FLAGS.is_train:
filenames = os.listdir(dataset)
data_dir = os.path.join(os.getcwd(), dataset)
data = glob.glob(os.path.join(data_dir, "*.bmp"))
#(2)glob.glob得到所有的训练集或是测试集图像
else:
#确定测试数据集合的文件夹为Set5
data_dir = os.path.join(os.sep, (os.path.join(os.getcwd(), dataset)), "Set5")
data = glob.glob(os.path.join(data_dir, "*.bmp"))
return data
def make_data(sess, data, label):
"""
Make input data as h5 file format
Depending on 'is_train' (flag value), savepath would be changed.
"""
#把数据保存成.h5格式
if FLAGS.is_train:
savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5')
else:
savepath = os.path.join(os.getcwd(), 'checkpoint/test.h5')
with h5py.File(savepath, 'w') as hf:
hf.create_dataset('data', data=data)
hf.create_dataset('label', data=label)
def imread(path, is_grayscale=True):#目的:读取指定路径的图像
"""
Read image using its path.
Default value is gray-scale, and image is read by YCbCr format as the paper said.
"""
#读指定路径的图像
if is_grayscale:
return scipy.misc.imread(path, flatten=True, mode='YCbCr').astype(np.float)
else:
return scipy.misc.imread(path, mode='YCbCr').astype(np.float)
def modcrop(image, scale=3):
#把图像的长和宽都变成scale的倍数
"""
To scale down and up the original image, first thing to do is to have no remainder while scaling operation.
We need to find modulo of height (and width) and scale factor.
Then, subtract the modulo from height (and width) of original image size.
There would be no remainder even after scaling operation.
"""
if len(image.shape) == 3:
h, w, _ = image.shape
h = h - np.mod(h, scale)
w = w - np.mod(w, scale)
image = image[0:h, 0:w, :]
else:
h, w = image.shape
h = h - np.mod(h, scale)
w = w - np.mod(w, scale)
image = image[0:h, 0:w]
return image
#把result变为和origin一样的大小
def input_setup(sess, config):#功能:读取train set or test set ;做sub-images;保存成h5文件
"""
Read image files and make their sub-images and saved them as a h5 file format.
"""
#global nx#后加
#global ny#后加
#读图像集,制作子图并保存为h5文件格式
# 读取数据路径
# Load data path
if config.is_train:
data = prepare_data(sess, dataset="Train")
else:
data = prepare_data(sess, dataset="Test")
sub_input_sequence = []
sub_label_sequence = []
padding = abs(config.image_size - config.label_size) / 2 # 6
#padding=0;#修改padding值,测试效果
#训练
if config.is_train:
for i in xrange(len(data)):#一幅图作为一个data
input_, label_ = preprocess(data[i], config.scale)
#得到data[]的LR和HR图input_和label_
if len(input_.shape) == 3:
if len(input_.shape) == 3:
h, w, _ = input_.shape
else:
h, w = input_.shape
#把input_和label_分割成若干自图sub_input和sub_label
for x in range(0, h-config.image_size+1, config.stride):
for y in range(0, w-config.image_size+1, config.stride):
sub_input = input_[x:x+config.image_size, y:y+config.image_size] # [33 x 33]
sub_label = label_[x+int(padding):x+int(padding)+config.label_size, y+int(padding):y+int(padding)+config.label_size] # [21 x 21]
# Make channel value
sub_input = sub_input.reshape([config.image_size, config.image_size, 1])
#按image size大小重排 因此 imgae_size应为33 而label_size应为21
sub_label = sub_label.reshape([config.label_size, config.label_size, 1])
sub_input_sequence.append(sub_input)
#在sub_input_sequence末尾加sub_input中元素 但考虑为空
sub_label_sequence.append(sub_label)
sub_label_sequence.append(sub_label)
else:
#测试
input_, label_ = preprocess(data[2], config.scale)#测试图片
if len(input_.shape) == 3:
h, w, _ = input_.shape
else:
h, w = input_.shape
# Numbers of sub-images in height and width of image are needed to compute merge operation.
nx = ny = 0
#自图需要进行合并操作
for x in range(0, h-config.image_size+1, config.stride):#x从0到h-33+1 步长stride(21)
nx += 1; ny = 0
for y in range(0, w-config.image_size+1, config.stride):#y从0到w-33+1 步长stride(21)
ny += 1
sub_input = input_[x:x+config.image_size, y:y+config.image_size] # [33 x 33]
sub_label = label_[x+int(padding):x+int(padding)+config.label_size, y+int(padding):y+int(padding)+config.label_size] # [21 x 21]
sub_input = sub_input.reshape([config.image_size, config.image_size, 1])
sub_label = sub_label.reshape([config.label_size, config.label_size, 1])
sub_input_sequence.append(sub_input)
sub_label_sequence.append(sub_label)
"""
len(sub_input_sequence) : the number of sub_input (33 x 33 x ch) in one image
(sub_input_sequence[0]).shape : (33, 33, 1)
"""
# Make list to numpy array. With this transform
# 上面的部分和训练是一样的
arrdata = np.asarray(sub_input_sequence) # [?, 33, 33, 1]
arrlabel = np.asarray(sub_label_sequence) # [?, 21, 21, 1]
make_data(sess, arrdata, arrlabel)
if not config.is_train:#存成h5格式
return nx, ny
def imsave(image, path):
return scipy.misc.imsave(path, image)
def merge(images, size):
h, w = images.shape[1], images.shape[2]#觉得下标应该是0,1
img = np.zeros((h*size[0], w*size[1], 1))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j*h:j*h+h, i*w:i*w+w, :] = image
return img
utils.py说明了就是一个函数池,注意下面函数就可以:
- prepare_data(sess,dataset):返回data,data是训练集或测试集中bmp格式的图像。
- input_setup(sess,config):读取train set or test set ;做sub-images;保存成h5文件。
- read_data(path):读取.h5文件的data和label数据,转化np.array格式。
- preprocess(path,scale=3):(1)读取灰度图像;(2)modcrop;(3)归一化;(4)两次bicubic interpolation,返回input_ ,label_。即对路径下的image裁剪成scale整数倍,再对image缩小1/scale倍后,放大scale倍以得到低分辨率图input_,调整尺寸后的image为高分辨率图label_。
- make_data(sess,data,label):将data保存为h5格式的数据,保存到指定路径,是通过create_dataset函数写入的。
- imread(path,is_grayscale=True):读取指定路径的图像。
- modcrop(image, scale=3) #把图像的长和宽都变成scale的倍数。
- modcrop_small(image) #把result变为和origin一样的大小(需要自己写或参考其他)。
- imsave(image,path):将scipy.misc.imsave封装到imsave供自己使用。
- merge(image,size):合并分割后的图片。
到这里差不多,代码解读基本完成。相信你看完之后也可以自己完成运行测试啦!
- 【6】最后,再附一个项目运行基本流程:
- 准备数据集(训练集、测试集);
- 训练模型
- 利用模型测试数据
- 模型评价
来源:CSDN
作者:yunxiaoMr
链接:https://blog.csdn.net/weixin_41297324/article/details/104043845