深度模型(七):Sampled Softmax

Softmax

给定softmax的输入 ( z 1 , z 2 , . . . , z n ) (z_1,z_2,...,z_n) ,则输出为 f ( z 1 , f ( z 2 ) , . . . , f ( z n ) ) f(z_1,f(z_2),...,f(z_n)) ,其中 f ( z i ) , i [ 1 , n ] f(z_i),i\in[1,n] 的计算方式为:

f ( z i ) = e z i j = 1 n e z j f(z_i)=\frac{e^{z_i}}{\sum_{j=1}^ne^{z_j}}

Sampled softmax

目前流行的基于神经网络的机器翻译(NMT)模型,采用的是Encoder-Decoder结构,对于输入序列 x = ( x 1 , x 2 , . . . , x S n ) \boldsymbol{x}=(x_1,x_2,...,x_{S_n}) ,生成对应的目标序列 y = ( y 1 , y 2 , . . . , y T n ) \boldsymbol{y}=({y_1,y_2,...,y_{T_n})} ,模型的建模目标是最大化目标序列的条件概率。

l o g P ( y x ) = t = 1 T n P ( y t y < t , x ) logP(\boldsymbol{y}|\boldsymbol{x})=\sum_{t=1}^{T_n}P(y_t|y_{<t},\boldsymbol{x})

对于包含N个样本的训练数据,模型的训练目标就是最大化整体条件概率:

θ = a r g m a x θ n = 1 N t = 1 T n l o g p ( y t n y < t n , x n ) \theta^*=argmax_{\theta}\sum_{n=1}^N\sum_{t=1}^{T_n}logp(y_t^n|y_{<t}^n,\boldsymbol{x}_n)

详细的模型结构在这里就不再展开,可以参考我之前的文章深度模型(二):Attention,在这里我们主要关注模型softmax层的计算。

softmax层输出目标序列中第 t t 位的符号概率分布,计算方式为:

p ( y t y < t , x ) = e x p ( w t T ϕ ( y t 1 , z t , c t ) + b t ) Z p(y_t|y_{<t},x)=\frac{exp(w_t^T\phi(y_{t-1},z_t,c_t)+b_t)}{Z}
= e x p ( w t T ϕ ( y t 1 , z t , c t ) + b t ) y k V e x p ( w k T ϕ ( y t 1 , z t , c t ) + b k ) =\frac{exp(w_t^T\phi(y_{t-1},z_t,c_t)+b_t)}{\sum_{y_k\in V}exp(w_k^T\phi(y_{t-1},z_t,c_t)+b_k)}

其中V表示目标序列的词汇表, y t 1 y_{t-1} 表示目标序列中前一位的符号, z t z_t 表示Decoder当前的隐状态, c t c_t 表示Encoder隐状态的Attention值。

可以看出为了计算目标符号 y t y_t 的条件概率,必须计算 Z Z 值,这需要对词表 V V 中的符号进行遍历,计算量随着词表规模的变大而变大,目前词表的规模从几千到几万不等。

为了支持超大规模的词表,一个很自然的思路就是,能不能通过一些算法达到近似计算 Z Z 值的目的呢。论文《On Using Very Large Target Vocabulary for Neural Machine Translation》提出了一种对 Z Z 值的近似计算方法,这就是sampled softmax:

p ( y t y < t , x ) = e x p ( w t T ϕ ( y t 1 , z t , c t ) + b t ) Z ^ p(y_t|y_{<t},x)=\frac{exp(w_t^T\phi(y_{t-1},z_t,c_t)+b_t)}{\widehat Z}
= e x p ( w t T ϕ ( y t 1 , z t , c t ) + b t ) y k V e x p ( w k T ϕ ( y t 1 , z t , c t ) + b k ) =\frac{exp(w_t^T\phi(y_{t-1},z_t,c_t)+b_t)}{\sum_{y_k\in V'}exp(w_k^T\phi(y_{t-1},z_t,c_t)+b_k)}

其中 V V' 就是采样得到的词表,词表规模要远小雨整体的词表 V V ,因此整体的词表 V V 的规模不再造成计算量增长的问题。采样方式和 V V' 的选择方式,以后有时间再补上。

发布了52 篇原创文章 · 获赞 105 · 访问量 7万+

猜你喜欢

转载自blog.csdn.net/gaofeipaopaotang/article/details/99071854