常见的数据扩展方式unsqueeze与expand的用法与区别

常见的数据扩展方式unsqueeze与expand的用法与区别

unsqueeze以及expand的区别

  • unsqueeze可以增加一个维度,但是维度的siz只是1而已;
  • 然而,expand却可以将数据进行复制,将增加的数据维度变为n。
# 获得一开始的初始化数值:tensor([[a1,a2,a3]])
nn1=torch.rand(1,3)
print(nn1)
print("nn1.shape",nn1.shape)
# unsqueeze是解压的意思,在第i个维度上进行扩展,将其扩展为tensor([[[a1,a2,a3]]])
nn2=nn1.unsqueeze(0)
print("*"*100)
print(nn2)
print("nn2.shape",nn2.shape)
nn3=nn1.unsqueeze(2)
print('='*100)
print(nn3)
print("nn3.shape",nn3.shape)
#利用expand对数据进行扩展
nn4=nn1.expand(1,3,3)
print("*"*100)
print(nn4)
print("nn4.shape",nn4.shape)
tensor([[0.8664, 0.8674, 0.7234]])
nn1.shape torch.Size([1, 3])
>>>
输出结果如下:
****************************************************************************************************
tensor([[[0.8664, 0.8674, 0.7234]]])
nn2.shape torch.Size([1, 1, 3])
====================================================================================================
tensor([[[0.8664],
         [0.8674],
         [0.7234]]])
nn3.shape torch.Size([1, 3, 1])
****************************************************************************************************
tensor([[[0.8664, 0.8674, 0.7234],
         [0.8664, 0.8674, 0.7234],
         [0.8664, 0.8674, 0.7234]]])
nn4.shape torch.Size([1, 3, 3])

相反地,squeeze()函数用于减小维度,它只能减少size=1的维度;

猜你喜欢

转载自blog.csdn.net/weixin_42782150/article/details/127363041