tensorflow 集群训练示例代码

    #encoding:utf-8  
    # -*- coding: utf-8 -*-  
    #使用说明:1、修改分类数目;2、修改输入图片大小;  
    # 3、修改是否启用集群; 4、修改batch size大小;5、修改数据路径、模型保存路径  
    #6、设置是否启用boostrap loss 损失函数  
    import os  

    import tensorflow as tf  
    from input_data import Data_layer  
    import net  

    num_class = 2  
    input_height = 256  
    input_width = 256  
    crop_height = 224  
    crop_width = 224  
    learning_rate = 0.01  
    tf.set_random_seed(123)  
    batch_size = tf.placeholder(tf.int32, [], 'batch_size')  
    tf.add_to_collection('batch_size', batch_size)  
    is_training = tf.placeholder(tf.bool, [])  
    is_boostrap = tf.placeholder(tf.bool, [])  
    drop_prob = tf.placeholder(tf.float32, [])  
    tf.add_to_collection('is_training', is_training)  
    def load_save_model(sess,saver,model_path,is_save):  
        if is_save is False:  
                print "***********restore model from %s***************"%model_path  
                saver.restore(sess, model_path)  
        else:  
            saver.save(sess, model_path)  
    def train_cluster(train_volume_data,valid_volume_data,model_path):  



        tf.flags.DEFINE_string("ps_hosts", "localhost:2222", "Comma-separated list of hostname:port pairs")  
        tf.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",  
                               "Comma-separated list of hostname:port pairs")  
        tf.app.flags.DEFINE_string("job_name", "", "Either 'ps' or 'worker'")  
        tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")  
        tf.app.flags.DEFINE_string("volumes", "", "volumes info")  

        FLAGS = tf.app.flags.FLAGS  


        ps_hosts = FLAGS.ps_hosts.split(",")  
        print("ps_hosts:", ps_hosts)  
        worker_hosts = FLAGS.worker_hosts.split(",")  
        print("worker_hosts:", worker_hosts)  
        cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})  
        print("FLAGS.task_index:", FLAGS.task_index)  
        # Create and start a server for the local task.  
        server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)  

        if FLAGS.job_name == "ps":  
            server.join()  
        elif FLAGS.job_name == "worker":  
            with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index,  
                                                          cluster=cluster)):  
                input_data = Data_layer(train_volume_data, valid_volume_data, batch_size=batch_size  
                                        , image_height=input_height, image_width=input_width, crop_height=crop_height  
                                        , crop_width=crop_width)  
                images, labels = input_data.get_next_batch(is_training, num_class)  

                #net_worker = net.resnet(images, labels, num_class, 18, is_training, drop_prob)  
                net_worker = net.resnet256(images, labels, num_class, is_training,is_boostrap)  
                saver = tf.train.Saver()  
                init = tf.global_variables_initializer()  
                sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), init_op=init,saver=saver,global_step=net_worker['global_step'])  

            with sv.prepare_or_wait_for_session(server.target) as session:  
                coord = tf.train.Coordinator()  
                threads = tf.train.start_queue_runners(session, coord=coord)  
                #threads = sv.start_queue_runners(session)  
                load_save_model(session, saver, model_path, False)  

                try:  
                    for i in range(400000):  
                        if i < 10000:  
                            train_dict = {batch_size: 32,  
                                          drop_prob: 1,  
                                          is_training: True,  
                                          is_boostrap: False  
                                          }  
                        else:  
                            train_dict = {batch_size: 32,  
                                          drop_prob: 1,  
                                          is_training: True,  
                                          is_boostrap: True  
                                          }  
                        step, _ = session.run([net_worker['global_step'], net_worker['train_op']], feed_dict=train_dict)  
                        if i % 500 == 0:  
                            train_dict = {batch_size: 32,  
                                          drop_prob: 1,  
                                          is_training: True,  
                                          is_boostrap: False  
                                          }  
                            entropy, train_acc = session.run([net_worker['cross_entropy'], net_worker['accuracy']],  
                                                             feed_dict=train_dict)  
                            print('***** {}:{},{} *****'.format(i, entropy, train_acc))  
                        if i % 2000 == 0:  
                            test_dict = {drop_prob: 1.0,  
                                         is_training: False,  
                                         batch_size: 256,  
                                         is_boostrap:False}  
                            acc = session.run(net_worker['accuracy'], feed_dict=test_dict)  
                            print('*****locate step {},valid step {}:accuracy {} *****'.format(i,step, acc))  

                            if i>3000:  
                                print "**************save model***************"  
                                load_save_model(session,saver,model_path,True)  


                except Exception, e:  
                    coord.request_stop(e)  

                finally:  
                    coord.request_stop()  
                    coord.join(threads)  

猜你喜欢

转载自blog.csdn.net/daydayup_668819/article/details/80004224