K-近邻算法(kNN)
工作原理:
存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。
优点:精度高、对异常值不敏感、无数据输入假定。
缺点:计算复杂度高、空间复杂度高。
适用数据范围:数值型和标称型。
k-近邻算法的一般流程:
(1) 收集数据:可以使用任何方法。
(2) 准备数据:距离计算所需要的数值,最好是结构化的数据格式。
(3) 分析数据:可以使用任何方法。
(4) 训练算法:此步骤不适用于k-近邻算法。
(5) 测试算法:计算错误率。
(6) 使用算法:首先需要输入样本数据和结构化的输出结果,然后运行k-近邻算法判定输
入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理。
下面代码使用欧氏距离公式,计算两个向量点xA和xB之间的距离:
例如,点(0, 0)与(1, 2)之间的距离计算为:
如果数据集存在4个特征值,则点(1, 0, 0, 1)与(7, 6, 9, 4)之间的距离计算为:
计算完所有点之间的距离后,可以对数据按照从小到大的次序排序。然后,确定前k个距离最小元素所在的主要分类,输入k总是正整数;最后,将classCount字典分解为元组列表,然后使用程序第二行导入运算符模块的itemgetter方法,按照第二个元素的次序对元组进行排序。此处的排序为逆序,即按照从最大到最小次序排序,最后返回发生频率最高的元素标签。
参考代码:
数据集下载:https://download.csdn.net/download/bashendixie5/13678689
00000000000010000000000000000000
00000000001111000000000000000000
00000000001111111000000000000000
00000000011111111110000000000000
00000000011111111111000000000000
00000000111111111111100000000000
00000000111111111111110000000000
00000000111111110111110000000000
00000000111111000011110000000000
00000001111110000001111000000000
00000001111110000000011110000000
00000001111110000000001110000000
00000001111110000000001110000000
00000001111110000000001111000000
00000001111100000000001111000000
00000001111100000000000111000000
00000011111000000000000111000000
00000011111000000000000111000000
00000011110000000000001111000000
00000011110000000000001111000000
00000011110000000000001111000000
00000011110000000000001110000000
00000001110000000000011110000000
00000000111000000000011111000000
00000000111000000001111111000000
00000000111100000011111100000000
00000000111100000111111000000000
00000000111111111111111000000000
00000000011111111111110000000000
00000000001111111110000000000000
00000000001111111110000000000000
00000000000011110000000000000000
Python代码:
# K近邻算法
from numpy import *
import operator
from os import listdir
# 分类算法
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
# 计算距离
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances ** 0.5
# 选择距离最小的k个点
sortedDistIndicies = distances.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createDataSet():
group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels
# 文件转矩阵
def file2matrix(filename):
fr = open(filename)
numberOfLines = len(fr.readlines()) # get the number of lines in the file
returnMat = zeros((numberOfLines, 3)) # prepare matrix to return
classLabelVector = [] # prepare labels return
fr = open(filename)
index = 0
for line in fr.readlines():
line = line.strip()
listFromLine = line.split('\t')
returnMat[index, :] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat, classLabelVector
# 归一化数值
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m, 1))
normDataSet = normDataSet / tile(ranges, (m, 1)) # element wise divide
return normDataSet, ranges, minVals
# 约会分类测试
def datingClassTest():
hoRatio = 0.50 # hold out 10%
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') # load data setfrom file
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m * hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
print
"the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
if (classifierResult != datingLabels[i]): errorCount += 1.0
print
"the total error rate is: %f" % (errorCount / float(numTestVecs))
print
errorCount
# 图片转向量
def img2vector(filename):
# 创建一个1024长度的向量
returnVect = zeros((1, 1024))
# 打开文件
fr = open(filename)
# 循环行列
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32 * i + j] = int(lineStr[j])
return returnVect
# 手写分类测试
def handwritingClassTest():
hwLabels = []
# 加载训练数据
trainingFileList = listdir('D:/Project/DeepLearn/2trainlog/9/trainingDigits')
# 获取训练集数量
m = len(trainingFileList)
# 创建矩阵
trainingMat = zeros((m, 1024))
# 循环训练集数量
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0] # take off .txt
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i, :] = img2vector('D:/Project/DeepLearn/2trainlog/9/trainingDigits/%s' % fileNameStr)
testFileList = listdir('D:/Project/DeepLearn/2trainlog/9/testDigits') # iterate through the test set
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0] # take off .txt
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('D:/Project/DeepLearn/2trainlog/9/testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
if (classifierResult != classNumStr): errorCount += 1.0
print("\nthe total number of errors is: %d" % errorCount)
print("\nthe total error rate is: %f" % (errorCount / float(mTest)))
handwritingClassTest()
查看运行结果,就test数据集来说,错误率仅为0.010571
the classifier came back with: 9, the real answer is: 9
the classifier came back with: 9, the real answer is: 9
the classifier came back with: 9, the real answer is: 9
the classifier came back with: 9, the real answer is: 9
the classifier came back with: 7, the real answer is: 9
the classifier came back with: 9, the real answer is: 9
the classifier came back with: 9, the real answer is: 9
the classifier came back with: 9, the real answer is: 9
the classifier came back with: 9, the real answer is: 9
the total number of errors is: 10
the total error rate is: 0.010571