吉吉:
这篇文章是实现简单手写数字的识别,数字的相关信息全部存在mnist数据集中,可以在网上自行下载,最后应得到的是csv格式的文件,实现功能之前先跟我看看数据集给了我们哪些信息,come on.........
数据初探
data_file = open("mnist_dataset/mnist_train_100.csv", 'r')
data_list = data_file.readlines()
data_file.close()
len(data_list)
100
有100行数据,具体看看某一行数据(第二行),data_list[1]
'0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,51,159,253,159,50,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,48,238,252,252,252,237,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,54,227,253,252,239,233,252,57,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10,60,224,252,253,252,202,84,252,253,122,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,163,252,252,252,253,252,252,96,189,253,167,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,51,238,253,253,190,114,253,228,47,79,255,168,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,48,238,252,252,179,12,75,121,21,0,0,253,243,50,0,0,0,0,0,0,0,0,0,0,0,0,0,38,165,253,233,208,84,0,0,0,0,0,0,253,252,165,0,0,0,0,0,0,0,0,0,0,0,0,7,178,252,240,71,19,28,0,0,0,0,0,0,253,252,195,0,0,0,0,0,0,0,0,0,0,0,0,57,252,252,63,0,0,0,0,0,0,0,0,0,253,252,195,0,0,0,0,0,0,0,0,0,0,0,0,198,253,190,0,0,0,0,0,0,0,0,0,0,255,253,196,0,0,0,0,0,0,0,0,0,0,0,76,246,252,112,0,0,0,0,0,0,0,0,0,0,253,252,148,0,0,0,0,0,0,0,0,0,0,0,85,252,230,25,0,0,0,0,0,0,0,0,7,135,253,186,12,0,0,0,0,0,0,0,0,0,0,0,85,252,223,0,0,0,0,0,0,0,0,7,131,252,225,71,0,0,0,0,0,0,0,0,0,0,0,0,85,252,145,0,0,0,0,0,0,0,48,165,252,173,0,0,0,0,0,0,0,0,0,0,0,0,0,0,86,253,225,0,0,0,0,0,0,114,238,253,162,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,85,252,249,146,48,29,85,178,225,253,223,167,56,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,85,252,252,252,229,215,252,252,252,196,130,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,28,199,252,252,253,252,252,233,145,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,25,128,252,253,252,141,37,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0'
看起来密密麻麻有点头大,不方,让我来告诉你其中的秘密,其实呢,第一个数字是0,这代表着标签,其余的784个数字是构成图像像素的颜色值,所以这些值处在0-255之间,你可能还是有点蒙蔽,来来,接着看吧。
all_values = data_list[1].split(',')
image_array = numpy.asfarray(all_values[1:]).reshape((28,28))#asfarray这个函数是将文本字符串转换成实数,并创建这些数字的数组。
matplotlib.pyplot.imshow(image_array, cmap='Greys', interpolation='None')
此刻你应该恍然大悟了,下一步将非标签数据范围控制在0.01-1,为啥不控制在0-1,因为0的存在无法导致权重的更新,所以选一个略大于0的数代替0就好啦。
scaled_input = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
print(scaled_input)
[0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.208 0.62729412 0.99223529 0.62729412 0.20411765
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.19635294 0.934
0.98835294 0.98835294 0.98835294 0.93011765 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.21964706 0.89129412 0.99223529 0.98835294 0.93788235
0.91458824 0.98835294 0.23129412 0.03329412 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.04882353 0.24294118 0.87964706
0.98835294 0.99223529 0.98835294 0.79423529 0.33611765 0.98835294
0.99223529 0.48364706 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.64282353 0.98835294 0.98835294 0.98835294 0.99223529
0.98835294 0.98835294 0.38270588 0.74376471 0.99223529 0.65835294
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.208 0.934
0.99223529 0.99223529 0.74764706 0.45258824 0.99223529 0.89517647
0.19247059 0.31670588 1. 0.66223529 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.19635294 0.934 0.98835294 0.98835294 0.70494118
0.05658824 0.30117647 0.47976471 0.09152941 0.01 0.01
0.99223529 0.95341176 0.20411765 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.15752941 0.65058824
0.99223529 0.91458824 0.81752941 0.33611765 0.01 0.01
0.01 0.01 0.01 0.01 0.99223529 0.98835294
0.65058824 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.03717647 0.70105882 0.98835294 0.94176471 0.28564706
0.08376471 0.11870588 0.01 0.01 0.01 0.01
0.01 0.01 0.99223529 0.98835294 0.76705882 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.23129412
0.98835294 0.98835294 0.25458824 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.99223529 0.98835294 0.76705882 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.77870588 0.99223529 0.74764706
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 1. 0.99223529
0.77094118 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.30505882 0.96505882 0.98835294 0.44482353 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.99223529 0.98835294 0.58458824 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.34 0.98835294
0.90294118 0.10705882 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.03717647 0.53411765
0.99223529 0.73211765 0.05658824 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.34 0.98835294 0.87576471 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.03717647 0.51858824 0.98835294 0.88352941 0.28564706
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.34 0.98835294 0.57294118 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.19635294 0.65058824
0.98835294 0.68164706 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.34388235 0.99223529
0.88352941 0.01 0.01 0.01 0.01 0.01
0.01 0.45258824 0.934 0.99223529 0.63894118 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.34 0.98835294 0.97670588 0.57682353
0.19635294 0.12258824 0.34 0.70105882 0.88352941 0.99223529
0.87576471 0.65835294 0.22741176 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.34 0.98835294 0.98835294 0.98835294 0.89905882 0.84470588
0.98835294 0.98835294 0.98835294 0.77094118 0.51470588 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.11870588 0.78258824
0.98835294 0.98835294 0.99223529 0.98835294 0.98835294 0.91458824
0.57294118 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.10705882 0.50694118 0.98835294
0.99223529 0.98835294 0.55741176 0.15364706 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 0.01 0.01
0.01 0.01 0.01 0.01 ]
最后一件事情:确定正确的输出值,sigmoid(阈值函数这里我用的是sigmoid哈)的值域为(0,1)但无法到达0和1,所以可以用0.9代表正确数字所在的位置,其余位置用0.1填充。
onodes = 10
targets = numpy.zeros(onodes) + 0.01
targets[int(all_values[0])] = 0.99
print(targets)
[ 0.99 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01]
定义实类
class neuralNetwork:
def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
# set number of nodes in each input, hidden, output layer
self.inodes = inputnodes
self.hnodes = hiddennodes
self.onodes = outputnodes
self.wih = numpy.random.normal(0.0, pow(self.inodes, -0.5), (self.hnodes, self.inodes))
self.who = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.onodes, self.hnodes))
self.lr = learningrate
self.activation_function = lambda x: scipy.special.expit(x)
pass
def train(self, inputs_list, targets_list):
inputs = numpy.array(inputs_list, ndmin=2).T
targets = numpy.array(targets_list, ndmin=2).T
hidden_inputs = numpy.dot(self.wih, inputs)
hidden_outputs = self.activation_function(hidden_inputs)
final_inputs = numpy.dot(self.who, hidden_outputs)
final_outputs = self.activation_function(final_inputs)
output_errors = targets - final_outputs
hidden_errors = numpy.dot(self.who.T, output_errors)
self.who += self.lr * numpy.dot((output_errors * final_outputs * (1.0 - final_outputs)), numpy.transpose(hidden_outputs))
self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), numpy.transpose(inputs))
pass
def query(self, inputs_list):
inputs = numpy.array(inputs_list, ndmin=2).T
hidden_inputs = numpy.dot(self.wih, inputs)
hidden_outputs = self.activation_function(hidden_inputs)
final_inputs = numpy.dot(self.who, hidden_outputs)
final_outputs = self.activation_function(final_inputs)
return final_outputs
理解起来其实也很简单,前提是前向传播,反向传播过程动手推过一边,我在这就不多比比。
定义训练次数函数
epochs = 5
for e in range(epochs):
for record in training_data_list:
all_values = record.split(',')
inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
targets = numpy.zeros(output_nodes) + 0.01
targets[int(all_values[0])] = 0.99
n.train(inputs, targets)
pass
pass
定义输出输入节点个数,学习率
input_nodes = 784
hidden_nodes = 200
output_nodes = 10
learning_rate = 0.1
n = neuralNetwork(input_nodes,hidden_nodes,output_nodes, learning_rate)
载入训练集和测试集
test_data_file = open("mnist_dataset/mnist_test_10.csv", 'r')
test_data_list = test_data_file.readlines()
test_data_file.close()
training_data_file = open("mnist_dataset/mnist_train_100.csv", 'r')
training_data_list = training_data_file.readlines()
training_data_file.close()
看看我们要预测的值
all_values = test_data_list[0].split(',')
print(all_values[0])
7
image_array = numpy.asfarray(all_values[1:]).reshape((28,28))
matplotlib.pyplot.imshow(image_array,cmap='Greys',interpolation='None')
n.query((numpy.asfarray(all_values[1:])/255.0*0.99)+0.01)
array([[0.09225031],
[0.03115615],
[0.07405099],
[0.08403405],
[0.08896464],
[0.04145529],
[0.0151916 ],
[0.74246098],
[0.08593406],
[0.08722363]])
可以看出的确是index=7数字最大,这个神经网络还不错哟。
接下来写一个得分函数,预测正确得1分,错误为0分,看看前十个数字预测对的百分比
scorecard = []
for record in test_data_list:
all_values = record.split(',')
correct_label = int(all_values[0])
print(correct_label,'correct label')
inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
outputs = n.query(inputs)
label = numpy.argmax(outputs)
print(label,'network is answer')
if (label == correct_label):
scorecard.append(1)
else:
scorecard.append(0)
pass
pass
print(scorecard)
7 correct label
7 network is answer
2 correct label
0 network is answer
1 correct label
1 network is answer
0 correct label
0 network is answer
4 correct label
4 network is answer
1 correct label
1 network is answer
4 correct label
4 network is answer
9 correct label
4 network is answer
5 correct label
4 network is answer
9 correct label
7 network is answer
[1, 0, 1, 1, 1, 1, 1, 0, 0, 0]
计算正确率:
scorecard_array = numpy.asarray(scorecard)
print ("performance = ", scorecard_array.sum() / scorecard_array.size)
performance = 0.6