深度学习(七十二)tensorflow 集群训练

版权声明:本文为博主原创文章,欢迎转载,转载请注明原文地址、作者信息。 https://blog.csdn.net/hjimce/article/details/79650210
#encoding:utf-8
# -*- coding: utf-8 -*-
#使用说明:1、修改分类数目;2、修改输入图片大小;
# 3、修改是否启用集群; 4、修改batch size大小;5、修改数据路径、模型保存路径
#6、设置是否启用boostrap loss 损失函数
import os

import tensorflow as tf
from input_data import Data_layer
import net

num_class = 2
input_height = 256
input_width = 256
crop_height = 224
crop_width = 224
learning_rate = 0.01
tf.set_random_seed(123)
batch_size = tf.placeholder(tf.int32, [], 'batch_size')
tf.add_to_collection('batch_size', batch_size)
is_training = tf.placeholder(tf.bool, [])
is_boostrap = tf.placeholder(tf.bool, [])
drop_prob = tf.placeholder(tf.float32, [])
tf.add_to_collection('is_training', is_training)
def load_save_model(sess,saver,model_path,is_save):
    if is_save is False:
            print "***********restore model from %s***************"%model_path
            saver.restore(sess, model_path)
    else:
        saver.save(sess, model_path)
def train_cluster(train_volume_data,valid_volume_data,model_path):



    tf.flags.DEFINE_string("ps_hosts", "localhost:2222", "Comma-separated list of hostname:port pairs")
    tf.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
                           "Comma-separated list of hostname:port pairs")
    tf.app.flags.DEFINE_string("job_name", "", "Either 'ps' or 'worker'")
    tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
    tf.app.flags.DEFINE_string("volumes", "", "volumes info")

    FLAGS = tf.app.flags.FLAGS


    ps_hosts = FLAGS.ps_hosts.split(",")
    print("ps_hosts:", ps_hosts)
    worker_hosts = FLAGS.worker_hosts.split(",")
    print("worker_hosts:", worker_hosts)
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
    print("FLAGS.task_index:", FLAGS.task_index)
    # Create and start a server for the local task.
    server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)

    if FLAGS.job_name == "ps":
        server.join()
    elif FLAGS.job_name == "worker":
        with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index,
                                                      cluster=cluster)):
            input_data = Data_layer(train_volume_data, valid_volume_data, batch_size=batch_size
                                    , image_height=input_height, image_width=input_width, crop_height=crop_height
                                    , crop_width=crop_width)
            images, labels = input_data.get_next_batch(is_training, num_class)

            #net_worker = net.resnet(images, labels, num_class, 18, is_training, drop_prob)
            net_worker = net.resnet256(images, labels, num_class, is_training,is_boostrap)
            saver = tf.train.Saver()
            init = tf.global_variables_initializer()
            sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), init_op=init,saver=saver,global_step=net_worker['global_step'])

        with sv.prepare_or_wait_for_session(server.target) as session:
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(session, coord=coord)
            #threads = sv.start_queue_runners(session)
            load_save_model(session, saver, model_path, False)

            try:
                for i in range(400000):
                    if i < 10000:
                        train_dict = {batch_size: 32,
                                      drop_prob: 1,
                                      is_training: True,
                                      is_boostrap: False
                                      }
                    else:
                        train_dict = {batch_size: 32,
                                      drop_prob: 1,
                                      is_training: True,
                                      is_boostrap: True
                                      }
                    step, _ = session.run([net_worker['global_step'], net_worker['train_op']], feed_dict=train_dict)
                    if i % 500 == 0:
                        train_dict = {batch_size: 32,
                                      drop_prob: 1,
                                      is_training: True,
                                      is_boostrap: False
                                      }
                        entropy, train_acc = session.run([net_worker['cross_entropy'], net_worker['accuracy']],
                                                         feed_dict=train_dict)
                        print('***** {}:{},{} *****'.format(i, entropy, train_acc))
                    if i % 2000 == 0:
                        test_dict = {drop_prob: 1.0,
                                     is_training: False,
                                     batch_size: 256,
                                     is_boostrap:False}
                        acc = session.run(net_worker['accuracy'], feed_dict=test_dict)
                        print('*****locate step {},valid step {}:accuracy {} *****'.format(i,step, acc))

                        if i>3000:
                            print "**************save model***************"
                            load_save_model(session,saver,model_path,True)


            except Exception, e:
                coord.request_stop(e)

            finally:
                coord.request_stop()
                coord.join(threads)



猜你喜欢

转载自blog.csdn.net/hjimce/article/details/79650210
今日推荐