tensorflow保存及载入模型、添加检查点

在训练完模型之后,就要把模型保存起来,方便以后使用。

保存模型 save():

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(...)
    saver.save(sess, savePath/fileName)

载入模型 restore() :

saver = tf.train.Saver()
with tf.Session() as sess1:
    sess.run(...)
    saver.restore(sess1, savePath/fileName)

保存模型时也可以指定变量名字与变量的对应关系:

1)saver = tf.train.Saver({key1: valve, key2: value})
   例: saver = tf.train.Saver({'weights': w, biases: b})
2)saver = tf.train.Saver([w, b])  # 放到list中
3)saver = tf.train.Saver(v.op.name: v for v in [w, b])  # 将op的名字当做key

打印模型内容:

print_tensors_in_checkpoint_file(save_dir+'linerModel.cpkt', None, True)

下面是一个保存及载入模型的完整例子:

import numpy as np
import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

# 构建实验数据
train_x = np.linspace(-1, 1, 100)
# y = 2 * x + b
train_y = 2. * train_x + np.random.randn(*train_x.shape) * 0.3

# 创建模型
# 占位符
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
# 模型参数
weights = tf.Variable(tf.random_normal([1]), name='weights')
biases = tf.Variable(tf.zeros([1]), name='biases')
z = tf.multiply(X, weights) + biases

# 构建损失函数
loss = tf.reduce_mean(tf.square(Y - z))
# 定义学习率
learning_rate = 0.01
# 构建优化函数
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 最小化损失函数
train = optimizer.minimize(loss)

# 初始化所有变量
init = tf.global_variables_initializer()

# 定义 epochs
training_epochs = 20
# 每隔两步显示一次中间值
display_step = 2

# 存放批次值和损失值
plot_data = {'batchsize': [], 'loss': []}


# 定义保存模型对象
saver = tf.train.Saver()
save_dir = 'logs/'  # 生成模型的路径
# 启动Session
with tf.Session() as sess:
    # 初始化全局变量
    sess.run(init)

    # 向模型中 feed 数据
    for epoch in range(training_epochs):
        for (x, y) in zip(train_x, train_y):
            feed_dict = {X: x, Y: y}
            sess.run(train, feed_dict=feed_dict)

        # 显示训练中的数据
        if epoch % display_step == 0:
            loss_ = sess.run(loss, feed_dict={X: train_x, Y: train_y})
            print('epoch:', epoch + 1, 'loss = ', loss_, 'weights=',
                  sess.run(weights), 'biases=', sess.run(biases))

    print('Finished...')
    # 保存模型
    saver.save(sess, save_dir+'linerModel.cpkt')  # 如果指定的文件夹不存在会自动创建
    print('loss=', sess.run(loss, feed_dict={X: train_x, Y: train_y}), 'weights=',
          sess.run(weights), 'biases=', sess.run(biases))


# 使用模型
with tf.Session() as sess_2:
    saver.restore(sess_2, save_dir+'linerModel.cpkt')
    print('下面是模型载入结果: ')
    print('x=0.2, z=', sess_2.run(z, feed_dict={X: 0.2}))


# 打印文件内容
print_tensors_in_checkpoint_file(save_dir+'linerModel.cpkt', None, True)

添加检查点 Checkpoint

在训练之中,难免会出现中断的情况,这时就设置一个检查点。
saver = tf.train.Saver(max_to_keep=1)  # 生成saver
with tf.Session() as sess1:
    sess.run(...)
    saver.save(sess1, savePath/fileName, global_step=epoch)
max_to_keep 参数指定最多生成多少个检查点文件

 载入检查点

load_epoch = 18           #  只是文件的一个后缀,可以根据需要修改
with tf.Session() as sess_2:
    saver.restore(sess_2, save_dir+'linearModel.cpkt-' + str(load_epoch))

另一种添加检查点的方式: trainMonitored

tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='logs/ckpt',
                                       save_checkpoint_secs=2) as sess:
    print(sess.run([global_step]))
    while not sess.should_stop():  # 启用死循环,session不停止就不结束
        i = sess.run(step)
        print(i)
如果不设置 save_checkpoint_secs 参数,默认时间是10mins,
该方法必须定义global_step,不然报错

下面是添加及载入检查点的完整例子:

import numpy as np
import tensorflow as tf

# 构建实验数据
train_x = np.linspace(-1, 1, 100)
# y = 2 * x + b
train_y = 2. * train_x + np.random.randn(*train_x.shape) * 0.3

# 创建模型
# 占位符
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
# 模型参数
weights = tf.Variable(tf.random_normal([1]), name='weights')
biases = tf.Variable(tf.zeros([1]), name='biases')
z = tf.multiply(X, weights) + biases

# 构建损失函数
loss = tf.reduce_mean(tf.square(Y - z))
# 定义学习率
learning_rate = 0.01
# 构建优化函数
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 最小化损失函数
train = optimizer.minimize(loss)

# 初始化所有变量
init = tf.global_variables_initializer()

# 定义 epochs
training_epochs = 20
# 每隔两步显示一次中间值
display_step = 2

# 存放批次值和损失值
plot_data = {'batchsize': [], 'loss': []}


# 定义保存模型对象
saver = tf.train.Saver(max_to_keep=2)
save_dir = 'logs/'  # 生成模型的路径
# 启动Session
with tf.Session() as sess:
    # 初始化全局变量
    sess.run(init)

    # 向模型中 feed 数据
    for epoch in range(training_epochs):
        for (x, y) in zip(train_x, train_y):
            feed_dict = {X: x, Y: y}
            sess.run(train, feed_dict=feed_dict)

        # 显示训练中的数据
        if epoch % display_step == 0:
            loss_ = sess.run(loss, feed_dict={X: train_x, Y: train_y})
            print('epoch:', epoch + 1, 'loss = ', loss_, 'weights=',
                  sess.run(weights), 'biases=', sess.run(biases))
            # 保存检查点
            saver.save(sess, save_dir + 'linearModel.cpkt', global_step=epoch)
    print('Finished...')
    # 保存模型
    saver.save(sess, save_dir+'linerModel.cpkt')  # 如果指定的文件夹不存在会自动创建
    print('loss=', sess.run(loss, feed_dict={X: train_x, Y: train_y}), 'weights=',
          sess.run(weights), 'biases=', sess.run(biases))

# 载入检查点
load_epoch = 18
with tf.Session() as sess_2:
    saver.restore(sess_2, save_dir+'linearModel.cpkt-' + str(load_epoch))
    print('下面是检查点的结果: ')
    print('x=0.2, z=', sess_2.run(z, feed_dict={X: 0.2}))

# trainMonitored
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='logs/ckpt',
                                       save_checkpoint_secs=2) as sess:
    print(sess.run([global_step]))
    while not sess.should_stop():  # 启用死循环,session不停止就不结束
        i = sess.run(step)
        print(i)

猜你喜欢

转载自blog.csdn.net/qq_42413820/article/details/80902737
今日推荐