【深度学习】Dropout原理以及代码实现

Drouout起源

随着深度学习的发展,各种网络结构层出不穷,导致网络越来越深,数据的容纳量越来越大,模型的参数也越来越多,这样就会导致神经网络很容易过拟合,过拟合的模型几乎是不能用于实践中的,因为拟合的数据与测试数据不一致,导致模型的泛化性能较低。

由于这个问题,现在有很多解决办法,比如使用集成模型,同时训练多个模型,最终让几个模型进行加权,这样能够解决过拟合是因为使用了不同的模型,每个模型拟合到的数据特征不同,最终由于互补,能够使模型的泛化性能更强,随机森林就是很经典的ensemble模型,分别单独训练多个子树,每个子树之间互不依赖。

但是这种方法对于深度学习不太现实,因为要训练多个网络,并且使得每个网络最终得到它们的最优解,我们都知道现在的网络很复杂,训练太过昂贵,训练一个模型都很不容易,何况是同时训练多个,这显然是不太现实的,另外一点就是可以使用同一个网络结构,然后训练不同的数据,也就是说,分别使用不同的训练子集进行训练获得每个训练子集对应的网络结构,这样的问题就是导致浪费了数据,而且真实中没有那么多的数据进行切分。

神经网络的两个缺点:

  • 捕捉高阶特征依赖,容易过拟合
  • 训练费时费力

什么是Dropout

Dropout是一项技术可以解决上面提到的问题,它可以理解是一种正则化的技术,它的原理是在模型正向传播过程中,以一定的概率p使隐层中的神经元暂时性失活,注意只是暂时失活,并不是永久失活,每个mini-batch进行传播时,都会随机失活一定神经元,这样就会缓解过拟合的现象,原因是训练每个mini-batch时,并不是所有的神经元都会参与训练,说白了就是每个传播都是网络结构的部分神经元在发挥作用,这样与原网络就会少很多训练参数,因为失活一定神经元,会导致每个mini-batch的数据不会完全拟合,这就是Dropout的原理所在。

对于概率p是我们指定的,它的意思就是每个神经元有多大的概率仍在工作,如果令p=0.3,那么也就是说每个神经元有70%的概率会失活,如果p=0,那么所有的神经元会全部失活,导致网络不会传播。

在训练期间,我们会指定p,然后随机产生一个符合伯努利分布的张量,该张量的形状和该隐层神经元的个数相同,如果一个隐层的神经元的个数为 5 个,那么产生的张量就是 【 0,0,1,1,0】,就是说只有第三个和第四个工作,其余全部失活。

image-20211124114944775

还有个问题就是,由于在训练期间使用Dropout让部分神经元失活,而在模型推断的过程中所有的神经元是存活的,这样就会导致训练期间该层的输出与推断时该层的输出期望不一致

有人会说在推断时也采用Dropout那这样不就和训练一致了,这样期望是一致的,但是这样在现实中会存在一个问题,因为Dropout会随机的进行失活神经元,如果推断时使用Dropout随机失活的话,这样就会导致每个推断的结果不一致,带来结果不稳定的问题

  • 训练: p x ∗ 1 + ( 1 − p ) ∗ 0 = p x px*1+(1-p)*0=px px1+(1p)0=px
  • 推断: x x x

这样导致推断时的期望会与训练时不一致,解决办法有两种,一种是训练时进行缩放,第二种是在推断时进行放大

  • 方法一: x / = p x/=p x/=p
  • 方法二: x ∗ = p x*=p x=p

代码实现

import numpy as np

"""
discard_prob : 每个神经元被丢弃的概率,计算时要换成keep_prob
"""
def dropout(x, discard_prob=0.5, seed=None):
    if discard_prob < 0 or discard_prob > 1:
        raise ValueError('Dropout prob must be in interval [0,1].')
        
    keep_prob = 1 - discard_prob

    seed = np.random.seed(seed)

    random_tensor = np.random.binomial(n=1, p=keep_prob, size=x.shape)
    
    x *= random_tensor

    x /= keep_prob
    
    return x

x = np.random.uniform(low=1, high=5, size=10)
out = dropout(x,discard_prob=0.2)
out

猜你喜欢

转载自blog.csdn.net/m0_47256162/article/details/121512403

相关文章