STM首次将Memory Network引入VOS领域,引申为一个space-tine的memory network,并实现了较好的分割准确率以及较快的速度。DAVIS2020大赛很多优秀的模型都是根据STM进行改造的,可见其具有很棒的指导意义和研究价值。在STM出现之间,VOS的方法基本包括以下三种:
(1)Propagation-based methods:基于mask传播的方法,该方法主要利用前一帧的mask进行不断的传播,其好处是可以解决appearance变化较大的问题,且速度较快,但不适用于occlusion,drift等问题;
(2)Detection-based methods:基于第一帧的mask,大部分方法需要进行online training,所以fine tune的时间会大大影响模型的性能,无法做到实时,但基于第一帧微调的模型会对occlusion等问题更加鲁棒;
(3)Hybrid methods:以上两种方法的混合产品,基于前一帧和第一帧来对当前帧进行预测分割,这样就融合了两者的优点,Hybrid的一些算法性能和精度也是比前两者优秀的。
因此本文的motivation就是基于Hybrid methods:我们如何设计网络模型,来合理利用更多frame的信息,以对当前帧的mask进行精准预测。因此本文设计了如下图所示的网络结构:
算法的流程大致是这样的:将前面的frame和mask都保存在一个external memory bank中,当对当前帧预测mask时,首先从memory bank中选出若干帧,将这些帧及其mask输入到Memory Encoder中,得到对应的key和value;然后将当前帧输入到Query Decoder中,得到Query的key和value。随后将Memory的key和Query的key进行内积运算以计算相似度,相当于一种时空的Attention机制,为不同时间和区域的value分配权重。将这个相似度图和Memory中的Value相乘,这就是Space-time Memory Read的结果,将该结果和Query的value进行拼接,送入最后的Decoder进行mask的还原预测。
上面有很多术语,其基本的结构大致如上图所示。key和value构成了一个键值对,key的作用是用来寻址,而value保存了一些用来生成mask的更加细节的信息。
Memory Encoder:以memory中之前存放的若干帧及其mask(之前得到的0-1的概率图)为输入,经过一个Resnet18后分成两个分支,一个生成通道数为C/8的key特征图,一个生成通道数为C/2的value特征图(C表示Resnet18生成的特征图通道数);
Query Encoder:将当前帧作为输入,输出得到key和value,维度与Memory Encoder一致(不同的是Memory中存在若干之前的帧,因此还会多一个维度T,用来表示时间)。
Space-time Memory Read是本文的核心部分,已经有了memory的key和value,query的key和value,如何对这些信息加以利用,来生成更精准的pixel-wise的mask预测。如上图所示,首先是将Query的key和Memory的key进行矩阵内积(代码里用的torch.bmm,这分明是普通的矩阵乘法啊),得到一个similarity map,通过一个softmax约束到0-1的范围;然后将这个相似度图与Query的value做内积(代码中仍然是矩阵乘法),相当于为value分配了一个time-space的权重矩阵。最后将Query的value和刚才内积后的结果concat一下,作为从memory中read出来的结果,送入后续的decoder进行最终mask的预测。
这个read操作相当于是对一些原始的memory network的拓展,扩展到3d时空空间来进行像素的matching。
Decoder部分就是基于上述read操作的输出,去重构当前帧的mask,其结构如本文第一张图所示,通过一些refinement module来对图像的像素信息进行还原,并最终生成mask(该mask的尺寸为输入图像的1/4)。
对于多目标分割来说,作者采用了soft aggregation operation的方法,该方法与Fast Video Object Segmentation by Reference-Guided Mask Propagation这篇论文的方法类似。
训练过程主要分为两个阶段,首先pre-train是在有mask标注的静态图像上进行预训练,对这些数据进行随机仿射变换以得到不同数据;随后Main training是在DAVIS和Youtube VOS这类特定的数据集上进行训练,利用三帧来进行训练,采样的间隔为0-25不等。推断的时候是利用5帧的信息来预测mask。
实验结果,由于采用了external memory,即可以利用前面若干帧的信息进行mask预测的指导,因此其精度相比于propagation方法和detection方法都有较大的提升:
一些ablation study的实验结果: