python手写kmeans 简单易懂

import matplotlib.pyplot as plt
class kmeans:
    def __init__(self,k,n,node_list,iterations):
        self.k = k
        self.n = n
        self.all_nodes = node_list
        self.iter = iterations
        self.center = []
        self.res = []
        
    def draw(self,input_list):
        if len(input_list) == self.n:
            for each_node in input_list:
                x = each_node[0]
                y = each_node[1]
                plt.scatter(x,y)
            plt.show()
        else:
            color_dict = {0:'r',1:'b',2:'g',3:'black',4:'yellow'}
            for index,each_list in enumerate(input_list):
                for each_node in each_list:
                    x = each_node[0]
                    y = each_node[1]
                    plt.scatter(x,y,color=color_dict[index])
            plt.show()
    def distance(self,x,y):
        res = 0
        for i in range(len(x)):
            res += (x[i]-y[i])**2
        return res
    def get_mean(self,input_nodelist):
        
        new_center = [0 for _ in range(len(input_nodelist[0]))]
        length = len(input_nodelist)
        for each_node in input_nodelist:
            for i in range(len(each_node)):
                new_center[i] += each_node[i]
        for i in range(len(each_node)):
            new_center[i] = new_center[i] / length
        return new_center
    def exist(self,target,input_list):
        x = target[0]
        y = target[1]
        for each_node in input_list:
            tempx = each_node[0]
            tempy = each_node[1]
            if tempx == x and tempy == y:
                return True
        return False
                
    def compute(self):
        #暂先选取前k个作为初始的k个中心
        self.draw(self.all_nodes)
        if self.center == []:
            for i in range(self.k):
                self.center.append(self.all_nodes[i])
        #开始迭代
        index = 0
        while index <= self.iter:
            index += 1
            self.res = []
            for each_center in self.center:
                self.res.append([each_center])
                
            #计算不在self.center的节点中,每一个与这k个中心的距离,
            #记录其中最小的距离对应的中心
            #将它划分到k个中心中
            for each_node in self.all_nodes:
                # 对每个其余节点 计算k个距离 记录最小距离以及对应的中心店
                #print(self.center)
                if not self.exist(each_node,self.center):
                    tempmin = None
                    tempcenter = None
                    tempindex = -1
                    for center_index,each_center in enumerate(self.center):
                        tempres = self.distance(each_node,each_center)
                        if tempmin == None or tempres < tempmin:
                            tempmin = tempres
                            tempcenter = each_center   
                            tempindex = center_index
                    self.res[tempindex].append(each_node)
                    #getattr(self,'NodeList_'+str(list_index)).append(each_node)
            new_center = []
            #计算每个NodeList中样本坐标的均值
            for cluster_index in range(self.k):
                #print(getattr(self,'NodeList_'+str(cluster_index)))
                this_center = self.get_mean(self.res[cluster_index]) 
                new_center.append(this_center)
            self.center = new_center
        #print(self.res)
        self.draw(self.res)
        return self.res


#测试
import numpy as np
node_list = np.random.randn(300,2)#[[1,2],[3,5],[6,2],[7,3],[11,0],[9,3],[1,6],[7,2],[2,12],[3,8]]
a = kmeans(4,len(node_list),node_list,10)
reslist = a.compute()

初始
在这里插入图片描述
聚类后
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_41545780/article/details/107588865
今日推荐