tf.cast()函数解析

版权声明:本文版权归作者和CSDN共有,欢迎转载。转载时请注明原作者并保留此段声明,若不保留我也不咬你,随你了=-=。 https://blog.csdn.net/TeFuirnever/article/details/88934687

tf.cast()函数用于执行 tensorflow 中张量数据类型转换。

tf.cast(
	x, 
	dtype, 
	name=None
)

参数:

  • x:待转换的数据(张量)
  • dtype:目标数据类型
  • name:可选参数,定义操作的名称

将x的数据格式转化成dtype。

例如,原来x的数据格式是bool,那么将其转化成float以后,就能够将其转化成0和1的序列,反之也可以。

tensorflow中的数据类型列表如下:

Python 类型 描述
tf.float32 32 位浮点数
tf.float64 64 位浮点数
tf.int64 64 位有符号整型
tf.int32 32 位有符号整型
tf.uint8 8 位无符号整型.
tf.string 可变长度的字节数组.每一个张量元素都是一个字节数组
tf.bool 布尔型

例子:

import tensorflow as tf
 
x = tf.Variable([1,2,3,4,5])
y = tf.cast(x,dtype=tf.float32)
 
print('x: {}'.format(x))
print('y: {}'.format(y))
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(y)
    print(sess.run(y))
> x: <tf.Variable 'Variable:0' shape=(5,) dtype=int32_ref>
  y: Tensor("Cast:0", shape=(5,), dtype=float32)
  [ 1.  2.  3.  4.  5.]

参考文章:

tf.cast()数据类型转换

猜你喜欢

转载自blog.csdn.net/TeFuirnever/article/details/88934687