MNIST数据集提取图片和标注信息

MNIST数据集

  • 简介

    MNIST数据集(http://yann.lecun.com/exdb/mnist/)是著名的手写数字分类数据集,主要由一下四部分组成:

    • 训练集图片:train-images.idx3-ubyte,处理后能得到60000个训练图片
    • 训练街标签:train-labels.idx1-ubyte,处理后能得到60000个训练标签
    • 测试集图片:t10k-images.idx3-ubyte,处理后能得到10000个测试图片
    • 测试集标签:t10k-labels.idx1-ubyte,处理后能得到10000个测试标签
  • 处理

    初始接触MNIST数据集时,需要将对应存储格式的数据信息转换为对应的图片和标签,具体处理过程如下:

    # -*- coding: utf-8 -*-
    """
    Created on Sat Aug 11 10:06:46 2018
    
    @author: yzhang
    """
    
    import os
    import struct
    import numpy as np
    import matplotlib.pyplot as plt
    from PIL import Image
    
    def load_mnist_image(path, filename, type = 'train'):
        full_name = os.path.join(path, filename)
        fp = open(full_name, 'rb')
        buf = fp.read()
        index = 0;
        magic, num, rows, cols = struct.unpack_from('>IIII', buf, index)
        index += struct.calcsize('>IIII')
    
        for image in range(0, num):
            im = struct.unpack_from('>784B', buf, index)
            index += struct.calcsize('>784B')
            im = np.array(im, dtype = 'uint8')
            im = im.reshape(28, 28)
            im = Image.fromarray(im)
            if (type == 'train'):
                isExists = os.path.exists('./train')
                if not isExists:
                    os.mkdir('./train')
                im.save('./train/train_%s.bmp' %image, 'bmp')
            if (type == 'test'):
                isExists = os.path.exists('./test')
                if not isExists:
                    os.mkdir('./test')
                im.save('./test/test_%s.bmp' %image, 'bmp')
    
    def load_mnist_label(path, filename, type = 'train'):
        full_name = os.path.join(path, filename)
        fp = open(full_name, 'rb')
        buf = fp.read()
        index = 0;
        magic, num = struct.unpack_from('>II', buf, index)
        index += struct.calcsize('>II')
        Labels = np.zeros(num)
    
        for i in range(num):
            Labels[i] = np.array(struct.unpack_from('>B', buf, index))
            index += struct.calcsize('>B')
    
        if (type == 'train'):
            np.savetxt('./train_labels.csv', Labels, fmt='%i', delimiter=',')
        if (type == 'test'):
            np.savetext('./test_labels.csv', Labels, fmt='%i', delimiter=',')
    
        return Labels
    
    if __name__ == '__main__':
        path = 'D:/Project/an_python/Minst/'
        train_images = 'train-images.idx3-ubyte'
        load_mnist_image(path, train_images, 'train')
        train_labels = 'train-labels.idx1-ubyte'
        load_mnist_label(path, train_labels, 'train')
        test_images = 't10k-images.idx3-ubyte'
        load_mnist_image(path, test_images, 'test')
        test_labels = 't10k-labels.idx1-ubyte'
        load_mnist_label(path, test_labels, 'test')
    

    执行上述Python程序后,便能够在D:/Project/an_python/Minst目录下得到提取出来的图片和标注信息。

猜你喜欢

转载自blog.csdn.net/yzhang6_10/article/details/81585572
今日推荐