来源:SCCL
通过下面的损失函数拉近xi1和xi2的距离,并拉远与Ba 中其他样本间的距离
1 j!=i 1是一个indicator函数且τ为温度参数,设置为0.05。相似函数sim(⋅)选择正则化向量的点积
对比学习在整个数据集Ba的平均为
代码复现如下
from __future__ import print_function
import torch
import torch.nn as nn
import numpy as np
class PairConLoss(nn.Module):
def __init__(self, temperature=0.05):
super(PairConLoss, self).__init__()
self.temperature = temperature
self.eps = 1e-08
print(f"\n Initializing PairConLoss \n")
def forward(self, features_1, features_2):
device = features_1.device
batch_size = features_1.shape[0] # 得到int数字
features= torch.cat([features_1, features_2], dim=0) # 在第0维进行拼接
mask = torch.eye(batch_size, dtype=torch.bool).to(device) # 这个函数主要是为了生成对角线全1,其余部分全0的二维数组
mask = mask.repeat(2, 2)
# 复制成4份,并两两拼接 eg:[1,2,3]--> repeat(3,3)--> [[1, 2, 3, 1, 2, 3, 1, 2, 3],
# [1, 2, 3, 1, 2, 3, 1, 2, 3],
# [1, 2, 3, 1, 2, 3, 1, 2, 3]]
# size(3) size(3,9)
mask = ~mask
# 取反操作 里面涉及到各种二进制,只需要记住一个公式就行了:~a=-(a+1)
pos = torch.exp(torch.sum(features_1*features_2, dim=-1) / self.temperature)
pos = torch.cat([pos, pos], dim=0)
neg = torch.exp(torch.mm(features, features.t().contiguous()) / self.temperature)
# mm是两个矩阵相乘,即两个二维的张量相乘 t()表示T转置 contiguous()函数不会对原始数据进行任何修改,而仅仅对其进行复制,这么做的目的是,在对tensor元素进行转换和维度变换等操作之后,元素地址在内存空间中保证连续性,在后续利用指针对tensor元素进行读取时,能够减少读取便利,提高内存空间优化
neg = neg.masked_select(mask).view(2*batch_size, -1)
# masked_select() 掩码替换
neg_mean = torch.mean(neg) # 均值
pos_n = torch.mean(pos)
Ng = neg.sum(dim=-1)
loss_pos = (- torch.log(pos / (Ng+pos))).mean()
return {
"loss":loss_pos, "pos_mean":pos_n.detach().cpu().numpy(), "neg_mean":neg_mean.detach().cpu().numpy(), "pos":pos.detach().cpu().numpy(), "neg":neg.detach().cpu().numpy()}