代码包括数据生成、可视化。
注意:下面代码仅供参考,实际使用还需加上一些约束,如迭代次数需要有个最大值,等等。
import numpy as np from matplotlib import pyplot as plt
# - generate random data def generate_data(n_point_per_cate, center_point_list): """ n_point_per_cate: point number per category center_point_list: center point list """ points_list = [] for point in center_point_list: points_list.append(np.random.randn(n_point_per_cate, 2) + np.array(point)) return np.concatenate(points_list, axis=0)
# - generate random data data = generate_data(100, [[3,4], [10,-4], [-5,0]]) data.shape
(300, 2)
# - visulize data plt.scatter(data[:,0], data[:,1])
# - k-means function def kmeans(data, K): """ data: input data K: category number """ n,d = data.shape cate_list = np.zeros(n) # - random centroid centroid_list = np.random.randn(K,d) is_ok = False lr = 0.5 while not is_ok: for j in range(n): nearest_centeroid_index = None nearest_centeroid_distance = float('inf') for k in range(K): dist = np.linalg.norm(centroid_list[k] - data[j]) if dist < nearest_centeroid_distance: nearest_centeroid_distance = dist nearest_centeroid_index = k cate_list[j] = nearest_centeroid_index # - update centroid_list last_centroid_list = centroid_list.copy() for j in range(K): new_centroid = np.mean(data[cate_list==j], axis=0) centroid_list[j] = centroid_list[j]*lr + new_centroid*(1-lr) print('centroid_list=', centroid_list) # - visualize plt.scatter(data[:,0], data[:,1], c=cate_list) plt.plot(centroid_list[:,0], centroid_list[:,1], 'r+') plt.show() # - check if need more update diff = np.linalg.norm(np.linalg.norm(centroid_list-last_centroid_list, axis=0)) print('diff=', diff) if diff < 0.1: is_ok = True
kmeans(data, K=3)
来源:https://www.cnblogs.com/journeyonmyway/p/12596287.html