trainer的父类
from typing import Mapping, Dict, Optional
import torch
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Engine, Events
from ignite.handlers import TerminateOnNan, EarlyStopping
from ignite.metrics import Loss, RunningAverage
from src.exception import ModelNotFoundException
from src.experiment import Number
class Trainer(object):
def __init__(self, *, model: torch.nn.Module = None, file: str = None, save: str = None, device: str = None):
if model is not None:
self.model = model
elif file is not None:
self.model = torch.load(file, map_location=device)
else:
raise ModelNotFoundException("模型未定义,请传入 torch.nn.Module 对象或可加载的模型的文件路径.")
if device is not None:
self.device = device
else:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if save is not None:
self.save: str = save
else:
raise ValueError("模型存储路径未定义!")
self.metrics: Dict = {
"MSE": Loss(torch.nn.MSELoss())}
self.trainer: Optional[Engine] = None
self.evaluator: Optional[Engine] = None
def set_dataset(self, train_batch_size, val_batch_size=1) -> None:
raise NotImplementedError("请重写 set_dataset.")
def set_metrics(self, metric: Mapping) -> None:
"""
设置自定义的评价指标,以字典形式传入
"""
self.metrics.update(metric)
@staticmethod
def score_function(engine: Engine) -> Number:
return -engine.state.metrics["MSE"]
def early_stop(self, every: int = 1, patience: int = 10, min_delta: float = 0,
output_transform=lambda x: {
'MSE': torch.nn.MSELoss()(*x)}) -> None:
"""
如果模型试集的性能没有提升,则提前停止训练
:param every: 间隔多少个 EPOCH 验证一次测试集
:param patience: 多少次模型在测试集上性能没有优化就停止训练
:param min_delta: 分数最少提高多少才认为有改进
:param output_transform: 对 engine 的输出进行转换的函数,转换成日志要输出的评估值
:return:
"""
evaluator_bar_format = "\033[0;32m 测试集验证:{percentage:3.0f}%|{bar}{postfix} 【已执行时间:{elapsed},剩余时间:{remaining}】\033[0m"
bar = ProgressBar(persist=True, bar_format=evaluator_bar_format)
bar.attach(self.evaluator, output_transform=output_transform)
handler = EarlyStopping(patience=patience, score_function=self.score_function,
trainer=self.trainer, min_delta=min_delta)
self.evaluator.add_event_handler(Events.COMPLETED, handler)
self.trainer.add_event_handler(Events.EPOCH_COMPLETED(every=every), lambda: self.evaluator.run(self.test_set))
def create_trainer(self) -> None:
"""
创建 trainer engine
"""
raise NotImplementedError("请重写 create_trainer.")
def create_evaluator(self) -> None:
"""
创建 evaluator engine
"""
raise NotImplementedError("请重写 create_evaluator.")
def set_trainer(self):
"""
配置切面操作
:return:
"""
self.trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
"""
控制台记录日志
bar_format:输出的格式
"""
trainer_bar_format = "\033[0;34m{desc}【{n_fmt:0>5s}/{total_fmt:0>5s}】 {percentage:3.0f}%|{bar}{postfix} 【已执行时间:{elapsed},剩余时间:{remaining}】\033[0m"
bar = ProgressBar(persist=True, bar_format=trainer_bar_format)
RunningAverage(output_transform=lambda x: x, alpha=0.98).attach(self.trainer, 'loss')
bar.attach(self.trainer, metric_names=["loss"])
self.trainer.add_event_handler(Events.COMPLETED, lambda: torch.save(self.model, self.save))
self.trainer.add_event_handler(Events.COMPLETED, lambda: print("训练结束...."))
self.trainer.add_event_handler(Events.STARTED, lambda: print("训练开始...."))
def run(self, max_epochs, test_frequency=10) -> None:
if not hasattr(self, "train_set") or not hasattr(self, "test_set"):
raise FileExistsError("请先通过 set_dataset 方法设置数据集.")
self.create_trainer()
self.create_evaluator()
self.set_trainer()
self.early_stop(every=test_frequency)
self.trainer.run(self.train_set, max_epochs=max_epochs)
trainer的实现类
import torch
from ignite.contrib.handlers import LRScheduler
from ignite.engine import create_supervised_trainer, Events, create_supervised_evaluator
from torch import nn, optim
from torch.optim.lr_scheduler import ExponentialLR
from src.data import get_data_loaders
from src.experiment.Trainer import Trainer
from src.model import ConvLSTM
from src.util import config
from src.util.patch import reshape_patch_back
cfg = config.load_model_parameters("ConvLSTM")
class ConvLSTMTrainer(Trainer):
def __init__(self, *, model: torch.nn.Module = None, file: str = None, save: str = None, device: str = None):
super().__init__(model=model, file=file, save=save, device=device)
self.model.to(device)
def create_trainer(self) -> None:
"""
学习率衰减的代码可以写在这,虽然也是创建 Handler,我认为在这比较适合
"""
criterion = nn.MSELoss()
optimizer = optim.Adam(self.model.parameters(), lr=0.01)
self.trainer = create_supervised_trainer(model=self.model, optimizer=optimizer, loss_fn=criterion,
device=self.device)
step_scheduler = ExponentialLR(optimizer=optimizer, gamma=0.98)
scheduler = LRScheduler(step_scheduler)
self.trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
def create_evaluator(self) -> None:
self.evaluator = create_supervised_evaluator(model=self.model, metrics=self.metrics, device=self.device,
output_transform=lambda x, y, y_pred: (
reshape_patch_back(y_pred, patch_size=4),
reshape_patch_back(y, patch_size=4)
))
def set_dataset(self, train_batch_size, val_batch_size=1) -> None:
train_set, test_set = get_data_loaders("ConvLSTM", train_batch_size, val_batch_size)
setattr(self, "train_set", train_set)
setattr(self, "test_set", test_set)
测试代码
if __name__ == '__main__':
net = ConvLSTM(in_channels=cfg["in_channels"] * 4 * 4, hidden_channels_list=cfg["hidden_channels_list"],
kernel_size_list=cfg["kernel_size_list"], forget_bias=cfg["forget_bias"])
trainer = ConvLSTMTrainer(model=net, save="test.pth", device="cuda")
trainer.set_dataset(train_batch_size=2, val_batch_size=1)
trainer.run(max_epochs=3, test_frequency=1)
控制台显示