tensorflow 笔记6--迁移学习

tensorflow 笔记6–迁移学习


参考文档https://github.com/ageron/handson-ml/blob/master/11_deep_learning.ipynb


一、冻结部分层权重

法一:

with tf.name_scope("train"):                                        
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    # 指定要训练的那部分层
    train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="hidden[34]|outputs")
    training_op = optimizer.minimize(loss, var_list=train_vars)


# 恢复冻结层的数据,其实也可以全部恢复
reuse_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="hidden[12]") 
restore_saver = tf.train.Saver(reuse_vars) 

with tf.Session() as sess:
    restore_saver.restore(sess, "./my_model_final.ckpt")

法二:

with tf.name_scope("dnn"):
    hidden1 = tf.layers.dense(X, n_hidden1, activation=tf.nn.relu, name="hidden1") # reused frozen
    hidden2 = tf.layers.dense(hidden1, n_hidden2, activation=tf.nn.relu, name="hidden2") # reused frozen
    # 在此之前的层不会进行梯度更新
    hidden2_stop = tf.stop_gradient(hidden2)
    # 注意以下的层要相应的修改为hidden2_stop
    hidden3 = tf.layers.dense(hidden2_stop, n_hidden3, activation=tf.nn.relu, name="hidden3") # reused, not frozen
    hidden4 = tf.layers.dense(hidden3, n_hidden4, activation=tf.nn.relu, name="hidden4") # new!
    logits = tf.layers.dense(hidden4, n_outputs, name="outputs") # new!

# 剩下的和正常的一样

二、缓存冻结层结果

# 先设置冻结层,再进行以下操作

with tf.Session() as sess:
    init.run()
    restore_saver.restore(sess, "./my_model_final.ckpt")

    # 缓存冻结层的结果,即训练期间只计算一次
    h2_cache = sess.run(hidden2, feed_dict={X: X_train})
    h2_cache_valid = sess.run(hidden2, feed_dict={X: X_valid}) 

    for epoch in range(n_epochs):
        shuffled_idx = np.random.permutation(len(X_train))
        # feed的数据应该相应的改为冻结层的结果
        hidden2_batches = np.array_split(h2_cache[shuffled_idx], n_batches)
        y_batches = np.array_split(y_train[shuffled_idx], n_batches)
        for hidden2_batch, y_batch in zip(hidden2_batches, y_batches):
            sess.run(training_op, feed_dict={hidden2:hidden2_batch, y:y_batch})

        accuracy_val = accuracy.eval(feed_dict={hidden2: h2_cache_valid,  y: y_valid})      
        print(epoch, "Validation accuracy:", accuracy_val)               

    save_path = saver.save(sess, "./my_new_model_final.ckpt")

猜你喜欢

转载自blog.csdn.net/Wang_Jiankun/article/details/81135864