import argparse
import glob
import h5py
import numpy as np
import PIL.Image as pil_image
from utils import convert_rgb_to_y
parser = argparse.ArgumentParser()
parser.add_argument('--images-dir', type=str, default='/home/radio/DS/SRCNN-pytorch-master/train_out')
parser.add_argument('--output-path', type=str, default='/home/radio/DS/SRCNN-pytorch-master/out.h5')
parser.add_argument('--patch-size', type=int, default=320)
parser.add_argument('--stride', type=int, default=320)
parser.add_argument('--scale', type=int, default=2)
args = parser.parse_args()
def train(args):
h5_file = h5py.File(args.output_path, 'w')
lr_patches = []
hr_patches = []
#搜索指定文件夹下的文件并排序
for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):
hr = pil_image.open(image_path).convert('RGB')
print('hr:',hr.size)
hr_width = (hr.width // args.scale) * args.scale
hr_height = (hr.height // args.scale) * args.scale #先预处理,这样就没有多余的像素流出
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)#缩放处理
hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)#转换成np格式,便于转换处理
hr = convert_rgb_to_y(hr)
lr = convert_rgb_to_y(lr)
# print('hr.shape:',hr.shape)
print('lr.shape:',lr.shape)
for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride):
for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride):
lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size])
hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size])
print(np.array(lr_patches).shape)
lr_patches = np.array(lr_patches)
# print('lr_patches:',lr_patches.shape)
hr_patches = np.array(hr_patches)
# print('hr_patches:', hr_patches.shape)
print(np.array(lr_patches).shape)
h5_file.create_dataset('lr', data=lr_patches)
h5_file.create_dataset('hr', data=hr_patches)
h5_file.close()
if __name__ == '__main__':
train(args)
说明:使用的图片是裁减之后大小为320x320的图片,patch_size为320x320,之后写入提前创建好的h5文件。