amp错误ZeroDivisionError: float division by zero

#ampError ZeroDivisionError: float division by zero
##After checking the information for a long time, I found that most of the reasons given on the Internet are that there is 0 in the divisor,
but after carefully checking the code, the position of the error report seems to be related to the amp in apex
and combined with some online training codes The comparison found that the following piece of code is different, the problem should be in this piece

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)
                pbar.set_postfix(**{
    
    'loss (batch)': loss.item()})

                optimizer.zero_grad()
                #loss.backward()
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                     scaled_loss.backward()
                optimizer.step()

mainly here

                optimizer.zero_grad()
                #loss.backward()
                
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                     scaled_loss.backward()
                optimizer.step()

Not the same
modification method: change to the following

                optimizer.zero_grad()
                loss.backward()
                
                #with amp.scale_loss(loss, optimizer) as scaled_loss:
                     #scaled_loss.backward()
                optimizer.step()```
重新运行训练代码,成功解决问题。即放弃了使用apex的加速功能,不过总比跑不起来代码要好!

Guess you like

Origin blog.csdn.net/qq_45014374/article/details/127258558