使用python将MNIST数据转换为图片

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/m_buddy/article/details/80964194

1. mnist数据集

mnist数据集是一个很经典的数据集,该数据集在这个地方可以下载到。但是呢下载到的图片并不是图片格式的,而是一种二进制的东西,直接读起来很不直观。这就需要将其转换为图像格式。好在网站上给出了其数据的格式。

1.1 训练集数据

首先来看训练集的样本数据格式
这里写图片描述
可以看到在图像数据的前面4个字节定义了magic number、图像个数以及图像的长和宽。后面的也就是全部的数据格式了,偏置的话是28*28个字节。
对应的训练集标签其样本数据格式为
这里写图片描述
可以从上图中可以看出,除了前面两个字节代表magic number和标签数量之外以后全是以一个字节来表示一个标签值。

1.2 测试数据集

测试数据集和训练数据集是类似的结构,唯一的区别便是数据的数量不一样了而已。下图中是其结构
这里写图片描述
这里写图片描述

2. 转换代码

# -*- coding: utf-8 -*-
import numpy as np
import struct
from PIL import Image
import matplotlib.pyplot as plt
import os

class DataUtils(object):
    def __init__(self, filename=None, outpath=None):
        self._filename = filename
        self._outpath = outpath

        self._tag = '>'  # 大端格式
        self._twoBytes = 'II'
        self._fourBytes = 'IIII'
        self._pictureBytes = '784B'
        self._labelByte = '1B'
        self._twoBytes2 = self._tag + self._twoBytes
        self._fourBytes2 = self._tag + self._fourBytes
        self._pictureBytes2 = self._tag + self._pictureBytes
        self._labelByte2 = self._tag + self._labelByte

        self._imgNums = 0
        self._LabelNums = 0

    def getImage(self):
        """
        将MNIST的二进制文件转换成像素特征数据
        """
        binfile = open(self._filename, 'rb') #以二进制方式打开文件
        buf = binfile.read()
        binfile.close()
        index = 0
        numMagic, self._imgNums, numRows, numCols = struct.unpack_from(self._fourBytes2, buf, index)
        index += struct.calcsize(self._fourBytes)
        images = []
        print('image nums: %d' % self ._imgNums)
        for i in range(self._imgNums):
            imgVal = struct.unpack_from(self._pictureBytes2, buf, index)
            index += struct.calcsize(self._pictureBytes2)
            imgVal = list(imgVal)
            for j in range(len(imgVal)):
                if imgVal[j] > 1:
                    imgVal[j] = 1
            images.append(imgVal)
        return np.array(images), self._imgNums

    def getLabel(self):
        """
        将MNIST中label二进制文件转换成对应的label数字特征
        """
        binFile = open(self._filename, 'rb')
        buf = binFile.read()
        binFile.close()
        index = 0
        magic, self._LabelNums = struct.unpack_from(self._twoBytes2, buf, index)
        index += struct.calcsize(self._twoBytes2)
        labels = []
        for x in range(self._LabelNums):
            im = struct.unpack_from(self._labelByte2, buf, index)
            index += struct.calcsize(self._labelByte2)
            labels.append(im[0])
        return np.array(labels)

    def outImg(self, arrX, arrY, imgNums):
        """
        根据生成的特征和数字标号,输出png的图像
        """
        output_txt = self._outpath + '/img.txt'
        output_file = open(output_txt, 'a+')

        m, n = np.shape(arrX)
        # 每张图是28*28=784Byte
        for i in range(imgNums):
            img = np.array(arrX[i])
            img = img.reshape(28, 28)
            outfile = str(i) + "_" + str(arrY[i]) + ".jpg"
            print('saving file: %s' % outfile)

            txt_line = outfile + " " + str(arrY[i]) + '\n'
            output_file.write(txt_line)

            img = Image.fromarray(img, '1')
            img.save(self._outpath + '/' + outfile)
            print('saving file: %s; done' % outfile)

            # plt.figure()
            # plt.imshow(img, cmap='binary')  # 将图像黑白显示
            # plt.savefig(self._outpath + "/" + outfile)
        output_file.close()


if __name__ == '__main__':
    trainfile_X = '../Image/train-images-idx3-ubyte'
    trainfile_y = '../Image/train-labels-idx1-ubyte'
    testfile_X = '../Image/t10k-images-idx3-ubyte'
    testfile_y = '../Image/t10k-labels-idx1-ubyte'

    # 加载mnist数据集
    train_X, train_img_nums = DataUtils(filename=trainfile_X).getImage()
    train_y = DataUtils(filename=trainfile_y).getLabel()
    test_X, test_img_nums = DataUtils(testfile_X).getImage()
    test_y = DataUtils(testfile_y).getLabel()

    # 以下内容是将图像保存到本地文件中
    path_trainset = "../Image/imgs_train"
    path_testset = "../Image/imgs_test"
    if not os.path.exists(path_trainset):
        os.mkdir(path_trainset)
    if not os.path.exists(path_testset):
        os.mkdir(path_testset)
    DataUtils(outpath=path_trainset).outImg(train_X, train_y, train_img_nums)
    DataUtils(outpath=path_testset).outImg(test_X, test_y, test_img_nums)

猜你喜欢

转载自blog.csdn.net/m_buddy/article/details/80964194