KMeans均值算法

余生颓废 提交于 2019-11-29 11:24:43

K均值聚类算法

算法思想

k-means聚类算法

  1. 随机选取K个对象作为初始的聚类中心;
  2. 计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心;
  3. 根据分配结果更新聚类中心。均值,顾名思义,对类中的所有样本点求均值,即为新的中心点;
  4. 循环2,3两步,终止条件可以是没有(或最小数目)对象被重新分配给不同的聚类,没有(或最小数目)聚类中心再发生变化,误差平方和局部最小。

代码实现

样本点存储样例(数据之间以\t分隔)
测试文件
因为这里测试样本点很少,所以设置当每个点所属类不再变化时即结束循环,此时得到最终聚类结果,如果样本数较多可设置迭代轮数来控制
才开始学python,还不熟练,参照着一些代码,在上面进行了修改,最后的实现如下:

# yyf date:19/09/05
import numpy as np
import matplotlib.pyplot as plt

# 读取数据
def readfile(filename):
    data = np.loadtxt(filename, delimiter='\t')
    return data

# 选取k个初始中心
def initCenterp(data,k):
    m = np.shape(data)[0]
    centerpoints = np.zeros((k, 2))
    list = np.random.choice(range(m), k, replace=False) # replace=False避免选取中心点重复的情况

    n = 0
    for i in range(m):
        if i in list:
            centerpoints[n, ] = data[i, :]
            n+=1
    return centerpoints

# 计算样本与中心之间的距离
def countDist(a,b):
    dist = np.sqrt(np.sum((a-b)**2))
    return dist

def KMeans(data,k):
    flag = True
    m = np.shape(data)[0]
    # n = np.shape(data)[1] 此处默认为2
    # 初始化clusterSet
    clusterSet = np.zeros((m, 2))
    #初始化中心点
    centerpoints = initCenterp(data, k)
    '''存储中心点数据'''

    # 重复步骤
    while flag:
        flag = False

        # 对每个样本计算与各个中心之间的距离,选取距离最短的加入该群组
        for i in range(m):
            mindist = 100000.0
            minIndex = 0

            for j in range(k):
                dist = countDist(data[i, :], centerpoints[j, :])
                if dist < mindist:
                    minIndex = j
                    mindist = dist
            #若所有样本点均未更改分组,则结束循环
            if clusterSet[i, 0] != minIndex:
                flag = True
                clusterSet[i, 0] = minIndex
                clusterSet[i, 1] = mindist

        # 根据群中所有样本更新中心
        for i in range(k):
            pointsInCluster = data[np.nonzero(clusterSet[:, 0] == i)]  # 获取i簇所有的点
            centerpoints[i, :] = np.mean(pointsInCluster, axis=0)
            '''axis=0表示对各列求均值'''
    # 直到所有样本所属群组均不再变化停止
    return centerpoints, clusterSet

# 显示聚类的结果
def showCluster(data,clusterSet):
    m = np.shape(data)[0]
    mark = ['or', 'ob', '^g', '*k', '*y']
    for i in range(m):
        markIndex = int(clusterSet[i, 0])  # 不转换成int会报错TypeError: list indices must be integers or slices, not numpy.float64
        plt.plot(data[i, 0], data[i, 1], mark[markIndex])
    plt.show()

data = readfile('test.txt')
'''data存储样本点'''
k = 3 #类的数量
centerpoints, clusterSet = KMeans(data, k)
'''clusterSet存储聚类结果,第一行存所属类号,第二行存储到中心的距离'''
showCluster(data, clusterSet)

运行结果

聚类结果

注意事项

  • 在选取k个出事中心点时,注意防止出现重复的情况
  • k的选取很重要,通过观察选取合适值,值过大则失去了聚类的意义,过小则聚类效果不明显
  • 注意测试文件中的分隔符
    在使用pycharm编辑txt文件时,敲制表符来分割样本坐标,保存后txt文件中存储的却是空格符,所以如果数据读取部分报错,最好查看一下是否文件样式的问题
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!