ECCV 2022 | Learning Implicit Feature Alignment Function for Semantic Segmentation概述与代码分析
主要工作
基于隐神经表示设计了一种隐式特征对齐函数,来替换现有的基于插值的不同分辨率特征对齐方案。可以更加方便和高效的对齐多个不同分辨率的特征。
原始的隐式特征函数:
不考虑专业术语。直观来讲,隐式特征函数本身是基于原始特征和目标特征之间的坐标关系,构建了一个从原始特征到目标特征映射变换。其中的变换关系可以通过神经网络学习和建模。
这一过程需要提供以下三种信息:
- 已有的原始特征 z z z
- 原始特征对应的连续的(归一化)网格坐标 x x x
- 我们想要生成的目标特征/预测对应的连续的(归一化)网格坐标 x q x_q xq
注意这里强调了归一化坐标。这一方法核心的一个假定是网格坐标系本身是对齐的,可能只是单位刻度上有差异。
通过这些信息,我们可以利用坐标之间的相对关系,从原始特征变换得到目标特征/预测。
需要注意的是,这一变换过程中,主要关注坐标系中与目标位置最邻近的原始特征点。
在此基础上,作者们引入了相对位置编码获得了更好的对齐效果:
通过同时集成多个不同层级的特征来实现对于最终预测的检索和计算:
实验结果
核心代码解析
- https://github.com/hzhupku/IFA/blob/main/pyseg/models/ifa_utils.py
- https://github.com/hzhupku/IFA/blob/main/pyseg/models/ifa.py
- https://github.com/hzhupku/IFA/blob/main/pyseg/models/fpn_ifa.py
import torch
import torch.nn as nn
import torch.nn.functional as F
def make_coord(hw, flatten=True):
"""构建网格坐标系,原点位于各轴有效范围的中心点。
使用的网格坐标系的三个参考点:网格区域的左右边界为-1和1或者ranges的两个值,正中心为0。
返回的网格坐标为 [N,[...,]len(hw)],其中最后一维表示具体的坐标,坐标顺序与hw中轴的顺序一致。
"""
start_idx, end_idx = -1, 1
axes_grid_centers = []
for i, n in enumerate(hw):
# 单一轴向的半个网格的宽度
width_of_half_grid = (end_idx - start_idx) / (2 * n)
# 这里计算的是各个方形网格区域的中心点坐标。
start_grid_center = start_idx + width_of_half_grid
grid_centers = (
start_grid_center + (2 * width_of_half_grid) * torch.arange(n).float()
)
# 使用linspace替换会导致精度无法对齐
# end_grid_center = end_idx - width_of_half_grid
# grid_centers = torch.linspace(start_grid_center, end_grid_center, steps=n)
axes_grid_centers.append(grid_centers)
paired_grid_centers = torch.stack(
torch.meshgrid(*axes_grid_centers, indexing="ij"), dim=-1
)
if flatten:
paired_grid_centers = paired_grid_centers.reshape(
-1, paired_grid_centers.shape[-1]
)
return paired_grid_centers
def ifa_feat_ann(src, tgt_hw, stride=1, local_ensemble=False):
bs, src_h, src_w = src.shape[0], src.shape[-2], src.shape[-1]
tgt_h, tgt_w = tgt_hw
coord_tgt_hw = make_coord((tgt_h, tgt_w)).to(device=src.device)
# hw,[tgt_h_id,tgt_w_id] =(repeat)=> bs,hw,[tgt_h_id,tgt_w_id] in (-1,1)
coord_tgt_hw = coord_tgt_hw.unsqueeze(0).expand(bs, *coord_tgt_hw.shape)
# 使用后可以与原始实现对齐,但是实际属于冗余操作
# coord_tgt_hw = (coord_tgt_hw + 1) / 2 * 2 - 1
coord_src_hw = make_coord((src_h, src_w), flatten=False).to(device=src.device)
# src_h,src_w,[src_h_id,src_w_id]
# => [src_h_id,src_w_id],src_h,src_w
coord_src_hw = coord_src_hw.permute(2, 0, 1)
# =(repeat)=> bs,[src_h_id,src_w_id],src_h,src_w in (-1,1)
coord_src_hw = coord_src_hw.unsqueeze(0).expand(bs, 2, src_h, src_w)
if local_ensemble:
# 利用局部ensemble来缓解基于索引的预测方式导致的预测不连续的问题
# 直接利用目标位置与周围四个隐编码位置之间的包围矩形面积来加权组合获得的四个预测,
# 从而平滑索引改变时带来的预测变化。
# 这一加权平滑的方式基本是沿用了双线性插值的思路。
tgt_x_shifts = [-1, 1]
tgt_y_shifts = [-1, 1]
eps_shift = 1e-6
rel_coord_hws = []
src2tgt_feats = []
areas = []
else:
tgt_x_shifts, tgt_y_shifts, eps_shift = [0], [0], 0
# tgt网格坐标系下的相对步长
tgt_x_stride = stride / tgt_w
tgt_y_stride = stride / tgt_h
for tgt_x_shift in tgt_x_shifts:
for tgt_y_shift in tgt_y_shifts:
# bs,hw,[tgt_w_id,tgt_h_id] in (-1,1)
coord_tgt_xy = coord_tgt_hw.flip(-1).clone()
# 在考虑局部ensemble的时候,这里对tgt坐标进行一个单位的相对偏移后再对src进行查询与映射
coord_tgt_xy[:, :, 0] += tgt_x_shift * tgt_x_stride + eps_shift
coord_tgt_xy[:, :, 1] += tgt_y_shift * tgt_y_stride + eps_shift
coord_tgt_xy.clamp_(-1 + 1e-6, 1 - 1e-6)
# bs,1,hw,[tgt_w_id,tgt_h_id]
coord_tgt_xy = coord_tgt_xy.unsqueeze(1)
# 使用tgt网格坐标对src特征网格坐标调整
# 采样 bs,[src_h_id,src_w_id],src_h,src_w 到 bs,[src_h_id',src_w_id'],1,hw
coord_src2tgt_hw = F.grid_sample(
coord_src_hw, coord_tgt_xy, mode="nearest", align_corners=False
)
# bs,hw,[src_h_id',src_w_id']
coord_src2tgt_hw = coord_src2tgt_hw[:, :, 0, :].permute(0, 2, 1)
# 与nearest latent code,即这里的src,相对坐标偏移
rel_coord_hw = coord_tgt_hw - coord_src2tgt_hw
rel_coord_hw[:, :, 0] *= src_h # src.shape[-2]
rel_coord_hw[:, :, 1] *= src_w # src.shape[-1]
# 使用目标网格坐标对输入特征重新采样
# bs,c,src_h,src_w => bs,c,1,tgt_h*tgt_w => bs,tgt_h*tgt_w,c
src2tgt_feat = F.grid_sample(
src, coord_tgt_xy, mode="nearest", align_corners=False
)
src2tgt_feat = src2tgt_feat[:, :, 0, :].permute(0, 2, 1)
if local_ensemble:
rel_coord_hws.append(rel_coord_hw)
src2tgt_feats.append(src2tgt_feat)
# 在局部ensemble的时候,需要统计tgt与周围四个src位置之间矩形的面积,用来加权平均从而平滑结果
# 而面积的计算正好是相对坐标乘积的绝对值
area = torch.abs(rel_coord_hw[:, :, 0] * rel_coord_hw[:, :, 1])
areas.append(area + 1e-9)
if not local_ensemble:
return rel_coord_hw, src2tgt_feat
else:
return rel_coord_hws, src2tgt_feats, areas
class ifa_simfpn(nn.Module):
def __init__(...):
super().__init__()
if learn_pe:
self.pos1 = PositionEmbeddingLearned(self.pos_dim // 2)
self.pos2 = PositionEmbeddingLearned(self.pos_dim // 2)
self.pos3 = PositionEmbeddingLearned(self.pos_dim // 2)
self.pos4 = PositionEmbeddingLearned(self.pos_dim // 2)
if ultra_pe:
self.pos1 = SpatialEncoding(2, self.pos_dim, require_grad=require_grad)
self.pos2 = SpatialEncoding(2, self.pos_dim, require_grad=require_grad)
self.pos3 = SpatialEncoding(2, self.pos_dim, require_grad=require_grad)
self.pos4 = SpatialEncoding(2, self.pos_dim, require_grad=require_grad)
self.pos_dim += 2
in_dim = 4 * (256 + self.pos_dim)
if unfold:
in_dim = 4 * (256 * 9 + self.pos_dim)
self.imnet = ... # in_dim -> num_classes
def forward(self, x, size, level=0, after_cat=False):
h, w = size
if after_cat:
return self.imnet(x).reshape(x.shape[0], -1, h, w)
# Feature unfolding: 为了丰富隐码包含的信息,对特征中3×3相邻隐码合并
if self.unfold:
x = F.unfold(x, 3, padding=1).reshape(
x.shape[0], x.shape[1] * 9, x.shape[2], x.shape[3]
)
if not self.local:
rel_coord_hw, src2tgt_feat = ifa_feat_ann(src=x, tgt_hw=[h, w])
if self.ultra_pe or self.learn_pe:
rel_coord_hw = eval("self.pos" + str(level))(rel_coord_hw)
x = torch.cat([rel_coord_hw, src2tgt_feat], dim=-1)
else:
rel_coord_hws, src2tgt_feats, areas = ifa_feat_ann(
src=x,
tgt_hw=[h, w],
stride=self.stride,
local_ensemble=True,
)
contexts = []
for rel_coord_hw, src2tgt_feat, area in zip(
rel_coord_hws, src2tgt_feats, areas
):
if self.ultra_pe or self.learn_pe:
rel_coord_hw = eval("self.pos" + str(level))(rel_coord_hw)
contexts.append(torch.cat([rel_coord_hw, src2tgt_feat], dim=-1))
# 这里将对角区域的面积进行了交换。0号与3号,1号与2号
# 整体的特征组合方式与双线性插值形式一致
# 关于双线性插值可见 https://blog.csdn.net/qq_58664081/article/details/129079354
areas[0], areas[3] = areas[3], areas[0]
areas[1], areas[2] = areas[2], areas[1]
total_area = torch.stack(areas).sum(dim=0)
for cxt, area in zip(contexts, areas):
x = cxt * ((area / total_area).unsqueeze(-1))
return x
class fpn_ifa(nn.Module):
def __init__(...):
super().__init__()
...
self.ifa = ifa_simfpn(
ultra_pe=ultra_pe,
pos_dim=pos_dim,
num_classes=num_classes,
local=local,
unfold=unfold,
stride=stride,
learn_pe=learn_pe,
require_grad=require_grad,
num_layer=num_layer,
)
def forward(self, x):
x1, x2, x3, x4 = x
aspp_out = ...
context = []
h, w = x1.shape[-2], x1.shape[-1]
target_feat = [x1, x2, x3, aspp_out]
for i, feat in enumerate(target_feat):
context.append(self.ifa(feat, size=[h, w], level=i + 1))
context = torch.cat(context, dim=-1).permute(0, 2, 1) # B,HW,C -> B,C,HW
return self.ifa(context, size=[h, w], after_cat=True)
这里代码的设计应当是借鉴自图像超分辨算法LIIF中的设计,代码基本一致https://github.com/yinboc/liif/blob/main/models/liif.py。
本文保留了LIIF中的Local Ensemble和Feature Unfolding的设计,但是不同之处主要有两点:
- 相对位置信息的使用不同于LIIF中直接将其作为imnet的输入的部分通道,这里使用了位置编码的方式进行处理。
- 没有使用LIIF中的Cell Decoding。