How to extract rules from decision tree spark MLlib

后端 未结 3 526
离开以前
离开以前 2021-01-12 13:31

I am using Spark MLlib 1.4.1 to create decisionTree model. Now I want to extract rules from decision tree.

How can I extract rules ?

3条回答
  •  时光说笑
    2021-01-12 13:42

    import networkx as nx
    

    Load the model data, this is present in hadoop if you have previously used model.save(location) at that location

    modeldf = spark.read.parquet(location+"/data/*")
    
    noderows = modeldf.select("id","prediction","leftChild","rightChild","split").collect()
    
     
    

    Creating a dummy feature array

    features = ["feature"+str(i) for i in range(0,700)]
    

    Initialize the graph

    G = nx.DiGraph()
    for rw in noderows:
    
        if rw['leftChild'] < 0 and rw['rightChild'] < 0:
    
            G.add_node(rw['id'], cat="Prediction", predval=rw['prediction'])
    
        else:
    
            G.add_node(rw['id'], cat="splitter", featureIndex=rw['split']['featureIndex'], thresh=rw['split']['leftCategoriesOrThreshold'], leftChild=rw['leftChild'], rightChild=rw['rightChild'], numCat=rw['split']['numCategories'])
    
     
    
    for rw in modeldf.where("leftChild > 0 and rightChild > 0").collect():
    
        tempnode = G.nodes(data="True")[rw['id']][1]
    
        #print(tempnode)
    
        G.add_edge(rw['id'], rw['leftChild'], reason="{0} less than {1}".format(features[tempnode['featureIndex']],tempnode['thresh']))
    
        G.add_edge(rw['id'], rw['rightChild'], reason="{0} greater than {1}".format(features[tempnode['featureIndex']],tempnode['thresh']))
    
     
    
     
    

    The code above converts all the rules to a graph network. To print all the rules in if and else format, we can find path to all the leaf nodes, and list the edge reason to extract the final rules

    nodes = [x for x in G.nodes() if G.out_degree(x)==0 and G.in_degree(x)==1]
    
    for n in nodes:
    
        p = nx.shortest_path(G,0,n)
    
        print("Rule No:",n)
    
        print(" & ".join([G.get_edge_data(p[i],p[i+1])['reason'] for i in range(0,len(p)-1)]))
    

    The output looks something like this:

    ('Rule No:', 5)

    feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 less than [1.0]

    ('Rule No:', 8)

    feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 greater than [1.0] & feature318 less than [0.0] & feature385 less than [0.0]

    ('Rule No:', 9)

    feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 greater than [1.0] & feature318 less than [0.0] & feature385 greater than [0.0]

    ('Rule No:', 11)

    feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 greater than [1.0] & feature318 greater than [0.0] & feature266 less than [0.0]

    ('Rule No:', 12)

    feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 less than [1.0] & feature367 greater than [1.0] & feature318 greater than [0.0] & feature266 greater than [0.0]

    ('Rule No:', 16)

    feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 greater than [1.0] & feature158 less than [1.0] & feature274 less than [0.0] & feature89 less than [1.0]

    ('Rule No:', 17)

    feature457 less than [0.0] & feature353 less than [0.0] & feature185 less than [1.0] & feature294 greater than [1.0] & feature158 less than [1.0] & feature274 less than [0.0] & feature89 greater than [1.0]

    Modified the initial code present here

提交回复
热议问题