PySpark & MLLib: Class Probabilities of Random Forest Predictions

后端 未结 4 1903
忘掉有多难
忘掉有多难 2021-02-03 12:19

I\'m trying to extract the class probabilities of a random forest object I have trained using PySpark. However, I do not see an example of it anywhere in the documentation, nor

4条回答
  •  -上瘾入骨i
    2021-02-03 13:10

    As far as I can tell this is not supported in the current version (1.2.1). The Python wrapper over the native Scala code (tree.py) defines only 'predict' functions which, in turn, call the respective Scala counterparts (treeEnsembleModels.scala). The latter make decisions by taking a vote among binary decisions. A much cleaner solution would have been to provide a probabilistic prediction which can be thresholded arbitrarily or used for ROC computation like in sklearn. This feature should be added for future releases!

    As a workaround, I implemented predict_proba as a pure Python function (see example below). It is neither elegant nor very efficient, as it runs a loop over the set of individual decision trees in a forest. The trick - or rather a dirty hack - is to access the array of Java decision tree models and cast them into Python counterparts. After that you can compute individual model's predictions over the entire dataset and accumulate their sum in an RDD using 'zip'. Dividing by the number of trees gets the desired result. For large datasets, a loop over a small number of decision trees in a master node should be acceptable.

    The code below is rather tricky due to the difficulties of integrating Python into Spark (run in Java). One should be very careful not to send any complex data to worker nodes, which results in crashes due to serialization problems. No code referring to the Spark context can be run on a worker node. Also, no code referring to any Java code can be serialized. For example, it may be tempting to use len(trees) instead of ntrees in the code below - bang! Writing such a wrapper in Java/Scala can be much more elegant, for example by running a loop over decision trees on worker nodes and hence reducing communication costs.

    The test function below demonstrates that the predict_proba gives identical test error as predict used in original examples.

    def predict_proba(rf_model, data):
       '''
       This wrapper overcomes the "binary" nature of predictions in the native
       RandomForestModel. 
       '''
    
        # Collect the individual decision tree models by calling the underlying
        # Java model. These are returned as JavaArray defined by py4j.
        trees = rf_model._java_model.trees()
        ntrees = rf_model.numTrees()
        scores = DecisionTreeModel(trees[0]).predict(data.map(lambda x: x.features))
    
        # For each decision tree, apply its prediction to the entire dataset and
        # accumulate the results using 'zip'.
        for i in range(1,ntrees):
            dtm = DecisionTreeModel(trees[i])
            scores = scores.zip(dtm.predict(data.map(lambda x: x.features)))
            scores = scores.map(lambda x: x[0] + x[1])
    
        # Divide the accumulated scores over the number of trees
        return scores.map(lambda x: x/ntrees)
    
    def testError(lap):
        testErr = lap.filter(lambda (v, p): v != p).count() / float(testData.count())
        print('Test Error = ' + str(testErr))
    
    
    def testClassification(trainingData, testData):
    
        model = RandomForest.trainClassifier(trainingData, numClasses=2,
                                             categoricalFeaturesInfo={},
                                             numTrees=50, maxDepth=30)
    
        # Compute test error by thresholding probabilistic predictions
        threshold = 0.5
        scores = predict_proba(model,testData)
        pred = scores.map(lambda x: 0 if x < threshold else 1)
        lab_pred = testData.map(lambda lp: lp.label).zip(pred)
        testError(lab_pred)
    
        # Compute test error by comparing binary predictions
        predictions = model.predict(testData.map(lambda x: x.features))
        labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
        testError(labelsAndPredictions)
    

    All-in-all, this was a nice exercise to learn Spark!

提交回复
热议问题