Scala: how to know which probability correspond to which class?

旧街凉风 提交于 2020-01-26 04:54:44

问题


I create a classifier random forest to predict something. The label is either "yes" (=1.0) or "no" (=0.0)

I apply my model on a test. Here is my code and my result for 20 lines:

import org.apache.spark.ml.tuning.CrossValidatorModel
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.functions._

var modelrf = CrossValidatorModel.load("modelSupervise/newModel")
var test = spark.sql("""select * from dc.newTest""")

var predictions = modelrf.transform(test)

predictions.select("id","label","rawPrediction","probability","prediction").show(20,false)


+--------+--------------+----------------------------------------+-----------------------------------------+----------+
|id      |label         |rawPrediction                           |probability                              |prediction|
+--------+--------------+----------------------------------------+-----------------------------------------+----------+
|1       |0             |[18.954508743604,1.0454912563959982]    |[0.9477254371802001,0.05227456281979992] |0.0       |
|2       |0             |[19.396893651115214,0.6031063488847838] |[0.9698446825557608,0.030155317444239195]|0.0       |
|3       |0             |[19.562942473138747,0.4370575268612524] |[0.9781471236569373,0.02185287634306262] |0.0       |
|4       |0             |[19.072030495384865,0.9279695046151306] |[0.9536015247692434,0.04639847523075654] |0.0       |
|5       |0             |[19.43338228765314,0.5666177123468583]  |[0.9716691143826571,0.02833088561734292] |0.0       |
|6       |0             |[19.696154641398266,0.3038453586017339] |[0.9848077320699133,0.015192267930086694]|0.0       |
|7       |0             |[19.561887703818552,0.4381122961814507] |[0.9780943851909274,0.02190561480907253] |0.0       |
|8       |0             |[19.670868420870097,0.32913157912990343]|[0.9835434210435048,0.01645657895649517] |0.0       |
|9       |0             |[19.31258444658832,0.6874155534116762]  |[0.9656292223294163,0.034370777670583816]|0.0       |
|10      |1             |[19.324118365007614,0.6758816349923846] |[0.9662059182503807,0.03379408174961923] |0.0       |
|11      |0             |[19.671923190190295,0.32807680980970505]|[0.9835961595095147,0.016403840490485253]|0.0       |
|12      |0             |[5.549867107480572,14.450132892519427]  |[0.2774933553740286,0.7225066446259714]  |1.0       |
|13      |0             |[8.302734500577003,11.697265499422995]  |[0.41513672502885013,0.5848632749711498] |1.0       |
|14      |0             |[3.719926021010336,16.280073978989666]  |[0.1859963010505168,0.8140036989494831]  |1.0       |
|15      |1             |[4.9810130629790486,15.018986937020955] |[0.2490506531489524,0.7509493468510476]  |1.0       |
|16      |1             |[7.575144612227263,12.424855387772734]  |[0.37875723061136324,0.6212427693886368] |1.0       |
|17      |0             |[9.763210063340546,10.236789936659454]  |[0.4881605031670273,0.5118394968329727]  |1.0       |
|18      |0             |[9.475787091640768,10.524212908359234]  |[0.4737893545820384,0.5262106454179617]  |1.0       |
|19      |1             |[4.236097613170449,15.763902386829551]  |[0.21180488065852243,0.7881951193414776] |1.0       |
|20      |0             |[8.748700591583557,11.251299408416445]  |[0.43743502957917785,0.5625649704208222] |1.0       |
|21      |0             |[8.908800090849974,11.091199909150026]  |[0.4454400045424987,0.5545599954575013]  |1.0       |
|22      |1             |[9.726530070446398,10.273469929553602]  |[0.4863265035223199,0.5136734964776801]  |1.0       |
|23      |1             |[8.908800090849974,11.091199909150026]  |[0.4454400045424987,0.5545599954575013]  |1.0       |
+--------+--------------+----------------------------------------+-----------------------------------------+----------+

Here is what I understand first:

for id=1. 18.95 trees predict the value "0.0" and 1.045 trees predict the value "1.1". I thought that scala order the values of the vector "rawPrediction" regaring the value of the class --> first regard the class "0" and the second one regard the class "1".

But if it were true and if we had "yes" or "no" instead of 0 or 1, what order would scala give? Alphabetical order?

I made some research and I find this question: Random Forest Classifier :To which class corresponds the probabilities

The question is the same but for the vector "probability". Which element of the vector correspond to the probability to predict "0" and which element correspond to the probability to predict "1"?

I do not understand the answer...

How to know, for each line, what is the probability for the model to predict "yes" (or 1)? Does scala order probabilities numericaly or alphabeticaly regarding the type of the label...?

Thank you in advance!!


回答1:


Here is the answer!!! In my question I load a model.

But the answer is before that.

To fit the model I use a labelIndexer on my target. This label indexer transform the target into an indexe by descending frequency.

ex: if, in my target I have 20% of "aa" and 80% of "bb" label indexer will create a column "label" that took the value 0 for "bb" and 1 for "aa" (because I "bb" is ore frequent than "aa")

When we fit a random forest, the probabilities correspond to the order of frequency.

In binary classification:

  • first proba = probability that the class is the most frequent class in the train set
  • second proba = probability that the class is the less frequent class in the train set


来源:https://stackoverflow.com/questions/57142056/scala-how-to-know-which-probability-correspond-to-which-class

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!