K-近邻算法的一个简单例子

k近邻算法思想很简单,一个类的数据之间距离较近,单纯比较距离就好,下面注释比较清楚,常犯的错误也已经标记出来了
import numpy as np
import operator
from matplotlib import pyplot as plt
def classify0(inX,dataSet,labels,k):
    dataSetSize=dataSet.shape[0]
    diffMat=np.tile(inX,(dataSetSize,1))-dataSet               #计算输入数据点与训练集数据点的差值
    sqDiffMat=diffMat**2
    sqDistances=sqDiffMat.sum(axis=1)                          #计算与各个训练集数据点的距离
    distance=sqDistances**0.5
    sortedDistIndicies=np.argsort(distance)                    #返回distance按照从小到大排序的序列的索引
    classCount={}                                              #建立一个空的字典
    for i in range(k):
        voteIlabel=labels[sortedDistIndicies[0]]
        classCount[voteIlabel]=classCount.get(voteIlabel,0)+1   #classCount字典存储每个标签的数量,Get函数(key,default)若key不存在则以default为默认值创建该key
    sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]
# DBset=np.matrix([[1,1]
#                  [1,1.1]
#                  [2,2]
#                  [2,2.1]])
DBset=np.array([[1,1],[1,1.1],[2,2],[2,2.1]])
LB=('甲','甲','乙','乙')                                         #训练集的Rt(标记),注意数量要和训练集大小一致
x=(1.8,2.3)          #元组类型
kp=1
mp=classify0(x,DBset,LB,kp)
print(mp)
###############################绘图################################
plt.figure(1)
plt.xlabel('x')            #设置x轴标签
plt.ylabel('y')
plt.xlim([0,5])         #设置x轴显示范围
plt.ylim([0,5])
ax=plt.subplot('111')
ax.set_title('KNN')
plt.scatter(DBset[:2,0],DBset[:2,1],c='g')             #'甲'类    DBset[:3,0]需要说一下啊 (:]左开右闭(划重点)
plt.scatter(DBset[2:,0],DBset[2:,1],c='r')
plt.scatter(x[0],x[1],c='r')                           #这里不能用[:,1]这种形式,因为是一维的,没有切片的概念
plt.show()

输出结果:

猜你喜欢

转载自blog.csdn.net/weixin_37922873/article/details/81292580