山东大学模式识别实验(java)KNN算法

KNN算法就是把待分类数据放在训练集里找出离他最近的K个元素(欧氏距离),然后看看其中哪个类最多,就将这个元素分为这个类。在本实验中,使用数字数据集。每个数字含有一个二维数组表示其中的像素点,可以认为拥有M*N个特征,只不过每个特征只有0和1两种值,表示该像素点是否绘制。

将下载的训练集和测试集放在项目根目录下,因为测试集中每个元素也是已标记数据,所以每次分类后可以判断分类是否正确,从而得出一个正确率。

在拿到待分类元素的K个邻居后,最简单的处理是每个邻居具有相等的投票权,考虑增大离得近的元素的影响力,也就是为他们的投票权设置权值。这里我的权值设置是离得最近的具有K票,第二近的具有K-1票,依次递减,比较容易理解,直接看代码。

1.封装的数字类:

public class Number {
	private int[][] data=new int[32][32];
	private int kind;
	public int[][] getData() {
		return data;
	}
	public void setData(int[][] data) {
		this.data = data;
	}
	public int getKind() {
		return kind;
	}
	public void setKind(int kind) {
		this.kind = kind;
	}
	
}
2.test类:

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;

public class Test {
	private int k;
	private List<Number> testDatas;
	
	public void putTestData(String path){//放入测试数据
		File folder=new File(path);
		File[] files=folder.listFiles();
		testDatas=new ArrayList<>();
		for(File file:files){
			testDatas.add(txt2Number(file));
		}
	}
	
	public void work(String path){//开始测试
		File folder=new File(path);
		File[] files=folder.listFiles();
		int numberAll=0;//测试总数
		int numCorrect=0;//测试正确数
		double result=0;//正确率
		for(File file:files){
			Number num=txt2Number(file);
			int[] minDistances=new int[k];
			int[] resultKinds=new int[k];
			for (int i = 0; i < k; i++) {
				minDistances[i]=Integer.MAX_VALUE;
				resultKinds[i]=0;
			}
			for(Number nu:testDatas){
				int currentDis=calcu(num, nu);
				int currentKind=nu.getKind();
				for (int i = 0; i < k; i++) {//将当前测试数据与邻居数组中每一个进行比对,看看是否可以替换掉一个
					if (currentDis<minDistances[i]) {
						resultKinds[i]=currentKind;
						minDistances[i]=currentDis;
						break;
					}
				}
			}
			int []kinds=new int[10];//10个类别的个数
			for (int i = 0; i < k; i++) {
				kinds[resultKinds[i]]+=add(minDistances, i);//加权后累加
			}
			int resultKind=0;
			int resultKindNum=0;
			for (int i = 0; i < 10; i++) {
				if (kinds[i]>resultKindNum) {
					resultKind=i;
					resultKindNum=kinds[i];
				}
			}
			numberAll++;
			if (resultKind==num.getKind()) {
				numCorrect++;
			}
		}
		result=((double)(numCorrect*100))/numberAll;
		System.out.println("k是:"+getK()+"   测试总数:"+numberAll+"    "
				+ "正确数:"+numCorrect+"   正确率"+result);
	}
	
	public void workOne(String path){//测试单个
		File fileTest=new File(path);
		Number num=txt2Number(fileTest);
		int[] minDistances=new int[k];
		int[] resultKinds=new int[k];
		for (int i = 0; i < k; i++) {
			minDistances[i]=Integer.MAX_VALUE;
			resultKinds[i]=0;
		}
		for(Number nu:testDatas){
			int currentDis=calcu(num, nu);
			int currentKind=nu.getKind();
			for (int i = 0; i < k; i++) {//将当前测试数据与邻居数组中每一个进行比对,看看是否可以替换掉一个
				if (currentDis<minDistances[i]) {
					resultKinds[i]=currentKind;
					minDistances[i]=currentDis;
					break;
				}
			}
		}
		int []kinds=new int[10];//10个类别的个数
		for (int i = 0; i < k; i++) {
			kinds[resultKinds[i]]++;
		}
		int resultKind=0;
		int resultKindNum=0;
		for (int i = 0; i < 10; i++) {
			if (kinds[i]>resultKindNum) {
				resultKind=i;
				resultKindNum=kinds[i];
			}
		}
		System.out.println("识别文件"+path+"为:"+resultKind+"  实际类型为:"+num.getKind());
	}
	
	public int calcu(Number a,Number b){//计算两张图的欧氏距离,为了简化计算不开根号
		int result=0;
		for(int i=0;i<32;i++){
			for (int j = 0; j < 32; j++) {
				int[][] d1=a.getData();
				int[][] d2=b.getData();
				int dis=d1[i][j]-d2[i][j];
				result+=dis*dis;
			}
		}
		return result;
	}
	
	public int getK() {
		return k;
	}

	public void setK(int k) {
		this.k = k;
	}


	
	public Number txt2Number(File file){//txt文件转Number对象
		Number num=new Number();
		int[][] data=new int[32][32];
		String fileName=file.getName();
		int kind =Integer.valueOf(fileName.substring(0,1));
		num.setKind(kind);
		try {
			BufferedReader reader=new BufferedReader(new FileReader(file));
			String s=null;
			for (int i = 0; i < 32; i++) {
				s=reader.readLine();
				for (int j = 0; j < 32; j++) {
					data[i][j]=Integer.valueOf(s.substring(j, j+1));
				}
			}
		} catch (Exception e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		num.setData(data);
		return num;
	}
	
	public int add(int []a,int index){//获取这个邻居元素的权值
		int re=0;
		for (int i = 0; i < a.length; i++) {
			if (a[index]<=a[i]) {
				re++;
			}
		}
		return re;
	}
	
	public static void main (String[] args) {
		Test test=new Test();
		test.putTestData("testDigits");
		test.setK(10);
		test.work("trainingDigits");
		
		test.setK(1);
		test.work("trainingDigits");
	}
}

实验结果:


猜你喜欢

转载自blog.csdn.net/zhang___yong/article/details/79053539