版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Laox1ao/article/details/79896656
问题
数据输入的格式为
input = tf.placeholder([None,xxx,xxx],dtype=tf.float32)
需要得到batch的维度来进行中间Variable的初始化
val = tf.zeros([batch_size,xxx,xxx],dtype=tf.float32)
方法
可行:
val = tf.zeros([tf.shape(input)[0],xxx,xxx],dtype=tf.float32)
失败:
val = tf.zeros([input.get_shape()[0],xxx,xxx],dtype=tf.float32)
val = tf.zeros([input.shape()[0],xxx,xxx],dtype=tf.float32)
错误提示:
ValueError: Cannot convert a partially known TensorShape to a Tensor
原因:
可能是由于get_shape()返回的是元组,tf.shape()返回的是tensor,所以tf.shape()返回的为None的batch_size维度可以继续作为其他tensor的维度,而get_shape()由于返回的是元组,取到的batch_size维度直接是None了,无法作为其他tensor的维度。