How to extract rules from decision tree spark MLlib

后端 未结 3 529
离开以前
离开以前 2021-01-12 13:31

I am using Spark MLlib 1.4.1 to create decisionTree model. Now I want to extract rules from decision tree.

How can I extract rules ?

3条回答
  •  执笔经年
    2021-01-12 13:40

    You can get the full model as a string by calling model.toDebugString(), or save it as JSON by calling model.save(sc, filePath).

    The documentation is here, which contains a example with a small sample data that you can inspect the output format in command line. Here I formatted the script that you can directly past and run.

    from numpy import array
    from pyspark.mllib.regression import LabeledPoint
    from pyspark.mllib.tree import DecisionTree
    
    data = [
    LabeledPoint(0.0, [0.0]),
    LabeledPoint(1.0, [1.0]),
    LabeledPoint(1.0, [2.0]),
    LabeledPoint(1.0, [3.0])
    ]
    
    model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
    print(model)
    
    print(model.toDebugString())
    

    the output is:

    DecisionTreeModel classifier of depth 1 with 3 nodes
    DecisionTreeModel classifier of depth 1 with 3 nodes
      If (feature 0 <= 0.0)
       Predict: 0.0
      Else (feature 0 > 0.0)
       Predict: 1.0 
    

    In real application, the model can be very large and consists many lines. So directly use dtModel.toDebugString() can cause IPython notebook to halt. So I suggest to out put it as a text file.

    Here is an example code of how to export a model dtModel to text file. Suppose we get the dtModel like this:

    dtModel = DecisionTree.trainClassifier(parsedTrainData, numClasses=7, categoricalFeaturesInfo={},impurity='gini', maxDepth=20, maxBins=24)
    
    
    
    modelFile = ~/decisionTreeModel.txt"
    f = open(modelFile,"w") 
    f.write(dtModel.toDebugString())
    f.close() 
    

    Here is an example output of the above script from my dtMmodel:

    DecisionTreeModel classifier of depth 20 with 20031 nodes
      If (feature 0 <= -35.0)
       If (feature 24 <= 176.0)
        If (feature 0 <= -200.0)
         If (feature 29 <= 109.0)
          If (feature 6 <= -156.0)
           If (feature 9 <= 0.0)
            If (feature 20 <= -116.0)
             If (feature 16 <= 203.0)
              If (feature 11 <= 163.0)
               If (feature 5 <= 384.0)
                If (feature 15 <= 325.0)
                 If (feature 13 <= -248.0)
                  If (feature 20 <= -146.0)
                   Predict: 0.0
                  Else (feature 20 > -146.0)
                   If (feature 19 <= -58.0)
                    Predict: 6.0
                   Else (feature 19 > -58.0)
                    Predict: 0.0
                 Else (feature 13 > -248.0)
                  If (feature 9 <= -26.0)
                   Predict: 0.0
                  Else (feature 9 > -26.0)
                   If (feature 10 <= 218.0)
    ...
    ...
    ...
    ...
    

提交回复
热议问题