tensorflow调用预训练模型

最近在使用tensorflow的预训练模型,把自己的心得记录下来~


Tensorflow读取并输出已保存模型的权重数值

参考链接
https://blog.csdn.net/AManFromEarth/article/details/81057577
https://blog.csdn.net/aiseu001/article/details/79851176

import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.python import pywrap_tensorflow
#首先,使用tensorflow自带的python打包库读取模型
model_reader = pywrap_tensorflow.NewCheckpointReader(r"saver-test")

#然后,使reader变换成类似于dict形式的数据
var_dict = model_reader.get_variable_to_shape_map()
print(len(var_dict))#输出模型中的变量个数
print(var_dict) #输出模型中的变量名称

#提取模型中的某一个变量名称和值
w1 = model_reader.get_tensor("conv1/W") #提取模型中名为conv1/W的变量(conv1的权重参数)
print(type(w1)) #输出变量w1的类型  <class 'numpy.ndarray'>
print(w1.shape) #输出变量w1的形状  (11, 11, 3, 96)
# print(w1)  #输出变量w1的值

#循环输出模型中所有参数的名称和值
for key in var_dict:
    print("variable name: ", key)
    print(model_reader.get_tensor(key))


#如果想要输出到文件,使用:
with open("output.txt","w+") as f:
#循环打印输出
    for key in var_dict:
        f.write(str(key))
        f.write(str(model_reader.get_tensor(key)))

读取预训练模型,有选择性的加载参数

我是对原网络进行了一些修改,添加了一些新的参数,所以想要导入预训练模型来初始化原来就有的那部分参数,新添加的参数采用随机初始化。代码:

reader = tf.train.NewCheckpointReader("output/saver-test")
restore_dict = dict()
for v in tf.trainable_variables(): #只读取当前网络结构中待训练的变量
    tensor_name = v.name.split(':')[0] # 把变量后面的:0去掉了(conv1/b:0->conv1/b)
    # print(tensor_name) #输出训练变量的名称列表
    if reader.has_tensor(tensor_name): #如果预训练模型中含有我们想要加载的参数,就把它添加到待restore的参数字典中
        print('has tensor', tensor_name)
        restore_dict[tensor_name] = v
saver = tf.train.Saver(restore_dict)#恢复指定的变量字典

with tf.Session(config=config) as sess:
    sess.run(tf.global_variables_initializer()) #对所有参数进行随机初始化
    sess.run(tf.local_variables_initializer())
    saver.restore(sess, "output/saver-test") # restore指定的变量
    # 获取当前网络结构中的conv1/b:0变量,查看其值,b的值就是预训练模型中保存的conv1/b的取值,说明restore成功
    b=tf.get_default_graph().get_tensor_by_name("conv1/b:0") 
    print(sess.run(b))

从提取上一次训练结果继续训练

与第二种情况的区别就是,模型结构都是一样的,这个恢复起来更简单,直接这样操作即可,连参数初始化都不用~~~

saver = tf.train.Saver(max_to_keep=100) #max_to_keep参数是保存ckpt文件的个数,默认是5
with tf.Session(config=config) as sess:
    saver.restore(sess, "output/saver")

分享一个链接:
https://blog.csdn.net/qq_25737169/article/details/78125061
https://stackoverflow.com/questions/52532150/how-to-restore-pretrained-model-to-initialize-parameters

猜你喜欢

转载自blog.csdn.net/aaon22357/article/details/82862692