Tensorflow常用接口

1. tf.get_variable

官网链接:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/get_variable

作用: 获取具有这些参数的现有变量或创建一个新变量
说明: tensorflow1.0中的接口,在tensorflow2.0中用tf.compat.v1.get_variable兼容本接口

# 官方接口
tf.get_variable(
    name, shape=None, dtype=None, initializer=None, regularizer=None,
    trainable=None, collections=None, caching_device=None, partitioner=None,
    validate_shape=True, use_resource=None, custom_getter=None, constraint=None,
    synchronization=tf.VariableSynchronization.AUTO,
    aggregation=tf.VariableAggregation.NONE
)

 此函数在名称前面加上当前变量作用域,并执行检查

代码示例:

def foo():
  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
    v = tf.get_variable("v", [1])
  return v
v1 = foo()  # Creates v.
v2 = foo()  # Gets the same, existing v.
print(v1 == v2) 
# True
print(v2.name)
# "foo/v:0"

2. tf.variable_scope

官网链接:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/variable_scope

作用:上下文管理器,用于定义创建变量的操作。
说明: tensorflow1.0中的接口,在tensorflow2.0中用tf.compat.v1.variable_scope兼容本接口

# 官方接口
tf.variable_scope(
    name_or_scope, 
    default_name=None,
    values=None, 
    initializer=None,
    regularizer=None, 
    caching_device=None, 
    partitioner=None, 
    custom_getter=None,
    reuse=None, 
    dtype=None, 
    use_resource=None, 
    constraint=None,
    auxiliary_name_scope=True
)

 此上下文管理器验证值是否来自同一个图,确保图是默认图,并推送名称范围和变量范围。
 重点关注第一个变量name_or_scope即可,name_or_scope为String或VariableScope类型,用户可以定义或打开上下文环境范围

代码示例:

with tf.variable_scope("foo"):
    with tf.variable_scope("bar"):
        v = tf.get_variable("v", [1])
        print(v.name)  # "foo/bar/v:0"

3. tf.global_variables

官方链接:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/global_variables

概念: 全局变量是在环境中共享的变量。Variable()构造函数或get_Variable()会自动将新变量添加到图形集合中
作用: 返回全局变量
说明: tensorflow1.0中的接口,在tensorflow2.0中用tf.compat.v1.global_variables兼容本接口

# 官方接口
tf.global_variables(scope=None)

若scope为None,则返回所有全局变量Variable对象列表;
若scope不为None,则返回指定作用域内的变量Variable对象列表

代码示例:

# 创建两个变量v1和v2
v1 = tf.Variable(tf.constant(0.0, shape=[1], dtype=tf.float32), name='v')
with tf.variable_scope("foo"):
    v2 = tf.get_variable("v2", [1])
# 通过tf.global_variables接口获取
print(tf.global_variables())     
# [<tf.Variable 'v:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'foo/v2:0' shape=(1,) dtype=float32_ref>]
print(tf.global_variables("foo"))  
# [<tf.Variable 'foo/v2:0' shape=(1,) dtype=float32_ref>]

4. tf.global_variables_initializer

作用:给全局变量初始化值
说明:tensorflow的机制是定义variable的时候仅定义,并没有实际执行初始化。等程序确定初始化了,即tf.global_variable_initializer了才真正给全局变量赋值

代码示例:

# 创建两个变量a和b
# tf中建立的变量是没有初始化的,现在还不是一个tensor量,而是一个Variable变量类型
a = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name="a")
b = tf.Variable(tf.constant(1), name="b")

# 1. 不执行初始化,直接打印变量 -- 报错
with tf.Session() as sess:
    print("a.name: {}   value: {}".format(a.name, sess.run(a)))
    print("b.name: {}   value: {}".format(b.name, sess.run(b)))
# 报错(没有初始化值):FailedPreconditionError: Attempting to use uninitialized value a[[{
    
    {node _retval_a_0_0}}]]

# 2. 执行初始化后,打印变量
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("a.name: {}   value: {}".format(a.name, sess.run(a))) # a.name: a:0   value: [0.69489884]
    print("b.name: {}   value: {}".format(b.name, sess.run(b))) # b.name: b:0   value: 1

5. tf.train.Saver

官方链接:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/train/Saver#args

作用:保存和恢复变量,常用于模型保存

# 官方接口
tf.train.Saver(
    var_list=None, reshape=False, sharded=False, max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False,
    saver_def=None, builder=None, defer_build=False, allow_empty=False,
    write_version=tf.train.SaverDef.V2, pad_step_number=False,
    save_relative_paths=False, filename=None
)
  • var_list: Variable/SaveableObject的列表,或将名称映射到SaveableObjects的字典。如果为None,则默认为所有可保存对象的列表
  • max_to_keep:要保留的最近检查点的最大数目。默认值为5。

(1) save方法

作用: 变量保存

# 官方接口
save(
    sess, save_path, global_step=None, latest_filename=None,
    meta_graph_suffix='meta', write_meta_graph=True, write_state=True,
    strip_default_attrs=False, save_debug_info=False
)
  • sess:用于保存变量的session
  • save_path: 字符串类型,模型保存路径
  • global_step: 如果不为None,则加到模型保存路径作为模型前缀名称

(2) restore

作用: 恢复之前保存的变量

# 官方接口
restore(
    sess, save_path
)

 训练一个简单的模型,拟合y = 2*x + 1, 展示模型保存和加载,代码示例:

import numpy as np
import tensorflow as tf

# 1. 准备一组训练数据x 和 y
x = np.random.rand(100).astype(np.float32)
y = x * 2 + 1

# 2. 搭建模型  y = w*x +b
# 创建两个变量, 用于拟合
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y_pre = w * x + b
# 构建损失函数
loss = tf.reduce_mean(tf.square(y_pre - y))
# 创建优化器--梯度下降法
optimizer = tf.train.GradientDescentOptimizer(0.5)
# 实例化模型保存对象
saver=tf.train.Saver(max_to_keep=10)

# 3. 模型训练
with tf.Session() as sess:
    # 全局变量初始化
    sess.run(tf.global_variables_initializer())
    for step in range(61):
        sess.run(optimizer.minimize(loss))
        if step % 20 == 0:
            print("step: {}  w: {}  b: {}".format(step, sess.run(w), sess.run(b)))
            # 模型保存
            saver.save(sess=sess, save_path=r'C:\Users\ASUS\Desktop\model\my-model', global_step=step)
'''
打印输出:
step: 0  w: [1.0791051]  b: [2.2803402]
step: 20  w: [1.7042794]  b: [1.1717249]
step: 40  w: [1.92974]  b: [1.0408]
step: 60  w: [1.983307]  b: [1.0096936]
'''

文件保存情况如下:
在这里插入图片描述
加载指定模型,打印对应参数:

# 模型加载,打印加载的变量参数
with tf.Session() as sess:
    saver.restore(sess=sess, save_path=r'C:\Users\ASUS\Desktop\model\my-model-20')
    print("w: {}  b: {}".format(sess.run(w), sess.run(b)))
# w: [1.7042794]  b: [1.1717249]

6. tf.train.exponential_decay

官网链接:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/train/exponential_decay

作用:将指数衰减应用于学习率

# 官方接口
tf.train.exponential_decay(
    learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None
)
'''
learning_rate:  初始学习率
global_step:    当前迭代次数
decay_steps:    衰减次数, 可理解为当global_step==decay_steps时,learning_rate衰减为learning_rate * decay_rate
decay_rate:     学习率衰减系数,通常介于0-1之间
staircase=False: 若staircase为True,则global_step/decay_steps始终取整数,每迭代decay_steps衰减一次,变化曲线是阶梯状。
name=None:
'''

代码示例:

import numpy as np
import tensorflow as tf

# 1.准备一组训练数据 y = w*x + b
x = np.random.rand(100).astype(np.float32)
y = x * 2 + 1

# 2. 搭建模型  y = w*x +b
# 创建两个变量, 用于拟合
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y_pre = w * x + b
# 构建损失函数
loss = tf.reduce_mean(tf.square(y_pre - y))
# 创建优化器--梯度下降法
global_steps = tf.train.create_global_step()
starter_learning_rate = 0.5
decay_steps = 10 # 衰减次数 (即每迭代10次,学习率衰减0.9)
decay_rate = 0.9
learning_rate = tf.train.exponential_decay(learning_rate=starter_learning_rate, global_step=global_steps, decay_steps=10, decay_rate=decay_rate, staircase=False, name='learning_rate')
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 实例化模型保存对象
saver=tf.train.Saver(max_to_keep=10)

# 3. 模型训练
with tf.Session() as sess:
    # 全局变量初始化
    sess.run(tf.global_variables_initializer())
    # 迭代60次
    for step in range(1, 61):
        sess.run(optimizer.minimize(loss, global_step=global_steps))
        if step % 20 == 0:
            print("step: {}  global_step: {}  w: {}  b: {}".format(step, sess.run(global_steps), sess.run(w), sess.run(b)))
            # 模型保存
            saver.save(sess=sess,save_path=r'C:\Users\ASUS\Desktop\model\my-model', global_step=global_steps

分别设置staircase=True和False时,训练过程中学习率衰减情况如下:
在这里插入图片描述

7. tf.placeholder

作用:预先为tensor插入占位符,在后续执行运算时将被填充数据

  • 关键点:如果直接执行运算将会报错。其值必须先使用feed_dict填充后再送到Session.run()、Tensor.eval()或Operation.run()中。
# 官方接口
tf.placeholder(
    dtype, shape=None, name=None
)
'''
dtype: 填充到tensor中的数据类型
shape: 填充到tensor的数据维度,若shape=None,可以填充任意维度的数据
name:  名称
'''

代码示例:

x = tf..placeholder(tf.float32, shape=(3, 3)
y = tf.matmul(x, x) # 矩阵乘法
with tf.compat.v1.Session() as sess:
  # print(sess.run(y))  # ERROR: will fail because x was not fed.
  rand_array = np.random.rand(3, 3)
  print(sess.run(y, feed_dict={
    
    x: rand_array}))  # Will succeed.
'''
[[0.86674637 0.97794515 1.1810058 ]
 [0.8465452  0.99383634 1.4519857 ]
 [0.1171783  0.14964394 0.45982745]]
'''

猜你喜欢

转载自blog.csdn.net/yewumeng123/article/details/131342886
今日推荐