这周我的工作主要负责封装风格迁移网络对于外界的接口,经过一周的训练,我们已经得到了七种风格的网络模型,分别保存在
ckpt文件当中。首先判断用户选择的风格样式,然后调用相应的tensorflow调用相关的网络模型,将用户传入的图片经过迁移、保存,返回给用户。
主要代码如下:
def main(argv): os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' style = argv[1] rawImg = argv[2] genImg = argv[3] print(style,rawImg,genImg) tf.app.flags.DEFINE_string('loss_model', 'vgg_16', 'You can view all the support models in nets/nets_factory.py') tf.app.flags.DEFINE_integer('image_size', 256, 'Image size to train.') model_path = "E:\\programming\\PYworkbench\\Style-Transformer-Website\\trained_models\\" if style=='1': tf.app.flags.DEFINE_string("model_file", model_path+"shuimo.ckpt-done", "") elif style=='2': tf.app.flags.DEFINE_string("model_file", model_path + "cubist.ckpt-6000", "") elif style =='3': tf.app.flags.DEFINE_string("model_file", model_path + "denoised_starry.ckpt-done", "") elif style =='4': tf.app.flags.DEFINE_string("model_file", model_path + "feathers.ckpt-done", "") elif style=='5': tf.app.flags.DEFINE_string("model_file", model_path + "mosaic.ckpt-done", "") elif style=='6': tf.app.flags.DEFINE_string("model_file", model_path + "scream.ckpt-done", "") elif style=='7': tf.app.flags.DEFINE_string("model_file", model_path + "udnie.ckpt-done", "") elif style=='8': tf.app.flags.DEFINE_string("model_file", model_path + "wave.ckpt-done", "") elif style == '9': tf.app.flags.DEFINE_string("model_file", model_path + "jianzhi.ckpt-4000", "") tf.app.flags.DEFINE_string("image_file",rawImg, "") FLAGS = tf.app.flags.FLAGS with open(FLAGS.image_file, 'rb') as img: with tf.Session().as_default() as sess: if FLAGS.image_file.lower().endswith('png'): image = sess.run(tf.image.decode_png(img.read())) else: image = sess.run(tf.image.decode_jpeg(img.read())) height = image.shape[0] width = image.shape[1] tf.logging.info('Image size: %dx%d' % (width, height)) with tf.Graph().as_default(): with tf.Session().as_default() as sess: # Read image data. image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing( FLAGS.loss_model, is_training=False) image = reader.get_image(FLAGS.image_file, height, width, image_preprocessing_fn) # Add batch dimension image = tf.expand_dims(image, 0) generated = model.net(image, training=False) generated = tf.cast(generated, tf.uint8) # Remove batch dimension generated = tf.squeeze(generated, [0]) # Restore model variables. saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1) sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) # Use absolute path FLAGS.model_file = os.path.abspath(FLAGS.model_file) saver.restore(sess, FLAGS.model_file) # Make sure 'generated' directory exists. generated_file = genImg if os.path.exists('generated') is False: os.makedirs('generated') # Generate and write image data to file. with open(generated_file, 'wb') as img: start_time = time.time() img.write(sess.run(tf.image.encode_jpeg(generated))) if(style == '1'): str = 'python Sky_segment_postProcessing/sky_postprocessing.py '\ + rawImg + ' ' +generated_file os.system(str) end_time = time.time() print('Elapsed time: %fs' % (end_time - start_time)) print('Done. Please check %s.' % generated_file)