详细分析Pytorch中的repeat以及reshape基本知识(附Demo)

前言

两者的差异如下:

特性 torch.repeat torch.reshape
功能 重复数据 改变数据的形状
内存占用 可能增加内存占用 通常不增加内存占用
返回类型 新的张量(可能是不同内存地址) 视图或新张量(取决于情况)
用法 tensor.repeat(repeats) tensor.reshape(shape)
示例 tensor.repeat(2, 3) tensor.reshape(2, 3)

总结

  • repeat 用于数据重复,适合需要扩展张量的场景
  • reshape 用于改变张量的形状,适合需要重新组织数据的场景

1. repeat

.
torch.repeat 是 PyTorch 中用于扩展张量的函数,它会在指定的维度上重复张量的内容
它的基本语法为 tensor.repeat(repeats),其中 repeats 是一个包含每个维度重复次数的元组

原理分析

  • torch.repeat 不会创建新的数据,而是通过在指定的维度上重复原始数据的引用来实现扩展
  • 这使得 repeat 在内存占用方面相对高效,因为它不复制数据,而只是扩展视图
import torch

# 创建一个 1D 张量
tensor1d = torch.tensor([1, 2, 3])
print(tensor1d)
# 在第一个维度上重复 2 次
repeated1d = tensor1d.repeat(2)
print(repeated1d)  # 输出: tensor([1, 2, 3, 1, 2, 3])

print("-----------------------")
# 创建一个 2D 张量
tensor2d = torch.tensor([[1, 2], [3, 4]])
print(tensor2d)
# 在行和列上分别重复 2 次
repeated2d = tensor2d.repeat(2, 3)
print(repeated2d)
# 输出:
# tensor([[1, 2, 1, 2, 1, 2],
#         [3, 4, 3, 4, 3, 4],
#         [1, 2, 1, 2, 1, 2],
#         [3, 4, 3, 4, 3, 4]])


print("-----------------------")
# 创建一个 3D 张量
tensor3d = torch.tensor([[[1, 2]], [[3, 4]]])
print(tensor3d)
# 在第一个维度上重复 2 次,第二个维度上重复 2 次
repeated3d = tensor3d.repeat(2, 2, 1)
print(repeated3d)
# 输出:
# tensor([[[1, 2],
#          [1, 2]],
#
#         [[3, 4],
#          [3, 4]],
#
#         [[1, 2],
#          [1, 2]],
#
#         [[3, 4],
#          [3, 4]]])

截图如下:

在这里插入图片描述

2. reshape

torch.reshape 用于改变张量的形状,而不改变其数据内容

其基本语法为 tensor.reshape(shape),shape 是新的维度大小,可以是一个整数或元组

原理分析

  • reshape 试图返回一个新的视图,不会分配新的内存,只要数据的排列允许
  • 如果不能实现,则会返回一个新的张量,包含相同的数据,但在内存上不同
import torch

# 创建一个 1D 张量
tensor1d = torch.tensor([1, 2, 3, 4, 5, 6])
print(tensor1d)
# 重塑为 2x3 的张量
reshaped1d = tensor1d.reshape(2, 3)
print(reshaped1d)
# 输出:
# tensor([[1, 2, 3],
#         [4, 5, 6]])

# 创建一个 2D 张量
tensor2d = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(tensor2d)
# 重塑为 3x2 的张量
reshaped2d = tensor2d.reshape(2, 3)
print(reshaped2d)
# 输出:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])

# 创建一个 3D 张量
tensor3d = torch.tensor([[[1], [2]], [[3], [4]]])
print(tensor3d)
# 重塑为 4x1 的张量
reshaped3d = tensor3d.reshape(4)
print(reshaped3d)
# 输出:
# tensor([1, 2, 3, 4])

截图如下:

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_47872288/article/details/143182670