【对比学习】CUT模型论文解读与NCE loss代码解析

标题:Contrastive Learning for Unpaired Image-to-Image Translation(基于对比学习的非配对图像转换)
作者:Taesung Park, Alexei A. Efros, Richard Zhang ,Jun-Yan Zhu
论文下载地址:http://arxiv.org/abs/2007.15651
开源代码:https://github.com/taesungp/contrastive-unpaired-translation

一、介绍

在图像转换(image-to-image translation)的任务中,我们想要的是在保留输入图像的结构特征的基础上,加入目标域的外观特征。一个经典的任务就是把马转换成斑马,在保留输入的马的图像结构的同时,将纹路换成目标域(斑马)的纹路。目前主流的做法基本上都是基于CycleGAN方法的变体,利用对抗损失(adversarial loss)强化目标域的外观特征,使用循环一致性损失(cycle-consistency loss)来保证原始输入图像的结构不变。但是CycleGAN的假设非常严格,要求输入的图像域和目标域之间存在双射关系,这一点在其实是很难满足的。所以这篇论文提出了一个替代性方案,通过最大化输入输出图像块的互信息(mutual information),使用一个对比损失函数infoNCE loss, 来学习一个编码器Encoder将对应的图像块之间相互联系起来,与其他的图像块分离;如此一来Encoder可以专注于两个域之间共性的部分如形状,而忽略两个域之间的差异性部分如纹理
CUT这篇论文证明了以多层次,图像块的范式运用对比学习技术的有效性,并且发现从单张图像本身中提取负性图像块的效果要好于从整个数据集中其他的图像中提取,因此甚至可以在单张图像上实现图像转换。
如下图中,使用多层图像块的对比损失,最大化相对应的多层图像块之间的互信息,这样将生成器和Encoder相结合,取得对应输入图像的生成图像。
Alt

二、相关概念

1.图像转换(image-to-image translation)

对称图像转换(pix2pix),使用对抗损失和重建损失形成输入和输出图像之间的映射。在非对称图像转换中,没有目标域的对应样本,循环一致损失成为事实上的标准做法(CycleGAN),通过学习一个从目标域到输入图像的映射,来检查是否输入图像被正确映射到了目标域。之后的做法大多是在循环一致损失的基础上完成的(如UNIT,MUNIT),在这个领域,循环一致损失主要在三个层面上使用,图像与图像之间,隐空间到图像,图像到隐空间。但是这些都基于输入域和目标域之间存在双射关系的严格假设,这一点当某个域的图像由相较于另个域更多的信息时就更难获得很好的效果。

2.关系保存(relationship preservation)

为了避开双射的限制,一个替代的想法是输入图像中存在的关系,类似地也应该在生成的图像中体现,就比如同一张图内近似的图像块,在生成的图像中也应该有这样近似的图像块。TraVeLGAN, DistanceGAN and GcGAN通过预定义的距离函数保证共享相似的内容,或是使用triplet loss保存输入图像之间的向量计算,再或是计算输入图像之间的距离和生成图像的距离使之保持一致等等做法,绕开循环一致性损失的限制。但是这些方法要么是需要预定义一个距离函数,要么保存的关系是基于整个图像的。CUT的做法是通过最大化互信息的方法,学习一个输入输出图像块之间的相似性函数,避免了以上方法的缺陷。

3.深度网络的感知相似性度量

大多是图像转换工作都是使用的逐像素重建进行度量,这无法反映人类的感知习惯并且会导致生成图片非常模糊。因此可以定义一个高维信号的感知距离函数,这一点使用在ImageNet上预训练的VGG分类网络就可以实现 ,并且其在人类感知测试中取得了超过传统度量方法(SSIM and FSIM)的效果。但是这个方法没法适应其他的数据集,并且它也不是一个基于图像对的相似性度量。CUT以互信息作为约束,将图像本身中的负样本利用起来,可以适用于不同特定的输入输出域,从而避免了对相似性函数的预定义。

4.对比特征学习

传统的无监督学习需要预先设计好的损失函数来衡量预测表现,新的方法通过最大化互信息绕开这个问题,使用噪声对比估计(noise contrastive estimation,NCE)来学习一个Encoder,将关联的信号拉近,并与数据集中的其他样本形成对比。信号可以是图像本身,也可以是下采样特征,相邻图像块等等。CUT首先将infoNCE loss应用到了条件图像生成领域。

三、CUT的基于对比学习方法

首先要定义图像转换问题,图像输入域为 X ∈ R H × W × C \mathcal{X} \in \mathbb{R}^{H \times W \times C} XRH×W×C,而输出图像域为 Y ∈ R H × W × 3 \mathcal{Y}\in\mathbb{R}^{H\times W\times 3} YRH×W×3,数据集为 X = { x ∈ X } , Y = { y ∈ Y } X=\{x \in \mathcal{X}\} ,Y=\{y \in \mathcal{Y}\} X={ xX}Y={ yY}, 其中在CUT的方法中数据集可以只包含单张图像。

在CUT方法中,生成器被G分解为两个部分, 先是一个encoder再是一个decoder,这样生成输出图像 y ^ \hat y y^ 的过程变成了, y ^ = G ( z ) = G d e c ( G e n c ( x ) ) \hat y=G(z)=G_{dec}(G_{enc(x)}) y^=G(z)=Gdec(Genc(x))

在GAN的图像生成部分,CUT仍然是使用GAN的对抗损失,来保证生成的图像能和目标域的图像尽可能相似,这部分的损失就是:
L ( G , D , X , Y ) = E y ∼ Y  ​ l o g D ( y ) + E x ∼ X ​  l o g ( 1 − D ( G ( x ) ) ) L(G,D,X,Y)=E_{y∼Y} \ ​logD(y)+E_{x∼X} ​\ log(1−D(G(x))) L(G,D,X,Y)=EyY logD(y)+ExX log(1D(G(x)))
在互信息最大化方面,采用noise contrastive estimation(NCE)框架。对比学习的问题有三个信号组成,查询样本( q q q)和正样本( k + k^+ k+),负样本( k − k^- k),要做的就是让 q q q k + k^+ k+ 信号相关联和 k − k^- k 形成对比。将 q q q k + k^+ k+,以及N个 k − k^- k,分别映射成K维向量 v ,   v + ∈ R K ,   v − ∈ R N × K v, \ v^+\in \mathbb{R}^K,\ v^-\in \mathbb{R}^{N\times K} v, v+RK, vRN×K,并用 v n − ∈ R K v^-_n\in \mathbb{R}^K vnRK表示第n个负样本,将这些样本归一化至单位球中,防止空间扩张或坍缩。这样就形成了一个N+1的分类问题,交叉熵损失计算如下其中 τ \tau τ 是比例超参,表示正样本被选中的概率。
ℓ ( v , v + , v − ) = − l o g   [ e x p ( v ⋅ v + / τ ) e x p ( v ⋅ v + / τ ) + ∑ n = 1 N ​ e x p ( v ⋅ v − / τ ) ​ ] ℓ(v,v^+,v^−)=−log\ [{\frac{exp(v⋅v^+/τ)}{exp(v⋅v^+/τ)+∑^{N}_{n=1}​exp(v⋅v^−/τ)}}​] (v,v+,v)=log [exp(vv+/τ)+n=1Nexp(vv/τ)exp(vv+/τ)]
无监督学习中用到对比学习,既有图像层次也有图像块层次,具体到CUT要解决的任务中,对于输入输出图像不仅整个图像应该有着同样的结构,对应的图像块之间也应该有相应的结构 。所以应该用多层次图像块(multilayer patch-based)的学习目标。通过 G e n c G_{enc} Genc​ 编码特征层,其中不同层不同空间位置代表了不同的图像块,层数越深图像块越大。

Alt
CUT模型选择了L层特征图,将其通过2层MLP网络 H l H_l Hl ​产生了一系列的特征 { z l } L = { H l ( G e n c l ( x ) } L ​ \{z_l\}_L=\{H_l(G^l_{enc}(x)\}_L​ { zl}L={ Hl(Gencl(x)}L,其中 G e n c l G^l_{enc} Gencl​ 表示第 l l l 层输出特征。序列 l ∈ { 1 , 2 , 3 , . . . , L } ,   s ∈ { 1 , 2 , . . . , S l } l\in\{1,2,3,...,L\},\ s\in\{1,2,...,S_l\} l{ 1,2,3,...,L}, s{ 1,2,...,Sl}, 其中 S l S_l Sl 表示第 l l l 层有 S l S_l Sl 个空间位置。将对应特征记为 z l s ∈ R C l z^s_l\in\mathbb{R}^{C_l} zlsRCl 其他特征标记为 z l S \ s ∈ R ( S l − 1 ) × C l z^{S\backslash s}_l\in\mathbb{R}^{(S_l-1)\times{C_l}} zlS\sR(Sl1)×Cl,其中 C l C_l Cl是每层通道数。同样将输出图像 y ^ \hat y y^ 也编码成 { z ^ l } L ​ = { H l ​ ( G e n c l ​ ( x ) } L ​ \{\hat{z}_l\}_L​=\{H_l​(G_{enc}^l​(x)\}_L​ { z^l}L={ Hl(Gencl(x)}L.
CUT模型的目标是将输入输出对应位置的图像块进行匹配,同一张图像其他位置的图像块作为负样本,将损失记做PatchNCE loss:
L P a t c h N C E ​ ( G , H , X ) = E x ∼ X ​ = ∑ l = 1 L ​ ∑ s = 1 S l ​​  ℓ ( z l s ^ ​ , z l s ​ , z l S \ s ​ ) L_{PatchNCE}​(G,H,X)=E_{x∼X​}=\sum_{l=1}^L​ \sum_{s=1}^{S_l}​​\ ℓ(\hat {z^s_l}​,z^s_l​,z^{S\backslash s}_l​) LPatchNCE(G,H,X)=ExX=l=1Ls=1Sl​​ (zls^,zls,zlS\s)
同样也可以从数据集的其他图像中提取图像块做负样本记做 z ~ \tilde z z~ ,可以像MoCo一样用一个辅助的移动平均编码器 H ^ l \hat H_l H^l​ 和移动平均 MLP层 H ^ \hat H H^ 共同计算,维护一个负样本字典 Z − Z^- Z
L e x t e r n a l ​ ( G , H , X ) = E x ∼ X , z ^ ∼ Z   ∑ l = 1 L ∑ s = 1 S l   ℓ ( z ^ l s ​ , z l s ​ , z ~ l ​ ) L_{external​}(G,H,X)=E_{x∼X,\hat z∼Z} \ \sum_{l=1}^L \sum_{s=1}^{S^l}\ ℓ(\hat z^s_l​, z^s_l​,\tilde z_l​) Lexternal(G,H,X)=ExX,z^Z l=1Ls=1Sl (z^ls,zls,z~l)
最终的目标函数,和CycleGAN一样也添加了一致损失(identity loss)— L P a t c h N C E ( G , H , Y ) \mathcal{L}_{PatchNCE}(G,H,Y) LPatchNCE(G,H,Y),以使 E y ∼ Y ∥ G ( y ) − y ∥ 1 \mathbb{E}_{y\sim Y}\|G(y)-y\|_1 EyYG(y)y1 尽量小避免生成器对产生的图片造成不必要的变化。所以总损失包含对抗损失,对比损失,一致性损失三个部分:
L ( G , D , X , Y ) + λ X ​  L P a t c h N C E ​ ( G , H , X ) + λ Y ​  L P a t c h N C E ​ ( G , H , Y ) L(G,D,X,Y)+λ_X​\ L_{PatchNCE​}(G,H,X)+λ_Y​\ L_{PatchNCE}​(G,H,Y) L(G,D,X,Y)+λX LPatchNCE(G,H,X)+λY LPatchNCE(G,H,Y)
当使用 λ X = 1 λ_X = 1 λX=1, λ Y = 1 λ_Y = 1 λY=1 联合训练时称为CUT,当取 λ Y = 0 λ_Y = 0 λY=0 时,作为补偿取 λ X = 10 λ_X = 10 λX=10 时称为FastCUT, 可以被看做是更快更轻量级的CycleGAN。可以看出CUT所采用的损失函数组成部分不多,要求的超参也不多。

四、总结

​ 综合来看,CUT这篇论文相较于其他非对称图像转换的论文,主要的创新点还是在于引入了对比学习的概念,将CycleGAN的循环一致性损失改换成对比损失,放松了对图像域要求存在双射关系的假设,因此可以用在单向的图像转换任务中去,并且在结构上更加轻量级,避免了CycleGAN额外的生成器和判别器,减少了计算花费。从消融实验中可以看出,CUT的最关键核心点在于基于最大化互信息的,使用输入图像本身的图像块,以及要使用多层Encoder获得不同层级的。

五、对比损失代码解析

(1)伪代码理解

对于NCE loss的理解可以参考学习B站大神的视频讲解:MoCo 论文逐段精读【论文精读】
这里贴上MoCo论文里对比损失NCE loss的伪代码,CUT模型的官方代码中的对比损失基本就是根据这个伪代码实现的
在这里插入图片描述
第一步: 获取positive logits,此处q的维度是[N,1,C],k的维度是[N,C,1],N是训练的数据量。q@k可以得到正样本的相关性向量,即q和k中一 一对应的样本视为正样本,进行矩阵相乘后得到彼此之间的相关性。因此l_pos的维度是[N,1,1]。
第二步: 获取negative logits,此处q的维度是[N,C],queue(负样本字典)的维度是[C,K],N是训练的数据量,K是类别数(当作多分类的类别数量,可以理解成每一个样本点(是采样的样本点,因为负样本是靠采样得到的)即为一个类别)。q@queue可以得到所有负样本之间的混淆矩阵,及q中每个点与queue每个点之间的相关性,因此l_neg的维度是[N,K]。
第三步: 获取整个logits,此时维度大小为[N,K+1],K+1可以理解成有K+1个类别,所以后面计算NCE loss其实就是计算分类问题而已。
第四步: 设置labels标签,此处labels设置成大小为N的全零矩阵。

(1) 这里可能会很奇怪为什么要这么设置,我也想想了许久才明白的。
上面提到N代表的是训练的数据量,所以labels设置成N就能理解了,每个数据对应一个标签嘛!
(2)那为什么金标准都是零?
在第三步中,整个logits的大小是[N,K+1],即代表有N个数据,K+1个类别,所以NCE loss可以理解成计算K+1类的分类问题。其中正样本对应的就是第0类,其他的K个类别不管是什么都无所谓因为都不是我们想要得到的。因此最后分类的目标就是把正样本给分出来,所以标签labels可以设置成全零。

第五步: 计算交叉熵损失。

(2)代码解析

代码部分只摘取部分重要的解释,详细内容请自行查看官方代码

# 一些基础参数赋值
batch_size=2
image_size=512
netF='mlp_sample'  # 对应特征提取的Hl模块
netF_nc=256        # mlp层输出的维度大小
nce_T=0.07         # NCE loss的温度系数
num_patches=512    # 计算NCE loss时每一层采样点的数量
nce_layers='0,4,8,12,16'   #计算NCE loss的层序号
nce_includes_all_negatives_from_minibatch=False  # 该参数为True代表在计算负样本时,负样本字典应包含batch里的其他图片,在执行当单图片转换是才会赋值True。对于CUT和FastCUT任务默认为False

# 生成loss的定义
def compute_G_loss(self):
    """Calculate GAN and NCE loss for the generator"""
    fake = self.fake_B
    # First, G(A) should fake the discriminator
    if self.opt.lambda_GAN > 0.0:
        pred_fake = self.netD(fake)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
    else:
        self.loss_G_GAN = 0.0

    if self.opt.lambda_NCE > 0.0:
        self.loss_NCE = self.calculate_NCE_loss(self.real_A, self.fake_B)
    else:
        self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0

    if self.opt.nce_idt and self.opt.lambda_NCE > 0.0:
        self.loss_NCE_Y = self.calculate_NCE_loss(self.real_B, self.idt_B)
        loss_NCE_both = (self.loss_NCE + self.loss_NCE_Y) * 0.5
    else:
        loss_NCE_both = self.loss_NCE

    self.loss_G = self.loss_G_GAN + loss_NCE_both
    return self.loss_G

# 计算NCE loss
def calculate_NCE_loss(self, src, tgt):
	n_layers = len(self.nce_layers)  # n_layers=5
    # 提取编码器中对应的5层特征,输出的feat_q的形式为:list[[2,1,518,518], [2,64,512,512], [2,128,256,256], [2,128,128,128], [2,128,128,128]]。
    # list里面每个元素的维度为[batches,channels,heights,weights],输入图像默认大小是512*512,而第一个元素为518是因为在数据处理的时候做了padding
    feat_q = self.netG(tgt, self.nce_layers, encode_only=True)  
      
	# 只有在FastCUT模式才会做此强制翻转作为额外的正则化
    if self.opt.flip_equivariance and self.flipped_for_equivariance:  
        feat_q = [torch.flip(fq, [3]) for fq in feat_q]
          
	# 同样feat_k的形式为list[[2,1,518,518], [2,64,512,512], [2,128,256,256], [2,128,128,128], [2,128,128,128]]
    feat_k = self.netG(src, self.nce_layers, encode_only=True)  
    
   	# 通过MLP层提取特征和选取采样点,首先在k中随机采样num_patches=512个样本点,并返回采样点对应的ids
    feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
    # q也是经过MLP层提取特征,并选取和k对应ids的采样点
    feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)

	# 计算NCE loss
    total_nce_loss = 0.0
    for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
        loss = crit(f_q, f_k) * self.opt.lambda_NCE
        total_nce_loss += loss.mean()

    return total_nce_loss / n_layers

# 采样前经过的MLP层提取特征,此处netF选用PatchSampleF,以下是PatchSampleF的定义
class PatchSampleF(nn.Module):
    def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]):
        # potential issues: currently, we use the same patch_ids for multiple images in the batch
        super(PatchSampleF, self).__init__()
        self.l2norm = Normalize(2)
        self.use_mlp = use_mlp
        self.nc = nc  # hard-coded
        self.mlp_init = False
        self.init_type = init_type
        self.init_gain = init_gain
        self.gpu_ids = gpu_ids

    def create_mlp(self, feats):   # 创建MLP层结构
        for mlp_id, feat in enumerate(feats):
            input_nc = feat.shape[1]
            mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
            if len(self.gpu_ids) > 0:
                mlp.cuda()
            setattr(self, 'mlp_%d' % mlp_id, mlp)
        init_net(self, self.init_type, self.init_gain, self.gpu_ids)
        self.mlp_init = True

    def forward(self, feats, num_patches=512, patch_ids=None):
        return_ids = []
        return_feats = []
        if self.use_mlp and not self.mlp_init:
            self.create_mlp(feats)
        for feat_id, feat in enumerate(feats): # 此处feats的形式为list[[2,1,518,518], [2,64,512,512], [2,128,256,256], [2,128,128,128], [2,128,128,128]]
            B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]   # B=2,H和W为不同层特征图的大小
            feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)   # [B,C,H,W]——>[B,H,W,C]——>[B,HW,C]
            if num_patches > 0:
                if patch_ids is not None:  # 对于feat_q_pool,因为此时传入了feat_k_pool采样到采样点对应ids
                    patch_id = patch_ids[feat_id]
                else:  # 一开始feat_k_pool是没有采样点id传入的,所以需要先随机选取采样点;后面feat_q_pool根据feat_k_pool得到的采样点进行对应查询采样 
                    patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device) # 打乱feat_reshape中HW维度的顺序
                    patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))]  # 选择打乱后feat_reshape的前num_patches个点,作为采样点的ids
                x_sample = feat_reshape[:, patch_id, :].flatten(0, 1)  # 获取对应采样点的特征,此刻x_sample的维度为[B,num_patches,C]——>[BXnum_patches, C]=[2x512, 256]
            # 【注意:此处采样点的数量是num_patches=512,这512个采样点的id是不连续的,也就是说随机在特征图里采样512个点,对这随机采样的512个点做对比学习,这和论文中画出patch作为采样块有些不同】
            else:
                x_sample = feat_reshape
                patch_id = []
            if self.use_mlp:
                mlp = getattr(self, 'mlp_%d' % feat_id)
                x_sample = mlp(x_sample)
            return_ids.append(patch_id)
            x_sample = self.l2norm(x_sample)

            if num_patches == 0:
                x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
            return_feats.append(x_sample)
        # 此时返回的return_feats的形式为list[[1024, 256],[1024, 256],[1024, 256],[1024, 256],[1024, 256]]
        # return_ids返回每一层对应的采样点id号,[[512],[512],[512],[512],[512]]
        return return_feats, return_ids

# NCE loss的定义
class PatchNCELoss(nn.Module):
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
        self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool

    def forward(self, feat_q, feat_k):
        batchSize = feat_q.shape[0]   # batchSize=1024,batch上所有的采样点 
        dim = feat_q.shape[1]         # dim=256,每个采样点的特征维度大小
        feat_k = feat_k.detach()

        # pos logit
        # 变换后feat_q的维度变为[1024,1,256], feat_k的维度变为[1024,256,1].进行矩阵乘法后得到l_pos的维度为[1024,1,1]
        # 此操作可以理解为feat_q与feat_k一 一对应的位置是相同的类别也就是正样本,因此feat_q与feat_k对应位置的矩阵乘法相当于求q与k+之间的相关性,也就是正样本之间相关性系数。
        # 而batchSize=2x512是因为这是对应位置的矩阵相乘,因此可以将不同patch的采样点合并计算
        l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
        l_pos = l_pos.view(batchSize, 1) # [1024,1,1]——>[1024,1]

        # neg logit
        # Should the negatives from the other samples of a minibatch be utilized?
        # In CUT and FastCUT, we found that it's best to only include negatives from the same image. Therefore, we set ‘--nce_includes_all_negatives_from_minibatch’ as False
        # However, for single-image translation, the minibatch consists of crops from the "same" high-resolution image.
        # Therefore, we will include the negatives from the entire minibatch.
        if self.opt.nce_includes_all_negatives_from_minibatch:
            # reshape features as if they are all negatives of minibatch of size 1.
            batch_dim_for_bmm = 1
        else:
            batch_dim_for_bmm = self.opt.batch_size # batch_dim_for_bmm=2

        # reshape features to batch size
        feat_q = feat_q.view(batch_dim_for_bmm, -1, dim) # [1024, 256]——>[2,512,256]
        feat_k = feat_k.view(batch_dim_for_bmm, -1, dim) # [1024, 256]——>[2,512,256]
        npatches = feat_q.size(1)  # npatches=512
        
        # feat_q的维度为[2,512,256],feat_k变换后的维度为[2,256,512].对feat_q和feat_k进行矩阵相乘得到的是q中每个采样点与k中每个采样点的相关性矩阵(类似混淆矩阵)大小是512x512,结果l_neg_curbatch的维度为[2,512,512]
        # 【注意:为什么此处不将不同batch的样本合并来获得更大的负样本?】
        # (1)作者提到在FastCUT和CUT模式中,仅使用同一张图像的采样点作为负样本点结果比使用不同图像的结果更好。
        #     至于计算l_pos可以这么做的原因是l_pos计算的是对应位置的采样点,是一一对应的,所以l_pos合并不同图像计算与不合并没有差别。
        # (2)除此之外,我觉得另外一个原因是:合并不同图像的负样本点的计算量开销和内存消耗远远大于不合并。
        #     合并后将变成[1,1024,256]@[1,256,1024](计算量:1024x256x1024,保存的矩阵大小1024x1024)
        #     而不合并是[2,512,256]@[2,256,512](计算量:2x512x256x512,保存的矩阵大小为2x512x512)
        l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))  # 输出维度[2,512,512]

        # diagonal entries are similarity between same features, and hence meaningless.
        # just fill the diagonal with very small number, which is exp(-10) and almost zero
        # 由于对角线计算的是相同采样位置的相似性,也就是计算的q@k+。所以计算负样本的时候要把对角线的值变成0,这样得到的矩阵才是真正意义上的q@k-。
        diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]  #生成对角线为1,其他元素均是0,就是大小为512x512的对角矩阵
        l_neg_curbatch.masked_fill_(diagonal, -10.0)   # 因为对角线是相同采样位置之间的相关性,所以-10操作相当于将其置0,得到负样本的相关性矩阵
        l_neg = l_neg_curbatch.view(-1, npatches)  # 输出维度为[1024,512]

		# 合并正负样本的logits,输出维度为[1024,513],1024可以理解成数据量大小,513可以理解成有513个类别,其中正样本的类别序号是0.
        out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T  

		# 此时NCE loss变成了513类别的分类问题,只有0类别是正类,所以变成可以设置成全零。进行交叉熵损失计算后的结果就是NCE loss的结果
        loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
                                                        device=feat_q.device))

        return loss

猜你喜欢

转载自blog.csdn.net/Joker00007/article/details/127678008