版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/bla234/article/details/89322878
-
data_utils.py
- 数据的基本处理方法定义,由torchvision.transforms来定义返回Compose对象
- 继承Dataset类,来定义train,test,val等数据的读取和处理方式
-
- 数学公式的常规操作,矩阵运算,
- 然后写测试代码来运行验证
-
- 这部分是最好理解和编写的
- 先写好基本的res模块和upsample模块
- 然后用nn.Sequential串联各个模块
-
- 设置超参数
- 读取数据集然后用DataLoader实现batch
- 定义网络对象,统计其中参数的总数
- 定义优化器,传网络参数进去
- 训练,验证循环的编写(核心)
- 更新鉴别器:
(1)低分辨率作为噪声传入生成器得到fake_img,高分辨率作为real_img
(2)鉴别器梯度归零,两图传入鉴别器,计算D_loss并backward回传梯度,然后调用optimizerD.step()更新鉴别器的速率 - 更新生成器:
(1)生成器梯度归零,计算G_loss并backward回传梯度,然后调用optimizerD.step()更新鉴别器的速率 - 记录loss
(1)重新从噪声生成fake_img,计算g_loss,d_loss(G和D的loss),d_score,g_score(G和D的输出,或者说置信度?)
(2)train_bar记录数据并输出 - 验证模式
(1)生成sr,计算sr和hr的mse,ssim和psnr
(2)输出展示hr,线性插值后的hr和sr - 保存模型参数并记录数据
- 更新鉴别器:
-
test_benchmark.py
- 模型设置为eval模式
- 和val有点像,也是数据加载和模型加载,传入模型,计算各项指标和展示结果,保存并输出数据
-
test_image.py
- benchmark的单张模式
- 图片计时
-
test_video.py
主要是cv2包读取和处理视频的一些方法和属性