在keras中自定义 Callbacks 的一个模板

# -*- coding: utf-8 -*-
from keras import layers
from keras.callbacks import Callback
from keras.models import Sequential
from sklearn import datasets


class MyCallback(Callback):
    """Keras (version=2.3.1) Callback 编写模板"""

    def __init__(self):
        super().__init__()

    def on_train_begin(self, logs: dict):
        """在整个训练开始时会调用次函数

            Parameters:
            ----------
                logs: dict, 该参数在当前版本默认为 None, 主要是为未来的 keras 版本的新行为预留位置
        """
        print('On train begin', logs)
        return

    def on_train_end(self, logs: dict):
        """在整个训练结束时调用次函数

            Parameters:
            ----------
                logs: dict, 该参数在当前版本默认为 None, 主要是为未来的 keras 版本的新行为预留位置
        """
        print('On train end', logs)
        return

    def on_epoch_begin(self, epoch, logs: dict):
        """在每个 epoch 开始的时候调用此函数

            Parameters:
            ----------
                epoch: int, 当前为第几个 epoch, 从 1 开始
                logs: dict, 为空
        """
        print('On epoch begin', epoch, logs)
        return

    def on_epoch_end(self, epoch, logs: dict):
        """在每个 epoch 结束的时候调用此函数

            Parameters:
            ----------
                epoch: int, 当前为第几个 epoch, 从 1 开始
                logs: dict, 包含了当前 epoch 的一些信息,主要的 key 有:
                    - accuracy
                    - loss
                    - val-accuracy(仅在 fit 中开启 validation 时才有)
                    - val-loss(仅在 fit 中开启 validation 时才有)
        """
        print('On epoch end', epoch, logs)
        pass

    def on_batch_begin(self, batch, logs: dict):
        """在每个 batch 开始的时候调用此函数

            Parameters:
            ----------
                batch: int, 当前为第几个 batch, 从 1 开始
                logs: dict, 包含了当前 batch 的一些信息,主要的 key 有:
                    - batch: 同参数 batch
                    - size: batch 的大小
        """
        print('On batch begin', batch, logs)
        return

    def on_batch_end(self, batch, logs: dict):
        """在每个 batch 结束的时候调用此函数

            Parameters:
            ----------
                batch: int, 当前为第几个 batch, 从 1 开始
                logs: dict, 包含了当前 batch 的一些信息,主要的 key 有:
                    - batch: 同参数 batch
                    - size: batch 的大小
                    - loss
                    - accuracy(仅当启用了 acc 监视)
        """
        print('On batch end', batch, logs)
        return


def main():
    X, y = datasets.load_breast_cancer(return_X_y=True)
    model = Sequential([
        layers.Dense(16, input_dim=30, activation='elu'),
        layers.Dense(8, activation='elu'),
        layers.Dense(1, activation='sigmoid')
    ])
    model.compile(
        loss='binary_crossentropy',
        optimizer='adam',
        metrics=['accuracy']
    )
    print('Start'.center(75, '='))
    model.fit(X, y, batch_size=200, epochs=1, verbose=2, callbacks=[MyCallback()])
    loss, acc = model.evaluate(X, y)
    print(f'{loss:.3f}, {acc:.2%}')


if __name__ == '__main__':
    main()

结果如下:

On train begin None
Epoch 1/1
On epoch begin 0 {}
On batch begin 0 {'batch': 0, 'size': 200}
On batch end 0 {'batch': 0, 'size': 200, 'loss': 20.647158, 'accuracy': 0.675}
On batch begin 1 {'batch': 1, 'size': 200}
On batch end 1 {'batch': 1, 'size': 200, 'loss': 25.70241, 'accuracy': 0.6425}
On batch begin 2 {'batch': 2, 'size': 169}
On batch end 2 {'batch': 2, 'size': 169, 'loss': 24.250557, 'accuracy': 0.6274165}
 - 0s - loss: 23.4943 - accuracy: 0.6274
On epoch end 0 {'loss': 23.494301593157026, 'accuracy': 0.6274165}
On train end None
569/569 [==============================] - 0s 68us/step
21.578, 62.21%

猜你喜欢

转载自blog.csdn.net/frostime/article/details/105083151