tf.keras.layers.Flatten() 示例

import tensorflow as tf
flatten = tf.keras.layers.Flatten()
print(flatten)
<tensorflow.python.keras.layers.core.Flatten object at 0x000001E3E5AA94C0>
inputs = tf.random.normal([32, 77, 10, 8])
print(inputs.shape)
(32, 77, 10, 8)

除第一维(样本个数),其它维度压缩成1维

output = flatten(inputs)
print(output.shape)
(32, 6160)

猜你喜欢

转载自blog.csdn.net/weixin_44493841/article/details/121419353