tensorflow中的条件语句和循环语句

tensorflow中,不可以直接拿tensor比较的结果作为 if 语句的条件,因此tensorflow中实现了自己的条件语句:

a = tf.get_variable("a",initializer=1)
b = tf.get_variable("b",initializer=2)

pred = tf.equal(a,b)

## 下面这种写法是正确的
def fun1():
	return a
def fun2():
	return b
c = tf.cond(pred, fun1, fun2)

## 下面这种写法是错误的
# if tf.equal(a,b):
# 	c = a
# else:
# 	c = b

with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	print(sess.run(c))

同理,tensorflow中,while函数也是需要有条件判断语句的,所以tensorflow实现了自己的while循环:

a = tf.get_variable("a",initializer=1)
b = tf.get_variable("b",initializer=5)

def cond(a, b):
	# 输入为loop_vars, 输出为布尔值
	return tf.less(a, b)

def body(a, b):
	# 输入为loop_vars, 输出为lop_vars
	a = a + 2
	b = b + 1
	return a, b
a, b = tf.while_loop(cond, body, loop_vars=[a, b])
with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	print(sess.run([a,b])) # 9,9

下面我们借助条件语句和循环语句实现一个在numpy中非常容易实现的功能:

给定两个placeholder,input_a和input_b,两者的维度都是[None,5],而实际输入是维度分别为[seq_len,5]和[seq_len+1,5],得到一个output_c,形状为[seq_len*2+1, 5],其奇数位来自于input_b,偶数位来自于input_a。即,对于output_c而言,[2*i,5]的位置来自于input_a[i,5],[2*i+1,5]的位置来自于input_b[i,5]。

题目如下:

input_a = tf.placeholder(dtype=tf.int64, shape=[None, 5]) #实际输入时为[seq_len, 5]
input_b = tf.placeholder(dtype=tf.int64, shape=[None, 5]) #实际输入时为[seq_len+1, 5]

"""
请在此处键入你的代码
"""

with tf.Session() as sess:
	
    feed_a = [[1,2,3,4,5],[2,3,4,5,6]] #shape=[3,5]
    feed_b = [[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0]] #shape=[4,5]

    sess.run(tf.global_variables_initializer())
	output_c_value = sess.run(output_c, feed_dict={input_a:feed_a, input_b:feed_b})

	print(output_c_value)
	#期望得到的值是[[0,0,0,0,0],[1,2,3,4,5],[0,0,0,0,0],[2,3,4,5,6],[0,0,0,0,0]]

在实现的过程中,有几个关键点:

1. 因为placeholder形状的第一维是None,所以取出来的seq_len是一个Tensor

2. 因为seq_len是一个Tensor,所以无法用seq_len来实现python原生的条件和循环语句

3. Tensor不支持对特定位进行赋值,所以必须建一个新的Tensor,然后把原有的Tensor中的特定位赋给新的Tensor

最终实现的代码如下:

import tensorflow as tf

input_a = tf.placeholder(dtype=tf.int64, shape=[None, 5]) #实际输入时为[seq_len, 5]
input_b = tf.placeholder(dtype=tf.int64, shape=[None, 5]) #实际输入时为[seq_len+1, 5]

seq_len = tf.shape(input_a)[0]

i = tf.get_variable("i", initializer=0)
output_c = tf.expand_dims(input_b[0,:], axis=0)

def loop_cond(input_a, input_b, output_c, i, seq_len):
	return tf.less(i, 2*seq_len)

def loop_body(input_a, input_b, output_c, i, seq_len):

	def concat_a():
		return tf.concat([output_c, tf.expand_dims(input_a[i//2,:], axis=0)], axis=0)
	def concat_b():
		return tf.concat([output_c, tf.expand_dims(input_b[i//2,:], axis=0)], axis=0)
	pred = tf.equal(i%2, 0)
	output_c = tf.cond(pred, concat_a, concat_b)
	i += 1
	return input_a, input_b, output_c, i, seq_len

_, _, output_c, _, _ = tf.while_loop(cond=loop_cond, body=loop_body, 
                                     loop_vars=[input_a, input_b, output_c, i, seq_len],                                                                          
                                     shape_invariants=[input_a.get_shape(), 
                                                       input_b.get_shape(),      
                                                       tf.TensorShape([None,5]), 
                                                       i.get_shape(), 
                                    	               seq_len.get_shape()])

with tf.Session() as sess:
	
	feed_a = np.array([[1,2,3,4,5],[2,3,4,5,6]]) #shape=[2,5]
	feed_b = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0]]) #shape=[3,5]

	sess.run(tf.global_variables_initializer())
	output_c_value = sess.run(output_c, feed_dict={input_a:feed_a, input_b:feed_b})

	print(output_c_value)
	#得到的值是[[0,0,0,0,0],[1,2,3,4,5],[0,0,0,0,0],[2,3,4,5,6],[0,0,0,0,0]]

有趣之处在于,tensorflow的while_loop是先循环再判断,也就是说,循环终止条件需要提前一步。

猜你喜欢

转载自blog.csdn.net/bonjourdeutsch/article/details/102684930