How to access individual trees in a model created by RandomForestClassifier (spark.ml-version)?

前端 未结 1 1348
不思量自难忘°
不思量自难忘° 2021-01-01 04:36

How to access individual trees in a model generated by Spark ML\'s RandomForestClassifier? I am using the Scala version of RandomForestClassifier.

相关标签:
1条回答
  • 2021-01-01 05:26

    Actually it has trees attribute:

    import org.apache.spark.ml.attribute.NominalAttribute
    import org.apache.spark.ml.classification.{
      RandomForestClassificationModel, RandomForestClassifier, 
      DecisionTreeClassificationModel
    }
    
    val meta = NominalAttribute
      .defaultAttr
      .withName("label")
      .withValues("0.0", "1.0")
      .toMetadata
    
    val data = sqlContext.read.format("libsvm")
      .load("data/mllib/sample_libsvm_data.txt")
      .withColumn("label", $"label".as("label", meta))
    
    val rf: RandomForestClassifier = new RandomForestClassifier()
      .setLabelCol("label")
      .setFeaturesCol("features")
    
    val trees: Array[DecisionTreeClassificationModel] = rf.fit(data).trees.collect {
      case t: DecisionTreeClassificationModel => t
    }
    

    As you can see the only problem is to get types right so we can actually use these:

    trees.head.transform(data).show(3)
    // +-----+--------------------+-------------+-----------+----------+
    // |label|            features|rawPrediction|probability|prediction|
    // +-----+--------------------+-------------+-----------+----------+
    // |  0.0|(692,[127,128,129...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
    // |  1.0|(692,[158,159,160...|   [0.0,59.0]|  [0.0,1.0]|       1.0|
    // |  1.0|(692,[124,125,126...|   [0.0,59.0]|  [0.0,1.0]|       1.0|
    // +-----+--------------------+-------------+-----------+----------+
    // only showing top 3 rows
    

    Note:

    If you work with pipelines you can extract individual trees as well:

    import org.apache.spark.ml.Pipeline
    
    val model = new Pipeline().setStages(Array(rf)).fit(data)
    
    // There is only one stage and know its type 
    // but lets be thorough
    val rfModelOption = model.stages.headOption match {
      case Some(m: RandomForestClassificationModel) => Some(m)
      case _ => None
    }
    
    val trees = rfModelOption.map {
      _.trees //  ... as before
    }.getOrElse(Array())
    
    0 讨论(0)
提交回复
热议问题