PyTorch的torch.cat

转载来源:https://blog.csdn.net/qq_39709535/article/details/80803003

1. 字面理解:torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起。

2. 例子理解


  
  
  1. >>> import torch
  2. >>> A=torch.ones( 2, 3) #2x3的张量(矩阵)
  3. >>> A
  4. tensor([[ 1., 1., 1.],
  5. [ 1., 1., 1.]])
  6. >>> B= 2*torch.ones( 4, 3) #4x3的张量(矩阵)
  7. >>> B
  8. tensor([[ 2., 2., 2.],
  9. [ 2., 2., 2.],
  10. [ 2., 2., 2.],
  11. [ 2., 2., 2.]])
  12. >>> C=torch.cat((A,B), 0) #按维数0(行)拼接
  13. >>> C
  14. tensor([[ 1., 1., 1.],
  15. [ 1., 1., 1.],
  16. [ 2., 2., 2.],
  17. [ 2., 2., 2.],
  18. [ 2., 2., 2.],
  19. [ 2., 2., 2.]])
  20. >>> C.size()
  21. torch.Size([ 6, 3])
  22. >>> D= 2*torch.ones( 2, 4) #2x4的张量(矩阵)
  23. >>> C=torch.cat((A,D), 1) #按维数1(列)拼接
  24. >>> C
  25. tensor([[ 1., 1., 1., 2., 2., 2., 2.],
  26. [ 1., 1., 1., 2., 2., 2., 2.]])
  27. >>> C.size()
  28. torch.Size([ 2, 7])

上面给出了两个张量A和B,分别是2行3列,4行3列。即他们都是2维张量。因为只有两维,这样在用torch.cat拼接的时候就有两种拼接方式:按行拼接和按列拼接。即所谓的维数0维数1. 

C=torch.cat((A,B),0)就表示按维数0(行)拼接A和B,也就是竖着拼接,A上B下。此时需要注意:列数必须一致,即维数1数值要相同,这里都是3列,方能列对齐。拼接后的C的第0维是两个维数0数值和,即2+4=6.

C=torch.cat((A,B),1)就表示按维数1(列)拼接A和B,也就是横着拼接,A左B右。此时需要注意:行数必须一致,即维数0数值要相同,这里都是2行,方能行对齐。拼接后的C的第1维是两个维数1数值和,即3+4=7.

从2维例子可以看出,使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数数值需相同,方能对齐

3.实例

在深度学习处理图像时,常用的有3通道的RGB彩色图像及单通道的灰度图。张量size为cxhxw,即通道数x图像高度x图像宽度。在用torch.cat拼接两张图像时一般要求图像大小一致而通道数可不一致,即h和w同,c可不同。当然实际有3种拼接方式,另两种好像不常见。比如经典网络结构:U-Net

                                    

里面用到4次torch.cat,其中copy and crop操作就是通过torch.cat来实现的。可以看到通过上采样(up-conv 2x2)将原始图像h和w变为原来2倍,再和左边直接copy过来的同样h,w的图像拼接。这样做,可以有效利用原始结构信息。

4.总结

使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数数值需相同,方能对齐。

猜你喜欢

转载自blog.csdn.net/humanpose/article/details/88375395