卜若的代码笔记-Deeplerning-第三章:浅层手写数字识别网络-续>批次喂数据

1 我们需要在数据源里面提供两个接口获得批次的接口

在此之前,我们还是来看一下我们的结构:

执行上下文Context,网络Net,数据DataAnalyser

我们修改一下数据源的供给,采用随机供给方式

然后我们在,我们现在分为50批次:

 

看一下训练效果:

直接喂训练集的数据和标签,其结果为

 我们现在尝试通过测试集去测试,看下准确率:

小小的修改一下数据源:

import sys
sys.path.append("..")
import web
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import time
from scipy import misc
import os
import json
import Dater as dater
from PIL import Image
import random
class getDC():


    def __init__(self):


        self.dc = []
        self.label = []
        mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

        trainSet = mnist.train.images
        labelSet = mnist.train.labels

        self.testData = []
        self.testLabel = []

        for i in range(50000):
            img = trainSet[i]
            img = img.reshape(1, 28*28)
            img = img * 255
            self.dc.append(img)
            self.label.append(labelSet[i])
            pass
        self.dc = np.array(self.dc)
        self.dc = self.dc.reshape([50000,784])
        self.label=np.array(self.label)
        self.label =   self.label.reshape([50000,10])

        for i in range(10000):
            tempImg = mnist.test.images[i]
            tempImg = tempImg.reshape(1,28*28)
            tempImg = tempImg*255
            self.testData.append(tempImg)
            self.testLabel.append(mnist.test.labels[i])

            pass

        self.testData = np.array(self.testData)
        self.testData = self.testData.reshape([10000, 784])
        self.testLabel = np.array(self.testLabel)
        self.testLabel = self.testLabel.reshape([10000, 10])


        pass

    def getBatch(self,scale):

        tempXBatch = []
        index = random.randint(0,(50000/scale) -1)
        for i in range(scale):
            tempXBatch.append(self.dc[index * scale + i])
            pass
        tempXBatch = np.array(tempXBatch)
        tempXBatch = tempXBatch.reshape([scale,784])

        tempXBatch2 = []

        for i in range(scale):
            tempXBatch2.append(self.label[index * scale + i])
            pass
        tempXBatch2 = np.array(tempXBatch2)
        tempXBatch2 = tempXBatch2.reshape([scale, 10])

        return tempXBatch,tempXBatch2







# //dc = getDC()
# #
# img = dc.dc[0].reshape([28,28])
# plt.figure(0)
# plt.imshow(img)
# plt.show()
# print(img)


主要是将测试数据和标签弄进去:

比较核心的是这一串代码,其实很简单,就是获得数据之后,从新组织一下数据结构就ok了!

我们来看一下准确率:

 

能够达到84.6%,其实这个已经算是不错了,当然,依旧可以提高 。

我们之前的浅层神经网络是不是只有一层W,现在我们再添加一层W2。

这个问题我们会在下一章进行讨论。

发布了202 篇原创文章 · 获赞 10 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/qq_37080133/article/details/102325396