使用TensorFlow 搭建神经网络的六步法

使用TensorFlow API:tf.keras 搭建神经网络

搭建神经网络六步法:

1.导入第三方库:import
2.导入并理解数据,划分训练集与测试集:train test
3.在Sequential()中搭建网络结构。逐层描述每层网络,相当于前向传播。:model=tf.keras.models.Sequential
4.在compile中配置训练方法。即选择哪种优化器,选择哪个损失函数,选择哪种评测指标。model.compile
5.在fit中进行训练。告知训练集和测试集的输入特征和标签。每个betch是多少,要迭代多少次数据集:model.fit
6.用model.summary打印出网络的结构和参数。

函数用法介绍

1.model=tf.keras.models.Sequential

Sequential 函数是一个容器,容器里封装了神经网络的网络结构,描述了在Sequential函数的输入参数从输入层到输出层的网络结构。
如:

拉直层:tf.keras.layers.Flatten()
拉直层可以变换张量的尺寸,把输入特征拉直为一维数组,是不含计算参数的层。

全连接层:tf.keras.layers.Dense( 神经元个数,activation=”激活函数”, kernel_regularizer=”正则化方式”)

其中:
activation(字符串给出)可选 relu、softmax、sigmoid、tanh 等,kernel_regularizer 可选 tf.keras.regularizers.l1()、
tf.keras.regularizers.l2()
卷积层:tf.keras.layers.Conv2D( filter = 卷积核个数, kernel_size = 卷积核尺寸,
strides = 卷积步长,padding = “valid” or “same”)

LSTM 层:tf.keras.layers.LSTM()。

2.Model.compile

Compile 用于配置神经网络的训练方法,告知训练时使用的优化器、损失函数和准确率评测标准。

Model.compile( optimizer = 优化器, loss = 损失函数, metrics = [“准确率”])
(1)optimizer 可以是字符串形式给出的优化器名字,也可以是函数形式,使用函数形式可以设置学习率、动量和超参数。
可选择有:
‘sgd’or tf.optimizers.SGD( lr=学习率,decay=学习率衰减率,momentum=动量参数)

‘adagrad’or tf.keras.optimizers.Adagrad(lr=学习率,decay=学习率衰减率)

‘adadelta’or tf.keras.optimizers.Adadelta(lr=学习率, decay=学习率衰减率)

‘adam’or tf.keras.optimizers.Adam (lr=学习率, decay=学习率衰减率)
(2) Loss 可以是字符串形式给出的损失函数的名字,也可以是函数形式。
可选项包括:
‘mse’or tf.keras.losses.MeanSquaredError()
‘sparse_categorical_crossentropy or tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)

损失函数常需要经过 softmax 等函数将输出转化为概率分布的形式。from_logits 则用来标注该损失函数是否需要转换为概率的形式,取 False 时表示转化为概率分布,取 True 时表示没有转化为概率分布,直接输出。
(3)Metrics 标注网络评测指标。
可选项包括:
‘accuracy’:y_和 y 都是数值,如 y_=[1] y=[1]。
‘categorical_accuracy’:y_和 y 都是以独热码和概率分布表示。
如 y_=[0, 1, 0], y=[0.256, 0.695, 0.048]。
‘sparse_ categorical_accuracy’:y_是以数值形式给出,y 是以独热码形式
给出。 如 y_=[1],y=[0.256, 0.695, 0.048]。

3.model.fit()

fit 函数用于执行训练过程。
——model.fit(训练集的输入特征, 训练集的标签,batch_size, epochs, validation_data = (测试集的输入特征,测试集的标签), validataion_split = 从测试集划分多少比例给训练集, validation_freq = 测试的 epoch 间隔次数)

4.model.summary()

summary 函数用于打印网络结构和参数统计.
在这里插入图片描述上图是 model.summary()对鸢尾花分类网络的网络结构和参数统计,对于输入为 4 输出为 3 的全连接网络,共有 15 个参数。

猜你喜欢

转载自blog.csdn.net/waner_jiaki/article/details/109903373