RandomForestClassifier was given input with invalid label column error in Apache Spark

后端 未结 1 522
忘了有多久
忘了有多久 2021-01-13 05:24

I am trying to find Accuracy using 5-fold cross validation using Random Forest Classifier Model in SCALA. But i am getting the following error while running:

相关标签:
1条回答
  • 2021-01-13 05:42

    RandomForestClassifier, same as many other ML algorithms, require specific metadata to be set on the label column and labels values to be integral values from [0, 1, 2 ..., #classes) represented as doubles. Typically this is handled by an upstream Transformers like StringIndexer. Since you convert labels manually metadata fields are not set and classifier cannot confirm that these requirements are satisfied.

    val df = Seq(
      (0.0, Vectors.dense(1, 0, 0, 0)),
      (1.0, Vectors.dense(0, 1, 0, 0)),
      (2.0, Vectors.dense(0, 0, 1, 0)),
      (2.0, Vectors.dense(0, 0, 0, 1))
    ).toDF("label", "features")
    
    val rf = new RandomForestClassifier()
      .setFeaturesCol("features")
      .setNumTrees(5)
    
    rf.setLabelCol("label").fit(df)
    // java.lang.IllegalArgumentException: RandomForestClassifier was given input ...
    

    You can either re-encode label column using StringIndexer:

    import org.apache.spark.ml.feature.StringIndexer
    
    val indexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("label_idx")
      .fit(df)
    
    rf.setLabelCol("label_idx").fit(indexer.transform(df))
    

    or set required metadata manually:

    val meta = NominalAttribute
      .defaultAttr
      .withName("label")
      .withValues("0.0", "1.0", "2.0")
      .toMetadata
    
    rf.setLabelCol("label_meta").fit(
      df.withColumn("label_meta", $"label".as("", meta))
    )
    

    Note:

    Labels created using StringIndexer depend on the frequency not value:

    indexer.labels
    // Array[String] = Array(2.0, 0.0, 1.0)
    

    PySpark:

    In Python metadata fields can be set directly on the schema:

    from pyspark.sql.types import StructField, DoubleType
    
    StructField(
        "label", DoubleType(), False,
        {"ml_attr": {
            "name": "label",
            "type": "nominal", 
            "vals": ["0.0", "1.0", "2.0"]
        }}
    )
    
    0 讨论(0)
提交回复
热议问题