背景信息
MNIST数据集简介
MNIST数据集是从 NIST 的Special Database 3(SD-3)和Special Database 1(SD-1)构建而来。由于SD-3是由美国人口调查局的员工进行标注,SD-1是由美国高中生进行标注,因此SD-3比SD-1更干净也更容易识别。Yann LeCun等人从SD-1和SD-3中各取一半作为MNIST的训练集(60000条数据)和测试集(10000条数据),其中训练集来自250位不同的标注员,此外还保证了训练集和测试集的标注员是不完全相同的。
本文目的
本文实现MNIST数据集和标签的读取,并转化为Numpy的数组进行输出。
前提条件
以完成MNIST数据集的下载,如下所示:
root@5e3ac72a80f4:~/.cache/paddle/dataset/mnist# ll total 11344 drwxr-xr-x 2 root root 4096 Mar 12 03:22 ./ drwxr-xr-x 13 root root 4096 Apr 1 07:01 ../ -rw-r--r-- 1 root root 1648877 Mar 12 03:22 t10k-images-idx3-ubyte.gz -rw-r--r-- 1 root root 4542 Mar 12 03:22 t10k-labels-idx1-ubyte.gz -rw-r--r-- 1 root root 9912422 Mar 12 03:22 train-images-idx3-ubyte.gz -rw-r--r-- 1 root root 28881 Mar 12 03:22 train-labels-idx1-ubyte.gz
详细代码
#导入所需包 import subprocess import numpy import platform
#定义变量 image_filename='/root/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz' label_filename='/root/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz' buffer_size=100
# 定义函数读取image,并保存为数组 def get_images(image_filename, buffer_size): m = subprocess.Popen(['zcat', image_filename], stdout=subprocess.PIPE) m.stdout.read(16) # skip some magic bytes images=numpy.fromfile(m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape((buffer_size, 28 * 28)).astype('float32') images = images / 255.0 * 2.0 - 1.0 m.terminate() return images
# 定义函数读取labels,并保存为数组 def get_labels(label_filename, buffer_size): l = subprocess.Popen(['zcat', label_filename], stdout=subprocess.PIPE) l.stdout.read(8) # skip some magic bytes labels = numpy.fromfile(l.stdout, 'ubyte', count=buffer_size).astype("int") #print labels.shape l.terminate() return labels
# 创建Paddle中使用的def reader_create(image_filename, label_filename, buffer_size) def mnist_reader(image_filename, label_filename, buffer_size): def reader(): images=get_images(image_filename, buffer_size) labels=get_labels(label_filename, buffer_size) for i in xrange(buffer_size): yield images[i,:], int(labels[i]) return reader
查看结果: