Plot dendrogram using sklearn.AgglomerativeClustering

后端 未结 5 1156
刺人心
刺人心 2021-01-31 15:17

I\'m trying to build a dendrogram using the children_ attribute provided by AgglomerativeClustering, but so far I\'m out of luck. I can\'t use sc

相关标签:
5条回答
  • 2021-01-31 15:21

    I came across the exact same problem some time ago. The way I managed to plot the damn dendogram was using the software package ete3. This package is able to flexibly plot trees with various options. The only difficulty was to convert sklearn's children_ output to the Newick Tree format that can be read and understood by ete3. Furthermore, I need to manually compute the dendrite's span because that information was not provided with the children_. Here is a snippet of the code I used. It computes the Newick tree and then shows the ete3 Tree datastructure. For more details on how to plot, take a look here

    import numpy as np
    from sklearn.cluster import AgglomerativeClustering
    import ete3
    
    def build_Newick_tree(children,n_leaves,X,leaf_labels,spanner):
        """
        build_Newick_tree(children,n_leaves,X,leaf_labels,spanner)
    
        Get a string representation (Newick tree) from the sklearn
        AgglomerativeClustering.fit output.
    
        Input:
            children: AgglomerativeClustering.children_
            n_leaves: AgglomerativeClustering.n_leaves_
            X: parameters supplied to AgglomerativeClustering.fit
            leaf_labels: The label of each parameter array in X
            spanner: Callable that computes the dendrite's span
    
        Output:
            ntree: A str with the Newick tree representation
    
        """
        return go_down_tree(children,n_leaves,X,leaf_labels,len(children)+n_leaves-1,spanner)[0]+';'
    
    def go_down_tree(children,n_leaves,X,leaf_labels,nodename,spanner):
        """
        go_down_tree(children,n_leaves,X,leaf_labels,nodename,spanner)
    
        Iterative function that traverses the subtree that descends from
        nodename and returns the Newick representation of the subtree.
    
        Input:
            children: AgglomerativeClustering.children_
            n_leaves: AgglomerativeClustering.n_leaves_
            X: parameters supplied to AgglomerativeClustering.fit
            leaf_labels: The label of each parameter array in X
            nodename: An int that is the intermediate node name whos
                children are located in children[nodename-n_leaves].
            spanner: Callable that computes the dendrite's span
    
        Output:
            ntree: A str with the Newick tree representation
    
        """
        nodeindex = nodename-n_leaves
        if nodename<n_leaves:
            return leaf_labels[nodeindex],np.array([X[nodeindex]])
        else:
            node_children = children[nodeindex]
            branch0,branch0samples = go_down_tree(children,n_leaves,X,leaf_labels,node_children[0])
            branch1,branch1samples = go_down_tree(children,n_leaves,X,leaf_labels,node_children[1])
            node = np.vstack((branch0samples,branch1samples))
            branch0span = spanner(branch0samples)
            branch1span = spanner(branch1samples)
            nodespan = spanner(node)
            branch0distance = nodespan-branch0span
            branch1distance = nodespan-branch1span
            nodename = '({branch0}:{branch0distance},{branch1}:{branch1distance})'.format(branch0=branch0,branch0distance=branch0distance,branch1=branch1,branch1distance=branch1distance)
            return nodename,node
    
    def get_cluster_spanner(aggClusterer):
        """
        spanner = get_cluster_spanner(aggClusterer)
    
        Input:
            aggClusterer: sklearn.cluster.AgglomerativeClustering instance
    
        Get a callable that computes a given cluster's span. To compute
        a cluster's span, call spanner(cluster)
    
        The cluster must be a 2D numpy array, where the axis=0 holds
        separate cluster members and the axis=1 holds the different
        variables.
    
        """
        if aggClusterer.linkage=='ward':
            if aggClusterer.affinity=='euclidean':
                spanner = lambda x:np.sum((x-aggClusterer.pooling_func(x,axis=0))**2)
        elif aggClusterer.linkage=='complete':
            if aggClusterer.affinity=='euclidean':
                spanner = lambda x:np.max(np.sum((x[:,None,:]-x[None,:,:])**2,axis=2))
            elif aggClusterer.affinity=='l1' or aggClusterer.affinity=='manhattan':
                spanner = lambda x:np.max(np.sum(np.abs(x[:,None,:]-x[None,:,:]),axis=2))
            elif aggClusterer.affinity=='l2':
                spanner = lambda x:np.max(np.sqrt(np.sum((x[:,None,:]-x[None,:,:])**2,axis=2)))
            elif aggClusterer.affinity=='cosine':
                spanner = lambda x:np.max(np.sum((x[:,None,:]*x[None,:,:]))/(np.sqrt(np.sum(x[:,None,:]*x[:,None,:],axis=2,keepdims=True))*np.sqrt(np.sum(x[None,:,:]*x[None,:,:],axis=2,keepdims=True))))
            else:
                raise AttributeError('Unknown affinity attribute value {0}.'.format(aggClusterer.affinity))
        elif aggClusterer.linkage=='average':
            if aggClusterer.affinity=='euclidean':
                spanner = lambda x:np.mean(np.sum((x[:,None,:]-x[None,:,:])**2,axis=2))
            elif aggClusterer.affinity=='l1' or aggClusterer.affinity=='manhattan':
                spanner = lambda x:np.mean(np.sum(np.abs(x[:,None,:]-x[None,:,:]),axis=2))
            elif aggClusterer.affinity=='l2':
                spanner = lambda x:np.mean(np.sqrt(np.sum((x[:,None,:]-x[None,:,:])**2,axis=2)))
            elif aggClusterer.affinity=='cosine':
                spanner = lambda x:np.mean(np.sum((x[:,None,:]*x[None,:,:]))/(np.sqrt(np.sum(x[:,None,:]*x[:,None,:],axis=2,keepdims=True))*np.sqrt(np.sum(x[None,:,:]*x[None,:,:],axis=2,keepdims=True))))
            else:
                raise AttributeError('Unknown affinity attribute value {0}.'.format(aggClusterer.affinity))
        else:
            raise AttributeError('Unknown linkage attribute value {0}.'.format(aggClusterer.linkage))
        return spanner
    
    clusterer = AgglomerativeClustering(n_clusters=2,compute_full_tree=True) # You can set compute_full_tree to 'auto', but I left it this way to get the entire tree plotted
    clusterer.fit(X) # X for whatever you want to fit
    spanner = get_cluster_spanner(clusterer)
    newick_tree = build_Newick_tree(clusterer.children_,clusterer.n_leaves_,X,leaf_labels,spanner) # leaf_labels is a list of labels for each entry in X
    tree = ete3.Tree(newick_tree)
    tree.show()
    
    0 讨论(0)
  • 2021-01-31 15:24

    Here is a simple function for taking a hierarchical clustering model from sklearn and plotting it using the scipy dendrogram function. Seems like graphing functions are often not directly supported in sklearn. You can find an interesting discussion of that related to the pull request for this plot_dendrogram code snippet here.

    I'd clarify that the use case you describe (defining number of clusters) is available in scipy: after you've performed the hierarchical clustering using scipy's linkage you can cut the hierarchy to whatever number of clusters you want using fcluster with number of clusters specified in the t argument and criterion='maxclust' argument.

    0 讨论(0)
  • 2021-01-31 15:26

    From the official docs:

    import numpy as np
    
    from matplotlib import pyplot as plt
    from scipy.cluster.hierarchy import dendrogram
    from sklearn.datasets import load_iris
    from sklearn.cluster import AgglomerativeClustering
    
    
    def plot_dendrogram(model, **kwargs):
        # Create linkage matrix and then plot the dendrogram
    
        # create the counts of samples under each node
        counts = np.zeros(model.children_.shape[0])
        n_samples = len(model.labels_)
        for i, merge in enumerate(model.children_):
            current_count = 0
            for child_idx in merge:
                if child_idx < n_samples:
                    current_count += 1  # leaf node
                else:
                    current_count += counts[child_idx - n_samples]
            counts[i] = current_count
    
        linkage_matrix = np.column_stack([model.children_, model.distances_,
                                          counts]).astype(float)
    
        # Plot the corresponding dendrogram
        dendrogram(linkage_matrix, **kwargs)
    
    
    iris = load_iris()
    X = iris.data
    
    # setting distance_threshold=0 ensures we compute the full tree.
    model = AgglomerativeClustering(distance_threshold=0, n_clusters=None)
    
    model = model.fit(X)
    plt.title('Hierarchical Clustering Dendrogram')
    # plot the top three levels of the dendrogram
    plot_dendrogram(model, truncate_mode='level', p=3)
    plt.xlabel("Number of points in node (or index of point if no parenthesis).")
    plt.show()
    

    Note that this currently (as of scikit-learn v0.23) only will work when calling AgglomerativeClustering with the distance_threshold parameter, but as of v0.24 you will be able to force the calculation of distances by setting compute_distances to true (see nightly build docs).

    0 讨论(0)
  • 2021-01-31 15:28

    Use the scipy implementation of agglomerative clustering instead. Here is an example.

    from scipy.cluster.hierarchy import dendrogram, linkage
    
    data = [[0., 0.], [0.1, -0.1], [1., 1.], [1.1, 1.1]]
    
    Z = linkage(data)
    
    dendrogram(Z)  
    

    You can find documentation for linkage here and documentation for dendrogram here.

    0 讨论(0)
  • 2021-01-31 15:36

    For those willing to step out of Python and use the robust D3 library, it's not super difficult to use the d3.cluster() (or, I guess, d3.tree()) APIs to achieve a nice, customizable result.

    See the jsfiddle for a demo.

    The children_ array luckily functions easily as a JS array, and the only intermediary step is to use d3.stratify() to turn it into a hierarchical representation. Specifically, we need each node to have an id and a parentId:

    var N = 272;  // Your n_samples/corpus size.
    var root = d3.stratify()
      .id((d,i) => i + N)
      .parentId((d, i) => {
        var parIndex = data.findIndex(e => e.includes(i + N));
        if (parIndex < 0) {
          return; // The root should have an undefined parentId.
        }
        return parIndex + N;
      })(data); // Your children_
    

    You end up with at least O(n^2) behaviour here due to the findIndex line, but it probably doesn't matter until your n_samples becomes huge, in which case, you could precompute a more efficient index.

    Beyond that, it's pretty much plug and chug use of d3.cluster(). See mbostock's canonical block or my JSFiddle.

    N.B. For my use case, it sufficed merely to show non-leaf nodes; it's a bit trickier to visualise the samples/leaves, since these might not all be in the children_ array explicitly.

    0 讨论(0)
提交回复
热议问题