import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次的大小 batch_size = 64 #计算一共有多少个批次 n_batch = mnist.train.num_examples // batch_size #定义三个placeholder x = tf.placeholder(tf.float32,[None,784]) y = tf.placeholder(tf.float32,[None,10]) keep_prob=tf.placeholder(tf.float32) # 784-1000-500-10 W1 = tf.Variable(tf.truncated_normal([784,1000],stddev=0.1)) b1 = tf.Variable(tf.zeros([1000])+0.1) L1 = tf.nn.tanh(tf.matmul(x,W1)+b1) L1_drop = tf.nn.dropout(L1,keep_prob) W2 = tf.Variable(tf.truncated_normal([1000,500],stddev=0.1)) b2 = tf.Variable(tf.zeros([500])+0.1) L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2) L2_drop = tf.nn.dropout(L2,keep_prob) W3 = tf.Variable(tf.truncated_normal([500,10],stddev=0.1)) b3 = tf.Variable(tf.zeros([10])+0.1) prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3) #交叉熵 loss = tf.losses.softmax_cross_entropy(y,prediction) #使用梯度下降法 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss) #初始化变量 init = tf.global_variables_initializer() #结果存放在一个布尔型列表中 correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置 #求准确率 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess: sess.run(init) for epoch in range(31): for batch in range(n_batch): batch_xs,batch_ys = mnist.train.next_batch(batch_size) sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.5}) test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0}) train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0}) print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) +",Training Accuracy " + str(train_acc))
Extracting MNIST_data\train-images-idx3-ubyte.gz Extracting MNIST_data\train-labels-idx1-ubyte.gz Extracting MNIST_data\t10k-images-idx3-ubyte.gz Extracting MNIST_data\t10k-labels-idx1-ubyte.gz Iter 0,Testing Accuracy 0.9201,Training Accuracy 0.91234547 Iter 1,Testing Accuracy 0.9256,Training Accuracy 0.9229636 Iter 2,Testing Accuracy 0.9359,Training Accuracy 0.9328182 Iter 3,Testing Accuracy 0.9375,Training Accuracy 0.93716365 Iter 4,Testing Accuracy 0.9408,Training Accuracy 0.9411273 Iter 5,Testing Accuracy 0.9407,Training Accuracy 0.94365454 Iter 6,Testing Accuracy 0.9472,Training Accuracy 0.9484909 Iter 7,Testing Accuracy 0.9472,Training Accuracy 0.9502 Iter 8,Testing Accuracy 0.9516,Training Accuracy 0.95336366 Iter 9,Testing Accuracy 0.9522,Training Accuracy 0.95552725 Iter 10,Testing Accuracy 0.9525,Training Accuracy 0.95632726 Iter 11,Testing Accuracy 0.9566,Training Accuracy 0.9578909 Iter 12,Testing Accuracy 0.9574,Training Accuracy 0.9606182 Iter 13,Testing Accuracy 0.9573,Training Accuracy 0.96107274 Iter 14,Testing Accuracy 0.9587,Training Accuracy 0.9614546 Iter 15,Testing Accuracy 0.9581,Training Accuracy 0.9616727 Iter 16,Testing Accuracy 0.9599,Training Accuracy 0.96369094 Iter 17,Testing Accuracy 0.9601,Training Accuracy 0.96403635 Iter 18,Testing Accuracy 0.9618,Training Accuracy 0.9658909 Iter 19,Testing Accuracy 0.9608,Training Accuracy 0.9652 Iter 20,Testing Accuracy 0.9618,Training Accuracy 0.96607274 Iter 21,Testing Accuracy 0.9634,Training Accuracy 0.96794546 Iter 22,Testing Accuracy 0.9639,Training Accuracy 0.96836364 Iter 23,Testing Accuracy 0.964,Training Accuracy 0.96965456 Iter 24,Testing Accuracy 0.9644,Training Accuracy 0.9693091 Iter 25,Testing Accuracy 0.9647,Training Accuracy 0.9703818 Iter 26,Testing Accuracy 0.9639,Training Accuracy 0.9702 Iter 27,Testing Accuracy 0.9651,Training Accuracy 0.9708909 Iter 28,Testing Accuracy 0.9666,Training Accuracy 0.9711818 Iter 29,Testing Accuracy 0.9644,Training Accuracy 0.9710364 Iter 30,Testing Accuracy 0.9659,Training Accuracy 0.97205454