【个人思考】Tensorflow Triplet SemiHard Loss 代码详解

Triplet SemiHard Loss 代码详解

导读

这段时间的triplet loss真是让我头痛
当然也看了非常多不错的解析

  1. Triplet-Loss原理及其实现
  2. Tensorflow实现Triplet Loss
    都是非常不错的解析,让人清晰易懂

triplet loss中,可以说最关键的就是semihard loss,原论文中也是使用了这种训练方式。所以在自己的项目中也是强调了这一部分。
triplet loss理应有三个input, anchor, positive,negative。为什么我选择这个呢?
因为这个是通过输入label以及embedding,就可以对应的计算出来semihard loss了,非常的简单易用。如果自己去做筛选A P N,就会有预处理数据筛选三个样本的情况。

所以在原始框架下看看tensorflow代码:
tf.contrib.losses.metric_learning.triplet_semihard_loss
从tf 0.8开始就支持,而且代码可以copy出来用,毕竟arrays_op等基础package在tensorflow里面都是通用的

这一部分的代码直接拿过来用了,但是咱们毕竟是需要去思考其背后实现原理的。(其实我只是觉得tensorflow很烦躁,想换成torch去实现哈哈哈

于是我实现了2份代码,算上tensorflow提供的,一共3份代码版本

  • 1 tensorflow版本
  • 2 numpy version
  • 3 torch version
    首先我们观测

tensorflow version 核心代码

def triplet_semihard_loss(labels, embeddings, margin=1.0):
    """Computes the triplet loss with semi-hard negative mining.
    The loss encourages the positive distances (between a pair of embeddings with
    the same labels) to be smaller than the minimum negative distance among
    which are at least greater than the positive distance plus the margin constant
    (called semi-hard negative) in the mini-batch. If no such negative exists,
    uses the largest negative distance instead.
    See: https://arxiv.org/abs/1503.03832.
    Args:
      labels: 1-D tf.int32 `Tensor` with shape [batch_size] of
        multiclass integer labels.
      embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should
        be l2 normalized.
      margin: Float, margin term in the loss definition.
    Returns:
      triplet_loss: tf.float32 scalar.
    """
    # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor.
    lshape = array_ops.shape(labels)
    assert len(lshape.shape) == 1
    labels = array_ops.reshape(labels, [lshape[0], 1])
    # Build pairwise squared distance matrix.
    pdist_matrix = pairwise_distance(embeddings, squared=True)
    # Build pairwise binary adjacency matrix.
    adjacency = math_ops.equal(labels, array_ops.transpose(labels))
    # Invert so we can select negatives only.
    adjacency_not = math_ops.logical_not(adjacency)

    batch_size = array_ops.size(labels)

    # Compute the mask.
    pdist_matrix_tile = array_ops.tile(pdist_matrix, [batch_size, 1])
    mask = math_ops.logical_and(
        array_ops.tile(adjacency_not, [batch_size, 1]),
        math_ops.greater(
            pdist_matrix_tile, array_ops.reshape(
                array_ops.transpose(pdist_matrix), [-1, 1])))

    mask_final = array_ops.reshape(
        math_ops.greater(
            math_ops.reduce_sum(
                math_ops.cast(mask, dtype=dtypes.float32), 1, keepdims=True),
            0.0), [batch_size, batch_size])
    mask_final = array_ops.transpose(mask_final)

    adjacency_not = math_ops.cast(adjacency_not, dtype=dtypes.float32)
    mask = math_ops.cast(mask, dtype=dtypes.float32)

    # negatives_outside: smallest D_an where D_an > D_ap.
    negatives_outside = array_ops.reshape(
        masked_minimum(pdist_matrix_tile, mask), [batch_size, batch_size])
    negatives_outside = array_ops.transpose(negatives_outside)

    # negatives_inside: largest D_an.
    negatives_inside = array_ops.tile(
        masked_maximum(pdist_matrix, adjacency_not), [1, batch_size])
    semi_hard_negatives = array_ops.where(
        mask_final, negatives_outside, negatives_inside)
    loss_mat = math_ops.add(margin, pdist_matrix - semi_hard_negatives)

    mask_positives = math_ops.cast(
        adjacency, dtype=dtypes.float32) - array_ops.diag(
        array_ops.ones([batch_size]))

    # In lifted-struct, the authors multiply 0.5 for upper triangular
    #   in semihard, they take all positive pairs except the diagonal.
    num_positives = math_ops.reduce_sum(mask_positives)

    triplet_loss = math_ops.truediv(
        math_ops.reduce_sum(
            math_ops.maximum(
                math_ops.multiply(loss_mat, mask_positives), 0.0)),
        num_positives,
        name='triplet_semihard_loss')

    return triplet_loss

其中pairwise_distance太简单了,此处略过,如果不懂可以看上述链接

Numpy 版本代码

因为tensorflow的debug起来太麻烦了(懂得都懂)
所以我根据tf实现了numpy版本的triplet semihard loss,并根据其来解析triplet semihard loss的具体工作流程,方便大家可以用来使用

import numpy as np
# 测试样例
labels = np.array([0,1,1,0,1])
embeddings = np.array([[0.20251631, 0.49964871, 0.31357543, 0.99332346, 0.40536699,
        0.05654062, 0.07307319, 0.2950833 , 0.5154805 , 0.43801481],
       [0.05170506, 0.92920793, 0.50820659, 0.80957615, 0.59039356,
        0.83899964, 0.3024558 , 0.29522561, 0.90828209, 0.7059259 ],
       [0.06045745, 0.73130719, 0.81192888, 0.37673241, 0.41282683,
        0.00261911, 0.54569239, 0.52696678, 0.94666249, 0.4798159 ],
       [0.9031102 , 0.09828223, 0.67050717, 0.77313736, 0.47979198,
        0.93205683, 0.30714715, 0.66625816, 0.11693463, 0.75662641],
       [0.13010331, 0.70302084, 0.29719897, 0.4037086 , 0.60219295,
        0.18917132, 0.0928293 , 0.70829784, 0.6350869 , 0.74187586]], dtype=np.float32)
margin = 1.0


pairwise_distances_squared = np.add(
    np.sum(np.square(feature), axis=1, keepdims=True),
    np.sum(np.square(np.transpose(feature)), axis=0,keepdims=True)) - 2.0 * np.matmul(feature, np.transpose(feature))

# pairwise_distances_squared = np.maximum(pairwise_distances_squared, 0.0)

error_mask = np.less_equal(pairwise_distances_squared, 0.0)

pairwise_distances = np.multiply(pairwise_distances, np.logical_not(error_mask)+0.0)

def masked_maximum(data, mask, dim=1):
    axis_minimums = np.min(data, dim, keepdims=True)
    masked_maximums = np.max(np.multiply(data - axis_minimums, mask), dim, keepdims=True) + axis_minimums
    return masked_maximums


def masked_minimum(data, mask, dim=1):
    axis_maximums = np.max(data, dim, keepdims=True)
    masked_minimums = np.min(np.multiply(data - axis_maximums, mask), dim, keepdims=True) + axis_maximums
    return masked_minimums


lshape = np.shape(labels)
assert len(lshape) == 1
labels = np.reshape(labels, [lshape[0], 1])

pdist_matrix = pairwise_distance(embeddings)
adjacency = np.equal(labels, np.transpose(labels))
# only the instances with different labels should be trained.
adjacency_not = np.logical_not(adjacency)
batch_size = np.size(labels)


# compute the mask
pdist_matrix_tile = np.tile(pdist_matrix, [batch_size, 1])

# 不同label,并且
# B * B 个element,每一个作为standard进行对比。
mask = np.logical_and(np.tile(adjacency_not, [batch_size, 1]),
                      np.greater(pdist_matrix_tile, np.reshape(np.transpose(pdist_matrix), [-1, 1])))


mask_final = np.reshape(
    np.greater(np.sum(mask+0.0, 1, keepdims=True),0.0), [batch_size, batch_size])
mask_final = np.transpose(mask_final)

adjacency_not = adjacency_not + 0.0
mask = mask + 0.0


# negatives_outside: smallest D_an where D_an > D_ap.
negatives_outside = np.reshape(masked_minimum(pdist_matrix_tile, mask), [batch_size, batch_size])

negatives_outside = np.transpose(negatives_outside)


# negatives_inside: largest D_an.
negatives_inside = np.tile(
    masked_maximum(pdist_matrix, adjacency_not), [1, batch_size])
semi_hard_negatives = np.where(
    mask_final, negatives_outside, negatives_inside)


loss_mat = np.add(margin, pdist_matrix - semi_hard_negatives)

mask_positives = adjacency+0.0 - np.diag(np.ones([batch_size]))

num_positives = np.sum(mask_positives)

triplet_loss = np.true_divide(
        np.sum(np.maximum(np.multiply(loss_mat, mask_positives), 0.0)),
        num_positives)

我们来看一些重点numpy代码

adjacency = np.equal(labels, np.transpose(labels))
adjacency_not = np.logical_not(adjacency)

这一部分中label的size为[ B B , 1],与其转秩矩阵size为[1, B B ],所以其结果为[ B B , B B ]
B B 为batch的大小, 且其轴坐标与列坐标对应了两个对应index的embedding的label是否是相同的。这样就筛选出了是否为P或者N
adjacency这个变量就是去衡量二者是否属于同一label,若相同则为True
adjacency_not则是相反,也就是只有两个embedding属于不同label,才为True

# compute the mask
pdist_matrix_tile = np.tile(pdist_matrix, [batch_size, 1])

# 不同label,并且
# B * B 个element,每一个作为standard进行对比。
mask = np.logical_and(np.tile(adjacency_not, [batch_size, 1]),
                      np.greater(pdist_matrix_tile, np.reshape(np.transpose(pdist_matrix), [-1, 1])))

这一处,首先将pdist_matrix(也就是pairwise distance的矩阵。其中每一个element对应了两个embedding的距离)做了一个纵向复制,且复制倍数为 B B
你就先记着,一开始我也无法理解的。这一步是为了后面的比较进行操作。

pdist_matrix:(其中有 B B * B B 个distance(虽然其中有接近一半是重复的,对称矩阵))

然后我们开始计算mask,这个mask是什么呢?
我们可以看到其中的adjacency_not,也就是说仅有不同label才奏效。(这个是在logical and之前)
后面一部分,则是判断pdist_matrix 是否比对应的距离大。
我们这里具体的列出来。 将pdist_matrix reshape成-1, 1之后,其实是 B B B*B 大小的一个列向量。
然后我们可以发现是这样的对比:
d ( e 1 , e 1 ) d ( e 1 , e 2 ) d ( e 1 , e 3 ) d ( e 1 , e b ) d ( e 1 , e 1 ) \begin{array}{ccc} d(e_1, e_1)& d(e_1, e_2)& d(e_1, e_3)& \cdots & d(e_1, e_b) \end{array} 与 d(e_1, e_1)

d ( e 2 , e 1 ) d ( e 2 , e 2 ) d ( e 2 , e 3 ) d ( e 2 , e b ) d ( e 1 , e 2 ) \begin{array}{ccc} d(e_2, e_1)& d(e_2, e_2)& d(e_2, e_3)& \cdots & d(e_2, e_b) \end{array} 与 d(e_1, e_2)

d ( e 3 , e 1 ) d ( e 3 , e 2 ) d ( e 3 , e 3 ) d ( e 3 , e b ) d ( e 1 , e 3 ) \begin{array}{ccc} d(e_3, e_1)& d(e_3, e_2)& d(e_3, e_3)& \cdots & d(e_3, e_b) \end{array} 与 d(e_1, e_3)

d ( e b , e 1 ) d ( e b , e 2 ) d ( e b , e 3 ) d ( e b , e b ) d ( e 1 , e b ) \begin{array}{ccc} d(e_b, e_1)& d(e_b, e_2)& d(e_b, e_3)& \cdots & d(e_b, e_b) \end{array} 与 d(e_1, e_b)

d ( e 1 , e 1 ) d ( e 1 , e 2 ) d ( e 1 , e 3 ) d ( e 1 , e b ) d ( e 2 , e 1 ) \begin{array}{ccc} d(e_1, e_1)& d(e_1, e_2)& d(e_1, e_3)& \cdots & d(e_1, e_b) \end{array} 与 d(e_2, e_1)

其实是不同label下, 不同label,并且 B * B 个element,每一个作为standard进行对比。度量其在行列中哪些比它本身更大。更大的作为True(为什么? 因为正式这一个个的distance,其中可能存在AP的pair距离,需要通过哪些比这个距离更大,从而筛选出其中的negative。)
因为我们此时已经拥有了adjacency_not,我们知道哪一对的距离是anchor和positive的,所以我们此时知道哪些是AP的距离,哪些是在 B B B*B 矩阵中比AP更大的距离。(semihard loss正是要求那些AN距离大于AP,但不足以大过margin的pair,因此这里要更大
所以只需要同时符合这两个条件的:

  1. 属于不同的label下
  2. 距离大于AP

就是我们的AN semihard loss候选

def masked_minimum(data, mask, dim=1):
    axis_maximums = np.max(data, dim, keepdims=True)
    masked_minimums = np.min(np.multiply(data - axis_maximums, mask), dim, keepdims=True) + axis_maximums
    return masked_minimums

# negatives_outside: smallest D_an where D_an > D_ap.
negatives_outside = np.reshape(masked_minimum(pdist_matrix_tile, mask), [batch_size, batch_size])

negatives_outside = np.transpose(negatives_outside)

这一处, 我们将函数 masked_minimum放进去,
可以发现在mask的情况下,我们求得最小的距离,注意,返回的结果,就已经是一大堆的distance了。并且是对应了每一个pair,是否满足AP条件的情况下,求的最小的距离。注意看注释 最小的满足 D a n > D a p D_{an} > D_{ap} 情况的 D a n D_{an} 。**这,就是我们所需要的semihard loss呀!**满足这样的情况下,最小,我们再尽可能的将这个 D a n D_{an} 优化的远一些。

那么问题来了,既然都求出来了semihard_loss,后面的操作是干什么呢?

def masked_maximum(data, mask, dim=1):
    axis_minimums = np.min(data, dim, keepdims=True)
    masked_maximums = np.max(np.multiply(data - axis_minimums, mask), dim, keepdims=True) + axis_minimums
    return masked_maximums

mask_final = np.reshape(
    np.greater(np.sum(mask+0.0, 1, keepdims=True),0.0), [batch_size, batch_size])
    
# negatives_inside: largest D_an.
negatives_inside = np.tile(
    masked_maximum(pdist_matrix, adjacency_not), [1, batch_size])
semi_hard_negatives = np.where(
    mask_final, negatives_outside, negatives_inside)

这边我们可以看到,masked_maximum函数就是在mask情况下,求的最大的distance。注释中也写了,是最大的 D a n D_{an} ,也就是说,我们也需要求得,对于满足条件的 A P AP 来说,距离他最远的 D a n D_{an} 是什么。
主要原因:

因为semihard triplet loss满足了一个重大缺陷,就是要大于 D a p D_{ap} ,但是某些情况比这个更严重。那就是有一些 D a n < D a p D_{an} < D_{ap} ,这是我们更需要去优化的,在这边的代码下,可能会出现不存在 D a n > D a p D_{an} > D_{ap} ,但是存在 D a n D_{an} 的情况,那么这个时候,我们需要去计算最远的 D a n D_{an} ,去优化 D a n D_{an} D a p D_{ap} 的距离。这种情况叫easy triplet.

并且最终我们求得了mask_final, 也就是满足存在 D a n D_{an} 情况下的一些distance,我们才对其求semihard_negatives。也就是这一行代码:

semi_hard_negatives = np.where(
    mask_final, negatives_outside, negatives_inside)

满足mask_final的情况下,选取negatives_outside和inside中最小的距离来作为优化目标。
最终的loss_matrix也就是:

loss_mat = np.add(margin, pdist_matrix - semi_hard_negatives)

这边就是二者相见然后加上定义好的margin,使得二者之间的距离满足我们定义的空间中的间隔。这是一个很重要的hyperparameter,大家可以好好调一下。

mask_positives = adjacency+0.0 - np.diag(np.ones([batch_size]))

num_positives = np.sum(mask_positives)

triplet_loss = np.true_divide(
        np.sum(np.maximum(np.multiply(loss_mat, mask_positives), 0.0)),
        num_positives)

最后一个关键点就是,我们并不能使用对角线上的distance,首先他们是anchor与其本身的距离作为度量。我们要选取的是有anchor, positive的,也就是同样label下的不同的两个embedding的距离。而不是anchor自己与自己的距离。那样距离肯定为0,优化没有意义。
所以我们将adjacency对角线上的都至为0,最后我们计算求和取平均。得到了triplet semihard loss

发布了1157 篇原创文章 · 获赞 1804 · 访问量 1455万+

猜你喜欢

转载自blog.csdn.net/weixin_40400177/article/details/105213578
今日推荐