图像重建Restormer介绍与使用


前言

图像恢复是计算机视觉领域中的一个重要研究方向,它旨在通过算法修复损坏、模糊或缺失的图像信息,从而恢复图像的原始质量。随着图像处理技术的不断发展,图像恢复在许多领域都发挥着重要作用,如医学影像、遥感图像、文化遗产保护等。
本次我将介绍一款在图像恢复的多个任务中表现都不错的一个网络Restormer,并介绍其环境配置与代码使用,帮助大家在实际项目中使用Restormer。


一、Restormer介绍

在计算机视觉领域,高分辨率图像恢复是一个重要的挑战。图像在采集、传输或处理过程中,往往因为各种原因受到模糊、噪声等干扰,导致图像质量下降。为了解决这个问题,研究者们提出了各种模型和技术。近年来,Transformer模型在自然语言处理和计算机视觉领域取得了巨大成功。然而,传统的Transformer模型在处理高分辨率图像时,由于其庞大的计算量和参数量,难以实现高效的处理。
为了解决这个问题,作者提出了一种高效的Transformer模型——Restormer。Restormer模型通过在构建块中进行了几个关键的设计,包括多头注意和前馈网络,使其能够捕获长程像素交互,同时仍然适用于大型图像。这种设计使得Restormer在处理高分辨率图像时,能够更高效地恢复图像质量。
Restormer模型在多个图像恢复任务中取得了最先进的结果。这些任务包括图像去模糊、单图像运动去模糊(单图像和双像素数据)和图像去噪(高斯灰度/颜色去噪和真实图像去噪)。这些结果证明了Restormer模型在图像恢复任务中的有效性。

模型结构:在这里插入图片描述
在这里插入图片描述

二、环境安装与配置

1.下载项目

2.安装虚拟环境:

conda create -n pytorch181 python=3.7
conda activate pytorch181

3.安装依赖

conda install pytorch=1.8 torchvision cudatoolkit=10.2 -c pytorch
pip install matplotlib scikit-learn scikit-image opencv-python yacs joblib natsort h5py tqdm
pip install einops gdown addict future lmdb numpy pyyaml requests scipy tb-nightly yapf lpips

4.下载预训练模型

三、代码使用与效果

运行下面代码,demo.py:修改def get_weights_and_parameters(task, parameters):函数中的
elif task == ‘Single_Image_Defocus_Deblurring’:
weights = r"F:\Restormer-main\Defocus_Deblurring\pretrained_models\single_image_defocus_deblurring.pth"
将权重路径修改为自己刚才下载的预训练模型路径,运行即可

## Restormer: Efficient Transformer for High-Resolution Image Restoration
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
## https://arxiv.org/abs/2111.09881

##--------------------------------------------------------------
##------- Demo file to test Restormer on your own images---------
## Example usage on directory containing several images:   python demo.py --task Single_Image_Defocus_Deblurring --input_dir './demo/degraded/' --result_dir './demo/restored/'
## Example usage on a image directly: python demo.py --task Single_Image_Defocus_Deblurring --input_dir './demo/degraded/portrait.jpg' --result_dir './demo/restored/'
## Example usage with tile option on a large image: python demo.py --task Single_Image_Defocus_Deblurring --input_dir './demo/degraded/portrait.jpg' --result_dir './demo/restored/' --tile 720 --tile_overlap 32
##--------------------------------------------------------------

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import os
from runpy import run_path
from skimage import img_as_ubyte
from natsort import natsorted
from glob import glob
import cv2
from tqdm import tqdm
import argparse
from pdb import set_trace as stx
import numpy as np

parser = argparse.ArgumentParser(description='Test Restormer on your own images')
parser.add_argument('--input_dir', default='./demo/degraded/', type=str, help='Directory of input images or path of single image')
parser.add_argument('--result_dir', default='./demo/restored/', type=str, help='Directory for restored results')
parser.add_argument('--task', default="Single_Image_Defocus_Deblurring", type=str, help='Task to run', choices=['Motion_Deblurring',
                                                                                    'Single_Image_Defocus_Deblurring',
                                                                                    'Deraining',
                                                                                    'Real_Denoising',
                                                                                    'Gaussian_Gray_Denoising',
                                                                                    'Gaussian_Color_Denoising'])
parser.add_argument('--tile', type=int, default=None, help='Tile size (e.g 720). None means testing on the original resolution image')
parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles')

args = parser.parse_args()

def load_img(filepath):
    return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)

def save_img(filepath, img):
    cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

def load_gray_img(filepath):
    return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)

def save_gray_img(filepath, img):
    cv2.imwrite(filepath, img)

def get_weights_and_parameters(task, parameters):
    if task == 'Motion_Deblurring':
        weights = os.path.join('Motion_Deblurring', 'pretrained_models', 'motion_deblurring.pth')
    elif task == 'Single_Image_Defocus_Deblurring':
        weights = r"F:\Restormer-main\Defocus_Deblurring\pretrained_models\single_image_defocus_deblurring.pth"
    elif task == 'Deraining':
        weights = os.path.join('Deraining', 'pretrained_models', 'deraining.pth')
    elif task == 'Real_Denoising':
        weights = os.path.join('Denoising', 'pretrained_models', 'real_denoising.pth')
        parameters['LayerNorm_type'] =  'BiasFree'
    elif task == 'Gaussian_Color_Denoising':
        weights = os.path.join('Denoising', 'pretrained_models', 'gaussian_color_denoising_blind.pth')
        parameters['LayerNorm_type'] =  'BiasFree'
    elif task == 'Gaussian_Gray_Denoising':
        weights = os.path.join('Denoising', 'pretrained_models', 'gaussian_gray_denoising_blind.pth')
        parameters['inp_channels'] =  1
        parameters['out_channels'] =  1
        parameters['LayerNorm_type'] =  'BiasFree'
    return weights, parameters

task    = args.task
inp_dir = args.input_dir
out_dir = os.path.join(args.result_dir, task)

os.makedirs(out_dir, exist_ok=True)

extensions = ['jpg', 'JPG', 'png', 'PNG', 'jpeg', 'JPEG', 'bmp', 'BMP']

if any([inp_dir.endswith(ext) for ext in extensions]):
    files = [inp_dir]
else:
    files = []
    for ext in extensions:
        files.extend(glob(os.path.join(inp_dir, '*.'+ext)))
    files = natsorted(files)

if len(files) == 0:
    raise Exception(f'No files found at {
      
      inp_dir}')

# Get model weights and parameters
parameters = {
    
    'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}
weights, parameters = get_weights_and_parameters(task, parameters)

load_arch = run_path(os.path.join('basicsr', 'models', 'archs', 'restormer_arch.py'))
model = load_arch['Restormer'](**parameters)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

checkpoint = torch.load(weights)
model.load_state_dict(checkpoint['params'])
model.eval()

img_multiple_of = 8

print(f"\n ==> Running {
      
      task} with weights {
      
      weights}\n ")

with torch.no_grad():
    for file_ in tqdm(files):
        if torch.cuda.is_available():
            torch.cuda.ipc_collect()
            torch.cuda.empty_cache()

        if task == 'Gaussian_Gray_Denoising':
            img = load_gray_img(file_)
        else:
            img = load_img(file_)

        input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device)

        # Pad the input if not_multiple_of 8
        height,width = input_.shape[2], input_.shape[3]
        H,W = ((height+img_multiple_of)//img_multiple_of)*img_multiple_of, ((width+img_multiple_of)//img_multiple_of)*img_multiple_of
        padh = H-height if height%img_multiple_of!=0 else 0
        padw = W-width if width%img_multiple_of!=0 else 0
        input_ = F.pad(input_, (0,padw,0,padh), 'reflect')

        if args.tile is None:
            ## Testing on the original resolution image
            restored = model(input_)
        else:
            # test the image tile by tile
            b, c, h, w = input_.shape
            tile = min(args.tile, h, w)
            assert tile % 8 == 0, "tile size should be multiple of 8"
            tile_overlap = args.tile_overlap

            stride = tile - tile_overlap
            h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
            w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
            E = torch.zeros(b, c, h, w).type_as(input_)
            W = torch.zeros_like(E)

            for h_idx in h_idx_list:
                for w_idx in w_idx_list:
                    in_patch = input_[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
                    out_patch = model(in_patch)
                    out_patch_mask = torch.ones_like(out_patch)

                    E[..., h_idx:(h_idx+tile), w_idx:(w_idx+tile)].add_(out_patch)
                    W[..., h_idx:(h_idx+tile), w_idx:(w_idx+tile)].add_(out_patch_mask)
            restored = E.div_(W)

        restored = torch.clamp(restored, 0, 1)

        # Unpad the output
        restored = restored[:,:,:height,:width]

        restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
        restored = img_as_ubyte(restored[0])

        f = os.path.splitext(os.path.split(file_)[-1])[0]
        # stx()
        if task == 'Gaussian_Gray_Denoising':
            save_gray_img((os.path.join(out_dir, f+'.png')), restored)
        else:
            save_img((os.path.join(out_dir, f+'.png')), restored)

    print(f"\nRestored images are saved at {
      
      out_dir}")

使用效果:
原图:
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
修复图:
请添加图片描述
在这里插入图片描述

在这里插入图片描述