How to extract data/labels back from TensorFlow dataset

家住魔仙堡 提交于 2020-12-29 05:30:07

问题


there are plenty of examples how to create and use TensorFlow datasets, e.g.

dataset = tf.data.Dataset.from_tensor_slices((images, labels))

My question is how to get back the data/labels from the TF dataset in numpy form? In other words want would be reverse operation of the line above, i.e. I have a TF dataset and want to get back images and labels from it.


回答1:


Supposing our tf.data.Dataset is called train_dataset , with eager_execution on, you can retrieve images and labels like this:

for images, labels in train_dataset.take(1):  # only take first element of dataset
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
  • the inline operation .numpy() converts tf.Tensors in numpy arrays
  • if you want to retrieve more elements of the dataset, just increase the number inside the take method. If you want all elements, just insert -1



回答2:


In case your tf.data.Dataset is batched, the following code will retrieve all the y labels:

y = np.concatenate([y for x, y in ds], axis=0)



回答3:


I think we get a good example here:

https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/overview.ipynb#scrollTo=BC4pEXtkp4K-

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

# where mnsit train is a tf dataset
mnist_train = tfds.load(name="mnist", split=tfds.Split.TRAIN)
assert isinstance(mnist_train, tf.data.Dataset)

mnist_example, = mnist_train.take(1)
image, label = mnist_example["image"], mnist_example["label"]

plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
print("Label: %d" % label.numpy())

So each individual component of the dataset can be accessed sort of like a dictionary. Presumably different datasets have different field names (Boston housing won't have image, and value, but might have 'features' and 'target' or 'price':

cnn = tfds.load(name="cnn_dailymail", split=tfds.Split.TRAIN)
assert isinstance(cnn, tf.data.Dataset)
cnn_ex, = cnn.take(1)
print(cnn_ex)

returns a dict() with keys ['article', 'highlight'] with numpy strings inside.




回答4:


Here is my own solution to the problem:

def dataset2numpy(dataset, steps=1):
    "Helper function to get data/labels back from TF dataset"
    iterator = dataset.make_one_shot_iterator()
    next_val = iterator.get_next()
    with tf.Session() as sess:
        for _ in range(steps):
           inputs, labels = sess.run(next_val)
           yield inputs, labels

Please note that this function will yield inputs/labels of dataset batch. The steps control how many batches from a dataset will be taken out.




回答5:


This worked for me

features = np.array([list(x[0].numpy()) for x in list(ds_test)])
labels = np.array([x[1].numpy() for x in list(ds_test)])



# NOTE: ds_test was created
iris, iris_info = tfds.load('iris', with_info=True)
ds_orig = iris['train']
ds_orig = ds_orig.shuffle(150, reshuffle_each_iteration=False)
ds_train = ds_orig.take(100)
ds_test = ds_orig.skip(100)


来源:https://stackoverflow.com/questions/56226621/how-to-extract-data-labels-back-from-tensorflow-dataset

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!