版权声明:本文为博主原创文章,未经作者允许请勿转载。 https://blog.csdn.net/heiheiya https://blog.csdn.net/heiheiya/article/details/82843852
这里的手写数字以0,1的形式存储在文本文件中,大小是32x32.目录trainingDigits有1934个样本。0-9每个数字大约有200个样本,命名规则如下:
下划线前的数字代表是样本0-9的数字,下划线后的数字代表是当前数字的第多少个样本。
目录testDigits下有946个样本。这个数据集可以在网上下载。
首先将32x32的二进制图像矩阵转换为1x1024的向量。
def img2vector(filename):
returnVect = zeros((1, 1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32*i+j] = int(lineStr[j])
return returnVect
然后实现kNN的分类器,原理请参考链接: 机器学习--k-近邻算法(kNN)学习笔记。
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
#下面的四行代码计算距离
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances ** 0.5
#对距离进行排序
sortedDistIndicies = distances.argsort()
classCount = {}
#确定前k个较小距离的类别
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
#获得最大频率的类别
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
现在,可以检测一下kNN分类器的效果了。
def handwritingClassTest():
hwLabels = []
#获取目录内容
trainingFileList = listdir('digits/trainingDigits')
m = len(trainingFileList)
traningMat = zeros((m, 1024))
for i in range(m):
#从文件名解析分类数字
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
traningMat[i, :] = img2vector('digits/trainingDigits/%s' % fileNameStr)
testFileList = listdir('digits/testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, traningMat, hwLabels, 3)
print("the classifier came back with: %d, the real answer is: %d\n" % (classifierResult, classNumStr))
if(classifierResult != classNumStr):
errorCount += 1.0
print("the total number of errors is: %d\n" % errorCount)
print("the total error rate is: %f" % (errorCount/float(mTest)))
将上面的几段代码保存为kNN.py,然后在终端执行如下操作:
最后的输出如下: