机器学习--k-近邻算法(kNN)实现手写数字识别

版权声明:本文为博主原创文章,未经作者允许请勿转载。 https://blog.csdn.net/heiheiya https://blog.csdn.net/heiheiya/article/details/82843852

这里的手写数字以0,1的形式存储在文本文件中,大小是32x32.目录trainingDigits有1934个样本。0-9每个数字大约有200个样本,命名规则如下:

下划线前的数字代表是样本0-9的数字,下划线后的数字代表是当前数字的第多少个样本。

目录testDigits下有946个样本。这个数据集可以在网上下载。

首先将32x32的二进制图像矩阵转换为1x1024的向量。

def img2vector(filename):
    returnVect = zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0, 32*i+j] = int(lineStr[j])
    return returnVect

然后实现kNN的分类器,原理请参考链接: 机器学习--k-近邻算法(kNN)学习笔记

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    #下面的四行代码计算距离
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances ** 0.5
    #对距离进行排序
    sortedDistIndicies = distances.argsort()
    classCount = {}
    #确定前k个较小距离的类别
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    #获得最大频率的类别
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

现在,可以检测一下kNN分类器的效果了。

def handwritingClassTest():
    hwLabels = []
    #获取目录内容
    trainingFileList = listdir('digits/trainingDigits')
    m = len(trainingFileList)
    traningMat = zeros((m, 1024))
    for i in range(m):
        #从文件名解析分类数字
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        traningMat[i, :] = img2vector('digits/trainingDigits/%s' % fileNameStr)
    testFileList = listdir('digits/testDigits')
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, traningMat, hwLabels, 3)
        print("the classifier came back with: %d, the real answer is: %d\n" % (classifierResult, classNumStr))
        if(classifierResult != classNumStr):
            errorCount += 1.0
    print("the total number of errors is: %d\n" % errorCount)
    print("the total error rate is: %f" % (errorCount/float(mTest)))

将上面的几段代码保存为kNN.py,然后在终端执行如下操作:

最后的输出如下:

猜你喜欢

转载自blog.csdn.net/heiheiya/article/details/82843852