题目:编程实现线性判别分析LDA,给出西瓜数据集 3.0a上的结果
clc clear all [num,txt]=xlsread('D:\机器学习\WaterMelon_3.0.xlsx'); %提取有效数据 data=num(1:end,[1,8,9]); label_txt=txt([2:end],10); label=ismember(label_txt,'是'); %整理所需数据 data=[data,label]; class1=data(find(label==1),[2,3]); class2=data(find(label==0),[2,3]); %核心代码 mu1=mean(class1); mu2=mean(class2); s1=cov(class1); s2=cov(class2); sw=s1+s2; sb=(mu1-mu2)'*(mu1-mu2); [V,D]=eig(inv(sw)*sb); %取较大特征值对应的特征向量 w=V(:,2); pre_value1=class1*w; pre_value2=class2*w; pre_value=[pre_value1;pre_value2]; %offset相当于y=w^Tx+b中的b值 offset=(mean(pre_value1)+mean(pre_value2))/2; for i=1:length(pre_value) pre_value(i)=pre_value(i)-offset; %采用sigmod函数进行类别判断 pre_label(i)=~round(1/(1+exp(- pre_value(i)))); end data_out=[data,pre_value,pre_label']; xlswrite('D:\机器学习\LDA数据输出.xls',data_out); figure('NumberTitle', 'on', 'Name','给西瓜分个家_马存诗'); hold on; grid on; plot(class1(:,1),class1(:,2),'b*'), plot(class2(:,1),class2(:,2),'r+'), plot([0,-w(1)],[0,-w(2)]); %求垂足画垂线 for i=1:length(label) proj_point = ProjPoint( [data(i,2),data(i,3)],[0,0,-w(1),-w(2)]); if (i<=length(class1)) plot(proj_point(1),proj_point(2),'b.'); plot([data(i,2),proj_point(1)],[data(i,3),proj_point(2)],'--'); else plot(proj_point(1),proj_point(2),'r.'); plot([data(i,2),proj_point(1)],[data(i,3),proj_point(2)],'--'); end end axis([0 0.8 0 0.6]); title('LDA图示结果'); % ProjPoint函数:求投影垂足的函数 function proj_point = ProjPoint( point,line ) x1 = line(1); y1 = line(2); x2 = line(3); y2 = line(4); x3 = point(1); y3 = point(2); yk = ((x3-x2)*(x1-x2)*(y1-y2) + y3*(y1-y2)^2 + y2*(x1-x2)^2) / (norm([x1-x2,y1-y2])^2); xk = ((x1-x2)*x2*(y1-y2) + (x1-x2)*(x1-x2)*(yk-y2)) / ((x1-x2)*(y1-y2)); if x1 == x2 xk = x1; end if y1 == y2 xk = x3; end proj_point = [xk,yk]; end
WaterMelon_3.0.xlsx数据集:
数据文件可去这儿下载:https://pan.baidu.com/s/1O_ZYNPvCudC97SdrQIOCQQ
data_out输出结果:
offset的求取简化了,应该求两个正态的交点,我直接对两个均值求的平均值。
LDA图示结果: