How to use Dataset API to read TFRecords file of lists of variant length?

余生颓废 提交于 2019-11-30 02:27:38

After hours of searching and trying, I believe the answer emerges. Below is my code.

def _int64_feature(value):
    # value must be a numpy array.
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value.flatten()))

# Write an array to TFrecord.
# a is an array which contains lists of variant length.
a = np.array([[0, 54, 91, 153, 177],
              [0, 50, 89, 147, 196],
              [0, 38, 79, 157],
              [0, 49, 89, 147, 177],
              [0, 32, 73, 145]])

writer = tf.python_io.TFRecordWriter('file')

for i in range(a.shape[0]): # i = 0 ~ 4
    x_train = np.array(a[i])
    feature = {'i'   : _int64_feature(np.array([i])), 
               'data': _int64_feature(x_train)}

    # Create an example protocol buffer
    example = tf.train.Example(features=tf.train.Features(feature=feature))

    # Serialize to string and write on the file
    writer.write(example.SerializeToString())

writer.close()

# Check TFRocord file.
record_iterator = tf.python_io.tf_record_iterator(path='file')
for string_record in record_iterator:
    example = tf.train.Example()
    example.ParseFromString(string_record)

    i = (example.features.feature['i'].int64_list.value)
    data = (example.features.feature['data'].int64_list.value)
    print(i, data)

# Use Dataset API to read the TFRecord file.
filenames = ["file"]
dataset = tf.data.TFRecordDataset(filenames)
def _parse_function(example_proto):
    keys_to_features = {'i':tf.VarLenFeature(tf.int64),
                        'data':tf.VarLenFeature(tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, keys_to_features)
    return tf.sparse_tensor_to_dense(parsed_features['i']), \
           tf.sparse_tensor_to_dense(parsed_features['data'])
# Parse the record into tensors.
dataset = dataset.map(_parse_function)
# Shuffle the dataset
dataset = dataset.shuffle(buffer_size=1)
# Repeat the input indefinitly
dataset = dataset.repeat()  
# Generate batches
dataset = dataset.batch(1)
# Create a one-shot iterator
iterator = dataset.make_one_shot_iterator()
i, data = iterator.get_next()
with tf.Session() as sess:
    print(sess.run([i, data]))
    print(sess.run([i, data]))
    print(sess.run([i, data]))

There are few things to note.
1. This SO question helps a lot.
2. tf.VarLenFeature would return SparseTensor, thus, using tf.sparse_tensor_to_dense to convert to dense tensor is necessary.
3. In my code, parse_single_example() can't be replaced with parse_example(), and it bugs me for a day. I don't know why parse_example() doesn't work out. If anyone know the reason, please enlighten me.

The error is very simple. Your data is not FixedLenFeature it is VarLenFeature. Replace your line:

 'data':tf.FixedLenFeature([], tf.int64)}

with

 'data':tf.VarLenFeature(tf.int64)}

Also, when you call print(i.eval()) and print(data.eval()) you are calling the iterator twice. The first print will print 0, but the second one will print the value of the second row [ 0, 50, 89, 147, 196]. You can do print(sess.run([i, data])) to get i and data from the same row.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!