代码参考:https://github.com/heshuting555/TransReID
论文参考:https://arxiv.org/abs/2102.04378
1.环境
ubuntu16.04
python3.6
cuda110
torch==1.7.0+cu110
torchvision==0.8.0
timm
yacs
opencv-python==4.1.0.25
2.模型准备
https://www.kaggle.com/abhinand05/vit-base-models-pretrained-pytorch
3.修改
将下面路径修改为你自己的路径:
1)预训练模型路径
2)数据路径
3)模型保存路径
修改configs/Market/vit_transreid.yml文件,都改为自己的(最后一个是因为模型保存在TransReID上一级目录中):
MODEL:
PRETRAIN_PATH: '/home/***/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth'
DATASETS:
NAMES: ('market1501')
ROOT_DIR: ('/root/datasets/')
TEST:
EVAL: True
IMS_PER_BATCH: 256
RE_RANKING: False
WEIGHT: './logs/market_vit_transreid/transformer_120.pth'
OUTPUT_DIR: './logs/market_vit_transreid'
4.训练
python train.py --config_file configs/Market/vit_transreid.yml MODEL.DEVICE_ID "('0')"