[Reserved] Pytorch in nn.Linear module of understanding
This article is reproduced and quoted text is simply to build and classify their knowledge to facilitate their future look, nothing other meaning.
This module is to be achieved formula: Y = X A T + B *
Source: https://blog.csdn.net/u012936765/article/details/52671156
Linear is a subclass of the module, a module is a parameterized, as its name indicates a kind of linear transformation.
create
parent's init function
Linear creation requires two parameters, inputSize and outputsize
inputSize: input nodes
outputSize: output nodes
so Linear has seven fields:weight : Tensor , outputSize ×× inputSize
bias: Tensor ,outputSize
gradWeight: Tensor , outputSize ×× inputSize
gradBias: Tensor ,outputSize
gradInput: Tensor
output: Tensor
_type: output:type()
例子
module = nn.Linear(10, 5)
1
Forward Pass
----------------
Disclaimer: This article is the original article CSDN bloggers "bubbleoooooo", and follow CC 4.0 BY-SA copyright agreement, reproduced, please attach the original source link and this statement. .
Original link: https://blog.csdn.net/u012936765/article/details/52671156
This article has a good example:
import torch
x = torch.randn(128, 20) # 输入的维度是(128,20)
m = torch.nn.Linear(20, 30) # 20,30是指维度
output = m(x)
print('m.weight.shape:\n ', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)
# ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
ans = torch.mm(x, m.weight.t()) + m.bias
print('ans.shape:\n', ans.shape)
print(torch.equal(ans, output))
The output is:
m.weight.shape:
torch.Size([30, 20])
m.bias.shape:
torch.Size([30])
output.shape:
torch.Size([128, 30])
ans.shape:
torch.Size([128, 30])
True
It is noted that the input two-dimensional tensor in a 128 * 20, after a linear transformation into a 128 * 30 input replaced if:
x = torch.randn(20, 128) # 输入的维度是(20,128)
m = torch.nn.Linear(20, 30) # 20,30是指维度
output = m(x)
It will be incorrect report. Since equation is Y = X A T + B. We can see from the above output, 30 is the dimension A is 20, then 20 * 30 is transposed, and so it should correspond to the number of columns of X. General: is the number of columns of input and output values is linear , the input into:
x = torch.randn(20, 20) # 输入的维度是(20,20)
m = torch.nn.Linear(20, 30) # 20,30是指维度
output = m(x)
After the output will find that change is still the number of columns.