[深度学习]基于TensorFlow的基本深度学习模型

cifar10训练数据集下载

链接:https://pan.baidu.com/s/1Qlp2G5xlECM6dyvUivWnFg
提取码:s32t

前置配置

引入tensorflow库,和其他辅助库文件。安装方式为pip3 install tensorflow numpy pickle。详细过程不在这里描述。
在这里,训练和测试数据集文件放在该脚本的父文件夹中,因此按照实际情况来对CIFAR_DIR赋值,该参数将在后续过程中被引用。

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import os
import pickle
import numpy as np

CIFAR_DIR = "./../cifar-10-batches-py"
print(os.listdir(CIFAR_DIR))

整理图片训练集和测试集

数据集解析函数的介绍

在训练模型之前,最基础的是要分析数据集的结构,以及解析的方法。
在训练脚本中定义一个辅助函数。该函数的作用是,将传递过来的参数中的文件打开,并以字节的形式储存在data变量中。解析后的数据表中将存在两个子数据集,一个名为data,另一个名为label,data即单个图片信息本身,而label则是该数据的类别,比如一辆车。

def load_data(filename):
    """read data from data file."""
    with open(filename, 'rb') as f:
        data = pickle.load(f, encoding='bytes')
        return data[b'data'], data[b'labels']

下面我们用实际数据跑一下该函数,使数据可视化,看看数据究竟是以怎样的形式得到处理的。
实现看一下这个训练集的文件组织形式。首先是5个data_batch,即5个数据组,还有一个test_batch用于对训练好的神经元进行测试。我们拿data_batch_1和test_batch做个实验,将其中的一部分数据解析出来。
在这里插入图片描述
下面是一段测试脚本,来读取data_batch_1中的第一个data和第一个label。
同样,我们还需要读取test_batch中的第一个data和第一个label。

import os
import pickle
import numpy as np

CIFAR_DIR = "./cifar-10-batches-py"
print(os.listdir(CIFAR_DIR))

def load_data(filename):
    """read data from data file."""
    with open(filename, 'rb') as f:
        data = pickle.load(f, encoding='bytes')
        print(data)#将读取到的生文件直接显示出来,事实证明这里面改是个超大的字典
        return data[b'data'], data[b'labels']

train_filename = os.path.join(CIFAR_DIR, 'data_batch_1')
test_filename = os.path.join(CIFAR_DIR, 'test_batch')

train_data, train_labels = load_data(train_filename)
test_data, test_labels = load_data(test_filename)

print('data_batch_1 的数据个数:{}'.format(len(train_data)))
print('data_batch_1 的标签个数:{}'.format(len(train_labels)))
print('-'*30)
print('test_batch 的数据个数:{}'.format(len(test_data)))
print('test_batch 的标签个数:{}'.format(len(test_labels)))
print('-'*30)
print('train_batch_1 图片的像素:{}'.format(len(train_data[0])))
print('test_batch 图片的像素:{}'.format(len(test_data[0])))
print('-'*30)
print('第一个训练集数据及其标签')
print(train_labels[0])
print(train_data[0])
print('-'*30)
print('第一个测试集数据及其标签')
print(test_labels[0])
print(test_data[0])

首先来看一下没有处理的生数据集,它是一个字典,里面包含该组的信息,labels,data和每一个文件的名字。data中的每一个列表里面其实都代表了一个像素点的只,代表不同的颜色,这个列表中的所有像素点组成了一个小图片,例如第一个data中的信息如列表中所示[ 59, 43, 50, …, 140, 84, 72]。我们看看这个列表有多长。
在这里插入图片描述
接下来我们统计出这个数据集的具体信息,打印信息如下:
在这里插入图片描述
所以可以看到每个batch中都有10000个图片,每个图片的像素数量为3072=1024*3,这里需要澄清,因为图片都是由三原色RGB组成,因此图片实际由1024(32 x 32)个像素点组成,3倍的数据量是由于颜色信息造成的。

对数据集的精加工

代码的作用写在了每一行后面的注释里
这个类主要的功能是,将所有文件中的数据合入到同一对儿列表中,
还提供了_shuffle_data函数,如果有必要的话,可以启用该函数,在不破坏图片数据和对应标签的对应关系的前提下,将所有的列表内容顺序打乱,进而提高训练的有效性。
最后,这个类还提供了一个函数next_batch,该函数允许训练或者测试的过程中以特定的步进值进行训练然后执行测试,这样我们就可以看到阶段性的测试结果。

# tensorflow.Dataset.
class CifarData:
    def __init__(self, filenames, need_shuffle):#初始化函数,需要一个文件名称,和一个是否需要打乱顺序的置位符(选1则打乱)
        all_data = []#将所有图片数据放在这个列表中
        all_labels = []#将所有图片对应的标签放在这个列表中
        for filename in filenames:#例如,训练数据集中有5个文件,则依次循环这5个文件
            data, labels = load_data(filename)#在这里引用了上一节提供的文件加载函数,返回一个数据集 和 一个标签集
            all_data.append(data)#将所有5个文件中的图片数据添加到all_data列表中
            all_labels.append(labels)#将所有5个文件中的标签数据添加到all_labels列表中
        self._data = np.vstack(all_data)
        self._data = self._data / 127.5 - 1
        self._labels = np.hstack(all_labels)
        print(self._data.shape)
        print(self._labels.shape)
        
        self._num_examples = self._data.shape[0]#data shape是(50000,3702),所以self._data.shape[0]就代表50000,
        										#也就是说,_num_examples代表所有训练数据的个数
        self._need_shuffle = need_shuffle#如果置1,则启动打乱程序:_shuffle_data
        self._indicator = 0#这个值的作用是一个游标,每次执行完函数next_batch后,这个值就更新为当前所在的self._data列表的位置
        				#这很重要,因为我们需要判断50000个数据是否都训练到了,如果想再学习一轮以加固训练效果,则需要该游标进行判断
        				#具体的执行过程需要参考 next_batch函数 以及 训练和测试 章节
        if self._need_shuffle:
            self._shuffle_data()#这里需要注意,我们要对实例化后的数据集进行_shuffle_data
            
    def _shuffle_data(self):
        # [0,1,2,3,4,5] -> [5,3,2,4,0,1]
        p = np.random.permutation(self._num_examples)
        self._data = self._data[p]
        self._labels = self._labels[p]
    
    def next_batch(self, batch_size):
        """return batch_size examples as a batch."""
        end_indicator = self._indicator + batch_size#该示例的batch_size为20,也就是每次训练取列表中的20个图片数据,
        											#然而我们有50000个图片,这样的重复需要50000/20次,而每次完成训练后,把本实例中的self._indicator更新为本次训练的end_indicator 
		"""
		如果没啥幺蛾子,下面两段IF是不执行的。需要执行的话有两种可能:
		1. 训练数据集用完了,判断是否需要第二轮,或更多轮的训练
		2. 初始状态下就异常,八成是batch_size设置得太大
		剩余的数据不够batch_size个了,则直接放弃剩余数据,打乱顺序继续第二轮训练
		"""
        if end_indicator > self._num_examples:#如果本次训练end_indicator数值超过50000了,则
        									#说明我们学完了所有50000个图片,需要进行第二轮的学习
            if self._need_shuffle:
                self._shuffle_data()#第二轮学习为了有效,我们再次打乱50000个图片数据
                self._indicator = 0#游标再次置0,这样我们就可以愉快地开始第二轮训练了
                end_indicator = batch_size
            else:#如果当时初始化没有要求打乱,则证明我们不需要第二轮学习。这时中断测试,并抛出异常提示没有更多的训练数据了。
                raise Exception("have no more examples")#不用再看了,训练结束了,放学!
        if end_indicator > self._num_examples:#上一个IF语句中重置了end_indicator为batch_size值,
        									#如果此时还是比_num_examples大的话,则证明这个batch_size只是单纯地设置得太大了。
            raise Exception("batch size is larger than all examples")#抛出异常,中断测试,回去修改batch_size去...
        batch_data = self._data[self._indicator: end_indicator]#获取本次需要的20个图片数据,每次循环都会往后移动20个
        batch_labels = self._labels[self._indicator: end_indicator]#获取本次需要的20个标签数据,每次循环都会往后移动20个
        self._indicator = end_indicator#循环由外部引用来进行,每完成一次next_batch就更新一下self._indicator
        return batch_data, batch_labels#当然这个函数执行完就要交货了,也就是需要用于训练的20个数据,上面的一堆步骤都是前期审核和位置记录

train_filenames = [os.path.join(CIFAR_DIR, 'data_batch_%d' % i) for i in range(1, 6)]
test_filenames = [os.path.join(CIFAR_DIR, 'test_batch')]

train_data = CifarData(train_filenames, True)
test_data = CifarData(test_filenames, False)

创建神经元

x = tf.placeholder(tf.float32, [None, 3072])
# [None], eg: [0,5,6,3]
y = tf.placeholder(tf.int64, [None])

# (3072, 10)
w = tf.get_variable('w', [x.get_shape()[-1], 10],
                   initializer=tf.random_normal_initializer(0, 1))
# (10, )
b = tf.get_variable('b', [10],
                   initializer=tf.constant_initializer(0.0))

# [None, 3072] * [3072, 10] = [None, 10]
y_ = tf.matmul(x, w) + b

# mean square loss
"""
# course: 1 + e^x
# api: e^x / sum(e^x)
# [[0.01, 0.9, ..., 0.03], []]
p_y = tf.nn.softmax(y_)
# 5 -> [0,0,0,0,0,1,0,0,0,0]
y_one_hot = tf.one_hot(y, 10, dtype=tf.float32)
loss = tf.reduce_mean(tf.square(y_one_hot - p_y))
"""

loss = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=y_)
# y_ -> sofmax
# y -> one_hot
# loss = ylogy_



# indices
predict = tf.argmax(y_, 1)
# [1,0,1,1,1,0,0,0]
correct_prediction = tf.equal(predict, y)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float64))

with tf.name_scope('train_op'):
    train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

训练和测试

init = tf.global_variables_initializer()
batch_size = 20
train_steps = 100000
test_steps = 100

# run 100k: 30.95%
with tf.Session() as sess:
    sess.run(init)
    for i in range(train_steps):
        batch_data, batch_labels = train_data.next_batch(batch_size)
        loss_val, acc_val, _ = sess.run(
            [loss, accuracy, train_op],
            feed_dict={
    
    
                x: batch_data,
                y: batch_labels})
        if (i+1) % 500 == 0:
            print('[Train] Step: %d, loss: %4.5f, acc: %4.5f' 
                  % (i+1, loss_val, acc_val))
        if (i+1) % 5000 == 0:
            test_data = CifarData(test_filenames, False)
            all_test_acc_val = []
            for j in range(test_steps):
                test_batch_data, test_batch_labels \
                    = test_data.next_batch(batch_size)
                test_acc_val = sess.run(
                    [accuracy],
                    feed_dict = {
    
    
                        x: test_batch_data, 
                        y: test_batch_labels
                    })
                all_test_acc_val.append(test_acc_val)
            test_acc = np.mean(all_test_acc_val)
            print('[Test ] Step: %d, acc: %4.5f' % (i+1, test_acc))

猜你喜欢

转载自blog.csdn.net/qq_33868661/article/details/113881206