对图像实现随机擦除代码(RandomErasing)

对应论文:《Random Erasing Data Augmentation》
相应的原文github代码地址:
https://github.com/zhunzhong07/Random-Erasing

class RandomErasing(object):
‘’’
Class that performs Random Erasing in Random Erasing Data Augmentation by Zhong et al.
-------------------------------------------------------------------------------------
probability: The probability that the operation will be performed.
sl: min erasing area
sh: max erasing area
r1: min aspect ratio
mean: erasing value
-------------------------------------------------------------------------------------
‘’’

def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]):
    self.probability = probability
    self.mean = mean
    self.sl = sl
    self.sh = sh
    self.r1 = r1

def __call__(self, img):
    
    if random.uniform(0, 1) > self.probability:
        return img
    
    for attempt in range(100):
        area = img.size()[1] * img.size()[2]
        
        target_area = random.uniform(self.sl, self.sh) * area
        aspect_ratio = random.uniform(self.r1, 1 / self.r1)
        
        h = int(round(math.sqrt(target_area * aspect_ratio)))
        w = int(round(math.sqrt(target_area / aspect_ratio)))
        
        if w < img.size()[2] and h < img.size()[1]:
            x1 = random.randint(0, img.size()[1] - h)
            y1 = random.randint(0, img.size()[2] - w)
            if img.size()[0] == 3:
                img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
                img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
                img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
            else:
                img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
            return img
    
    return img
发布了36 篇原创文章 · 获赞 1 · 访问量 6384

猜你喜欢

转载自blog.csdn.net/qq_34291583/article/details/103302310