torch.cat在给定维度上对输入的张量序列进行连接操作

torch.cat(inputs, dimension=0) → Tensor
参数 :
inputs (sequence of Tensors) – 可以是任意相同 Tensor 类型的 python 序列
dimension (int, optional) – 沿着此维连接张量序列

例子1  dim = 0:

import torch
import numpy
a = torch.randn(2, 2, 3) # 生成通道为2 的2行3列张量
print(a.data.size())
print(type(a))
print('a:\n', a)
b = torch.cat((a, a), dim=0) # dim=0表示在通道上连接两个张量a,a,通道数翻倍,行列不变
print(b.data.size())
print('b:\n', b)

cat结果如下:可以看出通道数变化而行列没有任何变化!

torch.Size([2, 2, 3])
<class 'torch.Tensor'>
a:
 tensor([[[ 0.4702, -1.1235,  1.6840],
         [ 0.8872,  1.1102, -1.8428]],

        [[ 0.8182, -0.5710, -1.2721],
         [ 0.2145,  1.2150,  0.7147]]])
torch.Size([4, 2, 3])
b:
 tensor([[[ 0.4702, -1.1235,  1.6840],
         [ 0.8872,  1.1102, -1.8428]],

        [[ 0.8182, -0.5710, -1.2721],
         [ 0.2145,  1.2150,  0.7147]],

        [[ 0.4702, -1.1235,  1.6840],
         [ 0.8872,  1.1102, -1.8428]],

        [[ 0.8182, -0.5710, -1.2721],
         [ 0.2145,  1.2150,  0.7147]]])

例子2  dim = 1:

import torch
import numpy
a = torch.randn(2, 2, 3)
print(a.data.size())
print(type(a))
print('a:\n', a)
b = torch.cat((a, a), dim=1) # dim=1表示在行上连接两个张量a,a,行增加而通道数和列不变
print(b.data.size())
print('b:\n', b)

cat结果如下:可以看出行增加,而通道数和列没有任何变化!

torch.Size([2, 2, 3])
<class 'torch.Tensor'>
a:
 tensor([[[ 0.7129, -0.1589, -1.4144],
         [ 1.2887,  2.2833,  0.7735]],

        [[-1.8561,  0.2988, -0.3955],
         [-1.8440,  1.8290,  0.1959]]])
torch.Size([2, 4, 3])
b:
 tensor([[[ 0.7129, -0.1589, -1.4144],
         [ 1.2887,  2.2833,  0.7735],
         [ 0.7129, -0.1589, -1.4144],
         [ 1.2887,  2.2833,  0.7735]],

        [[-1.8561,  0.2988, -0.3955],
         [-1.8440,  1.8290,  0.1959],
         [-1.8561,  0.2988, -0.3955],
         [-1.8440,  1.8290,  0.1959]]])

例子3  dim = 2:

import torch
import numpy
a = torch.randn(2, 2, 3)
print(a.data.size())
print(type(a))
print('a:\n', a)
b = torch.cat((a, a), dim=2) # dim=2表示在列上连接两个张量a,a,列增加,而行和通道数不变
print(b.data.size())
print('b:\n', b)

cat结果如下:可以看出列增加,而通道数和行没有任何变化!

torch.Size([2, 2, 3])
<class 'torch.Tensor'>
a:
 tensor([[[ 0.7935,  0.3900, -1.0024],
         [-0.2843, -0.5554,  0.4073]],

        [[ 1.0388, -0.4608,  0.4172],
         [ 0.6668, -1.2096, -1.2609]]])
torch.Size([2, 2, 6])
b:
 tensor([[[ 0.7935,  0.3900, -1.0024,  0.7935,  0.3900, -1.0024],
         [-0.2843, -0.5554,  0.4073, -0.2843, -0.5554,  0.4073]],

        [[ 1.0388, -0.4608,  0.4172,  1.0388, -0.4608,  0.4172],
         [ 0.6668, -1.2096, -1.2609,  0.6668, -1.2096, -1.2609]]])

猜你喜欢

转载自blog.csdn.net/lzdjlu/article/details/142906018