TextCNN实现imdb数据集情感分类任务

目录

一、导入必须的库

二、加载数据集

三、创建TextCNN模型

四、加载词向量

五、设置超参数、优化器和损失函数

六、训练过程设计

七、运用模型进行预测


一、导入必须的库

import torch
from torch import nn
from d2l import torch as d2l

二、加载数据集

batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)

三、创建TextCNN模型

class TextCNN(nn.Module):
    def __init__(self, vocab_size, embed_size, kernel_sizes, num_channels, **kwargs):
        super(TextCNN, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # 这个嵌入层不需要训练
        self.constant_embedding = nn.Embedding(vocab_size, embed_size)
        self.dropout = nn.Dropout(0.5)
        self.decoder = nn.Linear(sum(num_channels), 2)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.relu = nn.ReLU()

        self.convs = nn.ModuleList()

        for c, k in zip(num_channels, kernel_sizes):
            self.convs.append(nn.Conv1d(2 * embed_size, c, k))

    def forward(self, inputs):
        embeddings = torch.cat((self.embedding(inputs), self.constant_embedding(inputs)), dim=2)
        embeddings = embeddings.permute(0, 2, 1)
        encoding = torch.cat([torch.squeeze(self.relu(self.pool(conv(embeddings))), dim=-1) for conv in self.convs],
                             dim=1)
        outputs = self.decoder(self.dropout(encoding))
        return outputs

# 实例化
embed_size, kernel_sizes, nums_channels = 100, [3, 4, 5], [100, 100, 100]
net = TextCNN(len(vocab), embed_size, kernel_sizes, nums_channels)


def init_weights(m):
    if type(m) in (nn.Linear, nn.Conv1d):
        nn.init.xavier_uniform_(m.weight)

net.apply(init_weights)

四、加载词向量

net.embedding.weight.data.copy_(embeds)
# constant_embedding层不进行更新
net.constant_embedding.weight.data.copy_(embeds)
net.constant_embedding.weight.requires_grad = False

五、设置超参数、优化器和损失函数

device = d2l.try_gpu()
lr, num_epochs = 1e-3, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')

六、训练过程设计

d2l.train_ch13(net,train_iter,test_iter,loss,optimizer,num_epochs,device)

七、运用模型进行预测

d2l.predict_sentiment(net, vocab, 'I am a big apple.')

猜你喜欢

转载自blog.csdn.net/qq_38901850/article/details/125176659
今日推荐