Tensorflow: I get something wrong in accuracy

匿名 (未验证) 提交于 2019-12-03 02:38:01

问题:

I just run a simple code and want to get accuracy after training. I load the model that I saved, but when I want to get accuracy, I get something wrong. Why?

# coding=utf-8 from  color_1 import read_and_decode, get_batch, get_test_batch import AlexNet import cv2 import os import time import numpy as np import tensorflow as tf import AlexNet_train import math  batch_size=128 num_examples = 1000 crop_size=56  def evaluate(test_x, test_y):     image_holder = tf.placeholder(tf.float32, [batch_size, 56, 56, 3], name='x-input')     label_holder = tf.placeholder(tf.int32, [batch_size], name='y-input')      y = AlexNet.inference(image_holder,evaluate,None)      correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(label_holder,1))     accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))     saver = tf.train.Saver()     with tf.Session() as sess:         init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())         coord = tf.train.Coordinator()         sess.run(init_op)         threads = tf.train.start_queue_runners(sess=sess, coord=coord)         ckpt=tf.train.get_checkpoint_state(AlexNet_train.MODEL_SAVE_PATH)         if ckpt and ckpt.model_checkpoint_path:             ckpt_name = os.path.basename(ckpt.model_checkpoint_path)             global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]             saver.restore(sess, os.path.join(AlexNet_train.MODEL_SAVE_PATH, ckpt_name))             print('Loading success, global_step is %s' % global_step)             step=0              image_batch, label_batch = sess.run([test_x, test_y])             accuracy_score=sess.run(accuracy,feed_dict={image_holder: image_batch,                                                               label_holder: label_batch})             print("After %s training step(s),validation "                   "precision=%g" % (global_step, accuracy_score))         coord.request_stop()           coord.join(threads)  def main(argv=None):     test_image, test_label = read_and_decode('val.tfrecords')      test_images, test_labels = get_test_batch(test_image, test_label, batch_size, crop_size)      evaluate(test_images, test_labels)   if __name__=='__main__':     tf.app.run() 

And here is error,it said that this line in my code is wrong:" correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(label_holder,1))"

Traceback (most recent call last):   File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 80, in <module>     tf.app.run()   File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run     _sys.exit(main(_sys.argv[:1] + flags_passthrough))   File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 76, in main     evaluate(test_images, test_labels)   File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 45, in evaluate     label_holder: label_batch})   File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 767, in run     run_metadata_ptr)   File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 965, in _run     feed_dict_string, options, run_metadata)   File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1015, in _do_run     target_list, options, run_metadata)   File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1035, in _do_call     raise type(e)(node_def, op, message) tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected dimension in the range [-1, 1), but got 1      [[Node: ArgMax_1 = ArgMax[T=DT_INT32, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_y-input_0, ArgMax_1/dimension)]]  Caused by op u'ArgMax_1', defined at:   File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 80, in <module>     tf.app.run()   File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run     _sys.exit(main(_sys.argv[:1] + flags_passthrough))   File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 76, in main     evaluate(test_images, test_labels)   File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 22, in evaluate     correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(label_holder,1))   File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/math_ops.py", line 263, in argmax     return gen_math_ops.arg_max(input, axis, name)   File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_math_ops.py", line 168, in arg_max     name=name)   File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op     op_def=op_def)   File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2395, in create_op     original_op=self._default_original_op, op_def=op_def)   File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1264, in __init__     self._traceback = _extract_stack()  InvalidArgumentError (see above for traceback): Expected dimension in the range [-1, 1), but got 1      [[Node: ArgMax_1 = ArgMax[T=DT_INT32, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_y-input_0, ArgMax_1/dimension)]] 

How to solve it?

回答1:

Taking the part of this answer related to the problem here:

tf.argmax's definition states:

axis: A Tensor. Must be one of the following types: int32, int64. int32, 0 <= axis < rank(input). Describes which axis of the input Tensor to reduce across.

It seems, then, that the only way to run argmax on the last axis of the tensor is by giving it axis=-1, because of the "strictly less than" sign in the definition of the function.



易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!