[论文笔记] LARGE SCALE GAN TRAINING FOR HIGH FIDELITY NATURAL IMAGE SYNTHESIS

ICLR2019在审文章,作者单位DeepMind


总述

文章希望既保证GAN生成图像的保真度又保证生成图像的多样性。对此,作者认为他们有三点贡献:
1、论证了GANs能通过scaling来提升性能。他们使用与原先技术相比,2~4倍的参数量和8倍的batch size,引入了两种简单的结构调整方法来提升网络的可扩展性,并修改一种正则化方案来提高conditioning。
2、上述修改产生的另一种影响是,模型非常适用于“trucation trick”,它是一种简单的采样技术,可以对样本多样性与保真性进行外部地细粒度地调节。
3、发现大型GAN特有的不稳定性,并从经验上对他们进行描述。经过分析表明通过现有技术与创新的技术的结合能够减少这种不稳定性,但是训练时完全的稳定性只能通过以较大地牺牲模型表现来实现。

作者训练的class-condition GAN在ImageNet上的表现很好(128X128分辨率),与state-of-art相比,Inception Score(IS)从52.52提升到166.3,Frechet Inception Distance(FID)从18.65下降到9.6.

Scaling up GANs

Baseline 模型

基于SA-GAN结构,使用hinge-loss作为GAN的目标函数。使用class-conditional BN向生成器G中加入类别信息,用projection向鉴别器D中加入类别信息。优化设置与原SA-GAN论文相同,但学习率减半,D每更新两次,G更新一次。对G的权重采用滑动平均(但文章发现progressive learning对模型并不必要)。不同于其他论文使用 N ( 0 , 0.02 I ) \mathcal{N}(0,0.02I) 或Xavier进行初始化,本文使用正交初始化。BN层的统计值是基于所有设备上的统计,不同于标准实现采用基于每个设备的统计。

A. 加大 BATCH SIZE

基于此模型,作者发现,将batch size提高为原来8倍,IS分数提升约46%.大的batchsize一方面提高模型表现,使模型更快收敛;另一方面,作者发现,这种scaling使得模型更不稳定,训练中很容易collapse。

B. 提高通道数

接着,作者尝试将模型中每层的通道数提高50%,参数量翻番,这使得IS分数进一步提升21%。

C. 共享嵌入层

作者还发现,条件BN中嵌入类别c占用了很多的权重,文章于是采用共享的嵌入来取代独立的层嵌入。这降低了内存与计算成本,模型训练速度提高37%。

D. 多层级潜在空间

此外,作者使用了多种hierarchical latent spaces,即将噪声向量 z z 输入到生成器的不同层中,而不是仅仅输入到第一层。这种做的直觉思路是用潜在空间来直接影响不同分辨率以及不同层次下的特征。hierarchical latents降低了计算量与内存占用,模型表现提升4%,训练速度提高18%。

E. 截断技巧

一般的噪声向量服从分布 z N ( 0 , I ) z\sim\mathcal{N}(0,I) ,但该技巧为其采样设置一个阈值,当采样超过该阈值时,重新采样,以使得采样点落入阈值范围。减小该阈值会发现,GAN生成的图像多样性降低,质量提高。如下图所示,从左到右为逐渐降低阈值。
在这里插入图片描述
作者在此处将IS类比为precision,FID类比为recall,通过改变截断的阈值,做出FID-IS曲线如下。阈值减小,多样性下降,质量提高,IS对多样性并不敏感,而FID对多样性和质量都敏感。所以可以看到,最初FID会有提高,但当阈值越来越小时,模型多样性下降,FID急剧下降。

FID-IS曲线图

直接使用截断技巧对很多模型来说是有问题的,会导致saturation artifacts,如下图所示:


在这里插入图片描述

为解决这个问题,作者希望通过限制G变得更平滑来使得 z z 的全部空间能投射到好的输出样本上。作者尝试使用正交正则化,即直接应用正交条件:
R β ( W ) = β W T W I F 2 R_\beta(W)=\beta||W^TW-I||_F^2
其中 W W 是权值矩阵, β \beta 是超参数。但是这个正则化被认为太过于limiting,因此作者使用了该正则化的改进形式:
R β ( W ) = β W T W ( 1 I ) F 2 R_\beta(W)=\beta||W^TW\odot(1-I)||_F^2
其中 1 1 表示元素全为1的矩阵。
作者发现,不使用正交正则化,仅有16%的模型可以截断;使用正交正则化后,60%的模型可以被截断。

上述各种改进的效果对比如下表所示:
从左到右依次是Batch size,通道数,参数量,共享嵌入层,多层级潜在空间,正交正则,迭代次数,FID,IS分数。
在这里插入图片描述

Scaling导致的模型不稳定性分析

生成器G

鉴别器D

评价指标

常用评价指标,用来判断GAN生成的图片的质量好坏。下面给出其定义,计算方式以及代码。

Inception Score (IS)

最初在Improved Techniques for Training GANs (2016)一文中提出。将GAN生成的图像输出到Inception模型中,得到条件标签分布 p ( y x ) p(y|x) 。包含有意义目标的图像的 p ( y x ) p(y|x) 熵值会较小;此外,我们还希望GAN模型能产生更多样的图像,因此 p ( y x = G ( z ) ) d z \int{p(y|x=G(z))dz} 应该有较高的边际熵。综合这两点,提出metric如下:
e x p ( E x K L ( p ( y x ) p ( y ) ) ) exp(\mathbb{E}_\mathbf{x}\mathbf{KL}(p(y|x)||p(y)))

指数形式使得值更方便比较。实际写代码的时候, p ( y x ) p(y|x) 就是每张图输入到Inception的输出,而 p ( y ) p(y) 就是所有图的Inception输出均值。
pytorch代码如下:

#https://github.com/sbarratt/inception-score-pytorch/blob/master/inception_score.py
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
import torch.utils.data

from torchvision.models.inception import inception_v3

import numpy as np
from scipy.stats import entropy

def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1):
    """Computes the inception score of the generated images imgs
    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Inception v3
    splits -- number of splits
    """
    N = len(imgs)

    assert batch_size > 0
    assert N > batch_size

    # Set up dtype
    if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print("WARNING: You have a CUDA device, so you should probably set cuda=True")
        dtype = torch.FloatTensor

    # Set up dataloader
    dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)

    # Load inception model
    inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)
    inception_model.eval();
    up = nn.Upsample(size=(299, 299), mode='bilinear',align_corners=True).type(dtype)
    def get_pred(x):
        if resize:
            x = up(x)
        x = inception_model(x)
        return F.softmax(x).data.cpu().numpy()

    # Get predictions
    preds = np.zeros((N, 1000))

    for i, batch in enumerate(dataloader, 0):
        batch = batch.type(dtype)
        batchv = Variable(batch)
        batch_size_i = batch.size()[0]

        preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv)

    # Now compute the mean kl-div
    split_scores = []

    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py)) #calculate KL-div using entropy(a,b)
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)

if __name__ == '__main__':
    class IgnoreLabelDataset(torch.utils.data.Dataset):
        def __init__(self, orig):
            self.orig = orig

        def __getitem__(self, index):
            return self.orig[index][0]

        def __len__(self):
            return len(self.orig)

    import torchvision.datasets as dset
    import torchvision.transforms as transforms

    cifar = dset.CIFAR10(root='data/', download=True,
                             transform=transforms.Compose([
                                 transforms.Scale(32),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                             ])
    )

    IgnoreLabelDataset(cifar)

    print ("Calculating Inception Score...")
    print (inception_score(IgnoreLabelDataset(cifar), cuda=True, batch_size=32, resize=True, splits=10))

计算得到cifar-10图片集Inception Score 分数均值约为9.3701,方差为0.1496。
Inception Score 本身有许多缺陷,具体见这篇文章

Fréchet Inception Distance (FID)

Inception Score的一个缺点是它没有用到真实世界样本的统计值来和生成样本作比较。用 p ( ) p(\cdot) 表示生成模型产生的样本分布, p ω ( ) p_\omega(\cdot) 表示真实样本的分布。当给定均值与方差时,高斯分布是熵最大的分布。两个高斯分布的距离用Fréchet distance来度量。均值方差为 ( m , C ) (m,C) 的高斯分布 p ( ) p(\cdot) 和均值方差为 ( m ω , C ω ) (m_\omega,C_\omega) 的高斯分布 p ω ( ) p_\omega(\cdot) 的Fréchet 距离 d(.,.) 被定义为Fréchet Inception Distance(FID),由下式给出:
d 2 ( ( m , C ) , ( m ω , C ω ) ) = m m ω 2 2 + T r ( C + C ω 2 ( C C ω ) 1 / 2 ) d^2((m,C),(m_\omega,C_\omega))=||m-m_\omega||_2^2+\mathbf{Tr}(C+C_\omega-2(CC_\omega)^{1/2})

为计算FID,类似于将图像输入到Inception模型中得到Inception Score,不同的是,FID使用最后一个pooling层作为编码层,对这个编码层来计算均值 m ω m_\omega 和协方差 C ω C_\omega
代码如下:

#https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
    Stable version by Dougal J. Sutherland.
    Params:
    -- mu1   : Numpy array containing the activations of a layer of the
               inception net (like returned by the function 'get_predictions')
               for generated samples.
    -- mu2   : The sample mean over activations, precalculated on an 
               representive data set.
    -- sigma1: The covariance matrix over activations for generated samples.
    -- sigma2: The covariance matrix over activations, precalculated on an 
               representive data set.
    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) +
            np.trace(sigma2) - 2 * tr_covmean)


def calculate_activation_statistics(images, model, batch_size=64,
                                    dims=2048, cuda=False, verbose=False):
    """Calculation of the statistics used by the FID.
    Params:
    -- images      : Numpy array of dimension (n_images, 3, hi, wi). The values
                     must lie between 0 and 1.
    -- model       : Instance of inception model
    -- batch_size  : The images numpy array is split into batches with
                     batch size batch_size. A reasonable batch size
                     depends on the hardware.
    -- dims        : Dimensionality of features returned by Inception
    -- cuda        : If set to True, use GPU
    -- verbose     : If set to True and parameter out_step is given, the
                     number of calculated batches is reported.
    Returns:
    -- mu    : The mean over samples of the activations of the pool_3 layer of
               the inception model.
    -- sigma : The covariance matrix of the activations of the pool_3 layer of
               the inception model.
    """
    act = get_activations(images, model, batch_size, dims, cuda, verbose)
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma


def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
    if path.endswith('.npz'):
        f = np.load(path)
        m, s = f['mu'][:], f['sigma'][:]
        f.close()
    else:
        path = pathlib.Path(path)
        files = list(path.glob('*.jpg')) + list(path.glob('*.png'))

        imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files])

        # Bring images to shape (B, 3, H, W)
        imgs = imgs.transpose((0, 3, 1, 2))

        # Rescale images to be between 0 and 1
        imgs /= 255

        m, s = calculate_activation_statistics(imgs, model, batch_size,
                                               dims, cuda)

    return m, s


def calculate_fid_given_paths(paths, batch_size, cuda, dims):
    """Calculates the FID of two paths"""
    for p in paths:
        if not os.path.exists(p):
            raise RuntimeError('Invalid path: %s' % p)

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx])
    if cuda:
        model.cuda()

    m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size,
                                         dims, cuda)
    m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size,
                                         dims, cuda)
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)

    return fid_value

猜你喜欢

转载自blog.csdn.net/qq_26020233/article/details/83004755