手写中文文章识别(3)——data feeding

手写中文文章识别(1)——问题描述

https://blog.csdn.net/foreseerwang/article/details/80833749

手写中文文章识别(2)——样本集构建

https://blog.csdn.net/foreseerwang/article/details/80842498


前文提到,样本集构建会形成文章文件(.code/.char/.len)和手写中文图片文件,本文使用这些文件形成模型训练及validation所需的输入数据,即dataset文件。关于tensorflow Dataset使用方法,可参见之前的两篇文章:

https://blog.csdn.net/foreseerwang/article/details/80170210

https://blog.csdn.net/foreseerwang/article/details/80572182


使用Dataset的data feeding代码如下:

import os
import random
import numpy as np
import tensorflow as tf
import io
try:
    import cPickle as pickle
except ImportError:
    import pickle

# 图片augment设置参数
tf.app.flags.DEFINE_boolean('random_flip_up_down', False, "Whether to random flip up down")
tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")
tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast")

# 图片及训练相关参数
tf.app.flags.DEFINE_integer('image_size', 64, "手写中文图片边长,方形。")
tf.app.flags.DEFINE_integer('image_channel', 1, "手写中文图片通道数,1代表黑白图片")
tf.app.flags.DEFINE_boolean('gray', True, "是否修改为灰度")
tf.app.flags.DEFINE_integer('shuffle_size', 100, '数据集的shuffle缓存大小')
tf.app.flags.DEFINE_integer('sent_len_max', 10,
                            "每句话的最长尺寸,超过这个长度时,逐段形成数据集;同时,训练集中的每句话也都需要padded到这个长度")
tf.app.flags.DEFINE_integer('batch_size', 3, '生成的数据集batch size')
tf.app.flags.DEFINE_integer('eval_steps', 2, 'validation间隔,即每eval_steps次训练batch进行一次validation')

# 是否使用短字典及相关配置。完整字库长度7356,短字库长度4000,已可以涵盖约95%以上的常见字
# 本程序使用完整字典,可不用关心此处的配置
# 抱歉,受限于版权问题,无法上传手写中文字库
tf.app.flags.DEFINE_boolean('short_dict', False, 'whether to use short dict')
tf.app.flags.DEFINE_integer('charset_size_long', 7356, "Long character dictionary size")
tf.app.flags.DEFINE_integer('charset_size_short', 4000, "Short character dictionary size")
tf.app.flags.DEFINE_string('char_dict_long', './article_recog/char_dict_gbk_rvs20180518',
                           'The reversed long character dictionary: code-->char')
tf.app.flags.DEFINE_string('char_dict_short', './article_recog/char_dict_4000_rvs',
                           'The reversed short character dictionary: code-->char')

# 数据储存目录
# 抱歉,受限于版权问题,这部分数据不能上传
tf.app.flags.DEFINE_string('sample_dir', './sample', '样本集存储目录')
tf.app.flags.DEFINE_string('train_hwdb_dir', './hwdb/hwdb_by_char_Train_gbk', '训练用hwdb手写中文图库存储目录')
tf.app.flags.DEFINE_string('test_hwdb_dir', './hwdb/hwdb_by_char_Test_gbk', '测试用hwdb手写中文图库存储目录')

FLAGS = tf.app.flags.FLAGS

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# 根据是否使用短字典,选择不同的字典文件和尺寸
if FLAGS.short_dict:
    charset_size = FLAGS.charset_size_short
    char_dict_file = FLAGS.char_dict_short
else:
    charset_size = FLAGS.charset_size_long
    char_dict_file = FLAGS.char_dict_long



class DataIterator:

    def __init__(self, filenames, istrain=True):

        self.filenames = filenames

        if istrain:
            self.hwdb_dir=FLAGS.train_hwdb_dir
        else:
            self.hwdb_dir=FLAGS.test_hwdb_dir

    # 图片augment处理子程序
    # 只需要针对训练集做augment
    # 应用中发现,这里的augment还远远不够...
    @staticmethod
    def data_augmentation(images, labels, lengths, masks):
        if FLAGS.random_flip_up_down:
            images = tf.image.random_flip_up_down(images)
        if FLAGS.random_brightness:
            images = tf.image.random_brightness(images, max_delta=0.3)
        if FLAGS.random_contrast:
            images = tf.image.random_contrast(images, 0.8, 1.2)
        return images, labels, lengths, masks

    # 如果使用短字典,需要对输出的数据集做编码转换。
    @staticmethod
    def dataset_convert(char_vec,char_code,sent_len,mask):
        if FLAGS.short_dict:
            char_dict_dir = './article_recog'
            char_dict_map_filename = 'char_dict_map'  # long num --> short num
            char_dict_map_fullfilename = os.path.join(char_dict_dir, char_dict_map_filename)
            fr = open(char_dict_map_fullfilename, 'rb')
            char_dict_map = pickle.load(fr)
            fr.close()

            converted_char_code = []
            for ii in range(sent_len):
                code_long = char_code[ii]
                try:
                    code = char_dict_map[code_long]
                except KeyError:
                    code = charset_size - 1
                converted_char_code.append(code)
            return (np.asarray(char_vec, dtype='float32'),
                    np.asarray(converted_char_code, dtype='int32'),
                    sent_len,
                    np.asarray(mask, dtype='bool'))
        else:
            return (np.asarray(char_vec, dtype='float32'),
                    np.asarray(char_code, dtype='int32'),
                    sent_len,
                    np.asarray(mask, dtype='bool'))

    
    # 读取hwdb手写中文图片库子程序
    # hwdb手写中文图片库按照文字保存,所有书写人写的同一个中文文字,按顺序保存到同一个文件中
    # 读取时需要随机选择其中一个书写人的图片数据输出
    @staticmethod
    def read_hwdb(filename):
        with open(filename, 'rb') as f:
            char_file = np.fromfile(f, dtype='uint8')

        img_byte = FLAGS.image_size*FLAGS.image_size*FLAGS.image_channel

        char_file_len = len(char_file)
        if (char_file_len % (img_byte)) != 0:
            raise ValueError("Characters file %s error" % filename)

        char_num = char_file_len // (img_byte)
        char_idx = np.random.randint(char_num)
        char_mat_uint8 = char_file[char_idx*img_byte:(char_idx+1)*img_byte]
        char_mat = char_mat_uint8.astype(np.float32)

        depth_major = char_mat.reshape([FLAGS.image_channel, FLAGS.image_size, FLAGS.image_size])
        image = depth_major.transpose([1, 2, 0])

        return image

    # dataset generator子程
    # 全python代码,因此可以非常灵活的进行各种文档读取即数据转换处理
    # 读取字符(.char)、编码(.code)、长度(.len)文件,处理并输出dataset所需的image/label/length/mask数组
    # image:本句话里所有文字对应的图片,[length, image_size, image_size, image_channel]
    # label:本句话里所有的图片对应的字符编码,[length]
    # length:本句话的字符长度,标量,非矩阵
    # mask: 本句话长度所对应的True数组,[length],所有元素均为True
    # 后来发现mask数据可以不用,可以使用tf.sequence_mask函数随时从len生成mask
    def file_readline(self):
        for filename in self.filenames:
            char_file = os.path.join(FLAGS.sample_dir, filename + '.char')
            code_file = os.path.join(FLAGS.sample_dir, filename + '.code')
            len_file = os.path.join(FLAGS.sample_dir, filename + '.len')
            frchar = io.open(char_file, 'r', encoding='utf-8')
            frcode = io.open(code_file, 'r', encoding='utf-8')
            frlen = io.open(len_file, 'r', encoding='utf-8')

            try:
                while True:
                    chars = frchar.readline()
                    codes = frcode.readline()
                    len_in_char = frlen.readline()
                    sent_len = int(len_in_char)

                    char_list = chars[:-1]
                    code_list = codes.split()

                    if len(char_list) != sent_len or len(code_list) != sent_len:
                        print(sent_len)
                        print(len(char_list))
                        print(len(code_list))
                        raise ValueError("Characters or labels length error")

                    if sent_len>FLAGS.sent_len_max:
                        while sent_len > FLAGS.sent_len_max:
                            char_code = []
                            char_vec = []
                            for code in code_list[-sent_len:-sent_len+FLAGS.sent_len_max]:
                                char_code.append(int(code))

                                char_filename = os.path.join(self.hwdb_dir, code + '.char')
                                char_vec.append(self.read_hwdb(char_filename))

                            mask = np.ones(FLAGS.sent_len_max, dtype='bool')
                            sent_len -= FLAGS.sent_len_max

                            yield self.dataset_convert(char_vec,char_code,FLAGS.sent_len_max,mask)

                        char_code = []
                        char_vec = []
                        for code in code_list[-sent_len:]:
                            char_code.append(int(code))

                            char_filename = os.path.join(self.hwdb_dir, code + '.char')
                            char_vec.append(self.read_hwdb(char_filename))

                        mask = np.ones(sent_len, dtype='bool')

                        yield self.dataset_convert(char_vec,char_code,sent_len,mask)

                    else:
                        char_code = []
                        char_vec = []
                        for code in code_list:
                            char_code.append(int(code))

                            char_filename = os.path.join(self.hwdb_dir, code+'.char')
                            char_vec.append(self.read_hwdb(char_filename))

                        mask = np.ones(sent_len, dtype='bool')

                        yield self.dataset_convert(char_vec,char_code,sent_len,mask)

            except ValueError:
                pass

            frchar.close()
            frcode.close()
            frlen.close()


    # dataset生成函数
    # 需要注意的是,在输出dataset之前,需要进行padding,把image/label/mask长度pad到FLAGS.sent_len_max
    def input_pipeline(self, batch_size, num_epochs=None, aug=False, shuffle=False):

        char_dataset = tf.data.Dataset.from_generator(self.file_readline,
                           (tf.float32,tf.int32,tf.int32,tf.bool),
                           (tf.TensorShape([None,FLAGS.image_size,FLAGS.image_size,FLAGS.image_channel]),
                            tf.TensorShape([None]),tf.TensorShape([]),tf.TensorShape([None])))

        if aug:
            char_dataset = char_dataset.map(self.data_augmentation)
        char_dataset = char_dataset.repeat(num_epochs)
        if shuffle:
            char_dataset = char_dataset.shuffle(FLAGS.shuffle_size)
        char_dataset = char_dataset.padded_batch(
            batch_size,
            padded_shapes=(tf.TensorShape([FLAGS.sent_len_max,FLAGS.image_size,
                                           FLAGS.image_size,FLAGS.image_channel]),
                           tf.TensorShape([FLAGS.sent_len_max]),
                           tf.TensorShape([]),
                           tf.TensorShape([FLAGS.sent_len_max])),
            padding_values=(0.,0, 0, False))

        iterator = char_dataset.make_one_shot_iterator()
        databatch = iterator.get_next()

        return databatch


# datafeeding测试程序
# 读取相应文件内容,形成batch输出并打印
def datafeeding_test(train_files, valid_files):
    train_feeder = DataIterator(train_files, istrain=True)
    valid_feeder = DataIterator(valid_files, istrain=False)

    # 这里得到的batch,可以直接输入到模型中,而不用使用placehoder
    # 训练集
    trn_dataset = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True)
    train_image_batch = trn_dataset[0]
    train_label_batch = trn_dataset[1]
    train_len_batch = trn_dataset[2]
    train_mask_batch = trn_dataset[3]
    
    # validation集
    val_dataset = valid_feeder.input_pipeline(batch_size=FLAGS.batch_size)
    valid_image_batch = val_dataset[0]
    valid_label_batch = val_dataset[1]
    valid_len_batch = val_dataset[2]
    valid_mask_batch = val_dataset[3]

    fr = open(char_dict_file, 'rb')
    char_dict = pickle.load(fr)
    fr.close()

    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())

        try:
            training_steps = 0
            while True:
                
                # 实际程序中,此处放置每个batch的训练代码,当前放置打印代码,验证dataset输出正确
                trn_images, trn_labels, trn_lens, trn_masks = sess.run(
                    [train_image_batch, train_label_batch,
                     train_len_batch, train_mask_batch])
                print('!!! Train batch #%d !!!' % training_steps)
                print('--Shape of train images batch: (%d,%d,%d,%d,%d)' % trn_images.shape)
                print('--Corresponding sentence lengths of the batch:')
                print(trn_lens)
                print('--Corresponding labels of the batch:')
                for ii in range(FLAGS.batch_size):
                    for jj in range(trn_lens[ii]):
                        print('%5d' % trn_labels[ii,jj]),
                    print('')
                print('--Corresponding characters of the batch:')
                for ii in range(FLAGS.batch_size):
                    for jj in range(trn_lens[ii]):
                        print(char_dict[trn_labels[ii,jj]]),
                    print('')

                print('')

                training_steps += 1
                
                # 每FLAGS.eval_steps个训练batch后进行一次validation
                if training_steps%FLAGS.eval_steps == 0:
                    val_images, val_labels, val_lens, val_masks = sess.run(
                        [valid_image_batch, valid_label_batch,
                         valid_len_batch, valid_mask_batch])
                    print('### Validation batch #%d ###' % (training_steps//FLAGS.eval_steps-1))
                    print('--Shape of validation images batch: (%d,%d,%d,%d,%d)' % val_images.shape)
                    print('--Corresponding sentence lengths of the batch:')
                    print(val_lens)
                    print('--Corresponding labels of the batch:')
                    for ii in range(FLAGS.batch_size):
                        for jj in range(val_lens[ii]):
                            print('%5d' % val_labels[ii,jj]),
                        print('')
                    print('--Corresponding characters of the batch:')
                    for ii in range(FLAGS.batch_size):
                        for jj in range(val_lens[ii]):
                            print(char_dict[val_labels[ii,jj]]),
                        print('')

                    print('')
                
                # 仅用于测试,因此输出3个测试batch后即终止
                if training_steps == 3:
                    break

        except tf.errors.OutOfRangeError:
            print('==================Finished================')


def main(_):
    train_filelist = ['trn1', 'trn2']
    valid_filelist = ['val']
    datafeeding_test(train_filelist, valid_filelist)


if __name__ == "__main__":
    tf.app.run()

上述代码中用到的trn1.char、trn2.char和val.char文件内容如下(相应的.code和.len文件内容就不贴出来了,都很简单,用于示例):

trn1.char

钻石闪烁的光芒照射着世间人们日益懦弱的心灵。
那成功让钻石增值的一刀似乎在昭示着一个发人深省的道理:

trn2.char

人生需要勇气,
需要平常心。

val.char

手就不会发抖,
心也更会坚定。
因此,
抛开杂念,
用勇气成就人生吧!


输出结果(很抱歉,由于实际手写图片数据不能上传,读者无法直接运行,但可参照文件内容及如下输出理解上述data feeding代码):

!!! Train batch #0 !!!
--Shape of train images batch: (3,10,64,64,1)
--Corresponding sentence lengths of the batch:
[10 10  2]
--Corresponding labels of the batch:
 6493  4187  6674  3559  4061   528  5115  3607  1631  4131 
  185  6683   278   312  2586  4082  2178  1920  4061  1976 
 3526   158 
--Corresponding characters of the batch:
钻 石 闪 烁 的 光 芒 照 射 着 
世 间 人 们 日 益 懦 弱 的 心 
灵 。 

!!! Train batch #1 !!!
--Shape of train images batch: (3,10,64,64,1)
--Corresponding sentence lengths of the batch:
[10 10  7]
--Corresponding labels of the batch:
 6301  2189   685  5771  6493  4187  1341   470  4061   169 
  616   353   222  1197  2622  4295  4131   169   198   815 
  278  3342  4106  4061  6273  3822    24 
--Corresponding characters of the batch:
那 成 功 让 钻 石 增 值 的 一 
刀 似 乎 在 昭 示 着 一 个 发 
人 深 省 的 道 理 : 

### Validation batch #0 ###
--Shape of validation images batch: (3,10,64,64,1)
--Corresponding sentence lengths of the batch:
[7 7 3]
--Corresponding labels of the batch:
 2220  1649   178   334   815  2258    10 
 1976   234  2683   334  1221  1581   158 
 1173  3028    10 
--Corresponding characters of the batch:
手 就 不 会 发 抖 , 
心 也 更 会 坚 定 。 
因 此 , 

!!! Train batch #2 !!!
--Shape of train images batch: (3,10,64,64,1)
--Corresponding sentence lengths of the batch:
[ 7  6 10]
--Corresponding labels of the batch:
  278  3908  6810  5719   702  3090    10 
 6810  5719  1836  1820  1976   158 
 6493  4187  6674  3559  4061   528  5115  3607  1631  4131 
--Corresponding characters of the batch:
人 生 需 要 勇 气 , 
需 要 平 常 心 。 
钻 石 闪 烁 的 光 芒 照 射 着 



猜你喜欢

转载自blog.csdn.net/foreseerwang/article/details/80914473