eval_model.py可以参考train_model.py编写,因此差不多
大体思路就是:
1.读入图片并进行预处理2.从保存的风格ckpt文件中恢复模型权重(注意不需要vgg)
3.将图片tensor输入到net中,得到转换后的image并保存
代码及注释:
# coding: utf-8 from __future__ import print_function import tensorflow as tf from preprocessing import preprocessing_factory import reader import model import time import os """ 编码思路: 1.读入图片并进行预处理 2.从保存的风格ckpt文件中恢复模型权重(注意不需要vgg) 3.将图片tensor输入到net中,得到转换后的image并保存 """ ###################### # define the parameter# ###################### tf.app.flags.DEFINE_string('loss_model', 'vgg_16', '损失网络模型名 ') tf.app.flags.DEFINE_string('loss_model_file', 'loss_model_ckpt/vgg_16.ckpt', '损失网络ckpt文件路径 ') tf.app.flags.DEFINE_integer('image_size', 256, '图像大小') #model的ckpt相关 tf.app.flags.DEFINE_string("model_path", "transfer_model_ckpt", "风格ckpt文件路径") tf.app.flags.DEFINE_string("model_name", "candy", "风格名") tf.app.flags.DEFINE_string("model_file", "models.ckpt", "风格ckpt文件名") #内容图片与风格图片 tf.app.flags.DEFINE_string("image_file", "srcImg/test.jpg", "输入模型的图片路径") tf.app.flags.DEFINE_string("res_file", "resImg", "模型输出的图片保存目录") tf.app.flags.DEFINE_string("res_image", "res.jpg", "模型输出的图片保存目录") tf.app.flags.DEFINE_string("style_image", "styleImg/candy.jpg", "风格图片的路径") #损失函数权重参数 tf.app.flags.DEFINE_float('content_weight', 1.0, '内容损失函数权重') tf.app.flags.DEFINE_float('style_weight', 100.0, '风格损失函数权重') tf.app.flags.DEFINE_float('tv_weight', 0.5, 'total variation损失函数权重') #训练数据相关参数 tf.app.flags.DEFINE_integer( 'batch_size', 128, 'batch大小') tf.app.flags.DEFINE_integer( 'epoch', 2, 'epoch个数') #layers tf.app.flags.DEFINE_list("content_layers", "vgg_16/conv3/conv3_3", "用于计算内容损失的layers") tf.app.flags.DEFINE_list("style_layers", ["vgg_16/conv1/conv1_2", "vgg_16/conv2/conv2_2" "vgg_16/conv3/conv3_3" "vgg_16/conv4/conv4_3"], "用于计算风格损失的layers") tf.app.flags.DEFINE_string("checkpoint_exclude_scopes", "vgg_16/fc", "不从ckpt中恢复权重的层") #learning_rate tf.app.flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.') FLAGS = tf.app.flags.FLAGS height = 0 width = 0 def main(_): #指定image路径,读取图片获取宽和高 FLAGS.model_file=FLAGS.model_path+ FLAGS.model_file with open(FLAGS.image_file, 'rb') as img: with tf.Session().as_default() as sess: if FLAGS.image_file.lower().endswith('png'): image = sess.run(tf.image.decode_png(img.read())) else: image = sess.run(tf.image.decode_jpeg(img.read())) height = image.shape[0] width = image.shape[1] with tf.Graph().as_default(): with tf.Session().as_default() as sess: # 读入image数据,并进行预处理 image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing( FLAGS.loss_model, is_training=False) image = reader.get_image(FLAGS.image_file, height, width, image_preprocessing_fn) # 增加batch维度 image = tf.expand_dims(image, 0) #转换网络模型的输出,真正运行是在后面恢复权重以后 generated = model.net(image, training=False) generated = tf.cast(generated, tf.uint8)#转换数据格式 # 去除batch维度 generated = tf.squeeze(generated, [0]) saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1) sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) #从已训练的风格转换模型的ckpt文件中恢复权重 FLAGS.model_file = os.path.abspath(FLAGS.model_file) saver.restore(sess, FLAGS.model_file) generated_file = FLAGS.res_file+FLAGS.res_img if os.path.exists(FLAGS.res_file) is False: os.makedirs(FLAGS.res_file) # 生成图片 with open(generated_file, 'wb') as img: start_time = time.time() img.write(sess.run(tf.image.encode_jpeg(generated))) end_time = time.time() tf.logging.info('Elapsed time: %fs' % (end_time - start_time)) tf.logging.info('Done. Please check %s.' % generated_file) if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) tf.app.run()