TensorFlow中的图(自定义图)

上一节(TensorFlow2.0之计算过程与名字来历)中我们反复提到了图,但是代码中并没有看到图的定义,也没看到任何跟图有关的代码。其实,在TensorFlow中,TensorFlow会定义默认图。用户可以自己显式定义图,并将自定义图作为默认图。

TensorFlow的图中包含tf.Operation对象集合,一个tf.Operation对象表示一个计算单元,如加、减、乘、除等计算都是tf.Operation对象。TensorFlow中的图还包含tf.Tensor对象,tf.Tensor对象表示参与运算的数据,这些数据在路径的各个tf.Operation节点中参与运算,并在图中的各个路径中传递。

代码:

import tensorflow as tf
# 定义Python中int类型二维矩阵
A = [[1, 2, 3],
     [4, 5, 6]]
B = [[1, 1], 
     [1, 1],
     [1, 1]]
my_graph = tf.compat.v1.Graph()
with my_graph.as_default():
    # 将Python类型数据A和B传入图中
    A_tf = tf.compat.v1.constant(A, dtype=tf.float32, name="A")
    B_tf = tf.compat.v1.constant(B, dtype=tf.float32, name="B")
    #  构建图中的计算节点
    C_tf = tf.compat.v1.matmul(A_tf, B_tf)
print("C_tf is my_graph:", C_tf.graph is my_graph)
# 图构建完毕
with tf.compat.v1.Session(graph=my_graph) as sess:
    C = sess.run(C_tf)
    print(C)

输出:

C_tf is my_graph: True
[[ 6.  6.]
 [15. 15.]]

上述代码中,第8行自定义了图 ,并在第9-14行往自定义的图中加入了数据节点和计算节点:第15行打印验证加入自定义图中的结点是否正确;第17行将自定义的图作为tf.Session的默认图。从输出结果可以看到,每个数据对象的计算节点对象可以在指定图中存放。有一点需要注意的是,计算节点的输出数据对象会被放置到输入数据对象所在的图中。下面通过一个例子说明,代码如下:

import tensorflow as tf
# 定义Python中int类型二维矩阵
A = [[1, 2, 3],
     [4, 5, 6]]
B = [[1, 1],
     [1, 1],
     [1, 1]]
my_graph1 = tf.compat.v1.Graph()
my_graph2 = tf.compat.v1.Graph()
with my_graph1.as_default():
    # 将Python类型数据A传入图中
    A_tf = tf.compat.v1.constant(A, dtype=tf.float32, name="A")
    # 将Python类型数据B传入图中
    B_tf = tf.compat.v1.constant(B, dtype=tf.float32, name="B")
with my_graph2.as_default():
    # 试图将C_tf放入图my_graph2中
    C_tf = tf.compat.v1.matmul(A_tf, B_tf)

print("C_tf.graph is my_graph1:", C_tf.graph is my_graph1)
print("C_tf.graph is my_graph2:", C_tf.graph is my_graph2)

输出:

C_tf.graph is my_graph1: True
C_tf.graph is my_graph2: False

可以看到 ,即使指定将C_tf放在图my_graph2中,还是无法改变C_tf实际存放在my_graph1中的事实。此外,矩阵相乘计算节点也不会在my_graph2中,而会在my_graph1中。接下来我们自定义不同的图,看看不同的图中的数据和计算节点之间交叉引用会怎么样。

代码:

import tensorflow as tf
# 定义Python中int类型二维矩阵
A = [[1, 2, 3],
     [4, 5, 6]]
B = [[1, 1],
     [1, 1],
     [1, 1]]
my_graph1 = tf.compat.v1.Graph()
my_graph2 = tf.compat.v1.Graph()
my_graph3 = tf.compat.v1.Graph()
with my_graph1.as_default():
    # 将Python类型数据A传入图中:
    A_tf = tf.compat.v1.constant(A, dtype=tf.float32, name="A")
with my_graph2.as_default():
    # 将Python类型数据B传入图中:
    B_tf = tf.compat.v1.constant(B, dtype=tf.float32, name="B")
with my_graph3.as_default():
    # 构建图中的计算节点:
    C_tf = tf.matmul(A_tf, B_tf)

# 图构建完毕
with tf.compat.v1.Session(graph=my_graph3) as sess:
    C = sess.run(C_tf)
    print(C)

此时会报错,输出报错结果如下:

Traceback (most recent call last):
  File "E:/Pycharm专业版/Workspace/Data_Science/gensim_operation/word2vec_test/tensorflow_test/preparation_work/图中数据与计算节点交叉引用.py", line 19, in <module>
    C_tf = tf.matmul(A_tf, B_tf)
  File "E:\Anaconda\Anaconda_Package\lib\site-packages\tensorflow_core\python\util\dispatch.py", line 180, in wrapper
    return target(*args, **kwargs)
  File "E:\Anaconda\Anaconda_Package\lib\site-packages\tensorflow_core\python\ops\math_ops.py", line 2687, in matmul
    with ops.name_scope(name, "MatMul", [a, b]) as name:
  File "E:\Anaconda\Anaconda_Package\lib\site-packages\tensorflow_core\python\framework\ops.py", line 6337, in __enter__
    g_from_inputs = _get_graph_from_inputs(self._values)
  File "E:\Anaconda\Anaconda_Package\lib\site-packages\tensorflow_core\python\framework\ops.py", line 5982, in _get_graph_from_inputs
    _assert_same_graph(original_graph_element, graph_element)
  File "E:\Anaconda\Anaconda_Package\lib\site-packages\tensorflow_core\python\framework\ops.py", line 5917, in _assert_same_graph
    (item, original_item))
ValueError: Tensor("B:0", shape=(3, 2), dtype=float32) must be from the same graph as Tensor("A:0", shape=(2, 3), dtype=float32).

从上述报错结果中可以看到,不同图中的数据和计算节点相互引用计算,会出现错误。ValueError提示很明显,即在第19行中计算矩阵运算时,名为“A:0”的数据对象(即Tensor对象)与名为“B:0”的数据对象在不同图中。在构建图时,各个数据对象和计算节点对象必须在当前图中,不同图之间的资源是不能交叉引用的。

注意:tf.Graph()构造函数是非线程安全的函数,在创建图时需要在单线程或外部保证线程安全。

发布了105 篇原创文章 · 获赞 17 · 访问量 11万+

猜你喜欢

转载自blog.csdn.net/qq_38890412/article/details/104058919