问题
there are plenty of examples how to create and use TensorFlow datasets, e.g.
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
My question is how to get back the data/labels from the TF dataset in numpy form? In other words want would be reverse operation of the line above, i.e. I have a TF dataset and want to get back images and labels from it.
回答1:
Supposing our tf.data.Dataset is called train_dataset
, with eager_execution
on, you can retrieve images and labels like this:
for images, labels in train_dataset.take(1): # only take first element of dataset
numpy_images = images.numpy()
numpy_labels = labels.numpy()
- the inline operation
.numpy()
converts tf.Tensors in numpy arrays - if you want to retrieve more elements of the dataset, just increase the number inside the take method. If you want all elements, just insert
-1
回答2:
In case your tf.data.Dataset is batched, the following code will retrieve all the y labels:
y = np.concatenate([y for x, y in ds], axis=0)
回答3:
I think we get a good example here:
https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/overview.ipynb#scrollTo=BC4pEXtkp4K-
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
# where mnsit train is a tf dataset
mnist_train = tfds.load(name="mnist", split=tfds.Split.TRAIN)
assert isinstance(mnist_train, tf.data.Dataset)
mnist_example, = mnist_train.take(1)
image, label = mnist_example["image"], mnist_example["label"]
plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
print("Label: %d" % label.numpy())
So each individual component of the dataset can be accessed sort of like a dictionary. Presumably different datasets have different field names (Boston housing won't have image, and value, but might have 'features' and 'target' or 'price':
cnn = tfds.load(name="cnn_dailymail", split=tfds.Split.TRAIN)
assert isinstance(cnn, tf.data.Dataset)
cnn_ex, = cnn.take(1)
print(cnn_ex)
returns a dict() with keys ['article', 'highlight'] with numpy strings inside.
回答4:
Here is my own solution to the problem:
def dataset2numpy(dataset, steps=1):
"Helper function to get data/labels back from TF dataset"
iterator = dataset.make_one_shot_iterator()
next_val = iterator.get_next()
with tf.Session() as sess:
for _ in range(steps):
inputs, labels = sess.run(next_val)
yield inputs, labels
Please note that this function will yield inputs/labels of dataset batch. The steps control how many batches from a dataset will be taken out.
回答5:
This worked for me
features = np.array([list(x[0].numpy()) for x in list(ds_test)])
labels = np.array([x[1].numpy() for x in list(ds_test)])
# NOTE: ds_test was created
iris, iris_info = tfds.load('iris', with_info=True)
ds_orig = iris['train']
ds_orig = ds_orig.shuffle(150, reshuffle_each_iteration=False)
ds_train = ds_orig.take(100)
ds_test = ds_orig.skip(100)
来源:https://stackoverflow.com/questions/56226621/how-to-extract-data-labels-back-from-tensorflow-dataset