TensorFlow decode_csv shape error

做~自己de王妃 提交于 2020-08-10 00:24:02

问题


I read in a *.csv file using tf.data.TextLineDataset and apply map on it:

dataset = tf.data.TextLineDataset(os.path.join(data_dir, subset, 'label.txt'))
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
                          num_parallel_calls=num_parallel_calls)

Parse function parse_record_fn looks like this:

def parse_record(raw_record, is_training):
    default_record = ["./", -1]
    filename, label = tf.decode_csv([raw_record], default_record)
    # do something
    return image, label

But there raise an ValueError at tf.decode_csv in parse function:

ValueError: Shape must be rank 1 but is rank 0 for 'DecodeCSV' (op: 'DecodeCSV') with input shapes: [1], [], [].

My *.csv file example:

/data/1.png, 5
/data/2.png, 7

Question:

  1. Where goes wrong?
  2. What does shapes: [1], [], [] mean?

Reproduce

This error can be reproduced in this code:

import tensorflow as tf
import os

def parse_record(raw_record, is_training):
    default_record = ["./", -1]
    filename, label = tf.decode_csv([raw_record], default_record)

    # do something

    return image, label

with tf.Session() as sess:
    csv_path = './labels.txt'


    dataset = tf.data.TextLineDataset(csv_path)

    dataset = dataset.map(lambda value: parse_record(value, True))


sess.run(dataset)

回答1:


Looking at the documentation of tf.decode_csv, it says about the default records:

record_defaults: A list of Tensor objects with specific types. Acceptable types are float32, float64, int32, int64, string. One tensor per column of the input record, with either a scalar default value for that column or empty if the column is required.

I believe the error you are getting originates from how you define the tensor default_record. Your default_record certainly is a list of tensor objects (or objects convertible to tensors), but I think the error message is telling that they should be rank-1 tensors, not rank-0 tensors as in your case.

You can fix the issue by making the default records rank 1 tensors. See the following toy example:

import tensorflow as tf

my_line = 'filename.png, 10'
default_record_1 = [['./'], [-1]] # do this!
default_record_2 = ['./', -1] # this is what you do now

decoded_1 = tf.decode_csv(my_line, default_record_1)
with tf.Session() as sess:
    d = sess.run(decoded_1)
    print(d)

# This will cause an error
decoded_2 = tf.decode_csv(my_line, default_record_2)

The error produced on the last line is familiar:

ValueError: Shape must be rank 1 but is rank 0 for 'DecodeCSV_1' (op: 'DecodeCSV') with input shapes: [], [], [].

In the message, the input shapes, the three brackets [], refer to the shapes of the input arguments records, record_defaults, and field_delim of tf.decode_csv. In your case the first of these shapes is [1] since you input [raw_record]. I agree that the message for this case is not very informative...



来源:https://stackoverflow.com/questions/49473963/tensorflow-decode-csv-shape-error

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!