PyTorch的nn.Linear()详解

参考链接PyTorch的nn.Linear()详解 - douzujun - 博客园 (cnblogs.com)

这里演示了二维张量的全连接 :

 其实还可以输入三维张量,演示如下:

from torch import nn
import torch

# in_features由输入张量的形状决定,out_features则决定了输出张量的形状
linear = nn.Linear(in_features=64 * 3, out_features=5)

# 10个 大小为7*64*3, 3个channel 的张量
a = torch.rand(10, 3, 7, 64 * 3)

print(a.shape)  # torch.Size([10, 3, 7, 192])

print(linear.weight.shape)  # torch.Size([5, 192])

b = linear(a)

print(b.shape)  # torch.Size([10, 3, 7, 5])

猜你喜欢

转载自blog.csdn.net/Yang_4881002/article/details/127900126