加权交叉熵损失函数:tf.nn.weighted_cross_entropy_with_logits

天大地大妈咪最大 提交于 2020-08-13 10:39:48

tf.nn.weighted_cross_entropy_with_logits函数

tf.nn.weighted_cross_entropy_with_logits(
    targets,
    logits,
    pos_weight,
    name=None
)

定义在:tensorflow/python/ops/nn_impl.py。

计算加权交叉熵。

类似于sigmoid_cross_entropy_with_logits(),除了pos_weight,允许人们通过向上或向下加权相对于负误差的正误差的成本来权衡召回率和精确度。

通常的交叉熵成本定义为:

targets * -log(sigmoid(logits)) +
    (1 - targets) * -log(1 - sigmoid(logits))

值pos_weights > 1减少了假阴性计数,从而增加了召回率。相反设置pos_weights < 1会减少假阳性计数并提高精度。从一下内容可以看出pos_weight是作为损失表达式中的正目标项的乘法系数引入的:

targets * -log(sigmoid(logits)) * pos_weight +
    (1 - targets) * -log(1 - sigmoid(logits))

为了简便起见,让x = logits,z = targets,q = pos_weight。损失是:

  qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
= (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))

设置l = (1 + (q - 1) * z),确保稳定性并避免溢出,使用一下内容来实现:

(1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))

logits和targets必须具有相同的类型和形状。

参数:

  • targets:一个Tensor,与logits具有相同的类型和形状。
  • logits:一个Tensor,类型为float32或float64。
  • pos_weight:正样本中使用的系数。
  • name:操作的名称(可选)。

返回:

与具有分量加权逻辑损失的logits具有相同形状的Tensor。

可能引发的异常:

  • ValueError:如果logits和targets没有相同的形状。
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!