问题
I have multi-class classification using RNN and here is my main code for RNN:
def RNN(x, weights, biases):
x = tf.unstack(x, input_size, 1)
lstm_cell = rnn.BasicLSTMCell(num_unit, forget_bias=1.0, state_is_tuple=True)
stacked_lstm = rnn.MultiRNNCell([lstm_cell]*lstm_size, state_is_tuple=True)
outputs, states = tf.nn.static_rnn(stacked_lstm, x, dtype=tf.float32)
return tf.matmul(outputs[-1], weights) + biases
logits = RNN(X, weights, biases)
prediction = tf.nn.softmax(logits)
cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(cost)
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
I have to classify all inputs to 6 classes and each of classes is composed of one-hot code label as the follow:
happy = [1, 0, 0, 0, 0, 0]
angry = [0, 1, 0, 0, 0, 0]
neutral = [0, 0, 1, 0, 0, 0]
excited = [0, 0, 0, 1, 0, 0]
embarrassed = [0, 0, 0, 0, 1, 0]
sad = [0, 0, 0, 0, 0, 1]
The problem is I cannot print confusion matrix using tf.confusion_matrix()
function.
Is there any way to print confusion matrix using those labels?
If not, how can I convert one-hot code to integer indices only when I need to print confusion matrix?
回答1:
You cannot generate confusion matrix using one-hot vectors as input parameters of labels
and predictions
. You will have to supply it a 1D tensor containing your labels directly.
To convert your one hot vector to normal label, make use of argmax function:
label = tf.argmax(one_hot_tensor, axis = 1)
After that you can print your confusion_matrix
like this:
import tensorflow as tf
num_classes = 2
prediction_arr = tf.constant([1, 1, 1, 1, 0, 0, 0, 0, 1, 1])
labels_arr = tf.constant([0, 1, 1, 1, 1, 1, 1, 1, 0, 0])
confusion_matrix = tf.confusion_matrix(labels_arr, prediction_arr, num_classes)
with tf.Session() as sess:
print(confusion_matrix.eval())
Output:
[[0 3]
[4 3]]
来源:https://stackoverflow.com/questions/46810573/tensorflow-confusion-matrix-using-one-hot-code