生成对抗网络CycleGAN

1.介绍

论文:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

论文地址:https://arxiv.org/abs/1703.10593

什么是CycleGAN:CycleGAN主要用于图像之间的转换,假设有两个不成对的图像X和Y,算法训练去学习一个“自动相互转换”,训练时不需要成对的配对样本,只需要源域和目标域的图像。训练后网络就能实现对图像源域到目标域的迁移。CycleGAN适用于非配对的图像到图像转换,解决了模型需要成对数据进行训练的困难。

与pix2pixGAN的区别:二者都可以做图像变换,pix2pix模型必须要求成对数据(paired data),而CycleGAN利用非成对数据也能进行训练(unpaired data)。

 2.Cycle-GAN网络架构

相关工作:

GANDCGANCGANpix2pixGAN

CycleGAN其实就是一个 A→B 的单向 GAN 加上一个 B→A 的单向 GAN。两个 GAN 共享两个生成器,然后各自带一个判别器,所以加起来总共有两个判别器和两个生成器。一个单向 GAN 有两个 loss, 故 CycleGAN 加起来总共有四个 loss。

循环一致损失:因为网络需要保证生成的图像必须保留有原 始图像的特性,所以如果我们使用生成器GenratorA-B生 成一张假图像,那么要能够使用另外一个生成器 GenratorB-A来努力恢复成原始图像。此过程必须满足循环一致性。

identity loss:可以理解为,生成器是负责域x到域y的图像生成, 如果输入域y的图片还是应该生成域y的图片。

# 用狗的图像生成猫的图像
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from PIL import Image
import tqdm
import glob

dogs_path = glob.glob('D:\cnn\All Classfication\AlexNet\data/train\Dog/*.jpg') #获取数据集中的.jpg图片
cats_path = glob.glob('D:\cnn\All Classfication\AlexNet\data/train\Cat/*.jpg') #获取数据集中的.jpg图片
# print(cats_path[:3])
# print(dogs_path[:3])
cats_path_test = glob.glob('D:\cnn\All Classfication\AlexNet\data/val\Cat/*.jpg') #获取数据集中的.jpg图片
dogs_path_test = glob.glob('D:\cnn\All Classfication\AlexNet\data/val\Dog/*.jpg') #获取数据集中的.jpg图片

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize((256, 256)),
                                transforms.Normalize(mean=0.5, std=0.5)]) #Normalize为转化到-1~1之间

# 定义数据读取
class SGANDataset(Dataset):
    def __init__(self, imgs_path): #初始化
        super(SGANDataset, self).__init__()
        self.imgs_path     = imgs_path #定义属性

    def __len__(self):
        return len(self.imgs_path)

    def __getitem__(self, index): #对数据切片
        img_path        = self.imgs_path[index]

        # 从文件中读取图像
        pil_img         = Image.open(img_path)
        pil_img         = transform(pil_img)
        return pil_img

# 初始化训练集
dog_dataset = SGANDataset(dogs_path) #创建dataset
cat_dataset = SGANDataset(cats_path) #创建dataset

# 初始化测试集
dog_dataset_test = SGANDataset(dogs_path_test) #创建dataset
cat_dataset_test = SGANDataset(cats_path_test) #创建dataset

dog_dataloader = torch.utils.data.DataLoader(dog_dataset, batch_size=4, shuffle=True)
cat_dataloader = torch.utils.data.DataLoader(cat_dataset, batch_size=4, shuffle=True)

dog_dataloader_test = torch.utils.data.DataLoader(dog_dataset_test, batch_size=4)
cat_dataloader_test = torch.utils.data.DataLoader(cat_dataset_test, batch_size=4)

# cat_bath = next(iter(cat_dataloader)) #查看
# dog_bath = next(iter(dog_dataloader)) #查看
# print(dog_bath.shape) #torch.Size([4, 3, 256, 256])
# print(cat_bath.shape) #torch.Size([4, 3, 256, 256])

# 查看数据集
# plt.figure(figsize=(8, 12))
# for i, (dog, cat) in enumerate(zip(dog_bath[:3], cat_bath[:3])): #zip代表元组
#     # 因为dataset返回的数据是tensor,需要转为numpy格式,因为Normalize为转化到-1~1之间,所以加1再除以2将其转化到0~1之间
#     dog = (dog.permute(1, 2, 0).numpy() + 1) / 2
#     cat = (cat.permute(1, 2, 0).numpy() + 1) / 2
#     plt.subplot(3, 2, 2*i+1)
#     plt.title('dog')
#     plt.imshow(dog)
#     plt.subplot(3, 2, 2*i+2)
#     plt.title('cat')
#     plt.imshow(cat)
# plt.show()


#定义下采样模块
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.InstanceNorm2d(out_channels)

    def forward(self, x, is_bn=True): #is_bn用于确定是否使用bn层,默认为True
        x = self.conv_relu(x)
        if is_bn:
            x = self.bn(x)
        return x

#定义上采样模块
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()
        self.upconv_relu = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.InstanceNorm2d(out_channels)

    def forward(self, x, is_drop=False): #is_drop用于确定是否使用drop层,默认为False
        x = self.upconv_relu(x)
        x = self.bn(x)
        if is_drop:
            x = F.dropout2d(x)
        return x

# 定义生成器,包含6个下采样层,6个上采样层
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)     #3,256,256 -- 64,128,128
        self.down2 = Downsample(64, 128)   #64,128,128 -- 128,64,64
        self.down3 = Downsample(128, 256)  #128,64,64 -- 256,32,32
        self.down4 = Downsample(256, 512)  #256,32,32 -- 512,16,16
        self.down5 = Downsample(512, 512)  #512,16,16 -- 512,8,8
        self.down6 = Downsample(512, 512)  #512,8,8 -- 512,4,4

        self.up1 = Upsample(512, 512)      #512,4,4 -- 512,8,8
        self.up2 = Upsample(1024, 512)     #1024,8,8 -- 512,16,16
        self.up3 = Upsample(1024, 256)     #1024,16,16 -- 256,32,32
        self.up4 = Upsample(512, 128)      #512,32,32 -- 128,64,64
        self.up5 = Upsample(256, 64)       #256,64,64 -- 64,128,128
        #128,128,128 -- 3,256,256
        self.last = nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1)

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x6 = self.down6(x5)

        x6 = self.up1(x6, is_drop=True)
        x6 = torch.cat([x6, x5], dim=1)

        x6 = self.up2(x6, is_drop=True)
        x6 = torch.cat([x6, x4], dim=1)

        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x6, x3], dim=1)

        x6 = self.up4(x6)
        x6 = torch.cat([x6, x2], dim=1)

        x6 = self.up5(x6)
        x6 = torch.cat([x6, x1], dim=1)

        x6 = torch.tanh(self.last(x6))

        return x6

# 定义判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.down1 = Downsample(3, 64)
        self.down2 = Downsample(64, 128)
        self.last = nn.Conv2d(128, 1, 3)

    def forward(self, img):
        x = self.down1(img)
        x = self.down2(x)
        x =torch.sigmoid(self.last(x))
        return x

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 初始化两个生成器
gen_AB = Generator().to(device)
gen_BA = Generator().to(device)

# 初始化两个判别器
dis_A = Discriminator().to(device)
dis_B = Discriminator().to(device)

# 损失函数  1.gan loss  2.cycle consistance loss  3.identity loss
bce_loss = torch.nn.BCELoss()
l1_loss = torch.nn.L1Loss()

# 初始化优化器
# 对两个生成器同时进行优化, 使用itertools.chain对二者同时进行迭代
gen_optimizer = torch.optim.Adam(itertools.chain(gen_AB.parameters(), gen_BA.parameters()), lr=2e-4, betas=(0.5, 0.999))

# 对两个判别器分别进行优化
dis_A_optimizer = torch.optim.Adam(dis_A.parameters(), lr=2e-4, betas=(0.5, 0.999))
dis_B_optimizer = torch.optim.Adam(dis_B.parameters(), lr=2e-4, betas=(0.5, 0.999))

# 绘图函数,将每一个epoch中生成器生成的图片绘制
def gen_img_plot(model, epoch, test_input): # model为gen_AB/gen_BA,test_input
    generate = model(test_input).permute(0, 2, 3, 1).cpu().numpy() #将通道维度放在最后
    test_input = test_input.permute(0, 2, 3, 1).cpu().numpy() #1,3,256,256 -- 1,256,256,3
    plt.figure(figsize=(10, 6))
    display_list = [test_input[0], generate[0]]
    title = ['Input image', 'Generate image']
    for i in range(2):
        plt.subplot(1, 2, i + 1)
        plt.title(title[i])
        plt.imshow((display_list[i]+1)/2) #从-1~1 --> 0~1
        plt.axis('off')
    plt.savefig('./image/image_at_{}.png'.format(epoch))

test_batch = next(iter(dog_dataloader_test)) #batch_size,3,256,256
# 测试输入:选取test_batch中的第一张图片,并添加一个batch_size维度  3,256,256--1,3,256,256
test_input = torch.unsqueeze(test_batch[0], 0).to(device)

# cycleGAN训练
D_loss = []
G_loss = []
epochs = 50
for epoch in range(epochs):
    d_epoch_loss = 0
    g_epoch_loss = 0
    for step, (real_A, real_B) in enumerate(zip(dog_dataloader, cat_dataloader)): #取出真实的狗,猫图片
        real_A = real_A.to(device)
        real_B = real_B.to(device)
        #--------------------begin--------------------#
        # 生成器训练
        gen_optimizer.zero_grad() #训练之前梯度清0
        # identity loss
        same_B = gen_AB(real_B) #真实的B经过生成器gen_AB还是要得到真实的B
        identity_B_loss = l1_loss(same_B, real_B)
        same_A = gen_AB(real_A) #真实的A经过生成器gen_BA还是要得到真实的A
        identity_A_loss = l1_loss(same_A, real_A)
        # 对抗损失 gan loss
        fake_B = gen_AB(real_A) #真实A通过生成器生成了B,此时生成器希望判别器将其判别为真
        D_pred_fake_B = dis_B(fake_B)
        gen_loss_AB = bce_loss(D_pred_fake_B, torch.ones_like(D_pred_fake_B, device=device))
        fake_A = gen_BA(real_B) #真实B通过生成器生成了A,此时生成器希望判别器将其判别为真
        D_pred_fake_A = dis_A(fake_A)
        gen_loss_BA = bce_loss(D_pred_fake_A, torch.ones_like(D_pred_fake_A, device=device))
        # 循环一致损失
        recovered_A = gen_BA(fake_B)
        cycle_loss_ABA = l1_loss(recovered_A, real_A)

        recovered_B = gen_AB(fake_A)
        cycle_loss_BAB = l1_loss(recovered_B, real_B)

        # 生成器总的损失
        g_loss = identity_A_loss + identity_B_loss + gen_loss_AB + gen_loss_BA +cycle_loss_ABA + cycle_loss_BAB

        g_loss.backward()
        gen_optimizer.step()
        # --------------------end--------------------#

        # --------------------begin--------------------#
        # 判别器训练
        # dis_A训练
        dis_A_optimizer.zero_grad()
        dis_A_real_output = dis_A(real_A) #输入为真,期望判定为真
        dis_A_real_loss = bce_loss(dis_A_real_output, torch.ones_like(dis_A_real_output, device=device))

        dis_A_fake_output = dis_A(fake_A.detach())  #输入为假,期望判定为假,梯度截断
        dis_A_fake_loss = bce_loss(dis_A_fake_output, torch.zeros_like(dis_A_fake_output, device=device))

        dis_A_loss = dis_A_real_loss + dis_A_fake_loss #生成器A的总损失
        dis_A_loss.backward()
        dis_A_optimizer.step()

        # dis_B训练
        dis_B_optimizer.zero_grad()
        dis_B_real_output = dis_B(real_B)  #输入为真,期望判定为真
        dis_B_real_loss = bce_loss(dis_B_real_output, torch.ones_like(dis_B_real_output, device=device))

        dis_B_fake_output = dis_B(fake_B.detach())  #输入为假,期望判定为假,梯度截断
        dis_B_fake_loss = bce_loss(dis_B_fake_output, torch.zeros_like(dis_B_fake_output, device=device))

        dis_B_loss = dis_B_real_loss + dis_B_fake_loss #生成器B的总损失
        dis_B_loss.backward()
        dis_B_optimizer.step()
        # --------------------end--------------------#

        with torch.no_grad():
            g_epoch_loss += g_loss.item() #将每一个批次的loss累加
            d_epoch_loss += (dis_A_loss + dis_B_loss).item()  # 将每一个批次的loss累加

    with torch.no_grad():
        g_epoch_loss /= (step + 1) #求得每一轮的平均loss
        d_epoch_loss /= (step + 1) #求得每一轮的平均loss
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print('epoch:', epoch, 'g_epoch_loss:', g_epoch_loss, 'd_epoch_loss:', d_epoch_loss)
        gen_img_plot(gen_AB, epoch, test_input)

猜你喜欢

转载自blog.csdn.net/m0_56247038/article/details/130395582
今日推荐