PyTorch : nn.Linear() 详解

线性转换:
在这里插入图片描述
举例:

input1 = torch.randn(128, 20)
input2 = torch.randn(128, 3, 20) #中间 * 可以添加任意维度
input3 = torch.randn(128, 3, 4, 20) 
m = nn.Linear(20, 30)
output1 = m(input1)
output2 = m(input2)
output3 = m(input3)
print(output1.size(), output2.size(), output3.size())
#
torch.Size([128, 30]) torch.Size([128, 3, 30]) torch.Size([128, 3, 4, 30])

中间 * 可以是任意维度,原理解释:

input2 = torch.randn(128, 3, 20)

m = nn.Linear(20, 30)
output2 = m(input2)

input3 = input2.reshape(128 * 3, 20)
output3 = m(input3)

print(output3 == output2.reshape(128 * 3, -1))
#
tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], dtype=torch.uint8)

可见,将所有前面的维度相乘变为了二维矩阵,nn.Linear() 线性变换,也就是全连接层的变换。

PyTorch官方文档

发布了70 篇原创文章 · 获赞 87 · 访问量 7533

猜你喜欢

转载自blog.csdn.net/qq_40263477/article/details/105154821