# Pytorch 0.4.0 LeNet实现cifar10分类.
# @Time: 2018/6/15
# @Author: xfLi
import torchvision as tv
import torch.nn as nn
import torch as t
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch import optim
MAX_EPOCH = 2
CLASS_NUM = 10
class Net(nn.Module): #定义网络
def __init__(self, class_num=CLASS_NUM):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2= nn.Linear(120, 84)
self.fc3 = nn.Linear(84, class_num)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), kernel_size=(2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), kernel_size=(2, 2))
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def getData(): #数据预处理
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#训练集
train_set = tv.datasets.CIFAR10(root='/data/', train=True, transform=transform, download=True)
train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
#测试集
test_set = tv.datasets.CIFAR10(root='/data/', train=False, transform=transform, download=True)
test_loader = t.utils.data.DataLoader(test_set, batch_size = 4, shuffle = False)
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
return train_loader, test_loader, classes
def train(): #训练
net = Net()
train_dataloader, test_dataloader, classes = getData() #加载数据
ceterion = nn.CrossEntropyLoss() #交叉熵损失
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(MAX_EPOCH):
for step, data in enumerate(train_dataloader):
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
outputs = net(inputs)
loss = ceterion(outputs, labels)
loss.backward()
optimizer.step()
if step % 1000 == 999:
acc =test_net(net, test_dataloader)
print('Epoch: ', epoch, ' |step: ', step, ' |train_loss: ', loss.item(),
'|test accuracy:%.4f' % acc)
print('Finished Training')
return net
def test_net(net, test_dataloader): # 获取在测试集上的准确率
correct, total = .0, .0
for inputs, label in test_dataloader:
output = net(inputs)
_, predicted = t.max(output, 1) # 获取分类结果
total += label.size(0) # 记录总个数
correct += (predicted == label).sum() # 记录分类正确的个数
return float(correct) / total
if __name__ == '__main__':
net = train()
【PyTorch】:LeNet实现cifar10分类.
猜你喜欢
转载自blog.csdn.net/qq_30159015/article/details/80756515
今日推荐
周排行