PairConLoss contrastive learning similarity loss construction

Source: SCCL

The distance between xi1 and xi2 is shortened through the following loss function, and the distance from other samples in Ba is further increased.

Insert image description here
1 j!=i 1 is an indicator function and τ is the temperature parameter, set to 0.05. The similarity function sim(⋅) selects the dot product of the regularization vector.
Insert image description here
The average of contrastive learning in the entire data set Ba is
Insert image description here

The code is reproduced as follows

from __future__ import print_function
import torch
import torch.nn as nn
import numpy as np

class PairConLoss(nn.Module):
    def __init__(self, temperature=0.05):
        super(PairConLoss, self).__init__()
        self.temperature = temperature
        self.eps = 1e-08
        print(f"\n Initializing PairConLoss \n")

    def forward(self, features_1, features_2):
        device = features_1.device
        batch_size = features_1.shape[0] # 得到int数字
        features= torch.cat([features_1, features_2], dim=0) # 在第0维进行拼接
        mask = torch.eye(batch_size, dtype=torch.bool).to(device) # 这个函数主要是为了生成对角线全1,其余部分全0的二维数组
        mask = mask.repeat(2, 2)
        # 复制成4份,并两两拼接  eg:[1,2,3]--> repeat(3,3)--> [[1, 2, 3, 1, 2, 3, 1, 2, 3],
        #                                                   [1, 2, 3, 1, 2, 3, 1, 2, 3],
        #                                                   [1, 2, 3, 1, 2, 3, 1, 2, 3]]
        #                       size(3)                                 size(3,9)
        mask = ~mask
        # 取反操作 里面涉及到各种二进制,只需要记住一个公式就行了:~a=-(a+1)
        pos = torch.exp(torch.sum(features_1*features_2, dim=-1) / self.temperature)
        pos = torch.cat([pos, pos], dim=0)
        neg = torch.exp(torch.mm(features, features.t().contiguous()) / self.temperature)
        #          mm是两个矩阵相乘,即两个二维的张量相乘  t()表示T转置 contiguous()函数不会对原始数据进行任何修改,而仅仅对其进行复制,这么做的目的是,在对tensor元素进行转换和维度变换等操作之后,元素地址在内存空间中保证连续性,在后续利用指针对tensor元素进行读取时,能够减少读取便利,提高内存空间优化
        neg = neg.masked_select(mask).view(2*batch_size, -1)
                # masked_select() 掩码替换
        neg_mean = torch.mean(neg) # 均值
        pos_n = torch.mean(pos)
        Ng = neg.sum(dim=-1)
            
        loss_pos = (- torch.log(pos / (Ng+pos))).mean()
        
        return {
    
    "loss":loss_pos, "pos_mean":pos_n.detach().cpu().numpy(), "neg_mean":neg_mean.detach().cpu().numpy(), "pos":pos.detach().cpu().numpy(), "neg":neg.detach().cpu().numpy()}

            

Guess you like

Origin blog.csdn.net/weixin_46398647/article/details/126668202