1.Introduction
利用mnist数据集进行训练DCGAN网络,生成数字图像。
2.Source code
#encoding:utf-8
""" Deep Convolutional Generative Adversarial Network (DCGAN).
Using deep convolutional generative adversarial networks (DCGAN) to generate
digit images from a noise distribution.
References:
- Unsupervised representation learning with deep convolutional generative
adversarial networks. A Radford, L Metz, S Chintala. arXiv:1511.06434.
Links:
- [DCGAN Paper](https://arxiv.org/abs/1511.06434).
- [MNIST Dataset](http://yann.lecun.com/exdb/mnist/).
Author: Aymeric Damien
Project: https://github.com/aymericdamien/TensorFlow-Examples/
"""
from __future__ import division, print_function, absolute_import
import scipy.misc
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import PIL.Image as Image
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
# Training Params
num_steps = 100 #20000
batch_size = 32
# Network Params
image_dim = 784 # 28*28 pixels * 1 channel
gen_hidden_dim = 256
disc_hidden_dim = 256
noise_dim = 200 # Noise data points
log_dir = "mnist_logs"
# Generator Network
# Input: Noise, Output: Image
def generator(x, reuse=False):
with tf.variable_scope('Generator', reuse=reuse):
# TensorFlow Layers automatically create variables and calculate their
# shape, based on the input.
x = tf.layers.dense(x, units=6 * 6 * 128) #全连接层 输出维度为units=4608
x = tf.nn.tanh(x) #计算x的正切值
# Reshape to a 4-D array of images: (batch, height, width, channels)
# New shape: (batch, 6, 6, 128)
x = tf.reshape(x, shape=[-1, 6, 6, 128])
# Deconvolution, image shape: (batch, 14, 14, 64)
x = tf.layers.conv2d_transpose(x, 64, 4, strides=2)
# Deconvolution, image shape: (batch, 28, 28, 1)
x = tf.layers.conv2d_transpose(x, 1, 2, strides=2)
# Apply sigmoid to clip values between 0 and 1
x = tf.nn.sigmoid(x)
return x
# Discriminator Network
# Input: Image, Output: Prediction Real/Fake Image
def discriminator(x, reuse=False): # shape:[None, 28, 28, 1]
with tf.variable_scope('Discriminator', reuse=reuse):
# Typical convolutional neural network to classify images.
x = tf.layers.conv2d(x, 64, 5) # shape:[None, 24, 24, 64]
x = tf.nn.tanh(x) # shape:[None, 24, 24, 1]
x = tf.layers.average_pooling2d(x, 2, 2) # shape:[None, 12, 12, 64]
x = tf.layers.conv2d(x, 128, 5) # shape:[None, 8, 8, 128]
x = tf.nn.tanh(x)
x = tf.layers.average_pooling2d(x, 2, 2) # shape:[None, 4, 4, 128]
x = tf.contrib.layers.flatten(x) # shape:[None, 4096] 4*4*128=4096
x = tf.layers.dense(x, 1024) #shape: [None,1024]
x = tf.nn.tanh(x) #shape: [None,1024]
# Output 2 classes: Real and Fake images
x = tf.layers.dense(x, 2) #shape: [None,2]
return x
# Build Networks
# Network Inputs
noise_input = tf.placeholder(tf.float32, shape=[None, noise_dim])
real_image_input = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
# Build Generator Network
gen_sample = generator(noise_input) #shape: [None, 28, 28, 1]
# Build 2 Discriminator Networks (one from noise input, one from generated samples)
disc_real = discriminator(real_image_input) #shape:[None,2]
disc_fake = discriminator(gen_sample, reuse=True) #shape: [None,2]
disc_concat = tf.concat([disc_real, disc_fake], axis=0) #shpae: [2*None,2]
# Build the stacked generator/discriminator
stacked_gan = discriminator(gen_sample, reuse=True)
# Build Targets (real or fake images)
disc_target = tf.placeholder(tf.int32, shape=[None])
gen_target = tf.placeholder(tf.int32, shape=[None])
# Build Loss
disc_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=disc_concat, labels=disc_target))
gen_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=stacked_gan, labels=gen_target))
# Build Optimizers
optimizer_gen = tf.train.AdamOptimizer(learning_rate=0.001)
optimizer_disc = tf.train.AdamOptimizer(learning_rate=0.001)
# Training Variables for each optimizer
# By default in TensorFlow, all variables are updated by each optimizer, so we
# need to precise for each one of them the specific variables to update.
# Generator Network Variables
gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator')
# Discriminator Network Variables
disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator')
# Create training operations
train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)
train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)
# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()
# Start training
with tf.Session() as sess:
train_writer = tf.summary.FileWriter(log_dir + '/train', sess.graph)
# Run the initializer
sess.run(init)
for i in range(1, num_steps+1):
# Prepare Input Data
# Get the next batch of MNIST data (only images are needed, not labels)
batch_x, _ = mnist.train.next_batch(batch_size)
batch_x = np.reshape(batch_x, newshape=[-1, 28, 28, 1])
# Generate noise to feed to the generator
z = np.random.uniform(-1., 1., size=[batch_size, noise_dim])
# Prepare Targets (Real image: 1, Fake image: 0)
# The first half of data fed to the generator are real images,
# the other half are fake images (coming from the generator).
batch_disc_y = np.concatenate(
[np.ones([batch_size]), np.zeros([batch_size])], axis=0)
# Generator tries to fool the discriminator, thus targets are 1.
batch_gen_y = np.ones([batch_size])
# Training
feed_dict = {real_image_input: batch_x, noise_input: z,
disc_target: batch_disc_y, gen_target: batch_gen_y}
_, _, gl, dl = sess.run([train_gen, train_disc, gen_loss, disc_loss],
feed_dict=feed_dict)
if i % 100 == 0 or i == 1:
print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (i, gl, dl))
# Generate images from noise, using the generator network.
f, a = plt.subplots(4, 10, figsize=(10, 4))
for i in range(10):
# Noise input.
z = np.random.uniform(-1., 1., size=[4, noise_dim])
g = sess.run(gen_sample, feed_dict={noise_input: z})
print('g.size: ')
fig_count=0;
for j in range(4):
# Generate image from noise. Extend to 3 channels for matplot figure.
img = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),newshape=(28, 28, 3))
a[j][i].imshow(img)
#############
print("save image")
scipy.misc.imsave('./gen_samp/'+str(i)+str(j)+'.jpg', img)
#scipy.misc.imsave('restmp.jpg', img)
f.show()
plt.draw()
plt.waitforbuttonpress()
来源:CSDN
作者:Rosun_
链接:https://blog.csdn.net/Mr_KkTian/article/details/77967201