【计算机视觉】使用 notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类


项目地址:

https://github.com/biluko/Paper_Codes_for_fun/tree/master/CLIP

在这里插入图片描述

一、CLIP 模型

CLIP(Contrastive Language-Image Pretraining)是由OpenAI开发的一个深度学习模型,用于处理图像和文本之间的联合表示。它的目标是将图像和文本嵌入到一个共享的向量空间中,使得相似的图像和文本在这个空间中距离较近,而不相似的图像和文本距离较远。

CLIP模型的特点在于它可以通过对图像和文本之间进行对比学习,来学习到一个通用的特征表示。在训练过程中,CLIP通过最大化相似图像和文本的相似性,并最小化不相似图像和文本的相似性来调整模型参数。这种对比学习的方法使得CLIP能够在多个任务上进行迁移学习,如图像分类、文本分类、图像生成等。

CLIP模型的应用非常广泛。通过将图像和文本映射到共享的向量空间,CLIP可以实现图像和文本之间的多模态检索和匹配。例如,通过将一张图片和一个描述该图片内容的文本查询进行编码,可以计算它们在向量空间中的距离,并找到与之相似的图片或文本。这为图像搜索、商品推荐、智能问答等应用提供了新的可能性。

CLIP模型的优势在于它不需要大量标注的训练数据,而是通过对比学习来学习通用的特征表示。这使得CLIP在跨领域和跨语言的应用上具有良好的泛化能力。此外,CLIP还能够理解和生成自然语言描述的图像,以及生成图像描述的文本,具备了一定的语义理解和生成能力。

总之,CLIP是一个强大的深度学习模型,能够将图像和文本嵌入到共享的向量空间中,并实现多模态的检索和匹配。它在图像和文本处理、多模态应用以及迁移学习等方面有着广泛的应用前景。

二、准备

包括下载 CLIP 依赖和将设置改为 GPU:

! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

在这里插入图片描述

import numpy as np
import torch
from pkg_resources import packaging

print("Torch version:", torch.__version__)

在这里插入图片描述

三、加载模型

展示可选择的不同图片特征提取器:

import clip
clip.available_models()

在这里插入图片描述
加载模型和图片处理器:

model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("模型参数:", f"{
      
      np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("输入图片尺寸:", input_resolution)
print("文本长度:", context_length)
print("词表大小:", vocab_size)

在这里插入图片描述

四、查看图片处理器

这里调整图片大小 224 × 224 224 \times 224 224×224,中心裁剪,然后使用均值和标准差进行归一化,最后输出tensor向量:

preprocess

在这里插入图片描述

五、文本分词

clip.tokenize("Hello World!")

在这里插入图片描述

六、输入图片和文本,并可视化

import os 
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch

%matplotlib inline
%config InlineBackend.figure_format='retina'

# images in skimage to use and their textual descriptions
descriptions = {
    
    
    "page": "a page of text about segmentation",
    "chelsea": "a facial photo of a tabby cat",
    "astronaut": "a portrait of an astronaut with the American flag",
    "rocket": "a rocket standing on a launchpad",
    "motorcycle_right": "a red motorcycle standing in a garage",
    "camera": "a person looking at a camera on a tripod",
    "horse": "a black-and-white silhouette of a horse", 
    "coffee": "a cup of coffee on a saucer"
}
original_images=[]
images=[]
texts=[]
plt.figure(figsize=(16,5))

for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
    name = os.path.splitext(filename)[0]
    if name not in descriptions:
        continue

    image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
  
    plt.subplot(2, 4, len(images) + 1)
    plt.imshow(image)
    plt.title(f"{
      
      filename}\n{
      
      descriptions[name]}")
    plt.xticks([])
    plt.yticks([])

    original_images.append(image)
    images.append(preprocess(image))
    texts.append(descriptions[name])

plt.tight_layout()

在这里插入图片描述

七、将图片和文字 encode 生成特征

image_input = torch.tensor(np.stack(images)).cuda()

print(image_input.shape)
text_tokens = clip.tokenize(['This is '+ desc for desc in texts]).cuda()

with torch.no_grad():
  image_features = model.encode_image(image_input).float()
  text_features = model.encode_text(text_tokens).float()
print(image_features.shape)
print(text_features.shape)

在这里插入图片描述

八、计算 cosine 相似度

image_features /= image_features.norm(dim = -1,keepdim = True)
text_features /= text_features.norm(dim = -1,keepdim = True)

similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
count = len(descriptions)
plt.figure(figsize = (20, 14))
plt.imshow(similarity, vmin = 0.1, vmax = 0.3)
# plt.colorbar()
plt.yticks(range(count), texts, fontsize = 18)
plt.xticks([])
for i, image in enumerate(original_images):
    plt.imshow(image, extent = (i - 0.5, i + 0.5, -1.6, -0.6), origin = "lower")
for x in range(similarity.shape[1]):
    for y in range(similarity.shape[0]):
        plt.text(x, y, f"{
      
      similarity[y, x]:.2f}", ha = "center", va = "center", size = 12)

for side in ["left", "top", "right", "bottom"]:
  plt.gca().spines[side].set_visible(False)

plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])

plt.title("Cosine similarity between text and image features", size = 20)

在这里插入图片描述

九、零样本进行图片分类

数据集CIFAR100,就是使用相似度计算得分,然后softmax一下:

from torchvision.datasets import CIFAR100

cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform = preprocess, download = True)

在这里插入图片描述
加上prompt 提示模板进行分类:

text_descriptions = [f"This is a photo of a {
      
      label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()

计算相似度得分:

with torch.no_grad():
    text_features = model.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim = -1, keepdim = True)

text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)

可视化结果:

plt.figure(figsize=(16, 16))

for i, image in enumerate(original_images):
    plt.subplot(4, 4, 2 * i + 1)
    plt.imshow(image)
    plt.axis("off")

    plt.subplot(4, 4, 2 * i + 2)
    y = np.arange(top_probs.shape[-1])
    plt.grid()
    plt.barh(y, top_probs[i])
    plt.gca().invert_yaxis()
    plt.gca().set_axisbelow(True)
    plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
    plt.xlabel("probability")

plt.subplots_adjust(wspace = 0.5)
plt.show()

在这里插入图片描述

十、编写函数进行图片分类

输入图片和供选择标签进行分类:

def show_result(image, probs, labels, label_name):
  plt.figure()
  plt.subplot(1, 2, 1)
  plt.imshow(image)
  plt.axis("off")

  plt.subplot(1, 2, 2)
  y = np.arange(probs.shape[-1])
  plt.grid()
  plt.barh(y, probs[0])
  plt.gca().invert_yaxis()
  plt.gca().set_axisbelow(True)
  plt.yticks(y, [label_name[index] for index in labels[0].numpy()])
  plt.xlabel("probability")

  plt.subplots_adjust(wspace = 0.5)
  plt.show()


def clip_classifier(image_path, choice_label, top_k = 5):
  # top_k小于choice_label数
  if top_k > len(choice_label):
    raise Exception('top_k大于候选标签数')

  # 读取图片
  image = Image.open(image_path).convert("RGB")
  # 输入特征
  text_descriptions = [f"This is a photo of a {
      
      label}" for label in choice_label]
  text_tokens = clip.tokenize(text_descriptions).cuda()
  image_input = preprocess(image)
  image_input = image_input.clone().detach().cuda()

  with torch.no_grad():
    image_features = model.encode_image(image_input.unsqueeze(0)).float()
    text_features = model.encode_text(text_tokens).float()

    image_features /= image_features.norm(dim = -1, keepdim = True)
    text_features /= text_features.norm(dim = -1, keepdim = True)

  #相似度得分
  text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
  top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)
  show_result(image, top_probs, top_labels, choice_label)

十一、测试自己的函数

clip_classifier('R.jpg',['Luffy','pig','boy','girl','one piece','bleach','black','man','cartoon','red','detector'])

在这里插入图片描述

clip_classifier('Holmes.jpg',['Holmes','pig','boy','girl','one piece','bleach','black','man','cartoon','red','detector'])

在这里插入图片描述

十二、编写函数对多图片进行分类

def clip_classifier_m(image_dir, choice_label, top_k = 5):
  # image_dir不为文件夹
  if not os.path.isdir(image_dir):
    raise Exception(image_dir + ' 应该为一个图片文件夹')

  # top_k小于choice_label数
  if top_k > len(choice_label):
    raise Exception('top_k大于候选标签数')


  #读取图片
  original_images = []
  images = []

  for filename in [filename for filename in os.listdir(image_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
    image = Image.open(os.path.join(image_dir, filename)).convert("RGB")

    original_images.append(image)
    images.append(preprocess(image))

  # 输入特征
  text_descriptions = [f"This is a photo of a {
      
      label}" for label in choice_label]
  text_tokens = clip.tokenize(text_descriptions).cuda()
  image_input = torch.tensor(np.stack(images)).cuda()
  with torch.no_grad():
    image_features = model.encode_image(image_input).float()
    text_features = model.encode_text(text_tokens).float()

    image_features /= image_features.norm(dim = -1, keepdim = True)
    text_features /= text_features.norm(dim = -1, keepdim = True)

  # 相似度得分
  text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
  top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)
  show_result_m(original_images, top_probs, top_labels, choice_label)


def show_result_m(images, probs, labels, label_name):
  length = len(images)
  num_row = length // 2

  plt.figure(figsize = (16, 16))

  for i, image in enumerate(images):
    plt.subplot(num_row, 4, 2 * i + 1)
    plt.imshow(image)
    plt.axis("off")

    plt.subplot(num_row, 4, 2 * i + 2)
    y = np.arange(probs.shape[-1])
    plt.grid()
    plt.barh(y, probs[i])
    plt.gca().invert_yaxis()
    plt.gca().set_axisbelow(True)
    plt.yticks(y, [label_name[index] for index in labels[i].numpy()])
    plt.xlabel("probability")

  plt.subplots_adjust(wspace = 1)
  plt.show()
clip_classifier_m('img',['Luffy','pig','boy','girl','one piece','bleach','black','man','cartoon','red','Holmes'])

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/wzk4869/article/details/131126354
今日推荐