当训练数据量较小时,采用直接读取文件的方式,当训练数据量非常大时,直接读取文件的方式太耗内存,这时应采用高效的读取方法,读取tfrecords文件,这其实是一种二进制文件。tensorflow为其内置了各种存储和读取的函数,方便调用。
tensorflow 提供了统一的数据存储形式,就是tfrecord。它可以读取图片,csv.txt文本最新的dataset也可以直接读,所以不打算将,后面碰到了再分析吧。
保存在tfrecord里面。它保存的信息也更加复杂,多元。我们可以根据需要从里面获取数据。
总的来说转化只需3步:
1.我们可以写一段代码获取你的原始数据。
2.通过修改 tf.train.Example 的Features,将数据填入到Example协议内存块(protocol buffer)。
3.将协议内存块(protocol buffer)序列化为一个字符串,并且通过tf.python_io.TFRecordWriter将序列化的字符串写入到TFRecords文件。
===========================
直接上代码:我的代码结构图。
copy出来可能需要调整一下空格排版。
# -*- coding: utf-8 -*-
import osimport tensorflow as tf
from PIL import Image
#saved 图片路径,事先准备好的图片路径
cwd = './'
#will save文件路径,你的tfrecord文件做好后,将会保存的路径
filepath = './'
#指定每个tfrecord存放图片个数
bestnum = 300
#第几个图片,初始化作用不用管
num = 0
#第几个TFRecord文件,初始化作用不用管
recordfilenum = 0
#类别,也就是你原始图片的文件夹名字,注意你的文件夹名字就是你该类别的名字,有几个图片文件夹就写几个
classes=['rock',
'wood'
]
#指定好tfrecords格式文件名
ftrecordfilename = (" traindata.tfrecords-%.3d" % recordfilenum)
writer= tf.python_io.TFRecordWriter(filepath+ftrecordfilename)
#类别和路径
for index,name in enumerate(classes):#遍历所有含图片的文件夹
class_path=cwd+name+'/'
for img_name in os.listdir(class_path):#遍历该图片文件夹里面的所有图片名字
#以下代码表示开始对一张图片--》tfrecord格式的操作
if num>bestnum:#当读取该图片数目大于上面指定的tfrecord保存总数,自动保存到下一个tfrecord文件里
num = 1 #当该tfrecord文件夹图片读取满了,再开始新的一个tfrecord,该图片的编号自然从1开始编号。
recordfilenum = recordfilenum + 1
#重新指定tfrecords格式文件名
ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)
writer= tf.python_io.TFRecordWriter(filepath+ftrecordfilename)
#否则不会进入下一个tfrecord,说明当前tfrecord 文件夹还可以继续保存图片,num很重要保证不会乱。
#因为即使进入新的tfrecord,或者还在原来的tfrecord 文件里面,保存图片为tfrecord操作一样,所以合并一起写
img=Image.open(img_path,'r')
size = img.size
img_raw=img.tobytes()#将图片转化为二进制格式
#这里是核心,指定了图片转为tfrecord数据格式后,该格式保存了哪些图片特征赋值在value里面,
#tfrecord是字典{键:值}形式。键是tf自己维护不用我们管,我们把图片特征都放在值里面。
#在这个图片的值里面,我们又放了图片数据,对应标签,图片宽和高。4个特性,这4个特性要指定
#保存的数据类型一般int和byte即str用的较多,很少用float。其实里面不止放4个特性,根据你自己
#的要求,可以放任意个。
features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'img_width':tf.train.Feature(int64_list=tf.train.Int64List(value=[size[0]])),
'img_height':tf.train.Feature(int64_list=tf.train.Int64List(value=[size[1]]))
}))
#表示当前文件夹里面的该图片转化为tfrecord形式,至此完成了一张图片--》tfrecord格式,
#下一次循环,开始新的一张pic
writer.write(example.SerializeToString()) #序列化为字符串
#所有包含图片的文件夹遍历结束,关闭数据流
writer.close()
=========code——end====注意说明======
1. 以上代码红色字体,表示需要自己修改的,注释很清晰,很容易修改。
2. 该代码需要把图片的文件夹命名为类别名字,指定路径等。
3. 需要指定每个tfrecord保存图片个数,自己定。然后它会按你指定了文件类别名字顺序进行打开,
然后遍历里面的图片,最终按指定文件的先后顺序,代码会自己计算最终保存多少个tfrecord文件。
源代码里面看这个定顺序:classes=['rock',
'wood'
]
比如
我有2个文件,每个文件800个图片,我们指定每个tfrecord保存300个图片,
最终会有 2*800/300=5,即余下的100放不下,会自动新建一个tfrecord。所以有6个tfrecord文件。
而且6个tfrecord文件,也是按你上面指定的文件夹先后顺序来保存的。对应的每个文件夹的label,
也是按上面的顺序来,从0开始给每个图片度标记了。因为我们文件夹即类别名字也。
4.大家如果要做test数据集,自己准备一个文件存图片用上面的代码,一样的套路。记得改上面的
路径啊,tfrecord的保存名字啊,上面tfrecord实际保存数目等,红色标示部分代码即可。
5.官方提供的生成tfrecord代码,好像会自动生成一个TXT文件保存label,这个自己模仿他们官方
的文件格式手动写进去,一般自己的分类不会太多嘛,就那么几个,实在不行就Python脚本生成吧。
6. 里面的图片名字任意取,代码会读取的,尽量不用中文担心乱码,认不出。
======效果图
参考:
https://blog.csdn.net/chaipp0607/article/details/72960028