Tensorflow将图片集制作为tfrecord格式数据

以上代码将各个文件夹中的图片和标签制作为tfrecord格式数据集,使用了PIL打开图片

import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

cwd = 'D:/python学习/神经网络动物分类/train/'

classes = ["airplane", "automobile","bird","cat","deer",
           "dog","frog","horse","ship","truck"]

tfRecord_train = "D:\\python学习\\神经网络动物分类\\train.tfrecords"
tfRecord_test = "D:\\python学习\\神经网络动物分类\\test.tfrecords"

writer = tf.python_io.TFRecordWriter("train.tfrecords")

for index, name in enumerate(classes):
    class_path = cwd + name + '/'
    for img_name in os.listdir(class_path):
        img_path = class_path + img_name
        img = Image.open(img_path)
        # img = img.resize((128,128))
        img_raw = img.tobytes()
        example = tf.train.Example(features = tf.train.Features(feature={
            "label":tf.train.Feature(int64_list = tf.train.Int64List(value=[index])),
            'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
        }))
        writer.write(example.SerializeToString())

writer.close()


猜你喜欢

转载自blog.csdn.net/wwxy1995/article/details/80488568