一起用代码吸猫!本文正在参与【喵星人征文活动】。
前言
本文介绍了如何使用 resnet 网络来进行猫狗图像的区分,准确率可达到 98% 哦。妈妈再也不用担心我认不出我的猫了。
MegEngine框架安装
megEngine 是一站式 “深度学习”模型开发平台, 开启你的 AI 技能成长之旅
MegEngine框架下载可以按照下面的命令,
pip3 install megengine -f https://megengine.org.cn/whl/mge.html
复制代码
同时,你也可以 fork 公开项目,来学习关于模型的知识
恰逢此月活动为代码吸猫,于是我就来学习 MegEngine 上猫狗大战项目。
学习记录
项目简介
猫狗大战是megengine平台上一个使用深度学习算法来区分辨别猫图和狗图,通过resnet网络,最终判别准确率可达到 98% , 再也不怕认不清猫猫和狗狗了。
数据准备
基于 MegEngine DataSet 构造数据集
先引入依赖
from typing import Tuple
import numpy as np
from megengine.data.dataset import Dataset
import os
import cv2
复制代码
将1000张图片划分为训练集和测试集,比例为9:1
class CatVsDogDataset(Dataset):
def __init__(self, mode, dir):
super().__init__()
self.mode = mode
self.dir = dir
self.data_size = 0
self.data = []
self.label = []
# self.data 是数据,self.label 数据对应的标签
if self.mode == 'train':
dir = os.path.join(dir, "train")
for file in os.listdir(dir):
# 读文件
img = cv2.imread(os.path.join(dir, file))
self.data.append(img)
name = file.split(sep='.')
if name[0] == 'cat':
self.label.append(0) # the label of cat is 0
else:
self.label.append(1) # the label of dog is 1
elif self.mode == 'test':
dir = os.path.join(dir, "test")
for file in os.listdir(dir):
img = cv2.imread(os.path.join(dir, file))
self.data.append(img)
name = file.split(sep='.')
if name[0] == 'cat':
self.label.append(0) # the label of cat is 0
else:
self.label.append(1) # the label of dog is 1
else:
print('Undefined Dataset!')
self.data = np.array(self.data)
self.label = np.array(self.label)
print(self.data.shape)
print(self.label.shape)
# 定义获取数据集中每个样本的方法
def __getitem__(self, index: int) -> Tuple:
return self.data[index], self.label[index]
# 定义返回数据集长度的方法
def __len__(self) -> int:
return len(self.data)
复制代码
对划分后的数据做一下检查:
import os
print("训练数据集总数:",len(os.listdir("./dataset/CatVsDog/train")))
print("测试数据集总数:",len(os.listdir("./dataset/CatVsDog/test")))
train_dataset = CatVsDogDataset("train", "./dataset/CatVsDog")
test_dataset = CatVsDogDataset("test", "./dataset/CatVsDog")
复制代码
说明数据准备工作已经做好了
构建一个resnet网络结构
什么是 resnet 网络结构
将靠前若干层的某一层数据输出直接跳过多层引入到后面数据层的输入部分。意味着后面的特征层的内容会有一部分由其前面的某一层线性贡献。
从经验来看,网络的深度对模型的性能至关重要,当增加网络层数后,网络可以进行更加复杂的特征模式的提取,所以当模型更深时理论上可以取得更好的结果。但网络层数增加时,却出现了深度网络退化问题,这给深度网络的进步造成了巨大阻碍。
何博士提出的 ResNet 算法解决了 CNN 模型难以训练的问题,2014年VGC只有19层,儿15年就多达152层也印证了 ResNet 的优越性。
构造的核心代码引用了 megengine.functional 和 module 方法。
如果通过 python 纯原生手写的话,这个算法的实现并不简单,具体可以参考 Megengine 的公开项目。
Megengine 官方提供了训练好的模型,我们可以直接引用,我们直接把模型下载下来,当前模型已经被训练,无需再次训练。
os.system("wget https://data.megengine.org.cn/models/weights/resnet18_naiveaug_70312_78a63ca6.pkl")
复制代码
当然也可以对模型进行继续训练。
模型训练
def model_train():
import megengine as mge
smallnet = resnet18()
# 可选
state_dict = mge.load('resnet18_naiveaug_70312_78a63ca6.pkl')
smallnet.load_state_dict(state_dict)
batch_size = 16
sampler = RandomSampler(dataset=train_dataset, batch_size=batch_size, drop_last=True)
from megengine.data import transform
transform = transform.Compose([
transform.RandomResizedCrop(224),
transform.RandomHorizontalFlip(),
transform.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
transform.Lighting(0.1),
transform.Normalize(
mean=[103.530, 116.280, 123.675], std=[57.375, 57.120, 58.395]
),
transform.ToMode("CHW"),
])
train_dataloader = DataLoader(
train_dataset,
sampler=sampler,
transform=transform,
)
# 定义静态图训练函数
@trace(symbolic=True)
def train_func(data, label, *, net, optimizer):
net.train() # 网络设置成训练模式
pred = net(data)
# 使用交叉熵损失
loss = F.cross_entropy_with_softmax(pred, label)
optimizer.backward(loss)
return pred, loss
# 定义优化器
opt = optim.SGD(smallnet.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
# 模型训练
import megengine as mge
import numpy as np
# set trace.enabled=False if you want to run eager mode
# trace.enabled = False
# 训练迭代,优化器更新参数
# 这里为了方便演示,只迭代 10 个 epochs 。
# 实际训练可以设成 200 个 epochs ,在第 100 和第 150 个 epoch 位置将 lr 分别降至 0.01 和 0.001 。
epochs = 100
data_tensor = mge.tensor(dtype=np.float32)
label_tensor = mge.tensor(dtype=np.int32)
losses = []
for i in range(epochs):
print(".")
loss_rec = []
for data, label in train_dataloader:
"""
# 验证数据的正确性
img = np.array(data[0])
img = np.transpose(img,[1,2,0])
print(img.shape)
cv2.imshow("img", img)
cv2.waitKey(0)
"""
data_tensor.set_value(data)
label_tensor.set_value(label.astype("int32"))
opt.zero_grad()
# pred = smallnet(data)
# print(pred.shape)
# exit()
_, loss = train_func(data_tensor, label_tensor, net=smallnet, optimizer=opt)
opt.step()
loss_rec.append(loss.numpy().item())
loss = sum(loss_rec) / len(loss_rec)
losses.append(loss)
print("[Epoch {}] loss: {}".format(i, loss))
"""
损失可视化
模型保存
"""
import matplotlib.pyplot as plt
plt.plot(range(len(losses)), losses, color='red')
plt.xlabel("iterator")
plt.ylabel('loss')
plt.show()
# 模型保存
mge.save(smallnet.state_dict(), 'resnet18_static_100.mge')
复制代码
原谅小白,模型训练部分还未完全理解透彻。
模型测试
如果使用 megengine 提供的模型,就无需再写单独的测试函数,否则要像训练一样写复杂的测试函数,这里就可以显示出旷视的强大之处了。
def model_test():
"""
模型的加载与测试
:return:
"""
smallnet = resnet18()
import megengine as mge
state_dict = mge.load('resnet18_static_100.mge')
smallnet.load_state_dict(state_dict)
# 创建 DataLoader 用于测试
from megengine.data import transform
batch_size = 1
sampler_test = SequentialSampler(dataset=test_dataset, batch_size=batch_size)
transform_test = transform.Compose([
transform.Resize(256),
transform.CenterCrop(224),
transform.Normalize(
mean=[103.530, 116.280, 123.675], std=[57.375, 57.120, 58.395]
), # BGR
transform.ToMode("CHW"),
])
test_dataloader = DataLoader(
test_dataset,
sampler=sampler_test,
transform=transform_test,
)
# 定义静态图测试函数,进行模型的测试
@trace(symbolic=True)
def eval_func(data, label, *, net):
net.eval() # 网络设置为测试模式
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
return pred, loss
data_tensor = mge.tensor()
label_tensor = mge.tensor(dtype=np.int32)
correct = 0
total = 0
for data, label in test_dataloader:
label = label.astype("int32")
pred, _ = eval_func(data, label, net=smallnet)
pred_label = F.argmax(pred, axis=1)
# if(pred_label.numpy()[0]!=label[0]):
# img = np.array(data[0])
# img = np.transpose(img, [1, 2, 0])
# print(img.shape)
# cv2.imshow("img", img)
# cv2.waitKey(0)
correct += (pred_label == label).sum().numpy().item()
total += label.shape[0]
print("correct: {}, total: {}, accuracy: {:.2f}%".format(correct, total, correct * 100.0 / total))
复制代码
注意事项
- MegStudio 提供的模型是 imagenet 上 1000 分类的,可以改进为分类变成2分类并加载预训练模型
- MegStudio 的训练速度有几分慢,建议大家拷贝代码到本地机器上训练
致谢
- 参考链接: Megengine 猫狗大战、MegStudio
- 感谢我的小伙伴 战场小包 的邀请,贴个它的主页吧(juejin.cn/user/442409…)