torch报错:StopIteration: Caught StopIteration in replica 0 on device 0.

pytorch DataParallel报错解决

错误展示

错误名称:

StopIteration: Caught StopIteration in replica 0 on device 0.

包版本:

pytorch-pretrained-bert 0.6.2
torch                   1.6.0

错误如下:

报错显示图1
报错显示图2

问题原因

使用单gpu的时候是正常的,但是使用多gpu的时候会报错。问题是多gpu进行模型训练的时候产生的,具体为,不能够用多gpu加载预训练的bert。应该是torch版本的问题。根据2可以知道,torch1.5版本有这个问题,我是torch1.6也有这个问题,据3替换为torch1.4可以解决该问题。

解决方法

比较简单粗暴的解决方法如下:
注意有如下问题:

  File "/miniconda/lib/python3.7/site-packages/pytorch_pretrained_bert/modeling.py", line 727, in forward
    extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility

进入site-packages目录
/miniconda/lib/python3.7/site-packages/pytorch_pretrained_bert/modeling.py 这个路径下的modeling.py脚本把727行的
next(self.parameters()).dtype换成torch.float32

猜你喜欢

转载自blog.csdn.net/weixin_44152453/article/details/109290978
今日推荐