torch.cat(inputs, dimension=0) → Tensor
参数
:
•
inputs (sequence of Tensors) –
可以是任意相同
Tensor
类型的
python
序列
•
dimension (int, optional) –
沿着此维连接张量序列
例子1 dim = 0:
import torch
import numpy
a = torch.randn(2, 2, 3) # 生成通道为2 的2行3列张量
print(a.data.size())
print(type(a))
print('a:\n', a)
b = torch.cat((a, a), dim=0) # dim=0表示在通道上连接两个张量a,a,通道数翻倍,行列不变
print(b.data.size())
print('b:\n', b)
cat结果如下:可以看出通道数变化而行列没有任何变化!
torch.Size([2, 2, 3])
<class 'torch.Tensor'>
a:
tensor([[[ 0.4702, -1.1235, 1.6840],
[ 0.8872, 1.1102, -1.8428]],
[[ 0.8182, -0.5710, -1.2721],
[ 0.2145, 1.2150, 0.7147]]])
torch.Size([4, 2, 3])
b:
tensor([[[ 0.4702, -1.1235, 1.6840],
[ 0.8872, 1.1102, -1.8428]],
[[ 0.8182, -0.5710, -1.2721],
[ 0.2145, 1.2150, 0.7147]],
[[ 0.4702, -1.1235, 1.6840],
[ 0.8872, 1.1102, -1.8428]],
[[ 0.8182, -0.5710, -1.2721],
[ 0.2145, 1.2150, 0.7147]]])
例子2 dim = 1:
import torch
import numpy
a = torch.randn(2, 2, 3)
print(a.data.size())
print(type(a))
print('a:\n', a)
b = torch.cat((a, a), dim=1) # dim=1表示在行上连接两个张量a,a,行增加而通道数和列不变
print(b.data.size())
print('b:\n', b)
cat结果如下:可以看出行增加,而通道数和列没有任何变化!
torch.Size([2, 2, 3])
<class 'torch.Tensor'>
a:
tensor([[[ 0.7129, -0.1589, -1.4144],
[ 1.2887, 2.2833, 0.7735]],
[[-1.8561, 0.2988, -0.3955],
[-1.8440, 1.8290, 0.1959]]])
torch.Size([2, 4, 3])
b:
tensor([[[ 0.7129, -0.1589, -1.4144],
[ 1.2887, 2.2833, 0.7735],
[ 0.7129, -0.1589, -1.4144],
[ 1.2887, 2.2833, 0.7735]],
[[-1.8561, 0.2988, -0.3955],
[-1.8440, 1.8290, 0.1959],
[-1.8561, 0.2988, -0.3955],
[-1.8440, 1.8290, 0.1959]]])
例子3 dim = 2:
import torch
import numpy
a = torch.randn(2, 2, 3)
print(a.data.size())
print(type(a))
print('a:\n', a)
b = torch.cat((a, a), dim=2) # dim=2表示在列上连接两个张量a,a,列增加,而行和通道数不变
print(b.data.size())
print('b:\n', b)
cat结果如下:可以看出列增加,而通道数和行没有任何变化!
torch.Size([2, 2, 3])
<class 'torch.Tensor'>
a:
tensor([[[ 0.7935, 0.3900, -1.0024],
[-0.2843, -0.5554, 0.4073]],
[[ 1.0388, -0.4608, 0.4172],
[ 0.6668, -1.2096, -1.2609]]])
torch.Size([2, 2, 6])
b:
tensor([[[ 0.7935, 0.3900, -1.0024, 0.7935, 0.3900, -1.0024],
[-0.2843, -0.5554, 0.4073, -0.2843, -0.5554, 0.4073]],
[[ 1.0388, -0.4608, 0.4172, 1.0388, -0.4608, 0.4172],
[ 0.6668, -1.2096, -1.2609, 0.6668, -1.2096, -1.2609]]])