从零开始AlignedReID_01

数据准备

前言

从这篇博客开始, 就要记录一下复现AlignedReID算法的代码,以及在复现过程中遇到的一些问题,作为后续自己实现模型的基础,阅读和理解别人的代码是非常有必要的。并且,在实现代码的过程,会尽量使得代码工程化。

那么接下来,就开始吧!

1. 认识Market1501数据集

在对行人重识别领域有过一定了解的情况下,应该都知道这个数据集了,代码的实现就以该数据集进行训练和测试。

1.1 数据集下载

数据集下载地址:Market1501
(也可以去之前博客中提到的那个网站去下载,不过那个网站链接好像经常打不开~)

1.2 目录介绍

数据集下载完成后,对文件夹进行解压之后可以进行一下重命名,为了方便我们之后代码中路径的添加,目录结构如下图所示(我把数据集存放到了新建的data文件夹下):
在这里插入图片描述
1) “bounding_box_test”——用于测试集的 750 人,包含 19,732 张图像,前缀为 0000 表示在提取这 750 人的过程中DPM检测错的图(可能与query是同一个人),-1 表示检测出来其他人的图(不在这 750 人中)
2) “bounding_box_train”——用于训练集的 751 人,包含 12,936 张图像
3) “gt_bbox”——手工标注的bounding box,用于判断DPM检测的bounding box是不是一个好的box
4) “gt_query”——matlab格式,用于判断一个query的哪些图片是好的匹配(同一个人不同摄像头的图像)和不好的匹配(同一个人同一个摄像头的图像或非同一个人的图像)
5) “query”——为 750 人在每个摄像头中随机选择一张图像作为query,因此一个人的query最多有 6 个,共有 3,368 张图像

Note:现在主要使用1、2、5文件夹,弃用3,4。

1.3 命名规则

以 0001_c1s1_000151_01.jpg 为例
1) 0001 表示每个人的标签编号,从0001到1501;
2) c1 表示第一个摄像头(camera1),共有6个摄像头;
3) s1 表示第一个录像片段(sequece1),每个摄像机都有数个录像段;
4) 000151 表示 c1s1 的第000151帧图片,视频帧率25fps;
5) 01 表示 c1s1_001051 这一帧上的第1个检测框,由于采用DPM检测器,对于每一帧上的行人可能会框出好几个bbox。00 表示手工标注框

2. 数据集的加载

那么先新建工程项目文件目录,这里我使用的是pycharm + python3环境。工程目录文件如下:
在这里插入图片描述
在data_process文件夹下新建文件data_manager.py加载数据集:

#-*-coding:utf-8-*-
# 此文件用于加载数据集Market1501
"""
主要步骤:
1.拼接文件夹路径
2.获取图片路径信息、行人ID(pid)、摄像头ID(camid)
3.统计行人、图片总数

"""
from __future__ import print_function, absolute_import
import os.path as osp
import glob
import re

from IPython import embed


class Market1501(object):
    """
    Market1501

    Reference:
    Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.

    URL: http://www.liangzheng.org/Project/project_reid.html

    Dataset statistics:
    # identities: 1501 (+1 for background)
    # images: 12936 (train) + 3368 (query) + 15913 (gallery)
    """
    # 数据集Market1501目录
    dataset_dir = "market"

    # 通过创建类对象完成对数据集的加载,因此把读取操作都放入 init方法
    # 默认传入参数root 为数据集所在根目录
    # 默认传入参数min_seq_len 为最小序列长度 默认值为0
    # **kwargs可能会有其他参数
    def __init__(self, root='/home/dmb/Desktop/materials/data', min_seq_len=0,**kwargs):
        # 1.加载几个文件夹目录 拼接路径
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
        self.query_dir = osp.join(self.dataset_dir, 'query')
        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')

        # 检查是否成功加载
        self._check_before_run()

        # 调用目录处理方法
        # 2.获取图片路径信息、行人ID(pid)、摄像头ID(camid)
        # train: ('/home/dmb/Desktop/materials/data/market/bounding_box_train/0796_c3s2_089653_01.jpg', 420, 2)
        train,num_train_pids,num_train_imgs = self._process_dir(self.train_dir,relabel=True)
        query,num_query_pids,num_query_imgs = self._process_dir(self.query_dir,relabel=False)
        gallery,num_gallery_pids,num_gallery_imgs = self._process_dir(self.gallery_dir,relabel=False)
        # 3.统计行人、图片总数
        num_total_pids = num_train_pids + num_query_pids
        num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs
        # embed()
        # 打印信息
        print("=> Market1501 loaded")
        print("Dataset statistics:")
        print("  ------------------------------")
        print("  subset   | # ids | # images")
        print("  ------------------------------")
        print("  train    | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))
        print("  query    | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
        print("  gallery  | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
        print("  ------------------------------")
        print("  total    | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))
        print("  ------------------------------")

        self.train = train
        self.query = query
        self.gallery = gallery

        self.num_train_pids = num_train_pids
        self.num_query_pids = num_query_pids
        self.num_gallery_pids = num_gallery_pids

    def _check_before_run(self):
        # 定义检验加载是否成功方法
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("{} is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("{} is not available".format(self.train_dir))
        if not osp.exists(self.query_dir):
            raise RuntimeError("{} is not available".format(self.query_dir))
        if not osp.exists(self.gallery_dir):
            raise RuntimeError("{} is not available".format(self.gallery_dir))


    def _process_dir(self,dir_path,relabel=False):

        # 此函数返回一个符合glob匹配的pathname的list,返回结果有可能是空
        # 2.1匹配该路径下所有以.jpg结尾的文件,放入list
        img_paths = glob.glob(osp.join(dir_path,'*.jpg'))
        # 正则表达式设置匹配规则 只提取行人id以及摄像头id
        pattern = re.compile(r'([-\d]+)_c(\d)')

        # 2.2实现relabel
        # 原因是由于训练集只有751个行人,但标注是到1501,直接使用1501会使模型产生750个无效神经元
        # set集合存放的行人ID 后面会用的到
        # 使用set集合可以去重
        pid_container = set()
        # 遍历list集合中的图片名
        for img_path in img_paths:
            # 只关心每张图片的pid,其他值设置为缺省值
            # map() 会根据提供的函数对指定序列做映射。
            # 第一个参数 function 以参数序列中的每一个元素调用 function 函数,返回包含每次 function 函数返回值的新列表
            pid,_ = map(int,pattern.search(img_path).groups())

            # 跳过所有pid为-1的项
            if pid == -1:continue
            # 添加pid到列表
            pid_container.add(pid)
        pid2label = {
    
    pid:label for label,pid in enumerate(pid_container)}
        # embed()

        dataset = []
        for img_path in img_paths:
            pid,camid = map(int,pattern.search(img_path).groups())
            if pid == -1:continue
            assert 0 <= pid <=1501
            assert 1 <= camid <= 6
            camid -= 1
            # 这里有个判断 只有relabel = True 我才relabel
            if relabel :pid = pid2label[pid]
            dataset.append((img_path,pid,camid))

        num_pids = len(pid_container)
        num_imgs = len(dataset)
        # 返回值为dataset,图片id数量,图片数量
        return dataset,num_pids,num_imgs


"""Create dataset"""

__img_factory = {
    
    
    'market1501': Market1501,
    # 'cuhk03': CUHK03,
    # 'dukemtmcreid': DukeMTMCreID,
    # 'msmt17': MSMT17,
}

# __vid_factory = {
    
    
#     'mars': Mars,
#     'ilidsvid': iLIDSVID,
#     'prid': PRID,
#     'dukemtmcvidreid': DukeMTMCVidReID,
# }

def get_names():
    return __img_factory.keys()

def init_img_dataset(name, **kwargs):
    if name not in __img_factory.keys():
        raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, __img_factory.keys()))
    return __img_factory[name](**kwargs)

# 验证
if __name__ == "__main__":
    init_img_dataset(root='/home/dmb/Desktop/materials/data',name="market1501")

结果:
在这里插入图片描述

3. 重构data_loader库

data_loader是pytorch比较重要的一个库,主要负责数据的吞吐,我们需要数据按照自己需要的方式进行吞吐,那么就需要在源码的基础上进行一定量的修改。

#-*-coding:utf-8-*-
from __future__ import print_function, absolute_import
from PIL import Image
import numpy as np
import os.path as osp

import torch
from torch.utils.data import Dataset
from IPython import embed
from AlignedReId.data_process import data_manager

# 设置图片读取方法
def read_image(image_path):
    """Keep reading image until succeed.
    This can avoid IOError incurred by heavy IO process."""
    # 标志位表示是否读取到图片
    got_image = False
    if not osp.exists(image_path):
        raise IOError("{} is not exists".format(image_path))
    # 没读到图片就一直读
    while not got_image:
        try:
            # 把读到的图片转化为RGB格式
            img = Image.open(image_path).convert('RGB')
            got_image = True
        except IOError:
            print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(image_path))
            pass
        return img


# 重写dataset类
class ImageDataset(Dataset):
    """Image Person ReID Dataset"""
    def __init__(self,dataset,transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self,index):
        # 读取dataset的一行信息
        img_path,pid,camid =self.dataset[index]
        # 使用read_image读取图片
        img = read_image(img_path)
        # 判断是否进行数据增广
        if self.transform is not None:
            img = self.transform(img)
        return img, pid, camid



# 验证
# if __name__ == "__main__":
#     dataset =data_manager.init_img_dataset(root='/home/dmb/Desktop/materials/data',name="market1501")
#     train_loader = ImageDataset(dataset.train)
#     for batch_id,(img,pid,camid) in enumerate(train_loader):
#         break
#     print(batch_id,img,pid,camid)


返回结果:
在这里插入图片描述可以看到train_loader是dataset的一个实例,可以通过迭代器进行取值,得到的值分别为batch_id,图片,行人id,以及摄像头id。
(但这里需要注意的一点,我们想要的并不是图片本身,而是要把图片转换成一个tensor,后面会提到)

数据采样

在utils中新建sample.py文件,负责对训练集进行采样,每个epoch,我们会对训练集中每个行人采集4张图片,也就是751*4=3004张图片,代码如下:

from __future__ import absolute_import
from collections import defaultdict
import numpy as np

import torch
from torch.utils.data.sampler import Sampler

class RandomIdentitySampler(Sampler):
    """
    Randomly sample N identities, then for each identity,
    randomly sample K instances, therefore batch size is N*K.

    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.

    Args:
        data_source (Dataset): dataset to sample from.
        num_instances (int): number of instances per identity.
    """
    def __init__(self, data_source, num_instances=4):
        self.data_source = data_source
        self.num_instances = num_instances
        self.index_dic = defaultdict(list)
        for index, (_, pid, _) in enumerate(data_source):
            self.index_dic[pid].append(index)
        self.pids = list(self.index_dic.keys())
        self.num_identities = len(self.pids)

    def __iter__(self):
        indices = torch.randperm(self.num_identities)
        ret = []
        for i in indices:
            pid = self.pids[i]
            t = self.index_dic[pid]
            replace = False if len(t) >= self.num_instances else True
            t = np.random.choice(t, size=self.num_instances, replace=replace)
            ret.extend(t)
        return iter(ret)

    def __len__(self):
        return self.num_identities * self.num_instances

if __name__ == "__main__":
    dataset =dataset_manager.init_img_dataset(root='/home/dmb/Desktop/materials/data',name="market1501")
    train_loader = ImageDataset(dataset.train)
    sample = RandomIdentitySampler(train_loader)
    for ret in enumerate(sample):
        print(ret)

打印ret结果如下:
在这里插入图片描述括号内第一项表示为第n张图片,第二项表示为图片对应的标号。

5.数据预处理

这个地方还应该有关于数据增强的部分,但是pytorch提供了丰富的数据增强方法,这里就不自己写了,后面直接使用提供的数据增强方法。
如果后面需要自己设计数据增强方法,那就再记录把!

猜你喜欢

转载自blog.csdn.net/qq_37747189/article/details/111867247
今日推荐