knn原理请见:
https://www.cnblogs.com/pinard/p/6061661.html
https://zhuanlan.zhihu.com/p/22345658?refer=hgjjsms
https://www.cnblogs.com/pinard/p/6061661.html
https://zhuanlan.zhihu.com/p/22345658?refer=hgjjsms
在特征和样本数量较小时,可以直接计算距离;数量大时就不行了,要在学习的过程中建树,具体上面两组文章中都有
试着写一下 直接计算距离时的代码:
import random
import numpy as np
import numpy as np
def gen_data(num):
train_labels =[]
train_list =[]
#固定使用三个特征 三个分类
for i in range(num+1):
if (i%3 == 0):
train_labels.append('A')
for j in range(3):
train_list.append(random.uniform(0,3))
elif (i%3 == 1):
train_labels.append('B')
for j in range(3):
train_list.append(random.uniform(2,5))
else :
train_labels.append('C')
for j in range(3):
train_list.append(random.uniform(4,8))
#随机抽一个 做为测试数据,并从train_list train_labels中删去对应的值
test_no = random.randrange(1,num)
test_databegin = test_no*3
test_data = np.array(train_list[test_databegin:test_databegin+3])
test_lable = train_labels[test_no]
del train_labels[test_no]
del train_list[test_databegin+2]
del train_list[test_databegin+1]
del train_list[test_databegin]
train_sets=np.array(train_list).reshape([num,3])
#print (train_labels)
#print (train_sets)
print (test_data)
print (test_lable)
return train_sets,train_labels,test_data,test_lable
def getdistance(data1,data2):
result=0.0
for i in range(3):
dis=(data1[i]-data2[i])**2
result += dis
return result**0.5
train_labels =[]
train_list =[]
#固定使用三个特征 三个分类
for i in range(num+1):
if (i%3 == 0):
train_labels.append('A')
for j in range(3):
train_list.append(random.uniform(0,3))
elif (i%3 == 1):
train_labels.append('B')
for j in range(3):
train_list.append(random.uniform(2,5))
else :
train_labels.append('C')
for j in range(3):
train_list.append(random.uniform(4,8))
#随机抽一个 做为测试数据,并从train_list train_labels中删去对应的值
test_no = random.randrange(1,num)
test_databegin = test_no*3
test_data = np.array(train_list[test_databegin:test_databegin+3])
test_lable = train_labels[test_no]
del train_labels[test_no]
del train_list[test_databegin+2]
del train_list[test_databegin+1]
del train_list[test_databegin]
train_sets=np.array(train_list).reshape([num,3])
#print (train_labels)
#print (train_sets)
print (test_data)
print (test_lable)
return train_sets,train_labels,test_data,test_lable
def getdistance(data1,data2):
result=0.0
for i in range(3):
dis=(data1[i]-data2[i])**2
result += dis
return result**0.5
def getfreqlabel(sorted_dis,labels):
labelset = set(labels)
labelcount = {}
for i in range(len(sorted_dis)):
for label in labelset:
if (labels[sorted_dis[i][0]] == label ):
labelcount[label] = labelcount.get(label,0)+1
break
if (len(labelcount) == 1):
return list(labelcount)[0]
else:
return list( sorted(labelcount.items(),key= lambda x: x[1],reverse=True) )[0][0]
def knn(trainsets,labels,testdata,testlable,k):
num=np.shape(trainsets)[0]
distances = {}
for i in range (num):
distance = getdistance(testdata,trainsets[i])
distances[i] = distance
sorted_dis=sorted(distances.items(),key= lambda x: x[1])
sorted_dis=sorted_dis[0:k]
#print(sorted_dis)
predict = getfreqlabel(sorted_dis,labels)
if (testlable == predict):
print("correct! the prediction is %s " % (predict))
else :
print("wrong! the prediction is %s and the real result is %s" %(predict,testlable))
labelset = set(labels)
labelcount = {}
for i in range(len(sorted_dis)):
for label in labelset:
if (labels[sorted_dis[i][0]] == label ):
labelcount[label] = labelcount.get(label,0)+1
break
if (len(labelcount) == 1):
return list(labelcount)[0]
else:
return list( sorted(labelcount.items(),key= lambda x: x[1],reverse=True) )[0][0]
def knn(trainsets,labels,testdata,testlable,k):
num=np.shape(trainsets)[0]
distances = {}
for i in range (num):
distance = getdistance(testdata,trainsets[i])
distances[i] = distance
sorted_dis=sorted(distances.items(),key= lambda x: x[1])
sorted_dis=sorted_dis[0:k]
#print(sorted_dis)
predict = getfreqlabel(sorted_dis,labels)
if (testlable == predict):
print("correct! the prediction is %s " % (predict))
else :
print("wrong! the prediction is %s and the real result is %s" %(predict,testlable))
输入几次如下测试代码:
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
knn(traindata,labels,testdata,testlabel,5)
结果如下
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
[ 1.98740659 2.52245719 1.93773424]
A
correct! the prediction is A
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
A
correct! the prediction is A
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
[ 7.69919827 6.15798075 4.72847636]
C
correct! the prediction is C
C
correct! the prediction is C
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
knn(traindata,labels,testdata,testlabel,5)
[ 4.75048743 5.35337321 5.94363519]
C
correct! the prediction is C
C
correct! the prediction is C
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
knn(traindata,labels,testdata,testlabel,5)
[ 4.17889151 4.12894528 2.75309795]
B
correct! the prediction is B
B
correct! the prediction is B
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
knn(traindata,labels,testdata,testlabel,5)
[ 4.46343398 4.84038916 2.70829516]
B
correct! the prediction is B
B
correct! the prediction is B
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
knn(traindata,labels,testdata,testlabel,5)
[ 3.6755628 3.27467528 2.65612631]
B
correct! the prediction is B
B
correct! the prediction is B
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
knn(traindata,labels,testdata,testlabel,5)
[ 2.93270931 2.82268114 4.67005795]
B
correct! the prediction is B
B
correct! the prediction is B
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
knn(traindata,labels,testdata,testlabel,5)
[ 5.93042058 5.4844738 5.31275289]
C
correct! the prediction is C
C
correct! the prediction is C
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
knn(traindata,labels,testdata,testlabel,5)
[ 3.52034823 2.34853466 2.80139722]
B
correct! the prediction is B
B
correct! the prediction is B
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
knn(traindata,labels,testdata,testlabel,5)
[ 7.46207714 7.96870556 7.34639584]
C
correct! the prediction is C
C
correct! the prediction is C
traindata,labels,testdata,testlabel =gen_data(15)
knn(traindata,labels,testdata,testlabel,5)
knn(traindata,labels,testdata,testlabel,5)
[ 5.47349844 5.23281451 4.51244907]
C
wrong! the prediction is B and the real result is C
C
wrong! the prediction is B and the real result is C
可见,有预测错误的时候。当k设为更小的3时,错误率更高(BC类之间)
以上代码 在python3.6环境中测通过,如有错误请不吝指出,谢谢!
以上代码 在python3.6环境中测通过,如有错误请不吝指出,谢谢!