找出与给定图片最相似的三个图片

参考文献:

pytorch以图搜图作业__-周-_的博客-CSDN博客_以图搜图算法pytorch
pytorch加载自己的图片数据集的两种方法__-周-_的博客-CSDN博客_pytorch读取图片数据集

pytorch对网络层的增,删, 改, 修改预训练模型结构__-周-_的博客-CSDN博客_pytorch修改网络结构

1.网络的修改

1)保留vgg16提取特征网络,去除全连接层和 avgpool层, 并且给最后一个卷积层改成通道数为1

net = models.vgg16(pretrained=True)
net.classifier = nn.Sequential()
net.features[28] = nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1)
net.avgpool = nn.Sequential()

2.数据集的加载

1)定义一个转化成txt的函数

def mak_txt(root, file_name):
    path = os.path.join(root, file_name)
    data = os.listdir(path)
    f = open(path + '\\' + 'f.txt', 'w')
    for line in data:
        if line == 'f.txt':
            continue
        f.write(line + '\n')
    f.close()

2)调用mak_txt函数转化txt

image_packages = r'D:\AI\images_retreve\image_packages'
inputs_images = r'D:\AI\images_retrev\inputs_images'
path = r'D:\AI\images_retreve'
mak_txt(path, 'image_packages')
mak_txt(path, 'inputs_images')

3)图片预处理


# 进行图片预处理
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

4)定义mydataset

class MyDataset(Dataset):
    def __init__(self, img_path, transform=None):
        super(MyDataset, self).__init__()
        self.img_path = img_path
        self.txt_root = img_path + r'\f.txt'
        f = open(self.txt_root, 'r')
        data = f.readlines()

        imgs = []

        for line in data:
            line.strip()
            word = line.split()
            imgs.append(os.path.join(self.img_path, word[0]))

        self.img = imgs
        self.transform = transform

    def __len__(self):
        return len(self.img)

    def __getitem__(self, item):
        img = self.img[item]

        img = Image.open(img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img

5)加载数据集

扫描二维码关注公众号,回复: 14715465 查看本文章
dataset_inputs = MyDataset(inputs_images, transform=transform)
dataset_packages = MyDataset(image_packages, transform=transform)

data_loader_inputs = DataLoader(dataset=dataset_inputs, batch_size=1, shuffle=False)
data_loader_packages = DataLoader(dataset=dataset_packages, batch_size=100, shuffle=False)

3.计算相似性

1)开始输入数据

for i, data in enumerate(data_loader_inputs):
    output_inputs = net(data)

for i, data in enumerate(data_loader_packages):
    output_packages = net(data)

print(output_inputs.shape)
print(output_packages.shape)
2)调用F库中的欧式距离方法
dist2 = F.pairwise_distance(output_inputs, output_packages, p=2)
print(dist2.shape)

4.输出最相似的三张图片

1)输出最相似的三个图片的索引
max_list = []
for i in range(3):
    max_n = torch.argmin(dist2)
    max_list.append(int(max_n))
    dist2[max_n] = 9999999.9
print(max_list)

2)根据索引找到原图片

path_dir = image_packages + r'\f.txt'
f = open(path_dir, 'r')
data = f.readlines()
data_img = []
for i in range(3):
    img_path = os.path.join(image_packages, data[max_list[i]])
    data_img.append(img_path)

3)创建画布,将图片放在画布上展示出来

fig = plt.figure(figsize=(10, 10))
for i in range(1, 4):
    ax = fig.add_subplot(3, 1, i)  # 创建一个3行1列的画布, 遍历依次为第1个、第2个画布、第3个画布
    img = Image.open(data_img[i - 1].strip())
    ax.imshow(img)
    pass
plt.show()

猜你喜欢

转载自blog.csdn.net/weixin_52950958/article/details/125781318