【python】实现one-hot编码

tensorflow 封装的函数tf.one_hot():

import tensorflow as tf

import numpy as np

z=np.random.randint(0,10,size=[10])

y=tf.one_hot(z,10,on_value=1,off_value=None,axis=0)

with tf.Session()as sess:

print(z)

print(sess.run(y))

[5 7 7 0 5 5 2 0 0 0]

[[0 0 0 1 0 0 0 1 1 1]

[0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 1 0 0 0]

[0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 0 0 0 0]

[1 0 0 0 1 1 0 0 0 0]

[0 0 0 0 0 0 0 0 0 0]

[0 1 1 0 0 0 0 0 0 0]

[0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 0 0 0 0]]

瞅上去很完美,但是,tf.one_hot()会让你的graph动态增添节点,可能导致的后果就是内存泄露,模型越来越慢.....

所以当你发现模型占用内存越来越大,然后执行以下操作tf.get_default_graph().finalize()  报错的时候

就应该考虑自己实现one-hot()功能了。

def one_hot(labels,Label_class):

one_hot_label = np.array([[int(i == int(labels[j])) for i in range(Label_class)] for j in range(len(labels))])

return one_hot_label

举例:

扫描二维码关注公众号,回复: 3515291 查看本文章
import numpy as np

def one_hot(labels,Label_class):

one_hot_label = np.array([[int(i == int(labels[j])) for i in range(Label_class)] for j in range(len(labels))])

return one_hot_label


y = [2,5,6,7,8]

Label_class = 20

print one_hot(y,Label_class)

[[0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0]]

https://blog.csdn.net/shanyicheng1111/article/details/80007193

猜你喜欢

转载自blog.csdn.net/qq_34106574/article/details/82760537