数据集均值和方差的计算方法,以及常用数据集的均值和方差

学习目标:

分享一下在分类任务中,数据集均值和方差的计算方法,以及常用数据集(Cifa10、Cifa100)的均值和方差


计算方法:

精确值是通过分别计算R,G,B三个通道的数据算出来的, 比如你有2张图片,都是100100大小的,那么两图片的像素点共有2100*100 = 20 000 个; 那么这两张图片的

  1. 均值的求法: mean_R: 这20000个像素点的R值加起来,除以像素点的总数,这里是20000;mean_G 和mean_B 两个通道 的计算方法 一样的。

  2. 标准差的求法:
    在这里插入图片描述

代码:

from itertools import repeat
import os
from multiprocessing.pool import ThreadPool
from pathlib import Path
from PIL import Image
import numpy as np
from tqdm import tqdm

NUM_THREADS = os.cpu_count()


def calc_channel_sum(img_path):  # 计算均值的辅助函数,统计单张图像颜色通道和,以及像素数量
    img = np.array(Image.open(img_path).convert('RGB')) / 255.0  # 准换为RGB的array形式
    h, w, _ = img.shape
    pixel_num = h * w
    channel_sum = img.sum(axis=(0, 1))  # 各颜色通道像素求和
    return channel_sum, pixel_num


def calc_channel_var(img_path, mean):  # 计算标准差的辅助函数
    img = np.array(Image.open(img_path).convert('RGB')) / 255.0
    channel_var = np.sum((img - mean) ** 2, axis=(0, 1))
    return channel_var

def mean_and_var(data_path,data_format='*.png',decimal_places=4):
    """
    计算均值方差
    @param data_path: 数据集路径
    @param data_format: 图片格式(默认为png)
    @param decimal_places: 均值和方差,保留的小数位数(默认为4)
    @return:
    """
    print("Data root is ",data_path)
    train_path = Path(data_path)
    img_f = list(train_path.rglob(data_format))
    n = len(img_f)
    print(f'Data Nums is : {n}')
    print("Calculate the mean value")
    result = ThreadPool(NUM_THREADS).imap(calc_channel_sum, img_f)  # 多线程计算
    channel_sum = np.zeros(3)
    cnt = 0
    pbar = tqdm(enumerate(result), total=n)
    for i, x in pbar:
        channel_sum += x[0]
        cnt += x[1]
    mean = channel_sum / cnt

    mean=np.around(mean, decimal_places)  # 使用around()函数保留小数位数
    print('R_mean,G_mean,B_mean is ',mean)

    print("Calculate the var value")
    result = ThreadPool(NUM_THREADS).imap(lambda x: calc_channel_var(*x), zip(img_f, repeat(mean)))
    channel_sum = np.zeros(3)
    pbar = tqdm(enumerate(result), total=n)
    for i, x in pbar:
        channel_sum += x
    var = np.sqrt(channel_sum / cnt)
    var = np.around(var, decimal_places)  # 使用around()函数保留小数位数
    print('R_var,G_var,B_var is ', var)


if __name__ == '__main__':
    mean_and_var('/home/yangzhanshan/disk/datasets/cifa10/val')

常用数据集均值和方差:

cifa10 train

cifa10_mean=[0.4914, 0.4822, 0.4465]
cifa10_var=[0.2023, 0.1994, 0.2010]

cifa100-train

cifa100_mean=[0.5071, 0.4865, 0.4408]
cifa100_var=[0.2675, 0.2565, 0.2761]

imagenet

imagenet_mean=[0.485, 0.456, 0.406]
imagenet_std=[0.229, 0.224, 0.225]

猜你喜欢

转载自blog.csdn.net/qq_41823532/article/details/128969904
今日推荐