尽管 PyTorch 已经为我们实现神经网络提供了不少便利,但是人的惰性是无极限的,这里介绍一个进一步抽象的工具包——ignite,它将 PyTorch 训练过程更加简化了。
1. 安装
pip install pytorch-ignite
2. 基础示例
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
model = Net()
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
criterion = nn.NLLLoss()
trainer = create_supervised_trainer(model, optimizer, criterion)
val_metrics = {
"accuracy": Accuracy(),
"nll": Loss(criterion)
}
evaluator = create_supervised_evaluator(model, metrics=val_metrics)
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(trainer):
print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output))
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
evaluator.run(train_loader)
metrics = evaluator.state.metrics
print("Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
print("Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))
trainer.run(train_loader, max_epochs=100)
显然,这里先创建网络模型,Dataloader,优化器以及目标函数,然后用 ignite 的方法 create_supervised_trainer 和 create_supervised_evaluator 简化以往繁琐的循环写法,另外,ignite 还提供了面向切面的处理方法,可以在epoch、iteration等开始前、结束后位置执行你希望的操作
3. Engine
这是 ignite 的核心类,它是一种抽象,它在提供的数据上循环给定的次数,执行处理函数并返回结果
while epoch < max_epochs:
# run an epoch on data
data_iter = iter(data)
while True:
try:
batch = next(data_iter)
output = process_function(batch)
iter_counter += 1
except StopIteration:
data_iter = iter(data)
if iter_counter == epoch_length:
break
因此,模型训练器只是一个引擎,它在训练数据集上循环多次并更新模型参数。例如:
def train_step(trainer, batch):
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return loss.item()
trainer = Engine(train_step)
trainer.run(data, max_epochs=100)
【例 1】创建一个基本的训练器
def update_model(engine, batch):
inputs, targets = batch
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
return loss.item()
trainer = Engine(update_model)
@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_training(engine):
batch_loss = engine.state.output
lr = optimizer.param_groups[0]['lr']
e = engine.state.epoch
n = engine.state.max_epochs
i = engine.state.iteration
print("Epoch {}/{} : {} - batch loss: {}, lr: {}".format(e, n, i, batch_loss, lr))
trainer.run(data_loader, max_epochs=5)
【例 2】创建一个基本的评估器并计算指标
from ignite.metrics import Accuracy
def predict_on_batch(engine, batch)
model.eval()
with torch.no_grad():
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
return y_pred, y
evaluator = Engine(predict_on_batch)
Accuracy().attach(evaluator, "val_acc")
evaluator.run(val_dataloader)
【例 3】在训练数据集上计算图像均值/标准差
from ignite.metrics import Average
def compute_mean_std(engine, batch):
b, c, *_ = batch['image'].shape
data = batch['image'].reshape(b, c, -1).to(dtype=torch.float64)
mean = torch.mean(data, dim=-1).sum(dim=0)
mean2 = torch.mean(data ** 2, dim=-1).sum(dim=0)
return {
"mean": mean, "mean^2": mean2}
compute_engine = Engine(compute_mean_std)
img_mean = Average(output_transform=lambda output: output['mean'])
img_mean.attach(compute_engine, 'mean')
img_mean2 = Average(output_transform=lambda output: output['mean^2'])
img_mean2.attach(compute_engine, 'mean2')
state = compute_engine.run(train_loader)
state.metrics['std'] = torch.sqrt(state.metrics['mean2'] - state.metrics['mean'] ** 2)
mean = state.metrics['mean'].tolist()
std = state.metrics['std'].tolist()
【例 4】从状态恢复引擎的运行。用户可以加载state_dict并从加载的状态开始运行引擎
# Restore from an epoch
state_dict = {
"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)}
# or an iteration
# state_dict = {"iteration": 500, "max_epochs": 100, "epoch_length": len(data_loader)}
trainer = Engine(...)
trainer.load_state_dict(state_dict)
trainer.run(data)
Engine 对象还有以下方法:
-
terminate()
:向引擎发送终止信号,以便它在当前迭代之后完全终止运行。 -
terminate_epoch()
:向引擎发送终止信号,以便它在当前迭代之后终止当前epoch。 -
ignite.engine.create_supervised_trainer
:工厂功能,用于创建受监管模型的trainer。
def create_supervised_trainer( model: torch.nn.Module, optimizer: torch.optim.Optimizer, loss_fn: Union[Callable, torch.nn.Module], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, output_transform: Callable = lambda x, y, y_pred, loss: loss.item(), deterministic: bool = False, ) -> Engine:
model:要训练的模型
optimizer:要使用的优化器
loss_fn:要使用的损失函数
device:设备类型规范(默认值:无)Device can be CPU, GPU or TPU
non_blocking:如果为True且此副本位于CPU和GPU之间,则该副本可能相对于主机异步发生。在其他情况下,此参数无效。
prepare_batch:接收(batch,device,non_blocking)并输出张量元组(batch_x,batch_y)的函数
output_transform:接收“ x”,“ y”,“ y_pred”,“ loss”并返回要分配给引擎状态的值的函数。每次迭代后输出。默认为returning loss.item()
deterministic:如果为True,则返回类型为确定性的引擎DeterministicEngine,否则返回 Engine (默认值:False)
-
类似地还有
ignite.engine.create_supervised_evaluator
,其参数少于trainerdef create_supervised_evaluator( model: torch.nn.Module, metrics: Optional[Dict[str, Metric]] = None, device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, output_transform: Callable = lambda x, y, y_pred: (y_pred, y), ) -> Engine:
model:训练好的模型
metrics:指标名称到指标的映射
device:设备类型规范(默认值:无)Device can be CPU, GPU or TPU
output_transform:接收“ x”,“ y”,“ y_pred” 并在每次迭代后返回要分配给引擎state.output的值的函数。默认为返回值(y_pred,y,),它适合度量期望的输出。如果更改它,则应在指标中使用output_transform
【例 5】断点恢复训练
有可能从一个检查点恢复训练,并大致重现原来的运行行为。使用Ignite,这可以通过使用检查点处理程序轻松完成。引擎提供了两个方法来序列化和反序列化其内部状态state_dict()和load_state_dict()。除了序列化模型,优化器,lr调度器等用户可以存储培训器,然后恢复培训。例如
from ignite.handlers import Checkpoint, DiskSaver
trainer = ...
model = ...
optimizer = ...
lr_scheduler = ...
data_loader = ...
to_save = {
'trainer': trainer,
'model': model,
'optimizer': optimizer,
'lr_scheduler': lr_scheduler}
handler = Checkpoint(to_save, DiskSaver('/tmp/training', create_dir=True))
trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)
trainer.run(data_loader, max_epochs=100)
然后,我们可以从最后一个检查点恢复训练。
from ignite.handlers import Checkpoint
trainer = ...
model = ...
optimizer = ...
lr_scheduler = ...
data_loader = ...
to_load = {
'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
checkpoint = torch.load(checkpoint_file)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
trainer.run(train_loader, max_epochs=100)
4. Events & Handlers
为了提高 Engine 灵活性,引入了一个事件系统,该系统促进了运行的每个步骤之间的交互:
- engine is started/completed
- epoch is started/completed
- batch iteration is started/completed
详细的事件可以进ignite.engine.events
看
下面展示了 Engine
的 run()
方法执行的细节:
fire_event(Events.STARTED)
while epoch < max_epochs:
fire_event(Events.EPOCH_STARTED)
# run once on data
for batch in data:
fire_event(Events.ITERATION_STARTED)
output = process_function(batch)
fire_event(Events.ITERATION_COMPLETED)
fire_event(Events.EPOCH_COMPLETED)
fire_event(Events.COMPLETED)
上述代码展示了各个事件执行的位置
使用事件的方法又2种:add_event_handler() 或 装饰器 on
trainer = Engine(update_model)
trainer.add_event_handler(Events.STARTED, lambda _: print("Start training"))
# or
@trainer.on(Events.STARTED)
def on_training_started(engine):
print("Another message of start training")
# or even simpler, use only what you need !
@trainer.on(Events.STARTED)
def on_training_started():
print("Another message of start training")
# attach handler with args, kwargs
mydata = [1, 2, 3, 4]
def on_training_ended(data):
print("Training is ended. mydata={}".format(data))
trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)
利用add_event_handler()
方法还可以动态添加事件:
model = ...
train_loader, validation_loader, test_loader = ...
trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model, metrics={
"acc": Accuracy()})
def log_metrics(engine, title):
print("Epoch: {} - {} accuracy: {:.2f}"
.format(trainer.state.epoch, title, engine.state.metrics["acc"]))
@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(trainer):
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "train"):
evaluator.run(train_loader)
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "validation"):
evaluator.run(validation_loader)
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "test"):
evaluator.run(test_loader)
trainer.run(train_loader, max_epochs=100)
还可以将事件处理程序配置为以用户模式调用:每第n个事件一次,或使用自定义事件过滤功能:
model = ...
train_loader, validation_loader, test_loader = ...
trainer = create_supervised_trainer(model, optimizer, loss)
@trainer.on(Events.ITERATION_COMPLETED(every=50))
def log_training_loss_every_50_iterations():
print("{} / {} : {} - loss: {:.2f}"
.format(trainer.state.epoch, trainer.state.max_epochs, trainer.state.iteration, trainer.state.output))
@trainer.on(Events.EPOCH_STARTED(once=25))
def do_something_once_on_25_epoch():
# do something
def custom_event_filter(engine, event):
if event in [1, 2, 5, 10, 50, 100]:
return True
return False
@engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter))
def call_on_special_event(engine):
# do something on 1, 2, 5, 10, 50, 100 iterations
trainer.run(train_loader, max_epochs=100)
也可以自定义Events:
class CustomEvents(EventEnum):
"""
Custom events defined by user
"""
CUSTOM_STARTED = 'custom_started'
CUSTOM_COMPLETED = 'custom_completed'
engine.register_events(*CustomEvents)
可以同时对某个handler设置多个events:
events = Events.STARTED | Events.COMPLETED | Events.ITERATION_STARTED(every=3)
engine = ...
@engine.on(events)
def call_on_events(engine):
# do something
这些事件可用于附加任何处理程序,并使用触发fire_event()
。
@engine.on(CustomEvents.CUSTOM_STARTED)
def call_on_custom_event(engine):
# do something
@engine.on(Events.STARTED)
def fire_custom_events(engine):
engine.fire_event(CustomEvents.CUSTOM_STARTED)
Handlers 函数的参数不一定非得是engine,不涉及可以空参,可以多个其他参数
也可以允许将事件过滤器传递给引擎:
engine = Engine()
# a) custom event filter
def custom_event_filter(engine, event):
if event in [1, 2, 5, 10, 50, 100]:
return True
return False
@engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter))
def call_on_special_event(engine):
# do something on 1, 2, 5, 10, 50, 100 iterations
# b) "every" event filter
@engine.on(Events.ITERATION_STARTED(every=10))
def call_every(engine):
# do something every 10th iteration
# c) "once" event filter
@engine.on(Events.ITERATION_STARTED(once=50))
def call_once(engine):
# do something on 50th iteration
5. 内置Handlers
库提供了一组内置处理程序,用于检查训练流水线,保存最佳模型,在没有改进的情况下停止训练,使用实验跟踪系统等。可以在以下两个模块中找到它们:
- ignite.handlers
- ignite.contrib.handlers
一些类可以简单地添加Engine为可调用函数。例如,
from ignite.handlers import TerminateOnNan
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
另外还提供了attach()
方法,咋程序执行中手动的添加handles给Engine
from ignite.contrib.handlers.tensorboard_logger import *
# Create a logger
tb_logger = TensorboardLogger(log_dir="experiments/tb_logs")
# Attach the logger to the trainer to log model's weights as a histogram after each epoch
tb_logger.attach(
trainer,
event_name=Events.EPOCH_COMPLETED,
log_handler=WeightsHistHandler(model)
)
6.State
State 是用来存储 Engine 的输出结果的,每一个Engine对象都有 State 属性
- engine.state.seed: Seed to set at each data “epoch”.
- engine.state.epoch: Number of epochs the engine has completed. Initializated as 0 and the first epoch is 1.
- engine.state.iteration: Number of iterations the engine has completed. Initialized as 0 and the first iteration is 1.
- engine.state.max_epochs: Number of epochs to run for. Initializated as 1.
- engine.state.output: The output of the process_function defined for the Engine.
- etc
其他的可在技术文档里查找
在下面的代码中,engine.state.output
将存储批次损失。此输出用于打印每次迭代的损失。
def update(engine, batch):
x, y = batch
y_pred = model(inputs)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def on_iteration_completed(engine):
iteration = engine.state.iteration
epoch = engine.state.epoch
loss = engine.state.output
print("Epoch: {}, Iteration: {}, Loss: {}".format(epoch, iteration, loss))
trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)
在下面的代码中,engine.state.output将是已处理批次的损耗列表y_pred,y。如果要连接Accuracy到引擎,则需要output_transform来从engine.state.output获取y_pred和y
def update(engine, batch):
x, y = batch
y_pred = model(inputs)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item(), y_pred, y
trainer = Engine(update)
@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):
epoch = engine.state.epoch
loss = engine.state.output[0]
print ('Epoch {epoch}: train_loss = {loss}'.format(epoch=epoch, loss=loss))
accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)
与上面类似,但是这次process_function的输出是处理后的批次的损耗字典y_pred,y,这是用户可以使用output_transform从engine.state.output获取y_pred和y的方式
def update(engine, batch):
x, y = batch
y_pred = model(inputs)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return {
'loss': loss.item(),
'y_pred': y_pred,
'y': y}
trainer = Engine(update)
@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):
epoch = engine.state.epoch
loss = engine.state.output['loss']
print ('Epoch {epoch}: train_loss = {loss}'.format(epoch=epoch, loss=loss))
accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)
优良作法是State还用作存储在更新或处理程序函数中创建的用户数据。例如,我们想将new_attribute保存为state:
def user_handler_function(engine): engine.state.new_attribute = 12345
7. Metrics
库提供了各种机器学习任务的现成指标列表。支持两种计算指标的方式:1)在线和2)存储整个输出历史记录
指标可以附加到 Engine:
from ignite.metrics import Accuracy
accuracy = Accuracy()
accuracy.attach(evaluator, "accuracy")
state = evaluator.run(validation_data)
print("Result:", state.metrics)
# > {"accuracy": 0.12345}