torch和numpy的维度转换

torch维度转换

PIL读入的图片的格式是(H,W, C)
numpy储存图片的格式是(batch_size, H, W, C)
通常卷积需要的是(batch_size,C,H, W)

因此需要进行维度转换。

1、numpy中的维度转换

numpy中使用reshape来进行形状变换。
transpose的作用是坐标轴变换,切换角度来看待问题。

n = np.random.randn(2, 3, 4)
reshape_n = n.reshape(-1, 12)
print(reshape_n.shape) # (2, 12)
transpose_n = n.transpose(1, 2, 0)
print(transpose_n.shape) # (3, 4, 12)

2、torch维度转换

使用torch.view来进行变换,尺寸转换,总数量不变。

t = torch.randn(2, 3, 4)
view_t = t.view(-1, 12)
print(view_t.shape) # torch.Size([2, 12])

torch.squeeze()/torch.unsqueeze()这个是用来压缩或者添加维度的。
squeeze(n)只能压缩第n个维度为1的维度。

t = torch.randn(1, 3, 4)
squeeze_t = t.squeeze(0)
print(squeeze_t.shape) # torch.Size([3, 4])

torch.unsqueeze(n)在第n个维度前增加一个维度1

t = torch.randn(1, 3, 4)
unsqueeze_t = t.unsqueeze(0)
print(unsqueeze_t.shape) # torch.Size([1, 1, 3, 4])

torch.permute()维度转换。
对维度进行重新排序,和numpy的transpose差不多。

t = torch.randn(1, 3, 4)
permute_t = t.permute(1, 2, 0)
print(permute_t.shape) # torch.Size([3, 4, 1])

猜你喜欢

转载自blog.csdn.net/m0_59967951/article/details/126532068