使用python对cifar10数据集的python读取

关是读取数据,对于小白来说就花了不少时间,来总结以下。

首先你的数据从网上下的,一定要看清楚对于的是什么版,我就在这上面吃了大亏,之前是用的cifar10的模块,用的别人的包自动下载的,得到的是Bin文件,这个是二进制的文件,试用于c语言的,结果我用Python的pickle包Load半天老是出问题。

数据集下载的网址是:http://www.cs.toronto.edu/~kriz/cifar.html

一定要下对版本!!

然后读取顺利就可以看到了 ,这是网站上给的python3的读取方式

import pickle

def unpick(f):
     with open(f,'rb') as fo:
         dic=pickle.load(fo,encoding='bytes')
         return dic

得到的是个字典,注意字典的索引是字节型的,比如要读取data,那么应该是dic[b'data']

,字符串前面加b 才是字节

另外参考这篇https://blog.csdn.net/u010165147/article/details/54176612

上面的代码,有一些错误,我修改了下,亲测有效现在放上来我修改后的版本:

import pickle
import numpy as np
import os


class Cifar10DataReader():
    def __init__(self, cifar_folder, onehot=True):
        self.cifar_folder = cifar_folder
        self.onehot = onehot
        self.data_index = 1
        self.read_next = True
        self.data_label_train = None
        self.data_label_test = None
        self.batch_index = 0

    def unpickle(self, f):
        fo = open(f, 'rb')
        d = pickle.load(fo,encoding='bytes')
        fo.close()
        return d

    def next_train_data(self, batch_size=100):
        assert 10000 % batch_size == 0, "10000%batch_size!=0"
        rdata = None
        rlabel = None
        if self.read_next:
            f = os.path.join(self.cifar_folder, "data_batch_%s" % (self.data_index))
            print('read: %s' % f)
            dic_train = self.unpickle(f)
            self.data_label_train = list(zip(dic_train[b'data'], dic_train[b'labels']))  # label 0~9
            np.random.shuffle(self.data_label_train)

            self.read_next = False
            if self.data_index == 5:
                self.data_index = 1
            else:
                self.data_index += 1

        if self.batch_index < len(list(self.data_label_train)) // batch_size:
            # print self.batch_index
            datum = self.data_label_train[self.batch_index * batch_size:(self.batch_index + 1) * batch_size]
            self.batch_index += 1
            rdata, rlabel = self._decode(datum, self.onehot)
        else:
            self.batch_index = 0
            self.read_next = True
            return self.next_train_data(batch_size=batch_size)

        return rdata, rlabel

    def _decode(self, datum, onehot):
        rdata = list();
        rlabel = list()
        if onehot:
            for d, l in datum:
                rdata.append(np.reshape(np.reshape(d, [3, 1024]).T, [32, 32, 3]))
                hot = np.zeros(10)
                hot[int(l)] = 1
                rlabel.append(hot)
        else:
            for d, l in datum:
                rdata.append(np.reshape(np.reshape(d, [3, 1024]).T, [32, 32, 3]))
                rlabel.append(int(l))
        return rdata, rlabel

    def next_test_data(self, batch_size=100):
        if self.data_label_test is None:
            f = os.path.join(self.cifar_folder, "test_batch")
            print('read: %s' % f)
            dic_test = self.unpickle(f)
            data = dic_test[b'data']
            labels = dic_test[b'labels']  # 0~9
            self.data_label_test = list(zip(data,labels))

        np.random.shuffle(self.data_label_test)
        datum = self.data_label_test[0:batch_size]

        return self._decode(datum, self.onehot)


if __name__ == "__main__":
    dr = Cifar10DataReader(cifar_folder="E:\Tensorlow\Project\深度学习练习\cifar-10-batches-py\\")
    import matplotlib.pyplot as plt
    d, l = dr.next_test_data()
    print(np.shape(d), np.shape(l))
    plt.imshow(d[2])
    plt.show()
    # for i in range(600):
    #     d, l = dr.next_train_data(batch_size=100)
    #     print(np.shape(d), np.shape(l))

得到的图片

------------------------------------------------------------------------------------------------------------------------------------------------------------

嘀嘀嘀,前天测试的代码,我昨天又自己重新写了一遍,并且修改了一些,觉得之前的代码有一部分写的不是很好,复用率不高,比如数据文件在每一个批次的训练都需要重新加载,这样感觉效率会大大降低,所以我就做了些改进以及增加了对象的扩展性,可以调用读取数据的函数,得到未加工的向量形式的数据,而不是只能得到一个批次的数据张量。并且打上了备注(本人水平不高,备注可能也就自己能看懂,还望见谅),下面附上代码。

class Cifar10DataReader():
    import os
    import random
    import numpy as np
    import pickle
    def __init__(self, cifar_file, one_hot=False, file_number=1):
        self.batch_index = 0  # 第i批次
        self.file_number = file_number  # 第i个文件数
        self.cifar_file = cifar_file  # 数据集所在dir
        self.one_hot = one_hot
        self.train_data = self.read_train_file()  # 一个数据文件的训练集数据,得到的是一个1000大小的list,
        self.test_data = self.read_test_data()  # 得到1000个测试集数据

    # 读取数据函数,返回dict
    def unpickle(self, file):
        with open(file, 'rb') as fo:
            try:

                dicts = self.pickle.load(fo, encoding='bytes')
            except Exception as e:
                print('load error', e)
            return dicts

    # 读取一个训练集文件,返回数据list
    def read_train_file(self, files=''):
        if files:
            files = self.os.path.join(self.cifar_file, files)
        else:
            files = self.os.path.join(self.cifar_file, 'data_batch_%d' % self.file_number)
        dict_train = self.unpickle(files)
        train_data = list(zip(dict_train[b'data'], dict_train[b'labels']))  # 将数据和对应标签打包
        self.np.random.shuffle(train_data)
        print('成功读取到训练集数据:data_batch_%d' % self.file_number)
        return train_data

    # 读取测试集数据
    def read_test_data(self):
        files = self.os.path.join(self.cifar_file, 'test_batch')
        dict_test = self.unpickle(files)
        test_data = list(zip(dict_test[b'data'], dict_test[b'labels']))  # 将数据和对应标签打包
        print('成功读取测试集数据')
        return test_data

    # 编码得到的数据,变成张量,并分别得到数据和标签
    def encodedata(self, detum):
        rdatas = list()
        rlabels = list()
        for d, l in detum:
            rdatas.append(self.np.reshape(self.np.reshape(d, [3, 1024]).T, [32, 32, 3]))
            if self.one_hot:
                hot = self.np.zeros(10)
                hot[int(l)] = 1
                rlabels.append(hot)
            else:
                rlabels.append(l)
        return rdatas, rlabels

    # 得到batch_size大小的数据和标签
    def nex_train_data(self, batch_size=100):
        assert 1000 % batch_size == 0, 'erro batch_size can not divied!'  # 判断批次大小是否能被整除

        # 获得一个batch_size的数据
        if self.batch_index < len(self.train_data) // batch_size:  # 是否超出一个文件的数据量
            detum = self.train_data[self.batch_index * batch_size:(self.batch_index + 1) * batch_size]
            datas, labels = self.encodedata(detum)
            self.batch_index += 1
        else:  # 超出了就加载下一个文件
            self.batch_index = 0
            if self.file_number == 5:
                self.file_number = 1
            else:
                self.file_number += 1
            self.read_train_file()
            return self.nex_train_data(batch_size=batch_size)
        return datas, labels

    # 随机抽取batch_size大小的训练集
    def next_test_data(self, batch_size=100):
        detum = self.random.sample(self.test_data, batch_size)  # 随机抽取
        datas, labels = self.encodedata(detum)
        return datas, labels


if __name__ == '__main__':
    import matplotlib.pyplot as plt

    Cifar10 = Cifar10DataReader(r'E:\Tensorlow\Project\深度学习练习\cifar-10-batches-py', one_hot=True)
    d, l = Cifar10.nex_train_data()
    print(len(d))
    print(d)
    plt.imshow(d[0])
    plt.show()

猜你喜欢

转载自blog.csdn.net/qq_26593695/article/details/88563487