tf.keras.layers.Flatten

展平层

keras.layers.Flatten(data_format=None)

将输入张量展平,不影响批量大小。

data_format 表示输入张量的维度,默认为[batch, height, width, channel]


 示例

from tensorflow.keras.layers import Flatten
import tensorflow as tf
import numpy as np

# 定义一个展平层
flatten = Flatten()
 
# 生成一个维度为[64, 720, 720, 3]的矩阵
x = np.random.random((64, 720, 720, 3))
 
# 转成tensor类型,第一个维度64表示batch
# numpy中的数据类型和tensorflow中的数据类型完全兼容,所以这一步可以省略
x = tf.convert_to_tensor(x)
print(x.shape) # [64, 720, 720, 3]
 
# 进行展平
y = flatten(x)
print(y.shape) # [64, 1555200]

猜你喜欢

转载自blog.csdn.net/weixin_46566663/article/details/127618235