Tutorial教程:生成triplet训练基于triplet loss的深度模型
我相信很多想使用triplet loss的人被triplet的生成难倒了。当然,如果你的机器足够好,网络上的很多代码就可以满足你了,github上有很多用于手写数字识别的代码,他们都很好用,譬如说:
https://github.com/charlesLucky/keras-triplet-loss-mnist
但是,里面只是怎么计算loss,没有怎么生成triplet,他的triplet就是每次从batch里面随即挑,好在mnist只有十个class,怎么挑选都很好挑到足够的positive 和 negetive. 对于很大的dataset,class又很多,你又要怎么生成呢?
FaceNet源代码有给出一些:
https://github.com/davidsandberg/facenet
可是TF1 的代码,看也看不大明白啊,所以我们来看看TF2怎么生成!
首先我们需要看看tf.data.Dataset.interleave()
interleave()是Dataset的类方法,所以interleave是作用在一个Dataset上的。
interleave(
map_func,
cycle_length=AUTOTUNE,
block_length=1,
num_parallel_calls=None
)
以下解释和案例来自[1],[2]
解释:
- 假定我们现在有一个Dataset——A
- 从该A中取出cycle_length个element,然后对这些element apply map_func,得到cycle_length个新的Dataset对象。
- 然后从这些新生成的Dataset对象中取数据,取数逻辑为轮流从每个对象里面取数据,每次取block_length个数据
- 当这些新生成的某个Dataset的对象取尽时,从原Dataset中再取cycle_length个element,,然后apply map_func,以此类推。
举个例子:
a = tf.data.Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
# NOTE: New lines indicate "block" boundaries.
b=a.interleave(lambda x: tf.data.Dataset.from_tensors(x).repeat(6),
cycle_length=2, block_length=4)
for item in b:
print(item.numpy(),end=', ')
结果为:
1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 4, 4, 5, 5, 5, 5, 5, 5,
上面程序的图示,看示意图可能更清晰:
其中map_func在这里是重复6次-repeat(6)。
那么我们可以用它来生成我们想要的triplet
def pair_parser(imgs):
# Note y_true shape will be [batch,3]
return (imgs[0], imgs[1], imgs[2]),([1,1,2])
def processOneDir4(basedir):
list_ds = tf.data.Dataset.list_files(basedir+"/*.jpg").shuffle(100).repeat()
return list_ds
def generateTriplet(imgs,label):
labels = [int(tf.strings.split(imgs[0],os.path.sep)[0,-2]),int(tf.strings.split(imgs[1],os.path.sep)[0,-2]),int(tf.strings.split(imgs[2],os.path.sep)[0,-2])]
return (imgs),(labels)
dbdir = "./data"
allsubdir = [os.path.join(dbdir, o) for o in os.listdir(dbdir)
if os.path.isdir(os.path.join(dbdir,o))]
path_ds = tf.data.Dataset.from_tensor_slices(allsubdir)
ds = path_ds.interleave(lambda x: processOneDir4(x), cycle_length=5751,
block_length=2,
num_parallel_calls=4).batch(4, True).map(pair_parser, -1).batch(1, True).map(generateTriplet, -1)
即可。
参考:
[1] https://tensorflow.google.cn/api_docs/python/tf/data/Dataset?version=stable#interleave
[2]https://blog.csdn.net/menghuanshen/article/details/104240189