“用于无监督图像生成解耦的正交雅可比正则化”论文(OroJaR)——部分代码解读

1 前言

论文部分解读可以参考老师的博客
代码地址:GitHub

在github给出的代码中,除OroJaR的快速使用之外,还附带了用PyTorch(2个)和TensorFlow(1个)的demo。

环境配置文件environment.yml

name: orojar
channels:
  - pytorch
  - plotly
  - defaults
dependencies:
  - pillow=7.1.2
  - pip
  - python=3.6.10
  - pytorch=1.7.1
  - scipy=1.4.1
  - tensorflow-gpu=1.14.0
  - torchvision=0.8.2
  - pip:
    - matplotlib==3.3.0
    - moviepy==1.0.3
    - numpy
    - dominate
    - opencv-python==4.2.0.34
    - pandas==1.0.3
    - plotly==4.9.0
    - seaborn==0.10.1
    - tensorboardX
    - tensorboard>=1.15
    - tensorflow-probability==0.7.0
    - tqdm==4.46.0
    - wget
    - h5py

2 simplegan_experiments

2.1 概述

这部分代码放在OroJaR-mastersimplegan_experiments目录下。

这个demo是将OroJaR用于CycleGAN中,使用的数据集是DspritesCelebA

框架采用PyTorch

2.2 可视化预训练模型

1)下载模型
百度网盘(lg6v)
2)下载模型之后,将其放入pretrained_models文件夹
3)在终端运行如下代码可以在visuals文件夹中生成视频和图片

python visualize.py --model_path model_path --model_name OroJaR --model_type gan --nz <input dimension> --nc_out <output channel>
python .\visualize.py --model_path .\pretrained_models\dsprites_gan.pth --model_name dsprites_gan  

其中:
model_path表示与训练模型的位置(必填项);
model_name表示用于保存生成的视频和图片的保存地址;
model_type表示选择的模型类型,有两个模型可供选择:gangan128,其中gan对应64x64输入的Dsprites数据集,而gan128对应128x128输入的CelebA数据集;
nz表示输入的维度;
nc_out表示输出的通道数;
更加详细的参数解释可以看这段源码:

	parser.add_argument('--model_path', required=True,
                        help='Number of model paths. You can specify multiple experiments '
                             'to generate visuals for all of them with one call to this script.')
    parser.add_argument('--nz', default=6, type=int,
                        help='Number of components in G\'s latent space.')
    parser.add_argument('--nc_out', default=1, type=int,
                        help='Channal number of the output image')
    parser.add_argument('--samples', default=1, type=int,
                        help='Number of z samples to use (=interp_batch_size). This controls the "width" of the '
                             'generated videos.')
    parser.add_argument('--extent', default=2.0, type=float,
                        help='How "far" to move the z components (from -extent to extent)')
    parser.add_argument('--steps', default=40, type=int,
                        help='Number of frames in video (=granularity of interpolation).')
    parser.add_argument('--n_frames_to_save', type=int, default=9,
                        help='Number of "flattened" frames from video to save to png (0=disable).')
    parser.add_argument('--model_name', default='OroJaR',
                        help='Give names to the models you are evaluating')
    parser.add_argument('--model_type', default='gan',
                        help='Give model types to the models you are evaluating')
    parser.add_argument('--sefa', default=False, type=str2bool,
                        help='Use SeFa on the first conv/fc layer to achieve disentanglement.')
    parser.add_argument('--save_dir', type=str, default='./', help='figures are saved here')

4)放入自己的图片
要想放入自己的图片,我们需要做一些调整:

  • 在使用的时候需要设置nzsamples的值,其中samples相当于batch_sizenz表示输入维度。若输入的图片形状为(h, w),则samples=h, nz=w
  • 修改代码中sample_z的值为输入图片的像素矩阵

2.3 评估预训练模型

猜你喜欢

转载自blog.csdn.net/CesareBorgia/article/details/120380887