pytorch框架学习(6)训练一个简单的自己的CNN (三)细节篇

由于上两篇文章是按照视频教程顺下来的,很多细节地方有所忽略,本文重点对各个地方的细节进行学习与记录。


细节一torchvision.transforms.Compose
在制作数据集的过程中,我们使用了transforms.Compose对数据进行了预处理。
torchvision.transforms是pytorch中的图像预处理包。一般用Compose把多个步骤整合到一起:如:

transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
        transforms.CenterCrop(224),#从中心开始裁剪
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
    ])

在这里插入图片描述
其中ToTensor是使得图片从HWC转换为CHW从0 ~ 255转化为0 ~ 1


细节二torchvision.datasets.ImageFolder
ImageFolder是一个通用的数据加载器,它要求以各个文件夹的这种格式来组织数据集的训练、验证或者测试图片。

#root 为data/train或者data/valid
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
 
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

参数详解:

dataset=torchvision.datasets.ImageFolder(
                       root,  #图片储存的根目录,即各类别文件夹所在目录的上一级目录
                       transform=None,#预处理操作函数(细节1所定义的函数)
                       target_transform=None, # 对图片类别进行预处理的操作,输入为 target,输出对其的转换。 如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
                       loader=<function default_loader>, # 表示数据集加载方式,通常默认加载方式即可。
                       is_valid_file=None) #获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)

返回的dataset都有以下三种属性:

self.classes:用一个 list 保存类别名称
self.class_to_idx:类别对应的索引,与不做任何转换返回的
target 对应 self.imgs:保存(img-path, class) tuple的 list

在这里插入图片描述


细节三torch.utils.data.DataLoader
torch.utils.data.DataLoader 主要是对数据进行 batch 的划分。

torch.utils.data.DataLoader(image_datasets, # 数据,要求是dataset类型
							batch_size=8, # 批量大小
							shuffle=True, #是否进行数据洗牌)

其他的操作可见别人写的博客


细节四model.parameters()与model.state_dict()
model.parameters()与model.state_dict()都是Pytorch中用于查看网络参数的方法。
一般来说,前者多见于优化器的初始化;后者多见于模型的保存。
如:

optomizer = torch.optim.Adam(model.parameters(), lr=1e-5)
torch.save(model.state_dict(), ‘best_model,pth’)

model.state_dict()返回的是一个OrderDict,存储了网络结构的名字和对应的参数。
model.state_dict() 获取了 model 中所有的可学习参数(weight、bias),同时还获取了不可学习参数(BN layer 的 running mean 和 running var 等)。可以将 model.state_dict() 看作是在 model.parameters() 功能的基础上,又额外获取了所有不可学习参数。
————————————————
版权声明:本文为CSDN博主「yaoyz105」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_31347869/article/details/125065271


额外知识点1:深拷贝、浅拷贝和赋值之间的区别

Python提供了三种赋值方法,最常见的是赋值=、浅拷贝copy.copy()和深拷贝copy.deepcopy()

赋值:Python中的赋值都是进行对象的引用传递,即内存地址的传递。
浅拷贝:浅拷贝只拷贝对象本身,并不会拷贝对象内部的嵌套对象。
赋值:深拷贝会拷贝对象本身以及其所有的嵌套对象。
具体可以参考这个博客

  • model.state_dict()也是浅拷贝,如果令param=model.state_dict(),那么当你修改param,相应地也会修改model的参数。

细节五torch.load()与model.load_state_dict()
torch.load(“path路径”)表示加载已经训练好的模型,这个模型就是一个state_dict
model.load_state_dict()表示将训练好的模型参数重新加载至网络模型中

# save
torch.save(model.state_dict(), PATH)
 
# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

额外知识点2:Image.resize()和Image.thumbnail()

**Image.resize()**函数用于修改图片的尺寸。 > **Image.thumbnail()**函数用于制作当前图片的缩略图。

详情请参考这一篇博文~


猜你喜欢

转载自blog.csdn.net/vibration_xu/article/details/126176197