tf.shape()和tensor.get_shape()

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Laox1ao/article/details/79896656

问题

数据输入的格式为

input = tf.placeholder([None,xxx,xxx],dtype=tf.float32)

需要得到batch的维度来进行中间Variable的初始化

val = tf.zeros([batch_size,xxx,xxx],dtype=tf.float32)

方法

可行:

val = tf.zeros([tf.shape(input)[0],xxx,xxx],dtype=tf.float32)

失败:

val = tf.zeros([input.get_shape()[0],xxx,xxx],dtype=tf.float32)
val = tf.zeros([input.shape()[0],xxx,xxx],dtype=tf.float32)

错误提示:

ValueError: Cannot convert a partially known TensorShape to a Tensor

原因:
可能是由于get_shape()返回的是元组,tf.shape()返回的是tensor,所以tf.shape()返回的为None的batch_size维度可以继续作为其他tensor的维度,而get_shape()由于返回的是元组,取到的batch_size维度直接是None了,无法作为其他tensor的维度。

猜你喜欢

转载自blog.csdn.net/Laox1ao/article/details/79896656