PairConLoss对比学习相似度损失构建

来源:SCCL

通过下面的损失函数拉近xi1和xi2的距离,并拉远与Ba 中其他样本间的距离

在这里插入图片描述
1 j!=i 1是一个indicator函数且τ为温度参数,设置为0.05。相似函数sim(⋅)选择正则化向量的点积
在这里插入图片描述
对比学习在整个数据集Ba的平均为
在这里插入图片描述

代码复现如下

from __future__ import print_function
import torch
import torch.nn as nn
import numpy as np

class PairConLoss(nn.Module):
    def __init__(self, temperature=0.05):
        super(PairConLoss, self).__init__()
        self.temperature = temperature
        self.eps = 1e-08
        print(f"\n Initializing PairConLoss \n")

    def forward(self, features_1, features_2):
        device = features_1.device
        batch_size = features_1.shape[0] # 得到int数字
        features= torch.cat([features_1, features_2], dim=0) # 在第0维进行拼接
        mask = torch.eye(batch_size, dtype=torch.bool).to(device) # 这个函数主要是为了生成对角线全1,其余部分全0的二维数组
        mask = mask.repeat(2, 2)
        # 复制成4份,并两两拼接  eg:[1,2,3]--> repeat(3,3)--> [[1, 2, 3, 1, 2, 3, 1, 2, 3],
        #                                                   [1, 2, 3, 1, 2, 3, 1, 2, 3],
        #                                                   [1, 2, 3, 1, 2, 3, 1, 2, 3]]
        #                       size(3)                                 size(3,9)
        mask = ~mask
        # 取反操作 里面涉及到各种二进制,只需要记住一个公式就行了:~a=-(a+1)
        pos = torch.exp(torch.sum(features_1*features_2, dim=-1) / self.temperature)
        pos = torch.cat([pos, pos], dim=0)
        neg = torch.exp(torch.mm(features, features.t().contiguous()) / self.temperature)
        #          mm是两个矩阵相乘,即两个二维的张量相乘  t()表示T转置 contiguous()函数不会对原始数据进行任何修改,而仅仅对其进行复制,这么做的目的是,在对tensor元素进行转换和维度变换等操作之后,元素地址在内存空间中保证连续性,在后续利用指针对tensor元素进行读取时,能够减少读取便利,提高内存空间优化
        neg = neg.masked_select(mask).view(2*batch_size, -1)
                # masked_select() 掩码替换
        neg_mean = torch.mean(neg) # 均值
        pos_n = torch.mean(pos)
        Ng = neg.sum(dim=-1)
            
        loss_pos = (- torch.log(pos / (Ng+pos))).mean()
        
        return {
    
    "loss":loss_pos, "pos_mean":pos_n.detach().cpu().numpy(), "neg_mean":neg_mean.detach().cpu().numpy(), "pos":pos.detach().cpu().numpy(), "neg":neg.detach().cpu().numpy()}

            

猜你喜欢

转载自blog.csdn.net/weixin_46398647/article/details/126668202