绑定完请刷新页面
取消
刷新

分享好友

×
取消 复制
机器学习之 k 近邻
2017-04-28 13:11:58

核心思想

KNN算法假设给定的训练集中的实例都已经分好类了,对于新的实例,根据离它近的k个训练实例的类别来预测它的类别。即这k个实例大多数属于某个类别则该实例就属于某个类别。比如k为5,离新实例a近的5个样本的情况为,3个样本属于A类,1个样本属于B类,一个样本属于C类,那么新实例a属于A类。

常用距离

  • 欧氏距离
  • 曼哈顿距离
  • 切比雪夫距离

k值的影响

k值的选取可能会影响到分类结果,如下图,k=3和k=5时的分类结果是不同的。

  1. k值小可能会导致预测结果对近邻的样本点敏感,如果刚好是噪音则会导致预测结果出错,容易发生过拟合。近似误差小,估计误差大。
  2. k值大可能会导致较远的样本也影响预测,也可能会导致预测错误。近似误差大,估计误差小。
  3. k值一般先取较小的数,再用交叉验证方法选择优k值。

算法实现

两种方式:线性扫描和kd树。

线性扫描

KNN的简单朴素的方法即直接线性扫描,大致步骤如下:

  1. 计算待预测数据与各训练样本之间的距离;
  2. 按照距离递增排序;
  3. 选择距离小的k个点;
  4. 计算这k个点类别的频率,高的即为待预测数据的类别。

代码实现

from numpy import *
import pylab as pl

dataSet = array([[11, 12], [12, 12], [11, 11], [11, 16], [12, 16], [17, 11], [17, 12]])
classes = ['A', 'A', 'A', 'B', 'B', 'C', 'C']
k = 3
dot = [13, 13]
type
r = 0
dataSize = dataSet.shape[0]
diff = tile(dot, (dataSize, 1)) - dataSet
sqdiff = diff ** 2
squareDist = sum(sqdiff, axis=1)
dist = squareDist ** 0.5
sortedDistIndex = argsort(dist)
classCount = {}
for i in range(k):
    label = classes[sortedDistIndex[i]]
    classCount[label] = classCount.get(label,0) + 1
    if dist[i] > r:
        r = dist[i]
maxCount = 0
for key, value in classCount.items():
    if value > maxCount:
        maxCount = value
        type = key
pl.plot(dot[0], dot[1], 'ok')
circle = [i*pi/180 for i in range(0,360)]
x = cos(circle)*r+dot[0]
y = sin(circle)*r+dot[1]
pl.plot(x, y, 'r')
pl.plot([point[0] for point in dataSet[0:3]], [point[1] for point in dataSet[0:3]], 'og')
pl.plot([point[0] for point in dataSet[3:5]], [point[1] for point in dataSet[3:5]], 'or')
pl.plot([point[0] for point in dataSet[5:7]], [point[1] for point in dataSet[5:7]], 'oy')
pl.show()复制代码

kd树

线性扫描非常耗时,为了减少计算距离的次数提高效率,使用kd树方法,它能快速地找到查询点近邻。

可以通过将搜索空间进行层次划分建立索引树以加快检索速度。

对于二维空间,它终要划分的空间类似如下,

决定在哪个维度上进行分割是由所有数据在各个维度的方差决定的,方差越大说明该维度上的数据波动越大,更应该再该维度上对点进行划分。例如x维度方差较大,所以以x维度方向划分。

分割时一般取分割维度上的所有值的中值的点,比如下图,次计算方差较大的维度为x维度,中值点为A,以x=Ax分割,接着对分割后的点分别又继续分割,计算方差并寻找中值,以y=Cy、y=By分割,以此类推。

kd树查找

从根节点开始查找,直到叶子节点,整个过程将短距离d和相应的点记录下来。

回溯,通过计算待预测的点到分割平面的距离l与短距离d比较,看是否要进入节点的相邻空间去查找。回溯的过程是为了确认是否有必要进入相邻子空间去搜索,当待预测点到近点的距离d大于待预测点到分割面的距离l时,则需要到相邻子空间查找,否则则没必要,直接往上一层回溯。

欢迎关注:

分享好友

分享这个小栈给你的朋友们,一起进步吧。

远洋号
创建时间:2020-05-19 15:46:16
《图解数据结构与算法》《Tomcat内核设计剖析》书籍作者,公众号:《远洋号》,笔名:seaboat,擅长工程算法、人工智能算法、自然语言处理、架构、分布式、高并发、大数据、搜索引擎等方面的技术,大多数编程语言都会使用但更擅长Java、Python、C++。平时喜欢看书、写作、运动,擅长的项目有篮球、跑步、游泳、健身、羽毛球。崇尚开源,崇尚技术自由,更崇尚思想自由。
展开
订阅须知

• 所有用户可根据关注领域订阅专区或所有专区

• 付费订阅:虚拟交易,一经交易不退款;若特殊情况,可3日内客服咨询

• 专区发布评论属默认订阅所评论专区(除付费小栈外)

栈主、嘉宾

查看更多
  • seaboat
    栈主

小栈成员

查看更多
  • 小雨滴
  • Tester9456
  • 栈栈
  • dkl187788
戳我,来吐槽~