[pytorch] Eine kurze Einführung in nn.EMBEDDING

1. Referenzen

Angenommen, es gibt ein Wörterbuch mit insgesamt nur 10 Wörtern, und jedes Wort besteht aus 5 Buchstaben. Auf jede Seite wird nur ein Wort geschrieben, also werden diese 10 Wörter auf jede der 10 Seiten geschrieben.

Innen wie folgt,

[
[a,p,p,l,e],  # page 0
[g,r,e,e,n],  # page 1
[s,m,a,l,l],  # page 2
[w,a,t,c,h],  # page 3
[b,a,s,i,c],  # page 4
[e,n,j,o,y],  # page 5
[c,l,a,s,s],  # page 6
[e,m,b,e,d],  # page 7
[h,a,p,p,y],  # page 8
[p,l,a,t,e]   # page 9
]

Wir gehen davon aus, dass das Wörterbuch aufgerufen wirdEinbettung(10,5), die 10 und 5 hier sind die oben eingeführten Bedeutungen, 10 Wörter, jedes Wort hat 5 Buchstaben;

Jetzt möchte ich Seite 2 und 3 (beginnend bei 0) anzeigen, dann bekomme ich [s,m,a,l,l], [w,a,t,c,h] Inhalt.

Angenommen, wir einigen uns auf ein Passwort, Sie sagen mir die Anzahl der Seiten, und ich gebe die Wörter zurück, die der Anzahl der Seiten entsprechen.

Zum Beispiel schickst du mir das Passwort [ [2,3], [1,0], [8,6] ] ( also den LongTensor mit der Form (3, 2) )

Ich sage es Ihnen, indem ich das Wörterbuch abfrage

[
[ [s,m,a,l,l],  [w,a,t,c,h] ],
[ [g,r,e,e,n],   [a,p,p,l,e] ],
[ [h,a,p,p,y],   [c,l,a,s,s] ]
]

Das Wörterbuch ist hier die Einbettungstabelle , und das Passwort soll den Indexwert dieser Tabelle abfragen.


2. Warum ist eine Einbettung erforderlich?

Manchmal ist das, was wir intuitiv sehen, nicht unbedingt die Essenz der Dinge, wir müssen die „wesentlichen Merkmale“ oder „versteckten Merkmale“ durch das Phänomen sehen. Wie kommst du also durch? Oder was ist ein "verstecktes Feature"?

Embeding macht das. Es fragt die „embeding table“ ab, um die „verborgenen Merkmale“ eines Satzes oder eines Tonstücks zu erhalten.

Die Einbettungstabelle ist im Allgemeinen ein Satz von Fließkommawerten, die mit CNN- und LSTM-Netzwerken identisch sind und zu den Parametern gehören, die vom Netzwerk gelernt werden können.
Daher wird sein Wert nicht von Menschen definiert, und ein solches „Wörterbuch“ kann nicht von Menschen definiert werden, sondern wird schrittweise durch Deep-Learning-Netzwerke erlernt.


3. Zurück zu Python

nn.Einbettung in Pytorch bietet eine solche Implementierung;

Unten ist ein Beispiel

import torch

# 如同上面例子中的page索引
a = torch.LongTensor([[1,2], [5,2]]) 

# 一个10个单词,每个单词5个字母的字典
emb = torch.nn.Embedding(10,5)
print(emb.weight, emb.weight.shape)

# 同过索引查询embeding内容
y = emb(a)
print(y, y.shape)

Bildbeschreibung hier einfügen

Es ist ersichtlich, dass das „Wörterbuch“ kein Wort mehr ist, sondern einige Fließkommazahlen, die verborgene Merkmale darstellen.


4. Offizielle APIs

Einbettung
Bildbeschreibung hier einfügen

4.1 Parametereinführung

4.1.1 num_embedding und embedding_dim

num_embedding , embedding_dim sind die oben eingeführten "Wörter" und "Anzahl der Buchstaben pro Wort", die die Anzahl der Einbettungen im Wörterbuch und die Dimension jeder Einbettung darstellen.

4.1.2 padding_idx

padding_idx ist der Index des „Wortes“, das den Gradienten nicht aktualisiert, eine nicht trainierte Einbettung kann im Dictionary angegeben werden.
Siehe das Beispiel unten:

import torch

a = torch.LongTensor([[1,2], [5,2]])

emb = torch.nn.Embedding(10,5, padding_idx=0)
print(emb.weight, emb.weight.shape)
y = emb(a)
print(y, y.shape)

Bildbeschreibung hier einfügen
Hier padding_index=0, was bedeutet, dass die Einbettung unter diesem Index nicht lernen wird, sich zu aktualisieren, und der Standardwert ist auch 0 während der Initialisierung.

4.1.3 max_norm und norm_type

max_norm , norm_type werden nach Erhalt der Einbettung regularisiert; die möglichen Werte von
norm_type sind 1 und 2. Sie repräsentieren Paradigma 1 bzw. Paradigma 2, und der Standardwert ist 2.

max_norm ist der Maximalwert im Definitionsparadigma. Wenn der Wert in der Einbettung größer als dieser Schwellenwert ist, wird die Norm neu erstellt.

Guess you like

Origin blog.csdn.net/mimiduck/article/details/127095433