Plot dendrogram using sklearn.AgglomerativeClustering

后端 未结 5 1143
刺人心
刺人心 2021-01-31 15:17

I\'m trying to build a dendrogram using the children_ attribute provided by AgglomerativeClustering, but so far I\'m out of luck. I can\'t use sc

5条回答
  •  醉梦人生
    2021-01-31 15:26

    From the official docs:

    import numpy as np
    
    from matplotlib import pyplot as plt
    from scipy.cluster.hierarchy import dendrogram
    from sklearn.datasets import load_iris
    from sklearn.cluster import AgglomerativeClustering
    
    
    def plot_dendrogram(model, **kwargs):
        # Create linkage matrix and then plot the dendrogram
    
        # create the counts of samples under each node
        counts = np.zeros(model.children_.shape[0])
        n_samples = len(model.labels_)
        for i, merge in enumerate(model.children_):
            current_count = 0
            for child_idx in merge:
                if child_idx < n_samples:
                    current_count += 1  # leaf node
                else:
                    current_count += counts[child_idx - n_samples]
            counts[i] = current_count
    
        linkage_matrix = np.column_stack([model.children_, model.distances_,
                                          counts]).astype(float)
    
        # Plot the corresponding dendrogram
        dendrogram(linkage_matrix, **kwargs)
    
    
    iris = load_iris()
    X = iris.data
    
    # setting distance_threshold=0 ensures we compute the full tree.
    model = AgglomerativeClustering(distance_threshold=0, n_clusters=None)
    
    model = model.fit(X)
    plt.title('Hierarchical Clustering Dendrogram')
    # plot the top three levels of the dendrogram
    plot_dendrogram(model, truncate_mode='level', p=3)
    plt.xlabel("Number of points in node (or index of point if no parenthesis).")
    plt.show()
    

    Note that this currently (as of scikit-learn v0.23) only will work when calling AgglomerativeClustering with the distance_threshold parameter, but as of v0.24 you will be able to force the calculation of distances by setting compute_distances to true (see nightly build docs).

提交回复
热议问题