tf.cast(x, DstT, name)
参数: x --> 张量Tensor
DstT --> tf.DType 要转换的目标类型
name --> 运算名称(可选)
一般用于真实值和预测值比较后的布尔型转换为浮点型进行后续计算。
import tensorflow as tf
import numpy as np
y_pre = [0.9, 1.2, 0.75, 0.5, 0.8]
y = [0.8, 1.2, 0.75, 0.9, 0.8]
equal = tf.equal(y_pre, y)
cast = tf.cast(equal, 'float')
cast1 = tf.cast(equal, dtype = float)
cast2 = tf.cast(equal, dtype = tf.float32)
with tf.Session() as sess:
print(sess.run(equal))
print(sess.run(cast))
print(sess.run(cast1))
print(sess.run(cast2))
[False True True False True]
[ 0. 1. 1. 0. 1.]
[ 0. 1. 1. 0. 1.]
[ 0. 1. 1. 0. 1.]