转载请注明出处:http://www.cnblogs.com/willnote/p/6758953.html
官方API定义
tf.argmax(input, axis=None, name=None, dimension=None)
Returns the index with the largest value across axes of a tensor.
Args:
- input: A Tensor. Must be one of the following types: float32, float64, int64, int32, uint8, uint16, int16, int8, complex64, complex128, qint8, quint8, qint32, half.
- axis: A Tensor. Must be one of the following types: int32, int64. int32, 0 <= axis < rank(input). Describes which axis of the input Tensor to reduce across. For vectors, use axis = 0.
- name: A name for the operation (optional).
Returns:
- A Tensor of type int64.
关于axis
定义中的axis与numpy中的axis是一致的,下面通过代码进行解释
import numpy as np import tensorflow as tf sess = tf.session() m = sess.run(tf.truncated_normal((5,10), stddev = 0.1) ) print type(m) print m ------------------------------------------------------------------------------- <type 'numpy.ndarray'> [[ 0.09957541 -0.0965599 0.06064715 -0.03011306 0.05533558 0.17263047 -0.02660419 0.08313394 -0.07225946 0.04916157] [ 0.11304571 0.02099175 0.03591062 0.01287777 -0.11302195 0.04822164 -0.06853487 0.0800944 -0.1155676 -0.01168544] [ 0.15760773 0.05613248 0.04839646 -0.0218203 0.02233066 0.00929849 -0.0942843 -0.05943 0.08726917 -0.059653 ] [ 0.02553608 0.07298559 -0.06958302 0.02948747 0.00232073 0.11875584 -0.08325859 -0.06616175 0.15124641 0.09522969] [-0.04616683 0.01816062 -0.10866459 -0.12478453 0.01195056 0.0580056 -0.08500613 0.00635608 -0.00108647 0.12054099]]
m是一个5行10列的矩阵,类型为numpy.ndarray
#使用tensorflow中的tf.argmax() col_max = sess.run(tf.argmax(m, 0) ) #当axis=0时返回每一列的最大值的位置索引 print col_max row_max = sess.run(tf.argmax(m, 1) ) #当axis=1时返回每一行中的最大值的位置索引 print row_max array([2, 3, 0, 3, 0, 0, 0, 0, 3, 4]) array([5, 0, 0, 8, 9]) ------------------------------------------------------------------------------- #使用numpy中的numpy.argmax row_max = m.argmax(0) print row_max col_max = m.argmax(1) print col_max array([2, 3, 0, 3, 0, 0, 0, 0, 3, 4]) array([5, 0, 0, 8, 9])
可以看到tf.argmax()与numpy.argmax()方法的用法是一致的
- axis = 0的时候返回每一列最大值的位置索引
- axis = 1的时候返回每一行最大值的位置索引
- axis = 2、3、4...,即为多维张量时,同理推断
参考
来源:https://www.cnblogs.com/willnote/p/6758953.html