Class weights for balancing data in TensorFlow Object Detection API

Deadly 提交于 2019-12-18 12:38:25

问题


I'm fine-tuning SSD object detector using TensorFlow object detection API on Open Images Dataset. My training data contains imbalanced classes, e.g.

  1. top (5K images)
  2. dress (50K images)
  3. etc...

I would like to add class weights to classification loss to improve performance. How do I do that? The following section of the config file seems relevant:

loss {
  classification_loss {
    weighted_sigmoid {
    }
  }
  localization_loss {
    weighted_smooth_l1 {
    }
  }
 ...
  classification_weight: 1.0
  localization_weight: 1.0
}

How can I change the config file to add classification loss weights per class? If not through a config file, what's a recommended way of doing that?


回答1:


the API expects a weight for each object (bbox) directly in the annotation files. Due to this requirement the solutions to use class weights seem to be:

1) If you have a custom dataset you can modify the annotations of each object (bbox) to include the weight field as 'object/weight'.

2) If you don't want to modify the annotations you can recreate only the tf_records file in order to include the weights of the bboxes.

3) Modify the code of the API (seemed to me quite tricky)

I decided to go for #2, so I put here the code to generate such weighted tf records file for a custom dataset with two classes ("top", "dress") with weights (1.0, 0.1) given a folder of xml annotations as:

import os
import io
import glob
import hashlib
import pandas as pd
import xml.etree.ElementTree as ET
import tensorflow as tf
import random
from PIL import Image
from object_detection.utils import dataset_util

# Define the class names and their weight
class_names = ['top', 'dress', ...]
class_weights = [1.0, 0.1, ...]

def create_example(xml_file):

        tree = ET.parse(xml_file)
        root = tree.getroot()
        image_name = root.find('filename').text
        image_path = root.find('path').text
        file_name = image_name.encode('utf8')
        size=root.find('size')
        width = int(size[0].text)
        height = int(size[1].text)
        xmin = []
        ymin = []
        xmax = []
        ymax = []
        classes = []
        classes_text = []
        truncated = []
        poses = []
        difficult_obj = []
        weights = [] # Important line

        for member in root.findall('object'):

           xmin.append(float(member[4][0].text) / width)
           ymin.append(float(member[4][1].text) / height)
           xmax.append(float(member[4][2].text) / width)
           ymax.append(float(member[4][3].text) / height)
           difficult_obj.append(0)

           class_name = member[0].text
           class_id = class_names.index(class_name)
           weights.append(class_weights[class_id])

           if class_name == 'top':
               classes_text.append('top'.encode('utf8'))
               classes.append(1)
           elif class_name == 'dress':
               classes_text.append('dress'.encode('utf8'))
               classes.append(2)
           else:
               print('E: class not recognized!')

           truncated.append(0)
           poses.append('Unspecified'.encode('utf8'))

        full_path = image_path 
        with tf.gfile.GFile(full_path, 'rb') as fid:
            encoded_jpg = fid.read()
        encoded_jpg_io = io.BytesIO(encoded_jpg)
        image = Image.open(encoded_jpg_io)
        if image.format != 'JPEG':
           raise ValueError('Image format not JPEG')
        key = hashlib.sha256(encoded_jpg).hexdigest()

        #create TFRecord Example
        example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': dataset_util.int64_feature(height),
            'image/width': dataset_util.int64_feature(width),
            'image/filename': dataset_util.bytes_feature(file_name),
            'image/source_id': dataset_util.bytes_feature(file_name),
            'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
            'image/encoded': dataset_util.bytes_feature(encoded_jpg),
            'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
            'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
            'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
            'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
            'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
            'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
            'image/object/class/label': dataset_util.int64_list_feature(classes),
            'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
            'image/object/truncated': dataset_util.int64_list_feature(truncated),
            'image/object/view': dataset_util.bytes_list_feature(poses),
            'image/object/weight': dataset_util.float_list_feature(weights) # Important line
        })) 
        return example  

def main(_):

    weighted_tf_records_output = 'name_of_records_file.record' # output file
    annotations_path = '/path/to/annotations/folder/*.xml' # input annotations

    writer_train = tf.python_io.TFRecordWriter(weighted_tf_records_output)
    filename_list=tf.train.match_filenames_once(annotations_path)
    init = (tf.global_variables_initializer(), tf.local_variables_initializer())
    sess=tf.Session()
    sess.run(init)
    list = sess.run(filename_list)
    random.shuffle(list)  

    for xml_file in list:
      print('-> Processing {}'.format(xml_file))
      example = create_example(xml_file)
      writer_train.write(example.SerializeToString())

    writer_train.close()
    print('-> Successfully converted dataset to TFRecord.')


if __name__ == '__main__':
    tf.app.run()

If you have other kinds of annotations the code will be very similar but this one unfortunately will not work.




回答2:


The Object Detection API losses are defined in: https://github.com/tensorflow/models/blob/master/research/object_detection/core/losses.py

In particular, the following loss classes have been implemented:

Classification losses:

  1. WeightedSigmoidClassificationLoss
  2. SigmoidFocalClassificationLoss
  3. WeightedSoftmaxClassificationLoss
  4. WeightedSoftmaxClassificationAgainstLogitsLoss
  5. BootstrappedSigmoidClassificationLoss

Localization losses:

  1. WeightedL2LocalizationLoss
  2. WeightedSmoothL1LocalizationLoss
  3. WeightedIOULocalizationLoss

The weight parameters are used to balance anchors (prior boxes) and are of size [batch_size, num_anchors] in addition to hard negative mining. Alternatively, the focal loss down weighs well classified examples and focusses on the hard examples.

The primary class imbalance is due to many more negative examples (bounding boxes without objects of interest) in comparison to very few positive examples (bounding boxes with object classes). That seems to be the reason why class imbalance within positive examples (i.e. unequal distribution of positive class labels) is not implemented as part of object detection losses.



来源:https://stackoverflow.com/questions/51862997/class-weights-for-balancing-data-in-tensorflow-object-detection-api

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!