How to extract rules from decision tree spark MLlib

后端 未结 3 531
离开以前
离开以前 2021-01-12 13:31

I am using Spark MLlib 1.4.1 to create decisionTree model. Now I want to extract rules from decision tree.

How can I extract rules ?

相关标签:
3条回答
  • 2021-01-12 13:39

    We can extract rules using model.debugString attribute. Full example is as follows:

    Note : If you want details on below code, please check https://medium.com/@dipaweshpawar/decoding-decision-tree-in-pyspark-bdd98dcd1ddf

    from pyspark.sql.functions import to_date,datediff,lit,udf,sum,avg,col,count,lag
    from pyspark.sql.types import StringType,LongType,StructType,StructField,DateType,IntegerType,DoubleType
    from datetime import datetime
    from pyspark.sql import SparkSession
    from pyspark.ml.feature import VectorAssembler
    from pyspark.ml.classification import DecisionTreeClassifier
    from pyspark.ml import Pipeline
    import pandas as pd
    from pyspark.sql import DataFrame
    from pyspark.sql.functions import udf, lit, avg, max, min
    from pyspark.sql.types import StringType, ArrayType, DoubleType
    from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler
    from pyspark.ml.classification import DecisionTreeClassifier
    from pyspark.sql import SparkSession
    from pyspark.ml import Pipeline
    import operator
    
    import ast
    
    operators = {
                ">=": operator.ge,
                "<=": operator.le,
                ">": operator.gt,
                "<": operator.lt,
                "==": operator.eq,
                'and': operator.and_,
                'or': operator.or_
            }
    
    data = pd.DataFrame({
        'ball': [0, 1, 1, 3, 1, 0, 1, 3],
        'keep': [4, 5, 6, 7, 7, 4, 6, 7],
        'hall': [8, 9, 10, 11, 2, 6, 10, 11],
        'fall': [12, 13, 14, 15, 15, 12, 14, 15],
        'mall': [16, 17, 18, 10, 10, 16, 18, 10],
        'label': [21, 31, 41, 51, 51, 51, 21, 31]
    })
    df = spark.createDataFrame(data)
    
    f_list = ['ball','keep','mall','hall','fall']
     assemble_numerical_features = VectorAssembler(inputCols=f_list, outputCol='features',
                                                          handleInvalid='skip')
    
    dt = DecisionTreeClassifier(featuresCol='features', labelCol='label')
    
    pipeline = Pipeline(stages=[assemble_numerical_features, dt])
    model = pipeline.fit(df)
    df = model.transform(df)
    dt_m = model.stages[-1]
    
    # Step 1: convert model.debugString output to dictionary of nodes and children
    def parse_debug_string_lines(lines):
        
        block = []
        while lines:
    
            if lines[0].startswith('If'):
                bl = ' '.join(lines.pop(0).split()[1:]).replace('(', '').replace(')', '')
                block.append({'name': bl, 'children': parse_debug_string_lines(lines)})
    
                if lines[0].startswith('Else'):
                    be = ' '.join(lines.pop(0).split()[1:]).replace('(', '').replace(')', '')
                    block.append({'name': be, 'children': parse_debug_string_lines(lines)})
            elif not lines[0].startswith(('If', 'Else')):
                block2 = lines.pop(0)
                block.append({'name': block2})
            else:
                break
        
        return block
    
    def debug_str_to_json(debug_string):
        data = []
        for line in debug_string.splitlines():
            if line.strip():
                line = line.strip()
                data.append(line)
            else:
                break
            if not line: break
        json = {'name': 'Root', 'children': parse_debug_string_lines(data[1:])}
        return json
    
    # Step 2 : Using metadata stored in features column, build dictionary which maps each feature in features column of df to its index in feature vector
    f_type_to_flist_dict = df.schema['features'].metadata["ml_attr"]["attrs"]
    f_index_to_name_dict = {}
    for f_type, f_list in f_type_to_flist_dict.items():
    
        for f in f_list:
            f_index = f['idx']
            f_name = f['name']
            f_index_to_name_dict[f_index] = f_name
    
    
    def generate_explanations(dt_as_json, df:DataFrame, f_index_to_name_dict, operators):
    
        dt_as_json_str = str(dt_as_json)
        cond_parsing_exception_occured = False
    
        df = df.withColumn('features'+'_list',
                                udf(lambda x: x.toArray().tolist(), ArrayType(DoubleType()))
                                (df['features'])
                            )
        # step 3 : parse and check whether current instance follows condition in perticular node
        def parse_validate_cond(cond: str, f_vector: list):
    
            cond_parts = cond.split()
            condition_f_index = int(cond_parts[1])
            condition_op = cond_parts[2]
            condition_value = float(cond_parts[3])
    
            f_value = f_vector[condition_f_index]
            f_name = f_index_to_name_dict[condition_f_index].replace('numerical_features_', '').replace('encoded_numeric_', '').lower()
    
            if operators[condition_op](f_value, condition_value):
                return True, f_name + ' ' + condition_op + ' ' + str(round(condition_value,2))
    
            return False, ''
            
    # Step 4 : extract rules for an instance in a dataframe, going through nodes in a tree where instance is satisfying the rule, finally leading to a prediction node
        def extract_rule(dt_as_json_str: str, f_vector: list, rule=""):
            
            # variable declared in outer function is read only
            # in inner if not explicitly declared to be nonlocal
            nonlocal cond_parsing_exception_occured
    
            dt_as_json = ast.literal_eval(dt_as_json_str)
            child_l = dt_as_json['children']
    
            for child in child_l:
                name = child['name'].strip()
    
                if name.startswith('Predict:'):
                    # remove last comma
                    return rule[0:rule.rindex(',')]
    
                if name.startswith('feature'):
                    try:
                        res, cond = parse_validate_cond(child['name'], f_vector)
                    except Exception as e:
                        res = False
                        cond_parsing_exception_occured = True
                    if res:
                        rule += cond +', '
                        rule = extract_rule(str(child), f_vector, rule=rule)
            return rule
    
        df = df.withColumn('explanation',
                            udf(lambda dt, fv:extract_rule(dt, fv) ,StringType())
                            (lit(dt_as_json_str), df['features'+'_list'])
                        )
        # log exception occured while trying to parse
        # condition in decision tree node
        if cond_parsing_exception_occured:
            print('some node in decision tree has unexpected format')
    
        return df
    
    df = generate_explanations(debug_str_to_json(dt_m.toDebugString), df, f_index_to_name_dict, operators)
    rows = df.select(['ball','keep','mall','hall','fall','explanation','prediction']).collect()
    
    output :
    -----------------------
    [Row(ball=0, keep=4, mall=16, hall=8, fall=12, explanation='hall > 7.0, mall > 13.0, ball <= 0.5', prediction=21.0),
     Row(ball=1, keep=5, mall=17, hall=9, fall=13, explanation='hall > 7.0, mall > 13.0, ball > 0.5, keep <= 5.5', prediction=31.0),
     Row(ball=1, keep=6, mall=18, hall=10, fall=14, explanation='hall > 7.0, mall > 13.0, ball > 0.5, keep > 5.5', prediction=21.0),
     Row(ball=3, keep=7, mall=10, hall=11, fall=15, explanation='hall > 7.0, mall <= 13.0', prediction=31.0),
     Row(ball=1, keep=7, mall=10, hall=2, fall=15, explanation='hall <= 7.0', prediction=51.0),
     Row(ball=0, keep=4, mall=16, hall=6, fall=12, explanation='hall <= 7.0', prediction=51.0),
     Row(ball=1, keep=6, mall=18, hall=10, fall=14, explanation='hall > 7.0, mall > 13.0, ball > 0.5, keep > 5.5', prediction=21.0),
     Row(ball=3, keep=7, mall=10, hall=11, fall=15, explanation='hall > 7.0, mall <= 13.0', prediction=31.0)]
    
    output of dt_m.toDebugString:
    -----------------------------------
    'DecisionTreeClassificationModel (uid=DecisionTreeClassifier_2a17ae7633b9) of depth 4 with 9 nodes\n  If (feature 3 <= 7.0)\n   Predict: 51.0\n  Else (feature 3 > 7.0)\n   If (feature 2 <= 13.0)\n    Predict: 31.0\n   Else (feature 2 > 13.0)\n    If (feature 0 <= 0.5)\n     Predict: 21.0\n    Else (feature 0 > 0.5)\n     If (feature 1 <= 5.5)\n      Predict: 31.0\n     Else (feature 1 > 5.5)\n      Predict: 21.0\n'
    
    output of debug_str_to_json(dt_m.toDebugString):
    ------------------------------------
    {'name': 'Root',
    'children': [{'name': 'feature 3 <= 7.0',
       'children': [{'name': 'Predict: 51.0'}]},
      {'name': 'feature 3 > 7.0',
       'children': [{'name': 'feature 2 <= 13.0',
         'children': [{'name': 'Predict: 31.0'}]},
        {'name': 'feature 2 > 13.0',
         'children': [{'name': 'feature 0 <= 0.5',
           'children': [{'name': 'Predict: 21.0'}]},
          {'name': 'feature 0 > 0.5',
           'children': [{'name': 'feature 1 <= 5.5',
             'children': [{'name': 'Predict: 31.0'}]},
            {'name': 'feature 1 > 5.5',
             'children': [{'name': 'Predict: 21.0'}]}]}]}]}]}
    
    0 讨论(0)
  • 2021-01-12 13:40

    You can get the full model as a string by calling model.toDebugString(), or save it as JSON by calling model.save(sc, filePath).

    The documentation is here, which contains a example with a small sample data that you can inspect the output format in command line. Here I formatted the script that you can directly past and run.

    from numpy import array
    from pyspark.mllib.regression import LabeledPoint
    from pyspark.mllib.tree import DecisionTree
    
    data = [
    LabeledPoint(0.0, [0.0]),
    LabeledPoint(1.0, [1.0]),
    LabeledPoint(1.0, [2.0]),
    LabeledPoint(1.0, [3.0])
    ]
    
    model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
    print(model)
    
    print(model.toDebugString())
    

    the output is:

    DecisionTreeModel classifier of depth 1 with 3 nodes
    DecisionTreeModel classifier of depth 1 with 3 nodes
      If (feature 0 <= 0.0)
       Predict: 0.0
      Else (feature 0 > 0.0)
       Predict: 1.0 
    

    In real application, the model can be very large and consists many lines. So directly use dtModel.toDebugString() can cause IPython notebook to halt. So I suggest to out put it as a text file.

    Here is an example code of how to export a model dtModel to text file. Suppose we get the dtModel like this:

    dtModel = DecisionTree.trainClassifier(parsedTrainData, numClasses=7, categoricalFeaturesInfo={},impurity='gini', maxDepth=20, maxBins=24)
    
    
    
    modelFile = ~/decisionTreeModel.txt"
    f = open(modelFile,"w") 
    f.write(dtModel.toDebugString())
    f.close() 
    

    Here is an example output of the above script from my dtMmodel:

    DecisionTreeModel classifier of depth 20 with 20031 nodes
      If (feature 0 <= -35.0)
       If (feature 24 <= 176.0)
        If (feature 0 <= -200.0)
         If (feature 29 <= 109.0)
          If (feature 6 <= -156.0)
           If (feature 9 <= 0.0)
            If (feature 20 <= -116.0)
             If (feature 16 <= 203.0)
              If (feature 11 <= 163.0)
               If (feature 5 <= 384.0)
                If (feature 15 <= 325.0)
                 If (feature 13 <= -248.0)
                  If (feature 20 <= -146.0)
                   Predict: 0.0
                  Else (feature 20 > -146.0)
                   If (feature 19 <= -58.0)
                    Predict: 6.0
                   Else (feature 19 > -58.0)
                    Predict: 0.0
                 Else (feature 13 > -248.0)
                  If (feature 9 <= -26.0)
                   Predict: 0.0
                  Else (feature 9 > -26.0)
                   If (feature 10 <= 218.0)
    ...
    ...
    ...
    ...
    
    0 讨论(0)
  • 2021-01-12 13:42
    import networkx as nx
    

    Load the model data, this is present in hadoop if you have previously used model.save(location) at that location

    modeldf = spark.read.parquet(location+"/data/*")
    
    noderows = modeldf.select("id","prediction","leftChild","rightChild","split").collect()
    
     
    

    Creating a dummy feature array

    features = ["feature"+str(i) for i in range(0,700)]
    

    Initialize the graph

    G = nx.DiGraph()
    for rw in noderows:
    
        if rw['leftChild'] < 0 and rw['rightChild'] < 0:
    
            G.add_node(rw['id'], cat="Prediction", predval=rw['prediction'])
    
        else:
    
            G.add_node(rw['id'], cat="splitter", featureIndex=rw['split']['featureIndex'], thresh=rw['split']['leftCategoriesOrThreshold'], leftChild=rw['leftChild'], rightChild=rw['rightChild'], numCat=rw['split']['numCategories'])
    
     
    
    for rw in modeldf.where("leftChild > 0 and rightChild > 0").collect():
    
        tempnode = G.nodes(data="True")[rw['id']][1]
    
        #print(tempnode)
    
        G.add_edge(rw['id'], rw['leftChild'], reason="{0} less than {1}".format(features[tempnode['featureIndex']],tempnode['thresh']))
    
        G.add_edge(rw['id'], rw['rightChild'], reason="{0} greater than {1}".format(features[tempnode['featureIndex']],tempnode['thresh']))
    
     
    
     
    

    The code above converts all the rules to a graph network. To print all the rules in if and else format, we can find path to all the leaf nodes, and list the edge reason to extract the final rules

    nodes = [x for x in G.nodes() if G.out_degree(x)==0 and G.in_degree(x)==1]
    
    for n in nodes:
    
        p = nx.shortest_path(G,0,n)
    
        print("Rule No:",n)
    
        print(" & ".join([G.get_edge_data(p[i],p[i+1])['reason'] for i in range(0,len(p)-1)]))
    

    The output looks something like this:

    ('Rule No:', 5)

    feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 less than [1.0]

    ('Rule No:', 8)

    feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 greater than [1.0] & feature318 less than [0.0] & feature385 less than [0.0]

    ('Rule No:', 9)

    feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 greater than [1.0] & feature318 less than [0.0] & feature385 greater than [0.0]

    ('Rule No:', 11)

    feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 greater than [1.0] & feature318 greater than [0.0] & feature266 less than [0.0]

    ('Rule No:', 12)

    feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 greater than [1.0] & feature318 greater than [0.0] & feature266 greater than [0.0]

    ('Rule No:', 16)

    feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 greater than [1.0] & feature158 less than [1.0] & feature274 less than [0.0] & feature89 less than [1.0]

    ('Rule No:', 17)

    feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 greater than [1.0] & feature158 less than [1.0] & feature274 less than [0.0] & feature89 greater than [1.0]

    Modified the initial code present here

    0 讨论(0)
提交回复
热议问题