tensorflow学习(2):计算图,tf.get_default_graph(),tf.Graph()

一、基本概念

顾名思义,TensorFlow的名字已经出卖了它的“灵魂”,TensorFlow=tensor(张量)+flow(流动)。TensorFlow是一个通过计算图的形式来表达计算的编程框架。其每一个计算都是计算图上的一个节点,而节点之间的边描述了计算之间的依赖关系。
计算图是TensorFlow中最基本的一个概念,TensorFlow中的所有计算都会被转化为计算图上的节点。
在这里插入图片描述
上图是两个张量进行某种计算的计算图,张量都用节点表示,边代表了计算之间的依赖关系,这里的operation其实也应该是个圆圈。
假设以上是加的关系,那么有代码如下:

import tensorflow as tf
a = tf.constant([1,2])  #定义常量
b = tf.constant([3,4])
result = a + b            #定义关系
with tf.Session() as sess:
	sess.run(result)    #输出[4,6]

二、tf.get_default_graph()

功能:这个函数可以获取当前默认的计算图
例如,在上述代码的with语句块中加入如下语句

#通过a.graph可以查看张量所属的计算图
print(a.graph is tf.get_default_graph()) #输出True

三、tf.Graph()

除了使用默认计算图,TensorFlow支持通过tf.Graph()来生成新的计算图。不同计算图上的张量和运算都不会共享

import tensorflow as tf
g1 = tf.Graph()
with g1.as_default():
	#定义变量v,并设置初始值为0
	v = tf.get_variable("v", shape = [1], initializer = tf.zeros_initializer)
g2 = tf.Graph()
with g2.as_default():
	#定义变量v,并设置初始值为1
	v = tf.get_variable("v", shape = [1], initializer = tf.ones_initializer)
#在计算图g1中读取变量v的值
with tf.Session(graph = g1) as sess:
	tf.global_variables_initializer().run()
	with tf.variable_scope("",reuse = True):
		print(sess.run(tf.get_variable("v")))

#在计算图g2中读取变量v的值
with tf.Session(graph = g2) as sess:
	tf.global_variables_initializer().run()
	with tf.variable_scope("",reuse = True):
		print(sess.run(tf.get_variable("v")))

猜你喜欢

转载自blog.csdn.net/shanlepu6038/article/details/84475267