(9-3-2)基于Diffusion Transformer的文生图系统:训练模型(2)最小训练脚本+ 实现DiT模型

9.5.2  最小训练脚本

文件extract_features.py实现了一个最小的训练脚本,用于通过 PyTorch 的分布式数据并行(DDP)训练 DiT 模型。它主要完成以下功能:加载图像数据集,使用预训练的 VAE 模型将输入图像编码为潜在空间并进行归一化,然后将提取的特征和标签保存为 NumPy 文件。这一训练过程支持多 GPU 训练,并且可以有效地处理大规模数据集。

@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    使 EMA 模型朝当前模型更新。
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())
    
    for name, param in model_params.items():
 

猜你喜欢

转载自blog.csdn.net/asd343442/article/details/143163733