什么是k-means 算法?
是聚类算法的一种,聚类算法中的划分聚类算法,属于无监督学习方法。
算法思想?
1、首先确定划分群体的个数k
2、随机从数据集中选取k个中心点,这里使用欧几里德距离公式计算数据集任意一点到中心点的距离,把距离中心点近的划分一类。
3、重复步骤2.从第一次的群体中找到相应的中心点,在进行群体划分。
4、重复步骤2、3 直到 数据收敛或满足迭代次数。
----
距离假设我们有一组二维数据
P0(0,0) P1(1,3) P2(3,1) P3(8,9) P4(10,8) P5(9,10)
首先我们随机选取k=2 的群体假设中心点为 p0, p3
计算每个点与中心点的距离。
min(p1 ->p0 ,p1->p3) 取 距离近的,假设取p1。其他依次类推。。。
----
package com.fandong.algorithm;
import scala.Tuple2;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
/***
* 针对二维进行聚合
*/
public class KMeans {
//预先选择的数据集
private int cluster_k;
//分组存储的结果值
private Map<Integer,ArrayList<Tuple2<Double,Double>>> result;
//原始数据
private ArrayList<Tuple2<Double,Double>> scData;
public KMeans(int cluster_k,ArrayList<Tuple2<Double,Double>> sc){
this.cluster_k = cluster_k;
this.scData = sc;
}
/***
* 使用随机数获取k个中心点
* @param cluster_k
* @return
*/
public ArrayList<Tuple2<Double,Double>> getInitClusterK(int cluster_k){
ArrayList<Tuple2<Double,Double>> r = new ArrayList<Tuple2<Double, Double>>();
int size = this.scData.size()-1;
for(int i=0;i<cluster_k;i++){
int index = (int)(Math.random()*size);
r.add(this.scData.get(index));
}
return r;
}
/**
* 获取距离函数
* @param s1 数据集点
* @param cluster_point 集群点
* @return
*/
private double getDistance(Tuple2<Double,Double> s1, Tuple2<Double,Double> cluster_point){
double x=0D,y=0D;
x = (s1._1() - cluster_point._1())*(s1._1()-cluster_point._1());
y = (s1._2() -cluster_point._2()) *(s1._2()-cluster_point._2());
return Math.sqrt(x+y);
}
/***
* 获取距离最小的索引
* @param ds
* @return
*/
private int getMinIndex(Double[] ds){
Map<Tuple2<Double,Double>,Double> map= new HashMap<Tuple2<Double, Double>, Double>();
int index=0;
double minVlaue=ds[0];
int size = ds.length;
for(int i=1;i<size;i++){
if(ds[i]<minVlaue){
minVlaue = ds[i];
index =i;
}
}
return index;
}
/**
* 获取中心点
* @param ds
* @return
*/
private Tuple2<Double,Double> getClusterPoint(ArrayList<Tuple2<Double,Double>> ds) throws Exception {
double x=0D,y=0D;
int size = ds.size();
if(size == 0){
throw new Exception("除数不能为零");
}
for(Tuple2<Double,Double> t: ds){
x += t._1();
y += t._2();
}
return new Tuple2<Double, Double>(x/size,y/size);
}
/***
* 如果中心点不在变化,表示已经收敛。
* @param ck1
* @param ck2
* @return
*/
private boolean equalsArray(ArrayList<Tuple2<Double,Double>> ck1,ArrayList<Tuple2<Double,Double>> ck2){
Map<Tuple2<Double,Double>,Integer> dic = new HashMap<Tuple2<Double, Double>, Integer>();
for(Tuple2<Double,Double> t : ck1){
dic.put(t,1);
}
boolean isEqual =true;
for(Tuple2<Double,Double> t: ck2){
if(!dic.containsKey(t)){
isEqual =false;
break;
}
}
return isEqual;
}
/***
* 收敛算法,通过迭代次数控制最优解,
* @param iterationNums
*/
public void converage(int iterationNums) throws Exception {
Map<Integer,ArrayList<Tuple2<Double,Double>>> tmpResult = new HashMap<Integer, ArrayList<Tuple2<Double, Double>>>();
ArrayList<Tuple2<Double,Double>> clusterPoints =null;
ArrayList<Tuple2<Double,Double>> prevClusterPoints =null;
int size = this.scData.size();
for(int iter=0; iter<iterationNums;iter++){
int tmpSize = tmpResult.size();
if(tmpSize ==0){
clusterPoints = this.getInitClusterK(this.cluster_k);
}else{
prevClusterPoints = clusterPoints;
clusterPoints = new ArrayList<Tuple2<Double, Double>>();
for(int cluster=0;cluster<cluster_k;cluster++){
clusterPoints.add(this.getClusterPoint(tmpResult.get(cluster)));
}
/**
* 如果当前的中心点和上一次的中心点相等的话,则跳出循环。
*/
boolean isEqual = this.equalsArray(prevClusterPoints,clusterPoints);
if(isEqual){
break;
}
//重新初始化间数据点
tmpResult = new HashMap<Integer, ArrayList<Tuple2<Double, Double>>>();
}
System.out.println("迭代的次数"+(iter+1));
this.displayClusterPoints(clusterPoints);
System.out.println("开始聚类。。。。");
/***
* 重复此步骤,直到收敛
*/
for(int sc=0;sc<size;sc++){
Double[] curPointArr = new Double[this.cluster_k];
Tuple2<Double,Double> curPoint = this.scData.get(sc);
for(int cluster=0;cluster<cluster_k;cluster++){
curPointArr[cluster] = this.getDistance(curPoint,clusterPoints.get(cluster));
}
/**
* 根据计算的距离,决定当前的数据点应该分配到哪个集合中
*/
int clusterKey = this.getMinIndex(curPointArr);
if(tmpResult.containsKey(clusterKey)){
ArrayList<Tuple2<Double,Double>> val = tmpResult.get(clusterKey);
val.add(curPoint);
}else{
ArrayList<Tuple2<Double,Double>> val = new ArrayList<Tuple2<Double, Double>>();
val.add(curPoint);
tmpResult.put(clusterKey,val);
}
}
}
this.result = tmpResult;
}
/***
* 展示聚类的结果
* @param r
*/
public void displayResult(){
for(Integer key: this.result.keySet()){
StringBuilder line = new StringBuilder();
line.append("第"+key+"组: {");
for(Tuple2<Double,Double> t: this.result.get(key)){
line.append("("+t._1()+","+t._2()+") ");
}
line.append(" }");
System.out.println(line.toString());
}
}
/**
* 展示中心点
* @param r
*/
private void displayClusterPoints(ArrayList<Tuple2<Double,Double>> r){
StringBuilder stringBuilder = new StringBuilder();
for(Tuple2<Double,Double> t: r){
stringBuilder.append("("+t._1()+","+t._2()+") ");
}
System.out.println(stringBuilder.toString());
}
}
package com.fandong;
import com.fandong.algorithm.KMeans;
import scala.Tuple2;
import java.util.ArrayList;
public class KMeansDemo {
public static void main(String[] args) throws Exception {
ArrayList<Tuple2<Double,Double>> sc = new ArrayList<Tuple2<Double, Double>>();
sc.add(new Tuple2<Double, Double>(0D,0D));
sc.add(new Tuple2<Double, Double>(1D,3D));
sc.add(new Tuple2<Double, Double>(3D,1D));
sc.add(new Tuple2<Double, Double>(8D,9D));
sc.add(new Tuple2<Double, Double>(10D,8D));
sc.add(new Tuple2<Double, Double>(9D,10D));
KMeans kMeans = new KMeans(2,sc);
kMeans.converage(1000);
kMeans.displayResult();
}
}
Result: