Spark MLib Decision Trees: Probability of labels by features?

后端 未结 1 1176
深忆病人
深忆病人 2020-12-19 23:34

I could manage to display total probabilities of my labels, for example after displaying my decision tree, I have a table :

Total Predictions :
         


        
1条回答
  •  时光说笑
    2020-12-19 23:58

    Note: the following solution is for Scala only. I didn't find a way to do it in Python.

    Assuming you just want a visual representation of the tree as in your example, maybe one option is to adapt the method subtreeToString present in the Node.scala code on Spark's GitHub to include the probabilities at each node split, like in the following snippet:

    def subtreeToString(rootNode: Node, indentFactor: Int = 0): String = {
      def splitToString(split: Split, left: Boolean): String = {
        split.featureType match {
          case Continuous => if (left) {
            s"(feature ${split.feature} <= ${split.threshold})"
          } else {
            s"(feature ${split.feature} > ${split.threshold})"
          }
          case Categorical => if (left) {
            s"(feature ${split.feature} in ${split.categories.mkString("{", ",", "}")})"
          } else {
            s"(feature ${split.feature} not in ${split.categories.mkString("{", ",", "}")})"
          }
        }
      }
      val prefix: String = " " * indentFactor
      if (rootNode.isLeaf) {
        prefix + s"Predict: ${rootNode.predict.predict} \n"
      } else {
        val prob = rootNode.predict.prob*100D
        prefix + s"If ${splitToString(rootNode.split.get, left = true)} " + f"(Prob: $prob%04.2f %%)" + "\n" +
          subtreeToString(rootNode.leftNode.get, indentFactor + 1) +
          prefix + s"Else ${splitToString(rootNode.split.get, left = false)} " + f"(Prob: ${100-prob}%04.2f %%)" + "\n" +
          subtreeToString(rootNode.rightNode.get, indentFactor + 1)
      }
    }
    

    I've tested it on a model I run on the Iris dataset, and I've got the following result:

    scala> println(subtreeToString(model.topNode))
    
    If (feature 2 <= -0.762712) (Prob: 35.35 %)
     Predict: 1.0
    Else (feature 2 > -0.762712) (Prob: 64.65 %)
     If (feature 3 <= 0.333333) (Prob: 52.24 %)
      If (feature 0 <= -0.666667) (Prob: 92.11 %)
       Predict: 3.0
      Else (feature 0 > -0.666667) (Prob: 7.89 %)
       If (feature 2 <= 0.322034) (Prob: 94.59 %)
        Predict: 2.0
       Else (feature 2 > 0.322034) (Prob: 5.41 %)
        If (feature 3 <= 0.166667) (Prob: 50.00 %)
         Predict: 3.0
        Else (feature 3 > 0.166667) (Prob: 50.00 %)
         Predict: 2.0
     Else (feature 3 > 0.333333) (Prob: 47.76 %)
      Predict: 3.0
    

    A similar approach could be used for creating a tree structure with this information. The main difference would be to store the printed information (split.feature, split.threshold, predict.prob, and so on) as vals and use them to build the structure.

    0 讨论(0)
提交回复
热议问题