import tensorflow as tf
输入的维度只能是3维的
inputs = tf.random.normal([
32,
10,
8
])
print(inputs.shape)
(32, 10, 8)
lstm = tf.keras.layers.LSTM(4)
print(lstm)
<tensorflow.python.keras.layers.recurrent_v2.LSTM object at 0x000002E3D8C25AF0>
output = lstm(inputs)
print(output.shape)
(32, 4)
lstm = tf.keras.layers.LSTM(
4,
return_sequences=True
)
output = lstm(inputs)
print(output.shape)
(32, 10, 4)
lstm = tf.keras.layers.LSTM(
4,
return_sequences=True,
return_state=True
)
whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)
print(whole_seq_output.shape)
(32, 10, 4)
print(final_memory_state.shape)
(32, 4)
print(final_carry_state.shape)
(32, 4)