PyTorch实现卷积神经网络

 1 rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens)
 2 num_steps, batch_size = 35, 2
 3 X = torch.rand(num_steps, batch_size, vocab_size)
 4 state = None
 5 Y, state_new = rnn_layer(X, state)
 6 print(Y.shape, state_new.shape)
 7 
 8 
 9 class RNNModel(nn.Module):
10     def __init__(self, rnn_layer, vocab_size):
11         super(RNNModel, self).__init__()
12         self.rnn = rnn_layer
13         self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 
14         self.vocab_size = vocab_size
15         self.dense = nn.Linear(self.hidden_size, vocab_size)
16 
17     def forward(self, inputs, state):
18         # inputs.shape: (batch_size, num_steps)
19         X = to_onehot(inputs, vocab_size)
20         X = torch.stack(X)  # X.shape: (num_steps, batch_size, vocab_size)
21         hiddens, state = self.rnn(X, state)
22         hiddens = hiddens.view(-1, hiddens.shape[-1])  # hiddens.shape: (num_steps * batch_size, hidden_size)
23         output = self.dense(hiddens)
24         return output, state
25 
26 def predict_rnn_pytorch(prefix, num_chars, model, vocab_size, device, idx_to_char,
27                       char_to_idx):
28     state = None
29     output = [char_to_idx[prefix[0]]]  # output记录prefix加上预测的num_chars个字符
30     for t in range(num_chars + len(prefix) - 1):
31         X = torch.tensor([output[-1]], device=device).view(1, 1)
32         (Y, state) = model(X, state)  # 前向计算不需要传入模型参数
33         if t < len(prefix) - 1:
34             output.append(char_to_idx[prefix[t + 1]])
35         else:
36             output.append(Y.argmax(dim=1).item())
37     return ''.join([idx_to_char[i] for i in output])
38 
39 model = RNNModel(rnn_layer, vocab_size).to(device)
40 predict_rnn_pytorch('分开', 10, model, vocab_size, device, idx_to_char, char_to_idx)
41 
42 def train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
43                                 corpus_indices, idx_to_char, char_to_idx,
44                                 num_epochs, num_steps, lr, clipping_theta,
45                                 batch_size, pred_period, pred_len, prefixes):
46     loss = nn.CrossEntropyLoss()
47     optimizer = torch.optim.Adam(model.parameters(), lr=lr)
48     model.to(device)
49     for epoch in range(num_epochs):
50         l_sum, n, start = 0.0, 0, time.time()
51         data_iter = d2l.data_iter_consecutive(corpus_indices, batch_size, num_steps, device) # 相邻采样
52         state = None
53         for X, Y in data_iter:
54             if state is not None:
55                 # 使用detach函数从计算图分离隐藏状态
56                 if isinstance (state, tuple): # LSTM, state:(h, c)  
57                     state[0].detach_()
58                     state[1].detach_()
59                 else: 
60                     state.detach_()
61             (output, state) = model(X, state) # output.shape: (num_steps * batch_size, vocab_size)
62             y = torch.flatten(Y.T)
63             l = loss(output, y.long())
64             
65             optimizer.zero_grad()
66             l.backward()
67             grad_clipping(model.parameters(), clipping_theta, device)
68             optimizer.step()
69             l_sum += l.item() * y.shape[0]
70             n += y.shape[0]
71         
72 
73         if (epoch + 1) % pred_period == 0:
74             print('epoch %d, perplexity %f, time %.2f sec' % (
75                 epoch + 1, math.exp(l_sum / n), time.time() - start))
76             for prefix in prefixes:
77                 print(' -', predict_rnn_pytorch(
78                     prefix, pred_len, model, vocab_size, device, idx_to_char,
79                     char_to_idx))
80 
81 num_epochs, batch_size, lr, clipping_theta = 250, 32, 1e-3, 1e-2
82 pred_period, pred_len, prefixes = 50, 50, ['分开', '不分开']
83 train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
84                             corpus_indices, idx_to_char, char_to_idx,
85                             num_epochs, num_steps, lr, clipping_theta,
86                             batch_size, pred_period, pred_len, prefixes)

猜你喜欢

转载自www.cnblogs.com/hahasd/p/12333291.html
今日推荐