展平层
keras.layers.Flatten(data_format=None)
将输入张量展平,不影响批量大小。
data_format 表示输入张量的维度,默认为[batch, height, width, channel]
示例
from tensorflow.keras.layers import Flatten
import tensorflow as tf
import numpy as np
# 定义一个展平层
flatten = Flatten()
# 生成一个维度为[64, 720, 720, 3]的矩阵
x = np.random.random((64, 720, 720, 3))
# 转成tensor类型,第一个维度64表示batch
# numpy中的数据类型和tensorflow中的数据类型完全兼容,所以这一步可以省略
x = tf.convert_to_tensor(x)
print(x.shape) # [64, 720, 720, 3]
# 进行展平
y = flatten(x)
print(y.shape) # [64, 1555200]