tflearn加载两个模型,分别根据判断条件使用对应的模型进行预测

最近实现论文的一个代码,记录下自己的采坑过程:

实现目标:训练多个模型,保存下来,测试的时候,加载这些模型,根据条件判断需要使用哪个模型进行预测。

1、网络训练,模型保存

前面的网络设计省略,每个人任务不同,设计的网络也不同,以net表示。我使用的是tflearn(TensorFlow的用于快速开发的高层封装)

model = tflearn.DNN(net,checkpoint_path='./check_point',max_checkpoints=5, tensorboard_verbose=0,tensorboard_dir='log')

model.fit(X, Y, n_epoch=100, validation_set=1/6,
          snapshot_epoch=True, snapshot_step=100,
          show_metric=True, batch_size=batch_size, shuffle=True,
          run_id=dataset_name)
model.save("model/"+dataset_name+".model")

跑两个不同的网络,保存两个模型:com_model和sim_model(每个模型有4个文件:checkpoint,index,data,meta)

2、加载模型(采坑重点)

刚开始直接写了两段加载模型的代码,以为一切很顺利,

c_net = com_net(_input)
c_model = tflearn.DNN(c_net)
c_model.load(complex_model_path)
s_net = sim_net(_input2)
s_model = tflearn.DNN(s_net)
s_model.load(simple_model_path)
#根据判断条件,使用对应模型进行预测
if c>threshod:
    block_label = c_model.predict_label(block)  
else :
    block_label = s_model.predict_label(block)  

但是很可惜,报错了,各种key xxx not found in checkpoint的错误,然后查了一下网上的解决办法,看到一篇博客:

https://blog.csdn.net/xfgryujk/article/details/79597289里面提到了tflearn加载模型碰到这种问题怎么解决,写的很简单,

主要是使用tf.reset_default_graph(),然后加载模型时要传参数weights_only=True,代码如下:

加载第一个模型没有差别

c_net = com_net(_input)
c_model = tflearn.DNN(c_net)
c_model.load(complex_model_path)
加载第二个模型的时候就要修改了:
    ##很关键,不能去掉
     tf.reset_default_graph()
    s_net = sim_net(_input2)
    s_model = tflearn.DNN(s_net)

    s_model.load(simple_model_path,weights_only=True)

然后就可以顺利加载这两个模型了。后面的预测还是使用前面的代码。TensorFlow学的不是很细,自己思考了下,这个错误应该是两个模型不能同时加载到一个图中的原因吧。后面实验了一些上面那个博客所说的第二种方法,主要是给出了一个官方的例子,参考了下。代码如下:

with tf.variable_scope("scope1") as scope1:
   _input = tflearn.input_data(shape=[None, image_size, image_size, color_channel])
   c_net = com_net(_input)
   c_model = tflearn.DNN(c_net)
   c_model.load("./com_model/s_uniward_high.model",scope_for_restore="scope1",weights_only=True)

with tf .variable_scope("scope2") as scope2:
   _input2 = tflearn.input_data(shape=[None, image_size, image_size, color_channel])
   s_net = sim_net(_input2)
   s_model = tflearn.DNN(s_net)
   s_model.load("./sim_model/s_uniward_low.model",scope_for_restore="scope2",weights_only=True)

试着运行了下,结果很让人不解,两个模型加载没有问题,但是在预测的语句出错了:

block_label = c_model.predict_label(block)  

其中这个block是待预测的图像,我按照模型的输入维度,将其reshape成了(1,256,256,1),但是预测时候出现了维度不匹配的问题:


这个就很纳闷了,用单个模型测试完全没问题,说明不是维度出问题,后面发现,在加载第二个模型的时候,还是那句tf.reset_default_graph()不能少,加上去就没问题了。

3、总结

tensorflow中的操作和张量都是在图中完成的,因此需要注意图的使用,这里使用的tf.reset_default_graph()就是将当前默认图重置。






猜你喜欢

转载自blog.csdn.net/sinat_33486980/article/details/80545755