Online oversampling in Tensorflow input pipeline

后端 未结 1 1131
[愿得一人]
[愿得一人] 2021-01-24 22:00

I have an input pipeline similar to the one in the Convolutional Neural Network tutorial. My dataset is imbalanced and I want to use minority oversampling to try to deal with th

相关标签:
1条回答
  • 2021-01-24 22:31

    After doing a bit more research, I found a solution for what I wanted to do. What I forgot to mention is that the code mentioned in my question is followed by a batch method, such as batch() or batch_join().

    These functions take an argument that allows you to group tensors of various batch size rather than just tensors of a single example. The argument is enqueue_many and should be set to True.

    The following piece of code does the trick for me:

    for thread_id in range(num_preprocess_threads):
    
        # Parse a serialized Example proto to extract the image and metadata.
        image_buffer, label_index = parse_example_proto(
                example_serialized)
    
        image = image_preprocessing(image_buffer, bbox, False, thread_id)
    
        # Convert 3D tensor of shape [height, width, channels] to 
        # a 4D tensor of shape [batch_size, height, width, channels]
        image = tf.expand_dims(image, 0)
    
        # Define the boolean predicate to be true when the class label is 1
        pred = tf.equal(label_index, tf.convert_to_tensor([1]))
        pred = tf.reshape(pred, [])
    
        oversample_factor = 2
        r_image = tf.cond(pred, lambda: tf.concat(0, [image]*oversample_factor), lambda: image)
        r_label = tf.cond(pred, lambda: tf.concat(0, [label_index]*oversample_factor), lambda: label_index)
        images_and_labels.append([r_image, r_label])
    
    images, label_batch = tf.train.shuffle_batch_join(
        images_and_labels,
        batch_size=batch_size,
        capacity=2 * num_preprocess_threads * batch_size,
        min_after_dequeue=1 * num_preprocess_threads * batch_size,
        enqueue_many=True)
    
    0 讨论(0)
提交回复
热议问题