Getting the leaf probabilities of a tree model in spark

前端 未结 1 721
[愿得一人]
[愿得一人] 2021-01-18 03:18

I\'m trying to refactor a trained spark tree-based model (RandomForest or GBT classifiers) in such a way it can be exported in environments without spark. The toDebugS

相关标签:
1条回答
  • 2021-01-18 03:28

    I'm trying to refactor a trained spark tree-based model (RandomForest or GBT classifiers) in such a way it can be exported in environments without spark. The

    Given growing number of tools designed for real-time serving of Spark (and other) models, that's probably reinventing the wheel.

    However if you want to access model internals from plain Python it is best to load its serialized form.

    Let's say you have:

    from pyspark.ml.classification import RandomForestClassificationModel
    
    rf_model: RandomForestClassificationModel
    path: str  # Absolute path
    

    And you save the model:

    rf_model.write().save(path)
    

    You can load it back using Parquet reader that supports mixes of struct and list types. Model writer writes both node data:

    node_data = spark.read.parquet("{}/data".format(path))
    
    node_data.printSchema()
    
    root
     |-- treeID: integer (nullable = true)
     |-- nodeData: struct (nullable = true)
     |    |-- id: integer (nullable = true)
     |    |-- prediction: double (nullable = true)
     |    |-- impurity: double (nullable = true)
     |    |-- impurityStats: array (nullable = true)
     |    |    |-- element: double (containsNull = true)
     |    |-- rawCount: long (nullable = true)
     |    |-- gain: double (nullable = true)
     |    |-- leftChild: integer (nullable = true)
     |    |-- rightChild: integer (nullable = true)
     |    |-- split: struct (nullable = true)
     |    |    |-- featureIndex: integer (nullable = true)
     |    |    |-- leftCategoriesOrThreshold: array (nullable = true)
     |    |    |    |-- element: double (containsNull = true)
     |    |    |-- numCategories: integer (nullable = true)
    

    and tree metadata:

    tree_meta = spark.read.parquet("{}/treesMetadata".format(path))
    
    tree_meta.printSchema()                            
    root
     |-- treeID: integer (nullable = true)
     |-- metadata: string (nullable = true)
     |-- weights: double (nullable = true)
    

    where the former one provides all the information you need, as the prediction process is basically an aggregation of impurtityStats *.

    You could also access this data directly using underlying Java objects

    from  collections import namedtuple
    import numpy as np
    
    LeafNode = namedtuple("LeafNode", ("prediction", "impurity"))
    InternalNode = namedtuple(
        "InternalNode", ("left", "right", "prediction", "impurity", "split"))
    CategoricalSplit = namedtuple("CategoricalSplit", ("feature_index", "categories"))
    ContinuousSplit = namedtuple("ContinuousSplit", ("feature_index", "threshold"))
    
    def jtree_to_python(jtree):
        def jsplit_to_python(jsplit):
            if jsplit.getClass().toString().endswith(".ContinuousSplit"):
                return ContinuousSplit(jsplit.featureIndex(), jsplit.threshold())
            else:
                jcat = jsplit.toOld().categories()
                return CategoricalSplit(
                    jsplit.featureIndex(),
                    [jcat.apply(i) for i in range(jcat.length())])
    
        def jnode_to_python(jnode):
            prediction = jnode.prediction()        
            stats = np.array(list(jnode.impurityStats().stats()))
    
            if jnode.numDescendants() != 0:  # InternalNode
                left = jnode_to_python(jnode.leftChild())
                right = jnode_to_python(jnode.rightChild())
                split = jsplit_to_python(jnode.split())
    
                return InternalNode(left, right, prediction, stats, split)            
    
            else:
                return LeafNode(prediction, stats) 
    
        return jnode_to_python(jtree.rootNode())
    

    which can be applied to RandomForestModel like this:

    nodes = [jtree_to_python(t) for t in rf_model._java_obj.trees()]
    

    Furthermore such structure can be easily used to make predictions, for both individual trees (warning: Python 3.7+ ahead. For legacy usage please refer to functools documentation):

    from functools import singledispatch
    
    @singledispatch
    def should_go_left(split, vector): pass
    
    @should_go_left.register
    def _(split: CategoricalSplit, vector):
        return vector[split.feature_index] in split.categories
    
    @should_go_left.register
    def _(split: ContinuousSplit, vector):
        return vector[split.feature_index] <= split.threshold
    
    @singledispatch
    def predict(node, vector): pass
    
    @predict.register
    def _(node: LeafNode, vector):
        return node.prediction, node.impurity
    
    @predict.register
    def _(node: InternalNode, vector):
        return predict(
            node.left if should_go_left(node.split, vector) else node.right,
            vector
        )
    

    and forests:

    from typing import Iterable, Union
    
    def predict_probability(nodes: Iterable[Union[InternalNode, LeafNode]], vector):
        total = np.array([
            v / v.sum() for _, v in  (
                predict(node, vector) for node in nodes
            )
        ]).sum(axis=0)
        return total / total.sum()
    

    That however depends on the internal API (and weakness of Scala package-scoped access modifiers) and might break in the future.


    * DataFrame as loaded from data path can be easily transformed to a structure compatible with predict and predict_probability functions defined above.

    from pyspark.sql.dataframe import DataFrame 
    from itertools import groupby
    from operator import itemgetter
    
    
    def model_data_to_tree(tree_data: DataFrame):
        def dict_to_tree(node_id, nodes):
            node = nodes[node_id]
            prediction = node.prediction
            impurity = np.array(node.impurityStats)
    
            if node.leftChild == -1 and node.rightChild == -1:
                return LeafNode(prediction, impurity)
            else:
                left = dict_to_tree(node.leftChild, nodes)
                right = dict_to_tree(node.rightChild, nodes)
                feature_index = node.split.featureIndex
                left_value = node.split.leftCategoriesOrThreshold
    
                split = (
                    CategoricalSplit(feature_index, left_value)
                    if node.split.numCategories != -1
                    else ContinuousSplit(feature_index, left_value[0])
                )
    
                return InternalNode(left, right, prediction, impurity, split)
    
        tree_id = itemgetter("treeID")
        rows = tree_data.collect()
        return ([
            dict_to_tree(0, {node.nodeData.id: node.nodeData for node in nodes})
            for tree, nodes in groupby(sorted(rows, key=tree_id), key=tree_id)
        ] if "treeID" in tree_data.columns
        else [dict_to_tree(0, {node.id: node for node in rows})])
    
    0 讨论(0)
提交回复
热议问题