PyTorch经验分享:新手如何搭建PyTorch程序
为什么是PyTorch
在2017年的10月份,笔者还在研二的时候,曾发布过一篇有关TensorFlow编程经验的博客。当时,笔者刚接触深度学习框架TensorFlow不足半年,相较于笔者最开始接触的深度学习框架Caffe,TensorFlow更易于上手,更方便随心所欲地实现各种网络结构,确实是一个非常优良的适合算法研究的深度学习框架。但是,在笔者对TensorFlow一年多的使用,写过一些程序后,对TensorFlow的部分局限性感受深刻,比如:
- TensorFlow的数据流图是静态的,因此在TensorFlow框架编程的过程中,是无法直接对Tensor进行条件判断的,只能通过会话层tf.Session()将Tensor的值run出来再进行判断。因此,在TensorFlow中,进行条件分支判断就非常难,比如采用tf.cond或者tf.py_func。
- 上述的静态图机制,也使得TensorFlow程序非常难调试。
- 笔者不得不吐槽TensorFlow的接口,在不断更新换代升级的过程中变化太快也太大(比如TensorFlow 2.0与TensorFlow 1.0)。这就造成了许多老版本程序,无法在新版本的TensorFlow框架上运行。
当然,纵使TensorFlow具有上述缺点,依然是非常优秀的深度学习框架,非常多的开源代码仍旧选择TensorFlow作为实现平台。
可是,该如何克服TensorFlow的上述缺点呢?答案是选择PyTorch。时间来到2018-2019年,深度学习框架PyTorch的热度越来越高,因为PyTorch具有非常多的优点。首先,TensorFlow的优点PyTorch也同样拥有,比如:
- 使用Python接口,使得算法实现轻松容易。
- 在构造普通模块的过程中,用户无需关心训练过程中的梯度反传,这为算法研究与实现带来了极大的方便。
- 接口简单,模块集成度高,支持用户较快地实现算法思想。
除此之外,PyTorch还具有许多其他的优点:
- 采用动态图机制。在PyTorch中,能够非常方便地将GPU中的Tensor与Python中的Numpy Array进行相互转换,读出Tensor的值。这就使得用户能非常方便地实现分支程序,而且极大地增加了调试的方便性。
- PyTorch官方文档比较全面且详细,不同版本的程序接口变化不大。为用户提供了极大的便利。
基于以上的优点,笔者也对PyTorch进行了了解并上手。在本篇博客中,笔者将与大家分享搭建PyTorch程序的经验,下面开始干货。
如何搭建PyTorch程序
在计算机视觉相关的深度学习任务中,简单的PyTorch程序主要包含以下这4个模块。
-网络模型定义文件,比如名称为network.py。
-数据读取文件,比如名称为dataset.py。
-训练接口程序,比如名称为train.py。
-测试接口程序,比如名称为evaluate.py。
比如,随便打开一个笔者的PyTorch程序,如下图所示:
在上图中,就包含之前提到的network,dataset,train和evaluate。当然也包含一些笔者自己定制的模块,比如cal_metrics用来计算实验指标,cfg用来设置某些超参数,image_io_and_process用来对图像进行预处理和后处理,loss_functions用来定义与实现某些损失,utils用来承载一些其他操作,比如读写txt文档。
网络模型的定义
在PyTorch中,定义一个网络模型,需要定义一个Python类,并继承torch.nn.Module这个类。在这个类中,通常是通过重写构造函数__init__来定义网络中使用的层,在定义层或者模块的时候,经常会使用到torch.nn这个库。然后再重写父类的forward函数,在其中使用__init__函数中定义的层实现网络的前传过程。比如下面的例子:
class My_Network(nn.Module):
def __init__(self):
super(network, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) #3×3卷积
self.bn1 = nn.BatchNorm2d(64) #BatchNorm
self.relu = nn.ReLU() #激活
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) #最大池化
self.fc = nn.Linear(64 * 14 * 14, 512) #全连接,将[n, (64*14*14)]变成[n, 512]
for m in self.modules(): #初始化
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def forward(self, x): #前传,输入x尺寸为[n, c, 28, 28]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = x.view(x.size(0), -1) #将四维Tensor变成两维,作为全连接的输入
x = self.fc(x)
return x
在构造函数__init__里面,定义好了网络所使用的各个模块。这个网络对输入完依次进行了卷积,批归一化,激活,最大池化,全连接的功能。并且用一个for循环对可训练参数进行初始化。初始化的过程一般都是在网络定义的构造函数中完成。在初始化的过程中可见,在构造函数中定义的所有模块,都被存放在self.modules()这个列表中。
然后,我们重写函数forward,就能实现对输入的x进行前传了。这样,就定义好了一个完整的简单的网络。
数据读取接口
在PyTorch中,与其他框架具有鲜明差异的一点,就是数据读取接口非常规范。在PyTorch中,通过继承torch.utils.data.Dataset这个类实现自己的数据集读取。在继承时,主要需要重写三个函数:__init__函数,__getitem__函数和__len__函数。其中,__init__函数主要进行一些初始化,比如读取一下所需的记录数据的txt或者xml文件,传入一些预处理参数等等。而__getitem__函数主要规定了,在每个batch训练给网络喂数据的时候,应该采用怎样的方式读取数据,以及做怎样的预处理。至于__len__函数的主要功能,就是告诉PyTorch,数据集中有多少数据。下面,就是一个数据读取接口的简单实现:
class My_Dataset(torch.utils.data.Dataset):
def __init__(self, data_txt_path):
data_list = read_txt(data_txt_path) #读取一下记录数据与标签的txt
self.data_list = np.random.permutation(data_list) #打乱一下txt
self.transform = T.ToTensor() #预处理,将数据转化成[n, c, h, w]的形式,并归一化到[0,1]
def __getitem__(self, index):
sample = self.data_list[index] #读取txt中的一行,是一个数据对(image_path, label)
image_path = sample.split(' ')[0]
label = int(sample.split(' ')[-1])
resized_image = read_image(image_path) #读取一下图像
image = self.transform(resized_image) #转化成Tensor
return image.float(), label
def __len__(self):
return len(self.data_list)
如上所示,在定义数据接口时,通过重写__init__,__getitem__和__len__三个函数,实现读取数据的功能,该数据接口可用于简单的图像分类。然后在训练程序中,只需要使用torch.utils.data.DataLoader开始在每一次训练中对数据集进行读取与遍历就行。
train_data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None)
其中,除去dataset,用的比较多的是batch_size,shuffle,num_workers和pin_memory这几个参数。batch_size指定了一个批次的数据容量,shuffle指定了是否打乱数据,num_workers指定了用多少个线程对数据进行读取,而pin_memory则指定了是否将数据放入显存。
训练程序
在搭建好网络,完善好数据读取接口后,我们就可以进行训练程序的搭建了。在进行训练程序的搭建时,与其他深度学习框架类似,都是先搭建图(模型),然后在图上进行训练。比如,一个简单的训练接口示意就如下所示:
from __future__ import print_function
import torch.nn
from torch.nn import DataParallel
import torch.optim
import torch.utils.data
import argparse
#import 所需的各种模块
parser = argparse.ArgumentParser(description="")
#添加各种自定义参数
args = parser.parse_args()
def main():
device = torch.device("cuda")
train_dataset = My_Dataset(txt_path)
train_data_loader = torch.utils.data.DataLoader(train_dataset,
batch_size,
shuffle,
num_workers)
model = My_Network()
#若有预训练参数,可以载入预训练参数
model.load_state_dict(torch.load(ckpt_path), strict=False)
#定义损失,比如criterion = torch.nn.CrossEntropyLoss()
criterion = My_Loss
model.to(device)
#可以将模型训练放在多张GPU上并行
model = DataParallel(model)
optimizer = torch.optim.SGD(model.parameters(), lr, weight_decay)
model.train() #将模型置为训练模式
for i in range(epoch):
for ii, data in enumerate(train_data_loader):
data_input, label = data
data_input = data_input.to(device)
label = label.to(device).long()
output = model(data_input)
loss = criterion(output, label)
print("loss =: ", loss)
optimizer.zero_grad() #首先梯度置零
loss.backward() #然后求梯度
optimizer.step() #通过梯度更新参数
iters = i * len(train_data_loader) + ii
if iters % save_interval == 0:
torch.save(model.state_dict(), save_path)
if __name__ == "__main__":
main()
如上所示,在进行训练时,首先进行四个步骤,即:
- 引用之前定义的数据读取接口。
- 引用model,即网络,需要的话就载入预训练参数。
- 定义loss。
- 定义优化器。
然后就可以开始在for循环中进行可训练参数更新了。
最后,在PyTorch训练程序中,还需要注意三点。
第一点是保存模型和加载预训练参数。
在保存模型时,通常是保存模型的已训练参数(也可以保存整个模型),通过torch.save函数完成;在加载预训练参数时,通过model的父类,即torch.nn.Module的load_state_dict完成,里面的“strict”参数表示是否需要将预训练模型里面的参数与model里面的完全对齐。
第二点是使用GPU训练
如果需要使用GPU训练时,需要使用到to(device)函数,意思就是将数据或者模型放到GPU上面。
第三点是参数更新过程
在进行参数更新时,一共分成三步:
- 将可训练参数梯度置零 optimizer.zero_grad()
- 根据损失值求梯度 loss.backward()
- 更新可训练参数 optimizer.step()
第四点是将model置为训练模式
对应代码中的
model.train()
训练模式会对某些定义的网络模块有影响,比如使用dropout层,在训练时会被激活。
测试程序
在模型训练完毕之后,需要对模型进行测试。在测试时,与其他深度学习框架类似,还是通过先对模型进行搭建,载入已训练参数,将测试数据前传得到结果。比如,一个简单的测试接口示意就如下所示:
from __future__ import print_function
import torch.nn
from torch.nn import DataParallel
import torch.optim
import torch.utils.data
import argparse
#import 所需的各种模块
parser = argparse.ArgumentParser(description="")
#添加各种自定义参数
args = parser.parse_args()
def main():
device = torch.device("cuda")
model = My_Network()
model = DataParallel(model)
model.load_state_dict(torch.load(snapshot_path)) #载入参数
model.to(device)
model.eval() #将模型置为测试模式
for eval_data_path in eval_data_list:
eval_data = read_image(eval_data_path)
data = torch.from_numpy(eval_data) #将data转化为Tensor
data = data.to(device)
output = model(data) #前传得到结果
#自定义操作,比如计算精度
if __name__ == "__main__":
main()
可以看到,测试程序比训练程序简单很多,主要先进行两个步骤:
- 定义网络
- 载入已训练参数
然后就可以读取测试样本进行前传了。需要注意的是,在进行模型的测试时,需要将model置为测试模式,对应代码中的
model.eval()
经过以上的示例,可见PyTorch程序比较规范,简洁与清晰。并且在PyTorch中还使用到了非常多的面向对象的规范,有许多操作都是通过继承Python类进行实现的。
写在后面
到这里,本篇博客就接近尾声了。在本篇博客中,笔者只是简单地展示了PyTorch程序搭建的架构,这也仅仅是笔者的个人习惯,旨在展示PyTorch程序的执行过程。各位读者朋友可以根据自己的喜好进行自定制的PyTorch代码搭建。
在学习PyTorch的过程中,推荐大家多阅读pyTorch官方文档。也可以多阅读github上的优秀开源项目,比如HRNet和mmdetection。
欢迎阅读笔者后续博客,各位读者朋友的支持与鼓励是我最大的动力!
written by jiong
不忘初心,牢记使命!