基于softmax回归的图像分类二

欢迎关注

利用Pytorch框架实现softmax回归的图像分类

1.导入基础包

import torch
from torch import nn
from torch.nn import init
import numpy as np
import sys
# sys.path.append('..')
# import d21zh_pytorch as d21

2.获取读取数据

batch_size=256
train_iter,test_iter=d21.load_data_fashion_mnist(batch_size)

3.定义和初始化模型

softmax回归的输出层是一个全连接层,所以用一个线性模块就可以了。每个batch样本x的形状为(batch_size,1,28,28),所以我们要先用view()将x的形状转换成(batch_size,784)才送入全连接层

num_inputs=784
num_outputs=10

class LinearNet(nn.Module):
    def __init__(self,num_inputs,num_outputs):
        super(LinearNet,self).__init__()
        self.linear=nn.Linear(num_inputs,num_outputs)
    def forward(self,x):#x.shape=[batch_size,1,28,28]
        print(x.shape)
        print(x.shape[0])#x.shape[0]获取x的第一个元素
        y=self.linear(x.view(x.shape[0],-1))
        print(y.shape)
        return y

net=LinearNet(num_inputs,num_outputs)#定义线性模型,784个输入,10个输出
#将x的形状进行变换
class FlattenLayer(nn.Module):
    def __init__(self):
        super(FlattenLayer,self).__init__()
    def forward(self,x):
        return x.view(x.shape[0],-1) 
    

# print(net)
# for x,y in train_iter:
#     net.forward(x)
from collections import OrderedDict
net=nn.Sequential(
                OrderedDict([
                    ('flatten',FlattenLayer()),
                    ('linear',nn.Linear(num_inputs,num_outputs))
                    ])
                )
# print(net[1])
# print(net.linear)
# print(net.flatten)
# for param in net.parameters():
#     print(param.shape)

Python字典使用方法包括OrderedDict

使用均值为0、标准差为0.01的正态分布随机初始化模型的权重参数

init.normal_(net.linear.weight,mean=0,std=0.01)
init.constant_(net.linear.bias,val=0)

4.softmax和交叉熵损失函数

#该函数同时包括softmax运算和交叉熵损失计算
loss=nn.CrossEntropyLoss()

5.定义优化算法

# for param in net.parameters():
#     print(param.shape)
optimizer=torch.optim.SGD(net.parameters(),lr=0.1)

6.训练模型

num_epochs=5
def evaluate_accuracy(data_iter,net):
    acc_sum,n=0.0,0
    for x,y in data_iter:
#         print('net(x): ',net(x).shape,'结束')#256行10列
        acc_sum+=(net(x).argmax(dim=1)==y).float().sum().item()#.sum()不是.mean(),这是统计正确的个数,与上面不同
#         print('acc_sum: ',acc_sum)
#         print('y.shape[0]:',y.shape[0])#:y.shape[0]=256
        n+=y.shape[0]
    return acc_sum/n
# print(evaluate_accuracy(test_iter,net))

def train_ch3(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,optimizer=None):
    for epoch in range(num_epochs):
        train_1_sum,train_acc_sum,n=0.0,0.0,0
        for x,y in train_iter:
#             print(x.shape)#torch.Size([256, 1, 28, 28])
            y_hat=net(x)
            L=loss(y_hat,y).sum()
            #梯度清零
            if optimizer is not None:
                optimizer.zero_grad()
            elif params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
            L.backward()
            if optimizer is None:
                sgd(params,lr,batch_size)
            else:
                optimizer.step()#softmax回归的简介实现方式
                
            train_1_sum+=L.item()
            train_acc_sum+=(y_hat.argmax(dim=1)==y).sum().item()
            n+=y.shape[0]
        test_acc=evaluate_accuracy(test_iter,net)
        print('epoch %d loss % .4f train_acc %.3f test_acc %.3f ' % (epoch+1,train_1_sum/n,train_acc_sum/n,test_acc))
train_ch3(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None,optimizer)

运行结果:

epoch 1 loss  0.0031 train_acc 0.750 test_acc 0.774 
epoch 2 loss  0.0022 train_acc 0.815 test_acc 0.797 
epoch 3 loss  0.0021 train_acc 0.824 test_acc 0.818 
epoch 4 loss  0.0020 train_acc 0.833 test_acc 0.816 
epoch 5 loss  0.0019 train_acc 0.836 test_acc 0.821 

7.预测

def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress',
                                'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[i] for i in labels]

#图片设置
def use_svg_display():
    """用矢量图显示svg"""
#在一行里面画出多张图像和对应标签的函数
def show_fashion_mnist(images,labels):
    use_svg_display()
    #'_'表示我们忽略,不使用的变量
    _,figs=plt.subplots(1,len(images),figsize=(25,25))#
    for f,img,lbl in zip(figs,images,labels):
        f.imshow(img.view((28,28)).numpy())
        f.set_title(lbl,color='white')
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()
    
x,y=iter(test_iter).next()
true_labels=get_fashion_mnist_labels(y.numpy())
pred_labels=get_fashion_mnist_labels(net(x).argmax(dim=1).numpy())
titles=[true+'\n'+pred for true,pred in zip(true_labels,pred_labels)]
show_fashion_mnist(x[0:9],titles[0:9])

运行结果:

预测结果

发布了17 篇原创文章 · 获赞 7 · 访问量 1448

猜你喜欢

转载自blog.csdn.net/qq_40211493/article/details/103945891
今日推荐