Video_based_ReID_RNN

前言

接下来,我们就来看看视频行人重识别训练模型的其中一种temporal aggregation method:RNN。
这是在序列模型训练中常用的一种模型,RNN可以提取到连续图像蕴含的信息,这里使用的是最简单的RNN结构。
目前这种方式的试验结果不如其他几种,如B部分:
在这里插入图片描述

模型输入

输入和之前的相同 差别只在经过的网络:

  • imgs
    • imgs.size() = [b,s,c,h,w]
    • 在训练级中 b为batch通常设置为32,seq_len设置为4,c为通道数为3,h图片高,w图片宽

模型初始化参数

        model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={
    
    'xent', 'htri'})
  • name 使用的模型名称
  • dataset.num_train_pids 分类时的分类数
  • loss xent=交叉熵损失 htri=Tripletloss

模型实现

class ResNet50RNN(nn.Module):
    def __init__(self, num_classes, loss={
    
    'xent'}, **kwargs):
        super(ResNet50RNN, self).__init__()
        self.loss = loss
        resnet50 = torchvision.models.resnet50(pretrained=True)
        self.base = nn.Sequential(*list(resnet50.children())[:-2])
        self.hidden_dim = 512
        self.feat_dim = 2048
        self.classifier = nn.Linear(self.hidden_dim, num_classes)
        #                   输入特征维数2048              LSTM中隐层的维度              循环神经网络的层数
        #输入的数据shape=(batch_size,seq_length,embedding_dim),而batch_first默认是False,所以我们的输入数据最好送进LSTM之前将batch_size与seq_length这两个维度调换?
        self.lstm = nn.LSTM(input_size=self.feat_dim, hidden_size=self.hidden_dim, num_layers=1, batch_first=True)

	# x = [32,4,3,224,112] [b,s,c,h,w]
    def forward(self, x):
        # b=32
        b = x.size(0)
        # t= 4
        t = x.size(1)
        # x = [128,3,224,112]
        x = x.view(b*t,x.size(2), x.size(3), x.size(4))
		# x = [128,2048,7,4]
        x = self.base(x)
		# [128,2048,1,1]
        x = F.avg_pool2d(x, x.size()[2:])
        x = x.view(b,t,-1)
        # x = [32,2048,4]
        # 使用RNN直接获取特征
        # output = [32,4,512]?
        output, (h_n, c_n) = self.lstm(x)
        # output = [32,512,4]
        output = output.permute(0, 2, 1)
        # f = [32,512]
        f = F.avg_pool1d(output, t)
        f = f.view(b, self.hidden_dim)
        
        if not self.training:
            return f
        y = self.classifier(f)

        if self.loss == {
    
    'xent'}:
            return y
        elif self.loss == {
    
    'xent', 'htri'}:
            return y, f
        elif self.loss == {
    
    'cent'}:
            return y, f
        else:
            raise KeyError("Unsupported loss: {}".format(self.loss))

猜你喜欢

转载自blog.csdn.net/qq_37747189/article/details/114746628
RNN
今日推荐