python: 神经网络实现MNIST图像识别

神经网络输入层神经单元个数:784      (图像大小28*28)

               输出层                      :10        (10个类别分类,即10个数字)

              隐藏层个数                 :2

     第1个隐藏层的神经单元数   :50

     第2个隐藏层的神经单元数   :100

先定义get_data()、init_network()predict()这3个函数:

 1 def get_data():
 2     (x_train,t_train),(x_test,t_test)=load_mnist(nomalize=True,flatten=True,one_hot_label=False)
 3     return x_test,t_test
 4 
 5 def init_natwork():
 6     with open("sample_weight.pkl",'rb') as f:
 7         network=pickle.load(f)
 8     return network
 9 
10 def predict(network,x):
11     W1,W2,W3=network['W1'],network['W2'],network['W3']
12     b1,b2,b3=network['b1'],network['b2'],network['b3']
13     a1 = np.dot(x, W1) + b1
14     z1 = sigmoid(a1)
15     a2 = np.dot(z1, W2) + b2
16     z2 = sigmoid(a2)
17     a3 = np.dot(z2, W3) + b3
18     y = softmax(a3)
19     return y

init_network() 会读入保存在 pickle 文件 sample_weight.pkl 中的学习到的权重参数 {8[因为之前我们假设学习已经完成,所以学习到的参数被保存下来。假设保存在 sample_weight.pkl 文件中,在推理阶段,我们直接加载这些已经学习到的参数。——译者注]}。这个文件中以字典变量的形式保存了权重和偏置参数。剩余的 2 个函数,和前面介绍的代码实现基本相同,无需再解释。现在,我们用这 3 个函数来实现神经网络的推理处理。然后,评价它的识别精度(accuracy),即能在多大程度上正确分类。

 1 x, t = get_data()
 2 network = init_network()
 3 
 4 accuracy_cnt = 0
 5 for i in range(len(x)):
 6     y = predict(network, x[i])
 7     p = np.argmax(y) # 获取概率最高的元素的索引
 8     if p == t[i]:
 9         accuracy_cnt += 1
10 
11 print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

首先获得 MNIST 数据集,生成网络。接着,用 for 语句逐一取出保存在 x 中的图像数据,用 predict() 函数进行分类。predict() 函数以 NumPy 数组的形式输出各个标签对应的概率。比如输出 [0.1, 0.3, 0.2, ..., 0.04]的数组,该数组表示“0”的概率为 0.1,“1”的概率为 0.3,等等。然后,我们取出这个概率列表中的最大值的索引(第几个元素的概率最高),作为预测结果。可以用 np.argmax(x) 函数取出数组中的最大值的索引,np.argmax(x) 将获取被赋给参数 x 的数组中的最大值元素的索引。最后,比较神经网络所预测的答案和正确解标签,将回答正确的概率作为识别精度。

下面我们进行基于批处理的代码实现。这里用粗体显示与之前的实现的不同之处。

 1 x, t = get_data()
 2 network = init_network()
 3 
 4 batch_size = 100 # 批数量
 5 accuracy_cnt = 0
 6 
 7 for i in range(0, len(x), batch_size):
 8     x_batch = x[i:i+batch_size]
 9     y_batch = predict(network, x_batch)
10     p = np.argmax(y_batch, axis=1)
11     accuracy_cnt += np.sum(p == t[i:i+batch_size])

我们来逐个解释粗体的代码部分。首先是 range() 函数。range() 函数若指定为 range(start, end),则会生成一个由 start 到 end-1 之间的整数构成的列表。若像 range(start, end, step) 这样指定 3 个整数,则生成的列表中的下一个元素会增加 step 指定的值。我们来看一个例子。

>>> list( range(0, 10) )
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> list( range(0, 10, 3) )
[0, 3, 6, 9]

 range() 函数生成的列表的基础上通过 x[i:i+batch_size] 从输入数据中抽出批数据x[i:i+batch_n] 会取出从第 i 个到第 i+batch_n 个之间的数据。本例中是像 x[0:100]x[100:200]……这样,从头开始以 100 为单位将数据提取为批数据。

然后,通过 argmax() 获取值最大的元素的索引。不过这里需要注意的是,我们给定了参数 axis=1。这指定了在 100 × 10 的数组中,沿着第 1 维方向(以第 1 维为轴)找到值最大的元素的索引(第 0 维对应第 1 个维度){9[矩阵的第 0 维是列方向,第 1 维是行方向。——译者注]}。这里也来看一个例子。

>>> x = np.array([[0.1, 0.8, 0.1], [0.3, 0.1, 0.6],
...     [0.2, 0.5, 0.3], [0.8, 0.1, 0.1]])
>>> y = np.argmax(x, axis=1)
>>> print(y)
[1 2 1 0]

最后,我们比较一下以批为单位进行分类的结果和实际的答案。为此,需要在 NumPy 数组之间使用比较运算符(==)生成由 True/False 构成的布尔型数组,并计算 True 的个数。我们通过下面的例子进行确认。

>>> y = np.array([1, 2, 1, 0])
>>> t = np.array([1, 2, 0, 0])
>>> print(y==t)
[True True False True]
>>> np.sum(y==t)
3

仔细看过去,有很多内容之前不懂,看懂以后心里豁然开朗,不得不说真的是讲得很棒,清晰又容易理解。
 

猜你喜欢

转载自www.cnblogs.com/taoyuxin/p/11440494.html