python机器学习手写算法系列——KNN分类

现实当中,富人的邻居可能也是富人,穷人的邻居可能也是穷人。如果你的邻居都是富人,那么,你很可能也是富人。基于此,我们有了KNN算法。KNN的全名是K-Nearest Neighbors,即K个最近的邻居。他通过距离被预测点最近的K个邻居来预测被预测点。

如下图所示,绿色的圆形是被预测点。它周围有红色三角和蓝色正方形。如果我们取K为3,那么,它的三个邻居是两个红色三角和一个蓝色正方形。因为它的邻居里面,最多的就是红色三角,所以我们预测它也是红色三角。

在这里插入图片描述
(图一)

首先,我们载入数据。这里我们以iris数据集为例。

import math
from collections import Counter 
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
# Create color maps
cmap_light = ListedColormap(['orange', 'cyan', 'cornflowerblue'])
cmap_bold = ListedColormap(['darkorange', 'c', 'darkblue'])
from sklearn import neighbors, datasets

n_neighbors = 15

# import some data to play with
iris = datasets.load_iris()

# we only take the first two features. We could avoid this ugly
# slicing by using a two-dim dataset
X = iris.data[:, :2]
y = iris.target

接着,我们写一个KNNClassifier类。这里,最关键的参数当然是k,或者叫n_neighbors。接着,我们写一个fit方法。这个方法除了“记住”数据外,不做任何事情。因为这个特点,我们叫KNN为lazy algorithm。在预测方法predict_one里,我们首先计算被预测点到数据集中每个点的距离,我们得到了 d i s t a n c e _ a r r a y distance\_array distance_array。我们用numpy里面的argsort函数,把这些距离从小到大排序,并得到他们的坐标。我们取前k个坐标,得到相应的标签 n e i g h b o u r _ l a b e l s neighbour\_labels neighbour_labels。我们用python里面的Counter得到每个标签的出现次数,并选出出现频率最高的标签 m o s t _ f r e q u e n t most\_frequent most_frequent。至此,一个点的预测完成了。我们map一下predict_one,得到predict方法。

class KNNClassifier():
    X=None
    y=None
    n_neighbors=0
    
    def __init__(self, n_neighbors=15):
        self.n_neighbors=n_neighbors
        
    def fit(self, X, y):
        self.X=np.array(X)
        self.y=np.array(y)
    
    def predict_one(self, p):
        distance_array=np.array(list(map(lambda o: math.dist(p, o), self.X)))
        argsorted=np.argsort(distance_array)
        neighbours = argsorted[:self.n_neighbors]
        neighbour_labels = y[neighbours]
        occurence_count = Counter(neighbour_labels)
        most_frequent = occurence_count.most_common(1)[0][0]
        return most_frequent
        
    def predict(self, X):
        y_hat = np.array(list(map(self.predict_one, X)))
        return y_hat

文字总是苍白的,我们做个图。

knn = KNNClassifier()
knn.fit(X, y)

x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
h = .02  # step size in the mesh
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                     np.arange(y_min, y_max, h))
Z = knn.predict(np.c_[xx.ravel(), yy.ravel()])

# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx, yy, Z, cmap=cmap_light)

# Plot also the training points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold,
            edgecolor='k', s=20)
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title("My KNN (k = %i)"
          % (n_neighbors))

plt.show()

在这里插入图片描述
(图二)

我们的算法对么?比较一下Scikit-learn吧

# we create an instance of Neighbours Classifier and fit the data.
clf = neighbors.KNeighborsClassifier(n_neighbors)
clf.fit(X, y)

# Plot the decision boundary. For that, we will assign a color to each
# point in the mesh [x_min, x_max]x[y_min, y_max].
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
h = .02  # step size in the mesh
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                     np.arange(y_min, y_max, h))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.pcolormesh(xx, yy, Z, cmap=cmap_light)

# Plot also the training points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold,
            edgecolor='k', s=20)
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title("Scikit-learn KNN (k = %i)"
          % (n_neighbors))

plt.show()

在这里插入图片描述
(图三)

目测二者几乎一样。

其他

K的选择
从图一上看,如果我们把k从3改成5,结果就变了。所以,k的选择很重要。k必须足够大,才具有参考意义。k又必须足够小,否则就退化为平均值了。

距离
我们这里用的欧式距离。也可以用曼哈顿距离,或者自己写一个距离函数。距离也可以用来做权重,距离越近的,权重越大。

数据去冗余
有些数据,有他没他,对结果是没有影响,或者影响很小的。比如,k=5,有100个点据集在一起,那么,只有边缘点才有意义。这时,可以去掉一些点,使得knn预测的时候速度大幅提高。这时KNN就变成了CNN(Condensed nearest neighbors)

无监督
除了分类,回归,KNN也可以用来做异常检测。到距离最远的邻居k的距离为k-distance,直接比较这个值,越大越有可能是异常点。

源代码

https://github.com/EricWebsmith/machine_learning_from_scrach

参考文献

https://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html

https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html

https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm

猜你喜欢

转载自blog.csdn.net/juwikuang/article/details/108565458