深度学习中Dropout浅析

关于Dropout这个概念,之前在学习深度学习的那本花书中遇到过,在项目中通过Caffe、Keras也用过,但是没有仔细研究过它是做什么的,而仅仅是拿过来用,最近看了一些相关的博客,想总结一下。

Dropout的作用是为了防止过拟合的,特别是当训练集数据量较小的时候,这就会导致其泛化能力大大下降。而加上Dropout的操作,其作用是让某些神经元失去作用,即在每次训练的时候,每个神经元有一定的几率被移除。

其实,Dropout可以被认为是集成大量深层神经网络的实用Bagging方法。Bagging涉及训练多个模型,并在每个测试样本上评估多个模型。当每个模型都是一个很大的神经网络时,这似乎是不切实际的,因为训练和评估这样的网络需要花费很多运行时间和内存。Dropout提供了一种廉价的Bagging集成近似,能够训练和评估指数级数量的神经网络。

它所做的就是在训练过程中集成包括所有从基础网络除去非输出单元后形成的子网络。简单一点的是乘零操作,即将一些单元的输出乘零就能有效地删除一个单元。而有些框架的实现就是如此。比如keras。

这里写图片描述

回想一下Bagging学习,我们定义 k 个不同的模型,从训练集有替换采样构造 k 个不同的数据集,然后在训练集 i 上训练模型 i 。Dropout的目标是在指数级数量的神经网络上近似这个过程。具体来说,在训练中使用Dropout时,我们会使用基于小批量的学习算法和较小的步长,如梯度下降等。我们每次在小批量中加载一个样本,然后随机抽样应用于网络中所有输入和隐藏单元的不同二值掩码。对于每个单元,掩码是独立采样的。掩码值为1 的采样概率(导致包含一个单元)是训练开始前一个固定的超参数。其实这个就是dropout的参数,其取值范围为 [ 0 , 1 ] .

以下是keras中关于dropout的函数实现,其实它使用随机数生成器生成0,1的向量,然后分别乘上该神经元的值。最后需要注意的是每个值还需要除以(1-level)的概率。其中level即是dropout的参数(以[0,1]之间某个值作为丢弃概率的参数)。

def dropout(x, level, noise_shape=None, seed=None):
    """Sets entries in `x` to zero at random,
    while scaling the entire tensor.

    # Arguments
        x: tensor
        level: fraction of the entries in the tensor
            that will be set to 0.
        noise_shape: shape for randomly generated keep/drop flags,
            must be broadcastable to the shape of `x`
        seed: random seed to ensure determinism.
    """
    if level < 0. or level >= 1:
        raise ValueError('Dropout level must be in interval [0, 1[.')
    if seed is None:
        seed = np.random.randint(1, 10e6)
    if isinstance(noise_shape, list):
        noise_shape = tuple(noise_shape)

    rng = RandomStreams(seed=seed)
    retain_prob = 1. - level

    if noise_shape is None:
        random_tensor = rng.binomial(x.shape, p=retain_prob, dtype=x.dtype)
    else:
        random_tensor = rng.binomial(noise_shape, p=retain_prob, dtype=x.dtype)
        random_tensor = T.patternbroadcast(random_tensor,
                                           [dim == 1 for dim in noise_shape])
    x *= random_tensor
    x /= retain_prob
    return x

至于为什么要除以那个值,可以查看这篇博客的解释。

猜你喜欢

转载自blog.csdn.net/qxconverse/article/details/79510205