前言
接下来,我们就来看看视频行人重识别训练模型的其中一种temporal aggregation method:temporal attention。
在这个模型中对sequence of image feature使用attention weighted average,给每一个切片c一个attention系数。
论文中叙述这是效果最好的一种方式,如C部分:
模型输入
- imgs
- imgs.size() = [b,s,c,h,w]
- 在训练级中 b为batch通常设置为32,seq_len设置为4,c为通道数为3,h图片高,w图片宽
模型初始化参数
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={
'xent', 'htri'})
- name 使用的模型名称
- dataset.num_train_pids 分类时的分类数
- loss xent=交叉熵损失 htri=Tripletloss
模型实现
class ResNet50TA(nn.Module):
def __init__(self, num_classes, loss={
'xent'}, **kwargs):
# 使用父类初始化方法
super(ResNet50TA, self).__init__()
# loss
self.loss = loss
# 使用resnet50
resnet50 = torchvision.models.resnet50(pretrained=True)
# 使用resnet50 除了最后两层
self.base = nn.Sequential(*list(resnet50.children())[:-2])
# attention network 1
self.att_gen = 'softmax' # method for attention generation: softmax or sigmoid
self.feat_dim = 2048 # feature dimension
self.middle_dim = 256 # middle layer dimension
self.classifier = nn.Linear(self.feat_dim, num_classes)
# 输入通道数2048,输出256
self.attention_conv = nn.Conv2d(self.feat_dim, self.middle_dim, [7,4]) # 7,4 cooresponds to 224, 112 input image size
self.attention_tconv = nn.Conv1d(self.middle_dim, 1, 3, padding=1)
# x=[b,s,c,w,h]
def forward(self, x):
b = x.size(0)
t = x.size(1)
x = x.view(b*t, x.size(2), x.size(3), x.size(4))
# x =[128,2048,7,4] resnet 得到每一帧的特征f
x = self.base(x)
# a = [128,256,1,1] 先经过卷基层,在经过relu函数 计算score
a = F.relu(self.attention_conv(x))
# a =[32,4,256] 重构张量维度
a = a.view(b, t, self.middle_dim)
# a= [32,256,4]
a = a.permute(0,2,1)
# a = [32,1,4] spatial + temporal conv
a = F.relu(self.attention_tconv(a))
# a[32,4]
a = a.view(b, t)
# x=[128,2048,1,1] 平均池化 这里得到的是每一帧的features
x = F.avg_pool2d(x, x.size()[2:])
# 两种方式计算最终的attention softmax sigmod
if self. att_gen=='softmax':
# print("softmax")
# a=[32,4] 归一化
a = F.softmax(a, dim=1)
elif self.att_gen=='sigmoid':
# print("sigmoid")
a = F.sigmoid(a)
a = F.normalize(a, p=1, dim=1)
else:
raise KeyError("Unsupported attention generation function: {}".format(self.att_gen))
# x= [32,4,2048] 重构feature [b,s,f]
x = x.view(b, t, -1)
# a = [32,4,1]
# 在最后一维上增加一维
a = torch.unsqueeze(a, -1)
# list.extend(list1) 参数必须是列表类型,可以将参数中的列表合并到原列表的末尾,使原来的 list长度增加len(list1)。
# a = [32, 4, 2048]
a = a.expand(b, t, self.feat_dim)
# 矩阵x和a对应位相乘,x和a的维度必须相等 计算得分
att_x = torch.mul(x,a)
# 对attr的第二维进行求和了?
att_x = torch.sum(att_x,1)
# f = [32,2048] [b,f]
f = att_x.view(b,self.feat_dim)
# 不是训练 输出特征f
if not self.training:
return f
# 全链接层
y = self.classifier(f)
# 根据loss 返回不同的参数
if self.loss == {
'xent'}:
return y
elif self.loss == {
'xent', 'htri'}:
return y, f
elif self.loss == {
'cent'}:
return y, f
else:
raise KeyError("Unsupported loss: {}".format(self.loss))