简单的CNN图像分类datasets.MNIST

 1 import numpy as np
 2 import torch
 3 import torch.nn as nn
 4 import torch.nn.functional as F
 5 import torch.optim as optim
 6 from torchvision import datasets,transforms
 7 print("PyTorch Version",torch.__version__)
 8 
 9 class Net(nn.Module):
10     def __init__(self):
11         super(Net,self).__init__()
12         self.conv1=nn.Conv2d(1,20,5,1)
13         self.conv2=nn.Conv2d(20,50,5,1)
14         self.fc1=nn.Linear(4*4*50,500)
15         self.fc2=nn.Linear(500,10)
16     def forward(self,x):
17         #x=1*28*28
18         x=F.relu(self.conv1(x))
19         x=F.max_pool2d(x,2,2)
20         x=F.relu(self.conv2(x))
21         x=F.max_pool2d(x,2,2)
22         x=x.view(-1,4*4*50)
23         x=F.relu(self.fc1(x))
24         x=self.fc2(x)
25         return F.log_softmax(x,dim=1)
26 
27 
28 def train(model, device, train_loader, optimizer, epoch):
29     model.train()
30     for idx, (data, target) in enumerate(train_loader):
31         data, target = data.to(device), target.to(device)
32 
33         pred = model(data)
34         loss = F.nll_loss(pred, target)
35 
36         # SGD
37         optimizer.zero_grad()
38         loss.backward()
39         optimizer.step()
40 
41         if idx % 100 == 0:
42             print("Train Epoch:{},iteration:{},Loss:{}".format(
43                 epoch, idx, loss.item()))
44 
45 
46 def test(model, device, test_loader):
47     model.eval()
48     total_loss = 0.
49     correct = 0.
50     with torch.no_grad():
51         for idx, (data, target) in enumerate(test_loader):
52             data, target = data.to(device), target.to(device)
53 
54             output = model(data)  # batch_size*10
55             total_loss += F.nll_loss(output, target, reduction="sum").item()
56             pred = output.argmax(dim=1)  # batch_zize*1
57             correct += pred.eq(target.view_as(pred)).sum().item()
58         total_loss /= len(test_loader.dataset)
59         acc = correct / len(test_loader.dataset) * 100.
60         print("Test loss:{},Accuracy:{}".format(total_loss, acc))
61 device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
62 batch_size=32
63 train_dataloader=torch.utils.data.DataLoader(
64     datasets.MNIST("./mnist_data",train=True,download=True,
65     transform=transforms.Compose([
66         transforms.ToTensor(),
67         transforms.Normalize((0.1307,),(0.3081,))
68         ])),
69     batch_size=batch_size,shuffle=True,
70     pin_memory=True
71 )
72 test_dataloader=torch.utils.data.DataLoader(
73     datasets.MNIST("./mnist_data",train=False,download=True,
74     transform=transforms.Compose([
75         transforms.ToTensor(),
76         transforms.Normalize((0.1307,),(0.3081,))
77         ])),
78     batch_size=batch_size,shuffle=True,
79    pin_memory=True
80 )
81 lr=0.01
82 momentum=0.5
83 model=Net().to(device)
84 optimizer=torch.optim.SGD(model.parameters(),lr=lr,momentum=momentum)
85 
86 num_epochs=2
87 for epoch in range(num_epochs):
88     train(model,device,train_dataloader,optimizer,epoch)
89     test(model,device,test_dataloader)
90 torch.save(model.state_dict(),"mnist_cnn.pt")
 1 PyTorch Version 1.5.0
 2 Train Epoch:0,iteration:0,Loss:2.2912797927856445
 3 Train Epoch:0,iteration:100,Loss:0.5194383859634399
 4 Train Epoch:0,iteration:200,Loss:0.2786406874656677
 5 Train Epoch:0,iteration:300,Loss:0.22770120203495026
 6 Train Epoch:0,iteration:400,Loss:0.1970459222793579
 7 Train Epoch:0,iteration:500,Loss:0.4494241774082184
 8 Train Epoch:0,iteration:600,Loss:0.059495702385902405
 9 Train Epoch:0,iteration:700,Loss:0.0606967955827713
10 Train Epoch:0,iteration:800,Loss:0.024992913007736206
11 Train Epoch:0,iteration:900,Loss:0.059543460607528687
12 Train Epoch:0,iteration:1000,Loss:0.052940040826797485
13 Train Epoch:0,iteration:1100,Loss:0.04461891949176788
14 Train Epoch:0,iteration:1200,Loss:0.07550729811191559
15 Train Epoch:0,iteration:1300,Loss:0.24627575278282166
16 Train Epoch:0,iteration:1400,Loss:0.15704390406608582
17 Train Epoch:0,iteration:1500,Loss:0.01860976219177246
18 Train Epoch:0,iteration:1600,Loss:0.01433052122592926
19 Train Epoch:0,iteration:1700,Loss:0.15008395910263062
20 Train Epoch:0,iteration:1800,Loss:0.015903979539871216
21 Test loss:0.07957145636081696,Accuracy:97.55
22 Train Epoch:1,iteration:0,Loss:0.03802196681499481
23 Train Epoch:1,iteration:100,Loss:0.03068508207798004
24 Train Epoch:1,iteration:200,Loss:0.010870471596717834
25 Train Epoch:1,iteration:300,Loss:0.01409807801246643
26 Train Epoch:1,iteration:400,Loss:0.0440949946641922
27 Train Epoch:1,iteration:500,Loss:0.028995990753173828
28 Train Epoch:1,iteration:600,Loss:0.01879104971885681
29 Train Epoch:1,iteration:700,Loss:0.013935379683971405
30 Train Epoch:1,iteration:800,Loss:0.021408677101135254
31 Train Epoch:1,iteration:900,Loss:0.2530629634857178
32 Train Epoch:1,iteration:1000,Loss:0.07138478755950928
33 Train Epoch:1,iteration:1100,Loss:0.05166466534137726
34 Train Epoch:1,iteration:1200,Loss:0.010311082005500793
35 Train Epoch:1,iteration:1300,Loss:0.06208721548318863
36 Train Epoch:1,iteration:1400,Loss:0.008724316954612732
37 Train Epoch:1,iteration:1500,Loss:0.035819828510284424
38 Train Epoch:1,iteration:1600,Loss:0.06683069467544556
39 Train Epoch:1,iteration:1700,Loss:0.0015392005443572998
40 Train Epoch:1,iteration:1800,Loss:0.10098540782928467
41 Test loss:0.04268560756444931,Accuracy:98.69

猜你喜欢

转载自www.cnblogs.com/-xuewuzhijing-/p/12971924.html
今日推荐