(10-3)基于多模态模型的文生图系统:多模态成对抗网络(GAN)模型(1)

10.4  多模态成对抗网络(GAN)模型

在本项目中,GAN模型的作用是根据输入的文本描述生成与之匹配的高质量图像,通过生成器网络生成图像,判别器网络评估图像的真实性,以及匹配网络判断生成图像与文本描述的一致性,以实现文本到图像的转换。

10.4.1  准备CLIP模型

文件perpare.py定义了一系列函数来加载和准备CLIP模型及其相关组件,设置训练和评估所需的生成器、判别器、比较器模型,以及文本和图像编码器。另外,还定义了函数来准备数据集和数据加载器,以便在训练过程中能够有效地读取和处理图像和文本数据。这些功能旨在支持在多GPU环境下的分布式训练,确保模型和数据能够高效地进行并行处理。

def load_clip(clip_info, device):
    """
    导入并加载CLIP模型
    """
    import clip as clip
    model = clip.load(clip_info['type'], device=device)[0]
    return model

def prepare_models(args):
    """
    准备所需的模型
    """
    # 从命令行参数中设置设备、GPU和本地排名用于分布式训练
    device = args.device
    local_rank = args.local_rank
    multi_gpus = args.multi_gpus

    # 创建用于训练和评估的CLIP模型
    CLIP4trn = load_clip(args.clip4trn, device).eval()
    CLIP4evl = load_clip(args.clip4evl, device).eval()

    # 创建生成器、判别器、比较器模型及文本和图像编码器
    NetG, NetD, NetC, CLIP_IMG_ENCODER, CLIP_TXT_ENCODER = choose_model(args.model)

    # 冻结CLIP图像编码器的权重并设置为评估模式
    CLIP_img_enc = CLIP_IMG_ENCODER(CLIP4trn).to(device)
    for p in CLIP_img_enc.parameters():
        p.requires_grad = False
    CLIP_img_enc.eval()

    # 冻结CLIP文本编码器的权重并设置为评估模式
    CLIP_txt_enc = CLIP_TXT_ENCODER(CLIP4trn).to(device)
    for p in CLIP_txt_enc.parameters():
        p.requires_grad = False
    CLIP_txt_enc.eval()

    # 初始化并配置CLIP-GAN模型
    netG = NetG(args.nf, args.z_dim, args.cond_dim, args.imsize, args.ch_size, args.mixed_precision, CLIP4trn).to(device)
    netD = NetD(args.nf, args.imsize, args.ch_size, args.mixed_precision).to(device)
    netC = NetC(args.nf, args.cond_dim, args.mixed_precision).to(device)

    # 如果有多个GPU且训练为True,将模型移到分布式训练环境中并包装到DistributedDataParallel()中
    if (args.multi_gpus) and (args.train):
        # 打印可用GPU的数量
        print("Let's use ", torch.cuda.device_count(), " GPUs!")

        # 使用torchrun将生成器模型包装到DistributedDataParallel()中以进行分布式和并行训练
        netG = torch.nn.parallel.DistributedDataParallel(
            netG,
            broadcast_buffers=False,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True
        )

        # 使用torchrun将判别器模型包装到DistributedDataParallel()中以进行分布式和并行训练
        netD = torch.nn.parallel.DistributedDataParallel(
            netD,
            broadcast_buffers=False,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True
        )

        # 使用torchrun将比较器模型包装到DistributedDataParallel()中以进行分布式和并行训练
        netC = torch.nn.parallel.DistributedDataParallel(
            netC,
            broadcast_buffers=False,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True
        )
    # 返回配置和初始化的模型
    return CLIP4trn, CLIP4evl, CLIP_img_enc, CLIP_txt_enc, netG, netD, netC


def prepare_dataset(args, split, transform):
    # 如果图像不是RGB,则设置输入图像大小为256,否则设置为给定值
    if args.ch_size != 3:
        imsize = 256
    else:
        imsize = args.imsize

    # 定义输入图像的变换
    if transform is not None:
        image_transform = transform
    else:
        image_transform = transforms.Compose([
            transforms.Resize(int(imsize * 76 / 64)),
            transforms.RandomCrop(imsize),
            transforms.RandomHorizontalFlip(),
        ])

    # 导入自定义数据集类并初始化
    from lib.datasets import TextImgDataset as Dataset
    dataset = Dataset(split=split, transform=image_transform, args=args)
    return dataset


def prepare_datasets(args, transform):
    """
    切分数据集并创建数据集对象
    """
    # 训练数据集
    train_dataset = prepare_dataset(args, split='train', transform=transform)
    # 测试数据集
    val_dataset = prepare_dataset(args, split='test', transform=transform)
    return train_dataset, val_dataset


def prepare_dataloaders(args, transform=None):
    """
    创建数据加载器以从文件夹中检索图像数据集
    """
    # 定义加载数据集的超参数,如批量大小和工作线程数量
    batch_size = args.batch_size
    num_workers = args.num_workers

    # 调用prepare_datasets函数将数据集分为训练和测试数据集
    train_dataset, valid_dataset = prepare_datasets(args, transform)

    # 创建训练数据加载器并将其包装用于多GPU分布式训练
    if args.multi_gpus == True:
        train_sampler = DistributedSampler(train_dataset)
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            drop_last=True,
            num_workers=num_workers,
            sampler=train_sampler
        )
    else:
        train_sampler = None
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            drop_last=True,
            num_workers=num_workers,
            shuffle='True'
        )

    # 创建验证数据加载器并将其包装用于多GPU分布式训练
    if args.multi_gpus == True:
        valid_sampler = DistributedSampler(valid_dataset)
        valid_dataloader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=batch_size,
            drop_last=True,
            num_workers=num_workers,
            sampler=valid_sampler
        )
    else:
        valid_dataloader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=batch_size,
            drop_last=True,
            num_workers=num_workers,
            shuffle='True'
        )

    return train_dataloader, valid_dataloader, train_dataset, valid_dataset, train_sampler

10.4.2  训练、评估和保存GAN模型

文件modules.py定义了一些函数,用于训练、评估和保存生成对抗网络(GAN)的模型,以及生成图像样本和计算评价指标。

(1)定义函数train,用于训练生成对抗网络(GAN)的函数。在训练过程中,通过优化器更新生成器(NetG)和鉴别器(NetD),同时使用匹配器(NetC)来评估图像与文本描述的一致性。

def train(dataloader, netG, netD, netC, text_encoder, image_encoder, optimizerG, optimizerD, scaler_G, scaler_D, args):
    """
    使用给定的数据集通过dataloader训练GAN网络。
    """
    # 从命令行参数中获取超参数
    batch_size = args.batch_size
    device = args.device
    epoch = args.current_epoch
    max_epoch = args.max_epoch
    z_dim = args.z_dim
    
    # 设置模型为训练模式
    netG, netD, netC, image_encoder = netG.train(), netD.train(), netC.train(), image_encoder.train()
    
    # 如果启用了多GPU训练,则由主GPU显示进度条
    if (args.multi_gpus == True) and (get_rank() != 0):
        None
    else:
        loop = tqdm(total=len(dataloader))
    
    for step, data in enumerate(dataloader, 0):
        ##############################
        # 训练鉴别器(NetD)          #
        ##############################
        optimizerD.zero_grad()

        # 启用混合精度上下文
        with torch.cuda.amp.autocast() if args.mixed_precision else dummy_context_mgr() as mpc:
            # 准备数据以进行前向传播处理
            real, captions, CLIP_tokens, sent_emb, words_embs, keys = prepare_data(data, text_encoder, device)
            real = real.requires_grad_()
            sent_emb = sent_emb.requires_grad_()
            words_embs = words_embs.requires_grad_()
            
            # 预测真实图像的嵌入和特征向量
            CLIP_real, real_emb = image_encoder(real)
            real_feats = netD(CLIP_real)

            # 预测真实图像的相似度分数
            pred_real, errD_real = predict_loss(netC, real_feats, sent_emb, negtive=False)
            
            # 预测与文本提示不匹配
            mis_sent_emb = torch.cat((sent_emb[1:], sent_emb[0:1]), dim=0).detach()
            _, errD_mis = predict_loss(netC, real_feats, mis_sent_emb, negtive=True)
            
            # 合成假图像和特征嵌入
            noise = torch.randn(batch_size, z_dim).to(device)
            fake = netG(noise, sent_emb)
            CLIP_fake, fake_emb = image_encoder(fake)
            fake_feats = netD(CLIP_fake.detach())

            # 预测假图像的相似度分数
            _, errD_fake = predict_loss(netC, fake_feats, sent_emb, negtive=True)

        # 计算Manifold-Aware Gradient Penalty(MAGP)以对GAN中的文本语义进行正则化
        if args.mixed_precision:
            errD_MAGP = MA_GP_MP(CLIP_real, sent_emb, pred_real, scaler_D)
        else:
            errD_MAGP = MA_GP_FP32(CLIP_real, sent_emb, pred_real)

        # 计算鉴别器损失
        with torch.cuda.amp.autocast() if args.mixed_precision else dummy_context_mgr() as mpc:
            errD = errD_real + (errD_fake + errD_mis) / 2.0 + errD_MAGP

        # 更新鉴别器网络权重
        if args.mixed_precision:
            scaler_D.scale(errD).backward()
            scaler_D.step(optimizerD)
            scaler_D.update()
            if scaler_D.get_scale() < args.scaler_min:
                scaler_D.update(16384.0)
        else:
            errD.backward()
            optimizerD.step()

        ##############################
        # 训练生成器(NetG)          #
        ##############################
        optimizerG.zero_grad()
        with torch.cuda.amp.autocast() if args.mixed_precision else dummy_context_mgr() as mpc:
            # 创建假图像的特征张量或嵌入
            fake_feats = netD(CLIP_fake)

            # 预测假图像特征与句子嵌入之间的相似度
            output = netC(fake_feats, sent_emb)

            # 计算假嵌入与句子嵌入之间的文本相似度
            text_img_sim = torch.cosine_similarity(fake_emb, sent_emb).mean()

            # 计算生成器网络的损失
            errG = -output.mean() - args.sim_w * text_img_sim

        # 更新生成器网络权重
        if args.mixed_precision:
            scaler_G.scale(errG).backward()
            scaler_G.step(optimizerG)
            scaler_G.update()
            if scaler_G.get_scale() < args.scaler_min:
                scaler_G.update(16384.0)
        else:
            errG.backward()
            optimizerG.step()

        # 显示训练信息
        if (args.multi_gpus == True) and (get_rank() != 0):
            None
        else:
            loop.update(1)
            loop.set_description(f'Train Epoch [{epoch}/{max_epoch}]')
            loop.set_postfix()

    if (args.multi_gpus == True) and (get_rank() != 0):
        None
    else:
        loop.close()

(2)定义函数test,用于评估生成对抗网络(GAN)在测试数据集上的性能,通过计算生成图像的FID(Fréchet Inception Distance)和文本-图像相似度(TI_sim)。

def test(dataloader, text_encoder, netG, PTM, device, m1, s1, epoch, max_epoch, times, z_dim, batch_size):
    FID, TI_sim = calculate_FID_CLIP_sim(
        dataloader, text_encoder, netG, PTM, device, m1, s1,
        epoch, max_epoch, times, z_dim, batch_size
    )
    return FID, TI_sim

(3)函数save_model的功能是将生成器(netG)、判别器(netD)、匹配器(netC)以及它们各自的优化器(optG、optD)的状态保存到持久化存储中,并附带额外的信息,如当前轮数和训练步数。

def save_model(netG, netD, netC, optG, optD, epoch, multi_gpus, step, save_path):
    """
    保存模型到持久化存储,并附带额外信息

    Args:
        netG: 生成器模型
        netD: 判别器模型
        netC: 匹配器模型
        optG: 生成器优化器
        optD: 判别器优化器
        epoch: 当前的训练轮数
        multi_gpus: 是否使用多GPU训练
        step: 当前训练步数
        save_path: 模型保存路径

    Returns:
        None
    """
    # 检查是否是分布式训练
    if (multi_gpus == True) and (get_rank() != 0):
        None
    else:
        # 保存生成器、判别器、匹配器模型
        state = {
            'model': {
                'netG': netG.state_dict(),
                'netD': netD.state_dict(),
                'netC': netC.state_dict()
            },
            # 保存优化器状态
            'optimizers': {
                'optimizer_G': optG.state_dict(),
                'optimizer_D': optD.state_dict()
            },
            # 保存轮数信息
            'epoch': epoch
        }
        # 保存模型检查点
        torch.save(state, '%s/state_epoch_%03d_%03d.pth' % (save_path, epoch, step))

(4)函数MA_GP_MP的功能是计算生成对抗网络(GAN)中的MA-GP(Manifold-Aware Gradient Penalty)惩罚项,用于对抗器(判别器)的训练。给定真实图像、句子嵌入和判别器的相似度分数,函数计算对输入句子关于图像的梯度,并返回MA-GP惩罚项。

def MA_GP_MP(img, sent, out, scaler):
    """
    Args:
        img: 真实图像
        sent: 句子嵌入
        out: 判别器的相似度分数
        scaler: 混合精度缩放器对象

    Returns:
        d_loss_gp: MA-GP惩罚项
    """
    # 计算关于图像和句子嵌入的得分梯度
    grads = torch.autograd.grad(
        outputs=scaler.scale(out),
        inputs=(img, sent),
        grad_outputs=torch.ones_like(out),
        retain_graph=True,
        create_graph=True,
        only_inputs=True
    )
    inv_scale = 1. / (scaler.get_scale() + float("1e-8"))
    # 重新缩放梯度以抵消缩放因子
    grads = [grad * inv_scale for grad in grads]
    with torch.cuda.amp.autocast():
        # 将梯度连接成一个张量
        grad0 = grads[0].view(grads[0].size(0), -1)
        grad1 = grads[1].view(grads[1].size(0), -1)
        grad = torch.cat((grad0, grad1), dim=1)

        # 计算MA-GP惩罚项
        grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
        d_loss_gp = 2.0 * torch.mean((grad_l2norm) ** 6)
    return d_loss_gp

(5)函数MA_GP_FP32的功能是在单精度(FP32)下计算生成对抗网络(GAN)中的MA-GP(Manifold-Aware Gradient Penalty)损失。提供了真实图像、句子嵌入和判别器的相似度分数,函数MA_GP_FP32计算对输入句子关于图像的梯度,并返回MA-GP惩罚项。

def MA_GP_FP32(img, sent, out):
    """
    Args:
        img: 真实图像
        sent: 句子嵌入
        out: 判别器的相似度分数
    Returns:
        d_loss_gp: MA-GP惩罚项
    """
    # 计算关于图像和句子嵌入的得分梯度
    grads = torch.autograd.grad(
        outputs=out,
        inputs=(img, sent),
        grad_outputs=torch.ones(out.size()).cuda(),
        retain_graph=True,
        create_graph=True,
        only_inputs=True
    )
    # 将梯度连接成一个张量
    grad0 = grads[0].view(grads[0].size(0), -1)
    grad1 = grads[1].view(grads[1].size(0), -1)
    grad = torch.cat((grad0, grad1), dim=1)
    # 计算MA-GP惩罚项
    grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
    d_loss_gp = 2.0 * torch.mean((grad_l2norm) ** 6)
    return d_loss_gp

(6)函数sample的功能是在训练后生成样本并保存,从数据加载器中获取数据,通过生成器生成假图像,并将生成的图像和相关文本保存到指定目录下的文件夹中。

def sample(dataloader, netG, text_encoder, save_dir, device, multi_gpus, z_dim, stamp):
    """
    生成样本并保存
    Args:
        dataloader: 数据加载器
        netG: 生成器模型
        text_encoder: 文本编码器
        save_dir: 保存目录
        device: 设备(GPU或CPU)
        multi_gpus: 是否使用多个GPU进行训练
        z_dim: 噪声向量的维度
        stamp: 时间戳或标识符
    """
    # 将生成器设置为评估模式
    netG.eval()
    for step, data in enumerate(dataloader, 0):
        # 准备真实数据
        real, captions, CLIP_tokens, sent_emb, words_embs, keys = prepare_data(data, text_encoder, device)
        # 生成假图像
        batch_size = sent_emb.size(0)
        with torch.no_grad():
            noise = torch.randn(batch_size, z_dim).to(device)
            fake_imgs = netG(noise, sent_emb, eval=True).float()
            # 将像素值限制在[-1, 1]范围内
            fake_imgs = torch.clamp(fake_imgs, -1., 1.)
            # 保存生成的图像,附加GPU id
            if multi_gpus == True:
                batch_img_name = 'step_%04d.png' % (step)
                batch_img_save_dir = osp.join(save_dir, 'batch', str('gpu%d' % (get_rank())), 'imgs')
                batch_img_save_name = osp.join(batch_img_save_dir, batch_img_name)
                batch_txt_name = 'step_%04d.txt' % (step)
                batch_txt_save_dir = osp.join(save_dir, 'batch', str('gpu%d' % (get_rank())), 'txts')
                batch_txt_save_name = osp.join(batch_txt_save_dir, batch_txt_name)
            else:
                batch_img_name = 'step_%04d.png' % (step)
                batch_img_save_dir = osp.join(save_dir, 'batch', 'imgs')
                batch_img_save_name = osp.join(batch_img_save_dir, batch_img_name)
                batch_txt_name = 'step_%04d.txt' % (step)
                batch_txt_save_dir = osp.join(save_dir, 'batch', 'txts')
                batch_txt_save_name = osp.join(batch_txt_save_dir, batch_txt_name)
            # 创建目录
            mkdir_p(batch_img_save_dir)
            vutils.save_image(fake_imgs.data, batch_img_save_name, nrow=8, value_range=(-1, 1), normalize=True)
            mkdir_p(batch_txt_save_dir)
            # 保存图像和文本
            txt = open(batch_txt_save_name, 'w')
            for cap in captions:
                txt.write(cap + '\n')
            txt.close()
            for j in range(batch_size):
                im = fake_imgs[j].data.cpu().numpy()
                im = (im + 1.0) * 127.5
                im = im.astype(np.uint8)
                im = np.transpose(im, (1, 2, 0))
                im = Image.fromarray(im)
                # 保存假图像
                if multi_gpus == True:
                    single_img_name = 'batch_%04d.png' % (j)
                    single_img_save_dir = osp.join(save_dir, 'single', str('gpu%d' % (get_rank())), 'step%04d' % (step))
                    single_img_save_name = osp.join(single_img_save_dir, single_img_name)
                else:
                    single_img_name = 'step_%04d.png' % (step)
                    single_img_save_dir = osp.join(save_dir, 'single', 'step%04d' % (step))
                    single_img_save_name = osp.join(single_img_save_dir, single_img_name)
                mkdir_p(single_img_save_dir)
                im.save(single_img_save_name)

        # 打印进度
        if (multi_gpus == True) and (get_rank() != 0):
            None
        else:
            print('Step: %d' % (step))

未完待续

猜你喜欢

转载自blog.csdn.net/asd343442/article/details/143402857