数据扩充:旋转、翻转、明暗、高斯模糊

'''
Function: this script is designed for hand train data augmentation
Method  : rotate/flip/bright/blur
Author  : xiakj
Date    : 2019/7/26
'''

import os
import numpy as np
from PIL import Image
from PIL import ImageEnhance
import random
import cv2

def Gaussblur(imgpath, new_img_path, imgname, gaussNum):
    # 高斯模糊,高斯核越大,图像越模糊
    img = cv2.imread(imgpath + imgname)
    for i in range(gaussNum):
        scale = random.randrange(3, 9, 2)  # 高斯核必须是奇数 3 5 7 9
        img_ = cv2.GaussianBlur(img, ksize=(scale, scale), sigmaX=0, sigmaY=0)
        fullpath = '{}{}{}{}{}'.format(new_img_path, 'g', str(i), '_', imgname)
        cv2.imwrite(fullpath, img_)

def BrightEnhance(imgpath, imgname, briNum):
    # imgpath:训练图像保存路径
    # imgname:训练图像名称,如0.jpg
    # points:训练图像关键点,共10个
    # labelfile: 标签文件路径
    # transNum:亮度变换次数,一般为偶数,变暗transNum/2次,变亮transNum/2次
    img = Image.open(imgpath+imgname)
    enh_bri = ImageEnhance.Brightness(img)
    for i in range(briNum):
        scale1 = random.uniform(0.5, 1.5)
        img_bri = enh_bri.enhance(scale1)
        fullpath = '{}{}{}{}{}'.format(imgpath, 'b', str(i), '_', imgname)
        # io.imsave(fullpath, img_bri)
        img_bri.save(fullpath)

def rotateImg(img, cols, rows, angle):
    tmp_img = img.copy()
    M = cv2.getRotationMatrix2D(((cols-1)/2.0,(rows-1)/2.0), angle ,1)
    dst = cv2.warpAffine(tmp_img, M, (cols, rows))
    return dst


if __name__ == '__main__':
    # #### ******************* for train data augmentation ********************
    brightNum = 2
    blurNum = 2
    rotNum = 2
    img_path = r'./train/8/'
    new_img_path = r'./new_train/8/'
    val_path = r'./new_val/8/'
    count = 0
    re_size = 128    # resize scale
    ## ************** step1: read image and process ***********
    for filename in os.listdir(img_path):
        count = count + 1
        print('train #%d' % count)
        # read image
        rawPath = img_path + filename
        img = cv2.imread(rawPath)
        img = cv2.resize(img, (re_size, re_size), interpolation=cv2.INTER_LINEAR)
        rows, cols, channel = img.shape
        # raw
        new_name = new_img_path + filename
        cv2.imwrite(new_name, img)
        # aug1: rotate image
        for i in range(rotNum):
            rotAngle = random.randint(-30,30)
            out_img = rotateImg(img, cols, rows, rotAngle)
            new_name = '{}{}{}{}{}'.format(new_img_path, 'r', str(i), '_', filename)
            cv2.imwrite(new_name, out_img)
        # # aug2: flip image
        # tmp_img = img.copy()
        # #out_img = tmp_img.transpose(Image.FLIP_LEFT_RIGHT)
        # out_img = cv2.flip(tmp_img, 1)
        # new_name = new_img_path + 'f_' + filename
        # cv2.imwrite(new_name, out_img)
        # aug3: gaussian
        Gaussblur(img_path, new_img_path, filename, blurNum)
        # aug4: bright
        BrightEnhance(img_path, filename, brightNum)


    #### ******************* for val data select and save ********************
    count = 0
    count_val = 0
    for filename in os.listdir(new_img_path):
        count = count + 1
        print('val #%d'%count)
        if count % 6 == 0:
            count_val = count_val + 1
            img = cv2.imread(new_img_path + filename)
            new_name = val_path + filename
            cv2.imwrite(new_name, img)   # save val data
            os.remove(new_img_path + filename) # remove val data from train data
        else:
            pass
    print('num(train)=%d num(val)=%d'%(count-count_val, count_val))















发布了59 篇原创文章 · 获赞 57 · 访问量 7万+

猜你喜欢

转载自blog.csdn.net/xiakejiang/article/details/97373847
今日推荐