pytorch repeat_interleave()函数详细讲解(附代码)

repeat_interleave 函数原型

torch.repeat_interleave(input, repeats, dim=None) → Tensor

官方介绍

详解:
      函数的功能:重复张量的元素,返回一个张量
      输入参数:

       input (类型:torch.Tensor):输入张量
      repeats(类型:int或torch.Tensor):重复的次数。repeats参数会被广播来适应输入张量的维度,可
       dim(类型:int):在哪个维度进行重复。默认情况下,将把输入张量展平(flatten)为向量,然后将每个元素重复repeats次,并返回重复后的张量。
     返回值:

    返回一个张量,这个张量与输入张量在dim维的shape不同,其他维的shape与输入的shape一致
 

个人理解

torch.repeat_interleave(self: Tensor, repeats: int, dim: Optional[int]=None)

参数说明:

self: 传入的数据为tensor

repeats: 复制的份数

dim: 要复制的维度,可设定为0/1/2…

代码演示

>>> import matplotlib.pyplot as plt
>>> import torch.nn as nn
>>> data1 = torch.rand([2,1,3,3])
>>> print("data1_shape: ",data1.shape)
data1_shape:  torch.Size([2, 1, 3, 3])
>>> print("data1: ",data1)
data1:  tensor([[[[0.1947, 0.5847, 0.3518],
          [0.3061, 0.6423, 0.2480],
          [0.1966, 0.1335, 0.7225]]],


        [[[0.1014, 0.0204, 0.1819],
          [0.9579, 0.1556, 0.7318],
          [0.7449, 0.8234, 0.1277]]]])
>>> data2 = torch.repeat_interleave(data1,repeats=3,dim=1)
>>> print("data2_shape: ", data2.shape)
data2_shape:  torch.Size([2, 3, 3, 3])
>>> print("data2: ",data2)
data2:  tensor([[[[0.1947, 0.5847, 0.3518],
          [0.3061, 0.6423, 0.2480],
          [0.1966, 0.1335, 0.7225]],

         [[0.1947, 0.5847, 0.3518],
          [0.3061, 0.6423, 0.2480],
          [0.1966, 0.1335, 0.7225]],

         [[0.1947, 0.5847, 0.3518],
          [0.3061, 0.6423, 0.2480],
          [0.1966, 0.1335, 0.7225]]],


        [[[0.1014, 0.0204, 0.1819],
          [0.9579, 0.1556, 0.7318],
          [0.7449, 0.8234, 0.1277]],

         [[0.1014, 0.0204, 0.1819],
          [0.9579, 0.1556, 0.7318],
          [0.7449, 0.8234, 0.1277]],

         [[0.1014, 0.0204, 0.1819],
          [0.9579, 0.1556, 0.7318],
          [0.7449, 0.8234, 0.1277]]]])

猜你喜欢

转载自blog.csdn.net/Vertira/article/details/130983664