KNN算法原理与实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/potato012345/article/details/52796800

KNN

算法概述

      KNN (K-Nearest Neighbors)算法不需要训练过程,直接利用样本之间的距离进行分类。算法的基本过程是:给定一个测试样例,分别计算它与训练集中所有样本之间的距离,选取距离最小的K个训练样本对测试样例的类别进行投票,最后将得票最多(在K个样本中最多的类别)作为测试样例的预测类别。

      需要注意的是,计算样本与样例距离时可以采用多种距离指标,如欧氏距离、曼哈顿距离等。K 的值也可以根据实际应用背景来灵活确定,一般地,我们选取20以内的值。

      KNN 算法便于理解与解释,且预测准确率高。但当训练集和测试集的样本数较大时,会产生巨大的时间和空间开销。

算法描述

输入:测试样例向量 x,训练样本集D,训练集类别向量L,近邻数目K

输出:预测类别 l

for 训练样本 in D do:
    计算训练样本与 x 之间的距离,存储在向量dis中
将dis从小到大排序,得到sorted_dis
取sorted_dis的前K项,统计每一项对应的训练样本的类别
统计K个类别中,数量最多的那个,为l赋值
return l

Python 实现代码

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# KNN algorithm in MLinAction

from numpy import *
import operator


def classify(inx, dataSet, lables, k):    # 分类器主程序. inx为测试样例向量,dataset为训练集 
    dataSize = dataSet.shape[0]           # lables 是与训练集数据一一对应的类别向量,k 为选取的近邻数目
    diffMat = tile(inx, (dataSize, 1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistance = sqDiffMat.sum(axis=1)
    distance = sqDistance**1 / 2    # 计算测试样例与训练集每个样本的距离(欧氏距离)
    sortIndex = sqDistance.argsort()    # 按距离排序
    classCount = {}
    for i in range(k):    # 取距离最小的前k个训练样本进行投票(voting)
        if lables[sortIndex[i]] not in classCount.keys():
            classCount[lables[sortIndex[i]]] = 0
        classCount[lables[sortIndex[i]]] += 1
    maxCount = 0
    for key in classCount.keys():    # voting, k个训练样本中最多的类别作为测试样例的预测样本 
        if classCount[key] > maxCount:
            maxCount = classCount[key]
            resClass = key
    return resClass


def loadData(filename1, filename2):    # 产生训练集数据和测试集数据的函数
    f1 = open(filename1, 'r')    # 全集文件名
    f2 = open(filename2, 'w')    # 测试集文件名
    lineNum = 1000
    horatio = 0.1
    lableVec = []    # 训练集样本的类别向量
    dataMat = zeros((int(lineNum * (1 - horatio)), 3))    # 训练集样本
    testDataMat = zeros((int(lineNum * horatio), 3))    # 测试集样本
    testLableVec = []    # 测试集样例的类别向量
    index = 0
    for i in range(lineNum):
        line = f1.readline()
        if i < lineNum * horatio:
            f2.write(line)
            lineVec = line.strip().split('\t')
            testDataMat[i, :] = lineVec[0:3]
            testLableVec.append(lineVec[-1])
        else:
            lineVec = line.strip().split('\t')
            dataMat[i - 100, :] = lineVec[0:3]
            lableVec.append(lineVec[-1])
    return dataMat, lableVec, testDataMat, testLableVec


def dataNorm(dataMat, testDataMat):    # 数据标准化函数:(data-min)/(max-min)
    Max1 = dataMat.max(0)
    Max2 = testDataMat.max(0)
    Min1 = dataMat.min(0)
    Min2 = testDataMat.min(0)
    Max_Min1 = Max1 - Min1
    Max_Min2 = Max2 - Min2
    m1 = dataMat.shape[0]
    m2 = testDataMat.shape[0]
    normDataMat = (dataMat - tile(Min1, (m1, 1))) / tile(Max_Min1, (m1, 1))
    normTestDataMat = (testDataMat - tile(Min2, (m2, 1))) / \
        tile(Max_Min2, (m2, 1))
    return normDataMat, normTestDataMat


def testClassify(dataSet, testDataSet, lables, testlables):    # 算法测试函数,返回预测正确率
    numRight = 0
    numTest = testDataSet.shape[0]
    for i in range(numTest):
        lable = classify(testDataSet[i, :], dataSet, lables, 25)
        if lable == testlables[i]:
            numRight += 1
    return float(numRight) / numTest



dataMat, lableVec, testDataMat, testLableVec = loadData(
    'datingTestSet.txt', 'testSet.txt')
dataSet, testDataSet = dataNorm(dataMat, testDataMat)
ratio = testClassify(dataSet, testDataSet, lableVec, testLableVec)
print ratio


猜你喜欢

转载自blog.csdn.net/potato012345/article/details/52796800