pytorch 中的view 参数使用方法

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/nijiayan123/article/details/85102044

最近在看pytorch 代码时发现了使用了view 这个参数。一开始还不知道是啥。查看之后发现原来跟 keras 中的flat和reshape 类似。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool  = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

看到上面的代码大家对比一下 keras 中的代码就可以发现其中的差别了。其实就是一个压缩的操作。
那么这里为什么使用了有参数 “-1”呢。当你不知道你那个位置的参数具体是多少时可以使用“-1”来代替。程序会根据后面的参数推断出这个“-1”具体是什么值。当然你不能存在歧义的。

猜你喜欢

转载自blog.csdn.net/nijiayan123/article/details/85102044