算法 聚类(划分)K-means

什么是k-means 算法?

 是聚类算法的一种,聚类算法中的划分聚类算法,属于无监督学习方法。

算法思想?

1、首先确定划分群体的个数k

2、随机从数据集中选取k个中心点,这里使用欧几里德距离公式计算数据集任意一点到中心点的距离,把距离中心点近的划分一类。

3、重复步骤2.从第一次的群体中找到相应的中心点,在进行群体划分。

4、重复步骤2、3 直到 数据收敛或满足迭代次数。

----

距离假设我们有一组二维数据

P0(0,0) P1(1,3) P2(3,1) P3(8,9) P4(10,8) P5(9,10)

首先我们随机选取k=2  的群体假设中心点为 p0, p3

计算每个点与中心点的距离。

min(p1 ->p0 ,p1->p3) 取 距离近的,假设取p1。其他依次类推。。。 

----

package com.fandong.algorithm;

import scala.Tuple2;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/***
 * 针对二维进行聚合
 */
public class KMeans {

    //预先选择的数据集
    private int cluster_k;

    //分组存储的结果值
    private Map<Integer,ArrayList<Tuple2<Double,Double>>> result;

    //原始数据
    private ArrayList<Tuple2<Double,Double>> scData;

    public KMeans(int cluster_k,ArrayList<Tuple2<Double,Double>> sc){
        this.cluster_k = cluster_k;
        this.scData = sc;
    }

    /***
     * 使用随机数获取k个中心点
     * @param cluster_k
     * @return
     */
    public ArrayList<Tuple2<Double,Double>> getInitClusterK(int cluster_k){
        ArrayList<Tuple2<Double,Double>>  r = new ArrayList<Tuple2<Double, Double>>();
       int size = this.scData.size()-1;
       for(int i=0;i<cluster_k;i++){
           int index = (int)(Math.random()*size);
          r.add(this.scData.get(index));
       }
       return r;
    }

    /**
     * 获取距离函数
     * @param s1  数据集点
     * @param cluster_point 集群点
     * @return
     */
    private double getDistance(Tuple2<Double,Double> s1, Tuple2<Double,Double> cluster_point){
        double x=0D,y=0D;
        x = (s1._1() - cluster_point._1())*(s1._1()-cluster_point._1());
        y =  (s1._2() -cluster_point._2()) *(s1._2()-cluster_point._2());
        return  Math.sqrt(x+y);
    }


    /***
     * 获取距离最小的索引
     * @param ds
     * @return
     */
    private int getMinIndex(Double[] ds){
        Map<Tuple2<Double,Double>,Double> map= new HashMap<Tuple2<Double, Double>, Double>();
        int index=0;
        double minVlaue=ds[0];
        int size = ds.length;
        for(int i=1;i<size;i++){
            if(ds[i]<minVlaue){
                minVlaue = ds[i];
                index =i;
            }
        }
        return index;
    }

    /**
     * 获取中心点
     * @param ds
     * @return
     */
    private Tuple2<Double,Double> getClusterPoint(ArrayList<Tuple2<Double,Double>> ds) throws Exception {
        double x=0D,y=0D;
        int size = ds.size();
        if(size == 0){
            throw new Exception("除数不能为零");
        }
        for(Tuple2<Double,Double> t: ds){
            x += t._1();
            y += t._2();
        }

        return new Tuple2<Double, Double>(x/size,y/size);

    }


    /***
     * 如果中心点不在变化,表示已经收敛。
     * @param ck1
     * @param ck2
     * @return
     */
    private boolean equalsArray(ArrayList<Tuple2<Double,Double>> ck1,ArrayList<Tuple2<Double,Double>> ck2){

        Map<Tuple2<Double,Double>,Integer> dic = new HashMap<Tuple2<Double, Double>, Integer>();

        for(Tuple2<Double,Double> t : ck1){
            dic.put(t,1);
        }
        boolean isEqual =true;
        for(Tuple2<Double,Double> t: ck2){
            if(!dic.containsKey(t)){
                isEqual =false;
                break;
            }
        }
        return isEqual;
    }

    /***
     * 收敛算法,通过迭代次数控制最优解,
     * @param iterationNums
     */
    public void converage(int iterationNums) throws Exception {
      Map<Integer,ArrayList<Tuple2<Double,Double>>> tmpResult = new HashMap<Integer, ArrayList<Tuple2<Double, Double>>>();
      ArrayList<Tuple2<Double,Double>> clusterPoints =null;
        ArrayList<Tuple2<Double,Double>> prevClusterPoints =null;
      int size = this.scData.size();
      for(int iter=0; iter<iterationNums;iter++){

          int tmpSize = tmpResult.size();
          if(tmpSize ==0){
              clusterPoints = this.getInitClusterK(this.cluster_k);
          }else{
              prevClusterPoints = clusterPoints;
              clusterPoints = new ArrayList<Tuple2<Double, Double>>();
              for(int cluster=0;cluster<cluster_k;cluster++){
                 clusterPoints.add(this.getClusterPoint(tmpResult.get(cluster)));
              }
              /**
               * 如果当前的中心点和上一次的中心点相等的话,则跳出循环。
               */
              boolean isEqual = this.equalsArray(prevClusterPoints,clusterPoints);
              if(isEqual){
                  break;
              }

              //重新初始化间数据点
              tmpResult = new HashMap<Integer, ArrayList<Tuple2<Double, Double>>>();

          }

          System.out.println("迭代的次数"+(iter+1));
          this.displayClusterPoints(clusterPoints);
          System.out.println("开始聚类。。。。");
          /***
           * 重复此步骤,直到收敛
           */
          for(int sc=0;sc<size;sc++){
              Double[] curPointArr = new Double[this.cluster_k];
              Tuple2<Double,Double> curPoint = this.scData.get(sc);
              for(int cluster=0;cluster<cluster_k;cluster++){
                  curPointArr[cluster] = this.getDistance(curPoint,clusterPoints.get(cluster));
              }
              /**
               * 根据计算的距离,决定当前的数据点应该分配到哪个集合中
               */
               int clusterKey = this.getMinIndex(curPointArr);
              if(tmpResult.containsKey(clusterKey)){
                  ArrayList<Tuple2<Double,Double>> val = tmpResult.get(clusterKey);
                  val.add(curPoint);
              }else{
                  ArrayList<Tuple2<Double,Double>> val = new ArrayList<Tuple2<Double, Double>>();
                  val.add(curPoint);
                  tmpResult.put(clusterKey,val);
              }
          }
      }

      this.result = tmpResult;

    }

    /***
     * 展示聚类的结果
     * @param r
     */
    public void displayResult(){
       for(Integer key: this.result.keySet()){
           StringBuilder line = new StringBuilder();
           line.append("第"+key+"组: {");
           for(Tuple2<Double,Double> t: this.result.get(key)){
               line.append("("+t._1()+","+t._2()+") ");
           }
           line.append(" }");
           System.out.println(line.toString());
       }
    }

    /**
     * 展示中心点
     * @param r
     */
    private void displayClusterPoints(ArrayList<Tuple2<Double,Double>> r){
        StringBuilder stringBuilder = new StringBuilder();
        for(Tuple2<Double,Double> t: r){
            stringBuilder.append("("+t._1()+","+t._2()+") ");
        }
        System.out.println(stringBuilder.toString());
    }
}
package com.fandong;

import com.fandong.algorithm.KMeans;
import scala.Tuple2;

import java.util.ArrayList;

public class KMeansDemo {
    public static void main(String[] args) throws Exception {
        ArrayList<Tuple2<Double,Double>> sc = new ArrayList<Tuple2<Double, Double>>();
        sc.add(new Tuple2<Double, Double>(0D,0D));
        sc.add(new Tuple2<Double, Double>(1D,3D));
        sc.add(new Tuple2<Double, Double>(3D,1D));
        sc.add(new Tuple2<Double, Double>(8D,9D));
        sc.add(new Tuple2<Double, Double>(10D,8D));
        sc.add(new Tuple2<Double, Double>(9D,10D));

        KMeans kMeans = new KMeans(2,sc);

        kMeans.converage(1000);

        kMeans.displayResult();

    }
}

Result:

发布了61 篇原创文章 · 获赞 1 · 访问量 663

猜你喜欢

转载自blog.csdn.net/u012842247/article/details/103466682