1 前言
在github给出的代码中,除OroJaR
的快速使用之外,还附带了用PyTorch(2个)和TensorFlow(1个)的demo。
环境配置文件environment.yml
:
name: orojar
channels:
- pytorch
- plotly
- defaults
dependencies:
- pillow=7.1.2
- pip
- python=3.6.10
- pytorch=1.7.1
- scipy=1.4.1
- tensorflow-gpu=1.14.0
- torchvision=0.8.2
- pip:
- matplotlib==3.3.0
- moviepy==1.0.3
- numpy
- dominate
- opencv-python==4.2.0.34
- pandas==1.0.3
- plotly==4.9.0
- seaborn==0.10.1
- tensorboardX
- tensorboard>=1.15
- tensorflow-probability==0.7.0
- tqdm==4.46.0
- wget
- h5py
2 simplegan_experiments
2.1 概述
这部分代码放在OroJaR-master
的simplegan_experiments
目录下。
这个demo是将OroJaR
用于CycleGAN
中,使用的数据集是Dsprites
和CelebA
。
框架采用PyTorch
2.2 可视化预训练模型
1)下载模型
百度网盘(lg6v)
2)下载模型之后,将其放入pretrained_models
文件夹
3)在终端运行如下代码可以在visuals
文件夹中生成视频和图片
python visualize.py --model_path model_path --model_name OroJaR --model_type gan --nz <input dimension> --nc_out <output channel>
python .\visualize.py --model_path .\pretrained_models\dsprites_gan.pth --model_name dsprites_gan
其中:
model_path
表示与训练模型的位置(必填项);
model_name
表示用于保存生成的视频和图片的保存地址;
model_type
表示选择的模型类型,有两个模型可供选择:gan
和gan128
,其中gan
对应64x64输入的Dsprites
数据集,而gan128
对应128x128输入的CelebA
数据集;
nz
表示输入的维度;
nc_out
表示输出的通道数;
更加详细的参数解释可以看这段源码:
parser.add_argument('--model_path', required=True,
help='Number of model paths. You can specify multiple experiments '
'to generate visuals for all of them with one call to this script.')
parser.add_argument('--nz', default=6, type=int,
help='Number of components in G\'s latent space.')
parser.add_argument('--nc_out', default=1, type=int,
help='Channal number of the output image')
parser.add_argument('--samples', default=1, type=int,
help='Number of z samples to use (=interp_batch_size). This controls the "width" of the '
'generated videos.')
parser.add_argument('--extent', default=2.0, type=float,
help='How "far" to move the z components (from -extent to extent)')
parser.add_argument('--steps', default=40, type=int,
help='Number of frames in video (=granularity of interpolation).')
parser.add_argument('--n_frames_to_save', type=int, default=9,
help='Number of "flattened" frames from video to save to png (0=disable).')
parser.add_argument('--model_name', default='OroJaR',
help='Give names to the models you are evaluating')
parser.add_argument('--model_type', default='gan',
help='Give model types to the models you are evaluating')
parser.add_argument('--sefa', default=False, type=str2bool,
help='Use SeFa on the first conv/fc layer to achieve disentanglement.')
parser.add_argument('--save_dir', type=str, default='./', help='figures are saved here')
4)放入自己的图片
要想放入自己的图片,我们需要做一些调整:
- 在使用的时候需要设置
nz
和samples
的值,其中samples
相当于batch_size
,nz
表示输入维度。若输入的图片形状为(h, w)
,则samples=h, nz=w
- 修改代码中
sample_z
的值为输入图片的像素矩阵