keras 整理之 Layers

版权声明:本文为博主原创文章,欢迎交流分享,未经博主允许不得转载。 https://blog.csdn.net/HHTNAN/article/details/82493952

» 嵌入层 Embedding

Embedding

keras.layers.Embedding(input_dim, output_dim, 
embeddings_initializer='uniform', embeddings_regularizer=None, 
activity_regularizer=None, embeddings_constraint=None, mask_zero=False, 
input_length=None)

将正整数(索引值)转换为固定尺寸的稠密向量。 例如: [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]

该层只能用作模型中的第一层。

参数

input_dim: int > 0。词汇表大小, 即,最大整数 index + 1。
output_dim: int >= 0。词向量的维度。
embeddings_initializer: embeddings 矩阵的初始化方法 (详见 initializers)。
embeddings_regularizer: embeddings matrix 的正则化方法 (详见 regularizer)。
embeddings_constraint: embeddings matrix 的约束函数 (详见 constraints)。
mask_zero: 是否把 0 看作为一个应该被遮蔽的特殊的 "padding" 值。 这对于可变长的 循环神经网络层 十分有用。 如果设定为 True,那么接下来的所有层都必须支持 masking,否则就会抛出异常。 如果 mask_zero 为 True,作为结果,索引 0 就不能被用于词汇表中 (input_dim 应该与 vocabulary + 1 大小相同)。
input_length: 输入序列的长度,当它是固定的时。 如果你需要连接 Flatten 和 Dense 层,则这个参数是必须的 (没有它,dense 层的输出尺寸就无法计算)。

输入尺寸

尺寸为 (batch_size, sequence_length) 的 2D 张量。

输出尺寸

尺寸为 (batch_size, sequence_length, output_dim) 的 3D 张量。

参考文献

A Theoretically Grounded Application of Dropout in Recurrent Neural Networks

案例:

from keras.models import Sequential
from keras.layers import Embedding, Bidirectional, LSTM
from keras_contrib.layers import CRF
import numpy as np

model = Sequential()
model.add(Embedding(input_dim=1000, output_dim=60, input_length=10))
# 模型将输入一个大小为 (batch, input_length) 的整数矩阵。
# 输入中最大的整数(即词索引)不应该大于 999 (词汇表大小)
# 现在 model.output_shape == (None, 10, 60),其中 None 是 batch 的维度。

input_array = np.random.randint(1000, size=(32, 10))
print("input_array.shape={},len(input_array)={}".format(input_array.shape,len(input_array)))
print(input_array)

model.compile('rmsprop', 'mse')
output_array = model.predict(input_array)
assert output_array.shape == (32, 10, 60)
input_array.shape=(32, 10),len(input_array)=32
[[ 21 772 551 347 451 993 593 219 923 117]
 [711 600 601  66 984 581 671 292 963  39]
 [810 978 800 377 224  68 113 526 466 258]
 [908 145 471 724 519 795 926 904 879  29]
 [475 230 469 157   0 715 274 680 880 820]
 [344 889  34 938 915 563 384 947 752 405]
 [302 371 427  77 861  99 352 467 438 653]
 [682 536 321 221 137  48 387 380  36 409]
 [569 812 825 751 850   8 704 532 443 973]
 [226 634 491 294 512  65 434  88 653  76]
 [229 419 633 426 751 966 599 794 404 488]
 [792 259 833 130  65 561 361 282 815 372]
 [733 282 692 434 949 939 221 847 425 341]
 [666 510 690 842 801 981 556 777  10 438]
 [156  71 338 705 475 548  48 766 317 237]
 [109 919 138 640 508 522 236  17 444 604]
 [869 817 372 725 369  24  78 330 910 684]
 [573 579 409  41  83 310 591 617   0  56]
 [669 327 353  92 238 741 429 692 626 174]
 [924 328  43 529 329 409 929  44 204 114]
 [981 408  10 212 999 150 233 384 911 557]
 [ 14 615 573 565 422 899  35 498 204 534]
 [126 906 160 352 690 405 427 422 657 693]
 [821 520 896 164 898 539 450 355 236 292]
 [390 970 631  93 112 589 506 625  76 436]
 [732 790 494 874 113 131 657 426 558 398]
 [753 748 146 554 255 849 824 766 954 809]
 [ 96 997 313 376 986 839 378 959 689 395]
 [ 98 502 699 400 131 718  20 619 909 385]
 [867 757 430 605  63 172 964 344 835 309]
 [637 746 759 790 382 811 647 899 867 580]
 [478 284 838 146 428 637 311 221 175 849]]

猜你喜欢

转载自blog.csdn.net/HHTNAN/article/details/82493952