首先安利一篇博客,分析这篇论文很清楚:https://blog.csdn.net/sinat_26253653/article/details/79416234
本文主要是在这篇博客的基础上结合代码进行分析。
文章依然采用了encoder-decoder的框架。作者认为decoder的时候非视觉词多依赖的是语义信息而不是视觉信息。而且,在生成caption的过程中,非视觉词的梯度会误导或者降低视觉信息的有效性。因此,本文提出了带有视觉标记的自适应的attention模型(adative attention model with a visual sentinel),在每一个time step,模型决定更依赖于图像还是visual sentinel。其中,visual sentinel存放了decoder已经知道的信息。
本文的贡献在于 :
1、提出了带有视觉标记的自适应的attention模型
2、提出了新的spatial attention机制
3、提出了LSTM的扩展,在hidden state以外加入了一个额外的visual sentinel vector
模型结构:
1、首先是encoder部分:
encoder分别提取图像的全局特征和局部特征。采用的是resnet-152,去除最后两层,提取卷积特征。
v_g代表全局特征([batch_size,256]),v代表局部特征([batch_size,49,512])。
局部特征分为v1,v2.v3.....v49
A = self.resnet_conv( images )#[batch_size,2048,7,7]
a_g = self.avgpool( A ) #[batch_size,2048,1,1]
a_g = a_g.view( a_g.size(0), -1 ) #[batch_size,2048]
# V = [ v_1, v_2, ..., v_49 ]
V = A.view( A.size( 0 ), A.size( 1 ), -1 ).transpose( 1,2 )#[2, 49, 2048]
V = F.relu( self.affine_a( self.dropout( V ) ) )#[2, 49, 512]
v_g = F.relu( self.affine_b( self.dropout( a_g ) ) )
return V, v_g
#V:[batch_size, 49, 512(hidden size)] v_g:[batch_size,256(embeding size)]
2、其次是decoder部分
先上模型结构图
首先LSTM的输入不再只是当前时刻的word_embedding,而是和全局图像特征cat起来。
即
作者在LSTM中加入sentinel(哨兵机制),产生s_t:
具体做法是:
相应代码为:
g_t的定义与公式略有不同。
# g_t = sigmoid( W_x * x_t + W_h * h_(t-1) )
gate_t = self.affine_x( self.dropout( x_t ) ) + self.affine_h( self.dropout( h_t_1 ) )
gate_t = F.sigmoid( gate_t )
# Sentinel embedding
s_t = gate_t * F.tanh( cell_t )
return s_t
attention机制:
输入各个局部特征V,各个时刻的hidden_state、s_t
对于普通的attention:
基于hidden_state,decoder会关注图像的不同区域,ct就是该区域经过CNN后提取出的feature map。
这就是在不使用自适应时对各个局部特征加权得到的最终图像特征。
相关代码:
# W_v * V + W_g * h_t * 1^T
content_v = self.affine_v( self.dropout( V ) ).unsqueeze( 1 ) \
+ self.affine_g( self.dropout( h_t ) ).unsqueeze( 2 )
# z_t = W_h * tanh( content_v )
z_t = self.affine_h( self.dropout( F.tanh( content_v ) ) ).squeeze( 3 )
alpha_t = F.softmax( z_t.view( -1, z_t.size( 2 ) ) ).view( z_t.size( 0 ), z_t.size( 1 ), -1 )
# Construct c_t: B x seq x hidden_size
c_t = torch.bmm( alpha_t, V ).squeeze( 2 )
在本文中,作者使用了自适应,context vector变为:
权重也变为:
上述公式可以简化为:
相应代码:
# W_s * s_t + W_g * h_t
content_s = self.affine_s( self.dropout( s_t ) ) + self.affine_g( self.dropout( h_t ) )
# w_t * tanh( content_s )
z_t_extended = self.affine_h( self.dropout( F.tanh( content_s ) ) )
# Attention score between sentinel and image content
extended = torch.cat( ( z_t, z_t_extended ), dim=2 )
alpha_hat_t = F.softmax( extended.view( -1, extended.size( 2 ) ) ).view( extended.size( 0 ), extended.size( 1 ), -1 )
beta_t = alpha_hat_t[ :, :, -1 ]
# c_hat_t = beta * s_t + ( 1 - beta ) * c_t
beta_t = beta_t.unsqueeze( 2 )
c_hat_t = beta_t * s_t + ( 1 - beta_t ) * c_t
return c_hat_t, alpha_t, beta_t
最后计算单词的概率分布:
相应公式为
相应代码为:
scores = self.mlp( self.dropout( c_hat + hiddens ) )#最后单词的概率分布
计算损失函数:
loss = LMcriterion( packed_scores[0], targets )#评价
最后放一下代码链接