pytorch如何在测试的时候启用dropout

我们知道,dropout一般都在训练的时候使用,那么测试的时候如何也开启dropout呢?
在pytorch中,网络有train和eval两种模式,在train模式下,dropout和batch normalization会生效,而val模式下,dropout不生效,bn固定参数。

想要在测试的时候使用dropout,可以把dropout单独设为train模式,这里可以使用apply函数:

def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()

下面是完整demo代码:

# coding: utf-8
import torch
import torch.nn as nn
import numpy as np


class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(8, 8)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.fc(x)
        x = self.dropout(x)
        return x

net = SimpleNet()
x = torch.FloatTensor([1]*8)
net.train()
y = net(x)
print('train mode result: ', y)

net.eval()
y = net(x)
print('eval mode result: ', y)

net.eval()
y = net(x)
print('eval2 mode result: ', y)

def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()

net.eval()
net.apply(apply_dropout)
y = net(x)
print('apply eval result:', y)

运行结果:

可以看到,在eval模式下,由于dropout未生效,每次跑的结果不同,利用apply函数,将Dropout单独设为train模式,dropout就生效了。

发布了443 篇原创文章 · 获赞 149 · 访问量 55万+

猜你喜欢

转载自blog.csdn.net/qian99/article/details/89052262