VLM--CLIP作分类任务的损失函数

info_nce_loss

这个是clip作对比学习的损失函数
各个博客上都有详细介绍了,我这里就不赘述

def info_nce_loss(image_features, text_features,logit_scale,labels, temperature=0.07):
    batch_size = image_features.shape[0]

    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    similarity_matrix = torch.matmul(image_features, text_features.T) / temperature

    logits_per_image = similarity_matrix
    logits_per_text = similarity_matrix.T

    # 构造标签,正样本对应的位置为1,其余为0,这里假设批次内第一个文本特征是对应图像的正样本文本特征
    gen_labels = torch.arange(batch_size).long().to(image_features.device)

    total_loss = (
        F.cross_entropy(logits_per_image, gen_labels)+
        F.cross_entropy(logits_per_text, gen_labels)
    )/2

    return total_loss, logits_per_image, logits_per_text

我踩的坑

微调 c l i p clip clip 做分类任务类别数为3

  1. 数据集为图像-文本对数据集:即一个数据样本为一个图像和对应的文本在json文件里。这里每个类别的图像的文本都是一样的,也就是a类别下图像可能会有细微不同,但是文本都是一样的
  2. 微调 c l i p clip clip 的结构同原始 c l i p clip clip 一致,输出的图像特征维度为 [ 输入图像数量 , 512 ] [输入图像数量,512] [输入图像数量,512],文本特征维度为 [ 输入的文本数量

猜你喜欢

转载自blog.csdn.net/qq_61786525/article/details/144647928