【深度学习】torch.squeeze()移除维度函数 | torch.unsqueeze()增加某一维度函数 | pytorch


前言

这两个函数在pytorch框架下的深度学习经常用到,这次把它们记录一下。

一、torch.squeeze()函数

torch.squeeze()用来“挤”掉某一个维度为1的维度,或者所有维度为1的维度。(只挤掉维度为1的维度)
例子如下:

import torch
A=torch.rand(1,3,224,224)
B=torch.unsqueeze(A,dim=0)
print(B.shape)

结果:
在这里插入图片描述
一般来说,这个函数多用于最后网络输出图片的可视化。
如果对维度不为1的维度进行去除:

import torch
A=torch.rand(1,3,224,224)
B=torch.squeeze(A,dim=1)
print(B.shape)
A=torch.rand(1,3,224,224)
B=torch.squeeze(A,dim=2)
print(B.shape)
A=torch.rand(1,3,224,224)
B=torch.squeeze(A,dim=3)
print(B.shape)

在这里插入图片描述
不会发生变化

二、torch.unsqueeze()函数

torch.unsqueeze()函数用来插入新的维度扩充张量。例子如下:
在第0维度增加一个维度大小为1的维度(也就是在最前面加一个1)

import torch
A=torch.rand(3,224,224)
B=torch.unsqueeze(A,dim=0)
print(B.shape)

结果为:(这个一般用的最多,比如输入的VGG的照片是1,3,224,224.一般的三通道照片是3,224,224,这时就需要用unsqueeze函数)
在这里插入图片描述
在第1,2,3维度增加一个维度大小为1的维度,只需要把dim改改就行

import torch
A=torch.rand(3,224,224)
B=torch.unsqueeze(A,dim=1)
print(B.shape)
import torch
A=torch.rand(3,224,224)
B=torch.unsqueeze(A,dim=2)
print(B.shape)
import torch
A=torch.rand(3,224,224)
B=torch.unsqueeze(A,dim=3)
print(B.shape)

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_46274756/article/details/128101166