10.4 多模态成对抗网络(GAN)模型
在本项目中,GAN模型的作用是根据输入的文本描述生成与之匹配的高质量图像,通过生成器网络生成图像,判别器网络评估图像的真实性,以及匹配网络判断生成图像与文本描述的一致性,以实现文本到图像的转换。
10.4.1 准备CLIP模型
文件perpare.py定义了一系列函数来加载和准备CLIP模型及其相关组件,设置训练和评估所需的生成器、判别器、比较器模型,以及文本和图像编码器。另外,还定义了函数来准备数据集和数据加载器,以便在训练过程中能够有效地读取和处理图像和文本数据。这些功能旨在支持在多GPU环境下的分布式训练,确保模型和数据能够高效地进行并行处理。
def load_clip(clip_info, device):
"""
导入并加载CLIP模型
"""
import clip as clip
model = clip.load(clip_info['type'], device=device)[0]
return model
def prepare_models(args):
"""
准备所需的模型
"""
# 从命令行参数中设置设备、GPU和本地排名用于分布式训练
device = args.device
local_rank = args.local_rank
multi_gpus = args.multi_gpus
# 创建用于训练和评估的CLIP模型
CLIP4trn = load_clip(args.clip4trn, device).eval()
CLIP4evl = load_clip(args.clip4evl, device).eval()
# 创建生成器、判别器、比较器模型及文本和图像编码器
NetG, NetD, NetC, CLIP_IMG_ENCODER, CLIP_TXT_ENCODER = choose_model(args.model)
# 冻结CLIP图像编码器的权重并设置为评估模式
CLIP_img_enc = CLIP_IMG_ENCODER(CLIP4trn).to(device)
for p in CLIP_img_enc.parameters():
p.requires_grad = False
CLIP_img_enc.eval()
# 冻结CLIP文本编码器的权重并设置为评估模式
CLIP_txt_enc = CLIP_TXT_ENCODER(CLIP4trn).to(device)
for p in CLIP_txt_enc.parameters():
p.requires_grad = False
CLIP_txt_enc.eval()
# 初始化并配置CLIP-GAN模型
netG = NetG(args.nf, args.z_dim, args.cond_dim, args.imsize, args.ch_size, args.mixed_precision, CLIP4trn).to(device)
netD = NetD(args.nf, args.imsize, args.ch_size, args.mixed_precision).to(device)
netC = NetC(args.nf, args.cond_dim, args.mixed_precision).to(device)
# 如果有多个GPU且训练为True,将模型移到分布式训练环境中并包装到DistributedDataParallel()中
if (args.multi_gpus) and (args.train):
# 打印可用GPU的数量
print("Let's use ", torch.cuda.device_count(), " GPUs!")
# 使用torchrun将生成器模型包装到DistributedDataParallel()中以进行分布式和并行训练
netG = torch.nn.parallel.DistributedDataParallel(
netG,
broadcast_buffers=False,
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=True
)
# 使用torchrun将判别器模型包装到DistributedDataParallel()中以进行分布式和并行训练
netD = torch.nn.parallel.DistributedDataParallel(
netD,
broadcast_buffers=False,
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=True
)
# 使用torchrun将比较器模型包装到DistributedDataParallel()中以进行分布式和并行训练
netC = torch.nn.parallel.DistributedDataParallel(
netC,
broadcast_buffers=False,
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=True
)
# 返回配置和初始化的模型
return CLIP4trn, CLIP4evl, CLIP_img_enc, CLIP_txt_enc, netG, netD, netC
def prepare_dataset(args, split, transform):
# 如果图像不是RGB,则设置输入图像大小为256,否则设置为给定值
if args.ch_size != 3:
imsize = 256
else:
imsize = args.imsize
# 定义输入图像的变换
if transform is not None:
image_transform = transform
else:
image_transform = transforms.Compose([
transforms.Resize(int(imsize * 76 / 64)),
transforms.RandomCrop(imsize),
transforms.RandomHorizontalFlip(),
])
# 导入自定义数据集类并初始化
from lib.datasets import TextImgDataset as Dataset
dataset = Dataset(split=split, transform=image_transform, args=args)
return dataset
def prepare_datasets(args, transform):
"""
切分数据集并创建数据集对象
"""
# 训练数据集
train_dataset = prepare_dataset(args, split='train', transform=transform)
# 测试数据集
val_dataset = prepare_dataset(args, split='test', transform=transform)
return train_dataset, val_dataset
def prepare_dataloaders(args, transform=None):
"""
创建数据加载器以从文件夹中检索图像数据集
"""
# 定义加载数据集的超参数,如批量大小和工作线程数量
batch_size = args.batch_size
num_workers = args.num_workers
# 调用prepare_datasets函数将数据集分为训练和测试数据集
train_dataset, valid_dataset = prepare_datasets(args, transform)
# 创建训练数据加载器并将其包装用于多GPU分布式训练
if args.multi_gpus == True:
train_sampler = DistributedSampler(train_dataset)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
drop_last=True,
num_workers=num_workers,
sampler=train_sampler
)
else:
train_sampler = None
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
drop_last=True,
num_workers=num_workers,
shuffle='True'
)
# 创建验证数据加载器并将其包装用于多GPU分布式训练
if args.multi_gpus == True:
valid_sampler = DistributedSampler(valid_dataset)
valid_dataloader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=batch_size,
drop_last=True,
num_workers=num_workers,
sampler=valid_sampler
)
else:
valid_dataloader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=batch_size,
drop_last=True,
num_workers=num_workers,
shuffle='True'
)
return train_dataloader, valid_dataloader, train_dataset, valid_dataset, train_sampler
10.4.2 训练、评估和保存GAN模型
文件modules.py定义了一些函数,用于训练、评估和保存生成对抗网络(GAN)的模型,以及生成图像样本和计算评价指标。
(1)定义函数train,用于训练生成对抗网络(GAN)的函数。在训练过程中,通过优化器更新生成器(NetG)和鉴别器(NetD),同时使用匹配器(NetC)来评估图像与文本描述的一致性。
def train(dataloader, netG, netD, netC, text_encoder, image_encoder, optimizerG, optimizerD, scaler_G, scaler_D, args):
"""
使用给定的数据集通过dataloader训练GAN网络。
"""
# 从命令行参数中获取超参数
batch_size = args.batch_size
device = args.device
epoch = args.current_epoch
max_epoch = args.max_epoch
z_dim = args.z_dim
# 设置模型为训练模式
netG, netD, netC, image_encoder = netG.train(), netD.train(), netC.train(), image_encoder.train()
# 如果启用了多GPU训练,则由主GPU显示进度条
if (args.multi_gpus == True) and (get_rank() != 0):
None
else:
loop = tqdm(total=len(dataloader))
for step, data in enumerate(dataloader, 0):
##############################
# 训练鉴别器(NetD) #
##############################
optimizerD.zero_grad()
# 启用混合精度上下文
with torch.cuda.amp.autocast() if args.mixed_precision else dummy_context_mgr() as mpc:
# 准备数据以进行前向传播处理
real, captions, CLIP_tokens, sent_emb, words_embs, keys = prepare_data(data, text_encoder, device)
real = real.requires_grad_()
sent_emb = sent_emb.requires_grad_()
words_embs = words_embs.requires_grad_()
# 预测真实图像的嵌入和特征向量
CLIP_real, real_emb = image_encoder(real)
real_feats = netD(CLIP_real)
# 预测真实图像的相似度分数
pred_real, errD_real = predict_loss(netC, real_feats, sent_emb, negtive=False)
# 预测与文本提示不匹配
mis_sent_emb = torch.cat((sent_emb[1:], sent_emb[0:1]), dim=0).detach()
_, errD_mis = predict_loss(netC, real_feats, mis_sent_emb, negtive=True)
# 合成假图像和特征嵌入
noise = torch.randn(batch_size, z_dim).to(device)
fake = netG(noise, sent_emb)
CLIP_fake, fake_emb = image_encoder(fake)
fake_feats = netD(CLIP_fake.detach())
# 预测假图像的相似度分数
_, errD_fake = predict_loss(netC, fake_feats, sent_emb, negtive=True)
# 计算Manifold-Aware Gradient Penalty(MAGP)以对GAN中的文本语义进行正则化
if args.mixed_precision:
errD_MAGP = MA_GP_MP(CLIP_real, sent_emb, pred_real, scaler_D)
else:
errD_MAGP = MA_GP_FP32(CLIP_real, sent_emb, pred_real)
# 计算鉴别器损失
with torch.cuda.amp.autocast() if args.mixed_precision else dummy_context_mgr() as mpc:
errD = errD_real + (errD_fake + errD_mis) / 2.0 + errD_MAGP
# 更新鉴别器网络权重
if args.mixed_precision:
scaler_D.scale(errD).backward()
scaler_D.step(optimizerD)
scaler_D.update()
if scaler_D.get_scale() < args.scaler_min:
scaler_D.update(16384.0)
else:
errD.backward()
optimizerD.step()
##############################
# 训练生成器(NetG) #
##############################
optimizerG.zero_grad()
with torch.cuda.amp.autocast() if args.mixed_precision else dummy_context_mgr() as mpc:
# 创建假图像的特征张量或嵌入
fake_feats = netD(CLIP_fake)
# 预测假图像特征与句子嵌入之间的相似度
output = netC(fake_feats, sent_emb)
# 计算假嵌入与句子嵌入之间的文本相似度
text_img_sim = torch.cosine_similarity(fake_emb, sent_emb).mean()
# 计算生成器网络的损失
errG = -output.mean() - args.sim_w * text_img_sim
# 更新生成器网络权重
if args.mixed_precision:
scaler_G.scale(errG).backward()
scaler_G.step(optimizerG)
scaler_G.update()
if scaler_G.get_scale() < args.scaler_min:
scaler_G.update(16384.0)
else:
errG.backward()
optimizerG.step()
# 显示训练信息
if (args.multi_gpus == True) and (get_rank() != 0):
None
else:
loop.update(1)
loop.set_description(f'Train Epoch [{epoch}/{max_epoch}]')
loop.set_postfix()
if (args.multi_gpus == True) and (get_rank() != 0):
None
else:
loop.close()
(2)定义函数test,用于评估生成对抗网络(GAN)在测试数据集上的性能,通过计算生成图像的FID(Fréchet Inception Distance)和文本-图像相似度(TI_sim)。
def test(dataloader, text_encoder, netG, PTM, device, m1, s1, epoch, max_epoch, times, z_dim, batch_size):
FID, TI_sim = calculate_FID_CLIP_sim(
dataloader, text_encoder, netG, PTM, device, m1, s1,
epoch, max_epoch, times, z_dim, batch_size
)
return FID, TI_sim
(3)函数save_model的功能是将生成器(netG)、判别器(netD)、匹配器(netC)以及它们各自的优化器(optG、optD)的状态保存到持久化存储中,并附带额外的信息,如当前轮数和训练步数。
def save_model(netG, netD, netC, optG, optD, epoch, multi_gpus, step, save_path):
"""
保存模型到持久化存储,并附带额外信息
Args:
netG: 生成器模型
netD: 判别器模型
netC: 匹配器模型
optG: 生成器优化器
optD: 判别器优化器
epoch: 当前的训练轮数
multi_gpus: 是否使用多GPU训练
step: 当前训练步数
save_path: 模型保存路径
Returns:
None
"""
# 检查是否是分布式训练
if (multi_gpus == True) and (get_rank() != 0):
None
else:
# 保存生成器、判别器、匹配器模型
state = {
'model': {
'netG': netG.state_dict(),
'netD': netD.state_dict(),
'netC': netC.state_dict()
},
# 保存优化器状态
'optimizers': {
'optimizer_G': optG.state_dict(),
'optimizer_D': optD.state_dict()
},
# 保存轮数信息
'epoch': epoch
}
# 保存模型检查点
torch.save(state, '%s/state_epoch_%03d_%03d.pth' % (save_path, epoch, step))
(4)函数MA_GP_MP的功能是计算生成对抗网络(GAN)中的MA-GP(Manifold-Aware Gradient Penalty)惩罚项,用于对抗器(判别器)的训练。给定真实图像、句子嵌入和判别器的相似度分数,函数计算对输入句子关于图像的梯度,并返回MA-GP惩罚项。
def MA_GP_MP(img, sent, out, scaler):
"""
Args:
img: 真实图像
sent: 句子嵌入
out: 判别器的相似度分数
scaler: 混合精度缩放器对象
Returns:
d_loss_gp: MA-GP惩罚项
"""
# 计算关于图像和句子嵌入的得分梯度
grads = torch.autograd.grad(
outputs=scaler.scale(out),
inputs=(img, sent),
grad_outputs=torch.ones_like(out),
retain_graph=True,
create_graph=True,
only_inputs=True
)
inv_scale = 1. / (scaler.get_scale() + float("1e-8"))
# 重新缩放梯度以抵消缩放因子
grads = [grad * inv_scale for grad in grads]
with torch.cuda.amp.autocast():
# 将梯度连接成一个张量
grad0 = grads[0].view(grads[0].size(0), -1)
grad1 = grads[1].view(grads[1].size(0), -1)
grad = torch.cat((grad0, grad1), dim=1)
# 计算MA-GP惩罚项
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
d_loss_gp = 2.0 * torch.mean((grad_l2norm) ** 6)
return d_loss_gp
(5)函数MA_GP_FP32的功能是在单精度(FP32)下计算生成对抗网络(GAN)中的MA-GP(Manifold-Aware Gradient Penalty)损失。提供了真实图像、句子嵌入和判别器的相似度分数,函数MA_GP_FP32计算对输入句子关于图像的梯度,并返回MA-GP惩罚项。
def MA_GP_FP32(img, sent, out):
"""
Args:
img: 真实图像
sent: 句子嵌入
out: 判别器的相似度分数
Returns:
d_loss_gp: MA-GP惩罚项
"""
# 计算关于图像和句子嵌入的得分梯度
grads = torch.autograd.grad(
outputs=out,
inputs=(img, sent),
grad_outputs=torch.ones(out.size()).cuda(),
retain_graph=True,
create_graph=True,
only_inputs=True
)
# 将梯度连接成一个张量
grad0 = grads[0].view(grads[0].size(0), -1)
grad1 = grads[1].view(grads[1].size(0), -1)
grad = torch.cat((grad0, grad1), dim=1)
# 计算MA-GP惩罚项
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
d_loss_gp = 2.0 * torch.mean((grad_l2norm) ** 6)
return d_loss_gp
(6)函数sample的功能是在训练后生成样本并保存,从数据加载器中获取数据,通过生成器生成假图像,并将生成的图像和相关文本保存到指定目录下的文件夹中。
def sample(dataloader, netG, text_encoder, save_dir, device, multi_gpus, z_dim, stamp):
"""
生成样本并保存
Args:
dataloader: 数据加载器
netG: 生成器模型
text_encoder: 文本编码器
save_dir: 保存目录
device: 设备(GPU或CPU)
multi_gpus: 是否使用多个GPU进行训练
z_dim: 噪声向量的维度
stamp: 时间戳或标识符
"""
# 将生成器设置为评估模式
netG.eval()
for step, data in enumerate(dataloader, 0):
# 准备真实数据
real, captions, CLIP_tokens, sent_emb, words_embs, keys = prepare_data(data, text_encoder, device)
# 生成假图像
batch_size = sent_emb.size(0)
with torch.no_grad():
noise = torch.randn(batch_size, z_dim).to(device)
fake_imgs = netG(noise, sent_emb, eval=True).float()
# 将像素值限制在[-1, 1]范围内
fake_imgs = torch.clamp(fake_imgs, -1., 1.)
# 保存生成的图像,附加GPU id
if multi_gpus == True:
batch_img_name = 'step_%04d.png' % (step)
batch_img_save_dir = osp.join(save_dir, 'batch', str('gpu%d' % (get_rank())), 'imgs')
batch_img_save_name = osp.join(batch_img_save_dir, batch_img_name)
batch_txt_name = 'step_%04d.txt' % (step)
batch_txt_save_dir = osp.join(save_dir, 'batch', str('gpu%d' % (get_rank())), 'txts')
batch_txt_save_name = osp.join(batch_txt_save_dir, batch_txt_name)
else:
batch_img_name = 'step_%04d.png' % (step)
batch_img_save_dir = osp.join(save_dir, 'batch', 'imgs')
batch_img_save_name = osp.join(batch_img_save_dir, batch_img_name)
batch_txt_name = 'step_%04d.txt' % (step)
batch_txt_save_dir = osp.join(save_dir, 'batch', 'txts')
batch_txt_save_name = osp.join(batch_txt_save_dir, batch_txt_name)
# 创建目录
mkdir_p(batch_img_save_dir)
vutils.save_image(fake_imgs.data, batch_img_save_name, nrow=8, value_range=(-1, 1), normalize=True)
mkdir_p(batch_txt_save_dir)
# 保存图像和文本
txt = open(batch_txt_save_name, 'w')
for cap in captions:
txt.write(cap + '\n')
txt.close()
for j in range(batch_size):
im = fake_imgs[j].data.cpu().numpy()
im = (im + 1.0) * 127.5
im = im.astype(np.uint8)
im = np.transpose(im, (1, 2, 0))
im = Image.fromarray(im)
# 保存假图像
if multi_gpus == True:
single_img_name = 'batch_%04d.png' % (j)
single_img_save_dir = osp.join(save_dir, 'single', str('gpu%d' % (get_rank())), 'step%04d' % (step))
single_img_save_name = osp.join(single_img_save_dir, single_img_name)
else:
single_img_name = 'step_%04d.png' % (step)
single_img_save_dir = osp.join(save_dir, 'single', 'step%04d' % (step))
single_img_save_name = osp.join(single_img_save_dir, single_img_name)
mkdir_p(single_img_save_dir)
im.save(single_img_save_name)
# 打印进度
if (multi_gpus == True) and (get_rank() != 0):
None
else:
print('Step: %d' % (step))