info_nce_loss
这个是clip作对比学习的损失函数
各个博客上都有详细介绍了,我这里就不赘述
def info_nce_loss(image_features, text_features,logit_scale,labels, temperature=0.07):
batch_size = image_features.shape[0]
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
similarity_matrix = torch.matmul(image_features, text_features.T) / temperature
logits_per_image = similarity_matrix
logits_per_text = similarity_matrix.T
# 构造标签,正样本对应的位置为1,其余为0,这里假设批次内第一个文本特征是对应图像的正样本文本特征
gen_labels = torch.arange(batch_size).long().to(image_features.device)
total_loss = (
F.cross_entropy(logits_per_image, gen_labels)+
F.cross_entropy(logits_per_text, gen_labels)
)/2
return total_loss, logits_per_image, logits_per_text
我踩的坑
我微调 c l i p clip clip 做分类任务类别数为3
:
- 数据集为图像-文本对数据集:即一个数据样本为一个图像和对应的文本在json文件里。
这里每个类别的图像的文本都是一样的,也就是a类别下图像可能会有细微不同,但是文本都是一样的
- 微调 c l i p clip clip 的结构同原始 c l i p clip clip 一致,输出的图像特征维度为 [ 输入图像数量 , 512 ] [输入图像数量,512] [输入图像数量,512],文本特征维度为 [ 输入的文本数量