聚类算法之K-means算法-UCI数据集上的java实现

本文主要分析了K-means聚类算法的基本原理,时间复杂度以及优缺点,最后用UCI数据集进行了测试,包含java实现代码,适合初学者参考。

一.算法原理

输入:聚类个数k,以及包含 n个数据对象的数据库。
输出:满足方差最小标准的k个聚类。
处理流程:       
(1)从 n个数据对象任意选择 k 个对象作为初始聚类中心,即中心点。

(2)根据每个聚类的中心点,计算每个对象与这些中心点的距离;并根据最小距离重新对相应对象进行划分;
(3)重新计算每个(有变化)聚类的均值(中心点)

(4)循环(2)到(3)直到每个聚类不再发生变化为止

二.复杂度

时间复杂度:O(kntd),其中,t为迭代次数,K为簇的数目,n为数据数,d为维数

空间复杂度:O((n+K)d),其中,K为簇的数目,n为数据数,d为维数

三.算法优缺点

优点:

扫描二维码关注公众号,回复: 1907895 查看本文章

1.时间复杂度低,速度快。
2.对于大规模数据集,该算法是相对可扩展的,并且效率较高。

缺点:

1.必须事先给出要生成的簇数k。
2.不适合发现非凸面形状的簇和大小差别很大的簇。
3.对噪声和离群点敏感。
4.只适用于数值型数据。
5.初始点随机选取,可能导致终止于局部最优解。

四.java实现

实验数据用的是UCI上面的iris数据集,数据可以从UCI官网上面下载,下载下来是txt文件,可以自行百度将其导入到mysql数据库中。我的数据库中的字段信息如下图所示:(大家可以自己定义,原理看得明白就行)


接下来直接贴实现代码。代码结构如下:


1.Point类,主要是对应数据库中字段的模型类

package Kmeans;

//模型类,对应数据库中的属性
public class Point
{
	//定义iris数据集的四个属性
	private double x;
	private double y;
	private double z;
	private double w;
	public double getX()
	{
		return x;
	}
    public void setX(double x)
	{
	  	this.x=x;  
	}
    public double getY()
	{
	 	return y;
	}
	public void setY(double y)
	{
		this.y=y;
	}
    public double getZ()
    {
    	return z;
    }
    public void setZ(double z)
    {
    	this.z=z;
    }
    public double getW()
    {
    	return w;
    }
    public void setW(double w)
    {
    	this.w=w;
    }
    public Point()
    {
    }
    public Point(double x,double y,double z,double w)
    {
		super();
		this.x=x;
		this.y=y;
		this.z=z;
		this.w=w;
    }
    public String toString()
    {
    	return "Point [x="+x+",y="+y+",z="+z+",w="+w+"]";
    }
    //重写equals方法和hashCode方法,因为后面需要用到HashMap的containsKey(point)方法,而Point类作为其中的key参数
    /*@Override 
    public boolean equals(Object obj)
    {
    	Point p=(Point)obj;
		if(this.getX()==p.getX()&&this.getY()==p.getY()&&this.getZ()==p.getZ()&&this.getW()==p.getW())
		{
			return true;
		}
		return false;
    }
    @Override
    public int hashCode()
    {
    	return (int)(this.x+this.y+this.z+this.w);
    }*/
	@Override
	public int hashCode() 
	{
		final int prime = 31;
		int result = 1;
		long temp;
		temp = Double.doubleToLongBits(w);
		result = prime * result + (int) (temp ^ (temp >>> 32));
		temp = Double.doubleToLongBits(x);
		result = prime * result + (int) (temp ^ (temp >>> 32));
		temp = Double.doubleToLongBits(y);
		result = prime * result + (int) (temp ^ (temp >>> 32));
		temp = Double.doubleToLongBits(z);
		result = prime * result + (int) (temp ^ (temp >>> 32));
		return result;
	}
	@Override
	public boolean equals(Object obj) 
	{
		if (this == obj)
			return true;
		if (obj == null)
			return false;
		if (getClass() != obj.getClass())
			return false;
		Point other = (Point) obj;
		if (Double.doubleToLongBits(w) != Double.doubleToLongBits(other.w))
			return false;
		if (Double.doubleToLongBits(x) != Double.doubleToLongBits(other.x))
			return false;
		if (Double.doubleToLongBits(y) != Double.doubleToLongBits(other.y))
			return false;
		if (Double.doubleToLongBits(z) != Double.doubleToLongBits(other.z))
			return false;
		return true;
	}
    
}
ps:hashcode方法和equals方法必须重写,否则程序会出问题。

2.SelectData类,主要实现从数据库中读取数据到Arraylist中。

package Kmeans;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.ArrayList;


//从数据库中读取数据的类
public class SelectData
{
	private Connection con; 
	private PreparedStatement ps; 
	private ResultSet rs;
	/*
	 * 从数据库中取数据存放到ArrayList中
	 */
	public ArrayList<Point> getPoints()
	{
		//定义存放数据的列表
		ArrayList<Point> points=new ArrayList<Point>();
		 try 
		  {
			  //连接数据库代码,先要加载mysql驱动
			   Class.forName("com.mysql.jdbc.Driver").newInstance(); 
			   con = DriverManager.getConnection("jdbc:mysql://localhost:3306/uci_dataset","root","asdzxc123");
			   String sql="select sepal_length,sepal_width,petal_length,petal_width from iris";
			   ps = con.prepareStatement(sql);  
			   rs = ps.executeQuery(); 	
			   while(rs.next())
			   {
				   Point p=new Point();
				   p.setX(rs.getDouble("sepal_length"));
				   p.setY(rs.getDouble("sepal_width"));
				   p.setZ(rs.getDouble("petal_length"));
				   p.setW(rs.getDouble("petal_width"));
				   points.add(p);
				   //System.out.println("数据集为: "+p);
			   }
			   /*for(Point pp:points)
			   {
				   System.out.println(pp);
			   }*/
			   //System.out.println("ArrayList数据集: "+points);
			   rs.close();
			   ps.close();
			   con.close();
		  } 
		  catch (Exception e) 
		  {
		   e.printStackTrace();
		   System.out.println("数据库连接失败");
		  }
		 return points;
	}
	/*public static void main(String[] args)
	   {
		   new SelectData().getPoints();
	   }*/
}
3.ManagePoint类,主要实现对中心点的更新等操作,具体函数看代码。

package Kmeans;

import java.util.ArrayList;
import java.util.Map;

public class ManagePoint 
{
	/**
	 * 计算对象点到中心点之间的距离
	 * @param p 对象点
	 * @param q 中心点
	 * @return 两点之间的距离
	 */
	public double getDistance(Point p,Point q)
	{
		double dx=p.getX()-q.getX();
		double dy=p.getY()-q.getY();
		double dz=p.getZ()-q.getZ();
		double dw=p.getW()-q.getW();
		double dist=dx*dx+dy*dy+dz*dz+dw*dw;
		return dist;
	}
	/**
	 * 判断新的中心点是否和前一轮旧的中心的相同
	 * @param lastCenterCluster旧的
	 * @param nowCenterCluster新的
	 * @return 相同则返回true,否则返回false
	 */
	public boolean isEqual(Map<Point,ArrayList<Point>> lastCenterCluster,Map<Point,ArrayList<Point>> nowCenterCluster,int k)
	{
		boolean flag;
		int i=0;
		if(lastCenterCluster==null)
		{
			//System.out.println("11111111");
			return false;
		}
		else
		{
			for(Point point:nowCenterCluster.keySet())
			{
				//System.out.println("222222");
				flag=lastCenterCluster.containsKey(point);
				if(flag)
				{
					i++;
				}
			    
			}
			if(i==k) return true;
		}
		//System.out.println("333333");
		return false;
	}
	/**
	 * 计算新的中心点
	 * @param value  HashMap中的value,为一个ArrayList
	 * @return  返回新的中心点
	 */
	public Point getNewCenter(ArrayList<Point> value)
	{
		double sumX=0,sumY=0,sumZ=0,sumW=0;
		for(Point point:value)
		{
			sumX+=point.getX();
			sumY+=point.getY();
			sumZ+=point.getZ();
			sumW+=point.getW();
		}
		System.out.println("新的中心: ("+sumX/value.size()+","+sumY/value.size()+","+sumZ/value.size()+","+sumW/value.size()+")");
		Point point=new Point();
		point.setX(sumX/value.size());
		point.setY(sumY/value.size());
		point.setZ(sumZ/value.size());
		point.setW(sumW/value.size());
		return point;
	}	
}
4.KmeansMain类,实现聚类。

package Kmeans;


import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;

public class KmeansMain
{
    public Map<Point,ArrayList<Point>> executeKmeans(int k)
    {
    	ArrayList<Point> dataList=new ArrayList<Point>();//存放从SelectData类中获取的数据库中的源数据
    	Map<Point,ArrayList<Point>> nowCenterClusterMap=new HashMap<Point,ArrayList<Point>>();//当前中心及其簇内的点
    	Map<Point,ArrayList<Point>> lastCenterClusterMap=null;//上一个中心及其簇内所有点
    	dataList=new SelectData().getPoints();
		// 随机创建K个点作为起始中心
		Random rd=new Random();
		System.out.println("起始中心下标: ");
		for(int i=0;i<k;i++) 
		{
			int index=rd.nextInt(150);//public int nextInt(int n)该方法的作用是生成一个随机的int值,该值介于[0,n)的区间,也就是0到n之间的随机int值,包含0而不包含n。
			System.out.println("第"+(i+1)+"个随机生成的中心 :"+index);
			nowCenterClusterMap.put(dataList.get(index),new ArrayList<Point>());
		}
		// 输出起始中心
		System.out.println("起始中心: ");
		for(Point point:nowCenterClusterMap.keySet())
		{
			System.out.println("key:  "+point);
		}
		// 将数据点point加入配到离其最近的map的value中
		ManagePoint managePoint=new ManagePoint();
		while(true) 
		{
			
			for(Point point:dataList) 
			{
				double shortestDistance = Double.MAX_VALUE;// 初始化最短距离为Double的最大值
				Point key = null;
				for (Entry<Point,ArrayList<Point>> entry:nowCenterClusterMap.entrySet()) 
				{
					// 计算中心与各点间的距离
					double distance=managePoint.getDistance(entry.getKey(),point);
					if(distance<shortestDistance) 
					{
						shortestDistance=distance;
						key=entry.getKey();
					}
				}
				nowCenterClusterMap.get(key).add(point);
			}
			//如果这个判断放到上面while之后,那么return的值变为lastclustermap即可,因为在每次更新中心之后,nowclustermap里面只有key,没有value,只有执行循环之后才有value,才可以返回
			// 如果新的中心与上次的中心相等,则退出整个循环
			if (managePoint.isEqual(lastCenterClusterMap,nowCenterClusterMap,k)) 
			{
				System.out.println("中心相等了,聚类结束!");
				//测试lastCenterClusterMap数据,因为跳出循环时,它的数据应该和nowCenterClusterMap保持一致
				/*for (Entry<Point,ArrayList<Point>> entry:lastCenterClusterMap.entrySet()) 
				{
					System.out.println("\n" + "稳定的中心: "+entry.getKey());
					System.out.println("该簇的大小: "+entry.getValue().size());
					System.out.println("簇里的点:"+entry.getValue());
				}
				System.out.println("中心相等了,聚类结束!!!!");*/
				break;
			}
			// 更新中心
			lastCenterClusterMap=nowCenterClusterMap;
			nowCenterClusterMap=new HashMap<Point, ArrayList<Point>>();
			System.out.println("------------------------------------------------------------------");
			for(Entry<Point,ArrayList<Point>> entry:lastCenterClusterMap.entrySet()) 
			{
				nowCenterClusterMap.put(managePoint.getNewCenter(entry.getValue()),new ArrayList<Point>());
			}
		}
		return nowCenterClusterMap;
	}

	public static void main(String[] args) 
	{
		long start=System.currentTimeMillis();
		int K=3;// 分为三个类
		Map<Point,ArrayList<Point>> result =new KmeansMain().executeKmeans(K);
		// 输出分类结果
		System.out.println("===========聚类结果: ============");
		for (Entry<Point,ArrayList<Point>> entry:result.entrySet()) 
		{
			System.out.println("\n" + "稳定的中心: "+entry.getKey());
			System.out.println("该簇的大小: "+entry.getValue().size());
			System.out.println("簇里的点:"+entry.getValue());
		}
		long end=System.currentTimeMillis();
		System.out.println("执行本段程序所花费的时间为:"+(end-start)+"ms");
	}
}
结果如图所示,可进行多次运行查看聚类效果。


执行时间一般为0.4s左右。我用一百万条数据的数据集测试过,执行速度也是非常快的,聚类效率还是挺不错的。

后期会慢慢写其他聚类算法的实现,尽请期待!

特别说明:k-means算法中的初始中心点是随机选取,但为了程序方便,当然实际应用中也不会随机选取,故在初始点的选取过程中采用的是k-means++的选取方式,即从数据点中随机选取,在这里有个问题,如果选取的中心点有重复了(因为代码中未作判断),就会出bug,但是概率基本为0,但还是做下说明,免得大家运行时偶尔出问题,如果出问题了就重新运行下就ok。


猜你喜欢

转载自blog.csdn.net/qq_20372833/article/details/70877811