repeat_interleave 函数原型
torch.repeat_interleave(input, repeats, dim=None) → Tensor
官方介绍
详解:
函数的功能:重复张量的元素,返回一个张量
输入参数:
input (类型:torch.Tensor):输入张量
repeats(类型:int或torch.Tensor):重复的次数。repeats参数会被广播来适应输入张量的维度,可
dim(类型:int):在哪个维度进行重复。默认情况下,将把输入张量展平(flatten)为向量,然后将每个元素重复repeats次,并返回重复后的张量。
返回值:
返回一个张量,这个张量与输入张量在dim维的shape不同,其他维的shape与输入的shape一致
个人理解
torch.repeat_interleave(self: Tensor, repeats: int, dim: Optional[int]=None)
参数说明:
self: 传入的数据为tensor
repeats: 复制的份数
dim: 要复制的维度,可设定为0/1/2…
代码演示
>>> import matplotlib.pyplot as plt
>>> import torch.nn as nn
>>> data1 = torch.rand([2,1,3,3])
>>> print("data1_shape: ",data1.shape)
data1_shape: torch.Size([2, 1, 3, 3])
>>> print("data1: ",data1)
data1: tensor([[[[0.1947, 0.5847, 0.3518],
[0.3061, 0.6423, 0.2480],
[0.1966, 0.1335, 0.7225]]],
[[[0.1014, 0.0204, 0.1819],
[0.9579, 0.1556, 0.7318],
[0.7449, 0.8234, 0.1277]]]])
>>> data2 = torch.repeat_interleave(data1,repeats=3,dim=1)
>>> print("data2_shape: ", data2.shape)
data2_shape: torch.Size([2, 3, 3, 3])
>>> print("data2: ",data2)
data2: tensor([[[[0.1947, 0.5847, 0.3518],
[0.3061, 0.6423, 0.2480],
[0.1966, 0.1335, 0.7225]],
[[0.1947, 0.5847, 0.3518],
[0.3061, 0.6423, 0.2480],
[0.1966, 0.1335, 0.7225]],
[[0.1947, 0.5847, 0.3518],
[0.3061, 0.6423, 0.2480],
[0.1966, 0.1335, 0.7225]]],
[[[0.1014, 0.0204, 0.1819],
[0.9579, 0.1556, 0.7318],
[0.7449, 0.8234, 0.1277]],
[[0.1014, 0.0204, 0.1819],
[0.9579, 0.1556, 0.7318],
[0.7449, 0.8234, 0.1277]],
[[0.1014, 0.0204, 0.1819],
[0.9579, 0.1556, 0.7318],
[0.7449, 0.8234, 0.1277]]]])