一、介绍
模型训练往往需要很长的时间,而且我们往往想要寻找最佳的模型参数,因此需要将最优的模型参数能够保存下来,以供继续训练,方便测试的同时也便于寻找最优解。
断点续训的意思就是可以从保存的模型开始集训训练模型。
同时也可以直接加载训练好的模型去做predict 预测。
二、如何存取
先贴出代码:
import tensorflow as tf import os mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy']) checkpoint_save_path = "./checkpoint/mnist.ckpt" if os.path.exists(checkpoint_save_path + '.index'): print('-------------load the model-----------------') model.load_weights(checkpoint_save_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True) history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) model.summary()
# 读取模型 # 定义路径:checkpoint_save_path = "*********.ckpt" ,该.ckpt文件就是保存的模型的文件 # 由于如若已有保存的模型就一定会有index目录文件,于是用以下的代码检测是否已有存在的模型 # if os.path.exists(checkpoint_save_path + '.index'): # print('-------------load the model-----------------') # 若存在模型,则直接用函数 .load_weights()去加载模型 # model.load_weights(checkpoint_save_path) # 保存模型 # 这里的.callbacks.ModelCheckpoint是keras的callback的一种功能,将会在另一篇博客介绍callbacks的详细用法 # callback =tf.keras.callbacks.ModelCheckpoint( # filepath=文件路径, # save_weights_only=是否仅保存参数, # save_best_only=是否仅保存最优模型) # 同时需要在fit中加入回调选项callback,并返回给history,(这里的callback和上面定义的callback是同一个),即: history = model.fit(...,callbacks=[callback])
在每个training/epoch/batch结束时,如果我们想执行某些任务,例如模型缓存、输出日志、计算当前的acurracy等等,Keras中的callback就派上用场了。
# callbacks可以做到以下功能: # ModelCheckpoint模型断点续训:保存当前模型的所有权重 # EarlyStopping提早结束:当模型的损失不再下降的时候就终止训练,当然,会保存最优的模型。 # LearningRateSchedule动态调整训练时的参数,比如优化的学习率
代码执行结果:
该模型参数已经被保存到这里:
扫描二维码关注公众号,回复:
14824531 查看本文章
我们可以很方便的从模型里读取参数使用了
三、查看保存的参数
有时候你可能想要直观的查看可训练参数,下面将介绍如何查看参数并存入文本
print(model.trainable_variables)
该语句可以打印出参数的值,但是会有以下问题:
打印出来的参数被省略号代替了很大部分,我们想要看到全部的参数怎么办呢?
设置最大显示数目即可:在开头添加下面代码
import numpy as np np.set_printoptions(threshold=np.inf)
现在就完全ok了,可以看到全部参数
然后保存参数到txt文件:
file = open('./weights.txt', 'w') for v in model.trainable_variables: file.write(str(v.name) + '\n') file.write(str(v.shape) + '\n') file.write(str(v.numpy()) + '\n') file.close()
当前目录下查看txt:
已经保存成功了
完整代码上:
import tensorflow as tf
import os
import numpy as np
np.set_printoptions(threshold=np.inf)
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model-----------------')
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
callbacks=[cp_callback])
model.summary()
print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()