I am using Spark MLlib 1.4.1 to create decisionTree model. Now I want to extract rules from decision tree.
How can I extract rules ?
We can extract rules using model.debugString attribute. Full example is as follows:
Note : If you want details on below code, please check https://medium.com/@dipaweshpawar/decoding-decision-tree-in-pyspark-bdd98dcd1ddf
from pyspark.sql.functions import to_date,datediff,lit,udf,sum,avg,col,count,lag
from pyspark.sql.types import StringType,LongType,StructType,StructField,DateType,IntegerType,DoubleType
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml import Pipeline
import pandas as pd
from pyspark.sql import DataFrame
from pyspark.sql.functions import udf, lit, avg, max, min
from pyspark.sql.types import StringType, ArrayType, DoubleType
from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
import operator
import ast
operators = {
">=": operator.ge,
"<=": operator.le,
">": operator.gt,
"<": operator.lt,
"==": operator.eq,
'and': operator.and_,
'or': operator.or_
}
data = pd.DataFrame({
'ball': [0, 1, 1, 3, 1, 0, 1, 3],
'keep': [4, 5, 6, 7, 7, 4, 6, 7],
'hall': [8, 9, 10, 11, 2, 6, 10, 11],
'fall': [12, 13, 14, 15, 15, 12, 14, 15],
'mall': [16, 17, 18, 10, 10, 16, 18, 10],
'label': [21, 31, 41, 51, 51, 51, 21, 31]
})
df = spark.createDataFrame(data)
f_list = ['ball','keep','mall','hall','fall']
assemble_numerical_features = VectorAssembler(inputCols=f_list, outputCol='features',
handleInvalid='skip')
dt = DecisionTreeClassifier(featuresCol='features', labelCol='label')
pipeline = Pipeline(stages=[assemble_numerical_features, dt])
model = pipeline.fit(df)
df = model.transform(df)
dt_m = model.stages[-1]
# Step 1: convert model.debugString output to dictionary of nodes and children
def parse_debug_string_lines(lines):
block = []
while lines:
if lines[0].startswith('If'):
bl = ' '.join(lines.pop(0).split()[1:]).replace('(', '').replace(')', '')
block.append({'name': bl, 'children': parse_debug_string_lines(lines)})
if lines[0].startswith('Else'):
be = ' '.join(lines.pop(0).split()[1:]).replace('(', '').replace(')', '')
block.append({'name': be, 'children': parse_debug_string_lines(lines)})
elif not lines[0].startswith(('If', 'Else')):
block2 = lines.pop(0)
block.append({'name': block2})
else:
break
return block
def debug_str_to_json(debug_string):
data = []
for line in debug_string.splitlines():
if line.strip():
line = line.strip()
data.append(line)
else:
break
if not line: break
json = {'name': 'Root', 'children': parse_debug_string_lines(data[1:])}
return json
# Step 2 : Using metadata stored in features column, build dictionary which maps each feature in features column of df to its index in feature vector
f_type_to_flist_dict = df.schema['features'].metadata["ml_attr"]["attrs"]
f_index_to_name_dict = {}
for f_type, f_list in f_type_to_flist_dict.items():
for f in f_list:
f_index = f['idx']
f_name = f['name']
f_index_to_name_dict[f_index] = f_name
def generate_explanations(dt_as_json, df:DataFrame, f_index_to_name_dict, operators):
dt_as_json_str = str(dt_as_json)
cond_parsing_exception_occured = False
df = df.withColumn('features'+'_list',
udf(lambda x: x.toArray().tolist(), ArrayType(DoubleType()))
(df['features'])
)
# step 3 : parse and check whether current instance follows condition in perticular node
def parse_validate_cond(cond: str, f_vector: list):
cond_parts = cond.split()
condition_f_index = int(cond_parts[1])
condition_op = cond_parts[2]
condition_value = float(cond_parts[3])
f_value = f_vector[condition_f_index]
f_name = f_index_to_name_dict[condition_f_index].replace('numerical_features_', '').replace('encoded_numeric_', '').lower()
if operators[condition_op](f_value, condition_value):
return True, f_name + ' ' + condition_op + ' ' + str(round(condition_value,2))
return False, ''
# Step 4 : extract rules for an instance in a dataframe, going through nodes in a tree where instance is satisfying the rule, finally leading to a prediction node
def extract_rule(dt_as_json_str: str, f_vector: list, rule=""):
# variable declared in outer function is read only
# in inner if not explicitly declared to be nonlocal
nonlocal cond_parsing_exception_occured
dt_as_json = ast.literal_eval(dt_as_json_str)
child_l = dt_as_json['children']
for child in child_l:
name = child['name'].strip()
if name.startswith('Predict:'):
# remove last comma
return rule[0:rule.rindex(',')]
if name.startswith('feature'):
try:
res, cond = parse_validate_cond(child['name'], f_vector)
except Exception as e:
res = False
cond_parsing_exception_occured = True
if res:
rule += cond +', '
rule = extract_rule(str(child), f_vector, rule=rule)
return rule
df = df.withColumn('explanation',
udf(lambda dt, fv:extract_rule(dt, fv) ,StringType())
(lit(dt_as_json_str), df['features'+'_list'])
)
# log exception occured while trying to parse
# condition in decision tree node
if cond_parsing_exception_occured:
print('some node in decision tree has unexpected format')
return df
df = generate_explanations(debug_str_to_json(dt_m.toDebugString), df, f_index_to_name_dict, operators)
rows = df.select(['ball','keep','mall','hall','fall','explanation','prediction']).collect()
output :
-----------------------
[Row(ball=0, keep=4, mall=16, hall=8, fall=12, explanation='hall > 7.0, mall > 13.0, ball <= 0.5', prediction=21.0),
Row(ball=1, keep=5, mall=17, hall=9, fall=13, explanation='hall > 7.0, mall > 13.0, ball > 0.5, keep <= 5.5', prediction=31.0),
Row(ball=1, keep=6, mall=18, hall=10, fall=14, explanation='hall > 7.0, mall > 13.0, ball > 0.5, keep > 5.5', prediction=21.0),
Row(ball=3, keep=7, mall=10, hall=11, fall=15, explanation='hall > 7.0, mall <= 13.0', prediction=31.0),
Row(ball=1, keep=7, mall=10, hall=2, fall=15, explanation='hall <= 7.0', prediction=51.0),
Row(ball=0, keep=4, mall=16, hall=6, fall=12, explanation='hall <= 7.0', prediction=51.0),
Row(ball=1, keep=6, mall=18, hall=10, fall=14, explanation='hall > 7.0, mall > 13.0, ball > 0.5, keep > 5.5', prediction=21.0),
Row(ball=3, keep=7, mall=10, hall=11, fall=15, explanation='hall > 7.0, mall <= 13.0', prediction=31.0)]
output of dt_m.toDebugString:
-----------------------------------
'DecisionTreeClassificationModel (uid=DecisionTreeClassifier_2a17ae7633b9) of depth 4 with 9 nodes\n If (feature 3 <= 7.0)\n Predict: 51.0\n Else (feature 3 > 7.0)\n If (feature 2 <= 13.0)\n Predict: 31.0\n Else (feature 2 > 13.0)\n If (feature 0 <= 0.5)\n Predict: 21.0\n Else (feature 0 > 0.5)\n If (feature 1 <= 5.5)\n Predict: 31.0\n Else (feature 1 > 5.5)\n Predict: 21.0\n'
output of debug_str_to_json(dt_m.toDebugString):
------------------------------------
{'name': 'Root',
'children': [{'name': 'feature 3 <= 7.0',
'children': [{'name': 'Predict: 51.0'}]},
{'name': 'feature 3 > 7.0',
'children': [{'name': 'feature 2 <= 13.0',
'children': [{'name': 'Predict: 31.0'}]},
{'name': 'feature 2 > 13.0',
'children': [{'name': 'feature 0 <= 0.5',
'children': [{'name': 'Predict: 21.0'}]},
{'name': 'feature 0 > 0.5',
'children': [{'name': 'feature 1 <= 5.5',
'children': [{'name': 'Predict: 31.0'}]},
{'name': 'feature 1 > 5.5',
'children': [{'name': 'Predict: 21.0'}]}]}]}]}]}