tf.cast() 用法

tf.cast(x, DstT, name)

参数: x --> 张量Tensor

             DstT  --> tf.DType 要转换的目标类型

             name  --> 运算名称(可选)

一般用于真实值和预测值比较后的布尔型转换为浮点型进行后续计算。

import tensorflow as tf
import numpy as np
 
y_pre = [0.9, 1.2, 0.75, 0.5, 0.8]
y = [0.8, 1.2, 0.75, 0.9, 0.8]
equal =  tf.equal(y_pre, y)
cast = tf.cast(equal, 'float')
cast1 = tf.cast(equal, dtype = float)
cast2 = tf.cast(equal, dtype = tf.float32)
with tf.Session() as sess:
    print(sess.run(equal))
    print(sess.run(cast))
    print(sess.run(cast1))
    print(sess.run(cast2))
[False  True  True False  True]
[ 0.  1.  1.  0.  1.]
[ 0.  1.  1.  0.  1.]
[ 0.  1.  1.  0.  1.]

猜你喜欢

转载自blog.csdn.net/Muzi_Water/article/details/81363372