Pytorch troubleshoots the situation where the loss value appears nan

question

When running the laboratory, the loss value was nan (as shown in the figure below)
Insert image description here

troubleshooting

Because there are some judgments, the statement is used bd_index = torch.where(s_label != o_label)[0]to select the subscripts that meet the conditions. This may return an empty list (when the conditions are not met), because the torch.log()function is used for subsequent calculations, so if it is an empty list, it will appear loss = nan. However, different situations must be analyzed in detail. The best way to troubleshoot is to use statements with torch.autograd.detect_anomaly():to wrap the training code (as shown in the figure below), so that the point where the null value appears can be quickly located (it must be removed during normal training, which will drag slow speed)

Insert image description here
Finally positioned torch.log()everywhere , the final solution is to add constants log_prob = torch.log(prob + 1e-7)so that there will be no null values

Run successfully:
Insert image description here

Guess you like

Origin blog.csdn.net/qq_56039091/article/details/127675557