数据准备
前言
从这篇博客开始, 就要记录一下复现AlignedReID算法的代码,以及在复现过程中遇到的一些问题,作为后续自己实现模型的基础,阅读和理解别人的代码是非常有必要的。并且,在实现代码的过程,会尽量使得代码工程化。
那么接下来,就开始吧!
1. 认识Market1501数据集
在对行人重识别领域有过一定了解的情况下,应该都知道这个数据集了,代码的实现就以该数据集进行训练和测试。
1.1 数据集下载
数据集下载地址:Market1501
(也可以去之前博客中提到的那个网站去下载,不过那个网站链接好像经常打不开~)
1.2 目录介绍
数据集下载完成后,对文件夹进行解压之后可以进行一下重命名,为了方便我们之后代码中路径的添加,目录结构如下图所示(我把数据集存放到了新建的data文件夹下):
1) “bounding_box_test”——用于测试集的 750 人,包含 19,732 张图像,前缀为 0000 表示在提取这 750 人的过程中DPM检测错的图(可能与query是同一个人),-1 表示检测出来其他人的图(不在这 750 人中)
2) “bounding_box_train”——用于训练集的 751 人,包含 12,936 张图像
3) “gt_bbox”——手工标注的bounding box,用于判断DPM检测的bounding box是不是一个好的box
4) “gt_query”——matlab格式,用于判断一个query的哪些图片是好的匹配(同一个人不同摄像头的图像)和不好的匹配(同一个人同一个摄像头的图像或非同一个人的图像)
5) “query”——为 750 人在每个摄像头中随机选择一张图像作为query,因此一个人的query最多有 6 个,共有 3,368 张图像
Note:现在主要使用1、2、5文件夹,弃用3,4。
1.3 命名规则
以 0001_c1s1_000151_01.jpg 为例
1) 0001 表示每个人的标签编号,从0001到1501;
2) c1 表示第一个摄像头(camera1),共有6个摄像头;
3) s1 表示第一个录像片段(sequece1),每个摄像机都有数个录像段;
4) 000151 表示 c1s1 的第000151帧图片,视频帧率25fps;
5) 01 表示 c1s1_001051 这一帧上的第1个检测框,由于采用DPM检测器,对于每一帧上的行人可能会框出好几个bbox。00 表示手工标注框
2. 数据集的加载
那么先新建工程项目文件目录,这里我使用的是pycharm + python3环境。工程目录文件如下:
在data_process文件夹下新建文件data_manager.py加载数据集:
#-*-coding:utf-8-*-
# 此文件用于加载数据集Market1501
"""
主要步骤:
1.拼接文件夹路径
2.获取图片路径信息、行人ID(pid)、摄像头ID(camid)
3.统计行人、图片总数
"""
from __future__ import print_function, absolute_import
import os.path as osp
import glob
import re
from IPython import embed
class Market1501(object):
"""
Market1501
Reference:
Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
URL: http://www.liangzheng.org/Project/project_reid.html
Dataset statistics:
# identities: 1501 (+1 for background)
# images: 12936 (train) + 3368 (query) + 15913 (gallery)
"""
# 数据集Market1501目录
dataset_dir = "market"
# 通过创建类对象完成对数据集的加载,因此把读取操作都放入 init方法
# 默认传入参数root 为数据集所在根目录
# 默认传入参数min_seq_len 为最小序列长度 默认值为0
# **kwargs可能会有其他参数
def __init__(self, root='/home/dmb/Desktop/materials/data', min_seq_len=0,**kwargs):
# 1.加载几个文件夹目录 拼接路径
self.dataset_dir = osp.join(root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
self.query_dir = osp.join(self.dataset_dir, 'query')
self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
# 检查是否成功加载
self._check_before_run()
# 调用目录处理方法
# 2.获取图片路径信息、行人ID(pid)、摄像头ID(camid)
# train: ('/home/dmb/Desktop/materials/data/market/bounding_box_train/0796_c3s2_089653_01.jpg', 420, 2)
train,num_train_pids,num_train_imgs = self._process_dir(self.train_dir,relabel=True)
query,num_query_pids,num_query_imgs = self._process_dir(self.query_dir,relabel=False)
gallery,num_gallery_pids,num_gallery_imgs = self._process_dir(self.gallery_dir,relabel=False)
# 3.统计行人、图片总数
num_total_pids = num_train_pids + num_query_pids
num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs
# embed()
# 打印信息
print("=> Market1501 loaded")
print("Dataset statistics:")
print(" ------------------------------")
print(" subset | # ids | # images")
print(" ------------------------------")
print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))
print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
print(" ------------------------------")
print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))
print(" ------------------------------")
self.train = train
self.query = query
self.gallery = gallery
self.num_train_pids = num_train_pids
self.num_query_pids = num_query_pids
self.num_gallery_pids = num_gallery_pids
def _check_before_run(self):
# 定义检验加载是否成功方法
if not osp.exists(self.dataset_dir):
raise RuntimeError("{} is not available".format(self.dataset_dir))
if not osp.exists(self.train_dir):
raise RuntimeError("{} is not available".format(self.train_dir))
if not osp.exists(self.query_dir):
raise RuntimeError("{} is not available".format(self.query_dir))
if not osp.exists(self.gallery_dir):
raise RuntimeError("{} is not available".format(self.gallery_dir))
def _process_dir(self,dir_path,relabel=False):
# 此函数返回一个符合glob匹配的pathname的list,返回结果有可能是空
# 2.1匹配该路径下所有以.jpg结尾的文件,放入list
img_paths = glob.glob(osp.join(dir_path,'*.jpg'))
# 正则表达式设置匹配规则 只提取行人id以及摄像头id
pattern = re.compile(r'([-\d]+)_c(\d)')
# 2.2实现relabel
# 原因是由于训练集只有751个行人,但标注是到1501,直接使用1501会使模型产生750个无效神经元
# set集合存放的行人ID 后面会用的到
# 使用set集合可以去重
pid_container = set()
# 遍历list集合中的图片名
for img_path in img_paths:
# 只关心每张图片的pid,其他值设置为缺省值
# map() 会根据提供的函数对指定序列做映射。
# 第一个参数 function 以参数序列中的每一个元素调用 function 函数,返回包含每次 function 函数返回值的新列表
pid,_ = map(int,pattern.search(img_path).groups())
# 跳过所有pid为-1的项
if pid == -1:continue
# 添加pid到列表
pid_container.add(pid)
pid2label = {
pid:label for label,pid in enumerate(pid_container)}
# embed()
dataset = []
for img_path in img_paths:
pid,camid = map(int,pattern.search(img_path).groups())
if pid == -1:continue
assert 0 <= pid <=1501
assert 1 <= camid <= 6
camid -= 1
# 这里有个判断 只有relabel = True 我才relabel
if relabel :pid = pid2label[pid]
dataset.append((img_path,pid,camid))
num_pids = len(pid_container)
num_imgs = len(dataset)
# 返回值为dataset,图片id数量,图片数量
return dataset,num_pids,num_imgs
"""Create dataset"""
__img_factory = {
'market1501': Market1501,
# 'cuhk03': CUHK03,
# 'dukemtmcreid': DukeMTMCreID,
# 'msmt17': MSMT17,
}
# __vid_factory = {
# 'mars': Mars,
# 'ilidsvid': iLIDSVID,
# 'prid': PRID,
# 'dukemtmcvidreid': DukeMTMCVidReID,
# }
def get_names():
return __img_factory.keys()
def init_img_dataset(name, **kwargs):
if name not in __img_factory.keys():
raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, __img_factory.keys()))
return __img_factory[name](**kwargs)
# 验证
if __name__ == "__main__":
init_img_dataset(root='/home/dmb/Desktop/materials/data',name="market1501")
结果:
3. 重构data_loader库
data_loader是pytorch比较重要的一个库,主要负责数据的吞吐,我们需要数据按照自己需要的方式进行吞吐,那么就需要在源码的基础上进行一定量的修改。
#-*-coding:utf-8-*-
from __future__ import print_function, absolute_import
from PIL import Image
import numpy as np
import os.path as osp
import torch
from torch.utils.data import Dataset
from IPython import embed
from AlignedReId.data_process import data_manager
# 设置图片读取方法
def read_image(image_path):
"""Keep reading image until succeed.
This can avoid IOError incurred by heavy IO process."""
# 标志位表示是否读取到图片
got_image = False
if not osp.exists(image_path):
raise IOError("{} is not exists".format(image_path))
# 没读到图片就一直读
while not got_image:
try:
# 把读到的图片转化为RGB格式
img = Image.open(image_path).convert('RGB')
got_image = True
except IOError:
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(image_path))
pass
return img
# 重写dataset类
class ImageDataset(Dataset):
"""Image Person ReID Dataset"""
def __init__(self,dataset,transform=None):
self.dataset = dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self,index):
# 读取dataset的一行信息
img_path,pid,camid =self.dataset[index]
# 使用read_image读取图片
img = read_image(img_path)
# 判断是否进行数据增广
if self.transform is not None:
img = self.transform(img)
return img, pid, camid
# 验证
# if __name__ == "__main__":
# dataset =data_manager.init_img_dataset(root='/home/dmb/Desktop/materials/data',name="market1501")
# train_loader = ImageDataset(dataset.train)
# for batch_id,(img,pid,camid) in enumerate(train_loader):
# break
# print(batch_id,img,pid,camid)
返回结果:
可以看到train_loader是dataset的一个实例,可以通过迭代器进行取值,得到的值分别为batch_id,图片,行人id,以及摄像头id。
(但这里需要注意的一点,我们想要的并不是图片本身,而是要把图片转换成一个tensor,后面会提到)
数据采样
在utils中新建sample.py文件,负责对训练集进行采样,每个epoch,我们会对训练集中每个行人采集4张图片,也就是751*4=3004张图片,代码如下:
from __future__ import absolute_import
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
class RandomIdentitySampler(Sampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.
Args:
data_source (Dataset): dataset to sample from.
num_instances (int): number of instances per identity.
"""
def __init__(self, data_source, num_instances=4):
self.data_source = data_source
self.num_instances = num_instances
self.index_dic = defaultdict(list)
for index, (_, pid, _) in enumerate(data_source):
self.index_dic[pid].append(index)
self.pids = list(self.index_dic.keys())
self.num_identities = len(self.pids)
def __iter__(self):
indices = torch.randperm(self.num_identities)
ret = []
for i in indices:
pid = self.pids[i]
t = self.index_dic[pid]
replace = False if len(t) >= self.num_instances else True
t = np.random.choice(t, size=self.num_instances, replace=replace)
ret.extend(t)
return iter(ret)
def __len__(self):
return self.num_identities * self.num_instances
if __name__ == "__main__":
dataset =dataset_manager.init_img_dataset(root='/home/dmb/Desktop/materials/data',name="market1501")
train_loader = ImageDataset(dataset.train)
sample = RandomIdentitySampler(train_loader)
for ret in enumerate(sample):
print(ret)
打印ret结果如下:
括号内第一项表示为第n张图片,第二项表示为图片对应的标号。
5.数据预处理
这个地方还应该有关于数据增强的部分,但是pytorch提供了丰富的数据增强方法,这里就不自己写了,后面直接使用提供的数据增强方法。
如果后面需要自己设计数据增强方法,那就再记录把!