之前介绍的几种算法,都是监督学习算法,我们需要对数据进行预处理,也就是在使用数据前,需要对数据集的样本数据进行标记。今天我们看一种无监督学习算法——k-means。
k-means算法用来实现聚类,什么是聚类?打一个比方,我们在袋子中放着各种水果,我们事先并不知道有哪几种,通过一些算法,我们可以借助于特性将水果聚集为几个类别,然后我们再去看这几个类别分别代表了什么水果。
k-means算法的思想非常简单,假设有m条数据,n个特性:
随机选取k个点作为起始中心(k行n列的矩阵,每个特征都有自己的中心);
遍历数据集中的每一条数据,计算它与每个中心的距离;
将数据分配到距离最近的中心所在的簇;
使用每个簇中的数据的均值作为新的簇中心
如果簇的组成点发生变化,则跳转执行第2步;否则,结束聚类。
影响k-means的因素主要是k的选取,比如,数据可以分为三类,但是我们的k选择为2,那么就会有一个类被划分进了一个错误的类。所以,我们需要多尝试一些k值。另外,初始k个中心的选择,也会影响算法的执行。下面看看《机器学习实战》中的算法实现。
首先是选取初始随机中心的函数,需要注意的是我们需要对每个中心的n个特性分别计算中心值:
def rand_cent(data_set, k): n = np.shape(data_set)[1] centroids = np.mat(np.zeros((k, n))) for j in range(n): min_j = np.min(data_set[:, j]) range_j = float(np.max(data_set[:, j]) - min_j) centroids[:, j] = min_j + range_j * np.random.rand(k, 1) return centroids
接下来,我们采用欧式距离计算中心的距离:
def dist_eclud(vec_a, vec_b): return np.sqrt(np.sum(np.power(vec_a - vec_b, 2)))
下面是k-means算法的核心:
def kmeans(data_set, k, dist_meas=dist_eclud, create_cent=rand_cent): m = np.shape(data_set)[0] cluster_assment = np.mat(np.zeros((m, 2))) centroids = create_cent(data_set, k) cluster_changed = True while cluster_changed: cluster_changed = False for i in range(m): min_dist = np.inf min_index = -1 for j in range(k): dist_ji = dist_meas(centroids[j, :], data_set[i, :]) if dist_ji < min_dist: min_dist = dist_ji min_index = j if cluster_assment[i, 0] != min_index: cluster_changed = True cluster_assment[i, :] = min_index, min_dist ** 2 #print(centroids) for cent in range(k): pts_in_cluster = data_set[np.nonzero(cluster_assment[:, 0].A == cent)[0]] centroids[cent, :] = np.mean(pts_in_cluster, axis=0) return centroids, cluster_assment
centroids返回中心的信息,cluster_assment返回了簇的信息,m行2列,m行对应m条样本数据,第一列保存了该行数据所属簇的index,第二列保存了该行到中心的距离,也就是偏离中心的误差。while中的内容就是上面2-5步骤做的事情。
下面看看使用,我伪造了一组数据,这些数据实际上可以被分到4类,边界也比较清晰,主要目的是为了看看算法的作用:
if __name__ == '__main__': data_set = np.mat([ [0.5, 0.3], [0.2, 0.7], [0.8, 0.9], [9.5, 0.3], [9.2, 0.7], [9.8, 0.9], [0.5, 9.3], [0.2, 9.7], [0.8, 9.9], [9.5, 9.3], [9.2, 9.7], [9.8, 9.9], ]) centroids, cluster_assment = kmeans(data_set, 4) import matplotlib.pyplot as plt print(centroids.A[:, 0]) plt.scatter(centroids.A[:, 0], centroids.A[:, 1], marker='x') plt.scatter(data_set.A[:, 0], data_set.A[:, 1]) plt.show()
运行结果如下:
运行结果和我们的预期相符合。