以上代码将各个文件夹中的图片和标签制作为tfrecord格式数据集,使用了PIL打开图片
import os import tensorflow as tf from PIL import Image import matplotlib.pyplot as plt import numpy as np cwd = 'D:/python学习/神经网络动物分类/train/' classes = ["airplane", "automobile","bird","cat","deer", "dog","frog","horse","ship","truck"] tfRecord_train = "D:\\python学习\\神经网络动物分类\\train.tfrecords" tfRecord_test = "D:\\python学习\\神经网络动物分类\\test.tfrecords" writer = tf.python_io.TFRecordWriter("train.tfrecords") for index, name in enumerate(classes): class_path = cwd + name + '/' for img_name in os.listdir(class_path): img_path = class_path + img_name img = Image.open(img_path) # img = img.resize((128,128)) img_raw = img.tobytes() example = tf.train.Example(features = tf.train.Features(feature={ "label":tf.train.Feature(int64_list = tf.train.Int64List(value=[index])), 'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw])) })) writer.write(example.SerializeToString()) writer.close()