【pytorch】基于Apex的混合精度加速

    双倍训练速度,双倍的快乐,简单记录Nvidia开发的基于PyTorch的混合精度训练加速神器--Apex,可以用短短三行代码就能实现不同程度的混合精度加速,训练时间直接缩小一半

三行代码搞定:

from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()

opt_level:

代码中只有一个opt_level需要用户自行配置:

  • O0:纯FP32训练,可以作为accuracy的baseline;
  • O1:混合精度训练(推荐使用),根据黑白名单自动决定使用FP16(GEMM, 卷积)还是FP32(Softmax)进行计算。
  • O2:“几乎FP16”混合精度训练,不存在黑白名单,除了Batch norm,几乎都是用FP16计算。
  • O3:纯FP16训练,很不稳定,但是可以作为speed的baseline;

参考链接:https://nvidia.github.io/apex/amp.html

发布了33 篇原创文章 · 获赞 46 · 访问量 4万+

猜你喜欢

转载自blog.csdn.net/qq_34795071/article/details/103539168