一个经典的Pytorch神经网络分类模型训练框架

数据预处理

import pandas as pd
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

url='https://raw.githubusercontent.com/efosler/cse5522data/master/vowelfmts.csv'
df=pd.read_csv(url)

vowels=df['vowel'].unique()
print(vowels)
sample_num = df['vowel'].shape[0]
vowels_dict = dict(zip(vowels, list(range(sample_num))))
print(vowels_dict)

features = np.zeros((sample_num, 2))
labels = np.zeros((sample_num, 1))
print(features.shape)

# build labels 
for i, vowel in enumerate(df['vowel']):
    labels[i] = vowels_dict[vowel]

print(labels)

# build features and normalize 
def normalize(x):
    return (x-min(x))/(max(x)-min(x))

features[:, 0] = normalize(df['f1'])
features[:, 1] = normalize(df['f2'])

1. 将 label 编码为数字,由于后面用的是nn.CrossEntropyLoss(),所以不用转为one-hot

2. 将training data 进行 normalize

划分数据,导入Dataloader

from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import torch.utils.data as Data
BATCH_SIZE = 32

train_data, test_data, train_targets, test_targets = train_test_split(features, labels)

torch_dataset = Data.TensorDataset(torch.from_numpy(train_data), torch.from_numpy(train_targets))
training_data = DataLoader(torch_dataset, batch_size=BATCH_SIZE, shuffle=True)

torch_dataset = Data.TensorDataset(torch.from_numpy(test_data), torch.from_numpy(test_targets))
testing_data = DataLoader(torch_dataset, batch_size=BATCH_SIZE)

1. features 和 labels 都处理好,进行train/valid/test split

2. 导入pytorch自带的Dataloader中,设置batch_size

定义网络

class MLPNet(nn.Module):
    
    def __init__(self, input_dim = 2, hid_dim = 100, output_dim=vowels.shape[0]):
        super(MLPNet, self).__init__()
        self.nn1 = nn.Linear(input_dim, hid_dim)
        self.nn2 = nn.Linear(hid_dim, output_dim)
    
    def forward(self, x):
        #print(x)
        x = self.nn1(x)
        x = F.relu(x)
        x = self.nn2(x)
        return x

主要是init()和forward()两部分,init()中主要定义torch.nn中的函数,forward()中主要使用torch.nn.functional中的函数

定义loss function 和 optimizer

model = MLPNet()
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.02, momentum=0.9)
print(model)

分类问题用nn.CrossEntropyLoss()比较简单

定义train()和test()

def train_func(training_data):
    
    train_loss = 0
    train_acc = 0
    
    for i, (x, y) in enumerate(training_data):
        y = y.squeeze(1).long()
        optimizer.zero_grad()
        output = model(x.float())
        loss = loss_func(output, y)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        train_acc += (output.argmax(1) == y).sum().item()
        
    return train_loss / len(train_targets), train_acc / len(train_targets)
    
def test_func(testing_data):
    
    loss = 0
    acc = 0
    
    for i, (x, y) in enumerate(testing_data):
        y = y.squeeze(1).long()
        with torch.no_grad():
            output = model(x.float())
            loss = loss_func(output, y)
            loss += loss.item()
            acc += (output.argmax(1) == y).sum().item()

    return loss / len(test_targets), acc / len(test_targets)

注意有一些细节,比如Tensor类型需要转为float/long,还有shape不符合需要squeeze/unsqueeze等等

扫描二维码关注公众号,回复: 11029379 查看本文章

定义epoch并开始训练与测试

N_EPOCHS = 2001
import time
best_test_acc = 0
for epoch in range(N_EPOCHS):

    start_time = time.time()
    train_loss, train_acc = train_func(training_data)
    valid_loss, valid_acc = test_func(testing_data)
    best_test_acc = max(valid_acc, best_test_acc)

    secs = int(time.time() - start_time)
    mins = secs / 60
    secs = secs % 60
    if epoch % 50 == 0:
        print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
        print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
        print(f'\tLoss: {valid_loss:.4f}(test)\t|\tAcc: {valid_acc * 100:.1f}%(test)')

print(f'Best Testing Acc: {best_test_acc * 100:.1f}% ')

参考:

https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html

猜你喜欢

转载自www.cnblogs.com/sbj123456789/p/12749266.html