高分辨率图像剪切——目标检测

高分辨率图像剪切


前言

最近做目标检测遇到一个问题,图像很大,几千x几千的像素,直接输入网络训练效果不好,通常目标检测输入的图像都是固定大小的,例如:yolo系列,图片输入网路前会resize到416*416。很大的图片直接做resize会丢失很多信息,因此训练出来的效果好,因此我们重叠剪切,对图像和标签同时做重叠剪切,然后输入网络训练。
在这里插入图片描述

代码

代码包含同时对标签和图像处理,标签格式是xml

import sys
import os
import cv2
import numpy as np
import os.path
from xml.dom.minidom import Document
from tqdm import tqdm
import xml.etree.cElementTree as ET


def clip_img(number, oriname):

    #------------------------------------------------------------#
    # 读入原始图像
    # 保存原图的大小
    # 可以resize也可以不resize,看情况而定
    #------------------------------------------------------------#
    from_name = os.path.join(image_path, oriname+'.png')
    img = cv2.imread(from_name)
    h_ori,w_ori, _ =img.shape
    img = cv2.resize(img, (2048, 1536))
    h, w, _ = img.shape

    #------------------------------------------------------------#
    # 输入.xml文件
    # 创建存放坐标四个值和类别的列表
    #------------------------------------------------------------#
    xml_name = os.path.join(label_xml_path, oriname+'.xml')
    xml_ori = ET.parse(xml_name).getroot()
    res = np.empty((0,5))

    #------------------------------------------------------------#
    # 找到每个.xml文件中的bbox
    # lower().strip()转化小写移除字符串头尾空格
    # vstack() 水平堆叠
    #------------------------------------------------------------#
    for obj in xml_ori.iter('object'):
        difficult = int(obj.find('difficult').text) == 1
        if difficult:
            continue
        name = obj.find('name').text.lower().strip()
        bbox = obj.find('bndbox')
        pts = ['xmin', 'ymin', 'xmax', 'ymax']
        bndbox = []
        for i, pt in enumerate(pts):
            cur_pt = int(bbox.find(pt).text) - 1
            cur_pt = int(cur_pt*h / h_ori) if i % 2== 1 else int(cur_pt * w / w_ori)
            bndbox.append(cur_pt)
     
        bndbox.append(name)
        res = np.vstack((res, bndbox))
    print('*'*5, res)

    #-------------------------------------------------------------#
    # 开始剪切 + 写入标签信息
    #-------------------------------------------------------------#
    i = 0
    win_size = 512 # 分块的大小
    stride = 128 # 重叠的大小,设置这个可以使分块有重叠
    for r in range(0, h - win_size, stride):
        for c in range(0, w - win_size, stride):
            flag =  np.zeros([1, len(res)])

            youwu = False
            xiefou = True

            tmp = img[r: r + win_size, c: c + win_size]
            for re in range(res.shape[0]):
                xmin, ymin, xmax, ymax, label = res[re]
                #------------------------------------------------#
                # 判断bb是否在当前剪切的区域内
                #------------------------------------------------#
                if int(xmin)>= c and int(xmax) <= c + win_size and int(ymin) >= r and int(ymax) <= r + win_size:
                    flag[0][re] = 1
                    youwu = True
                elif int(xmin) < c or int(xmax) > c + win_size or int(ymin) < r or int(ymax) > r + win_size:
                    pass
                else:
                    xiefou = False
                    break
            
            # 如果物体被分割了,则忽略不写入
            if xiefou:
                # 有物体则写入xml文件
                if youwu: 
                    #---------------------------------------------------#
                    # 创建.xml文件 + 写入bb
                    #---------------------------------------------------#
                    doc = Document()

                    width, height, channel = str(win_size), str(win_size), str(3)

                    annotation = doc.createElement('annotation')
                    doc.appendChild(annotation)

                    size_chartu = doc.createElement('size')
                    annotation.appendChild(size_chartu)

                    width1 = doc.createElement('width')
                    width1_text = doc.createTextNode(width)
                    width1.appendChild(width1_text)
                    size_chartu.appendChild(width1)

                    height1 = doc.createElement('height')
                    height1_text = doc.createTextNode(height)
                    height1.appendChild(height1_text)
                    size_chartu.appendChild(height1)

                    channel1 = doc.createElement('channel')
                    channel1_text = doc.createTextNode(channel)
                    channel1.appendChild(channel1_text)
                    size_chartu.appendChild(channel1)

                    for re in range(res.shape[0]):

                        xmin, ymin, xmax, ymax, label = res[re]

                        xmin=int(xmin)
                        ymin=int(ymin)
                        xmax=int(xmax)
                        ymax=int(ymax)

                        if flag[0][re] == 1:

                            xmin=str(xmin-c)
                            ymin=str(ymin-r)
                            xmax=str(xmax-c)
                            ymax=str(ymax-r)

                            object_charu = doc.createElement('object')
                            annotation.appendChild(object_charu)

                            name_charu = doc.createElement('name')
                            name_charu_text = doc.createTextNode(label)
                            name_charu.appendChild(name_charu_text)
                            object_charu.appendChild(name_charu)

                            dif = doc.createElement('difficult')
                            dif_text = doc.createTextNode('0')
                            dif.appendChild(dif_text)
                            object_charu.appendChild(dif)

                            bndbox = doc.createElement('bndbox')
                            object_charu.appendChild(bndbox)

                            xmin1 = doc.createElement('xmin')
                            xmin_text = doc.createTextNode(xmin)
                            xmin1.appendChild(xmin_text)
                            bndbox.appendChild(xmin1)

                            ymin1 = doc.createElement('ymin')
                            ymin_text = doc.createTextNode(ymin)
                            ymin1.appendChild(ymin_text)
                            bndbox.appendChild(ymin1)

                            xmax1 = doc.createElement('xmax')
                            xmax_text = doc.createTextNode(xmax)
                            xmax1.appendChild(xmax_text)
                            bndbox.appendChild(xmax1)

                            ymax1 = doc.createElement('ymax')
                            ymax_text = doc.createTextNode(ymax)
                            ymax1.appendChild(ymax_text)
                            bndbox.appendChild(ymax1)

                        else:
                            continue
                    xml_name = oriname+'_%3d.xml' % (i)
                    to_xml_name = os.path.join(lablel_xml_path, xml_name)
                    with open(to_xml_name, 'wb+') as f:
                        f.write(doc.toprettyxml(indent="\t", encoding='utf-8'))
                    #name = '%02d_%02d_%02d_.bmp' % (number, int(r/win_size), int(c/win_size))
                    img_name = oriname+'_%3d.png' %(i)
                    to_name = os.path.join(image_crop_path, img_name)
                    i = i+1
                    cv2.imwrite(to_name, tmp)

if __name__ == "__main__":

    image_path = r'E:\wcs\neural_new\neural_new\test\images'
    label_xml_path = r'E:\wcs\neural_new\neural_new\test\Annotations'

    image_crop_path = 'E:\\wcs\\cell\\test\\image\\'
    lablel_xml_path = 'E:\\wcs\\cell\\test\\label\\'

    if not os.path.exists(image_crop_path):
        os.makedirs(image_crop_path)
    if not os.path.exists(lablel_xml_path):
        os.makedirs(lablel_xml_path)

    for i, name in tqdm(enumerate(os.listdir(image_path))):
        clip_img(i, name.rstrip('.png'))


参考文献:https://blog.csdn.net/m0_37615398/article/details/84982384

猜你喜欢

转载自blog.csdn.net/CharmsLUO/article/details/124109179