神经网络的pytorch实现-基于MNIST数据集

简单的全连接神经网络,包含一个输入层,一个隐藏层,一个输出层

首先利用torchvision来加载数据集

import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable

# hyper parameter
input_size = 28 * 28 # image size of MNIST data
num_classes = 10
num_epochs = 10
batch_size = 100
learning_rate = 1e-3

# MNIST dataset
train_dataset = dsets.MNIST(root = '../../data_sets/mnist', #选择数据的根目录
                           train = True, # 选择训练集
                           transform = transforms.ToTensor(), #转换成tensor变量
                           download = False) # 不从网络上download图片
test_dataset = dsets.MNIST(root = '../../data_sets/mnist', #选择数据的根目录
                           train = False, # 选择训练集
                           transform = transforms.ToTensor(), #转换成tensor变量
                           download = False) # 不从网络上download图片
#加载数据

train_loader = torch.utils.data.DataLoader(dataset = train_dataset, 
                                           batch_size = batch_size, 
                                           shuffle = True)  # 将数据打乱
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                          batch_size = batch_size,
                                          shuffle = True)

建立模型

#input_size = 2
hidden_size = 100
#num_classes = 3

# 创建神经网络模型
class neural_net(nn.Module):
    def __init__(self, input_num,hidden_size, out_put):
        super(neural_net, self).__init__()
        self.fc1 = nn.Linear(input_num, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, out_put)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

model = neural_net(input_size, hidden_size, num_classes)
print(model)

输出结果:

neural_net (
  (fc1): Linear (784 -> 100)
  (relu): ReLU ()
  (fc2): Linear (100 -> 10)
)

优化

# optimization
learning_rate = 1e-3
num_epoches = 5
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

for epoch in range(num_epoches):

    #inputs = Variable(torch.from_numpy(train_x))
    #targets = Variable(torch.from_numpy(train_y))

    #print(inputs)
    #print(targets)

    #optimizer.zero_grad()
    #outputs = model(inputs)
    #loss = criterion(outputs, targets)
    #loss.backward()
    #optimizer.step()

    #print('current loss = %.5f' % loss.data[0])
    print('current epoch = %d' % epoch)
    for i, (images, labels) in enumerate(train_loader): #利用enumerate取出一个可迭代对象的内容
        images = Variable(images.view(-1, 28 * 28))
        labels = Variable(labels)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print('current loss = %.5f' % loss.data[0])

几步的优化之后结果就比logistic regression的效果好了,
如果如下:

current epoch = 0
current loss = 2.31471
current loss = 0.37235
current loss = 0.44407
current loss = 0.29467
current loss = 0.32532
current loss = 0.15531
current epoch = 1
current loss = 0.23495
current loss = 0.18591
current loss = 0.19729
current loss = 0.24272
current loss = 0.22655
current loss = 0.21057
current epoch = 2
current loss = 0.19062
current loss = 0.07879
current loss = 0.17636
current loss = 0.08401
current loss = 0.08000
current loss = 0.18595
current epoch = 3
current loss = 0.12932
current loss = 0.13962
current loss = 0.06450
current loss = 0.11173
current loss = 0.14006
current loss = 0.10523
current epoch = 4
current loss = 0.03298
current loss = 0.07106
current loss = 0.04047
current loss = 0.13380
current loss = 0.09479
current loss = 0.05664

尝试下预测:

# 做 prediction
total = 0
correct = 0

for images, labels in test_loader:

    images = Variable(images.view(-1, 28 * 28))
    outputs = model(images)

    _, predicts = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicts == labels).sum()

print('Accuracy = %.2f' % (100 * correct / total))
Accuracy = 97.05

结果还可以,没有做regulization,主要是电脑太慢了,如果epoch跑长点有可能会更好。

猜你喜欢

转载自blog.csdn.net/u012840636/article/details/78997033