[Reserved] Pytorch in nn.Linear module of understanding

[Reserved] Pytorch in nn.Linear module of understanding

This article is reproduced and quoted text is simply to build and classify their knowledge to facilitate their future look, nothing other meaning.

This module is to be achieved formula: Y = X A T + B *

Source: https://blog.csdn.net/u012936765/article/details/52671156

Linear is a subclass of the module, a module is a parameterized, as its name indicates a kind of linear transformation.

create

Write pictures described here

parent's init function

Write pictures described here

Linear creation requires two parameters, inputSize and outputsize
inputSize: input nodes
outputSize: output nodes
so Linear has seven fields:

weight : Tensor , outputSize ×× inputSize
bias: Tensor ,outputSize
gradWeight: Tensor , outputSize ×× inputSize
gradBias: Tensor ,outputSize
gradInput: Tensor
output: Tensor
_type: output:type()
例子
module = nn.Linear(10, 5)
1
Forward Pass

Write pictures described here

----------------
Disclaimer: This article is the original article CSDN bloggers "bubbleoooooo", and follow CC 4.0 BY-SA copyright agreement, reproduced, please attach the original source link and this statement. .
Original link: https://blog.csdn.net/u012936765/article/details/52671156

This article has a good example:

import torch

x = torch.randn(128, 20)  # 输入的维度是(128,20)
m = torch.nn.Linear(20, 30)  # 20,30是指维度
output = m(x)
print('m.weight.shape:\n ', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)

# ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
ans = torch.mm(x, m.weight.t()) + m.bias   
print('ans.shape:\n', ans.shape)

print(torch.equal(ans, output))

The output is:

m.weight.shape:
  torch.Size([30, 20])
m.bias.shape:
 torch.Size([30])
output.shape:
 torch.Size([128, 30])
ans.shape:
 torch.Size([128, 30])
True

It is noted that the input two-dimensional tensor in a 128 * 20, after a linear transformation into a 128 * 30 input replaced if:

x = torch.randn(20, 128)  # 输入的维度是(20,128)
m = torch.nn.Linear(20, 30)  # 20,30是指维度
output = m(x)

It will be incorrect report. Since equation is Y = X A T + B. We can see from the above output, 30 is the dimension A is 20, then 20 * 30 is transposed, and so it should correspond to the number of columns of X. General: is the number of columns of input and output values is linear , the input into:

x = torch.randn(20, 20)  # 输入的维度是(20,20)
m = torch.nn.Linear(20, 30)  # 20,30是指维度
output = m(x)

After the output will find that change is still the number of columns.

Guess you like

Origin www.cnblogs.com/jiading/p/11945346.html