基于Java实现机器学习的感知机(可视化界面)

前言:本人也是刚刚入门机器学习,就像入门很多语言一样,第一个程序总是Hello World 。机器学习也不然,入门机器学习的第一个程序就是感知机啦。感知机是二类线性分类模型,输出的值为{+1, -1}两种类型,感知机是利用超平面将两类分离,多个不同的感知机就可以组成一张神经元网络,再往上就是人工智能系统然后就是终结者阿诺......越说越离题了。

好啦步入正题吧。下面我从三个方面简单阐述一下这个感知机到底是个什么妖怪:感知机模型,感知机学习策略,感知机学习算法。(部分代码以及资料引用了网上的)

一、感知机模型:

假设输入空间(特征空间)是X∈Rn,输出空间是Y={+1, -1},仅有两种结果,就好比一条线,位于线上方的点带入该线的方程得到的y值总是大于0,所以感知机是一种线性分类模型,属于判别模型。输入x∈X表示实例特征向量。对应于输出空间(特征空间)的点:输出y∈Y表示实例类别,由输入空间到输出空间的如下函数:

f(x)=sign(w*x+b)={-1,+1}。

线性方程w*x+b 其中w称为权值,b称为偏置。咱们的感知机呢正是通过很多训练集来训练自己,从而不断更新w和b,直到找到一个最优的分类位置。

对应于特征空间Rn 中的一个超平面S,其中w是超平面的法向量,b是超平面的截距。根据这个原理我们可以推导出计算距离的公式:

能够将数据集的正实例和负实例完全正确的划分到超平面的两侧,则称数据集T是线性可分数据集,否则称线性不可分数据集。

二、感知机学习策略:

就像咱们有自己的学习方法一样,感知机也有自己的学习方法。而感知机的学习方法我们常称为损失函数。同时我们要将这个损失函数极小化,这就要求它是连续可导的。损失函数有两种选择:一、误分类点的总数;二、误分类点到超平面S的距离;第一种不易于优化,因此我们通常选择第二种。什么叫误分类点呢?如图:

我画了这样子的一条线,意在将两种颜色的圆分类,但蓝色类里面多了一个红色的,这个红色的就称之为误分类点啦。

对于误分类点来说,它到超平面的距离计算就相当于蓝色圈到超平面的距离取反,因为它代入超平面方程得到的y值应该是负值。也就是这个:

那么总距离就是:

因此,感知机sign(wx+b)的损失函数可以简写为:

三、感知机学习算法:

感知机学习问题转化为求解损失函数最优化问题,最优化的方法是随机梯度下降法。感知机学习算法有两种形式:原始形式和对偶形式。在训练数据线性可分的条件下,感知机学习算法是收敛的。

原始形式

我的感知机采用的正是原始形式,原始形式是通过给定的训练数据集T={(x1,y1), {x2,y2},…..,{xN,yN}}去求解参数w和b使得: 

这个损失函数极小。其中M是误分类点的集合。

感知机学习算法是误分类数据驱动的,采用随机梯度下降法,即随机选取一个超平面w0和b0,使用梯度下降法对损失函数进行极小化。极小化不是一次使得所有M集合误分类点梯度下降,而是一次随机选取一个点使其梯度下降。 
假设M集合是固定,那么损失函数的梯度为: 

然后呢再随机选取一个误分类点(xi,yi)对wi和b进行更新。更新方程如下:(w和b的初始值可以随便给,但尽量不要太大,否则会影响计算的时间。)

η(0<=η<=1)是步长,统计学习中称为学习率。通过不断迭代,损失函数不断减小,直到为0。 

所以对原始类的总结如下:

1、 随机选取w0和b0 ;
2、 在训练数据中选取(xi,yi) ;
3、 如果yi(w*xi+b) <= 0;

4、 转2,直到训练数据中,没有误分类点。 

对偶形式我就简单提一下吧,毕竟我还没代码实践。

对偶形似的基本想法是:将w和b表示实例xi和标记yi的线性组合形式,通过求解器系数的到w和b。在原始形式中,通过 w和b的更新方程不断修改w和b,假设修改了n次,则w和b关于(xi, yi)的增量分别是aixiyi和aiyi,这里的ai=niη。最后学习到的w和b是: 

其中ai > =0, i =1,2,….,N。当时,表示第i个实例有误分类而进行更新的次数。实例点更新次数越多,则它里超平面的距离就越近,也就越难正确分类。 

最后附上完整的Demo代码:

//PerceptronClassifier类:

package 感知机;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.TextField;
import java.util.ArrayList;
import java.util.Arrays;

import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;

import java.awt.Graphics;

public class PerceptronClassifier extends JFrame{
	
	//分类器参数
		private double[]w;//权值数组
		private double b = 0 ; //阈值
		private double eta = 1;//学习率
		ArrayList<Point>arrayList;
		
		public double getW(int i)
		{
			return w[i];
		}
		public double getB()
		{
			return b;
		}
	//初始化分类器,读入我们要分类的数据
		public  PerceptronClassifier(ArrayList<Point>arrayList,double eta)
		{
			this.arrayList = arrayList;
			w = new double[arrayList.get(0).x.length];
			this.eta = eta;
		}
	 // 分类器初始化

		public PerceptronClassifier(ArrayList<Point> arrayList) 
		{
			this.arrayList = arrayList;
			w = new double[arrayList.get(0).x.length];
			this.eta = 1;
		}
	/********************************************************/
		/*开始分类计算*/
		public boolean Classify()
		{
			boolean flag = false;
			while(!flag)//遍历所有的样本
			{
				
				for(int i=0;i<arrayList.size();i++)//所有的训练集
				{
					if(LearnAnswer(arrayList.get(i))<=0)//分类错误的点
					{
						UpdateWAndB(arrayList.get(i));//更新需要学习的点
						this.paint(this.getGraphics());  //动态更新,一旦出错马上重新遍历
					    try {
							Thread.sleep(300);
							} catch (InterruptedException e) {
							e.printStackTrace();
						}
						break;
					}
					if(i==arrayList.size()-1)//已经遍历到最后一个训练集
					{
						flag = true;
					}
				}
				
			}
			System.out.println("学习后:");
			System.out.println(Arrays.toString(w));//输出一轮学习后找到的权值和阈值
			System.out.println(b);
			return true;

		}
		
		private double LearnAnswer(Point point) //计算结果,用于判断分类是否正确
		{
			System.out.println(Arrays.toString(w));
			System.out.println(b);
			return point.y * (DotProduct(w, point.x) + b);
		}
		private void UpdateWAndB(Point point) //更新w 和 b 的值(随机梯度下降)
		{
			System.out.println("分类出错!更新w、b!");
			for (int i = 0; i < w.length; i++) {
				w[i] += eta * point.y * point.x[i];
			}
			b += eta * point.y;
			return;

		}
		
		private double DotProduct(double[] x1, double[] x2) //点乘函数
		{
			int len = x1.length;
			double sum = 0;
			for (int i = 0; i < len; i++) {
				sum += x1[i] * x2[i];
			}
			return sum;

		}
		
		public void InitUI()
		{
			this.setTitle("机器学习感知机");
			this.setSize(800, 600);
			this.setDefaultCloseOperation(3);
			this.setLocationRelativeTo(null);//窗口居中
			this.setResizable(false);//禁止最小化
			this.setLayout(null);//关闭流式布局
			

			JButton butstart = new JButton("开始训练");
		//	butstart.setPreferredSize(new Dimension(100,60));//设置按钮样式
			butstart.setBounds(150, 480, 100, 60);
			butstart.setContentAreaFilled(false);  //消除按钮背景颜色
			butstart.setOpaque(false); //除去边框
			butstart.setFocusPainted(false);//出去突起
			this.add(butstart);
			
			JButton butpro = new JButton("预测颜色");
			//	butstart.setPreferredSize(new Dimension(100,60));//设置按钮样式
			butpro.setBounds(300, 480, 100, 60);
			butpro.setContentAreaFilled(false);  //消除按钮背景颜色
			butpro.setOpaque(false); //除去边框
			butpro.setFocusPainted(false);//出去突起
			this.add(butpro);
			
			JLabel label1 = new JLabel("X:");
			label1.setFont(new Font("宋体",Font.BOLD,30));
			label1.setBounds(400, 500, 50, 50);
			this.add(label1);
			

			JLabel label2 = new JLabel("Y:");
			label2.setFont(new Font("宋体",Font.BOLD,30));
			label2.setBounds(540, 500, 50, 50);
			this.add(label2);
			
			
			TextField text1 = new TextField("1");
			text1.setFont(new Font("宋体",Font.BOLD,30));
			text1.setBounds(450, 500, 70, 40);
			this.add(text1);
			
			TextField text2 = new TextField("1");
			text2.setFont(new Font("宋体",Font.BOLD,30));
			text2.setBounds(590, 500, 70, 40);
			this.add(text2);
			this.setVisible(true);//设置窗体可见
			
			//添加监听
			ButtonListener BL = new ButtonListener(this,text1,text2);
			butstart.addActionListener(BL);
			butpro.addActionListener(BL);
		}
		
		//为了更形象,先画个坐标轴吧。重写paint函数就可以了。
		public void paint( Graphics g)
		{
			super.paint(g);
			//绘制坐标轴
			g.setColor(Color.black);
			g.drawLine(100, 100, 100, 480);
			g.drawLine(100, 480, 700, 480);
			
			//接下来从arrayList里面取点画出来,此时要注意颜色的设置,比如y为1设置蓝色,y为-1设置红色
			for(int i=0;i<arrayList.size();i++)
			{
				if(arrayList.get(i).y==1)
				{
					g.setColor(Color.BLUE);
				}
				else
				{
					g.setColor(Color.RED);
				}
				//位置可能需要进行适当的放大处理
//				g.drawLine((int)arrayList.get(i).x[0]*200+10,(int) arrayList.get(i).x[1]*200,
//						(int)arrayList.get(i).x[0]*200+10,(int)arrayList.get(i).x[1]*200+10);
				g.drawOval((int)arrayList.get(i).x[0]*100+200, (int)arrayList.get(i).x[1]*100+200, 15, 15);
			}
			//接下来是区分线
			//说白了就是计算点到直线的距离
			int x1=0,y2=0;
			System.out.println(this.getB()+" "+this.getW(1));
			int y1 = (int)((-1)*this.getB()/this.getW(1));
			int x2 = (int)((-1)*this.getB()/this.getW(0));
			System.out.println("开始画标准线!");
			g.setColor(Color.BLACK);
			System.out.println(x1*100+200+","+y1*100+200+","+x2*100+200+","+y2*100+200);
			g.drawLine(x1*100+200, y1*100+200, x2*100+200, y2*100+200);//跟上面保持一样的放大比例
		}
		
		public static void main(String[] args)
		{

			Point p1 = new Point(new double[] { 0,1.1 }, -1);//训练集
			Point p2 = new Point(new double[] { 1.2,0 }, -1);
			Point p3 = new Point(new double[] { 2.16,1 }, -1);
			Point p4 = new Point(new double[] { 1,2.64 }, -1);
			Point p5 = new Point(new double[] { 3.14,1.2 }, 1);
			Point p6 = new Point(new double[] { 1.32,3.4 }, 1);
			Point p7 = new Point(new double[] { 3.32,2.23 }, 1);
			Point p8 = new Point(new double[] { 2.71,2.4 }, 1);

			ArrayList<Point> list = new ArrayList<Point>();
			list.add(p1);
			list.add(p2);
			list.add(p3);
			list.add(p4);
			list.add(p5);
			list.add(p6);
			list.add(p7);
			list.add(p8);

			PerceptronClassifier classifier = new PerceptronClassifier(list);
		//	classifier.Classify();
			classifier.InitUI();

		}
	
		
}

//Point类:

package 感知机;

public class Point {
	
	double[] x = new double[2];
	double y =0;
	Point(double[]x ,double y)
	{
		this.x = x;
		this.y = y;
	}
	
	Point()
	{
		
	}
	

}

//ButtonListener:

package 感知机;

import java.awt.TextField;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;

import javax.swing.JOptionPane;

public class ButtonListener implements ActionListener{

	public PerceptronClassifier classifier;
	public TextField text1,text2;
	
	
	public ButtonListener(PerceptronClassifier classifier,TextField text1,TextField text2) {
		this.classifier =classifier;
		this.text1 = text1;
		this.text2 = text2;
	}
	public void actionPerformed(ActionEvent e) {
		if(e.getActionCommand().equals("开始训练"))
		{
			classifier.Classify();//启动训练方法
		}
		else if(e.getActionCommand().equals("预测颜色"))
		{
			//首先拿到文本框输入的坐标
			String x1 = text1.getText();
			String x2 = text2.getText();
			float xx1,xx2;
			//由于是string,我们需要强制转换为数字
			if(x1==""|| x2=="")
			{
				xx1=(float) 1.0;
				xx2=(float) 1.0;
			}
			 xx1 = new Float(x1);
			 xx2 = new Float(x2);
			System.out.println("拿到的XY为:"+xx1+" "+xx2);
			//将坐标点带入我们得到的方程,不同的结果代表不同的颜色,结果只有1和-1.
			//xx1*w1+xx2*w2+b
			if(xx1*classifier.getW(0)+xx2*classifier.getW(1)+classifier.getB()>=0)
			{
				JOptionPane.showMessageDialog(null,"该图形为蓝色");//消息框弹出
			}
			else
			{
				JOptionPane.showMessageDialog(null,"该图形为红色");
			}
		}
		
	}
	

}

有不妥之处欢迎指出!!Demo使用说明:点击开始训练即可进行分类。在文本框X和Y处输入相关点的坐标就会预测它在这个平面内的类别,也就是属于哪种颜色。

猜你喜欢

转载自blog.csdn.net/weixin_42294984/article/details/82465325
今日推荐