【62】Triplet 损失

1. 三元组损失函数
已经了解了Siamese网络架构,并且知道想要网络输出什么,即什么是好的编码
但是如何定义实际的目标函数
能够让神经网络学习并做到上节讨论的内容呢?

要想通过学习神经网络的参数来得到优质的人脸图片编码
方法之一就是:定义三元组损失函数然后应用梯度下降

为了应用三元组损失函数,需要比较成对的图像
比如这个图片,为了学习网络的参数,需要同时看几幅图片
比如这对图片(编号1和编号2),想要它们的编码相似,因为这是同一个人
然而假如是这对图片(编号3和编号4),会想要它们的编码差异大一些,因为这是不同的人

用三元组损失的术语来说,要做的通常是看一个 Anchor 图片
想让 Anchor图片 和 Positive图片(Positive意味着是同一个人)的距离很接近
当 Anchor图片 与 Negative图片(Negative意味着是非同一个人)对比时
会想让他们的距离离得更远一点

它代表你通常会同时看三张图片,需要看 Anchor图片、Postive图片,还有Negative图片
要把 Anchor图片、Positive图片和Negative图片简写成 A、P、N

2. 损失函数公式
把这些写成公式的话,想要的是网络的参数或者编码能够满足以下特性:
想要 ||f( A )-f( P )||2,希望这个数值很小
准确地说,想让它小于等 f( A ) 和 f( N ) 之间的距离
或者说是它们的范数的平方,即:||f( A ) - f( P )||2 ≤ ||f( A ) - f( N )||2
||f( A ) - f( P )||2 ,这就是 d(A,P)
||f( A ) - f( N )||2 ,这就是 d(A,N) 
可以把 d看作是距离(distance)函数,这也是为什么把它命名为 d 

现在如果把方程右边项移到左边,最终就得到:
||f( A ) - f( P )||2 - ||f( A ) - f( N )||2 ≤ 0

现在要对这个表达式做一些小的改变
有一种情况满足这个表达式,但是没有用处,就是把所有的东西都学成0
如果 f 总是输出0,即0 - 0 ≤ 0,这就是0减去0还等于0
如果所有图像的 f 都是一个零向量,那么总能满足这个方程

所以为了确保网络对于所有的编码不会总是输出0
也为了确保它不会把所有的编码都设成互相相等的
另一种方法能让网络得到这种没用的输出
就是如果每个图片的编码和其他图片一样
这种情况,还是得到0 - 0

为了阻止网络出现这种情况,需要修改这个目标
也就是这个不能是刚好小于等于0,应该是比0还要小
所以这个应该小于一个 -a 值(即 ||f( A ) - f( P )||2 - ||f( A ) - f( N )||2 ≤ -a)
这里的 a 是另一个超参数,这个就可以阻止网络输出无用的结果
按照惯例,习惯写 +a(即 ||f( A ) - f( P )||2 - ||f( A ) - f( N )||2 + a ≤ 0)
而不是把 -a 写在后面,它也叫做 间隔(margin)
这个术语你会很熟悉,如果看过关于支持向量机 (SVM) 的文献
可以把上面这个方程(||f( A ) - f( P )||2 - ||f( A ) - f( N )||2)也修改一下,加上这个间隔参数

举个例子,假如间隔设置成0.2
如果在这个例子中,如果 Anchor和 Negative图片的d(A,P)= 0.5,
即 d(A,N)只大一点,比如说0.51,条件就不能满足
虽然0.51也是大于0.5的,但还是不够好

想要 d(A,N)比 d(A,P) 大很多,会想让 d(A,N)至少是0.7或者更高
或者为了使这个间隔,或者间距至少达到0.2,可以把这项调大或者这个调小
这样这个间隔a,超参数a 至少是0.2
在 d(A,P) 和 d(A,N) 之间至少相差0.2,这就是间隔参数a的作用
它拉大了Anchor和Positive 图片对和Anchor与Negative 图片对之间的差距

取下面的这个方框圈起来的方程式,更公式化表示,然后定义三元组损失函数

其中positive图片和anchor图片是同一个人,但是negative图片和anchor不是同一个人

接下来定义损失函数,这个例子的损失函数,它的定义基于三元图片组

所以为了定义这个损失函数:
L(A,P,N) ) =max (||f( A ) - f( P )||2 - ||f( A ) - f( N )||2+a,0)

这个max函数的作用就是,只要这个||f( A ) - f( P )||2 - ||f( A ) - f( N )||2 + a ≤ 0
那么损失函数就是0
只要能使画绿色下划线部分小于等于0,只要能达到这个目标
那么这个例子的损失就是0

另一方面如果这个||f( A ) - f( P )||2 - ||f( A ) - f( N )||2 + a ≤ 0
然后取它们的最大值,最终会得到绿色下划线部分,即||f( A ) - f( P )||2 - ||f( A ) - f( N )||2 + a是最大值
这样会得到一个正的损失值
通过最小化这个损失函数达到的效果:使这部分 ||f( A ) - f( P )||2 - ||f( A ) - f( N )||2 + a 小于或者等于0
只要这个损失函数小于等于0,网络不会关心它负值有多大

整个网络的代价函数应该是 训练集中这些单个三元组损失的总和
假如有一个10000个图片的训练集,里面是1000个不同的人的照片(每个人十张图片)
要做的就是取这10000个图片,然后生成这样的三元组,然后训练学习算法
对这种代价函数用梯度下降,这个代价函数就是定义在数据集里的这样的三元组图片上

注意,为了定义三元组的数据集需要成对的A和P
即同一个人的成对的图片,为了训练系统确实需要一个数据集,里面有同一个人的多个照片
这样在1000个不同的人的10000张照片中,也许是这1000个人平均每个人10张照片
如果只有每个人一张照片,那么根本没法训练这个系统

当然,训练完这个系统之后,可以应用到一次学习问题上
对于人脸识别系统,可能只有想要识别的某个人的一张照片
但对于训练集,需要确保有同一个人的多个图片,至少是训练集里的一部分人
这样就有成对的Anchor和Positive图片了

3. 训练集
现在来看,如何选择这些三元组来形成训练集
一个问题是如果从训练集中,随机地选择A、P和N
遵守A和P是同一个人,而A和N是不同的人这一原则

有个问题就是,如果随机的选择它们,那么这个约束条件(d(A,P)+ a ≤ d(A,N) )很容易达到
因为随机选择的图片,A和N比A和P差别很大的概率很大
如果A和N是随机选择的不同的人,有很大的可能性 ||f( A ) - f( N )||2会比左边这项 ||f( A ) - f( P )||2
而且差距远大于a,这样网络并不能从中学到什么

所以为了构建一个数据集,要做的就是尽可能选择难训练的三元组A、P和N

具体而言,想要所有的三元组都满足这个条件(d(A,P)+ a ≤ d(A,N))
难训练的三元组就是,A、P和N的选择使得d(A,P)很接近d(A,N) ,即 d(A,P)≈d(A,N) d(A,P)≈d(A,N)
这样学习算法会竭尽全力使右边这个式子变大(d(A,N) ),或者使左边这个式子(d(A,P) )变小
这样左右两边至少有一个a的间隔
并且选择这样的三元组还可以增加学习算法的计算效率

如果随机的选择这些三元组,其中有太多会很简单,梯度算法不会有什么效果
因为网络总是很轻松就能得到正确的结果
只有选择难的三元组梯度下降法才能发挥作用,使得这两边离得尽可能远

如果对此感兴趣的话,这篇论文中有更多细节
作者是Florian Schroff, Dmitry Kalenichenko, James Philbin
他们建立了这个叫做 FaceNet 的系统,博客的许多观点都是来自于他们的工作
FaceNet: A Unified Embedding for Face Recognition and Clustering

总结一下,训练这个三元组损失需要取训练集,然后把它做成很多三元组,这就是一个三元组(编号1)

有一个Anchor图片和Positive图片,这两个图片是同一个人,还有一张另一个人的Negative图片
这是另一组(编号2),其中Anchor和Positive图片是同一个人,但这两个图片不是同一个人,等等

定义了这些包括A、P和N图片的数据集之后
还需要做的就是用梯度下降最小化之前定义的代价函数 J JJ
这样做的效果就是反向传播到网络中的所有参数来学习到一种编码,使得如果两个图片是同一个人
那么它们的 d 就会很小,如果两个图片不是同一个人,它们的 d dd 就会很大

这就是三元组损失,并且如何用它来训练网络输出一个好的编码用于人脸识别
现在的人脸识别系统,尤其是大规模的商业人脸识别系统都是在很大的数据集上训练
超过百万图片的数据集并不罕见,一些公司用千万级的图片,还有一些用上亿的图片来训练这些系统
这些是很大的数据集,即使按照现在的标准,这些数据集并不容易获得

幸运的是,一些公司已经训练了这些大型的网络并且上传了模型参数
所以相比于从头训练这些网络,在这一领域,由于这些数据集太大
这一领域的一个实用操作就是下载别人的预训练模型,而不是一切都要从头开始
但是即使下载了别人的预训练模型,了解怎么训练这些算法也是有用的
以防针对一些应用需要从头实现这些想法

猜你喜欢

转载自www.cnblogs.com/lau1997/p/12385363.html
62