问题
I have an input pipeline similar to the one in the Convolutional Neural Network tutorial. My dataset is imbalanced and I want to use minority oversampling to try to deal with this. Ideally, I want to do this "online", i.e. I don't want to duplicate data samples on disk.
Essentially, what I want to do is duplicate individual examples (with some probability) based on the label. I have been reading a bit on Control Flow in Tensorflow. And it seems tf.cond(pred, fn1, fn2)
is the way to go. I am just struggling to find the right parameterisation, since fn1
and fn2
would need to output lists of tensors, where the lists have the same size.
This is roughly what I have so far:
image = image_preprocessing(image_buffer, bbox, False, thread_id)
pred = tf.reshape(tf.equal(label, tf.convert_to_tensor([2])), [])
r_image = tf.cond(pred, lambda: [tf.identity(image), tf.identity(image)], lambda: [tf.identity(image),])
r_label = tf.cond(pred, lambda: [tf.identity(label), tf.identity(label)], lambda: [tf.identity(label),])
However, this raises an error as I mentioned before:
ValueError: fn1 and fn2 must return the same number of results.
Any ideas?
P.S.: this is my first Stack Overflow question. Any feedback on my question is appreciated.
回答1:
After doing a bit more research, I found a solution for what I wanted to do. What I forgot to mention is that the code mentioned in my question is followed by a batch method, such as batch()
or batch_join()
.
These functions take an argument that allows you to group tensors of various batch size rather than just tensors of a single example. The argument is enqueue_many
and should be set to True
.
The following piece of code does the trick for me:
for thread_id in range(num_preprocess_threads):
# Parse a serialized Example proto to extract the image and metadata.
image_buffer, label_index = parse_example_proto(
example_serialized)
image = image_preprocessing(image_buffer, bbox, False, thread_id)
# Convert 3D tensor of shape [height, width, channels] to
# a 4D tensor of shape [batch_size, height, width, channels]
image = tf.expand_dims(image, 0)
# Define the boolean predicate to be true when the class label is 1
pred = tf.equal(label_index, tf.convert_to_tensor([1]))
pred = tf.reshape(pred, [])
oversample_factor = 2
r_image = tf.cond(pred, lambda: tf.concat(0, [image]*oversample_factor), lambda: image)
r_label = tf.cond(pred, lambda: tf.concat(0, [label_index]*oversample_factor), lambda: label_index)
images_and_labels.append([r_image, r_label])
images, label_batch = tf.train.shuffle_batch_join(
images_and_labels,
batch_size=batch_size,
capacity=2 * num_preprocess_threads * batch_size,
min_after_dequeue=1 * num_preprocess_threads * batch_size,
enqueue_many=True)
来源:https://stackoverflow.com/questions/38484075/online-oversampling-in-tensorflow-input-pipeline