问题
Consider the problem of creating a dataset of sampling random small image patches from a directory of high-resolution images. The Tensorflow dataset API allows for a very easy way of doing this, by constructing a dataset of image names, shuffling them, mapping it to loaded images, then to random cropped patches.
However, this naive implementation is very inefficient as a separate high-resolution image will be loaded and cropped to generate each patch. Ideally an image could be loaded once and reused to generate many patches.
One simple way that was discussed previously is to generate multiple patches from an image and flatten them. However this has the unfortunate effect of biasing the data too much. We want each training batch to come from different images.
Ideally what I would like is a "random caching filter" transformation that takes an underlying dataset and caches N elements of it into memory. Its iterator will return a random element from the cache. Also, with pre-defined frequency it will replace a random element from the cache with a new one from the underlying dataset. This filter will allow for faster data access at the expense of less randomization and higher memory consumption.
Is there such functionality available?
If not, should it be implemented as a new dataset transformation or simply a new iterator? It seems a new iterator is all that is needed. Any pointers on how to create a new dataset iterator, ideally in C++?
回答1:
You should be able to use tf.data.Dataset.shuffle to achieve what you want. Here is a quick summary for the objectives:
- load very big images, produce smaller random crops from the images and batch them together
- make the pipeline efficient by creating multiple patches from a big image once the image is loaded
- add enough shuffle so that a batch of patches is diverse (all the patches come from different images)
- don't load too many big images in cache
You can achieve all that using the tf.data
API by doing the following steps:
- shuffle the filenames of the big images
- read the big images
- generate multiple patches from this image
- shuffle again all these patches with a big enough buffer size (see this answer on buffer size). Adjusting the buffer size is a tradeoff between good shuffling and size of the cached patches
- batch them
- prefetch one batch
Here is a the relevant code:
filenames = ... # filenames containing the big images
num_samples = len(filenames)
# Parameters
num_patches = 100 # number of patches to extract from each image
patch_size = 32 # size of the patches
buffer_size = 50 * num_patches # shuffle patches from 50 different big images
num_parallel_calls = 4 # number of threads
batch_size = 10 # size of the batch
get_patches_fn = lambda image: get_patches(image, num_patches=num_patches, patch_size=patch_size)
# Create a Dataset serving batches of random patches in our images
dataset = (tf.data.Dataset.from_tensor_slices(filenames)
.shuffle(buffer_size=num_samples) # step 1: all the filenames into the buffer ensures good shuffling
.map(parse_fn, num_parallel_calls=num_parallel_calls) # step 2
.map(get_patches_fn, num_parallel_calls=num_parallel_calls) # step 3
.apply(tf.contrib.data.unbatch()) # unbatch the patches we just produced
.shuffle(buffer_size=buffer_size) # step 4
.batch(batch_size) # step 5
.prefetch(1) # step 6: make sure you always have one batch ready to serve
)
iterator = dataset.make_one_shot_iterator()
patches = iterator.get_next() # shape [None, patch_size, patch_size, 3]
sess = tf.Session()
res = sess.run(patches)
The functions parse_fn
and get_patches
are defined like this:
def parse_fn(filename):
"""Decode the jpeg image from the filename and convert to [0, 1]."""
image_string = tf.read_file(filename)
# Don't use tf.image.decode_image, or the output shape will be undefined
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
# This will convert to float values in [0, 1]
image = tf.image.convert_image_dtype(image_decoded, tf.float32)
return image
def get_patches(image, num_patches=100, patch_size=16):
"""Get `num_patches` random crops from the image"""
patches = []
for i in range(num_patches):
patch = tf.random_crop(image, [patch_size, patch_size, 3])
patches.append(patch)
patches = tf.stack(patches)
assert patches.get_shape().dims == [num_patches, patch_size, patch_size, 3]
return patches
来源:https://stackoverflow.com/questions/48777889/tf-data-api-how-to-efficiently-sample-small-patches-from-images