目录
一、导入必须的库
import torch
from torch import nn
from d2l import torch as d2l
二、加载数据集
batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)
三、创建TextCNN模型
class TextCNN(nn.Module):
def __init__(self, vocab_size, embed_size, kernel_sizes, num_channels, **kwargs):
super(TextCNN, self).__init__(**kwargs)
self.embedding = nn.Embedding(vocab_size, embed_size)
# 这个嵌入层不需要训练
self.constant_embedding = nn.Embedding(vocab_size, embed_size)
self.dropout = nn.Dropout(0.5)
self.decoder = nn.Linear(sum(num_channels), 2)
self.pool = nn.AdaptiveAvgPool1d(1)
self.relu = nn.ReLU()
self.convs = nn.ModuleList()
for c, k in zip(num_channels, kernel_sizes):
self.convs.append(nn.Conv1d(2 * embed_size, c, k))
def forward(self, inputs):
embeddings = torch.cat((self.embedding(inputs), self.constant_embedding(inputs)), dim=2)
embeddings = embeddings.permute(0, 2, 1)
encoding = torch.cat([torch.squeeze(self.relu(self.pool(conv(embeddings))), dim=-1) for conv in self.convs],
dim=1)
outputs = self.decoder(self.dropout(encoding))
return outputs
# 实例化
embed_size, kernel_sizes, nums_channels = 100, [3, 4, 5], [100, 100, 100]
net = TextCNN(len(vocab), embed_size, kernel_sizes, nums_channels)
def init_weights(m):
if type(m) in (nn.Linear, nn.Conv1d):
nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)
四、加载词向量
net.embedding.weight.data.copy_(embeds)
# constant_embedding层不进行更新
net.constant_embedding.weight.data.copy_(embeds)
net.constant_embedding.weight.requires_grad = False
五、设置超参数、优化器和损失函数
device = d2l.try_gpu()
lr, num_epochs = 1e-3, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
六、训练过程设计
d2l.train_ch13(net,train_iter,test_iter,loss,optimizer,num_epochs,device)
七、运用模型进行预测
d2l.predict_sentiment(net, vocab, 'I am a big apple.')