Plotting the KMeans Cluster Centers for every iteration in Python

六月ゝ 毕业季﹏ 提交于 2021-01-05 07:22:26

问题


I created a dataset with 6 clusters and visualize it with the code below, and find the cluster center points for every iteration, now i want to visualize demonstration of update of the cluster centroids in KMeans algorithm. This demonstration should include first four iterations by generating 2×2-axis figure. I found the points but i cant plot them, can you please check out my code and by looking that, help me write the algorithm to scatter plot?

Here is my code so far:

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.datasets import make_blobs
data = make_blobs(n_samples=200, n_features=8, 
                           centers=6, cluster_std=1.8,random_state=101)
data[0].shape
plt.scatter(data[0][:,0],data[0][:,1],c=data[1],cmap='brg')

plt.show()
from sklearn.cluster import KMeans

print("First iteration points:")
kmeans = KMeans(n_clusters=6,random_state=0,max_iter=1)
kmeans.fit(data[0])
centroids=kmeans.cluster_centers_
print(kmeans.cluster_centers_)
print("Second iteration points:")
kmeans = KMeans(n_clusters=6,random_state=0,max_iter=2)
kmeans.fit(data[0])
print(kmeans.cluster_centers_)
print("Third iteration points:")
kmeans = KMeans(n_clusters=6,random_state=0,max_iter=3)
kmeans.fit(data[0])
print(kmeans.cluster_centers_)
print("Forth iteration points:")
kmeans = KMeans(n_clusters=6,random_state=0,max_iter=4)
kmeans.fit(data[0])
print(kmeans.cluster_centers_)

回答1:


You can use the plt.scatter() and plt.subplots() to achieve this as follows:

import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
data = make_blobs(n_samples=200, n_features=8, 
                           centers=6, cluster_std=1.8,random_state=101)

fig, ax = plt.subplots(nrows=2, ncols=2,figsize=(10,10))

from sklearn.cluster import KMeans
c=d=0
for i in range(4):
    ax[c,d].title.set_text(f"{i+1} iteration points:")
    kmeans = KMeans(n_clusters=6,random_state=0,max_iter=i+1)
    kmeans.fit(data[0])
    centroids=kmeans.cluster_centers_
    ax[c,d].scatter(data[0][:,0],data[0][:,1],c=data[1],cmap='brg')
    ax[c,d].scatter(kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1], s=200, c='black')
    d+=1
    if d==2:
        c+=1
        d=0

This will produce:



来源:https://stackoverflow.com/questions/65449241/plotting-the-kmeans-cluster-centers-for-every-iteration-in-python

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!