9.正则化

爷,独闯天下 提交于 2019-11-30 17:01:27
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

#载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)

#每个批次的大小
batch_size = 64
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size

#定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
keep_prob=tf.placeholder(tf.float32)

# 784-1000-500-10
#创建一个简单的神经网络
W1 = tf.Variable(tf.truncated_normal([784,1000],stddev=0.1))
b1 = tf.Variable(tf.zeros([1000])+0.1)
L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
L1_drop = tf.nn.dropout(L1,keep_prob) 

W2 = tf.Variable(tf.truncated_normal([1000,500],stddev=0.1))
b2 = tf.Variable(tf.zeros([500])+0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
L2_drop = tf.nn.dropout(L2,keep_prob) 

W3 = tf.Variable(tf.truncated_normal([500,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10])+0.1)
prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3)

#正则项
l2_loss = tf.nn.l2_loss(W1) + tf.nn.l2_loss(b1) + tf.nn.l2_loss(W2) + tf.nn.l2_loss(b2) + tf.nn.l2_loss(W3) + tf.nn.l2_loss(b3)

#交叉熵
loss = tf.losses.softmax_cross_entropy(y,prediction) + 0.0005*l2_loss
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)

#初始化变量
init = tf.global_variables_initializer()

#结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(31):
        for batch in range(n_batch):
            batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:1.0})
        
        test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
        train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
        print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) +",Training Accuracy " + str(train_acc))
Iter 0,Testing Accuracy 0.9451,Training Accuracy 0.94643635
Iter 1,Testing Accuracy 0.9529,Training Accuracy 0.9566909
Iter 2,Testing Accuracy 0.96,Training Accuracy 0.96574545
Iter 3,Testing Accuracy 0.9608,Training Accuracy 0.9655455
Iter 4,Testing Accuracy 0.9644,Training Accuracy 0.96776366
Iter 5,Testing Accuracy 0.9644,Training Accuracy 0.96772724
Iter 6,Testing Accuracy 0.9612,Training Accuracy 0.9637455
Iter 7,Testing Accuracy 0.9647,Training Accuracy 0.96952724
Iter 8,Testing Accuracy 0.9635,Training Accuracy 0.9685091
Iter 9,Testing Accuracy 0.9655,Training Accuracy 0.97016364
Iter 10,Testing Accuracy 0.9631,Training Accuracy 0.96703637
Iter 11,Testing Accuracy 0.9649,Training Accuracy 0.96965456
Iter 12,Testing Accuracy 0.9673,Training Accuracy 0.9712909
Iter 13,Testing Accuracy 0.9669,Training Accuracy 0.97174543
Iter 14,Testing Accuracy 0.9644,Training Accuracy 0.9681818
Iter 15,Testing Accuracy 0.9657,Training Accuracy 0.9709273
Iter 16,Testing Accuracy 0.9655,Training Accuracy 0.97154546
Iter 17,Testing Accuracy 0.966,Training Accuracy 0.9701818
Iter 18,Testing Accuracy 0.9635,Training Accuracy 0.96852726
Iter 19,Testing Accuracy 0.9665,Training Accuracy 0.9719818
Iter 20,Testing Accuracy 0.9679,Training Accuracy 0.9732909
Iter 21,Testing Accuracy 0.9683,Training Accuracy 0.9747273
Iter 22,Testing Accuracy 0.9664,Training Accuracy 0.9724
Iter 23,Testing Accuracy 0.9684,Training Accuracy 0.97367275
Iter 24,Testing Accuracy 0.9666,Training Accuracy 0.9719091
Iter 25,Testing Accuracy 0.9655,Training Accuracy 0.97212726
Iter 26,Testing Accuracy 0.9682,Training Accuracy 0.9728
Iter 27,Testing Accuracy 0.9676,Training Accuracy 0.97221816
Iter 28,Testing Accuracy 0.9669,Training Accuracy 0.97238183
Iter 29,Testing Accuracy 0.9675,Training Accuracy 0.97327274
Iter 30,Testing Accuracy 0.9665,Training Accuracy 0.9725091
 
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!