Online oversampling in Tensorflow input pipeline

我怕爱的太早我们不能终老 提交于 2019-12-02 05:35:21

After doing a bit more research, I found a solution for what I wanted to do. What I forgot to mention is that the code mentioned in my question is followed by a batch method, such as batch() or batch_join().

These functions take an argument that allows you to group tensors of various batch size rather than just tensors of a single example. The argument is enqueue_many and should be set to True.

The following piece of code does the trick for me:

for thread_id in range(num_preprocess_threads):

    # Parse a serialized Example proto to extract the image and metadata.
    image_buffer, label_index = parse_example_proto(
            example_serialized)

    image = image_preprocessing(image_buffer, bbox, False, thread_id)

    # Convert 3D tensor of shape [height, width, channels] to 
    # a 4D tensor of shape [batch_size, height, width, channels]
    image = tf.expand_dims(image, 0)

    # Define the boolean predicate to be true when the class label is 1
    pred = tf.equal(label_index, tf.convert_to_tensor([1]))
    pred = tf.reshape(pred, [])

    oversample_factor = 2
    r_image = tf.cond(pred, lambda: tf.concat(0, [image]*oversample_factor), lambda: image)
    r_label = tf.cond(pred, lambda: tf.concat(0, [label_index]*oversample_factor), lambda: label_index)
    images_and_labels.append([r_image, r_label])

images, label_batch = tf.train.shuffle_batch_join(
    images_and_labels,
    batch_size=batch_size,
    capacity=2 * num_preprocess_threads * batch_size,
    min_after_dequeue=1 * num_preprocess_threads * batch_size,
    enqueue_many=True)
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!