from torch import nn
m = nn.Bilinear(96, 96, 96)
input1 = torch.randn(8,7, 96)
input2 = torch.randn(8,7, 96)
output = m(input1, input2)
print(output.size())
torch.Size([8, 7, 96])
比起Linear层,Bilinear层有什么特点与优势吗?
参考资料
python - Understanding Bilinear Layers - Stack Overflow
pytorch中的nn.Bilinear的计算原理详解_nihate的博客-CSDN博客_pytorch中bilinear //讲得很细,非常推荐,如果像搞清楚计算原理的话