问题
I want to prepare the omniglot dataset for n-shot learning. Therefore I need 5 samples from 10 classes (alphabet)
Code to reproduce
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
builder = tfds.builder("omniglot")
# assert builder.info.splits['train'].num_examples == 60000
builder.download_and_prepare()
# Load data from disk as tf.data.Datasets
datasets = builder.as_dataset()
dataset, test_dataset = datasets['train'], datasets['test']
def resize(example):
image = example['image']
image = tf.image.resize(image, [28, 28])
image = tf.image.rgb_to_grayscale(image, )
image = image / 255
one_hot_label = np.zeros((51, 10))
return image, one_hot_label, example['alphabet']
def stack(image, label, alphabet):
return (image, label), label[-1]
def filter_func(image, label, alphabet):
# get just images from alphabet in array, not just 2
arr = np.array(2,3,4,5)
result = tf.reshape(tf.equal(alphabet, 2 ), [])
return result
# correct size
dataset = dataset.map(resize)
# now filter the dataset for the batch
dataset = dataset.filter(filter_func)
# infinite stream of batches (classes*samples + 1)
dataset = dataset.repeat().shuffle(1024).batch(51)
# stack the images together
dataset = dataset.map(stack)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(32)
for i, (image, label) in enumerate(tfds.as_numpy(dataset)):
print(i, image[0].shape)
Now I want to filter the images in the dataset by using the filter function. tf.equal just let me filter by one class, I want something like tensor in array.
Do you see a way doing this with the filter function? Or is this the wrong way and there is a much simpler way?
I want to create a batch of 51 images and according labels, which are from the same N=10 classes. From every class, I need K=5 different images and an additional one (which I need to classify). Every batch of N*K+1 (51) images should be from 10 new random classes.
Thank you very much in advance.
回答1:
tf.equal() supports broadcasting and allows to compare scalars with tensors of rank > 0
.
To KEEP only specific labels use this predicate:
dataset = datasets['train']
def predicate(x, allowed_labels=tf.constant([0., 1., 2.])):
label = x['label']
isallowed = tf.equal(allowed_labels, tf.cast(label, tf.float32))
reduced = tf.reduce_sum(tf.cast(isallowed, tf.float32))
return tf.greater(reduced, tf.constant(0.))
dataset = dataset.filter(predicate).batch(20)
for i, x in enumerate(tfds.as_numpy(dataset)):
print(x['label'])
# [1 0 0 1 2 1 1 2 1 0 0 1 2 0 1 0 2 2 0 1]
# [1 0 2 2 0 2 1 2 1 2 2 2 0 2 0 2 1 2 1 1]
# [2 1 2 1 0 1 1 0 1 2 2 0 2 0 1 0 0 0 0 0]
allowed_labels
specifies labels you want to keep. All labels that are not in this tensor will be filtered out.
来源:https://stackoverflow.com/questions/55731774/filter-dataset-to-get-just-images-from-specific-class