sklearn agglomerative clustering: dynamically updating the number of clusters

后端 未结 2 410
Happy的楠姐
Happy的楠姐 2021-01-15 10:57

The documentation for sklearn.cluster.AgglomerativeClustering mentions that,

when varying the number of clusters and using caching, it may be advant

相关标签:
2条回答
  • 2021-01-15 11:31

    You set a cacheing directory with the paramater memory = 'mycachedir' and then if you set compute_full_tree=True, when you rerun fit with different values of n_clusters, it will used the cached tree rather than recomputing each time. To give you an example of how to do this with sklearn's gridsearch API:

    from sklearn.cluster import AgglomerativeClustering
    from sklearn.grid_search import GridSearchCV
    
    ac = AgglomerativeClustering(memory='mycachedir', 
                                 compute_full_tree=True)
    classifier = GridSearchCV(ac, 
                              {n_clusters: range(2,6)}, 
                              scoring = 'adjusted_rand_score', 
                              n_jobs=-1, verbose=2)
    classifier.fit(X,y)
    
    0 讨论(0)
  • 2021-01-15 11:38

    I know it's an old question, however the solution below might turn out helpful

    # scores = input matrix
    
    from scipy.cluster.hierarchy import linkage
    from scipy.cluster.hierarchy import cut_tree
    from sklearn.metrics import silhouette_score
    from sklearn.metrics.pairwise import euclidean_distances
    
    linkage_mat = linkage(scores, method="ward")
    euc_scores = euclidean_distances(scores)
    
    n_l = 2
    n_h = scores.shape[0]
    
    silh_score = -2
    # Selecting the best number of clusters based on the silhouette score
    for i in range(n_l, n_h):
        local_labels = list(cut_tree(linkage_mat, n_clusters=i).flatten())
        sc = silhouette_score(
            euc_scores,
            metric="precomputed",
            labels=local_labels,
            random_state=42)
        if silh_score < sc:
            silh_score = sc
            labels = local_labels
    
    n_clusters = len(set(labels))
    print(f"Optimal number of clusters: {n_clusters}")
    print(f"Best silhouette score: {silh_score}")
    # ...
    
    0 讨论(0)
提交回复
热议问题