python-ignite在pytorch中的使用

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator

model = Net()
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
loss = torch.nn.NLLLoss()

trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model,
metrics={
‘accuracy’: Accuracy(),
‘nll’: Loss(loss)
})

@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(trainer):
print(“Epoch[{}] Loss: {:.2f}”.format(trainer.state.epoch, trainer.state.output))

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
evaluator.run(train_loader)
metrics = evaluator.state.metrics
print(“Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}”
.format(trainer.state.epoch, metrics[‘accuracy’], metrics[‘nll’]))

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
print(“Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}”
.format(trainer.state.epoch, metrics[‘accuracy’], metrics[‘nll’]))

trainer.run(train_loader, max_epochs=100)

下面进行分解。

model = Net()
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
loss = torch.nn.NLLLoss()

这里定义了网络模型,数据加载器,模型参数优化器,损失函数的定义。

trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model,
metrics={
‘accuracy’: Accuracy(),
‘nll’: Loss(loss)
})

这里使用create_supervised_trainer()定义了一个Engine类的对象,名字叫 trainer 。

使用create_supervised_evaluator()定义了一个Engine类的对象,名字叫 evaluator 。

其实不使用create_supervised_trainer和create_supervised_evaluator(),也可以直接定义一个Engine类,Engine类参数为训练过程的处理函数,或者验证过程的处理函数。

典型的训练过程:

def _update(engine, batch):
model.train()
optimizer.zero_grad()
x, y = _prepare_batch(batch, device=device)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return loss.item()

典型的验证过程:

def _inference(engine, batch):
model.eval()
with torch.no_grad():
x, y = _prepare_batch(batch, device=device)
y_pred = model(x)
return y_pred, y


@trainer.on(Events.ITERATION_COMPLETED)

def log_training_loss(engine):
print(“Epoch[{}] Loss: {:.2f}”.format(engine.state.epoch, engine.state.output))

这里就是注册一个函数,在训练结束时执行,函数的功能是进行写日志。


trainer.run(train_loader, max_epochs=100)

最后,我们进行训练的过程。训练时过程,是按照时间轴的步骤进行。训练结束的输出,通过engine.state.output来获取。

写到这,我想起了tensorflow的计算图和会话模型。发现这个框架和tensorflow的计算图/会话模型 有点神似呢。

发布了63 篇原创文章 · 获赞 7 · 访问量 3396

猜你喜欢

转载自blog.csdn.net/weixin_44523062/article/details/105118139
今日推荐