一、DINOv2 模型简介及使用
DINOv2
是由Meta AI
开发的第二代自监督视觉变换器模型,采用 Vision Transformer (ViT)
架构 。其核心特点是在无需人工标签的情况下,通过自监督学习技术,从海量无标注图像中学习有意义的视觉特征表示,类似于 NLP
领域的自监督 Base
模型,DINOv2
已经具有了对图像的理解能力,和强大的图像特征提取能力,因此它可以作为几乎所有计算机视觉任务的骨干模型。
下面是官方演示地址:
深度估计效果:
语义分割效果:
GitHub
开源地址:
huggingface
模型地址:
本文借助 DINOv2
强大的特征提取能力,实现图图相似度检索任务,但开始前,首先 先了解一下如何基于 DINOv2
实现图像的相似度计算:
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import matplotlib.pyplot as plt
import torch
plt.rcParams['font.sans-serif'] = ['SimHei']
# 生成图像特征
def gen_image_features(processor, model, device, image):
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt").to(device)
outputs = model(**inputs)
image_features = outputs.last_hidden_state
image_features = image_features.mean(dim=1)
return image_features[0]
# 计算两个图像的相似度
def similarity_image(processor, model, device, image1, image2):
features1 = gen_image_features(processor, model, device, image1)
features2 = gen_image_features(processor, model, device, image2)
cos_sim = torch.cosine_similarity(features1, features2, dim=0)
cos_sim = (cos_sim + 1) / 2
return cos_sim.item()
def main():
model_dir = "facebook/dinov2-base"
processor = AutoImageProcessor.from_pretrained(model_dir)
model = AutoModel.from_pretrained(model_dir)
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
model.to(device)
image1 = Image.open("img/dog1.jpg")
image2 = Image.open("img/dog2.jpg")
similarity = similarity_image(processor, model, device, image1, image2)
plt.figure()
plt.axis('off')
plt.title(f"相似度: {
similarity}")
plt.subplot(1, 2, 1)
plt.imshow(image1)
plt.subplot(1, 2, 2)
plt.imshow(image2)
plt.show()
if __name__ == '__main__':
main()
相似度计算的主要核心就是基于 DINOv2 生成特征向量,图图相似度检索也依赖这一点,不过需要多出一个特征向量的持久化存储端,整体实现架构如下图所示,其中特征向量存储采用 Milvus
数据库。
被检索图像数据集特征提取过程:
图像相似度检索过程:
关于 Milvus
的使用,可以参考下面这篇博客:
二、图图相似度检索实现 - 图像特征持久化
首先准备图像数据集,这里我随便准备了几张猫和狗的图片:
创建 Milvus Collection
,其中 DINOv2
特征向量维度为 768
维:
from pymilvus import MilvusClient, DataType
client = MilvusClient("http://192.168.0.5:19530")
schema = MilvusClient.create_schema(
auto_id=True,
enable_dynamic_field=False,
)
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=768)
schema.add_field(field_name="image_name", datatype=DataType.VARCHAR, max_length=256)
schema.verify()
index_params = client.prepare_index_params()
index_params.add_index(
field_name="id",
index_type="STL_SORT"
)
index_params.add_index(
field_name="vector",
index_type="IVF_FLAT",
metric_type="L2",
params={
"nlist": 1024}
)
# 创建 collection
client.create_collection(
collection_name="dinov2_collection",
schema=schema,
index_params=index_params
)
图像数据集提取特征向量后,持久化到 Milvus
中:
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
from pymilvus import MilvusClient
from tqdm import tqdm
import torch
import os
# 生成特征向量
def gen_image_features(processor, model, device, image):
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt").to(device)
outputs = model(**inputs)
image_features = outputs.last_hidden_state
image_features = image_features.mean(dim=1)
return image_features[0]
def main():
# 创建Milvus客户端
client = MilvusClient("http://192.168.0.5:19530")
# 加载模型
model_dir = "facebook/dinov2-base"
processor = AutoImageProcessor.from_pretrained(model_dir)
model = AutoModel.from_pretrained(model_dir)
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
model.to(device)
# 读取数据集
dataset_path = "./img"
for image_name in tqdm(os.listdir(dataset_path)):
image_path = os.path.join(dataset_path, image_name)
image = Image.open(image_path)
# 提取特征向量
features = gen_image_features(processor, model, device, image)
# 存吃至 milvus
client.insert(
collection_name="dinov2_collection",
data={
"vector": features,
"image_name": image_name
}
)
if __name__ == '__main__':
main()
运行后,在 Milvus insight
工具中,可以看到存储的内容:
三、图图相似度检索实现 - 图像特征检索
检索就是拿着当前图像的特征,去向量数据库中检索相似的特征信息,实现过程如下:
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
from pymilvus import MilvusClient
import matplotlib.pyplot as plt
import torch
import os
# 生成特征向量
def gen_image_features(processor, model, device, image):
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt").to(device)
outputs = model(**inputs)
image_features = outputs.last_hidden_state
image_features = image_features.mean(dim=1)
return image_features[0].tolist()
def main():
# 创建Milvus客户端
client = MilvusClient("http://192.168.0.5:19530")
# 加载模型
model_dir = "facebook/dinov2-base"
processor = AutoImageProcessor.from_pretrained(model_dir)
model = AutoModel.from_pretrained(model_dir)
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
model.to(device)
# 检索图像, 采用不在Milvus数据集中的图像
image = Image.open("./img2/dog.jpeg")
# 提取特征向量
features = gen_image_features(processor, model, device, image)
# 特征召回
results = client.search(
collection_name="dinov2_collection",
data=[features],
limit=2,
output_fields=["image_name"],
search_params={
"metric_type": "L2",
"params": {
}
}
)
plt.figure()
plt.axis('off')
for i, res in enumerate(results[0]):
image_name = res["entity"]["image_name"]
image_path = os.path.join("./img", image_name)
image = Image.open(image_path)
plt.subplot(1, 2, (i+1))
plt.imshow(image)
plt.show()
if __name__ == '__main__':
main()
测试输入检索图像:
召回图像:
测试输入检索图像:
召回图像: