Tutorial教程:生成triplet训练基于triplet loss的深度模型

Tutorial教程:生成triplet训练基于triplet loss的深度模型


我相信很多想使用triplet loss的人被triplet的生成难倒了。当然,如果你的机器足够好,网络上的很多代码就可以满足你了,github上有很多用于手写数字识别的代码,他们都很好用,譬如说:

https://github.com/charlesLucky/keras-triplet-loss-mnist

但是,里面只是怎么计算loss,没有怎么生成triplet,他的triplet就是每次从batch里面随即挑,好在mnist只有十个class,怎么挑选都很好挑到足够的positive 和 negetive. 对于很大的dataset,class又很多,你又要怎么生成呢?

FaceNet源代码有给出一些:
https://github.com/davidsandberg/facenet

可是TF1 的代码,看也看不大明白啊,所以我们来看看TF2怎么生成!

首先我们需要看看tf.data.Dataset.interleave()

interleave()是Dataset的类方法,所以interleave是作用在一个Dataset上的。

interleave(
    map_func,
    cycle_length=AUTOTUNE,
    block_length=1,
    num_parallel_calls=None
)

以下解释和案例来自[1],[2]

解释:

  • 假定我们现在有一个Dataset——A
  • 从该A中取出cycle_length个element,然后对这些element apply map_func,得到cycle_length个新的Dataset对象。
  • 然后从这些新生成的Dataset对象中取数据,取数逻辑为轮流从每个对象里面取数据,每次取block_length个数据
  • 当这些新生成的某个Dataset的对象取尽时,从原Dataset中再取cycle_length个element,,然后apply map_func,以此类推。

举个例子:

a = tf.data.Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
# NOTE: New lines indicate "block" boundaries.
b=a.interleave(lambda x: tf.data.Dataset.from_tensors(x).repeat(6),
            cycle_length=2, block_length=4) 
for item in b:
    print(item.numpy(),end=', ')

结果为:

1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 4, 4, 5, 5, 5, 5, 5, 5, 

上面程序的图示,看示意图可能更清晰:
在这里插入图片描述
其中map_func在这里是重复6次-repeat(6)。

那么我们可以用它来生成我们想要的triplet

def pair_parser(imgs):
    # Note y_true shape will be [batch,3]
    return (imgs[0], imgs[1], imgs[2]),([1,1,2])
def processOneDir4(basedir):
    list_ds = tf.data.Dataset.list_files(basedir+"/*.jpg").shuffle(100).repeat()
    return list_ds
def generateTriplet(imgs,label):
    labels = [int(tf.strings.split(imgs[0],os.path.sep)[0,-2]),int(tf.strings.split(imgs[1],os.path.sep)[0,-2]),int(tf.strings.split(imgs[2],os.path.sep)[0,-2])]
    return (imgs),(labels)
    
dbdir = "./data"
allsubdir = [os.path.join(dbdir, o) for o in os.listdir(dbdir) 
                    if os.path.isdir(os.path.join(dbdir,o))]
path_ds = tf.data.Dataset.from_tensor_slices(allsubdir)
ds = path_ds.interleave(lambda x: processOneDir4(x), cycle_length=5751,
                  block_length=2,
                  num_parallel_calls=4).batch(4, True).map(pair_parser, -1).batch(1, True).map(generateTriplet, -1)

即可。

参考:
[1] https://tensorflow.google.cn/api_docs/python/tf/data/Dataset?version=stable#interleave
[2]https://blog.csdn.net/menghuanshen/article/details/104240189

猜你喜欢

转载自blog.csdn.net/MrCharles/article/details/105596292
今日推荐