【分类算法】K-NN

↘锁芯ラ 提交于 2020-01-21 01:53:37

K-NN的定义

今天,我们来分享一个“街知巷闻”入门级别的分类算法 —— K-NN。相信提到这个词的时候小伙伴们都有:噢~是它。这里题外话一下,为什么说K-NN是“街知巷闻”入门级别呢?其实他有如下特点:

  • 算法思想非常简单
  • 对数学以来少,非常适合初学者
  • 虽然它的体量小,却可以走完监督学习的整个流程
  • 可以通过它来入门监督学习,然后扩展其他算法

首先说一下,K-NN属于分类算法,那么分类算法是属于监督学习。所以K-NN是监督学习算法,它需要带label的数据。关于分类和聚类,我在过去的文章有讲过,大家可以回看一下。什么是分类算法? / 什么是聚类算法?
那么到底什么是K-NN呢?看过过去文章的都知道,我比较喜欢通过拆解算法名字来初步理解算法是做什么的。K-NN,全称叫做:K-Nearest Neighboors,中文翻译:K-最近的 邻居们。假设我们的数据如下图,有很多个点。首先解释Nearest的,看到远近就应该想到距离。那么常用的距离计算方式我们有欧式距离、马氏距离、名氏距离甚至绝对值距离也是OK的。而Neighboors就是邻居们的意思。但在数据里面,我们一般用点来表示。所以也可以理解为是附近的点。那么K是一个正数,它代表多少个点。
所以整个K-NN的理解就是:K个最近的邻居点。就像图中的绿色点那样,现在已经有蓝色与红色两种分类

  • 如果K=2,那么绿色点应该属于红色集合
  • 如果K=5,那么绿色点应该属于蓝色集合

在这里插入图片描述
上面提到了K-NN的基本思想。其实就是运用适合的距离算法,找到最近的K个点,然后对其进行分类。那么接下来将讲一下K-NN模型的基本流程。整体来讲其实和传统监督学习差不多,也是fit+predict。

  1. 输入数据:数据必须带label的
  2. 整理数据:实际项目中总有乱七八糟的数据,但我们学习一般使用干净的数据,所以这步在学习demo中忽略
  3. 训练模型:使用整理好的数据训练模型
  4. 传入数据:传入一些新的数据,又或者新的点
  5. 计算距离:计算新数据到个点的距离,并找出最近的K个
  6. 分类:观察K个点大多数属于哪一类,那么新数据就分到该类

K-NN的注意事项

其实,讲到这里K-NN的基本思想已经讲完了。下面将分享一些我在调试K-NN的经验。
首先是距离公式的使用。在K-NN里面,距离公式其实有很多种。其中最常用的就是欧式距离。当然我们也可以运用自己的数学知识去构建一个适合我们自己的距离公式。每一个距离公式都有自己的特点,所以在距离公式上我们需要有初步的了解才能更好的完成项目。一般距离公式会用p定义形容的。
综合统称为Lp距离公式:
Lp(xi,xj)=(l=1nxi(l)xj(l))1pL_p(x_i,x_j) = (\sum_{l=1}^n|x_i^{(l)} - x_j^{(l)}|)^{\frac{1}{p}}
比如当p=1时,就是曼哈顿距离:
L1(xi,xj)=l=1nxi(l)xj(l)L_1(x_i,x_j) = \sum_{l=1}^n|x_i^{(l)} - x_j^{(l)}|
当p=2时,就是欧式距离:
Lp(xi,xj)=(l=1nxi(l)xj(l)2)12L_p(x_i,x_j) = (\sum_{l=1}^n|x_i^{(l)} - x_j^{(l)}|^2)^{\frac{1}{2}}

接下来就是K的问题了。在K-NN算法里面,K的选择往往对结果影响非常大。就像上面的图一样,K=2和K=5完全是连各种不一样的结果。一般K的错误选择会有如下规律:

  • K过小,这时候会造成过拟合
  • K过大,这时候会造成欠拟合

所以,一般情况下我比较建议选择合适的K值。而经验告诉我,一般可以先定一个比较低的值,然后再慢慢提上去。当然还有一种思想叫:网格搜索。它主要是帮助我们选出最优超参数。这个有机会我会跟他家分享。

K-NN 与K-Means

很多小伙伴其实都会把K-NN 与K-Means搞混。其实按照我“自创”的名字拆解法,其实是非常容易的理解的。下面我列一下他们的区别:

K-NN K-Means
监督学习 非监督学习
分类算法 聚类算法
需要带标签的数据 不需要带标签的数据
工作原理:寻找最近的邻居点,并分到类别最多的簇中 工作原理:把最近几个邻居点聚为一个簇
K的定义:以新数据为原点,寻找K个最近邻居点 K的定义:规定一个簇中只允许有K个点

代码实现

由于算法过于简单,我就不自己去创造数据然后写一个代码了。这里直接饮用sklearn的官方案例给大家好了。然后我会代码中加以中文注释,让大家更容易看明白。

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import neighbors, datasets
#初始化K=15
n_neighbors = 15

# 加载iris数据
iris = datasets.load_iris()

# 切分数据及
X = iris.data[:, :2]
y = iris.target
#定义测试集和训练集的比例
h = .02 

# 创建颜色地图,主要用于展示
cmap_light = ListedColormap(['orange', 'cyan', 'cornflowerblue'])
cmap_bold = ListedColormap(['darkorange', 'c', 'darkblue'])

for weights in ['uniform', 'distance']:
    # 创建K-NN模型并训练模型
    clf = neighbors.KNeighborsClassifier(n_neighbors, weights=weights)
    clf.fit(X, y)

    # 拟定图片边界,也是为了展示
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    # 输入新数据并输出结果                     
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

    # 绘制图片
    Z = Z.reshape(xx.shape)
    plt.figure()
    plt.pcolormesh(xx, yy, Z, cmap=cmap_light)

    
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold,
                edgecolor='k', s=20)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.title("3-Class classification (k = %i, weights = '%s')"
              % (n_neighbors, weights))
# 展示结果
plt.show()

总结

到目前为止,K-NN就已经讲完了。大家有木有感觉非常简单。白话点说就是:寻找该数据点的K个最近邻居点,然后判断这些邻居点最多属于那一个类,就把该数据分为哪一类。其实算法很简单,但使用的时候还是需要大家多加调试的。
下一篇文章我将分享SVM(支持向量机)。这可是一个大工程啊!!反正我反复看了很多资料,很多东西才明白是干啥的,所以我也准备像决策树一样开几章来分享。
点我阅读更多算法分享

参考文献
scikit-learn 官方案例

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!