9.5.2 最小训练脚本
文件extract_features.py实现了一个最小的训练脚本,用于通过 PyTorch 的分布式数据并行(DDP)训练 DiT 模型。它主要完成以下功能:加载图像数据集,使用预训练的 VAE 模型将输入图像编码为潜在空间并进行归一化,然后将提取的特征和标签保存为 NumPy 文件。这一训练过程支持多 GPU 训练,并且可以有效地处理大规模数据集。
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
"""
使 EMA 模型朝当前模型更新。
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():