mmdetection通过指定anchor_scales,anchor_ratios,anchor_strides来生成anchor。
anchor_scales是通过scales_per_octave和octave_base_scales计算得到的。计算过程位于:mmdetection/mmdetection/mmdet/models/anchor_heads/retina_head.py
octave_scales = np.array([2**(i / scales_per_octave) for i in range(scales_per_octave)])
anchor_scales = octave_scales * octave_base_scale
anchor_strides取决于特征图位于第几个strides,我们可以调整anchor_scales.
相关函数:
mmdet/models/anchor_heads/anchor_head.py
mmdet/core/anchor/anchor_generator.py
通过计算,在每个层次的特征图上都生成anchor.
(1)我们网络最后需要的anchors是N*4的tensor,其中N是每个cell对应的anchor的数目,4是该anchor在原图中对应的左上角坐标和右下角坐标,我们根据这个坐标去对应的特征图中裁剪特征。
为了得到一个feature_map中每个cell的左上和右下角坐标(get_anchors),我们可以先得到cell(0,0)所对应的anchor的左上角和右下角坐标(gen_base_anchors),然后根据偏移量依次偏移得到一个feature_map上所有cell的左上和右下角坐标。
(2)为了得到cell(0,0)对应的anchors的左上角坐标和右下角坐标,我们先计算cell(0,0)所对应anchors的中心点和长宽。
假设该feature_map相对于原图的stride=8(即下采样8倍),那么cell(0,0)所对应的区域就是原图的[0:8,0:8]。这样,cell(0,0)的中心点为(3.5,3.5)。
anchors的长宽由该feature_map的stride(作为anchors_strides),指定的anchors_scales和anchor_ratios构成。其中
w = anchor_stride
h = anchor_strife
h_ratios = torch.sqrt(anchor_ratios)
w_ratios = 1 / h_ratios
ws = (w * w_ratios* anchor_scales)
hs = (h * h_ratios * anchor_scales)
(3)因此,当我们想要引入自己聚类得到的anchors时,首先根据该anchors的长宽决定它在哪一层的feature_map中,然后通过计算得到对应的anchors_scales和anchor_ratios.
import torch
class AnchorGenerator(object):
"""
Examples:
>>> from mmdet.core import AnchorGenerator
>>> self = AnchorGenerator(9, [1.], [1.])
>>> all_anchors = self.grid_anchors((2, 2), device='cpu')
>>> print(all_anchors)
tensor([[ 0., 0., 8., 8.],
[16., 0., 24., 8.],
[ 0., 16., 8., 24.],
[16., 16., 24., 24.]])
"""
def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
self.base_size = base_size
self.scales = torch.Tensor(scales)
self.ratios = torch.Tensor(ratios)
self.scale_major = scale_major
self.ctr = ctr
self.base_anchors = self.gen_base_anchors()
@property
def num_base_anchors(self):
return self.base_anchors.size(0)
def gen_base_anchors(self):
w = self.base_size
h = self.base_size
if self.ctr is None:
x_ctr = 0.5 * (w - 1)
y_ctr = 0.5 * (h - 1)
else:
x_ctr, y_ctr = self.ctr
h_ratios = torch.sqrt(self.ratios)
w_ratios = 1 / h_ratios
if self.scale_major:
ws = (w * w_ratios[:, None] * self.scales[None, :]).view(-1)
hs = (h * h_ratios[:, None] * self.scales[None, :]).view(-1)
else:
ws = (w * self.scales[:, None] * w_ratios[None, :]).view(-1)
hs = (h * self.scales[:, None] * h_ratios[None, :]).view(-1)
# yapf: disable
base_anchors = torch.stack(
[
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
],
dim=-1).round()
# yapf: enable
return base_anchors
def _meshgrid(self, x, y, row_major=True):
xx = x.repeat(len(y))
yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
if row_major:
return xx, yy
else:
return yy, xx
def grid_anchors(self, featmap_size, stride=16, device='cuda'):
base_anchors = self.base_anchors.to(device)
feat_h, feat_w = featmap_size
shift_x = torch.arange(0, feat_w, device=device) * stride
shift_y = torch.arange(0, feat_h, device=device) * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
shifts = shifts.type_as(base_anchors)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.view(-1, 4)
# first A rows correspond to A anchors of (0, 0) in feature map,
# then (0, 1), (0, 2), ...
return all_anchors
def valid_flags(self, featmap_size, valid_size, device='cuda'):
feat_h, feat_w = featmap_size
valid_h, valid_w = valid_size
assert valid_h <= feat_h and valid_w <= feat_w
valid_x = torch.zeros(feat_w, dtype=torch.uint8, device=device)
valid_y = torch.zeros(feat_h, dtype=torch.uint8, device=device)
valid_x[:valid_w] = 1
valid_y[:valid_h] = 1
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy
valid = valid[:,
None].expand(valid.size(0),
self.num_base_anchors).contiguous().view(-1)
return valid
if __name__ == '__main__':
self = AnchorGenerator(8,[8], [0.5])
all_anchors = self.grid_anchors((5, 5), device='cpu')
print(all_anchors)
另一种处理anchor的办法:
我们首先通过聚类生成自己的anchor,然后在每个层次的特征图上都指定使用我们聚类生成的anchor。这样,也不用像前面那种方法,根据anchor_stride,anchor_ratio等去计算每层生成的anchor的实际大小等。