tensorflow 2.1 断点续训模型存取

一、介绍

        模型训练往往需要很长的时间,而且我们往往想要寻找最佳的模型参数,因此需要将最优的模型参数能够保存下来,以供继续训练,方便测试的同时也便于寻找最优解。

        断点续训的意思就是可以从保存的模型开始集训训练模型。

        同时也可以直接加载训练好的模型去做predict 预测。

二、如何存取

        先贴出代码:

import tensorflow as tf
import os

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()
# 读取模型
# 定义路径:checkpoint_save_path = "*********.ckpt" ,该.ckpt文件就是保存的模型的文件
# 由于如若已有保存的模型就一定会有index目录文件,于是用以下的代码检测是否已有存在的模型
# if os.path.exists(checkpoint_save_path + '.index'):
#     print('-------------load the model-----------------')
# 若存在模型,则直接用函数 .load_weights()去加载模型
#     model.load_weights(checkpoint_save_path)

# 保存模型
# 这里的.callbacks.ModelCheckpoint是keras的callback的一种功能,将会在另一篇博客介绍callbacks的详细用法
# callback =tf.keras.callbacks.ModelCheckpoint(
#   filepath=文件路径,
#   save_weights_only=是否仅保存参数,
#   save_best_only=是否仅保存最优模型)
# 同时需要在fit中加入回调选项callback,并返回给history,(这里的callback和上面定义的callback是同一个),即: history = model.fit(...,callbacks=[callback])

 在每个training/epoch/batch结束时,如果我们想执行某些任务,例如模型缓存、输出日志、计算当前的acurracy等等,Keras中的callback就派上用场了。

# callbacks可以做到以下功能:
# ModelCheckpoint模型断点续训:保存当前模型的所有权重
# EarlyStopping提早结束:当模型的损失不再下降的时候就终止训练,当然,会保存最优的模型。
# LearningRateSchedule动态调整训练时的参数,比如优化的学习率

 代码执行结果:

该模型参数已经被保存到这里:

扫描二维码关注公众号,回复: 14824531 查看本文章

我们可以很方便的从模型里读取参数使用了

三、查看保存的参数

        有时候你可能想要直观的查看可训练参数,下面将介绍如何查看参数并存入文本

print(model.trainable_variables) 

该语句可以打印出参数的值,但是会有以下问题:

 打印出来的参数被省略号代替了很大部分,我们想要看到全部的参数怎么办呢?

设置最大显示数目即可:在开头添加下面代码

import numpy as np
np.set_printoptions(threshold=np.inf)

 

 现在就完全ok了,可以看到全部参数

然后保存参数到txt文件:

file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

 当前目录下查看txt:

已经保存成功了

完整代码上:

import tensorflow as tf
import os
import numpy as np
np.set_printoptions(threshold=np.inf)
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()
print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

猜你喜欢

转载自blog.csdn.net/qq_46006468/article/details/119645336
2.1