PyTorch实现计算图像数据集的均值和标准差

一、实现过程

使用Pytorch进行预处理时,通常使用torchvision.transforms.Normalize(mean, std)方法进行数据标准化,其中参数mean和std分别表示图像集每个通道的均值和标准差序列。
首先,给出mean和std的定义,数学表示如下:
假设有一组数据集 X i ,    i ∈ { 1 , 2 , ⋯   , n } X_i,\,\,i\in\{1,2,\cdots,n\} Xi,i{ 1,2,,n},则这组数据集的均值为: m e a n = ∑ i = 1 n X i n (1) mean=\frac{\displaystyle\sum_{i=1}^nX_i}{n}\tag{1} mean=ni=1nXi(1)通常使用 X ‾ \overline X X表示数据的均值。
这组数据集的标准差为: s t d = ∑ i = 1 n ( X i − X ‾ ) 2 n = ∑ i = 1 n ( X i 2 − 2 X i X ‾ + X ‾ 2 ) n = ( ∑ i = 1 n X i 2 ) − n X ‾ 2 n = ∑ i = 1 n X i 2 n − X ‾ 2 (2) std=\sqrt{\frac{\displaystyle\sum_{i=1}^n\left(X_i-\overline X\right)^2}{n}}\\[2ex]=\sqrt{\frac{\displaystyle\sum_{i=1}^n(X_i^2-2X_i\overline X+\overline X^2)}{n}}\\[2ex]=\sqrt{\frac{\left(\displaystyle\sum_{i=1}^nX_i^2\right)-n\overline X^2}{n}}\\[2ex]=\sqrt{\frac{\displaystyle\sum_{i=1}^nX_i^2}{n}-\overline X^2}\tag{2} std=ni=1n(XiX)2 =ni=1n(Xi22XiX+X2) =n(i=1nXi2)nX2 =ni=1nXi2X2 (2)下面给出计算图像数据集每个通道的均值和标准差的函数代码:

import torch
from torchvision import transforms,datasets
from torch.utils.data import DataLoader

batch_size = 64

# 训练集(以CIFAR-10数据集为例)
train_dataset = datasets.CIFAR10(root='G:/datasets/cifar10',train=True,download=False,transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)

def get_mean_std_value(loader):
    '''
    求数据集的均值和标准差
    :param loader:
    :return:
    '''
    data_sum,data_squared_sum,num_batches = 0,0,0

    for data,_ in loader:
        # data: [batch_size,channels,height,width]
        # 计算dim=0,2,3维度的均值和,dim=1为通道数量,不用参与计算
        data_sum += torch.mean(data,dim=[0,2,3])    # [batch_size,channels,height,width]
        # 计算dim=0,2,3维度的平方均值和,dim=1为通道数量,不用参与计算
        data_squared_sum += torch.mean(data**2,dim=[0,2,3])  # [batch_size,channels,height,width]
        # 统计batch的数量
        num_batches += 1
    # 计算均值
    mean = data_sum/num_batches
    # 计算标准差
    std = (data_squared_sum/num_batches - mean**2)**0.5
    return mean,std

mean,std = get_mean_std_value(train_loader)
print('mean = {},std = {}'.format(mean,std))

CIFAR10数据集的均值和标准差为:

mean = tensor([0.4914, 0.4821, 0.4465]),std = tensor([0.2470, 0.2435, 0.2616])

MNIST数据集的均值和标准差为:

mean = tensor([0.1307]),std = tensor([0.3081])

二、参考文献

[1] https://zhuanlan.zhihu.com/p/378810257

猜你喜欢

转载自blog.csdn.net/weixin_43821559/article/details/123459085
今日推荐