knn python实现

knn原理请见:
https://www.cnblogs.com/pinard/p/6061661.html
https://zhuanlan.zhihu.com/p/22345658?refer=hgjjsms

在特征和样本数量较小时,可以直接计算距离;数量大时就不行了,要在学习的过程中建树,具体上面两组文章中都有
试着写一下 直接计算距离时的代码:
import random
import numpy as np
def gen_data(num):
    train_labels =[]
    train_list   =[]
    #固定使用三个特征 三个分类
    for i in range(num+1):
        if (i%3 == 0):
            train_labels.append('A')
            for j in range(3):
                train_list.append(random.uniform(0,3))
        elif (i%3 == 1):
            train_labels.append('B')
            for j in range(3):
                train_list.append(random.uniform(2,5))
        else :
            train_labels.append('C')
            for j in range(3):
                train_list.append(random.uniform(4,8))
   
    #随机抽一个 做为测试数据,并从train_list train_labels中删去对应的值
    test_no    = random.randrange(1,num)
    test_databegin = test_no*3
    test_data  = np.array(train_list[test_databegin:test_databegin+3])
    test_lable = train_labels[test_no]
    del train_labels[test_no]
    del train_list[test_databegin+2]
    del train_list[test_databegin+1]
    del train_list[test_databegin]
    
    train_sets=np.array(train_list).reshape([num,3])
    #print (train_labels)
    #print (train_sets)
    print (test_data)
    print (test_lable)
    return train_sets,train_labels,test_data,test_lable
   
   
def getdistance(data1,data2):
    result=0.0
    for i in range(3):
        dis=(data1[i]-data2[i])**2
        result += dis
    return result**0.5

def getfreqlabel(sorted_dis,labels):
    labelset   = set(labels)
    labelcount = {}
    for i in range(len(sorted_dis)):
        for label in labelset:
            if (labels[sorted_dis[i][0]] == label ):
                labelcount[label] = labelcount.get(label,0)+1
                break
    if (len(labelcount) == 1):
        return list(labelcount)[0]
    else:
        return list( sorted(labelcount.items(),key= lambda x: x[1],reverse=True) )[0][0]
        
   
def knn(trainsets,labels,testdata,testlable,k):
    num=np.shape(trainsets)[0]
    distances = {}
    for i in range (num):
        distance = getdistance(testdata,trainsets[i])
        distances[i] = distance
    sorted_dis=sorted(distances.items(),key= lambda x: x[1])
    sorted_dis=sorted_dis[0:k]
    #print(sorted_dis)
    predict = getfreqlabel(sorted_dis,labels)
    if (testlable == predict):
        print("correct! the prediction is %s " % (predict))
    else :
        print("wrong! the prediction is %s and the real result is %s" %(predict,testlable))
   
   
输入几次如下测试代码:
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)

结果如下
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
[ 1.98740659  2.52245719  1.93773424]
A
correct! the prediction is A

traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
[ 7.69919827  6.15798075  4.72847636]
C
correct! the prediction is C
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
[ 4.75048743  5.35337321  5.94363519]
C
correct! the prediction is C
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
[ 4.17889151  4.12894528  2.75309795]
B
correct! the prediction is B
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
[ 4.46343398  4.84038916  2.70829516]
B
correct! the prediction is B
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
[ 3.6755628   3.27467528  2.65612631]
B
correct! the prediction is B
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
[ 2.93270931  2.82268114  4.67005795]
B
correct! the prediction is B
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
[ 5.93042058  5.4844738   5.31275289]
C
correct! the prediction is C
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
[ 3.52034823  2.34853466  2.80139722]
B
correct! the prediction is B
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
[ 7.46207714  7.96870556  7.34639584]
C
correct! the prediction is C
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
[ 5.47349844  5.23281451  4.51244907]
C
wrong! the prediction is B and the real result is C

可见,有预测错误的时候。当k设为更小的3时,错误率更高(BC类之间)
以上代码 在python3.6环境中测通过,如有错误请不吝指出,谢谢!

猜你喜欢

转载自blog.csdn.net/anthea_luo/article/details/80876651
今日推荐