计算和比较不同图像重建方法之间的视觉相似性,使用 LPIPS(Learned Perceptual Image Patch Similarity)度量来评估。
-
导入必要的库:
argparse
:用于命令行参数解析。os
:用于操作系统交互。numpy
:用于数值计算。torch
:PyTorch 库,用于深度学习。lpips
:用于计算图像间的视觉相似度。dipy.io.image
:用于加载 NIfTI 格式的医学图像。
-
定义
gray_to_rgb
函数:- 将灰度图像转换为三通道 RGB 图像。灰度值被复制到 RGB 图像的红色通道,其他两个通道被设置为零。
-
初始化 LPIPS 模型:
- 使用
lpips.LPIPS
初始化一个 LPIPS 模型,使用alex
网络和版本0.1
。
- 使用
-
定义数据集名称和参数:
datasets_name
:包含多个数据集的名称,这些数据集可能是不同重建方法的结果。sun_names
和pun_names
:分别代表不同的参数或条件,用于迭代处理不同的数据集。
-
循环处理每个参数组合:
- 外层循环遍历
sun_names
,内层循环遍历pun_names
。 - 对于每对参数,计算重建图像和目标图像之间的 LPIPS 距离。
dir_path
和res_path
、target_path
用于构建文件路径,从文件系统中加载重建图像和目标图像。- 根据
sun
的值提取相应的 DWI 数据切片。 - 将提取的数据转换为预处理形式,应用
gray_to_rgb
函数将灰度图像转换为 RGB 图像。 - 将 RGB 图像转换为 PyTorch 张量,并计算 LPIPS 距离。
- 累积 LPIPS 距离并计算平均 LPIPS 距离。
- 外层循环遍历
-
输出结果:
- 打印每个参数组合的平均 LPIPS 距离。
代码中有几个假设和限制:
- 假设所有图像都是单位球面上的点,且关于球心对称的两个点被视为同一点。
- 代码中的路径和文件名是硬编码的,实际使用时需要根据实际文件系统进行调整。
- 代码没有处理可能的异常,例如文件不存在或路径错误。
import argparse
import os
import numpy as np
import torch
import lpips
from dipy.io.image import load_nifti
def gray_to_rgb(image_gray):
# 创建一个全零的三通道图像
height, width = image_gray.shape
image_rgb = np.zeros((3,height, width), dtype=np.uint8)
# 将灰度图像的值复制到红通道
image_rgb[2, :,:] = image_gray
return image_rgb
## Initializing the model
loss_fn = lpips.LPIPS(net='alex', version=0.1)
# the total list of images
datasets_name = [
"all_our.nii.gz",
"all_sh1.nii.gz",
"all_q_dl.nii.gz",
"all_SR_qdl.nii.gz",
"all_PSR_qdl.nii.gz",
"all_SAR.nii.gz",
"all_rcnn_20.nii.gz",
"target_all.nii.gz",
]
sun_names=[
30,20,15,10,6
]
pun_names=[
0,1,2,3,4,5,6
]
# for k,name in enumerate(datasets_name):
# if k==0:
# continue
for i,sun in enumerate(sun_names):
total_lpips_distance = 0
average_lpips_distance = 0
for j,pun in enumerate(pun_names):
dir_path=rf'../Results/myAE_RDN/patients/{
pun}/{
sun}'
res_path = rf'{
dir_path}/{
datasets_name[0]}'
target_path = rf'{
dir_path}/{
datasets_name[7]}'
print("res_path:",res_path)
print("target_path:", target_path)
data_res, _ = load_nifti(res_path, return_img=False)
data_tar, _ = load_nifti(target_path, return_img=False)
b1 =data_res[...,0:90-sun]
b1_tar = data_tar[...,0:90-sun]
b2 =data_res[...,90:180-sun]
b2_tar = data_tar[...,90:180-sun]
b3 = data_res[...,180:270-sun]
b3_tar = data_tar[...,180:270-sun]
pre_b1 = b1
lr_b1 = b1_tar
pre_b2 = b2
lr_b2 = b2_tar
pre_b3 = b3
lr_b3 = b3_tar
print("lr_b3.shape",lr_b3.shape)
h,w,s, b = pre_b1.shape
for id in range(s):
for idt in range(b):
pre =pre_b1[:,:,id,idt]
tar = lr_b1[:,:,id,idt]
pre = gray_to_rgb(pre)
tar = gray_to_rgb(tar)
pre = torch.tensor(pre.reshape(1,pre.shape[0],pre.shape[1],pre.shape[2]))
tar = torch.tensor(tar.reshape(1,tar.shape[0],tar.shape[1],tar.shape[2]))
current_lpips_distance = loss_fn.forward(pre.to(torch.float32), tar.to(torch.float32))
total_lpips_distance = total_lpips_distance + current_lpips_distance
# print('%d: %.3f' % (j, float(total_lpips_distance) / (72*(90-sun))))
average_lpips_distance = float(total_lpips_distance) / (72*(90-sun)*7)
print("The input numble is ", sun, "and the average_lpips_distance is: %.3f" % average_lpips_distance)