PyTorch中flatten() 函数的用法

一. 用法

Flatten层主要是用来将输入“压平”,即把多维的输入一维化,用在卷积层到全连接层的过渡。其不会影响batch的大小,可以理解为把高纬度的数组按照x轴或者y轴进行拉伸,变成一维的数组。

二. 参数

      1.start_dim(可选参数):指定从哪个维度开始展平张量。默认情况下,start_dim被设置为0,表示从第一个维度(通常是批大小)开始展平。如果设置为其他整数值,则会从指定的维度开始展平。

       2.end_dim(可选参数):指定在哪个维度结束展平张量。默认情况下,end_dim被设置为-1,表示展平直到最后一个维度。如果设置为其他整数值,则会在指定的维度结束展平。

三. 实例

 (1). 首先随机定义一个满足正态分布的(2,3,4)的数据x

import torch 

x = torch.randn(2,3,4)
print(x)
x = x.flatten(0)
print(x)

------------------------------------
tensor([[[ 0.1281,  1.6878,  0.2301, -0.0721],
         [ 1.2374, -0.6929,  1.1186,  0.4372],
         [ 0.5122,  1.4653, -0.1673,  0.7258]],

        [[ 0.2772, -1.9994, -1.2284,  0.2764],
         [-0.0451, -0.9195,  0.5749,  0.1942],
         [ 0.8539, -0.0434, -0.7313,  0.0234]]])
tensor([ 0.1281,  1.6878,  0.2301, -0.0721,  1.2374, -0.6929,  1.1186,  0.4372,
         0.5122,  1.4653, -0.1673,  0.7258,  0.2772, -1.9994, -1.2284,  0.2764,
        -0.0451, -0.9195,  0.5749,  0.1942,  0.8539, -0.0434, -0.7313,  0.0234])

此时x的维度是2×3×4=24,x = flatten(0) 和 x = flatten()的结果相同。

 (2).

import torch 

x = torch.randn(2,3,4)
print(x)
x = x.flatten(1)
print(x)

===========================================
tensor([[[-0.7137, -0.0859, -1.5284,  0.7284],
         [ 0.8425,  0.3606,  1.7639,  0.1848],
         [ 0.4040, -1.6575,  1.9134, -1.0787]],

        [[ 0.6981,  1.3494, -0.5817, -1.1824],
         [-0.4972,  0.4179,  2.1742, -0.2462],
         [ 0.2429, -1.9315, -0.3497,  0.7190]]])
tensor([[-0.7137, -0.0859, -1.5284,  0.7284,  0.8425,  0.3606,  1.7639,  0.1848,
          0.4040, -1.6575,  1.9134, -1.0787],
        [ 0.6981,  1.3494, -0.5817, -1.1824, -0.4972,  0.4179,  2.1742, -0.2462,
          0.2429, -1.9315, -0.3497,  0.7190]])

此时x是从1维度开始展开,最后的x维度为(2,3×4),也就是(2,12)

注意:start_dimend_dim参数的取值范围应该在 -x.dim() <= start_dim <= end_dim < x.dim() 之间。

猜你喜欢

转载自blog.csdn.net/m0_62278731/article/details/134263429