在图像分割任务中,我换了个MS-SSIM+L1损失函数,该损失函数的要求是输入两张图片,我的网络输出的是经过softmax的概率值。
为了将输出复现为图片与标签一同输入到损失中进行计算,我开始的处理是这样的:
#模型输出
outputs = self.model(images)
#复现预测图为灰度图像
pred = torch.argmax(outputs[0], dim=1).unsqueeze(1) #[1,1,H,W]
target = targets.unsqueeze(1) #[1,1,H,W]
loss_dict = dict(loss=self.criterion(pred, target))
losses = sum(loss for loss in loss_dict.values())
# xxxxxx somethings
losses.requires_grad_(True)
losses.backward()
上述train过程中使用了argmax函数,影响了losses反向传播的过程,导致训练没有任何效果(损失值固定不变)
这里的解决代码为:
#模型输出
outputs = self.model(images)
#复现预测图为灰度图像
target = F.one_hot(targets, num_classes=2).permute(0,3,1,2).float()
loss_dict = dict(loss=self.criterion(outputs[0], target))
losses = sum(loss for loss in loss_dict.values())
# xxxxxx somethings
losses.requires_grad_(True)
losses.backward()
另附:测试代码中如何把预测出的结果还原到图片
outputs = self.model(images)
pred = torch.argmax(outputs[0], dim=1).cpu().data.numpy() #[1,H,W]
pred = pred.squeeze(0) #[H,w]
#图像着色
predict = get_color_pallete(pred, "citys").convert("RGB") # PIL Image
predict.save(save_pth)
MS-SSIM+L1 loss(PyTorch):