9.5.4 DiT模型的标准训练
文件train_amp.py是一个用于训练 DiT(Diffusion Transformer)模型的最小训练脚本,使用 PyTorch 和加速库来支持分布式训练,首先设置了日志记录和实验目录,然后创建和初始化 DiT 模型、优化器和数据加载器。脚本通过循环进行多个训练轮次,计算损失并更新模型参数,同时使用指数移动平均(EMA)来改进模型稳定性。定期记录训练损失并保存模型检查点。训练完成后,模型会进入评估模式,准备进行进一步的推断或评估。
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
"""
将EMA模型更新为当前模型。
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
name = name.replace("module.", "