实现简单的神经网络(mnist数据集)

吉吉:

这篇文章是实现简单手写数字的识别,数字的相关信息全部存在mnist数据集中,可以在网上自行下载,最后应得到的是csv格式的文件,实现功能之前先跟我看看数据集给了我们哪些信息,come on.........

数据初探

data_file = open("mnist_dataset/mnist_train_100.csv", 'r')
data_list = data_file.readlines()
data_file.close()
len(data_list)
100

有100行数据,具体看看某一行数据(第二行),data_list[1]

'0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,51,159,253,159,50,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,48,238,252,252,252,237,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,54,227,253,252,239,233,252,57,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10,60,224,252,253,252,202,84,252,253,122,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,163,252,252,252,253,252,252,96,189,253,167,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,51,238,253,253,190,114,253,228,47,79,255,168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,48,238,252,252,179,12,75,121,21,0,0,253,243,50,0,0,0,0,0,0,0,0,0,0,0,0,0,38,165,253,233,208,84,0,0,0,0,0,0,253,252,165,0,0,0,0,0,0,0,0,0,0,0,0,7,178,252,240,71,19,28,0,0,0,0,0,0,253,252,195,0,0,0,0,0,0,0,0,0,0,0,0,57,252,252,63,0,0,0,0,0,0,0,0,0,253,252,195,0,0,0,0,0,0,0,0,0,0,0,0,198,253,190,0,0,0,0,0,0,0,0,0,0,255,253,196,0,0,0,0,0,0,0,0,0,0,0,76,246,252,112,0,0,0,0,0,0,0,0,0,0,253,252,148,0,0,0,0,0,0,0,0,0,0,0,85,252,230,25,0,0,0,0,0,0,0,0,7,135,253,186,12,0,0,0,0,0,0,0,0,0,0,0,85,252,223,0,0,0,0,0,0,0,0,7,131,252,225,71,0,0,0,0,0,0,0,0,0,0,0,0,85,252,145,0,0,0,0,0,0,0,48,165,252,173,0,0,0,0,0,0,0,0,0,0,0,0,0,0,86,253,225,0,0,0,0,0,0,114,238,253,162,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,85,252,249,146,48,29,85,178,225,253,223,167,56,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,85,252,252,252,229,215,252,252,252,196,130,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,28,199,252,252,253,252,252,233,145,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,25,128,252,253,252,141,37,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0'

看起来密密麻麻有点头大,不方,让我来告诉你其中的秘密,其实呢,第一个数字是0,这代表着标签,其余的784个数字是构成图像像素的颜色值,所以这些值处在0-255之间,你可能还是有点蒙蔽,来来,接着看吧。

all_values = data_list[1].split(',')
image_array = numpy.asfarray(all_values[1:]).reshape((28,28))#asfarray这个函数是将文本字符串转换成实数,并创建这些数字的数组。
matplotlib.pyplot.imshow(image_array, cmap='Greys', interpolation='None')

此刻你应该恍然大悟了,下一步将非标签数据范围控制在0.01-1,为啥不控制在0-1,因为0的存在无法导致权重的更新,所以选一个略大于0的数代替0就好啦。

scaled_input = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
print(scaled_input)
[0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.208      0.62729412 0.99223529 0.62729412 0.20411765
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.19635294 0.934
 0.98835294 0.98835294 0.98835294 0.93011765 0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.21964706 0.89129412 0.99223529 0.98835294 0.93788235
 0.91458824 0.98835294 0.23129412 0.03329412 0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.04882353 0.24294118 0.87964706
 0.98835294 0.99223529 0.98835294 0.79423529 0.33611765 0.98835294
 0.99223529 0.48364706 0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.64282353 0.98835294 0.98835294 0.98835294 0.99223529
 0.98835294 0.98835294 0.38270588 0.74376471 0.99223529 0.65835294
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.208      0.934
 0.99223529 0.99223529 0.74764706 0.45258824 0.99223529 0.89517647
 0.19247059 0.31670588 1.         0.66223529 0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.19635294 0.934      0.98835294 0.98835294 0.70494118
 0.05658824 0.30117647 0.47976471 0.09152941 0.01       0.01
 0.99223529 0.95341176 0.20411765 0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.15752941 0.65058824
 0.99223529 0.91458824 0.81752941 0.33611765 0.01       0.01
 0.01       0.01       0.01       0.01       0.99223529 0.98835294
 0.65058824 0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.03717647 0.70105882 0.98835294 0.94176471 0.28564706
 0.08376471 0.11870588 0.01       0.01       0.01       0.01
 0.01       0.01       0.99223529 0.98835294 0.76705882 0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.23129412
 0.98835294 0.98835294 0.25458824 0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.99223529 0.98835294 0.76705882 0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.77870588 0.99223529 0.74764706
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       1.         0.99223529
 0.77094118 0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.30505882 0.96505882 0.98835294 0.44482353 0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.99223529 0.98835294 0.58458824 0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.34       0.98835294
 0.90294118 0.10705882 0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.03717647 0.53411765
 0.99223529 0.73211765 0.05658824 0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.34       0.98835294 0.87576471 0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.03717647 0.51858824 0.98835294 0.88352941 0.28564706
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.34       0.98835294 0.57294118 0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.19635294 0.65058824
 0.98835294 0.68164706 0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.34388235 0.99223529
 0.88352941 0.01       0.01       0.01       0.01       0.01
 0.01       0.45258824 0.934      0.99223529 0.63894118 0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.34       0.98835294 0.97670588 0.57682353
 0.19635294 0.12258824 0.34       0.70105882 0.88352941 0.99223529
 0.87576471 0.65835294 0.22741176 0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.34       0.98835294 0.98835294 0.98835294 0.89905882 0.84470588
 0.98835294 0.98835294 0.98835294 0.77094118 0.51470588 0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.11870588 0.78258824
 0.98835294 0.98835294 0.99223529 0.98835294 0.98835294 0.91458824
 0.57294118 0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.10705882 0.50694118 0.98835294
 0.99223529 0.98835294 0.55741176 0.15364706 0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01       0.01       0.01
 0.01       0.01       0.01       0.01      ]

最后一件事情:确定正确的输出值,sigmoid(阈值函数这里我用的是sigmoid哈)的值域为(0,1)但无法到达0和1,所以可以用0.9代表正确数字所在的位置,其余位置用0.1填充。

onodes = 10
targets = numpy.zeros(onodes) + 0.01
targets[int(all_values[0])] = 0.99
print(targets)
[ 0.99  0.01  0.01  0.01  0.01  0.01  0.01  0.01  0.01  0.01]

定义实类

class neuralNetwork:
    
    
    
    def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
        # set number of nodes in each input, hidden, output layer
        self.inodes = inputnodes
        self.hnodes = hiddennodes
        self.onodes = outputnodes
        

        self.wih = numpy.random.normal(0.0, pow(self.inodes, -0.5), (self.hnodes, self.inodes))
        self.who = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.onodes, self.hnodes))


        
        self.lr = learningrate
        
        self.activation_function = lambda x: scipy.special.expit(x)
        
        pass

    

    def train(self, inputs_list, targets_list):
        inputs = numpy.array(inputs_list, ndmin=2).T
        targets = numpy.array(targets_list, ndmin=2).T
        hidden_inputs = numpy.dot(self.wih, inputs)
        hidden_outputs = self.activation_function(hidden_inputs)
        final_inputs = numpy.dot(self.who, hidden_outputs)
        final_outputs = self.activation_function(final_inputs)
        output_errors = targets - final_outputs
        hidden_errors = numpy.dot(self.who.T, output_errors) 
        self.who += self.lr * numpy.dot((output_errors * final_outputs * (1.0 - final_outputs)), numpy.transpose(hidden_outputs))
        self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), numpy.transpose(inputs))
        
        pass

    

    def query(self, inputs_list):

        inputs = numpy.array(inputs_list, ndmin=2).T
        hidden_inputs = numpy.dot(self.wih, inputs)
        hidden_outputs = self.activation_function(hidden_inputs)
        final_inputs = numpy.dot(self.who, hidden_outputs)
        final_outputs = self.activation_function(final_inputs)
        
        return final_outputs

理解起来其实也很简单,前提是前向传播,反向传播过程动手推过一边,我在这就不多比比。

定义训练次数函数

epochs = 5

for e in range(epochs):
    for record in training_data_list:
        all_values = record.split(',')
        inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
        targets = numpy.zeros(output_nodes) + 0.01
        targets[int(all_values[0])] = 0.99
        n.train(inputs, targets)
        pass
    pass

定义输出输入节点个数,学习率

input_nodes = 784
hidden_nodes = 200
output_nodes = 10


learning_rate = 0.1

n = neuralNetwork(input_nodes,hidden_nodes,output_nodes, learning_rate)

载入训练集和测试集

test_data_file = open("mnist_dataset/mnist_test_10.csv", 'r')
test_data_list = test_data_file.readlines()
test_data_file.close()
training_data_file = open("mnist_dataset/mnist_train_100.csv", 'r')
training_data_list = training_data_file.readlines()
training_data_file.close()

看看我们要预测的值

all_values = test_data_list[0].split(',')
print(all_values[0])
7
image_array = numpy.asfarray(all_values[1:]).reshape((28,28))
matplotlib.pyplot.imshow(image_array,cmap='Greys',interpolation='None')

n.query((numpy.asfarray(all_values[1:])/255.0*0.99)+0.01)
array([[0.09225031],
       [0.03115615],
       [0.07405099],
       [0.08403405],
       [0.08896464],
       [0.04145529],
       [0.0151916 ],
       [0.74246098],
       [0.08593406],
       [0.08722363]])

可以看出的确是index=7数字最大,这个神经网络还不错哟。

接下来写一个得分函数,预测正确得1分,错误为0分,看看前十个数字预测对的百分比

scorecard = []


for record in test_data_list:
    all_values = record.split(',')
    correct_label = int(all_values[0])
    print(correct_label,'correct label')
    inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
    outputs = n.query(inputs)
    label = numpy.argmax(outputs)
    print(label,'network is answer')
    if (label == correct_label):
        scorecard.append(1)
    else:
        scorecard.append(0)
        pass
    
    pass
print(scorecard)
7 correct label
7 network is answer
2 correct label
0 network is answer
1 correct label
1 network is answer
0 correct label
0 network is answer
4 correct label
4 network is answer
1 correct label
1 network is answer
4 correct label
4 network is answer
9 correct label
4 network is answer
5 correct label
4 network is answer
9 correct label
7 network is answer
[1, 0, 1, 1, 1, 1, 1, 0, 0, 0]

计算正确率:

scorecard_array = numpy.asarray(scorecard)
print ("performance = ", scorecard_array.sum() / scorecard_array.size)
performance =  0.6

猜你喜欢

转载自blog.csdn.net/weixin_41503009/article/details/83420189