Keras: weighted binary crossentropy

后端 未结 6 576
佛祖请我去吃肉
佛祖请我去吃肉 2021-01-31 08:39

I tried to implement a weighted binary crossentropy with Keras, but I am not sure if the code is correct. The training output seems to be a bit confusing. After a few epochs I j

相关标签:
6条回答
  • 2021-01-31 09:10

    I think using class weight in model.fit is not correct. {0:0.11, 1:0.89}, 0 here is the index, not the 0 class. Keras Documentation: https://keras.io/models/sequential/ class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). This can be useful to tell the model to "pay more attention" to samples from an under-represented class.

    0 讨论(0)
  • 2021-01-31 09:23

    You can use the sklearn module to automatically calculate the weights for each class like this:

    # Import
    import numpy as np
    from sklearn.utils import class_weight
    
    # Example model
    model = Sequential()
    model.add(Dense(32, activation='relu', input_dim=100))
    model.add(Dense(1, activation='sigmoid'))
    
    # Use binary crossentropy loss
    model.compile(optimizer='rmsprop',
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    
    # Calculate the weights for each class so that we can balance the data
    weights = class_weight.compute_class_weight('balanced',
                                                np.unique(y_train),
                                                y_train)
    
    # Add the class weights to the training                                         
    model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)
    

    Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

    0 讨论(0)
  • 2021-01-31 09:27

    Normally, the minority class will have a higher class weight. It'll be better to use one_weight=0.89, zero_weight=0.11 (btw, you can use class_weight={0: 0.11, 1: 0.89}, as suggested in the comment).

    Under class imbalance, your model is seeing much more zeros than ones. It will also learn to predict more zeros than ones because the training loss can be minimized by doing so. That's also why you're seeing an accuracy close to the proportion 0.11. If you take an average over model predictions, it should be very close to zero.

    The purpose of using class weights is to change the loss function so that the training loss cannot be minimized by the "easy solution" (i.e., predicting zeros), and that's why it'll be better to use a higher weight for ones.

    Note that the best weights are not necessarily 0.89 and 0.11. Sometimes you might have to try something like taking logarithms or square roots (or any weights satisfying one_weight > zero_weight) to make it work.

    0 讨论(0)
  • 2021-01-31 09:35

    In the case where you need to have a weighted validation loss with different weights than of the training loss, you can use the parameter validation_data of tensorflow.keras.model.fit() by putting your validation dataset as a tuple of Numpy arrays containing your validation data, labels and a weight for each sample.

    Note that you will have to map each sample to its weight using this technique (here by class).

    Follow the link here : https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit

    tensorflow documentation

    0 讨论(0)
  • 2021-01-31 09:36

    Using class_weights in model.fit is slightly different: it actually updates samples rather than calculating weighted loss.

    I also found that class_weights, as well as sample_weights, are ignored in TF 2.0.0 when x is sent into model.fit as TFDataset, or generator. It's fixed though in TF 2.1.0+ I believe.

    Here is my weighted binary cross entropy function for multi-hot encoded labels.

    import tensorflow as tf
    import tensorflow.keras.backend as K
    import numpy as np
    # weighted loss functions
    
    
    def weighted_binary_cross_entropy(weights: dict, from_logits: bool = False):
        '''
        Return a function for calculating weighted binary cross entropy
        It should be used for multi-hot encoded labels
    
        # Example
        y_true = tf.convert_to_tensor([1, 0, 0, 0, 0, 0], dtype=tf.int64)
        y_pred = tf.convert_to_tensor([0.6, 0.1, 0.1, 0.9, 0.1, 0.], dtype=tf.float32)
        weights = {
            0: 1.,
            1: 2.
        }
        # with weights
        loss_fn = get_loss_for_multilabels(weights=weights, from_logits=False)
        loss = loss_fn(y_true, y_pred)
        print(loss)
        # tf.Tensor(0.6067193, shape=(), dtype=float32)
    
        # without weights
        loss_fn = get_loss_for_multilabels()
        loss = loss_fn(y_true, y_pred)
        print(loss)
        # tf.Tensor(0.52158177, shape=(), dtype=float32)
    
        # Another example
        y_true = tf.convert_to_tensor([[0., 1.], [0., 0.]], dtype=tf.float32)
        y_pred = tf.convert_to_tensor([[0.6, 0.4], [0.4, 0.6]], dtype=tf.float32)
        weights = {
            0: 1.,
            1: 2.
        }
        # with weights
        loss_fn = get_loss_for_multilabels(weights=weights, from_logits=False)
        loss = loss_fn(y_true, y_pred)
        print(loss)
        # tf.Tensor(1.0439969, shape=(), dtype=float32)
    
        # without weights
        loss_fn = get_loss_for_multilabels()
        loss = loss_fn(y_true, y_pred)
        print(loss)
        # tf.Tensor(0.81492424, shape=(), dtype=float32)
    
        @param weights A dict setting weights for 0 and 1 label. e.g.
            {
                0: 1.
                1: 8.
            }
            For this case, we want to emphasise those true (1) label, 
            because we have many false (0) label. e.g. 
                [
                    [0 1 0 0 0 0 0 0 0 1]
                    [0 0 0 0 1 0 0 0 0 0]
                    [0 0 0 0 1 0 0 0 0 0]
                ]
    
    
    
        @param from_logits If False, we apply sigmoid to each logit
        @return A function to calcualte (weighted) binary cross entropy
        '''
        assert 0 in weights
        assert 1 in weights
    
        def weighted_cross_entropy_fn(y_true, y_pred):
            tf_y_true = tf.cast(y_true, dtype=y_pred.dtype)
            tf_y_pred = tf.cast(y_pred, dtype=y_pred.dtype)
    
            weights_v = tf.where(tf.equal(tf_y_true, 1), weights[1], weights[0])
            ce = K.binary_crossentropy(tf_y_true, tf_y_pred, from_logits=from_logits)
            loss = K.mean(tf.multiply(ce, weights_v))
            return loss
    
        return weighted_cross_entropy_fn
    
    0 讨论(0)
  • 2021-01-31 09:37

    You can calc the weights like this and have the binary cross entropy like this which will programmatically put one_weight to 0.11 and one to 0.89:

    one_weight = (1-num_of_ones)/(num_of_ones + num_of_zeros)
    zero_weight = (1-num_of_zeros)/(num_of_ones + num_of_zeros)
    
    def weighted_binary_crossentropy(zero_weight, one_weight):
    
        def weighted_binary_crossentropy(y_true, y_pred):
    
            b_ce = K.binary_crossentropy(y_true, y_pred)
    
            # weighted calc
            weight_vector = y_true * one_weight + (1 - y_true) * zero_weight
            weighted_b_ce = weight_vector * b_ce
    
            return K.mean(weighted_b_ce)
    
        return weighted_binary_crossentropy
    
    0 讨论(0)
提交回复
热议问题