Keras文档解读——The Model class

Keras原文链接

写在前面

不定期更新解读翻译Keras文档,这是tensorflow2.x版本的keras接口文档,注意一下。

1. Model class

Model类将许多layer组合在一起形成一个对象,一次训练或者预测特征数据

tf.keras.Model(inputs,outputs,name)

参数:

  • inputs:模型的输入。 一个keras.Input对象或者以list形式的多个keras.Input对象。
  • outputs:模型的输出。同上,也存在多个输出
  • name:字符串类型,为模型的名字(id)

2. 两种方法去构建一个Model

2.1 方法一

从一个Input开始,然后连接所有layers的call函数从而明确model的前向传播,最后用inputs和outputs创建模型

import tensorflow as tf

inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

2.2 方法二

写一个Model的子类,如果这样做,你需要在自己定义的子类中的__ init __函数中声明layers,在call函数中声明前向传播。

import tensorflow as tf

class MyModel(tf.keras.Model):

  def __init__(self):
    super(MyModel, self).__init__()
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

  def call(self, inputs):
    x = self.dense1(inputs)
    return self.dense2(x)

model = MyModel()

如果你选择这种方法,你可以选择性的有一个“training”参数在你的call函数中,通过这样,你可以表明不同的意图,例如训练和预测

import tensorflow as tf

class MyModel(tf.keras.Model):

  def __init__(self):
    super(MyModel, self).__init__()
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
    self.dropout = tf.keras.layers.Dropout(0.5)

  def call(self, inputs, training=False):
    x = self.dense1(inputs)
    if training:
      x = self.dropout(x, training=training)
    return self.dense2(x)

model = MyModel()

一旦模型被创建,你可以通过model.compile()去配置模型的损失函数和评价指标,通过model.fit去训练模型,通过model.predict去预测

3. summary方法

Model.summary(line_length=None, positions=None, print_fn=None)

参数(这些参数一般默认就可以)

  • line_length:打印的总长度(它是为了使用不同平台窗口的大小的)
  • position:打印的位置,选择默认即可([.33, .55, .67, 1.].)
  • print_fn:打印自定义的函数,这个不用管‘
    可能会触发的错误:如果在模型没有build前调用summary函数会发生错误

模型build有两种方式:
1.调用model.build(input=(shape))
2.延迟创建,即传入数据时自动调用build函数(⭐推荐)

4. get__layer方法

通过layer的name或者index获取模型的某个layer,如果name和index都被声明有限考虑index( Indices are based on order of horizontal graph traversal (bottom-up).)

Model.get_layer(name=None, index=None)

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_44441131/article/details/107264765