tensorflow—tf.one_hot()函数

tensorflow中tf.one_hot()函数的作用是将一个值化为one-hot编码的向量,指的是在分类问题中,将存在数据类别的那一类用X表示,不存在的用Y表示,这里的X常常是1, Y常常是0。
one-hot 的API如下:
one_hot( indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None )

需要指定 indices ,和 depth ,其中 depth 是编码深度。
例子:
var = tf.one_hot(indices=[ 1 , 2 , 3 ], depth= 4 , axis= 0 )  
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    a = sess.run(var)
    print(a)
输出:
[[ 0. 1. 0. 0. ]
[ 0. 0. 1. 0. ]
[ 0. 0. 0. 1. ]]



猜你喜欢

转载自blog.csdn.net/qq_27150893/article/details/80889859