CycleGAN论文详解:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

版权声明:转载注明出处:邢翔瑞的技术博客https://blog.csdn.net/weixin_36474809 https://blog.csdn.net/weixin_36474809/article/details/88778213

背景:ICCV2017的spotlight论文 cycleGAN在图像域迁移任务之中,不需要源域和目标域成对的样本对,只需要源域和目标域的图像即可。非常实用的地方就是输入的两张图片可以是任意的两张图片,也就是unpaired。对于我们项目很有作用。

目的:详解CycleGAN论文。

论文地址:https://arxiv.org/abs/1703.10593

目录

一、效果概览

二、传统GAN

1.1 传统GAN原理

1.2 GAN分类

三、cycleGAN方法概览

3.1 motivation

3.2 方法

3.3 实验及结果

3.4 训练方法

3.5 相关工作

四、公式及详尽描述

4.1 cycleGAN描述

4.2 adversarial loss

4.3 cycle consistency loss

4.4 总体loss

4.5 网络结构

4.6 训练细节

五、实验结果

5.1 AMT perceptual studies地图与遥感图像迁移

5.2 labels-photos

5.3 loss分析

六、个人总结

6.1 testA->testB

6.2 feature loss GAN


一、效果概览

训练时不需要成对的配对样本,只需要源域和目标域的图像。训练后网络就能实现对图像源域到目标域的迁移。

二、传统GAN

1.1 传统GAN原理

https://blog.csdn.net/leviopku/article/details/81292192

https://www.jianshu.com/p/40feb1aa642a

传统GAN有一个生成器Generator和判别器Descriminator,生成器G用于生成样本,判别起D用于判断这个样本是否为真样本。G用随机噪声生成假图,D根据真假图进行二分类的训练。D根据输入的图像生成score,这个score表示G生成的图像是否成功,进而进一步的训练G生成更好的图像。

  • 判别器D的监督信息就是真实的数据和G生成的数据打成的标签。
  • 判别器G的监督信息就是D(G(z)),也就是G生成图像再判别器中的score。

在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。(此公式必须深入并且反复理解,需要时参考GAN原论文)

  • 整个式子由两项构成。x表示真实图片,z表示输入G网络的噪声,而G(z)表示G网络生成的图片。
  • D(x)表示D网络判断真实图片是否真实的概率(因为x就是真实的,所以对于D来说,这个值越接近1越好)。而D(G(z))是D网络判断G生成的图片的是否真实的概率。
  • G的目的:上面提到过,D(G(z))是D网络判断G生成的图片是否真实的概率,G应该希望自己生成的图片“越接近真实越好”。也就是说,G希望D(G(z))尽可能得大,这时V(D, G)会变小。因此我们看到式子的最前面的记号是min_G。
  • D的目的:D的能力越强,D(x)应该越大,D(G(x))应该越小。这时V(D,G)会变大。因此式子对于D来说是求最大(max_D)

1.2 GAN分类

传统GAN 通过随机向量z生成图像y: z -> y

conditionalGAN(cGAN),pix2pix,通过随机向量z和图像x生成需要图像y :(z,x) -> y

cycleGAN,discoGAN与dualGAN

https://www.sohu.com/a/135098277_680233

DCGAN、WGAN、WGAN-GP、LSGAN、BEGAN

https://blog.csdn.net/qq_25737169/article/details/78857788

三、cycleGAN方法概览

3.1 motivation

对于image2image translation的任务,实现迁移需要对齐的样本对,但是对于很多任务,对齐样本对不易得到。

3.2 方法

希望从domain X-> domain Y

  • mapping: G实现X->Y,即G(X)与Y分布相同
  • invearse mapping F:Y-> X,
  • cycle consistency loss :  F(G(X)) ≈X 
  • 反之亦然, G(F(Y)) ≈Y

3.3 实验及结果

很多数据集上,unpaired training data实现了 style transfer, object transfiguration, season transfer, photo enhancement

3.4 训练方法

之前的方法需要image2image样本对,cycleGAN可以在缺少样本对信息的情况下,实现训练。

CycleGAN其实就是一个A→B单向GAN加上一个B→A单向GAN。两个GAN共享两个生成器,然后各自带一个判别器,所以加起来总共有两个判别器和两个生成器。一个单向GAN有两个loss,而CycleGAN加起来总共有四个loss。

  • GAB实现domainA到domainB迁移
  • GBA实现domainB到domainA迁移
  • DA实现判别GBA生成的数据or真实的A数据
  • DB实现判别GAB生成的数据or真实的B数据

训练过程如下:

  • GAB尽可能生成更真的图像愚弄DB
  • GBA尽可能生成更真的图像愚弄DA
  • DA尽可能判别出真实的A或者GBA生成的A
  • DB尽可能判别出真实的B或者GAB生成的B

3.5 相关工作

  • GAN
  • image2image translation
  • unpaired image2image translation
  • cycle consistency
  • neural style transfer

四、公式及详尽描述

4.1 cycleGAN描述

此处即上文所讲

  • 分布X,分布Y,分别来自不同的domain
  • G: X->Y ,生成器G来实现从X到Y的迁移
  • F: Y->X  ,生成器F来实现从Y到X的迁移
  • Dx判别X与 F(y) ,判别器Dx判别到底是真X还是F根据Y生成的与X同分布的数据
  • Dy判别Y与 G(x),判别器Dy判别到底是真Y还是G根据X生成的与Y同分布的数据

其过程包含了两种loss:

  • adversarial losses:尽可能让生成器生成的数据分布接近于真实的数据分布
  • cycle consistency losses: 防止生成器G与F相互矛盾,即两个生成器生成数据之后还能变换回来近似看成X->Y->X

4.2 adversarial loss

尽可能让生成器生成的数据接近于真实的数据分布:

与GAN一样,G用于实现X->Y, 训练应当尽可能让此G(X)接近于Y,判别器Dy用于判别样本的真假。与GAN的公式一样:

同理,对于F实现 Y->X,  

4.3 cycle consistency loss

用于让两个生成器生成的样本之间不要相互矛盾。

上一个adversarial loss只可以保证生成器生成的样本与真实样本同分布,但是我们希望对应的域之间的图像是一一对应的。即A-B-A还可以再迁移回来。

我们希望x -> G(x) -> F(G(x)) ≈ x,称作forward cycle consistency

同理,y -> F(y) -> G(F(y)) ≈ y, 称作 backward cycle consistency

为了尽可能保证consistency,我们设定相应的loss:

4.4 总体loss

即生成器G尽可能实现X到Y的迁移,生成器F尽可能实现Y到X的迁移,同时,希望两生成器的生成器是可以实现互逆,即相互迭代回到自身。(作者后面细节之中,λ 取10 )

4.5 idt loss

有一个loss再论文主要部分没有提及,但是在application之中提及了,并且代码之中有涉及,是idt loss

cycle_gan_model.py之中对它的定义是这样:

parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')

idt loss的定义在论文的application之中,防止input 与out put之间的color compostion过多。网络所有的loss的定义就是,reconstruction loss就是GAN loss和cycle consistency loss两个加在一起,GAN loss用于迁移类,cycle consistency loss用于尽量保留原图可以循环迁移。但是还有一个更直观的loss叫idt loss尽量的避免迁移过多。

4.5 网络结构

生成器运用下文中的网络结构,网络包含两个stride-2 的卷积,和数个residual blocks(看代码时候我们详细研究网络结构)

J. Johnson, A. Alahi, and L. Fei-Fei. Perceptual losses for real-time style transfer and super-resolution. In ECCV, 2016

对于判别器网络,我们运用70*70 PatchGANs

4.6 训练细节

对于loss训练而言,作者运用了两个techniques

第一个是针对loss的设置而言:对于G和D分别实现下面的两个minimize,

另一个是针对训练过程而言:

防止模型抖动:to reduce model oscillation , we update the discriminators using a history of generated images rather than the
ones produced by the latest generators. We keep an image buffer that stores the 50 previously created images

上面所说的运用了下文的方法:

A. Shrivastava, T. Pfister, O. Tuzel, J. Susskind, W. Wang, and R. Webb. Learning from simulated and unsupervised images through adversarial training. In CVPR, 2017

五、实验结果

个人感觉没有用的部分不细看了,有需要再参阅原论文,https://arxiv.org/abs/1703.10593

我们只看cycleGAN可以实现的运用,以及关于loss设置的实验

5.1 AMT perceptual studies地图与遥感图像迁移

即map与aerial图像之间的迁移(注意,我们这里可以看到feature loss可以划出来)

5.2 labels-photos

FCN score (labels-photos)

Semantic segmentation metrics(photo-labels)

5.3 loss分析

我们希望x -> G(x) -> F(G(x)) ≈ x,称作forward cycle consistency

同理,y -> F(y) -> G(F(y)) ≈ y, 称作 backward cycle consistency

作者实验了几种loss的组合方法:根据标签生成图像时loss用GAN+forward cycle loss就达到最优

六、个人总结

6.1 testA->testB

训练好网络进行测试的时候,testB中必须有图像,且cycleGAN必定根据testB来生成相应的结果。

我们看完论文至今没有弄明白testB中必须放入的图像什么作用,后续需要根据代码继续研究。

源码中关于test命令行的注释:

Example (You need to train models first or download pre-trained models from our website):
    Test a CycleGAN model (both sides):
        python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan

    Test a CycleGAN model (one side only):
        python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout

即直接运行test.py是for both sides。

但是还是没太明白其中testA和testB文件夹中数据的作用。看了源码之后,大致明了,见下面6.3

6.2 feature loss GAN

实验中我们可以看到feature loss GAN的图像是完全可以print出来的,对于我们的项目很有帮助。

说明我们可以通过改代码实现找到feature loss GAN的图像。

6.3 fake real rec的定义

real我们可以确定为实际的用于测试的样本,fake为根据随机生成的假样本 ? rec的定义是什么?

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))
  • real_A输入netG_A 生成 fake_B
  • fake_B输入netG_B生成rec_A 
  • real_B输入netG_B生成fake_A
  • fake_A输入netG_A生成rec_B

猜你喜欢

转载自blog.csdn.net/weixin_36474809/article/details/88778213