# -*- coding: utf-8 -*-
"""
Created on 2019/9/29 23:52
@author: Johnson
Email:[email protected]
@software: PyCharm
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torchvision.utils import make_grid
def imshow_batch(sample_batch):
images = sample_batch[0]
labels = sample_batch[1]
images = make_grid(images, nrow=4, pad_value=255)
# 1,2, 0
images_transformed = np.transpose(images.numpy(), (1, 2, 0))
plt.imshow(images_transformed)
plt.axis('off')
labels = labels.numpy()
plt.title(labels)
class Net(nn.Module):
'''
自定义的CNN网络,3个卷积层,包含batch norm。2个pool,
3个全连接层,包含Dropout
输入:28x28x1s
'''
def __init__(self):
super(Net, self).__init__()
self.feature = nn.Sequential(
OrderedDict(
[
# 28x28x1
('conv1', nn.Conv2d(in_channels=1,
out_channels=32,
kernel_size=5,
stride=1,
padding=2)),
('relu1', nn.ReLU()),
('bn1', nn.BatchNorm2d(num_features=32)),
# 28x28x32
('conv2', nn.Conv2d(in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
padding=1)),
('relu2', nn.ReLU()),
('bn2', nn.BatchNorm2d(num_features=64)),
('pool1', nn.MaxPool2d(kernel_size=2)),
# 14x14x64
('conv3', nn.Conv2d(in_channels=64,
out_channels=128,
kernel_size=3,
stride=1,
padding=1)),
('relu3', nn.ReLU()),
('bn3', nn.BatchNorm2d(num_features=128)),
('pool2', nn.MaxPool2d(kernel_size=2)),
# 7x7x128
('conv4', nn.Conv2d(in_channels=128,
out_channels=64,
kernel_size=3,
stride=1,
padding=1)),
('relu4', nn.ReLU()),
('bn4', nn.BatchNorm2d(num_features=64)),
('pool3', nn.MaxPool2d(kernel_size=2)),
# out 3x3x64
]
)
)
self.classifier = nn.Sequential(
OrderedDict(
[
('fc1', nn.Linear(in_features=3 * 3 * 64,
out_features=128)),
('dropout1', nn.Dropout2d(p=0.5)),
('fc2', nn.Linear(in_features=128,
out_features=64)),
('dropout2', nn.Dropout2d(p=0.6)),
('fc3', nn.Linear(in_features=64, out_features=10))
]
)
)
def forward(self, x):
out = self.feature(x)
out = out.view(-1, 64 * 3 *3)
out = self.classifier(out)
return out
# ---------------------------数据集-------------------------------------
data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])
train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, transform=tranform)
val_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)
# 随机显示一个batch
plt.figure()
imshow_batch(next(iter(train_dataloader)))
plt.show()
# -------------------------定义网络,参数设置--------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = Net()
print(net)
net = net.to(device)
loss_fc = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
# -----------------------------训练-----------------------------------------
file_runing_loss = open('./log/running_loss.txt', 'w')
file_test_accuarcy = open('./log/test_accuracy.txt', 'w')
epoch_num = 100
for epoch in range(epoch_num):
running_loss = 0.0
accuracy = 0.0
scheduler.step()
for i, sample_batch in enumerate(train_dataloader):
inputs = sample_batch[0]
labels = sample_batch[1]
inputs = inputs.to(device)
labels = labels.to(device)
net.train()
optimizer.zero_grad()
outputs = net(inputs)
loss = loss_fc(outputs, labels)
loss.backward()
optimizer.step()
print(i, loss.item())
# 统计数据,loss,accuracy
running_loss += loss.item()
if i % 20 == 19:
correct = 0
total = 0
net.eval()
for inputs, labels in val_dataloader:
outputs = net(inputs)
_, prediction = torch.max(outputs, 1)
correct += ((prediction == labels).sum()).item()
total += labels.size(0)
accuracy = correct / total
print('[{},{}] running loss = {:.5f} acc = {:.5f}'.format(epoch + 1, i+1, running_loss / 20, accuracy))
file_runing_loss.write(str(running_loss / 20)+'\n')
file_test_accuarcy.write(str(accuracy)+'\n')
running_loss = 0.0
print('\n train finish')
torch.save(net.state_dict(), './model/model_100_epoch.pth')
############测试网络
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])
test_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=4, num_workers=4, shuffle=False)
plt.figure()
imshow_batch(next(iter(test_dataloader)))
net = net.Net()
net.load_state_dict(torch.load(f='./model/model_100_epoch.pth', map_location='cpu'))
print(net)
images, labels = next(iter(test_dataloader))
outputs = net(images)
_, prediction = torch.max(outputs, 1)
print('label:', labels)
print('prdeiction:', prediction)
plt.show()
0020-pytorch-自定义网络进行分类
猜你喜欢
转载自blog.csdn.net/zhonglongshen/article/details/112805857
今日推荐
周排行