本文已参与「新人创作礼」活动,一起开启掘金创作之路。
首先,这篇文章是接着上篇文章继续写的,上篇文章地址在这里 上一篇文章
1、简单回顾
简单回顾下上篇内容,上篇我们简单讲解了如何对自己的数据集进行封装,实现了一个 Dataset
class Pokemon(Dataset):
复制代码
那么实现了这个的功能是为了做什么呢? 相比很多人都知道下面这一段代码,没错就是手写数字识别的代码,那么当我们实现了自己数据集的封装之后,我们就相当于也可以这样使用
# 自带的手写数字数据集
train_loader = torch.utils.data.DataLoader(
datasets.MNIST("../MNIST", train=True, download=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307), (0.3081)),
])), batch_size=batchSize, shuffle=True)
# 自定义数据集的使用
tf = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])
db = Pokemon("pokeman", 64, "train")
x, y = next(iter(db))
loader = DataLoader(db, batch_size=32, shuffle=True)
复制代码
同时在这里,在这里再介绍一种方法,上面我们自定义Dataset,花了很大的精力,但是在某些情况下并不需要花费这么大的精力,也就是在数据文件夹是每个类别一个文件夹
当数据集满足上面的格式要求的时候呢,我们只需要执行下面这几行代码就可以
tf = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])
db = torchvision.datasets.ImageFolder(root="pokeman",transform=tf)
loader = DataLoader(db,shuffle=True,batch_size=32)
复制代码
通过这样简单的几行代码,同样能够帮我们实现对数据集的封装效果,很实用
2、开始训练
如果对深度学习有所了解的肯定知道,训练一般有几个步骤
- 1、获取训练集、测试集等数据
- 2、创建模型
- 3、创建优化器
- 4、创建损失函数
- 5、执行训练
- 6、反向传播
- 7、计算损失 大概就是以上这几个步骤,当然如果要划分的更详细也可以,这里只是简单的说明下问题
import torch
import torchvision
from torch import optim, nn
from torch.utils.data import DataLoader
import visdom
from pokemon import Pokemon
from resnet import ResNet18
batchSize = 32
lr = 1e-3
epochs = 6000
torch.manual_seed(100)
train_db = Pokemon("pokeman", 224, "train")
test_db = Pokemon("pokeman", 224, "train")
val_db = Pokemon("pokeman", 224, "val")
train_loader = DataLoader(train_db, batch_size=batchSize, shuffle=True, num_workers=4)
test_loader = DataLoader(test_db, batch_size=batchSize, shuffle=True, num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchSize, shuffle=True, num_workers=4)
viz = visdom.Visdom()
def evalute(model, loader):
correct = 0
total = len(loader.dataset)
with torch.no_grad():
for x, y in loader:
logist = model(x)
pred = logist.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct / total
def main():
# 模型
model = ResNet18(5) # 5个类别
# 优化器
optimizer = optim.Adam(model.parameters(), lr=lr)
# 损失函数
criterion = nn.CrossEntropyLoss()
best_acc, best_epoch = 0, 0
# 使用 Visdom可视化
viz.line([0],[-1],win="loss",opts=dict(title="loss"))
viz.line([0],[-1],win="acc",opts=dict(title="acc"))
global_step = 0
# 开始训练
for epoch in range(epochs):
# 训练集开始训练
for step, (x, y) in enumerate(train_loader):
logist = model(x)
loss = criterion(logist, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 绘制损失函数折线
viz.line([loss.item()], [global_step], win="loss", update="append")
global_step +=1
# 验证
if epoch % 2 == 0:
acc_val = evalute(model, val_loader)
if acc_val>best_acc: # 当精度最好时,保存精度值
best_epoch = epoch
best_acc = acc_val
torch.save(model.state_dict(),"best.pkl")
viz.line([best_acc], [global_step], win="acc", update="append")
print("acc:",acc_val)
# 加载保存的模型,并验证
print("best acc",best_acc,"best epoch",best_epoch)
model.load_state_dict(torch.load("best.pkl"))
print("load pkl success!")
test_acc = evalute(model,test_loader)
print("test_acc",test_acc)
if __name__ == '__main__':
main()
复制代码
上面的训练中我们使用了Visdom来可视化展示,接下来我们展示可视化内容
经过训练我们最后的精度在0.86左右的样子
3、使用迁移学习的预训练模型
首先是导包
扫描二维码关注公众号,回复:
14472933 查看本文章
![](/qrcode.jpg)
# 我们自己的 ResNet18 网络模型,这里暂时注释
# from resnet import ResNet18
# 导入自带的 resnet18
from torchvision.models import resnet18
# 导入 Flatten 其实就是一个拉平操作
from utils import Flatten
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
# torch.prod 计算某个维度的乘积
# item() 是吧tensor转化为 numpy 但是只能转单个元素
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
复制代码
改变模型的使用
# 这个是我们之前的模型
# model = ResNet18(5) # 5个类别
# 得到resnet18网络模型 pretrained = True表示使用提供的预训练模型
it_model = resnet18(pretrained=True)
# 得到前面的17层 使用 *打散
model = nn.Sequential(
*list(it_model.children())[:-1], # 得到 [b,512,1,1]
Flatten(), # 得到 [b,512]
nn.Linear(512, 5)
)
复制代码
从上面的改动部分我们可以看到。我们使用的是提供的resnet18其中的前17层,最后一层我们通过nn.Linear(512, 5)来达到训练自己类型的目的,最后我们通过训练可以发现,能够达到 0.92以上的精度,相比之前我们自己从头开始训练的效果有明显的提高
4、总结
通过上面一个完整的例子,我们总结了如果实现对自己数据集从零到开始训练,并取得很好精度的完整过程,这个过程很重要,使我们基础菜鸟需要学习的重要流程,如果有读者能看到,希望读者能够认真的看完,我这里其实也是学习总结的过程,如果有不对,希望读者批评改正