使用pycaffe进行mnist手写数字识别

官方引导教程

运行环境

win10+python3.5+gpu版本的caffe

步骤

  1. 下载数据集
  2. 将数据集转为lmdb
  3. 训练
  4. 测试训练的出来的模型

下载数据集

mnist官网下载下面4个文件

t10k-images.idx3-ubyte
t10k-labels.idx1-ubyte
train-images.idx3-ubyte
train-labels.idx1-ubyte

它们的结构在mnist网站上有说明

训练图片集的打标值文件 (train-labels-idx1-ubyte):

[offset]  [type]               [value]                     [description] 
0000     32 bit integer   0x00000801(2049) magic number (MSB first) 
0004     32 bit integer   10000                     标签值总数 
0008     unsigned byte  ??                          标签值
0009     unsigned byte  ??                          标签值
........ 
xxxx     unsigned byte   ??                          标签值

训练图片集文件 (train-images-idx3-ubyte):

[offset]   [type]               [value]                     [description] 
0000     32 bit integer   0x00000803(2051)  magic number 
0004     32 bit integer   10000                      图片总数 
0008     32 bit integer    28                           单张图片的长度像素值数量
0012     32 bit integer    28                           单张图片的高度像素值数量
0016     unsigned byte   ??                          单像素值
0017     unsigned byte   ??                          单像素值 
........ 
xxxx     unsigned byte    ??                          单像素值

测试图片集的打标值文件 (t10k-labels-idx1-ubyte):

[offset]  [type]               [value]                     [description] 
0000     32 bit integer   0x00000801(2049) magic number (MSB first) 
0004     32 bit integer   10000                     标签值总数 
0008     unsigned byte  ??                          标签值
0009     unsigned byte  ??                          标签值
........ 
xxxx     unsigned byte   ??                          标签值

标签值的范围是0-9

测试图片集文件 (t10k-images-idx3-ubyte):

[offset]   [type]               [value]                     [description] 
0000     32 bit integer   0x00000803(2051)  magic number 
0004     32 bit integer   10000                      图片总数 
0008     32 bit integer    28                           单张图片的长度像素值数量
0012     32 bit integer    28                           单张图片的高度像素值数量
0016     unsigned byte   ??                          单像素值
0017     unsigned byte   ??                          单像素值 
........ 
xxxx     unsigned byte    ??                          单像素值

数据集转换为lmdb

下载的数据集有两对,一对是训练数据图片集和它对应的标签值, 另一对是测试图片集和它对应的标签值,这两对文件的结构是一样的,因此转换为lmdb文件时,可以使用同样的方法

def orgin_to_lmdb(image_file, label_file, lmdb_save_path, force_update=False):
    mean_file = '{}.binaryproto'.format(lmdb_save_path)

    if os.path.exists(mean_file) and os.path.exists(lmdb_save_path) and force_update == False:
        return

    try:
        shutil.rmtree(lmdb_save_path)
    except:
        pass
    try:
        shutil.rmtree(mean_file)
    except:
        pass

    with open(image_file, 'rb') as image_f:
        with open(label_file, 'rb') as label_f:
            # 读取标签文件头的4个整型
            size = struct.calcsize('>2I')
            magic, num_items = struct.unpack_from('>2I', label_f.read(size))
            print(magic, num_items)

            # 读取图片文件头的4个整型
            size = struct.calcsize('>4I')
            magic, num_images, num_rows, num_columns = struct.unpack_from('>4I', image_f.read(size))
            print(magic, num_images, num_rows, num_columns)

            map_size = num_images*num_rows*num_columns * 1.5

            # 遍历所有图片,将文件列表写入到lmdb中
            with lmdb.open(lmdb_save_path,map_size=map_size) as in_db:
                with in_db.begin(write=True) as in_txn:
                    im_size = num_rows * num_columns
                    label_size = struct.calcsize('>B')
                    im_idx = 0
                    while im_idx < num_images:
                        img_item = struct.unpack_from('>B', label_f.read(label_size))[0]
                        img_buf = image_f.read(im_size)

                        datum = caffe_pb2.Datum(
                            channels=1,  # 数据集里面的图片是灰度图,因此通道数设置为1
                            width=num_columns,
                            height=num_rows,
                            label=int(img_item),
                            data=img_buf
                        )
                        in_txn.put('{:0>8d}'.format(im_idx).encode('utf8'), datum.SerializeToString())
                        im_idx += 1

    # 生成mean文件
    cmd = '{0} {1} {2}'.format(compute_image_mean, lmdb_save_path, mean_file)
    print(cmd)
    os.system(cmd)

以下代码可以打开lmdb查看第一张图片

# 查看lmdb的第一张图片
def show_lmdb_first_image(lmdb_save_path):
    with lmdb.open(lmdb_save_path, readonly=True) as lmdb_env:
        lmdb_txn = lmdb_env.begin()
        lmdb_cursor = lmdb_txn.cursor()
        datum = caffe_pb2.Datum()

        lmdb_cursor.first()
        key, value = lmdb_cursor.item()
        datum.ParseFromString(value)

        label = datum.label
        data = caffe.io.datum_to_array(datum)
        print(label, datum.channels, data.shape)
        image = data.transpose(1, 2, 0)
        cv2.imshow('cv2.png', image)
        cv2.waitKey(0)

        cv2.destroyAllWindows()

使用数据集进行训练

使用caffe代码目录下的examples\mnist\lenet_solver.prototxtexamples\mnist\lenet_train_test.prototxt, 需要修改lenet_solver.prototxt中的网络文件地址为新的lenet_train_test.prototxt 需要修改lenet_train_test.prototxt的数据层为刚才生成的lmdb地址

        solver = caffe.SGDSolver('lenet_solver.prototxt')
        solver.solve()

完成之后会产生两个模型文件lenet_iter_5000.caffemodellenet_iter_10000.caffemodel

测试训练的出来的模型

需要先生成一个网络配置文件, 一般是改动训练时用的网络配置文件,这里直接使用examples\mnist\lenet.prototxt

        net = caffe.Net(
            'lenet.prototxt', # 网络配置文件
            caffe.TEST,
            weights='lenet_iter_10000.caffemodel'  # 训练产生的模型
        )

        transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
        transformer.set_transpose('data', (2,0,1))
        transformer.set_raw_scale('data', 255)
        # transformer.set_channel_swap('data', (2, 1, 0))  # minist用的是灰度图 channel只有1,因此无需转换

        # 因为minist的channel是1, 所以需要转为灰度图color=False
        im = caffe.io.load_image('3.jpg', color=False)  # 打开测试图片
        net.blobs['data'].data[0] = transformer.preprocess('data', im)
        res = net.forward()
        print(res['prob'].argmax())

测试图片是用windows测试工具写的几个数字,需要黑底白字,并且图片大小要改为28*28

  • 输入图片说明
  • 输入图片说明
  • 输入图片说明

有的识别会出错。。。

猜你喜欢

转载自my.oschina.net/u/111188/blog/1615636
今日推荐