10.优化器

╄→гoц情女王★ 提交于 2019-11-30 17:01:27
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)

#每个批次的大小
batch_size = 64
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size

#定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

#创建一个简单的神经网络
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b)

#交叉熵代价函数
# loss = tf.losses.softmax_cross_entropy(y,prediction)
loss = tf.losses.mean_squared_error(y,prediction)
#使用梯度下降法
# train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)

#初始化变量
init = tf.global_variables_initializer()

#结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(21):
        for batch in range(n_batch):
            batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
Iter 0,Testing Accuracy 0.9106
Iter 1,Testing Accuracy 0.921
Iter 2,Testing Accuracy 0.9261
Iter 3,Testing Accuracy 0.9277
Iter 4,Testing Accuracy 0.9291
Iter 5,Testing Accuracy 0.9315
Iter 6,Testing Accuracy 0.9293
Iter 7,Testing Accuracy 0.9299
Iter 8,Testing Accuracy 0.9298
Iter 9,Testing Accuracy 0.9315
Iter 10,Testing Accuracy 0.9317
Iter 11,Testing Accuracy 0.9329
Iter 12,Testing Accuracy 0.9324
Iter 13,Testing Accuracy 0.9339
Iter 14,Testing Accuracy 0.9321
Iter 15,Testing Accuracy 0.9322
Iter 16,Testing Accuracy 0.934
Iter 17,Testing Accuracy 0.9326
Iter 18,Testing Accuracy 0.9331
Iter 19,Testing Accuracy 0.9334
Iter 20,Testing Accuracy 0.9334
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!