Vector Retrieval: Based on the ResNet pre-training model to build an image search system

Table of contents

1 Project background introduction

2 Introduction of key technologies

2.1 Resnet network

2.2 Milvus vector database

3 System code implementation

3.1 Construction of the operating environment

3.2 Dataset download

3.3 Pre-training model download

3.4 Code implementation

3.4.1 Create Milvus tables and indexes

 3.4.2 Building a Resnet Encoding Network

3.4.3 Data vectorization and loading

3.4.4 Build search web

4 Summary


1 Project background introduction

Image search by image is a vector retrieval technology that searches and finds other images or related information related to it by uploading an image. The image search technology provides a more intuitive and efficient information retrieval method. This technology has a wide range of application scenarios and values, and is often used in commodity retrieval and shopping, animal and plant identification, food identification, knowledge retrieval and other fields. The technical points involved in image search are as follows:

  • How to vector encode image data
  • How to store massive vector data
  • How to Quickly Retrieve Massive Vector Data

Based on the Resnet pre-training model combined with the Milvus vector database, this project implements a map search system on the fruit dataset. Readers can expand the dataset to other fields and build a map search system that meets their own business.

2 Introduction of key technologies

2.1 Resnet network

ResNet, the full name of Residual Network, is one of the very important Convolutional Neural Network (CNN) architectures in the field of deep learning. It was proposed by Kaiming He et al. in 2015, and achieved remarkable results in the ImageNet image classification competition. At that time, it won the first place in classification tasks, target detection, and image segmentation. The innovation of ResNet is the introduction of residual connections (residual connections), allowing the network to more easily train deep networks during training.

In traditional neural networks, performance may saturate or even degrade as the number of network layers increases. This is because problems like vanishing and exploding gradients can make training difficult. ResNet solves this problem by introducing residual blocks. Each residual block consists of a main convolutional layer, the difference between its output and input is called "residual", and then the residual is added back to get the final output. Such an architecture allows information to spread more easily in the network, even when the network becomes very deep.

The classic network structures of ResNet are: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152. Among them, ResNet-18 and ResNet-34 have the same basic structure and belong to relatively shallow networks. The latter three belong to deeper networks, among which RestNet50 is the most commonly used.

 The advantages of ResNet include:

  • Training Deeper Networks: The introduction of residual connections allows for the construction of very deep networks that tend to converge more easily when trained.
  • Avoid Gradient Vanishing and Exploding: Residual connections help gradients to spread better in the network, reducing the problem of gradient disappearing and exploding.
  • Better Feature Learning: Residual blocks allow the network to learn residuals, i.e. learn fine-grained features that are easier to capture.

ResNet detailed introduction: ResNet

2.2 Milvus vector database

Milvus is a cloud-native vector database with high availability, high performance, and easy expansion, and is used for real-time recall of massive vector data.

Milvus is built on the basis of FAISS, Annoy, HNSW and other vector search libraries, and its core is to solve the problem of dense vector similarity retrieval. Based on the vector retrieval library, Milvus supports functions such as data partitioning, data persistence, incremental data ingestion, scalar-vector hybrid query, time travel, etc., and greatly optimizes the performance of vector retrieval to meet any vector retrieval scenario application requirements. Generally, it is recommended that users deploy Milvus with Kubernetes for best availability and resilience.

Milvus adopts a shared storage architecture, ​storage and computing are completely separated​, and computing nodes support horizontal expansion. From an architectural point of view, Milvus follows the separation of data flow and control flow, and is divided into four layers as a whole, namely access layer, coordinator service, worker node and storage layer. . Each level is independent of each other, independent expansion and disaster recovery.

 Milvus vector database can help users easily deal with massive unstructured data (picture/video/voice/text) retrieval. A single-node Milvus can complete a billion-level vector search within seconds, and the distributed architecture can also meet the user's horizontal expansion requirements.

The characteristics of milvus are summarized as follows:

  • High performance: It has high performance and can perform vector similarity retrieval on massive data sets.
  • High availability and reliability: Milvus supports expansion on the cloud, and its disaster recovery capability can ensure high service availability.
  • Hybrid query: Milvus supports scalar field filtering during vector similarity retrieval to achieve hybrid query.
  • Developer friendly: Milvus ecosystem that supports multi-language and multi-tools.

Milvus in detail: Milvus

3 System code implementation

3.1 Construction of the operating environment

For conda environment preparation, see: annoconda

git clone https://gitcode.net/ai-medical/image_image_search.git
cd image_image_search

pip install -r requirements.txt

3.2 Dataset download

download link:

First packet: package01

Second package: package01

Under the dataset directory, there are 10 folders, the folder name is fruit type, and each folder contains hundreds to thousands of pictures of this type of fruit, as shown in the following figure:

 Take the apple folder as an example, the content is as follows:

After downloading, unzip it and save it in the D:/dataset/fruit directory. The display is as follows

# ll fruit/
总用量 508
drwxr-xr-x 2 root root 36864 8月   2 16:35 apple
drwxr-xr-x 2 root root 24576 8月   2 16:36 apricot
drwxr-xr-x 2 root root 40960 8月   2 16:36 banana
drwxr-xr-x 2 root root 20480 8月   2 16:36 blueberry
drwxr-xr-x 2 root root 45056 8月   2 16:37 cherry
drwxr-xr-x 2 root root 12288 8月   2 16:37 citrus
drwxr-xr-x 2 root root 49152 8月   2 16:38 grape
drwxr-xr-x 2 root root 16384 8月   2 16:38 lemon
drwxr-xr-x 2 root root 36864 8月   2 16:39 litchi
drwxr-xr-x 2 root root 49152 8月   2 16:39 mango

3.3 Pre-training model download

 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',

Download the pre-training model of resnet50: resnet50 , store it in the D:/models directory

3.4 Code implementation

3.4.1 Create Milvus tables and indexes

from pymilvus import connections, db

conn = connections.connect(host="192.168.1.156", port=19530)
database = db.create_database("image_vector_db")

db.using_database("image_vector_db")
print(db.list_database())

Create a collection

from pymilvus import CollectionSchema, FieldSchema, DataType
from pymilvus import Collection, db, connections


conn = connections.connect(host="192.168.1.156", port=19530)
db.using_database("image_vector_db")

m_id = FieldSchema(name="m_id", dtype=DataType.INT64, is_primary=True,)
embeding = FieldSchema(name="embeding", dtype=DataType.FLOAT_VECTOR, dim=2048,)
path = FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=256,)
schema = CollectionSchema(
  fields=[m_id, embeding, path],
  description="image to image embeding search",
  enable_dynamic_field=True
)

collection_name = "fruit_vector"
collection = Collection(name=collection_name, schema=schema, using='default', shards_num=2)

create index

from pymilvus import Collection, utility, connections, db

conn = connections.connect(host="192.168.1.156", port=19530)
db.using_database("image_vector_db")

index_params = {
  "metric_type": "L2",
  "index_type": "IVF_FLAT",
  "params": {"nlist": 1024}
}

collection = Collection("fruit_vector")
collection.create_index(
  field_name="embeding",
  index_params=index_params
)

utility.index_building_progress("fruit_vector")

 3.4.2 Building a Resnet Encoding Network

Load the Resnet pre-training model, remove the fully connected layer, so that the Resnet encoding output feature dimension is 2048

from torchvision.models import resnet50
import torch
from torchvision import transforms
from torch import nn


class ResnetEmbeding:
    pretrained_model = 'D:/models/resnet50-0676ba61.pth'

    def __init__(self):
        self.model = resnet50()
        self.model.load_state_dict(torch.load(self.pretrained_model))

        # delete fc layer
        self.model.fc = nn.Sequential()
        self.transform = transforms.Compose([transforms.Resize((224, 224)),
                                             transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                                  std=[0.26862954, 0.26130258, 0.27577711])])

    def embeding(self, image):
        trans_image = self.transform(image)
        trans_image = trans_image.unsqueeze_(0)
        return self.model(trans_image)


resnet_embeding = ResnetEmbeding()

3.4.3 Data vectorization and loading

from resnet_embeding import resnet_embeding
from milvus_operator import restnet_image, MilvusOperator
from PIL import Image, ImageSequence
import os


def update_image_vector(data_path, operator: MilvusOperator):
    idxs, embedings, paths = [], [], []

    total_count = 0
    for dir_name in os.listdir(data_path):
        sub_dir = os.path.join(data_path, dir_name)
        for file in os.listdir(sub_dir):

            image = Image.open(os.path.join(sub_dir, file)).convert('RGB')
            embeding = resnet_embeding.embeding(image)

            idxs.append(total_count)
            embedings.append(embeding[0].detach().numpy().tolist())
            paths.append(os.path.join(sub_dir, file))
            total_count += 1

            if total_count % 50 == 0:
                data = [idxs, embedings, paths]
                operator.insert_data(data)

                print(f'success insert {operator.coll_name} items:{len(idxs)}')
                idxs, embedings, paths = [], [], []

        if len(idxs):
            data = [idxs, embedings, paths]
            operator.insert_data(data)
            print(f'success insert {operator.coll_name} items:{len(idxs)}')

    print(f'finish update {operator.coll_name} items: {total_count}')


if __name__ == '__main__':
    data_dir = 'D:/dataset/fruit'
    update_image_vector(data_dir, resnet_image)

3.4.4 Build search web

import gradio as gr
import torch
import numpy as np
import argparse
from net_helper import net_helper
from PIL import Image
from restnet_embeding import restnet_embeding
from milvus_operator import resnet_image


def image_search(image):
    if image is None:
        return None

    image = image.convert("RGB")

    # resnet编码
    imput_embeding = resnet_embeding.embeding(image)
    imput_embeding = imput_embeding[0].detach().cpu().numpy()

    results = restnet_image.search_data(imput_embeding)
    pil_images = [Image.open(result['path']) for result in results]
    return pil_images


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--share", action="store_true",
                        default=False, help="share gradio app")
    args = parser.parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    app = gr.Blocks(theme='default', title="image",
                    css=".gradio-container, .gradio-container button {background-color: #009FCC} "
                        "footer {visibility: hidden}")
    with app:
        with gr.Tabs():
            with gr.TabItem("image search"):
                with gr.Row():
                    with gr.Column():
                        image = gr.inputs.Image(type="pil", source='upload')
                        btn = gr.Button(label="search")

                    with gr.Column():
                        with gr.Row():
                            output_images = [gr.outputs.Image(type="pil", label=None) for _ in range(16)]

                btn.click(image_search, inputs=[image], outputs=output_images, show_progress=True)

    ip_addr = net_helper.get_host_ip()
    app.queue(concurrency_count=3).launch(show_api=False, share=True, server_name=ip_addr, server_port=9099)

4 Summary

Based on the two key technologies of the Resnet pre-training model and the milvus vector database, this project builds an image retrieval system for image search; during the construction process, the Resnet network model was transformed, the fully connected layer was removed, and after Restnet encoding The output vector dimension of each image is 2048, which is stored in the milvus vector database; in order to ensure the efficiency of image retrieval, a vector index is constructed in the milvus vector database through a script. This project can be used as a reference and directly used in the actual development of similar search-by-picture projects.

The complete code address of the project: code

Guess you like

Origin blog.csdn.net/lsb2002/article/details/132456845