林轩田 机器学习 PLA算法 15题
代码还没简化,能跑通
数据集来自https://d396qusza40orc.cloudfront.net/ntumlone%2Fhw1%2Fhw1_15_train.dat
我直接复制下来放到txt中保存的。
同时参考了https://blog.csdn.net/devil_bye/article/details/80752529
其中如果不在样本特征中添加一位标签位,则无法跑通。
例如,该样本提供的特征是4个特征,如样本[0.32,0.178,0.156,0.97,1],最后的需要添加一位标志变成
[0.32,0.178,0.156,0.97,1](至于原因为什么,还没想通,希望有朋友能够帮我解惑)
# coding=utf-8
from numpy import *
def loadData(filename):
fr=open(filename)
Xmat=[]
Ymat=[]
for line in fr.readlines():
line=line.strip()
curLine=line.split(' ')
temp=curLine[:-1]
temp2=(curLine[-1].split('\t'))[0]
Ymat.append((curLine[-1].split('\t'))[1])
temp.append(temp2)
temp.append(1)###添加标志位。
Xmat.append(map(float,temp))###txt中的样本特征是str格式的数字字符串,所以需要转换
Ymat=map(float,Ymat)
return mat(Xmat),Ymat
Xmat,Ymat=loadData('hw.txt')
def PLA(Xmat,Ymat):
m,n=shape(Xmat)
wbegin=zeros(n)
count=0
while True :
iter_count = 0
breakflag=1
for X in Xmat:
X=array(X[0])[0]###获得样本特征
Ypre=array(dot(X,wbegin))
if Ypre>0:
Ypre=1
else:
Ypre=-1
k=Ymat[iter_count]
if Ypre!=k:
wbegin=wbegin+X*Ymat[iter_count]
count = count + 1
breakflag=0
iter_count = iter_count+1
if breakflag==1:
break
print 'finish'
return wbegin,count
print PLA(Xmat,Ymat)