tf.cast:将张量转换为新类型
tf.cast(
x, dtype, name=None
)
该操作将x
(如果是Tensor
)或x.values
(如果是SparseTensor
或IndexedSlices
)强制转换为dtype
。
例子:
import tensorflow as tf
with tf.Session() as sess:
x = tf.constant([1.8, 2.2], dtype=tf.float32)
print(x)
b = tf.dtypes.cast(x, tf.int32)
print(b)
# 输出结果:
# Tensor("Const:0", shape=(2,), dtype=float32)
# Tensor("Cast:0", shape=(2,), dtype=int32)
操作支持的数据类型(x
和dtype
的)为: uint8
,uint16
,uint32
,uint64
,int8
,int16
,int32
,int64
, float16
,float32
,float64
,complex64
,complex128
,bfloat16
。如果从复杂类型(complex64
,complex128
)转换为实类型,则仅返回x的实部。如果将实类型转换为复杂类型(complex64
,complex128
),则将返回值的虚部设置为0
。这里对复杂类型的处理与numpy的行为相匹配。
Args |
|
---|---|
x |
一个Tensor 或者SparseTensor 或IndexedSlices 数字型的。这可能是uint8 ,uint16 ,uint32 ,uint64 ,int8 ,int16 ,int32 , int64 ,float16 ,float32 ,float64 ,complex64 ,complex128 , bfloat16 。 |
dtype |
目标类型。支持的dtypes列表与x 相同。 |
name |
操作的名称(可选)。 |
Returns |
---|
一个Tensor 或SparseTensor 或IndexedSlices 具有与x相同的形状和dtype 相同的类型。 |
Raises |
|
---|---|
TypeError |
如果x 无法转换为dtype 。 |