最近在复现一篇论文代码的过程中,想要输出中间图片的结果图,通过debug发现在pytorch网络中是用Tensor存储的四维张量。
1、维度顺序转换
第一维代表的是batch_size,然后是通道数和图像尺寸,首先要进行维度顺序的转换
通过permute函数实现
outputRs = outputR.permute(0,2,3,1)
shape转为96 * 128 * 3
2、转为numpy数组
#由于代码中的中间结果是带有梯度的要进行detach()操作
k = outputRs.cpu().detach().numpy()
3、根据第一维度batch_size逐个读取中间结果,并存储到磁盘中
Image需导入from PIL import Image
for i in range(10):
res = k[i] #得到batch中其中一步的图片
image = Image.fromarray(np.uint8(res)).convert('RGB')
#image.show()
#通过时间命名存储结果
timestamp = datetime.datetime.now().strftime("%M-%S")
savepath = timestamp + '_r.jpg'
image.save(savepath)