有关PyTorch中Checkpoint的原理、实现和问题

有关PyTorch中Checkpoint的原理、实现和问题

一、动机

​ 由于复现某些论文中的代码时,使用正常的方法跑,显存不够。了解到这个方法是牺牲时间来降低显存,使用完之后,果然可以跑起来,而且显存降低了好多。那个代码至少30G显存才可能跑起来,使用完之后,不到9个G。

​ 写这个博客希望可以帮助到一些有需要的人。

二、原理

我们使用pytorch训练模型的时候主要有四部分消耗显存。

  • 模型参数
  • 模型参数的梯度
  • 优化器状态
  • 中间激活值

模型的现存之所以那么大,其中原因之一就是计算梯度时,模型会把所有前向传播的中间激活值都保存下来,这非常消耗显存,这样的好处是,需要那个中间激活值时,可以直接用,就不需要再次计算,节省了时间。

Checkpointing采取的策略是:保留一部分中间激活值,其余部分丢弃,如果用到的中间激活值没有的话,就重新计算,这样大大节省了显存,但是增加了时间。

三、实现

for cascade in self.cascades:
     if is_training:
        kspace_pred = checkpoint.checkpoint(cascade, x1, x2)
     else:
        kspace_pred = cascade(x1, x2)
        
# cascade:网络
# x1:网络的参数1
# x2:网络的参数2

上述是在训练的时候使用checkpoint技术,在验证和测试的时候不使用。

checkpoint放在你进入网络,开始迭代的时候。

四、问题

如果,你使用的时候遇到下面这个警告。

警告:UserWarning: None of the inputs have requires_grad=True.

可能的解决办法之一:

你把所有的 requires_grad设置为True。

扫描二维码关注公众号,回复: 17469259 查看本文章

可能的解决办法之一:

你在测试或者验证的时候也使用了checkpoint,因为测试的或者验证的时候,不需要梯度传播,也就引发了这个警告。

你可以不用管,结果应该是一样的。

如果你不想看到警告,你就设置个判断,测试和验证的时候不使用checkpoint,仅在训练的时候使用。

参考文章

  • https://blog.csdn.net/Solo95/article/details/131606918?s
  • https://blog.csdn.net/Shirelle_/article/details/137868196
  • https://zhuanlan.zhihu.com/p/424512257
  • https://blog.csdn.net/P_LarT/article/details/122521212

猜你喜欢

转载自blog.csdn.net/lihaiyuan_0324/article/details/139299374