Unable to read from Tensorflow tfrecord file

試著忘記壹切 提交于 2019-12-11 12:15:09

问题


I am able to create the tfrecords file by using the below code.

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def convert_to_tfrecord(images,labels,file_name):
    # images is a numpy array of shape (num_images,channel,rows,column)
    # labels is a numpy array of shape (num_images,)
    num_labels = np.shape(labels)
    (num_images,depth,rows,cols) = np.shape(images)
    writer = tf.python_io.TFRecordWriter(file_name)
    for index in range(num_images):
        image_raw = images[index]
        image_raw = image_raw.astype(np.float32)
        image_raw = image_raw.tostring()
        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(rows),
            'width': _int64_feature(cols),
            'depth': _int64_feature(depth),
            'label': _int64_feature(int(labels[index])),
            'image_raw': _bytes_feature(image_raw)}))

        writer.write(example.SerializeToString())
    writer.close()

But, while reading data from the tfrecord file by using the below function

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    img_features = tf.parse_single_example(
       serialized_example,
       features={
          'height': tf.FixedLenFeature([], tf.int64),
          'width': tf.FixedLenFeature([], tf.int64),
          'depth': tf.FixedLenFeature([], tf.int64),
          'image_raw': tf.FixedLenFeature([], tf.string),
          'label': tf.FixedLenFeature([], tf.int64),
      })

    image = tf.decode_raw(img_features['image_raw'], tf.float32)
    label = tf.cast(img_features['label'], tf.int32)
    height = tf.cast(img_features['height'], tf.int32)
    width = tf.cast(img_features['width'], tf.int32)
    depth = tf.cast(img_features['depth'], tf.int32)
    image_shape = tf.stack([depth,height, width])
    image = tf.reshape(image, image_shape)
    return image,label

def inputs(batch_size, num_epochs):
    filename = ['set1.tfrecords']
    # dir_path is a global variable
    file_path = dir_path + 'set1.tfrecords'
    filename_queue = tf.train.string_input_producer([file_path], num_epochs=1)
    image,label = read_and_decode(filename_queue)
    images, sparse_labels = tf.train.shuffle_batch(
        [image, label], batch_size=batch_size, num_threads=2,
       capacity=1000 + 3 * batch_size, min_after_dequeue=1000)
    return images, sparse_labels

I am getting the following error constantly

 images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1, min_after_dequeue=10) 

File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/input.py", line 1225, in shuffle_batch
name=name)

File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/input.py", line 781, in _shuffle_batch
dtypes=types, shapes=shapes, shared_name=shared_name)

File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/data_flow_ops.py", line 641, in __init__
shapes = _as_shape_list(shapes, dtypes)

File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/data_flow_ops.py", line 77, in _as_shape_list
raise ValueError("All shapes must be fully defined: %s" % shapes)

ValueError: All shapes must be fully defined: [TensorShape([Dimension(None)]), TensorShape([])]

What is the reason for the above error and how to overcome this? I am able to read the tfrecords file by iterating over the file by using tf.python_io.tf_record_iterator(path=filename).


回答1:


The error is raised because tf.train.shuffle_batch needs to know the shape of your tensors to be able to batch them (items in a batch must have all the same shape). In principle, however, raw data can have different sizes, so tf.decode_raw doesn't set any shape for your tensor.

In the comments, you mention that all your images have shape (192,81,2), so you only need to set that shape in the image tensor before returning from read_and_decode:

def read_and_decode(filename_queue):
    # rest of your code here
    image_shape = [height, width, depth]
    image = tf.reshape(image, image_shape)
    image.set_shape(image_shape) #<<<<<<<<<<<<<<<
    return image,label


来源:https://stackoverflow.com/questions/47809776/unable-to-read-from-tensorflow-tfrecord-file

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!