SSD原理及Pytorch代码解读——网络架构(二):特征提取网络及总体计算过程

特征提取网络

前面我们已经知道了SSD采用PriorBox机制,也知道了SSD多层特征图来做物体检测,浅层的特征图检测小物体,深层的特征图检测大物体。上一篇博客也看到了SSD是如何在VGG基础的网络结构上进行一下改进。但现在的问题是SSD是使用哪些卷积层输出的特征图来做目标检测的?如下图所示:
在这里插入图片描述

从上图中可以看到,SSD使用了第4、7、8、9、10、11层的这6个卷积层输出作为特征图来做目标检测,但是这些特征图通道大小不一且数量很大,所以SSD在每一个特征图后面都接上了一个分类与位置卷积层使得输出的通道数符合要求。还有也可以从上图看出这6个特征图尺寸越来越小,而其对应的感受野越来越大。6个特征图上的每一个点分别对应4、6、6、6、4、4个PriorBox。接下来分别利用3×3的卷积,即可得到每一个PriorBox对应的类别与位置预测量。
举个例子,第8个卷积层得到的特征图大小为10×10×512,每个点对应6个PriorBox,一共有600个PriorBox。由于采用的PASCAL VOC数据集的物体类别为21类,因此3×3卷积后得到的类别特征维度为6×21=126,位置特征维度为6×4=24。

源码

代码文件为ssd.py。

# 每个特征图上一个点对应的PriorBox数量
mbox = [4, 6, 6, 6, 4, 4]

def multibox(vgg, extra_layers, cfg, num_classes):
	"""
	建立特征提取网络
	parameter:
		vgg: 基础VGG结构层列表,type:list
		extra_layers: 深度卷积层列表,type:list
		cfg:# 每个特征图上一个点对应的PriorBox数量
		num_classes:类别数量
	return:
		vgg: 基础VGG结构层列表,type:list
		extra_layers: 深度卷积层列表,type:list
		(loc_layers, conf_layers):元组,分别是每一个特征图上的回归层输出列表和分类层输出列表
	"""
    loc_layers = []		# 回归层输出
    conf_layers = []	# 分类层输出
    vgg_source = [21, -2]
    # 取第4、7卷积层输出并接上3×3的卷积
    for k, v in enumerate(vgg_source):
        loc_layers += [nn.Conv2d(vgg[v].out_channels,
                                 cfg[k] * 4, kernel_size=3, padding=1)]
        conf_layers += [nn.Conv2d(vgg[v].out_channels,
                        cfg[k] * num_classes, kernel_size=3, padding=1)]
                        
	# 取第8、9、10、11卷积层输出并接上3×3的卷积
    for k, v in enumerate(extra_layers[1::2], 2):
        loc_layers += [nn.Conv2d(v.out_channels, cfg[k]
                                 * 4, kernel_size=3, padding=1)]
        conf_layers += [nn.Conv2d(v.out_channels, cfg[k]
                                  * num_classes, kernel_size=3, padding=1)]
    return vgg, extra_layers, (loc_layers, conf_layers)

总体网络计算过程

为了更好地梳理网络的前向过程,将从代码角度讲述SSD网络的整个前向过程。

训练阶段

class SSD(nn.Module):
    """Single Shot Multibox Architecture
    The network is composed of a base VGG network followed by the
    added multibox conv layers.  Each multibox layer branches into
        1) conv2d for class conf scores
        2) conv2d for localization predictions
        3) associated priorbox layer to produce default bounding
           boxes specific to the layer's feature map size.
    See: https://arxiv.org/pdf/1512.02325.pdf for more details.

    Args:
        phase: (string) 模型所处阶段,为"test"或者"train"
        size: 输入图像大小
        base: 基础VGG16结构层列表,输入尺寸为300或者500,type:list
        extras: 深度卷积层列表,type:list
        head: "multibox head" 元组,分别是每一个特征图上的回归层输出列表和分类层输出列表
        num_classes:类别数量
    """

    def __init__(self, phase, size, base, extras, head, num_classes):
        super(SSD, self).__init__()
        self.phase = phase
        self.num_classes = num_classes
        self.cfg = voc	# voc为配置信息,用于生成prior box
        self.priorbox = PriorBox(self.cfg)	
        #import pdb
        #pdb.set_trace()
        self.priors = self.priorbox.forward()	# 生成每个特征图上的prior box
        self.size = size

        # SSD network
        self.vgg = nn.ModuleList(base)	# 生成基础VGG结构网络
        # Layer learns to scale the l2 normalized features from conv4_3
        self.L2Norm = L2Norm(512, 20)
        self.extras = nn.ModuleList(extras)	# 生成深度卷积层结构网络

        self.loc = nn.ModuleList(head[0])	# 生成回归网络结构
        self.conf = nn.ModuleList(head[1])	# 生成分类网络结构

        if phase == 'test':
            self.softmax = nn.Softmax(dim=-1)
            self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)

    def forward(self, x):
        """.
		SSD前向传播过程
        Args:
            x: 批量数据. Shape: [batch,3,300,300].

        Return:
            取决于不同阶段:
            test:
               输出预测结果,包括置信度和相应位置预测. Shape: [batch,num_classes*topk,5]

            train:
                结果列表组成:
                    1: 回归网络输出, Shape: [batch,num_priors,4]
                    2:分类网络输出, Shape: [batch,num_priors,num_classes]
                    3: prior box, Shape: [num_priors,4]
        """
        # sources保存特征图,loc与conf保存所有PriorBox的位置与类别预测特征
        sources = list()
        loc = list()
        conf = list()

        # 对输入图像卷积到conv4_3,将特征添加到sources中
        for k in range(23):
            x = self.vgg[k](x)

        s = self.L2Norm(x)
        sources.append(s)

        # 继续卷积到conv7,将特征添加到sources中
        for k in range(23, len(self.vgg)):
            x = self.vgg[k](x)
        sources.append(x)

        # 继续利用额外的卷积层计算,并将特征添加到sources中
        for k, v in enumerate(self.extras):
            x = F.relu(v(x), inplace=True)
            if k % 2 == 1:
                sources.append(x)

        # 对sources中的特征图利用类别与位置网络进行卷积计算,并保存到loc与conf中
        # 列表元素的尺寸为【batch, f_h, f_w, priors, f_h, f_w是该特征图的高和宽,priors是该特征图上每一个点对应的priorbox数量
        for (x, l, c) in zip(sources, self.loc, self.conf):
            loc.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf.append(c(x).permute(0, 2, 3, 1).contiguous())
		
		# 合并多层特征并修改尺寸,num_priors为priors box总数量,如果输入图像为300*300,那么就一共又8732个priors box
        loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)	# shape[batch, num_priors*4]
        conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)	# shape[batch, num_priors*num_classes]
        if self.phase == "test":
            output = self.detect(
                loc.view(loc.size(0), -1, 4),                   # loc preds
                self.softmax(conf.view(conf.size(0), -1,
                             self.num_classes)),                # conf preds
                self.priors.type(type(x.data))                  # default boxes
            )
        else:
            # 对于训练来说,output包括了loc与conf的预测值以及PriorBox的信息
            output = (
                loc.view(loc.size(0), -1, 4),
                conf.view(conf.size(0), -1, self.num_classes),
                self.priors
            )
        return output

预测阶段

class Detect(Function):
    """
    在预测阶段,构建Detect模块(即SSD的最后一层)。
    对边框位置坐标解码,并基于分类置信度对边框位置进行NMS。
    最后选择top_k数量的分类和边框。
    """
    def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh):
        self.num_classes = num_classes	# 类别数
        self.background_label = bkg_label	# 背景类别,数值为0
        self.top_k = top_k	# 筛选数量,默认200
        # Parameters used in nms.
        self.nms_thresh = nms_thresh	# 非最大化抑制阈值
        if nms_thresh <= 0:
            raise ValueError('nms_threshold must be non negative.')
        self.conf_thresh = conf_thresh	# 分类阈值
        self.variance = cfg['variance']		

    def forward(self, loc_data, conf_data, prior_data):
        """
        Args:
            loc_data: (tensor) 回归网络层输出值
                Shape: [batch,num_priors,4]
            conf_data: (tensor) 分类网络输出值
                Shape: [batch*num_priors,num_classes]
            prior_data: (tensor) Prior boxes
                Shape: [num_priors,4]
        """
        num = loc_data.size(0)  # batch size
        num_priors = prior_data.size(0)	# prior box数量,共8732
        output = torch.zeros(num, self.num_classes, self.top_k, 5)
        conf_preds = conf_data.view(num, num_priors,
                                    self.num_classes).transpose(2, 1)

        # 将边框预测偏移值解码成预测边框(以左上角和右下角坐标表示,因为标签也是这种表示形式).
        for i in range(num):
            decoded_boxes = decode(loc_data[i], prior_data, self.variance)	# 解码函数,shape:[num_priors,4]
            # 对每一个类别使用nms
            conf_scores = conf_preds[i].clone()
            for cl in range(1, self.num_classes):
                c_mask = conf_scores[cl].gt(self.conf_thresh)	# 获得大于分类阈值的索引掩码
                scores = conf_scores[cl][c_mask]	# 取出对应的分类置信度
                if scores.dim() == 0:
                    continue
                l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)	# shape:[num_priors,4]
                boxes = decoded_boxes[l_mask].view(-1, 4)	# 取出对应预测框
                # 使用nms进行筛选,并选出top_k个
                ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
                output[i, cl, :count] = \
                    torch.cat((scores[ids[:count]].unsqueeze(1),
                               boxes[ids[:count]]), 1)
        flt = output.contiguous().view(num, -1, 5)
        _, idx = flt[:, :, 0].sort(1, descending=True)
        _, rank = idx.sort(1)
        flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
        return output

猜你喜欢

转载自blog.csdn.net/weixin_41693877/article/details/107580985
今日推荐