How to use RandomForest in Spark Pipeline

后端 未结 1 1180
自闭症患者
自闭症患者 2021-02-10 19:31

I want to tunning my model with grid search and cross validation with spark. In the spark, it must put the base model in a pipeline, the office demo of pipeline use the Lo

相关标签:
1条回答
  • 2021-02-10 20:09

    However, the RandomForest model cannot be new by client code, so it seems not be able to use RandomForest in the pipeline api.

    Well, that is true but you simply trying to use a wrong class. Instead of mllib.tree.RandomForest you should use ml.classification.RandomForestClassifier. Here is an example based on the one from MLlib docs.

    import org.apache.spark.ml.classification.RandomForestClassifier
    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.feature.StringIndexer
    import org.apache.spark.mllib.linalg.Vector
    import org.apache.spark.mllib.util.MLUtils
    import sqlContext.implicits._ 
    
    case class Record(category: String, features: Vector)
    
    val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
    val splits = data.randomSplit(Array(0.7, 0.3))
    val (trainData, testData) = (splits(0), splits(1))
    
    val trainDF = trainData.map(lp => Record(lp.label.toString, lp.features)).toDF
    val testDF = testData.map(lp => Record(lp.label.toString, lp.features)).toDF
    
    val indexer = new StringIndexer()
      .setInputCol("category")
      .setOutputCol("label")
    
    val rf  = new RandomForestClassifier()
        .setNumTrees(3)
        .setFeatureSubsetStrategy("auto")
        .setImpurity("gini")
        .setMaxDepth(4)
        .setMaxBins(32)
    
    val pipeline = new Pipeline()
        .setStages(Array(indexer, rf))
    
    val model = pipeline.fit(trainDF)
    
    model.transform(testDF)
    

    There is one thing I couldn't figure out here. As far as I can tell it should be possible to use labels extracted from LabeledPoints directly, but for some reason it doesn't work and pipeline.fit raises IllegalArgumentExcetion:

    RandomForestClassifier was given input with invalid label column label, without the number of classes specified.

    Hence the ugly trick with StringIndexer. After applying we get required attributes ({"vals":["1.0","0.0"],"type":"nominal","name":"label"}) but some classes in ml seem to work just fine without it.

    0 讨论(0)
提交回复
热议问题