tensorflow中手写识别笔记

微笑、不失礼 提交于 2019-12-24 13:56:58

教程链接:https://www.w3cschool.cn/tensorflow_python/tensorflow_python-c1ov28so.html

tensorflow 常用的函数:

# 导入tensorflow,使用tf代替

import tensorflow as tf

# 计算x,和w的乘积,这里计算x矩阵和w矩阵的乘积

tf.matmul(x, w)      

#  先计算labels和logits的交叉熵(区别),在对结果进行归一化处理,softmax参考

tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)

# 然后求交叉熵的平均值

cross_entrony = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))

# 以梯度下降法,0.5的幅度,减小交叉熵

tf.train.GradientDescentOptimizer(0.5).minimize(cross_entrony)

# 初始化变量(tf,Variable())

tf.global_variables_initializer().run()

# 获取一行最大值的索引

tf.argmax(y, 1)

# 比较a和b对应位置是否是相同的,返回结果是bool类型

tf.equal(a, b)

# 把x的值转化为另一种y类型

tf.cast(x, y)

代码整体解读:

import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 加载数据
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)  # 下载数据集,存储在/home/msl/Downloads

# 构建回归模型
x = tf.placeholder(tf.float32, [None, 784])  # None * 784   测试集[60000, 784]
w = tf.Variable(tf.zeros([784, 10]))  # 784 * 10            和每个像素相乘,得到[None, 10],即为labels
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, w) + b  # 预测值

# 使用梯度下降法最小化交叉熵

y_ = tf.placeholder(tf.float32, [None, 10])
cross_entrony = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))  # 计算预测值和真实值的区别,并求均值
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entrony)

# 初始化变量
# init = tf.global_variables_initializer()
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

# 开始训练
old = time.time()
with tf.device("/gpu:0"):   # 使用gpu为:/gpu:0
    for i in range(1000):
        batch_xs, batch_xy = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={x: batch_xs, y_: batch_xy})
print(time.time() - old)

# 评估模型
# tf.argmax(y, 1)返回y中每行的最大值的索引
# tf.equal(x, y)判断x和y的值是否一致,返回值为bool类型
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))  # tf.cast(a, b)把a转化为b类型, 再求平均值
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))  # 使用测试集评估模型

 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!