概述
在上一篇博客 使用Pytorch裁剪图片并保存 中说明了如何使用Pytorch对单张图片进行裁剪并保存。在实际使用过程中可能会面临对于一整个文件夹里的数据集批量进行裁剪并以指定的文件名保存到某一路径之下,本文基于这种需求对其进行实现。
实现思路
- 创建DataSet类,实现其抽象方法,一个是支持下标索引操作的_getitem__,一个是支持获取数据长度的__len__;
- 实例化DataSet的对象dataset;
- 获取数据集的总长度;
- 根据长度遍历数据集;
- 每个循环中对图片进行裁剪并保存。
代码结构
其中,data/ 文件夹是要处理的目标文件夹数据,本示例主要有两个:crop_data_random.py
使用transforms.RandomCrop() 随机裁剪图片,crop_data_five.py
使用transforms.FiveCrop() 将每张图片裁剪为左上、左下、右上、右下和中间五张图片。
下面分别给出代码的实现方式及实现结果:
代码实现及运行结果
crop_data_random.py
import PIL.Image as Image
import os
from torch.utils.data import Dataset
from torchvision import transforms as transforms
import time
# os.environ["CUDA_VISIBLE_DEVICES"] = "0" #设置使用GPU跑代码
class MyData(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.img_path = os.listdir(self.data_dir)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.data_dir, img_name)
img = Image.open(img_item_path)
return img
def __len__(self):
return len(self.img_path)
# pytorch提供的torchvision主要使用PIL的Image类进行处理,所以它数据增强函数大多数都是以PIL作为输入,并且以PIL作为输出。
# 读取图片
def read_PIL(image_path):
image = Image.open(image_path)
return image
# 获取读到图片的不带后缀的名称
def get_name(image):
im_path = image.filename
im_name = os.path.split(im_path) # 将路径分解为路径中的文件名+扩展名,获取到的是一个数组格式,最后一个是文件名
name = os.path.splitext(im_name[len(im_name) - 1]) # 获取不带扩展名的文件名,是数组的最后一个
return name[0] # arr[0]是不带扩展名的文件名,arr[1]是扩展名
# 将图片Reszie
def resize_img(image):
Resize = transforms.Resize(512)
resize_img = Resize(image)
return resize_img
# 随机裁剪
def random_crop(image):
RandomCrop = transforms.RandomCrop(size=(256, 256))
random_image = RandomCrop(image)
return random_image
##################################################################################################
data_dir = 'data' #数据集的路径
dataset = MyData(data_dir) #创建对象
sum = dataset.__len__() #获取数据集的总长度
# 创建输出目录
outDir = 'data_crop_random/'
os.makedirs(outDir, exist_ok=True)
start=time.time() #程序开始的时间
for i in range(0, sum):
image = dataset.__getitem__(i) #获取下标为 i 的图像
name = get_name(image) #获取图像的文件名
resie_img = resize_img(image) #将图像Resize
for cont in range(3):
random_cropped_image = random_crop(resie_img) # 随机裁剪
# random_cropped_image.show() # 显示裁剪后的图片
out_name = name + '_crop_' + str(cont) + '.png' #输出的文件名
print(out_name)
random_cropped_image.save(os.path.join(outDir, out_name)) # 按照路径保存图片
end=time.time() #程序结束的时间
print('Running time: %s Seconds'%(end-start))
运行结果
裁剪后的图片已经保按要求存到文件夹中,如下图所示:
crop_data_five.py
import PIL.Image as Image
import os
from torch.utils.data import Dataset
from torchvision import transforms as transforms
import time
os.environ["CUDA_VISIBLE_DEVICES"] = "0" #设置使用GPU跑代码
class MyData(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.img_path = os.listdir(self.data_dir)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.data_dir, img_name)
img = Image.open(img_item_path)
return img
def __len__(self):
return len(self.img_path)
# pytorch提供的torchvision主要使用PIL的Image类进行处理,所以它数据增强函数大多数都是以PIL作为输入,并且以PIL作为输出。
# 读取图片
def read_PIL(image_path):
image = Image.open(image_path)
return image
# 获取读到图片的不带后缀的名称
def get_name(image):
im_path = image.filename
im_name = os.path.split(im_path) # 将路径分解为路径中的文件名+扩展名,获取到的是一个数组格式,最后一个是文件名
name = os.path.splitext(im_name[len(im_name) - 1]) # 获取不带扩展名的文件名,是数组的最后一个
return name[0] # arr[0]是不带扩展名的文件名,arr[1]是扩展名
# 将图片Reszie
def resize_img(image):
Resize = transforms.Resize(512)
resize_img = Resize(image)
return resize_img
# 裁剪左上、左下、右上、右下和中间等 5 张
def five_crop(image):
FiveCrop = transforms.FiveCrop(size=(256, 256))
five_images = FiveCrop(image)
return five_images
##################################################################################################
data_dir = 'data' #数据集的路径
dataset = MyData(data_dir) #创建对象
sum = dataset.__len__() #获取数据集的总长度
# 创建输出目录
outDir = 'data_crop_five/'
os.makedirs(outDir, exist_ok=True)
start=time.time() #程序开始的时间
for i in range(0, sum): #遍历每张图片
image = dataset.__getitem__(i) #获取下标为 i 的图像
name = get_name(image) #获取图像的文件名
resie_img = resize_img(image) #将图像Resize
five_imgs = five_crop(resie_img) #获取截取的 5 张图片
length = len(five_imgs) #获取长度
# print(length)
for cont in range(length): #遍历获取
img = five_imgs[cont] #截取到的单张图片
out_name = name + '_crop_five_' + str(cont) + '.png' # 输出的文件名
img.save(os.path.join(outDir, out_name)) # 按照路径保存单张图片
end=time.time() # 程序结束的时间
print('Running time: %s Seconds'%(end-start))
运行结果:
裁剪后的图片已经保按要求存到文件夹中,如下图所示:
总结
上面给的示例中,文件夹中只有3张图片,但是只要上面代码码跑通,无论文件夹中有多少张图片都是可以处理的。同时,可以更根据上面的示例去用Pytorch提供的其他方法切割图片,实现数据增强。