pytorch 的tensor的索引,切片,连接,换位 Indexing, Slicing,Joining, Mutating Ops

pytorch 的tensor的索引,切片,连接,换位 Indexing, Slicing,Joining, Mutating Ops

对 Tensor 张量拼接

torch.cat(inputs, dimension=0) ->Tensor

在给定的维度上对输入的张量序列 seq 进行拼接操作。torch.cat()可以看做 torch.split() 和 torch.chunk()的反操作。cat()函数可以通过下面的例子更好理解。
参数:

  • inputs (sequence of Tensors) - 可以是任意相同 Tensor 类型的 python 序列
  • dimension (int, optional) - 沿着此维连接张量序列。
    例子:
    在这里插入图片描述
    在这里插入图片描述
    在给定维度上将张量进行分块:
torch.chunk(tensor, chuncks, dim=0)

参数:

  • tensor(Tensor) - 代分块的输入张量
  • chunks(int) - 分块的个数
  • dim(int) - 沿此维度进行分块

沿着指定轴,按照索引 index 中指定的位置进行聚合。

torch.gather(input, dim, index, out=None) ->Tensor

参数:

  • input(Tensor) - 输入张量
  • dim(int) - 指定的维度
  • index(LongTensor) - 聚合元素的下标(index 的 size 和 input 的 size 一致)
  • out(Tensor, optional) - 目标张量
    在这里插入图片描述
    例子:
    在这里插入图片描述
    沿着指定维度对输入进行切片,取 index 中指定的相应项 (index 为一个 LongTensor),然后返回一个新的张量,返回的张量与原始张量_Tensor_ 有相同的维度(在指定维度上),返回的张量与原始张量共享内存空间。
    参数:
  • input(Tensor) - 输入张量
  • dim(int) - 索引的维度
  • index(LongTensor) - 包含索引下标的一维张量
  • out(Tensor, optional) - 目标张量
    例子:
    在这里插入图片描述

根据掩码张量 mask 中的二元值获取一个新的张量:

torch.masked_select(input, mask, out=None)-> Tensor

张量 mask 与 input 张量有相同数量的元素数目,但形状和维度不需要相同。返回的张量不与原始张量共享内存空间
参数:

  • input(Tensor) - 输入张量
  • mask(ByteTensor) - 掩码张量,包含了二元索引值
  • out(Tensor, optional) - 目标张量

返回一个包含输入 input 中非零元素索引的张量。输出张量中的每行包含输入中非零元素的索引。
如果输入 input 有 n 维,则输出的索引张量 output 的形状为 z * n,其中 z 是输入张量 input 中所有非零元素的个数。

torch.nonzero(input, out=None) ->LongTensor

参数:

  • input(Tensor) - 源张量
  • out(LongTensor,optional) - 包含索引值的结果张量
    例子:
    在这里插入图片描述
    将输入张量分割成相等形状的 chunks(如果可分)。如果沿指定维度的张量形状大小不能被split_size整分,则最后一个分块小于其它分块。
torch.split(input, split_size, dim=0)

参数:

  • input(Tensor) - 待分割张量
  • split_size(int) - 单个分块的形状大小
  • dim(int) - 沿着此维进行分割

将输入张量形状中的1去除并返回

torch.squeeze(input, dim=None, out=None)

如果输入是形如(A * 1 * B * 1 * C * 1 * D),那么输出形状就为:(A * B * C * D),当指定 dim 时,那么挤压操作只在给定维度上。返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
参数:

  • input(Tensor) - 待分割张量
  • dim(int, optional) - 如果给定,则 input 只会在给定维度挤压
  • out(Tensor, optional) - 目标张量

沿着一个新的维度对输入张量序列进行拼接。序列中的所有张量都应该为相同形状。

torch.stack(sequence,dim=0)

参数:

  • sequence(Sequence)- 待连接的张量序列
  • dim(int) - 插入的维度。必须介于 0 与待连接张量序列数之间。

对矩阵进行转置

torch.t(input, out=None) ->Tensor

输入一个矩阵(2维张量),并转置0,1维。可以被视为transpose(input, 0, 1)的简写函数。
参数:

  • input(Tensor) - 输入张量
  • output(Tensor, optional) - 结果张量

交换矩阵指定维度

torch.transpose(input, dim0, dim1, out=None) ->Tensor

返回输入矩阵 input 的转置。交换维度 dim0 和 dim1。 输出张量与输入张量共享内存, 所以改变其中一个会导致另外一个也被修改。
参数:

  • input(Tensor) - 输入张量
  • dim0(int) - 转置的第一维度
  • dim1(int) - 转置的第二维度
    例子:
    在这里插入图片描述

删除矩阵的指定维度,并返回指定维度切片的各个切片

torch.unbind(tensor, dim=0)

参数:

  • tensor(Tensor) - 输入张量
  • dim(int) - 删除的维度

对tensor 矩阵的指定维度位置插入维度 1

torch.unsequeeze(input, dim, out=None)

返回一个新张量,对输入的指定位置插入维度 1。返回张量和输入张量共享内存;如果dim为负,则将被转化 dim + input.dim() + 1
参数:

  • input(Tensor) - 输入张量
  • dim(int, optional) - 插入维度的索引
  • out(Tensor, optional) - 结果张量

例子:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_43915090/article/details/134763079
今日推荐