Keras: weighted binary crossentropy

后端 未结 6 578
佛祖请我去吃肉
佛祖请我去吃肉 2021-01-31 08:39

I tried to implement a weighted binary crossentropy with Keras, but I am not sure if the code is correct. The training output seems to be a bit confusing. After a few epochs I j

6条回答
  •  广开言路
    2021-01-31 09:36

    Using class_weights in model.fit is slightly different: it actually updates samples rather than calculating weighted loss.

    I also found that class_weights, as well as sample_weights, are ignored in TF 2.0.0 when x is sent into model.fit as TFDataset, or generator. It's fixed though in TF 2.1.0+ I believe.

    Here is my weighted binary cross entropy function for multi-hot encoded labels.

    import tensorflow as tf
    import tensorflow.keras.backend as K
    import numpy as np
    # weighted loss functions
    
    
    def weighted_binary_cross_entropy(weights: dict, from_logits: bool = False):
        '''
        Return a function for calculating weighted binary cross entropy
        It should be used for multi-hot encoded labels
    
        # Example
        y_true = tf.convert_to_tensor([1, 0, 0, 0, 0, 0], dtype=tf.int64)
        y_pred = tf.convert_to_tensor([0.6, 0.1, 0.1, 0.9, 0.1, 0.], dtype=tf.float32)
        weights = {
            0: 1.,
            1: 2.
        }
        # with weights
        loss_fn = get_loss_for_multilabels(weights=weights, from_logits=False)
        loss = loss_fn(y_true, y_pred)
        print(loss)
        # tf.Tensor(0.6067193, shape=(), dtype=float32)
    
        # without weights
        loss_fn = get_loss_for_multilabels()
        loss = loss_fn(y_true, y_pred)
        print(loss)
        # tf.Tensor(0.52158177, shape=(), dtype=float32)
    
        # Another example
        y_true = tf.convert_to_tensor([[0., 1.], [0., 0.]], dtype=tf.float32)
        y_pred = tf.convert_to_tensor([[0.6, 0.4], [0.4, 0.6]], dtype=tf.float32)
        weights = {
            0: 1.,
            1: 2.
        }
        # with weights
        loss_fn = get_loss_for_multilabels(weights=weights, from_logits=False)
        loss = loss_fn(y_true, y_pred)
        print(loss)
        # tf.Tensor(1.0439969, shape=(), dtype=float32)
    
        # without weights
        loss_fn = get_loss_for_multilabels()
        loss = loss_fn(y_true, y_pred)
        print(loss)
        # tf.Tensor(0.81492424, shape=(), dtype=float32)
    
        @param weights A dict setting weights for 0 and 1 label. e.g.
            {
                0: 1.
                1: 8.
            }
            For this case, we want to emphasise those true (1) label, 
            because we have many false (0) label. e.g. 
                [
                    [0 1 0 0 0 0 0 0 0 1]
                    [0 0 0 0 1 0 0 0 0 0]
                    [0 0 0 0 1 0 0 0 0 0]
                ]
    
    
    
        @param from_logits If False, we apply sigmoid to each logit
        @return A function to calcualte (weighted) binary cross entropy
        '''
        assert 0 in weights
        assert 1 in weights
    
        def weighted_cross_entropy_fn(y_true, y_pred):
            tf_y_true = tf.cast(y_true, dtype=y_pred.dtype)
            tf_y_pred = tf.cast(y_pred, dtype=y_pred.dtype)
    
            weights_v = tf.where(tf.equal(tf_y_true, 1), weights[1], weights[0])
            ce = K.binary_crossentropy(tf_y_true, tf_y_pred, from_logits=from_logits)
            loss = K.mean(tf.multiply(ce, weights_v))
            return loss
    
        return weighted_cross_entropy_fn
    

提交回复
热议问题