广告推荐算法(group auc)评价指标及Spark实现代码

江枫思渺然 提交于 2019-11-30 00:52:11

我们曾经有这样的疑惑,那就是训练样本,AUC得到提升。当将新模型放到线上后,却发现实际效果却没有老模型好,这时候很多人就开始疑惑了。
​ 在机器学习算法中,很多情况我们都是把auc当成最常用的一个评价指标,而auc反映整体样本间的排序能力,但是有时候auc这个指标可能并不能完全说明问题,有可能auc并不能真正反映模型的好坏,以CTR预估算法(推荐算法一般把这个作为一个很重要的指标)为例,把用户点击的样本当作正样本,没有点击的样本当作负样本,把这个任务当成一个二分类进行处理,最后模型输出的是样本是否被点击的概率。
​ 举个很简单的例子,假如有两个用户,分别是甲和乙,一共有5个样本,其中+表示正样本,-表示负样本,我们把5个样本按照模型A预测的score从小到大排序,得到 甲-,甲+,乙-,甲+,乙+. 那么实际的auc应该是 (1+2+2)/(32)=0.833, 那假如有另一个模型B,把这5个样本根据score从小到大排序后,得到 甲-,甲+,甲+,乙-,乙+, 那么该模型预测的auc是(1+1+2)/(32)=0.667。
 那么根据auc的表现来看,模型A的表现优于模型B,但是从实际情况来看,对于用户甲,模型B把其所有的负样本的打分都比正样本低,故,对于用户甲,模型B的auc是1, 同理对于用户乙,模型B的auc也应该是1,同样,对于用户甲和乙,模型A的auc也是1,所以从实际情况来看,模型B的效果和模型A应该是一样好的,这和实际的auc的结果矛盾。
  可能auc这个指标失真了,因为用户广告之间的排序是个性化的,不同用户的排序结果不太好比较,这可能导致全局auc并不能反映真实情况。
​ 因为auc反映的是整体样本间的一个排序能力,而在计算广告领域,我们实际要衡量的是不同用户对不同广告之间的排序能力, 实际更关注的是同一个用户对不同广告间的排序能力,为此,参考了阿里妈妈团队之前有使用的group auc的评价指标 group auc实际是计算每个用户的auc,然后加权平均,最后得到group auc,这样就能减少不同用户间的排序结果不太好比较这一影响。group auc具体公式如下:

gauc计ç®å¬å¼

实际处理时权重一般可以设为每个用户view的次数,或click的次数,而且一般计算时,会过滤掉单个用户全是正样本或负样本的情况。
 但是实际上一般还是主要看auc这个指标,但是当发现auc不能很好的反映模型的好坏(比如auc增加了很多,实际效果却变差了),这时候可以看一下gauc这个指标。

​ 由于Spark计算AUC需要将数据转为RDD,且非常慢。如果计算每个用户的AUC将会超级花费时间。下面是我实现的一种快速计算GAUC的方式,供大家参考。


 
  1. object AUCUtil {
  2.  
  3. /**
  4. * 获取ROC曲线
  5. * scoreAndLabels : _._1 is positive probability,_._2 is true label
  6. */
  7. def roc(scoreAndLabels: Array[(Double, Double)]): Array[(Double, Double)] = {
  8. val results = scala.collection.mutable.ArrayBuffer[(Double, Double)]()
  9. val scoreAndLabelsCount = scoreAndLabels.length
  10. val scoreAndLabelsWithIndex = scoreAndLabels.seq.sortBy(_._1).zipWithIndex.map(row => {
  11. (row._2 + 1, row._1._1, row._1._2)
  12. })
  13. val num = 20
  14. //阀值
  15. val thresholds = scala.collection.mutable.Set[Double]()
  16. for (a = scoreAndLabelsCount) index = scoreAndLabelsCount
  17. val threshold = scoreAndLabelsWithIndex.filter(_._1 == index)(0)._2
  18. thresholds += threshold
  19. }
  20.  
  21. //正样本的数量
  22. val positiveCount = scoreAndLabels.filter(_._2 == 1.0).length
  23. //负样本的数量
  24. val negativeCount = scoreAndLabels.filter(_._2 == 0.0).length
  25. results += ((0, 0))
  26. //全是正样本和全是负样本不处理
  27. if (positiveCount != 0 && negativeCount != 0) {
  28. val thrsholdsSorted = thresholds.toSeq.sortWith(_ > _)
  29. for (threshold row._2 == 1.0 && row._1 >= threshold).length
  30. //预测为正样本的数量
  31. val P = scoreAndLabels.filter(_._1 >= threshold).length
  32. //负样本中预测错误的数量
  33. val FP = scoreAndLabels.filter(row => row._2 == 0.0 && row._1 >= threshold).length
  34. //FPR
  35. val FPR = (FP.toDouble / negativeCount.toDouble).toDouble.formatted("%.5f").toDouble
  36. //TPR,召回率
  37. val TPR = (TP.toDouble / positiveCount.toDouble).toDouble.formatted("%.5f").toDouble
  38. results += ((FPR, TPR))
  39. }
  40. }
  41. results += ((1, 1))
  42. results.distinct.toArray
  43. }
  44.  
  45. /**
  46. * 使用梯形法则计算连接两个输入点的线下面积。
  47. * (上底+下底)*高/2
  48. */
  49. def trapezoid(points: Seq[(Double, Double)]): Double = {
  50. require(points.length == 2)
  51. val x = points.head
  52. val y = points.last
  53. (y._1 - x._1) * (y._2 + x._2) / 2.0
  54. }
  55.  
  56. /**
  57. * 计算曲线下的面积
  58. * curve : _1 is FPR,_2 is TPR
  59. */
  60. def areaUnderROC(curve: Iterable[(Double, Double)]): Double = {
  61. curve.toIterator.sliding(2).withPartial(false).aggregate(0.0)(
  62. seqop = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
  63. combop = _ + _)
  64. }
  65.  
  66. def group_auc(predictions: DataFrame, userId: String): Double = {
  67. group_auc(predictions, userId, "probability")
  68. }
  69.  
  70. /**
  71. * 二分类模型GAUC评估,使用曝光数加权
  72. * @param predictions 测试集返回的预测结果
  73. * @param userId 用户ID列名
  74. * @param probabilityCol 概率值的列名
  75. */
  76. def group_auc(predictions: DataFrame, userId: String, probabilityCol: String): Double = {
  77. import predictions.sparkSession.implicits._
  78. val positiveUser = predictions.where("label = 1").select(userId).distinct()
  79. val scoreAndLabels = predictions.join(positiveUser, Seq[String](userId), "leftsemi").select(col(userId).cast(StringType), col(probabilityCol), col("label").cast(DoubleType)).rdd.map(row => {
  80. val label = row.getAs[Double]("label")
  81. val id = row.getAs[String](userId)
  82. var score = row.getAs[Vector](probabilityCol)(1)
  83. (id, score, label)
  84. }).toDF(userId, "score", "label")
  85.  
  86. val userCount = scoreAndLabels.select(userId).distinct().count().toInt
  87. println(s"userCount = ${userCount}")
  88.  
  89. var result = scoreAndLabels.repartitionByRange(1000, col(userId)).rdd.mapPartitions({
  90. val results = ArrayBuffer[(Long,Double)]()
  91. iter =>
  92. {
  93. val scoreAndLabelMap = HashMap[String, Array[(Double, Double)]]()
  94. for (row scoreAndLabelArray)
  95. } else {
  96. var scoreAndLabelArray = Array[(Double, Double)]()
  97. scoreAndLabelArray :+= ((score, label))
  98. scoreAndLabelMap += (id -> scoreAndLabelArray)
  99. }
  100. }
  101. val users = scoreAndLabelMap.keys
  102. for (user ((x._1 + y._1),(x._2 + y._2)))
  103.  
  104. val totalImpression = result._1
  105. val totalAUC = result._2
  106. println(s"totalImpression = ${totalImpression}")
  107. println(s"totalAUC = ${totalAUC}")
  108. val gauc = if (totalImpression != 0) (totalAUC / totalImpression.toDouble).formatted("%.8f").toDouble else 0.0d
  109. println(s"GAUC = ${gauc}")
  110. gauc
  111. }
  112.  
  113. def main(args: Array[String]): Unit = {
  114. val scoreAndLabels = scala.collection.mutable.ArrayBuffer[(Double, Double)]()
  115. scoreAndLabels += ((0.4, 1))
  116. scoreAndLabels += ((0.4, 0))
  117. scoreAndLabels += ((0.3, 0))
  118. scoreAndLabels += ((0.35, 0))
  119.  
  120. scoreAndLabels.toArray
  121.  
  122. val roc_line = roc(scoreAndLabels.toArray)
  123.  
  124. println(roc_line.mkString(","))
  125.  
  126. val auc_value = areaUnderROC(roc_line)
  127. println(auc_value)
  128. }

代码已经测试可用,希望对大家有用。

python版的代码可以参考:

https://github.com/qiaoguan/deep-ctr-prediction/blob/master/DeepCross/metric.py

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!