计算和比较不同图像重建方法之间的视觉相似性,使用 LPIPS(Learned Perceptual Image Patch Similarity)度量来评估。

计算和比较不同图像重建方法之间的视觉相似性,使用 LPIPS(Learned Perceptual Image Patch Similarity)度量来评估。

  1. 导入必要的库

    • argparse:用于命令行参数解析。
    • os:用于操作系统交互。
    • numpy:用于数值计算。
    • torch:PyTorch 库,用于深度学习。
    • lpips:用于计算图像间的视觉相似度。
    • dipy.io.image:用于加载 NIfTI 格式的医学图像。
  2. 定义 gray_to_rgb 函数

    • 将灰度图像转换为三通道 RGB 图像。灰度值被复制到 RGB 图像的红色通道,其他两个通道被设置为零。
  3. 初始化 LPIPS 模型

    • 使用 lpips.LPIPS 初始化一个 LPIPS 模型,使用 alex 网络和版本 0.1
  4. 定义数据集名称和参数

    • datasets_name:包含多个数据集的名称,这些数据集可能是不同重建方法的结果。
    • sun_namespun_names:分别代表不同的参数或条件,用于迭代处理不同的数据集。
  5. 循环处理每个参数组合

    • 外层循环遍历 sun_names,内层循环遍历 pun_names
    • 对于每对参数,计算重建图像和目标图像之间的 LPIPS 距离。
    • dir_pathres_pathtarget_path 用于构建文件路径,从文件系统中加载重建图像和目标图像。
    • 根据 sun 的值提取相应的 DWI 数据切片。
    • 将提取的数据转换为预处理形式,应用 gray_to_rgb 函数将灰度图像转换为 RGB 图像。
    • 将 RGB 图像转换为 PyTorch 张量,并计算 LPIPS 距离。
    • 累积 LPIPS 距离并计算平均 LPIPS 距离。
  6. 输出结果

    • 打印每个参数组合的平均 LPIPS 距离。

代码中有几个假设和限制:

  • 假设所有图像都是单位球面上的点,且关于球心对称的两个点被视为同一点。
  • 代码中的路径和文件名是硬编码的,实际使用时需要根据实际文件系统进行调整。
  • 代码没有处理可能的异常,例如文件不存在或路径错误。
import argparse
import os
import numpy as np
import torch
import lpips
from dipy.io.image import load_nifti



def gray_to_rgb(image_gray):
    # 创建一个全零的三通道图像
    height, width = image_gray.shape
    image_rgb = np.zeros((3,height, width), dtype=np.uint8)

    # 将灰度图像的值复制到红通道
    image_rgb[2, :,:] = image_gray

    return image_rgb




## Initializing the model
loss_fn = lpips.LPIPS(net='alex', version=0.1)

# the total list of images






datasets_name = [
    "all_our.nii.gz",
    "all_sh1.nii.gz",
    "all_q_dl.nii.gz",
    "all_SR_qdl.nii.gz",
    "all_PSR_qdl.nii.gz",
    "all_SAR.nii.gz",
    "all_rcnn_20.nii.gz",
    "target_all.nii.gz",
]
sun_names=[
    30,20,15,10,6
]
pun_names=[
    0,1,2,3,4,5,6
]
# for k,name in enumerate(datasets_name):
#     if k==0:
#         continue
for i,sun in enumerate(sun_names):
        total_lpips_distance = 0
        average_lpips_distance = 0
        for j,pun in enumerate(pun_names):
            dir_path=rf'../Results/myAE_RDN/patients/{
      
      pun}/{
      
      sun}'
            res_path = rf'{
      
      dir_path}/{
      
      datasets_name[0]}'
            target_path = rf'{
      
      dir_path}/{
      
      datasets_name[7]}'
            print("res_path:",res_path)
            print("target_path:", target_path)

            data_res, _ = load_nifti(res_path, return_img=False)
            data_tar, _ = load_nifti(target_path, return_img=False)
            b1 =data_res[...,0:90-sun]
            b1_tar = data_tar[...,0:90-sun]
            b2 =data_res[...,90:180-sun]
            b2_tar = data_tar[...,90:180-sun]
            b3 = data_res[...,180:270-sun]
            b3_tar = data_tar[...,180:270-sun]
            pre_b1 = b1
            lr_b1 = b1_tar
            pre_b2 = b2
            lr_b2 = b2_tar
            pre_b3 = b3
            lr_b3 = b3_tar
            print("lr_b3.shape",lr_b3.shape)
            h,w,s, b = pre_b1.shape
            for id in range(s):
                for idt in range(b):
                    pre =pre_b1[:,:,id,idt]
                    tar = lr_b1[:,:,id,idt]
                    pre = gray_to_rgb(pre)
                    tar = gray_to_rgb(tar)
                    pre = torch.tensor(pre.reshape(1,pre.shape[0],pre.shape[1],pre.shape[2]))
                    tar = torch.tensor(tar.reshape(1,tar.shape[0],tar.shape[1],tar.shape[2]))
                    current_lpips_distance = loss_fn.forward(pre.to(torch.float32), tar.to(torch.float32))
                    total_lpips_distance = total_lpips_distance + current_lpips_distance

            # print('%d: %.3f' % (j, float(total_lpips_distance) / (72*(90-sun))))

        average_lpips_distance = float(total_lpips_distance) / (72*(90-sun)*7)
        print("The input numble is ", sun, "and the average_lpips_distance is: %.3f" % average_lpips_distance)

猜你喜欢

转载自blog.csdn.net/qq_44050612/article/details/143349915