tensorflow中tf.one_hot()函数的作用是将一个值化为one-hot编码的向量,指的是在分类问题中,将存在数据类别的那一类用X表示,不存在的用Y表示,这里的X常常是1, Y常常是0。
one-hot 的API如下:
one_hot( indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None )
需要指定
indices
,和
depth
,其中
depth
是编码深度。
例子:
var = tf.one_hot(indices=[
1
,
2
,
3
], depth=
4
, axis=
0
)
with
tf.Session()
as
sess:
sess.run(tf.global_variables_initializer())
a = sess.run(var)
print(a)
输出:
[[
0.
1.
0.
0.
]
[
0.
0.
1.
0.
]
[
0.
0.
0.
1.
]]