tf.keras.layers.Reshape 示例

import tensorflow as tf

除第一维外的维度改变

reshape = tf.keras.layers.Reshape(
    (5, 16),  # 输出的维度,输出维度乘积 = 输入维度的乘积
)
print(reshape)
<tensorflow.python.keras.layers.core.Reshape object at 0x0000020D524E59D0>

input的维度必须是大于1维,因为第1维是样本个数

inputs = tf.random.normal([
    32,  # 第1个维度是样本数
    10,
    8,
])
print(inputs.shape)
(32, 10, 8)

output改变除第一维的其它维度

output = reshape(inputs)
print(output.shape)
(32, 5, 16)

猜你喜欢

转载自blog.csdn.net/weixin_44493841/article/details/121419532