【原创】python实现RBF神经网络识别Mnist数据集

版权声明:本文为博主ExcelMann的原创文章,未经博主允许不得转载。

python实现RBF神经网络识别Mnist数据集

作者:ExcelMann,转载需注明。

话不多说,直接贴代码,代码有注释。

# Author:Xuangan, Xu
# Data:2020-11-11

"""
RBF神经网络
-----------------
设计RBF神经网络实现手写数字的识别问题
数据集:Mnist数据集
"""

import os
import struct
import math
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from scipy import *
from scipy.linalg import norm, pinv
from sklearn.cluster import KMeans

class RBFNet:

    def __init__(self,sample_num,input_nodes,hidden_nodes,output_nodes):
        # 初始化网络各层结点个数
        self.i_nodes = input_nodes  # 即每个样本的维数
        self.h_nodes = hidden_nodes # 即中心点个数
        self.o_nodes = output_nodes

        # 初始化中心节点、方差、权重以及矩阵G(G[i][j]表示第i个样本在第j个中心节点的高斯函数值)
        self.centers = [random.uniform(-1,1,input_nodes) for i in range(hidden_nodes)]
        self.beta = 0
        self.w = random.uniform(-0.5,0.5,(hidden_nodes,output_nodes))
        self.G = np.zeros((sample_num,hidden_nodes))

    def culMse(self,pre_y,y):
        """
        计算均方误差
        :param pre_y: 预测值(60000 X 10)
        :param y: 期望值(60000 X 10)
        """
        totalError = 0
        for i in range(len(y)):
            for j in range(10):
                totalError += (y[i][j]-pre_y[i][j])**2
        return totalError/len(y)

    def calBasis(self,X,Xi):
        """
        计算隐含层结点值
        :param X:输入值,大小为维度
        :param Xi:第i个中心点,大小为维度
        :return:返回该隐含层结点的值
        """
        return exp((-np.linalg.norm(X-Xi)**2)/2*self.beta**2)

    def calHiddenValue(self,X):
        """
        计算矩阵G
        :param X:输入样本X
        """
        for ci,c in enumerate(self.centers):
            for xi,x in enumerate(X):
                self.G[xi][ci] = self.calBasis(c, x)

    def train(self,X,Y):
        """
        训练网络,由闭式解得到网络权重w
        :param X:输入样本X,大小为(样本个数, 784)
        :param Y:标签数据Y,大小为(样本个数,10)
        """
        # 选择中心点
        estimator = KMeans(n_clusters=1000)  # 划分为1000类
        estimator.fit(X)
        self.centers = estimator.cluster_centers_  # 得到1000个类别的各自中心点
        # 更新方差beta
        max_distance = 0  # 所选取中心之间的最大距离
        for i in range(self.h_nodes):
            for j in range(self.h_nodes-i):
                dis = np.linalg.norm(self.centers[i]-self.centers[j])
                if(dis > max_distance):
                    max_distance = dis
        self.beta = max_distance/sqrt(2*self.h_nodes)
        # 计算矩阵G
        self.calHiddenValue(X)
        # 由闭式解,得到网络权重w
        self.w = dot(pinv(dot(self.G.T,self.G)), self.G.T).dot(Y)
        # 计算loss
        pre_y = self.G.dot(self.w)
        print(f"loss:{self.culMse(pre_y,Y)}")

    def estimate(self,X,Y):
        """
        测试数据,返回预测准确率
        :param X: 测试数据,200 X 784维
        :param Y: 测试数据的标签,200 X 10维
        """
        hidden_v = np.zeros(self.h_nodes)   # 隐含层结点值
        correct_num = 0
        # 遍历测试样本
        for i in range(X.shape[0]):
            # 计算该测试样本对应的隐藏层结点值
            for j in range(self.h_nodes):
                hidden_v[j] = self.calBasis(X[i],self.centers[j])
            output = hidden_v.dot(self.w)   # 网络输出值ouput
            pre_y = np.argmax(output)   # 预测标签
            y = np.argmax(Y[i])    # 期望标签
            if y == pre_y:
                correct_num = correct_num+1
        return correct_num/X.shape[0]


if __name__ == "__main__":
    # 通过tensorflow读取mnist数据,并对读取到的数据进行处理
    mnist = tf.keras.datasets.mnist
    (train_x,train_y),(test_x,test_y) = mnist.load_data()
    # 将图像数据转为0-1的范围
    train_data = np.zeros((60000, 784))
    train_label = np.zeros((60000, 10))
    test_data = np.zeros((10000, 784))
    test_label = np.zeros((10000, 10))
    for i in range(60000):  # 处理训练数据
        train_data[i] = (np.array(train_x[i]).flatten())/255
        temp = np.zeros(10)
        temp[train_y[i]] = 1
        train_label[i] = temp
    for i in range(10000):  # 处理测试数据
        test_data[i] = (np.array(test_x[i]).flatten())/255
        temp = np.zeros(10)
        temp[test_y[i]] = 1
        test_label[i] = temp

    # RBF网络的输入、隐含、输出层结点个数
    input_nodes = 784
    hidden_nodes = 1000
    output_nodes = 10
    sample_num = 60000

    rbfNet = RBFNet(sample_num,input_nodes,hidden_nodes,output_nodes)
    rbfNet.train(train_data,train_label)

    accuracy = rbfNet.estimate(test_data,test_label)
    print(f"准确率:{accuracy}")

猜你喜欢

转载自blog.csdn.net/a602389093/article/details/109613455