版权声明:转载注明出处:邢翔瑞的技术博客https://blog.csdn.net/weixin_36474809 https://blog.csdn.net/weixin_36474809/article/details/88863829
背景:我们需要搞懂cycleGAN如何对已有图片进行inference
目录
一、嵌套位置
1.1 调用位置
test.py之中,很容易看到调用inference的部分
for i, data in enumerate(dataset):
if i >= opt.num_test: # only apply our model to opt.num_test images.
break
model.set_input(data) # unpack data from data loader
model.test() # run inference
visuals = model.get_current_visuals() # get image results
img_path = model.get_image_paths() # get image paths
1.2 inference调用的函数
二、前馈运算
2.1 forward
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG_A(self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
self.fake_A = self.netG_B(self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
- real_A输入netG_A 生成 fake_B
- fake_B输入netG_B生成fake_A
- real_B输入netG_B生成fake_A
- fake_A输入netG_A生成rec_B
- 即fake就是根据real通过生成器G生成的
- rec就是re-cycle,就是A通过两次生成器返回的A
2.2 实验结果及解释
这也给实验结果一定的解释:
- 训练时模型是从trainA是正常布料,trainB是棉布料,测试时testA是正常,testB是撕裂布料
- 即训练: 正常——棉 ,模型是从正常布料到棉布料的迁移
- 测试:正常——破裂
分别是realA,fakeA,recA
分别是realB,fakeB,recB
三、模型
3.1 模型定义
base_model.py与cycle_gan_model.py之中定义了模型,loss等各种信息。
3.2 定义loss
具体参见:
初始化时定义了几种loss的名称,后面定义了backward_D_basic和backward_G
3.3 模型结构
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG_A(self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
self.fake_A = self.netG_B(self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))