Class weights for balancing data in TensorFlow Object Detection API

前端 未结 2 983
醉梦人生
醉梦人生 2020-12-29 13:07

I\'m fine-tuning SSD object detector using TensorFlow object detection API on Open Images Dataset. My training data contains imbalanced classes, e.g.

  1. top (5K i
相关标签:
2条回答
  • 2020-12-29 13:17

    The Object Detection API losses are defined in: https://github.com/tensorflow/models/blob/master/research/object_detection/core/losses.py

    In particular, the following loss classes have been implemented:

    Classification losses:

    1. WeightedSigmoidClassificationLoss
    2. SigmoidFocalClassificationLoss
    3. WeightedSoftmaxClassificationLoss
    4. WeightedSoftmaxClassificationAgainstLogitsLoss
    5. BootstrappedSigmoidClassificationLoss

    Localization losses:

    1. WeightedL2LocalizationLoss
    2. WeightedSmoothL1LocalizationLoss
    3. WeightedIOULocalizationLoss

    The weight parameters are used to balance anchors (prior boxes) and are of size [batch_size, num_anchors] in addition to hard negative mining. Alternatively, the focal loss down weighs well classified examples and focusses on the hard examples.

    The primary class imbalance is due to many more negative examples (bounding boxes without objects of interest) in comparison to very few positive examples (bounding boxes with object classes). That seems to be the reason why class imbalance within positive examples (i.e. unequal distribution of positive class labels) is not implemented as part of object detection losses.

    0 讨论(0)
  • 2020-12-29 13:22

    the API expects a weight for each object (bbox) directly in the annotation files. Due to this requirement the solutions to use class weights seem to be:

    1) If you have a custom dataset you can modify the annotations of each object (bbox) to include the weight field as 'object/weight'.

    2) If you don't want to modify the annotations you can recreate only the tf_records file in order to include the weights of the bboxes.

    3) Modify the code of the API (seemed to me quite tricky)

    I decided to go for #2, so I put here the code to generate such weighted tf records file for a custom dataset with two classes ("top", "dress") with weights (1.0, 0.1) given a folder of xml annotations as:

    import os
    import io
    import glob
    import hashlib
    import pandas as pd
    import xml.etree.ElementTree as ET
    import tensorflow as tf
    import random
    from PIL import Image
    from object_detection.utils import dataset_util
    
    # Define the class names and their weight
    class_names = ['top', 'dress', ...]
    class_weights = [1.0, 0.1, ...]
    
    def create_example(xml_file):
    
            tree = ET.parse(xml_file)
            root = tree.getroot()
            image_name = root.find('filename').text
            image_path = root.find('path').text
            file_name = image_name.encode('utf8')
            size=root.find('size')
            width = int(size[0].text)
            height = int(size[1].text)
            xmin = []
            ymin = []
            xmax = []
            ymax = []
            classes = []
            classes_text = []
            truncated = []
            poses = []
            difficult_obj = []
            weights = [] # Important line
    
            for member in root.findall('object'):
    
               xmin.append(float(member[4][0].text) / width)
               ymin.append(float(member[4][1].text) / height)
               xmax.append(float(member[4][2].text) / width)
               ymax.append(float(member[4][3].text) / height)
               difficult_obj.append(0)
    
               class_name = member[0].text
               class_id = class_names.index(class_name)
               weights.append(class_weights[class_id])
    
               if class_name == 'top':
                   classes_text.append('top'.encode('utf8'))
                   classes.append(1)
               elif class_name == 'dress':
                   classes_text.append('dress'.encode('utf8'))
                   classes.append(2)
               else:
                   print('E: class not recognized!')
    
               truncated.append(0)
               poses.append('Unspecified'.encode('utf8'))
    
            full_path = image_path 
            with tf.gfile.GFile(full_path, 'rb') as fid:
                encoded_jpg = fid.read()
            encoded_jpg_io = io.BytesIO(encoded_jpg)
            image = Image.open(encoded_jpg_io)
            if image.format != 'JPEG':
               raise ValueError('Image format not JPEG')
            key = hashlib.sha256(encoded_jpg).hexdigest()
    
            #create TFRecord Example
            example = tf.train.Example(features=tf.train.Features(feature={
                'image/height': dataset_util.int64_feature(height),
                'image/width': dataset_util.int64_feature(width),
                'image/filename': dataset_util.bytes_feature(file_name),
                'image/source_id': dataset_util.bytes_feature(file_name),
                'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
                'image/encoded': dataset_util.bytes_feature(encoded_jpg),
                'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
                'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
                'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
                'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
                'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
                'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
                'image/object/class/label': dataset_util.int64_list_feature(classes),
                'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
                'image/object/truncated': dataset_util.int64_list_feature(truncated),
                'image/object/view': dataset_util.bytes_list_feature(poses),
                'image/object/weight': dataset_util.float_list_feature(weights) # Important line
            })) 
            return example  
    
    def main(_):
    
        weighted_tf_records_output = 'name_of_records_file.record' # output file
        annotations_path = '/path/to/annotations/folder/*.xml' # input annotations
    
        writer_train = tf.python_io.TFRecordWriter(weighted_tf_records_output)
        filename_list=tf.train.match_filenames_once(annotations_path)
        init = (tf.global_variables_initializer(), tf.local_variables_initializer())
        sess=tf.Session()
        sess.run(init)
        list = sess.run(filename_list)
        random.shuffle(list)  
    
        for xml_file in list:
          print('-> Processing {}'.format(xml_file))
          example = create_example(xml_file)
          writer_train.write(example.SerializeToString())
    
        writer_train.close()
        print('-> Successfully converted dataset to TFRecord.')
    
    
    if __name__ == '__main__':
        tf.app.run()
    

    If you have other kinds of annotations the code will be very similar but this one unfortunately will not work.

    0 讨论(0)
提交回复
热议问题