将transformers的tokenizer处理之后(如BPE)的序列映射回输入序列

我之前写过一些使用Huggingface的transformers生成embedding的方法,例如这篇:怎样通过预训练的Transformers的模型得到一个Sentence的Representation_蛐蛐蛐的博客-CSDN博客

但之前使用的时候,主要是生成一个sequence的embedding,很少涉及提取单个token的embedding。试了一下,发现比我想像得要复杂一些,所以总结一下。

由于往往会使用BPE等方法,导致输入到transformer encoder的序列和真实的输入序列并不等长,也没有什么严格的对应方式,所以要计算原序列中单个word的embedding,就需要使用其他API了。StackOverFlow上也有人讨论了这个问题:tokenize - Mapping huggingface tokens to original input text - Stack Overflow

但我仔细看了一下,发现已有解答并不是很正确,例如我要找到下面这个序列的对应关系:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('roberta-large', do_lower_case=True)
example = "push r15 push r14 mov r15 , r8 push r13 push r12 mov r12 , rdi"
encoded = tokenizer(example)
print(encoded['input_ids'])
print(len(encoded['input_ids']))
print(encoded.tokens())
print(len(encoded.tokens()))
print(encoded.word_ids())
print(len(encoded.word_ids()))

输出结果是:

[0, 41935, 910, 996, 1920, 910, 1570, 32924, 910, 996, 2156, 910, 398, 1920, 910, 1558, 1920, 910, 1092, 32924, 910, 1092, 2156, 910, 7506, 2]
26
['<s>', 'push', 'Ġr', '15', 'Ġpush', 'Ġr', '14', 'Ġmov', 'Ġr', '15', 'Ġ,', 'Ġr', '8', 'Ġpush', 'Ġr', '13', 'Ġpush', 'Ġr', '12', 'Ġmov', 'Ġr', '12', 'Ġ,', 'Ġr', 'di', '</s>']
26
[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 22, None]
26

上面我的输入是一串汇编代码,这个只是为了举例方便。简单来说,tokenizer(InputSentence)会返回一个BatchEncoding的对象,我们从这个对象的function中即可以得到需要的信息。从上面的输出中可以看到,encoded['input_ids']对应的就是输入到transformer encoder的tensor输入,.tokens()返回的是tokenize以后的token,.word_ids()返回的是tokenize以后的token的编码(而并不是输入sequence的)。那具体应该怎么样对应到原输入sequence呢?我尝试了很多种方法,发现只有利用token_to_chars这个function。具体的API文档可以参考这里:Tokenizer

例如,下面代码中我列举了各种情况:

for token_index in range(len(encoded.tokens())):
    this_token=encoded.word_ids()[token_index]
    if(not this_token==None):
        print('###########################')
        print(token_index)
        print(encoded.token_to_chars(token_index))
        print(encoded.token_to_word(token_index))
        char_span=encoded.token_to_chars(token_index)
        print('...')
        for char_index in range(char_span.start,char_span.end):
            print(encoded.char_to_word(char_index))
            print(encoded.char_to_token(char_index))
        print('...')
        print('###########################')

对应的输出结果是:

###########################
1
CharSpan(start=0, end=4)
0
...
0
1
0
1
0
1
0
1
...
###########################
###########################
2
CharSpan(start=5, end=6)
1
...
1
2
...
###########################
###########################
3
CharSpan(start=6, end=8)
2
...
2
3
2
3
...
###########################
###########################
4
CharSpan(start=9, end=13)
3
...
3
4
3
4
3
4
3
4
...
###########################
###########################
5
CharSpan(start=14, end=15)
4
...
4
5
...
###########################
###########################
6
CharSpan(start=15, end=17)
5
...
5
6
5
6
...
###########################
###########################
7
CharSpan(start=18, end=21)
6
...
6
7
6
7
6
7
...
###########################
###########################
8
CharSpan(start=22, end=23)
7
...
7
8
...
###########################
###########################
9
CharSpan(start=23, end=25)
8
...
8
9
8
9
...
###########################
###########################
10
CharSpan(start=26, end=27)
9
...
9
10
...
###########################
###########################
11
CharSpan(start=28, end=29)
10
...
10
11
...
###########################
###########################
12
CharSpan(start=29, end=30)
11
...
11
12
...
###########################
###########################
13
CharSpan(start=31, end=35)
12
...
12
13
12
13
12
13
12
13
...
###########################
###########################
14
CharSpan(start=36, end=37)
13
...
13
14
...
###########################
###########################
15
CharSpan(start=37, end=39)
14
...
14
15
14
15
...
###########################
###########################
16
CharSpan(start=40, end=44)
15
...
15
16
15
16
15
16
15
16
...
###########################
###########################
17
CharSpan(start=45, end=46)
16
...
16
17
...
###########################
###########################
18
CharSpan(start=46, end=48)
17
...
17
18
17
18
...
###########################
###########################
19
CharSpan(start=49, end=52)
18
...
18
19
18
19
18
19
...
###########################
###########################
20
CharSpan(start=53, end=54)
19
...
19
20
...
###########################
###########################
21
CharSpan(start=54, end=56)
20
...
20
21
20
21
...
###########################
###########################
22
CharSpan(start=57, end=58)
21
...
21
22
...
###########################
###########################
23
CharSpan(start=59, end=60)
22
...
22
23
...
###########################
###########################
24
CharSpan(start=60, end=62)
22
...
22
24
22
24
...
###########################

可以看到,只有token_to_chars可以返回一个CharSpan,这个表明的是在原输入序列中的位置。其他的什么char_to_token就是在编码以后的序列上绕来扰去,这个和API文档里说的完全不一样啊。有了这个function,我们可以构造两个mapping,来获得原输入序列中的word和编码以后的token之间的对应关系:

corpora_records=example.split(' ')
word_2_char_mapping={}
char_cursor=0
for ind in range(len(corpora_records)):
    if(len(corpora_records[ind])>0):#the last space will not be considered
        start=char_cursor
        end=char_cursor+len(corpora_records[ind])
        word_2_char_mapping[ind]=[start,end]
        char_cursor=char_cursor+len(corpora_records[ind])+1#consider the white-space length

print(word_2_char_mapping)

word_2_token_mapping={}
for token_index in range(len(encoded.tokens())):
    this_token=encoded.word_ids()[token_index]
    if(not this_token==None):
        char_span=encoded.token_to_chars(token_index)
        for each_word in word_2_char_mapping:
            start=word_2_char_mapping[each_word][0]
            end=word_2_char_mapping[each_word][1]
            if(char_span.start>=start and char_span.end<=end):
                # print(batch_encoding.tokens()[token_index])#check the results to make sure our mapping is correct.
                # print('--->')
                # print(corpora_records[each_word])

                if(each_word in word_2_token_mapping):
                    word_2_token_mapping[each_word].append(token_index)
                else:
                    word_2_token_mapping[each_word]=[token_index]

print(word_2_token_mapping)

对应的输出是:

[0, 41935, 910, 996, 1920, 910, 1570, 32924, 910, 996, 2156, 910, 398, 1920, 910, 1558, 1920, 910, 1092, 32924, 910, 1092, 2156, 910, 7506, 2]
26
['<s>', 'push', 'Ġr', '15', 'Ġpush', 'Ġr', '14', 'Ġmov', 'Ġr', '15', 'Ġ,', 'Ġr', '8', 'Ġpush', 'Ġr', '13', 'Ġpush', 'Ġr', '12', 'Ġmov', 'Ġr', '12', 'Ġ,', 'Ġr', 'di', '</s>']
26
[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 22, None]
26
{0: [0, 4], 1: [5, 8], 2: [9, 13], 3: [14, 17], 4: [18, 21], 5: [22, 25], 6: [26, 27], 7: [28, 30], 8: [31, 35], 9: [36, 39], 10: [40, 44], 11: [45, 48], 12: [49, 52], 13: [53, 56], 14: [57, 58], 15: [59, 62]}
{0: [1], 1: [2, 3], 2: [4], 3: [5, 6], 4: [7], 5: [8, 9], 6: [10], 7: [11, 12], 8: [13], 9: [14, 15], 10: [16], 11: [17, 18], 12: [19], 13: [20, 21], 14: [22], 15: [23, 24]}

为了展示方便,省略了之前很长的那些print,可以看到,最后生成的这个word_2_token_mapping是完全正确的。最后把所有代码贴一下,方便大家可以很快查看检验:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('roberta-large', do_lower_case=True)
example = "push r15 push r14 mov r15 , r8 push r13 push r12 mov r12 , rdi"
encoded = tokenizer(example)
print(encoded['input_ids'])
print(len(encoded['input_ids']))
print(encoded.tokens())
print(len(encoded.tokens()))
print(encoded.word_ids())
print(len(encoded.word_ids()))

# for token_index in range(len(encoded.tokens())):
#     this_token=encoded.word_ids()[token_index]
#     if(not this_token==None):
#         print('###########################')
#         print(token_index)
#         print(encoded.token_to_chars(token_index))
#         print(encoded.token_to_word(token_index))
#         char_span=encoded.token_to_chars(token_index)
#         print('...')
#         for char_index in range(char_span.start,char_span.end):
#             print(encoded.char_to_word(char_index))
#             print(encoded.char_to_token(char_index))
#         print('...')
#         print('###########################')

corpora_records=example.split(' ')
word_2_char_mapping={}
char_cursor=0
for ind in range(len(corpora_records)):
    if(len(corpora_records[ind])>0):#the last space will not be considered
        start=char_cursor
        end=char_cursor+len(corpora_records[ind])
        word_2_char_mapping[ind]=[start,end]
        char_cursor=char_cursor+len(corpora_records[ind])+1#consider the white-space length

print(word_2_char_mapping)

word_2_token_mapping={}
for token_index in range(len(encoded.tokens())):
    this_token=encoded.word_ids()[token_index]
    if(not this_token==None):
        char_span=encoded.token_to_chars(token_index)
        for each_word in word_2_char_mapping:
            start=word_2_char_mapping[each_word][0]
            end=word_2_char_mapping[each_word][1]
            if(char_span.start>=start and char_span.end<=end):
                # print(batch_encoding.tokens()[token_index])#check the results to make sure our mapping is correct.
                # print('--->')
                # print(corpora_records[each_word])

                if(each_word in word_2_token_mapping):
                    word_2_token_mapping[each_word].append(token_index)
                else:
                    word_2_token_mapping[each_word]=[token_index]

print(word_2_token_mapping)

那些Stackoverflow上的回答,经过我仔细检查测试,发现不太靠谱。不过我也是看上面回答才知道能使用token_to_chars这个function。就简单总结这么多吧。

猜你喜欢

转载自blog.csdn.net/qysh123/article/details/126438203