SRGAN代码结构分析

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/bla234/article/details/89322878
  • data_utils.py

    1. 数据的基本处理方法定义,由torchvision.transforms来定义返回Compose对象
    2. 继承Dataset类,来定义train,test,val等数据的读取和处理方式
  • loss.py

    1. 数学公式的常规操作,矩阵运算,
    2. 然后写测试代码来运行验证
  • model.py

    1. 这部分是最好理解和编写的
    2. 先写好基本的res模块和upsample模块
    3. 然后用nn.Sequential串联各个模块
  • train.py

    1. 设置超参数
    2. 读取数据集然后用DataLoader实现batch
    3. 定义网络对象,统计其中参数的总数
    4. 定义优化器,传网络参数进去
    5. 训练,验证循环的编写(核心)
      • 更新鉴别器:
        (1)低分辨率作为噪声传入生成器得到fake_img,高分辨率作为real_img
        (2)鉴别器梯度归零,两图传入鉴别器,计算D_loss并backward回传梯度,然后调用optimizerD.step()更新鉴别器的速率
      • 更新生成器:
        (1)生成器梯度归零,计算G_loss并backward回传梯度,然后调用optimizerD.step()更新鉴别器的速率
      • 记录loss
        (1)重新从噪声生成fake_img,计算g_loss,d_loss(G和D的loss),d_score,g_score(G和D的输出,或者说置信度?)
        (2)train_bar记录数据并输出
      • 验证模式
        (1)生成sr,计算sr和hr的mse,ssim和psnr
        (2)输出展示hr,线性插值后的hr和sr
      • 保存模型参数并记录数据
  • test_benchmark.py

    1. 模型设置为eval模式
    2. 和val有点像,也是数据加载和模型加载,传入模型,计算各项指标和展示结果,保存并输出数据
  • test_image.py

    1. benchmark的单张模式
    2. 图片计时
  • test_video.py
    主要是cv2包读取和处理视频的一些方法和属性

猜你喜欢

转载自blog.csdn.net/bla234/article/details/89322878