from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
model = Net()
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
loss = torch.nn.NLLLoss()
trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model,
metrics={
‘accuracy’: Accuracy(),
‘nll’: Loss(loss)
})
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(trainer):
print(“Epoch[{}] Loss: {:.2f}”.format(trainer.state.epoch, trainer.state.output))
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
evaluator.run(train_loader)
metrics = evaluator.state.metrics
print(“Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}”
.format(trainer.state.epoch, metrics[‘accuracy’], metrics[‘nll’]))
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
print(“Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}”
.format(trainer.state.epoch, metrics[‘accuracy’], metrics[‘nll’]))
trainer.run(train_loader, max_epochs=100)
下面进行分解。
model = Net()
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
loss = torch.nn.NLLLoss()
这里定义了网络模型,数据加载器,模型参数优化器,损失函数的定义。
trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model,
metrics={
‘accuracy’: Accuracy(),
‘nll’: Loss(loss)
})
这里使用create_supervised_trainer()
定义了一个Engine类的对象,名字叫 trainer 。
使用create_supervised_evaluator()定义了一个Engine类的对象,名字叫 evaluator 。
其实不使用create_supervised_trainer和create_supervised_evaluator(),也可以直接定义一个Engine类,Engine类参数为训练过程的处理函数,或者验证过程的处理函数。
典型的训练过程:
def _update(engine, batch):
model.train()
optimizer.zero_grad()
x, y = _prepare_batch(batch, device=device)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return loss.item()
典型的验证过程:
def _inference(engine, batch):
model.eval()
with torch.no_grad():
x, y = _prepare_batch(batch, device=device)
y_pred = model(x)
return y_pred, y
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
print(“Epoch[{}] Loss: {:.2f}”.format(engine.state.epoch, engine.state.output))
这里就是注册一个函数,在训练结束时执行,函数的功能是进行写日志。
trainer.run(train_loader, max_epochs=100)
最后,我们进行训练的过程。训练时过程,是按照时间轴的步骤进行。训练结束的输出,通过engine.state.output来获取。
写到这,我想起了tensorflow的计算图和会话模型。发现这个框架和tensorflow的计算图/会话模型 有点神似呢。