[DEBUG diary] cannot import name 'amp'

Problem Description:

When using WongKinYiu/PyTorch_YOLOv4 for training, an error is reported:

Traceback (most recent call last):
  File "train.py", line 15, in <module>
    from torch.cuda import amp
ImportError: cannot import name 'amp'

Cause Analysis:

1. Only PyTorch1.6 and above can import amp from torch.cuda;
2. Otherwise, you need to install apex yourself, and the source code is changed to

from apex import amp

solution:

1. Check whether the PyTorch and CUDA versions correspond;
https://pytorch.org/get-started/previous-versions/
2. Update the PyTorch version to 1.6 and above (recommended, because there may be other incompatible versions in the model code)

or
install apex

git clone https://github.com/NVIDIA/apex.git
cd apex
python setup.py install --cpp_ext --cuda_ext
(有时会安装失败,检查PyTorch和CUDA版本 或者 去掉--cuda_ext 便可顺利安装)

and will from torch.cuda import ampchange to from apex import amp.

Guess you like

Origin blog.csdn.net/lucifer479/article/details/111322564