dual_contrastive_loss粗略解读

def dual_contrastive_loss(real_logits, fake_logits):
    device = real_logits.device
    real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits))

    def loss_half(t1, t2):
        t1 = rearrange(t1, 'i -> i ()')#最里面多了一个维度i*1
        t2 = repeat(t2, 'j -> i j', i = t1.shape[0])#i个j组合起来,重复的组合
        t = torch.cat((t1, t2), dim = -1)
        return F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long))

    return loss_half(real_logits, fake_logits) + loss_half(-fake_logits, -real_logits)

dual_contrastive_loss粗略解读

real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits))

这里是将real_logits, fake_logits两个矩阵分别转化为一维数组。

t1 = rearrange(t1, 'i -> i ()')#让t1增加一个维度从i-》i*1
t2 = repeat(t2, 'j -> i j', i = t1.shape[0])#t2是一个j个元素的一维数组,这里让重复i个j组合起来。j-》i*j
t = torch.cat((t1, t2), dim = -1)#shape为i*1的t1和shape为i*j的t2在第二个维度上拼接起来

loss_half(real_logits, fake_logits)

loss_half(real_logits, fake_logits)的作用如下,real_logits的每个元素都在t的每行的第一列,idnex=0,F.cross_entropy交叉熵损失函数的目标标签全部是0,也就是说这行代码就是让real_logits的每个元素与fake_logits的所有元素相比,real_logits的每个元素的值最大。

F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long))

loss_half(-fake_logits, -real_logits)

loss_half(-fake_logits, -real_logits)的作用如下,-fake_logits的每个元素都在t的每行的第一列,idnex=0,F.cross_entropy交叉熵损失函数的目标标签全部是0,意思是这行代码就是让-fake_logits的每个元素与-real_logits的所有元素相比,-fake_logits的每个元素的值最大。因为这里是两者的负号相比,所以这里的意思就是就是让fake_logits的每个元素与real_logits的所有元素相比,fake_logits的每个元素的值最小。

F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long))

猜你喜欢

转载自blog.csdn.net/qq_43263543/article/details/120849666