Knowing When to Look: Adaptive Attention via a Visual Sentinel for Image Captioning部分代码

class AdaptiveLSTMCell(nn.Module):
    def __init__(self, inputSize, hiddenSize):
        super(AdaptiveLSTMCell, self).__init__()
        self.hiddenSize = hiddenSize
        self.w_ih = nn.Parameter(torch.Tensor(5 * hiddenSize, inputSize))
        self.w_hh = nn.Parameter(torch.Tensor(5 * hiddenSize, hiddenSize))
        self.b_ih = nn.Parameter(torch.Tensor(5 * hiddenSize))
        self.b_hh = nn.Parameter(torch.Tensor(5 * hiddenSize))
        self.init_parameters()

    def init_parameters(self):
        stdv = 1.0 / math.sqrt(self.hiddenSize)
        self.w_ih.data.uniform_(-stdv, stdv)
        self.w_hh.data.uniform_(-stdv, stdv)
        self.b_ih.data.fill_(0)
        self.b_hh.data.fill_(0)

    def forward(self, inp, states):
        ht, ct = states
        gates = F.linear(inp, self.w_ih, self.b_ih) + F.linear(ht, self.w_hh, self.b_hh)
        
        ingate, forgetgate, cellgate, outgate, sgate = gates.chunk(5, 1)

        ingate = F.sigmoid(ingate)
        forgetgate = F.sigmoid(forgetgate)
        outgate = F.sigmoid(outgate)       
        cellgate = F.tanh(cellgate)
        
        c_new = (forgetgate * ct) + (ingate * cellgate)
        h_new = outgate * F.tanh(c_new)
        
        sgate = F.sigmoid(sgate)
        s_new = sgate * F.tanh(c_new)  # visual sentinel
        
        return h_new, c_new, s_new

猜你喜欢

转载自www.cnblogs.com/czhwust/p/10614163.html