K最近邻算法 【更新ing】

K最近邻算法 KNN

基本原理
离哪个类近,就属于该类
 
【例如:与下方新元素距离最近的三个点中,2个深色,所以新元素分类为深色】

K的含义就是最近邻的个数。在sklearn中,KNN的K值是通过n_neighbors参数来调节的
 

不适用:对数据集认真的预处理、对规模超大的数据集拟合的时间较长、对高维数据集拟合欠佳、对稀疏数据集无能为力
 
KNN用法
1.分类任务中的应用
from sklearn.datasets import make_blobs   #导入数据集生成器
from sklearn.neighbors import KNeighborsClassifier #导入KNN分类器
import matplotlib.pyplot as plt #导入画图工具
from sklearn.model_selection import train_test_split #导入数据集拆分工具
data = make_blobs(n_samples=200, centers = 2,random_state = 8) #生成样本数为200,分类为2的数据
X ,y = data #将生成的数据可视化
plt.scatter(X[:,0],X[:,1],c=y,cmap=plt.cm.spring,edgecolor='k')
plt.show()
 
 

 接下来用KNN拟合这些数据:

import numpy as np
clf = KNeighborsClassifier()
clf.fit(X,y)
 #以下代码用于画图
x_min,x_max = X[:,0].min() -1,X[:,0].max() + 1
y_min,y_max = X[:,1].min() -1,X[:,1].max() + 1
xx,yy = np.meshgrid(np.arange(x_min,x_max, .02),np.arange(y_min,y_max, .02))
Z = clf.predict(np.c_[xx.ravel(),yy.ravel()])
Z = Z.reshape(xx.shape)
plt.pcolormesh(xx,yy,Z,cmap=plt.cm.Pastel1)
plt.scatter(X[:, 0],X[:, 1],c = y,cmap = plt.cm.spring, edgecolor = 'k')
plt.xlim(xx.min(),xx.max())
plt.ylim(yy.min(),yy.max())
plt.title("Classifier:KNN")
plt.show()
 
如果有新的数据输入的话,模型就会自动将新的数据分到对应的类中
 
举例:
一个数据点,两个特征值是6.75,4.82,试验如下:
在上述代码的plt.show()之前加入:
plt.scatter(6.75,4.82,marker='*',c='red',s=200)
 

可见,分到了浅灰色一类

验证代码如下:

 

2.处理多元分类任务

 生成数据集:
 
#生成样本数500,分类数5的数据集
data2 = make_blobs(n_samples=500, centers = 5,random_state = 8)
X2,y2 = data2
#用散点图将数据可视化
plt.scatter(X2[:,0],X2[:,1],c = y2,cmap = plt.cm.spring,edgecolor='k')
plt.show()
 

用KNN建立模型拟合数据:

clf = KNeighborsClassifier()
clf.fit(X2,y2)
x_min,x_max = X2[:,0].min() -1,X2[:,0].max() + 1
y_min,y_max = X2[:,1].min() -1,X2[:,1].max() + 1
xx,yy = np.meshgrid(np.arange(x_min,x_max, .02),np.arange(y_min,y_max, .02))
Z = clf.predict(np.c_[xx.ravel(),yy.ravel()])
Z = Z.reshape(xx.shape)
plt.pcolormesh(xx,yy,Z,cmap=plt.cm.Pastel1)
plt.scatter(X2[:, 0],X2[:, 1],c = y2,cmap = plt.cm.spring, edgecolor = 'k')
plt.xlim(xx.min(),xx.max())
plt.ylim(yy.min(),yy.max())
plt.title("Classifier:KNN")
plt.show()

但仍然有小部分数据进入错误的分类

3.用于回归分析

  pass

 
 

猜你喜欢

转载自www.cnblogs.com/expedition/p/10707123.html