How to use RandomForest in Spark Pipeline

对着背影说爱祢 提交于 2019-12-04 13:16:26

问题


I want to tunning my model with grid search and cross validation with spark. In the spark, it must put the base model in a pipeline, the office demo of pipeline use the LogistictRegression as an base model, which can be new as an object. However, the RandomForest model cannot be new by client code, so it seems not be able to use RandomForest in the pipeline api. I don't want to recreate an wheel, so can anybody give some advice? Thanks


回答1:


However, the RandomForest model cannot be new by client code, so it seems not be able to use RandomForest in the pipeline api.

Well, that is true but you simply trying to use a wrong class. Instead of mllib.tree.RandomForest you should use ml.classification.RandomForestClassifier. Here is an example based on the one from MLlib docs.

import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLUtils
import sqlContext.implicits._ 

case class Record(category: String, features: Vector)

val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainData, testData) = (splits(0), splits(1))

val trainDF = trainData.map(lp => Record(lp.label.toString, lp.features)).toDF
val testDF = testData.map(lp => Record(lp.label.toString, lp.features)).toDF

val indexer = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("label")

val rf  = new RandomForestClassifier()
    .setNumTrees(3)
    .setFeatureSubsetStrategy("auto")
    .setImpurity("gini")
    .setMaxDepth(4)
    .setMaxBins(32)

val pipeline = new Pipeline()
    .setStages(Array(indexer, rf))

val model = pipeline.fit(trainDF)

model.transform(testDF)

There is one thing I couldn't figure out here. As far as I can tell it should be possible to use labels extracted from LabeledPoints directly, but for some reason it doesn't work and pipeline.fit raises IllegalArgumentExcetion:

RandomForestClassifier was given input with invalid label column label, without the number of classes specified.

Hence the ugly trick with StringIndexer. After applying we get required attributes ({"vals":["1.0","0.0"],"type":"nominal","name":"label"}) but some classes in ml seem to work just fine without it.



来源:https://stackoverflow.com/questions/32108887/how-to-use-randomforest-in-spark-pipeline

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!