assert not step_t.is_cuda, “If capturable=False, state_steps should not be CUDA tensors.

问题:assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors.

描述:

当我尝试恢复训练时,出现问题:
Traceback (most recent call last):
File “/home/yu/projects/mobilevit/ml-cvnets/engine/training_engine.py”, line 682, in run
train_loss, train_ckpt_metric = self.train_epoch(epoch)
File “/home/yu/projects/mobilevit/ml-cvnets/engine/training_engine.py”, line 353, in train_epoch
self.gradient_scalar.step(optimizer=self.optimizer)
File “/home/yu/anaconda3/envs/mobilevit/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py”, line 338, in step
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
File “/home/yu/anaconda3/envs/mobilevit/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py”, line 285, in _may
be_opt_step
retval = optimizer.step(*args, **kwargs)
File “/home/yu/anaconda3/envs/mobilevit/lib/python3.8/site-packages/torch/optim/optimizer.py”, line 109, in wrapper
return func(*args, **kwargs)
File “/home/yu/anaconda3/envs/mobilevit/lib/python3.8/site-packages/torch/autograd/grad_mode.py”, line 27, in decorat
e_context
return func(*args, **kwargs)
File “/home/yu/anaconda3/envs/mobilevit/lib/python3.8/site-packages/torch/optim/adamw.py”, line 161, in step
adamw(params_with_grad,
File “/home/yu/anaconda3/envs/mobilevit/lib/python3.8/site-packages/torch/optim/adamw.py”, line 218, in adamw
func(params,
File “/home/yu/anaconda3/envs/mobilevit/lib/python3.8/site-packages/torch/optim/adamw.py”, line 259, in _single_tenso
r_adamw
assert not step_t.is_cuda, “If capturable=False, state_steps should not be CUDA tensors.”

我的版本

Python: 3.8.18 (default, Sep 11 2023, 13:40:15) [GCC 11.2.0]
CUDA available: True
GPU 0,1,2,3: NVIDIA GeForce RTX 3090
CUDA_HOME: /usr/local/cuda-11.3
NVCC: Cuda compilation tools, release 11.3, V11.3.58
GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
PyTorch: 1.12.0
TorchVision: 0.13.0
OpenCV: 4.8.0
MMCV: 1.7.0
MMCV Compiler: GCC 9.3
MMCV CUDA Compiler: 11.3
MMDetection: 2.25.0+

解决:

将pytorch版本从1.12.0升级到1.12.1即可。

pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113

猜你喜欢

转载自blog.csdn.net/shysea2019/article/details/133961947