[TensorFlow] argmax, softmax_cross_entropy_with_logits, sparse_softmax_cross_entropy_with_logits函数详解

夙愿已清 提交于 2019-12-03 12:59:05

写在前面

tensorFlow版本:1.8.0

一、tf.argmax()

tf.argmax(
    input, 
    axis=None, 
    name=None, 
    dimension=None, 
    output_type=tf.int64
)

1、argmax()的作用是返回数组某一行或某一列最大数值所在的下标
2、input:输入矩阵
3、 axis:指定按行操作还是按列操作

此函数在使用时最重要的两个参数是input和axis,axis可选的取值是0和1。axis=0表示对矩阵按列进行argmax操作,axis=1表示对矩阵按行进行argmax操作。
举例如下:

import tensorflow as tf

data = tf.constant([[1, 5, 4], [2, 3, 6]], dtype=tf.float32)   # 定义一个2*3的矩阵
argmax_axis_0 = tf.argmax(data, 0)    # 按列
argmax_axis_1 = tf.argmax(data, 1)    # 按行
with tf.Session() as sess:
    value_axis_0 = sess.run(argmax_axis_0)
    value_axis_1 = sess.run(argmax_axis_1)
    print("value_axis_0: %s" % value_axis_0)
    print("value_axis_1: %s" % value_axis_1)

代码执行结果:

value_axis_0: [1 0 1]
value_axis_1: [1 2]

二、tf.nn.softmax_cross_entropy_with_logits()

tf.nn.softmax_cross_entropy_with_logits(
    _sentinel=None, 
    labels=None, 
    logits=None, 
    dim=-1, 
    name=None
)

1、tf.nn.softmax_cross_entropy_with_logits()函数的作用是先对神经网络的输出(logits)进行softmax操作转化成概率值,再将其与labels进行交叉熵计算。
2、labels:真实标签,使用one-hot编码表示。假设共n个样本,样本类别数为m,则输入矩阵的形状为(n, m)
3、logits:神经网络的输出,大小和labels一致。

(1)Softmax计算
softmax的计算公式如下:

softmax(xi)=exp(xi)jexp(xj)

其中xi 表示一个样本的输出。

代码示例:

import tensorflow as tf

# 输出样例,假设有2个样本,类别数为10,则输出结构:(2, 10)
logits = tf.constant([[1, 2, 3, 4, 2, 1, 0, 2, 1, 1], [1, 2, 4, 1, 0, 5, 0, 2, 1, 3]], dtype=tf.float32)
y = tf.nn.softmax(logits)
with tf.Session() as sess:
    print(sess.run(y))

输出结果:

[[0.02500168 0.0679616  0.1847388  0.5021721  0.0679616  0.02500168
  0.0091976  0.0679616  0.02500168 0.02500168]
 [0.0109595  0.029791   0.22012737 0.0109595  0.00403177 0.5983682
  0.00403177 0.029791   0.0109595  0.08098033]]

softmax处理之后的概率值表此样本属于某类别的概率,如0.02500168表示第一个样本属于第一个类别的概率是0.02500168。在最终判断具体是哪个类别时,应该取对应概率值最大的类别。

(2)tf.nn.softmax_cross_entropy_with_logits()
此函数便是先对输出结果进行softmax处理,得到概率值之后再将其与labels进行交叉熵计算。其计算公式为:

L=j=1Tyjlogsj

其中,
y表示真实标签,其维度为1*m。比如对于MNIST手写体识别的数据,标签有0~9共10种可能,则m=10,如[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]表示数字属于第4个类别,即数字3。yj 表示真实标签中的第i个的值。
sj 表示softmax输出向量的第j个值。

举例:
yj=[0,0,0,1,0,0,0,0,0,0]
sj=[0.02500168,0.0679616,0.1847388,0.5021721,0.0679616,0.02500168,0.0091976,0.0679616,0.02500168,0.02500168]
则,L=(0log(0.02500168)+0log(0.0679616)+0log(0.1847388)+1log(0.5021721)+....)=log(0.5021721)
因为当x<1时,logx<0且单调递增,因此,预测sj 越准确,L的值越小。

注意:我们对每一个样本都可以求出一个交叉熵,但是最后的损失值应该是一个数值,它是所有样本交叉熵的和,因此我们需要将所有样本的交叉熵求和,使用tf.reduce_sum()函数

代码举例:

import tensorflow as tf

# 输出样例,假设有2个样本,类别数为10,则输出结构:(2, 10)
logits = tf.constant([[1, 2, 3, 4, 2, 1, 0, 2, 1, 1], [1, 2, 4, 1, 0, 5, 0, 2, 1, 3]], dtype=tf.float32)

# 真实标签
y_ = tf.constant([[0, 0, 0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]], dtype=tf.float32)

# 使用softmax_cross_entropy_with_logits计算交叉熵值
loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_))

with tf.Session() as sess:
    value = sess.run(loss)
    print("loss is: ", value)

输出结果:

loss is:  2.2023613

三、tf.nn.sparse_softmax_cross_entropy_with_logits()

tf.nn.sparse_softmax_cross_entropy_with_logits(
    _sentinel=None, 
    labels=None, 
    logits=None, 
    dim=-1, 
    name=None
)

1、tf.nn.sparse_softmax_cross_entropy_with_logits()功能和tf.nn.softmax_cross_entropy_with_logits()一样
2、logits的输入格式不变,形状仍然是(n, m),n表示样本数,m表示类别个数。
3、labels的输入格式改变,因在tf.nn.softmax_cross_entropy_with_logits()中,每一个样本都需要使用一个m维的向量表示,这m维的向量使用one-hot编码表示,有一位个1,其余都是0。而在tf.nn.sparse_softmax_cross_entropy_with_logits()中,直接使用类别值代替one-hot向量,比如使用one-hot表示一个样本的真实标签是[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],则在tf.nn.sparse_softmax_cross_entropy_with_logits()中用4表示这个样本的真实值。

代码举例:

import tensorflow as tf

# 输出样例,假设有2个样本,类别数为10,则输出结构:(2, 10)
logits = tf.constant([[1, 2, 3, 4, 2, 1, 0, 2, 1, 1], [1, 2, 4, 1, 0, 5, 0, 2, 1, 3]], dtype=tf.float32)

# 真实标签
y_ = tf.constant([[0, 0, 0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]], dtype=tf.float32)

# 使用softmax_cross_entropy_with_logits计算交叉熵值
loss = tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=tf.argmax(y_, 1)))

with tf.Session() as sess:
    value = sess.run(loss)
    print("loss is: ", value)

输出结果:

loss is:  2.2023613

输出结果和softmax_cross_entropy_with_logits的输出结果一样。


参考文章:
https://blog.csdn.net/mao_xiao_feng/article/details/53382790
https://blog.csdn.net/qq575379110/article/details/70538051
https://blog.csdn.net/u014380165/article/details/77284921

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!