Get data set as numpy array from TFRecordDataset

无人久伴 提交于 2020-04-10 18:10:11

问题


I'm using the new tf.data API to create an iterator for the CIFAR10 dataset. I'm reading the data from two .tfrecord files. One which holds the training data (train.tfrecords) and another one which holds the test data (test.tfrecords). This works all fine. At some point, however, I need both data sets (training data and test data) as numpy arrays.

Is it possible to retrieve a data set as numpy array from a tf.data.TFRecordDataset object?


回答1:


You can use the tf.data.Dataset.batch() transformation and tf.contrib.data.get_single_element() to do this. As a refresher, dataset.batch(n) will take up to n consecutive elements of dataset and convert them into one element by concatenating each component. This requires all elements to have a fixed shape per component. If n is larger than the number of elements in dataset (or if n doesn't divide the number of elements exactly), then the last batch can be smaller. Therefore, you can choose a large value for n and do the following:

import numpy as np
import tensorflow as tf

# Insert your own code for building `dataset`. For example:
dataset = tf.data.TFRecordDataset(...)  # A dataset of tf.string records.
dataset = dataset.map(...)  # Extract components from each tf.string record.

# Choose a value of `max_elems` that is at least as large as the dataset.
max_elems = np.iinfo(np.int64).max
dataset = dataset.batch(max_elems)

# Extracts the single element of a dataset as one or more `tf.Tensor` objects.
# No iterator needed in this case!
whole_dataset_tensors = tf.contrib.data.get_single_element(dataset)

# Create a session and evaluate `whole_dataset_tensors` to get arrays.
with tf.Session() as sess:
    whole_dataset_arrays = sess.run(whole_dataset_tensors)


来源:https://stackoverflow.com/questions/48871438/get-data-set-as-numpy-array-from-tfrecorddataset

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!