inferno简介
Inferno是一个小库,提供了围绕PyTorch的实用程序和方便的函数/类。
其主要功能包括:
- 一个基本的训练类,用来封装训练过程(迭代/epoch循环、验证和checkpoint创建)
- 由networkx提供的用于构建复杂架构模型的图形API
- 数据并行在多个gpu上更容易
- Pytorch神经网络的子模块,模块级参数初始化
- 数据预处理/变换的子模块
- 支持Tensorboard
- 一个回调API来支持与调参工程师的灵活交互
- 未来将拥有更多功能
安装
Conda packages for python >= 3.6 for all distributions are availaible on conda-forge:
$ conda install -c pytorch -c conda-forge inferno
一个简单的示例
安装成功之后,可以直接运行下面的代码(来自官网 )。
import torch.nn as nn
from inferno.io.box.cifar import get_cifar10_loaders
from inferno.trainers.basic import Trainer
from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger
from inferno.extensions.layers.convolutional import ConvELU2D
from inferno.extensions.layers.reshape import Flatten
# Fill these in:
LOG_DIRECTORY = '...'
SAVE_DIRECTORY = '...'
DATASET_DIRECTORY = '...'
DOWNLOAD_CIFAR = True
USE_CUDA = True
# Build torch model
model = nn.Sequential(
ConvELU2D(in_channels=3, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
Flatten(),
nn.Linear(in_features=(256 * 4 * 4), out_features=10),
nn.LogSoftmax(dim=1)
)
# Load loaders
train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY,
download=DOWNLOAD_CIFAR)
# Build trainer
trainer = Trainer(model) \
.build_criterion('NLLLoss') \
.build_metric('CategoricalError') \
.build_optimizer('Adam') \
.validate_every((2, 'epochs')) \
.save_every((5, 'epochs')) \
.save_to_directory(SAVE_DIRECTORY) \
.set_max_num_epochs(10) \
.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
log_images_every='never'),
log_directory=LOG_DIRECTORY)
# Bind loaders
trainer \
.bind_loader('train', train_loader) \
.bind_loader('validate', validate_loader)
if USE_CUDA:
trainer.cuda()
# Go!
trainer.fit()
可视化方法
tensorboard --logdir="./" --port=6007
如果一切正常的话,会显示以下结果
以下是我目前(2020.11)总结的inferno(0.1.29)的各个模块的功能:
部分模块(上图画框部分)的源码介绍及使用见我的其他文章:
inferno Pytorch: inferno.io.transform 介绍及使用
inferno Pytorch: inferno.io.box.cifar下载cifar10 cifar100数据集 介绍及使用
inferno Pytorch: inferno.extensions.layers.convolutional 介绍及使用