[笔记] 使用numpy手写k-means算法

时光怂恿深爱的人放手 提交于 2020-03-30 02:49:08

代码包括数据生成、可视化。

注意:下面代码仅供参考,实际使用还需加上一些约束,如迭代次数需要有个最大值,等等。

import numpy as np
from matplotlib import pyplot as plt
# - generate random data

def generate_data(n_point_per_cate, center_point_list):
    """
    n_point_per_cate:
        point number per category
    center_point_list:
        center point list
    """
    
    points_list = []
    for point in center_point_list:
        points_list.append(np.random.randn(n_point_per_cate, 2) + np.array(point))
    return np.concatenate(points_list, axis=0)
# - generate random data

data = generate_data(100, [[3,4], [10,-4], [-5,0]])
data.shape
(300, 2)
# - visulize data

plt.scatter(data[:,0], data[:,1])

# - k-means function

def kmeans(data, K):
    """
    data: input data
    K: category number
    """
    
    n,d = data.shape
    cate_list = np.zeros(n)
    
    # - random centroid
    centroid_list = np.random.randn(K,d)
    
    is_ok = False
    lr = 0.5
    while not is_ok:
        for j in range(n):
            nearest_centeroid_index = None
            nearest_centeroid_distance = float('inf')
            
            for k in range(K):
                dist = np.linalg.norm(centroid_list[k] - data[j])
                if dist < nearest_centeroid_distance:
                    nearest_centeroid_distance = dist
                    nearest_centeroid_index = k
            cate_list[j] = nearest_centeroid_index
        
        # - update centroid_list
        last_centroid_list = centroid_list.copy()
        for j in range(K):
            new_centroid = np.mean(data[cate_list==j], axis=0)
            centroid_list[j] = centroid_list[j]*lr + new_centroid*(1-lr) 
        print('centroid_list=', centroid_list)
            
        # - visualize
        plt.scatter(data[:,0], data[:,1], c=cate_list)
        plt.plot(centroid_list[:,0], centroid_list[:,1], 'r+')
        plt.show()
        
        # - check if need more update
        diff = np.linalg.norm(np.linalg.norm(centroid_list-last_centroid_list, axis=0))
        print('diff=', diff)
        if diff < 0.1:
            is_ok = True
kmeans(data, K=3)

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!