PySpark & MLLib: Random Forest Feature Importances

前端 未结 4 1805
情书的邮戳
情书的邮戳 2020-12-08 22:27

I\'m trying to extract the feature importances of a random forest object I have trained using PySpark. However, I do not see an example of doing this anywhere in the documen

相关标签:
4条回答
  • 2020-12-08 22:36

    UPDATE for version > 2.0.0

    From the version 2.0.0, as you can see here, FeatureImportances is available for Random Forest.

    In fact, you can find here that:

    The DataFrame API supports two major tree ensemble algorithms: Random Forests and Gradient-Boosted Trees (GBTs). Both use spark.ml decision trees as their base models.

    Users can find more information about ensemble algorithms in the MLlib Ensemble guide. In this section, we demonstrate the DataFrame API for ensembles.

    The main differences between this API and the original MLlib ensembles API are:

    • support for DataFrames and ML Pipelines
    • separation of classification vs. regression
    • use of DataFrame metadata to distinguish continuous and categorical features
    • more functionality for random forests: estimates of feature importance, as well as the predicted probability of each class (a.k.a. class conditional probabilities) for classification.

    If you want to have Feature Importance values, you have to work with ml package, not mllib, and use dataframes.

    Below there is an example that you can find here:

    # IMPORT
    >>> import numpy
    >>> from numpy import allclose
    >>> from pyspark.ml.linalg import Vectors
    >>> from pyspark.ml.feature import StringIndexer
    >>> from pyspark.ml.classification import RandomForestClassifier
    
    # PREPARE DATA
    >>> df = spark.createDataFrame([
    ...     (1.0, Vectors.dense(1.0)),
    ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
    >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
    >>> si_model = stringIndexer.fit(df)
    >>> td = si_model.transform(df)
    
    # BUILD THE MODEL
    >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42)
    >>> model = rf.fit(td)
    
    # FEATURE IMPORTANCES
    >>> model.featureImportances
    SparseVector(1, {0: 1.0}) 
    
    0 讨论(0)
  • 2020-12-08 22:40

    I have to disappoint you, but feature importances in MLlib implementation of RandomForest are just not calculated, so you cannot get them from anywhere except by by implementing their calculation on your own.

    Here's how to find it out:

    You call a function RandomForest.trainClassifier deinfed here https://github.com/apache/spark/blob/branch-1.3/python/pyspark/mllib/tree.py

    It calls for callMLlibFunc("trainRandomForestModel", ...), which is a call to Scala function RandomForest.trainClassifier or RandomForest.trainRegressor (depending on the algo), which return you RandomForestModel object.

    This object is described in https://github.com/apache/spark/blob/branch-1.3/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala and is extending TreeEnsembleModel defined in the same source file. And unfortunately this class stores only algorithm (regression or classification), trees themselves, relative weights of the trees and combining strategy (sum, avg, vote). It does not store feature importances, unfortunately, and does not even calculate them (see https://github.com/apache/spark/blob/branch-1.3/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala for the calculation algorithm)

    0 讨论(0)
  • 2020-12-08 22:50

    I believe that this now works. You can call:

    from pyspark.ml.classification import RandomForestClassifier
    rf = RandomForestClassifier()
    model = rf.fit(data)
    print model.featureImportances
    

    Running fit on a RandomForestClassifier returns a RandomForestClassificationModel which has the desired featureImportances calculated. I hope that this helps : )

    0 讨论(0)
  • 2020-12-08 23:01

    Feature importance is now implemented in Spark 1.5. See resolved JIRA issue. You can get a Vector of feature importances with:

    val importances: Vector = model.featureImportances
    
    0 讨论(0)
提交回复
热议问题