TF_Record格式数据制作、读取 (基于猫狗大战、cifar10数据)

1、制作TF_Record数据集

  试用Image.open()打开图片会占用很大的内存,在这里我试用的是tf.gfile.Gfile,所以建议大家试用tf.gfile.Gfile(path,‘rb’)打开图片;

def create_tf_example(img_list,label,sess):
    #image = Image.open(img_list)   # 使用PIL skimage cv 读取图片 占用内存较大
    #image = image.resize((300,300))
    #image = image.tobytes()

    with tf.gfile.FastGFile(img_list,'rb') as fid:
        img = fid.read()

    ##  数据预处理   但是速度较慢   还是提前转换好尺寸大小较好
    #img = tf.image.decode_png(img, channels=3)  # 这里,也可以解码为 1 通道
    #img = tf.image.resize_image_with_crop_or_pad(img,40,40)   # 预处理速度很慢  但是效果较好   补充黑边 或 中心裁剪
    #img = tf.image.resize_images(img,[40,40])   # 速度较快   但还是很慢
    #image = sess.run(img)
    #image_bytes = image.tobytes()   # 将张量转换为 bytes   注意这种格式的解码方式不一样

    example = tf.train.Example(features=tf.train.Features(
        feature={
            'label': _int64_feature(label),
            'img_raw': _bytes_feature(img),
            #'width':_int64_feature(width),
            #'height':_int64_feature(height)
        }))

    return example    # 返回一个可写入的example

2、读取Cat_vs_Dogs数据,并生成record:

filepath = '/Users/***/Git_Mac/Cat_Vs_Dog/train/'     #cat_dog 根目录  下面
out_dir = 'Record/cat_dogs_2.record' 

def Creat_Cats_Vs_Dogs(file_dir,out_dir):
    cats = []
    label_cats = []
    dogs = []
    label_dogs = []
    for file in os.listdir(file_dir):   #读取所有图片的路径
        name = file.split(sep='.')
        if name[0]=='cat':
            cats.append(file_dir + file)
            label_cats.append(0)
        else:
            dogs.append(file_dir + file)
            label_dogs.append(1)
    print('There are %d cats\nThere are %d dogs' %(len(cats), len(dogs)))

    image_list = np.hstack((cats, dogs))    #组合数据
    label_list = np.hstack((label_cats, label_dogs))

    temp = np.array([image_list, label_list])
    temp = temp.transpose()   # 转置
    np.random.shuffle(temp)   # 打乱数据

    image_list = list(temp[:, 0])
    label_list = list(temp[:, 1])
    label_list = [int(i) for i in label_list]

    sess = tf.Session()
    write = tf.python_io.TFRecordWriter(out_dir)
    count =0
    for img,lbe in zip(image_list,label_list):

        example = create_tf_example(img,lbe,sess)   #每个图片生成一个example  并写入
        write.write(example.SerializeToString())

        count += 1
        if(count % 1000 == 0):
            print(count)

3、读取Cifar10数据,并生成record:

path = 'Git_Mac/cifar10'   # data 根目录
def Create_Cifar10_Record(path,out_dir):
    write = tf.python_io.TFRecordWriter(out_dir)

    label_list = [0,1,2,3,4,5,6,7,8,9]  # 标签
    sess = tf.Session()

    for index,directory in zip(label_list,['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']):
        img_list = glob.glob(os.path.join(path,'{}/*.png'.format(directory)))  #读取所有的图片路径
        count = 0

        for img in img_list:
            example = create_tf_example(img,index,sess=sess)
            write.write(example.SerializeToString())
            count+=1
            if(count %100 ==0):   # 查看标签  和  进度
                print(count)

    sess.close()
    write.close()

4、从TF_Record格式中读取数据

   从record格式中读取数据并解码  

def read_and_decode(tfrecords_file, batch_size, shuffle,n_class,one_hot = False):
    filename_queue = tf.train.string_input_producer([tfrecords_file])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  

    img_features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string),
            #'width': tf.FixedLenFeature([], tf.int64),
            #'height': tf.FixedLenFeature([], tf.int64)
        })

    #image = tf.decode_raw(img_features['img_raw'], tf.uint8)
    #width = tf.cast(img_features['width'], tf.int32)
    #height = tf.cast(img_features['height'], tf.int32)

    img = img_features['img_raw']
    img = tf.image.decode_png(img, channels=3)  # 解码图片  png格式   jpg 使用 decode_jpeg()
    image = tf.reshape(img, [32, 32, 3])   # 32*32*3   这个需要根据你自己的格式进行修改
    label = tf.cast(img_features['label'], tf.int32) 
    #image = tf.reshape(image, [300,300,3])
    image = tf.image.per_image_standardization(image)  # 标准化处理

    if shuffle:         # 是否打乱数据顺序  如果capacity设置过小 会导致数据混合不完全 打乱数据读取会占用很多内存
        image_batch, label_batch = tf.train.shuffle_batch(
            [image, label],
            batch_size = batch_size,
            num_threads= 64,
            capacity = 20000,
            min_after_dequeue = 1000)
    else:
        image_batch, label_batch= tf.train.batch(
            [image,label],
            batch_size = batch_size,
            num_threads = 64,
            capacity= 2000)

    image_batch = tf.cast(image_batch, tf.float32)   # 转换为tf.float32 格式

    if(one_hot == True):    # 生成one_hot格式标签  one_hot格式标签  对应不同的loss 设置方式
        label_batch = tf.one_hot(label_batch, depth= n_class)
        label_batch = tf.cast(label_batch, dtype=tf.int32)
        label_batch = tf.reshape(label_batch, [batch_size, n_class])

    return image_batch, label_batch

 线程读取数据

def Read_Record(filepath):
    with tf.Session() as sess: #开始一个会话
        image,label = read_and_decode(filepath,batch_size=batch_size,shuffle=True,n_class=2,one_hot=False)

        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        coord=tf.train.Coordinator()    # 很重要的
        threads= tf.train.start_queue_runners(coord=coord)
        try:
            for step in range(MAX_STEP):
                if coord.should_stop():
                    break
                img,lbe = sess.run([image,label])
                # 添加你自己的模型 teain 
                #plot_images(img,lbe,batch_size=batch_size)

        except tf.errors.OutOfRangeError as e:
            print(e)
        finally:
            coord.request_stop()

        coord.join(threads)

猜你喜欢

转载自blog.csdn.net/junqing_wu/article/details/80210413
今日推荐