cifar10数据格式以及读取方式

版权声明:版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/DarrenXf/article/details/85471718

cifar10 数据网站
http://www.cs.toronto.edu/~kriz/cifar.html

读取下面的文件

CIFAR-10 binary version (suitable for C programs)	162 MB	c32a1d4ab5d03f1284b67883e8d87530

下载cifar-10-binary.tar.gz 到./data/文件夹下

cd ./data/

解压下载后的文件到./data/

tar -xzvf cifar-10-binary.tar.gz

会出现一个文件夹 ‘cifar-10-batches-bin’
以及文件夹下的这些文件

batches.meta.txt
data_batch_1.bin    
data_batch_2.bin    
data_batch_3.bin    
data_batch_4.bin    
data_batch_5.bin    
readme.html         
test_batch.bin

data_batch_1.bin 到 data_batch_5.bin 是训练集二进制文件
每个文件中有10000张图片和10000个标记
共有50000张图片和50000个标记
每个二进制文件第一个字节是标记,后面的32x32x3是图片,
图片中前32x32 是 red channel, 接着32x32是 green channel,然后32x32 是blue channel. 然后依次类推.
依次类推.
test_batch.bin 是测试集文件

每个二进制文件是30730000个字节

下面是读取数据集的类文件
cifar10_dataset.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : cifar10_dataset.py
# Create date : 2018-12-24 19:58
# Modified date : 2018-12-31 16:21
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
#http://www.cs.toronto.edu/~kriz/cifar.html
import sys
import os
import struct
import numpy as np
import matplotlib.pyplot as plt

# pylint: disable=bad-continuation
meta_lt = [
        "airplane",
        "automobile",
        "bird",
        "cat",
        "deer",
        "dog",
        "frog",
        "horse",
        "ship",
        "truck",
        ]
# pylint: enable=bad-continuation
def create_path(path):
    if not os.path.isdir(path):
        os.makedirs(path)

def open_file_with_full_name(full_path, open_type):
    try:
        file_object = open(full_path, open_type)
        return file_object
    except Exception as e:
        print(e)
        return None

def get_file_full_name(path, name):
    if path[-1] == "/":
        full_name = path +  name
    else:
        full_name = path + "/" +  name
    return full_name

def open_file(path, name, open_type='a'):
    file_name = get_file_full_name(path, name)
    return open_file_with_full_name(file_name, open_type)

def _get_file_header_data(file_obj, header_len, unpack_str):
    raw_header = file_obj.read(header_len)
    header_data = struct.unpack(unpack_str, raw_header)
    return header_data

def _read_a_image(file_object):
    raw_img = file_object.read(32 * 32)
    red_img = struct.unpack(">1024B", raw_img)

    raw_img = file_object.read(32 * 32)
    green_img = struct.unpack(">1024B", raw_img)

    raw_img = file_object.read(32 * 32)
    blue_img = struct.unpack(">1024B", raw_img)

    img = np.zeros(shape=(1024, 3))
    for i in range(1024):
        l = [red_img[i], green_img[i], blue_img[i]]
        img[i] = l
    img = img.reshape(32, 32, 3)
    img = img / 255.
    return img

def _read_one_image(file_object):
    raw_img = file_object.read(32 * 32 * 3)
    img = struct.unpack(">3072B", raw_img)
    return img

def _read_a_label(file_object):
    raw_label = file_object.read(1)
    label = struct.unpack(">B", raw_label)
    return label

def _get_image_full_name(path, label, count):
    meta = meta_lt[label[0]]
    full_path = "%s%s" %(path, meta)
    create_path(full_path)
    full_path_name = "%s/%s.jpg" %(full_path, count)
    return full_path_name

def save_image(image, full_path_name):
    plt.imshow(image)
    plt.savefig(full_path_name)
    plt.close()

class Cifar10Set(object):
    def __init__(self, file_path):
        super(Cifar10Set, self).__init__()
        # pylint: disable=bad-continuation
        self._train_file_list = [
                            "data_batch_1.bin",
                            "data_batch_2.bin",
                            "data_batch_3.bin",
                            "data_batch_4.bin",
                            "data_batch_5.bin"
                            ]
        # pylint: enable=bad-continuation
        self._test_file_list = ["test_batch.bin",]
        self.file_path = file_path

    def _read_file(self, file_name):
        file_object = open_file(self.file_path, file_name, open_type="rb")
        return file_object

    def _generate_a_batch(self, batch_size, file_list):
        images = np.zeros(shape=(batch_size, 32 * 32 * 3))
        labels = np.zeros(shape=(batch_size, 10))
        i = 0
        file_name = file_list[i]
        file_name = "cifar-10-batches-bin/%s" % file_name
        train_file = self._read_file(file_name)

        count = 0
        ret = True
        while True:
            while count < batch_size:
                try:
                    label = _read_a_label(train_file)
                    image = _read_one_image(train_file)
                    images[count] = image
                    labels[count][label[0]] = 1
                    count += 1
                except Exception as err:
                    #print(err)
                    if i >= len(self._train_file_list):
                        ret = False
                        break
                    else:
                        i += 1
                        if i < len(file_list):
                            file_name = file_list[i]
                            file_name = "cifar-10-batches-bin/%s" % file_name
                            train_file = self._read_file(file_name)
            count = 0
            yield images, labels.astype(int), ret
            images = np.zeros(shape=(batch_size, 32*32*3))
            labels = np.zeros(shape=(batch_size, 10))

    def generator_images(self, file_list, path):
        count = 1
        for i in range(len(file_list)):
            file_name = file_list[i]
            file_name = "cifar-10-batches-bin/%s" % file_name
            train_file = self._read_file(file_name)

            while True:
                try:
                    label = _read_a_label(train_file)
                    image = _read_a_image(train_file)
                    full_path_name = _get_image_full_name(path, label, count)
                    save_image(image, full_path_name)
                    print("file:%s count:%s"% (file_name, count))

                except Exception as err:
                    print(err)
                    break
                count += 1

    def generator_train_images(self, path):
        self.generator_images(self._train_file_list, path)

    def generator_test_images(self, path):
        self.generator_images(self._test_file_list, path)

    def get_train_data_generator(self, batch_size=128):
        file_list = self._train_file_list
        gennerator = self._generate_a_batch(batch_size, file_list)
        return gennerator

    def get_test_data_generator(self, batch_size=128):
        file_list = self._test_file_list
        gennerator = self._generate_a_batch(batch_size, file_list)
        return gennerator

    def get_a_batch_data(self, data_generator):
        if sys.version > '3':
            batch_img, batch_labels, status = data_generator.__next__()
        else:
            batch_img, batch_labels, status = data_generator.next()
        return batch_img, batch_labels, status

下面是main.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : main.py
# Create date : 2018-12-23 16:53
# Modified date : 2018-12-31 15:37
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import cifar10_dataset

def test_cifar10_train_set():
    file_path = "./data/"
    batch_size = 100
    dataset = cifar10_dataset.Cifar10Set(file_path)
    data_generator = dataset.get_train_data_generator(batch_size)
    count = 1
    while True:
        batch_img, batch_labels, status = dataset.get_a_batch_data(data_generator)
        print("count:%s status:%s " % (count, status))
        if not status:
            break
        count += 1
        print(str(batch_labels))

def test_cifar10_test_set():
    file_path = "./data/"
    batch_size = 100
    dataset = cifar10_dataset.Cifar10Set(file_path)
    data_generator = dataset.get_test_data_generator(batch_size)
    count = 1
    while True:
        batch_img, batch_labels, status = dataset.get_a_batch_data(data_generator)
        print("count:%s status:%s " % (count, status))
        if not status:
            break
        count += 1
        print(str(batch_labels))

def test_generator_images():
    test_generator_train_images()
    test_generator_test_images()

def test_generator_train_images():
    file_path = "./data/"
    train_img_path = "./img/train/"
    dataset = cifar10_dataset.Cifar10Set(file_path)
    dataset.generator_train_images(train_img_path)

def test_generator_test_images():
    file_path = "./data/"
    test_img_path = "./img/test/"
    dataset = cifar10_dataset.Cifar10Set(file_path)
    dataset.generator_test_images(test_img_path)

def run():
	test_cifar10_train_set()
    test_cifar10_test_set()
    test_generator_images()
    
run()

上面的代码可以在python2 以及python3 运行.可以批量读取cifar10的所有训练以及测试数据.
而且通过matplotlib 把二进制的数据转换保存成了图片.使我们可以看到这些图片真实的样子.
下面是一些保存的图片.
第一张图片是青蛙 ,放成这么大的话,我作为人表示还是这比较难看出来的.但是离远点看,或者缩小了后,就有点像青蛙了.
在这里插入图片描述

也有些比较容易辨识的图片 比如第二章卡车图片, 这个还比较容易看出来.
在这里插入图片描述

再放上来一些
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/DarrenXf/article/details/85471718