RNN系列

import torch
import torch.nn as nn

seed = 0
torch.manual_seed( seed )



def compute_diff( t1, t2 ):
    return (t1 - t2).mean().item()






################### 研究GRU的计算原理 ###################




# 5表示序列长度
# 7表示batch_size
# 10表示输入x的维度
x_seq = torch.randn(5, 7, 10)




# 1表示RNN层数(即使是1层也不能少了这一维)
# 7表示batch_size
# 20表示隐藏状态h的维度
h_init = torch.randn(1, 7, 20)




# 定义1层GRU
gru = nn.GRU( input_size=10, hidden_size=20, num_layers=1 )





# GRU参数
W_hh, b_hh = gru.weight_hh_l0, gru.bias_hh_l0
W_ih, b_ih = gru.weight_ih_l0, gru.bias_ih_l0



# W_hh实际上是由( W_hr | W_hz | W_hn )拼接而成
W_hr, W_hz, W_hn = W_hh[0:20], W_hh[20:40], W_hh[40:60]
b_hr, b_hz, b_hn = b_hh[0:20], b_hh[20:40], b_hh[40:60]

# W_ih实际上是由( W_ir | W_iz | W_in )拼接而成
W_ir, W_iz, W_in = W_ih[0:20], W_ih[20:40], W_ih[40:60]
b_ir, b_iz, b_in = b_ih[0:20], b_ih[20:40], b_ih[40:60]











print( '\nStep1: 计算GRU(检查output_seq[-1]和h_last是否相等)' )
output_seq, h_last = gru( x_seq, h_init )

# output_seq.size() = [5, 7, 20]
# h_last.size() = [1, 7, 20]

# GRU的输出值一定是hidden状态组成的序列
# 即GRU将序列x转换为序列hidden




# RNN直接将hidden状态用作输出
# 所以最后的output_seq的最后一个元素,与h_last相等
print( 'equal =', torch.equal( output_seq[-1], h_last.squeeze() ) )
print( 'diff =', compute_diff( output_seq[-1], h_last.squeeze() ) )












# 自己实现GRU中的计算
print( '\nStep2: 自己实现GRU中的计算' )



my_output_seq = torch.zeros( [5, 7, 20] )
ht = h_init.squeeze()

for i in range(5):

    xt = x_seq[i]
    h_prev = ht

    rt = torch.sigmoid( xt.mm( W_ir.t() ) + b_ir + h_prev.mm( W_hr.t() ) + b_hr )   # reset gate
    zt = torch.sigmoid( xt.mm( W_iz.t() ) + b_iz + h_prev.mm( W_hz.t() ) + b_hz )   # update gate
    nt = torch.tanh( xt.mm( W_in.t() ) + b_in + rt * ( h_prev.mm( W_hn.t() ) + b_hn ) ) # new gate
    ht = ( 1 - zt ) * nt + zt * h_prev

    my_output_seq[i] = ht


print( 'equal =', torch.equal( my_output_seq, output_seq ) )
print( 'diff =', compute_diff( my_output_seq, output_seq ) )
print( '【注】由于浮点数计算的误差,不可能保证计算结果完全相等,但diff值足够小' )
















print( '\nStep3: 通过RNNCell的方式进行计算' )

gru_cell = nn.GRUCell(10, 20)

gru_cell.weight_hh, gru_cell.bias_hh = W_hh, b_hh
gru_cell.weight_ih, gru_cell.bias_ih = W_ih, b_ih

hx = h_init.squeeze()
out = torch.zeros( [5, 7, 20] )

for i in range(5):
    hx = gru_cell( x_seq[i], hx )
    out[i] = hx


print( 'equal =', torch.equal( out, output_seq ) )
print( 'diff =', compute_diff( out, output_seq ) )



猜你喜欢

转载自blog.csdn.net/o0Helloworld0o/article/details/82287151
RNN