class AdaptiveLSTMCell(nn.Module): def __init__(self, inputSize, hiddenSize): super(AdaptiveLSTMCell, self).__init__() self.hiddenSize = hiddenSize self.w_ih = nn.Parameter(torch.Tensor(5 * hiddenSize, inputSize)) self.w_hh = nn.Parameter(torch.Tensor(5 * hiddenSize, hiddenSize)) self.b_ih = nn.Parameter(torch.Tensor(5 * hiddenSize)) self.b_hh = nn.Parameter(torch.Tensor(5 * hiddenSize)) self.init_parameters() def init_parameters(self): stdv = 1.0 / math.sqrt(self.hiddenSize) self.w_ih.data.uniform_(-stdv, stdv) self.w_hh.data.uniform_(-stdv, stdv) self.b_ih.data.fill_(0) self.b_hh.data.fill_(0) def forward(self, inp, states): ht, ct = states gates = F.linear(inp, self.w_ih, self.b_ih) + F.linear(ht, self.w_hh, self.b_hh) ingate, forgetgate, cellgate, outgate, sgate = gates.chunk(5, 1) ingate = F.sigmoid(ingate) forgetgate = F.sigmoid(forgetgate) outgate = F.sigmoid(outgate) cellgate = F.tanh(cellgate) c_new = (forgetgate * ct) + (ingate * cellgate) h_new = outgate * F.tanh(c_new) sgate = F.sigmoid(sgate) s_new = sgate * F.tanh(c_new) # visual sentinel return h_new, c_new, s_new
Knowing When to Look: Adaptive Attention via a Visual Sentinel for Image Captioning部分代码
猜你喜欢
转载自www.cnblogs.com/czhwust/p/10614163.html
今日推荐
周排行