写在运行代码里面:
步骤一:
train_losses = [] train_losses.append(round(float(loss.data.cpu()),3))
步骤二:
torch.save(train_losses, '/home/wu/local/loss.pth')
保存loss
独立运行:
import torch import matplotlib.pyplot as plt import numpy as np loss = torch.load('./loss.pth') num = len(loss) x = [i+1 for i in range(num)] print(num) plt.figure(1) plt.plot(loss[:]) plt.show()