BERT预训练模型字向量提取工具--使用BERT编码句子

本文将介绍三个使用BERT编码句子(从BERT中提取向量)的工具。
(1)Embedding-as-service
github
这个库类似于bert-as-service,可以将句子编码成固定长度的向量,目前支持的预训练模型有BERT、ALBERT、XLNet、ELMO、Golve、word2vec等,我们可以将其作为我们模型的一部分,也可以将其作为一个服务直接使用它编码我们的句子。下面给除一个其作为服务编码句子的示例:

>>> from embedding_as_service_client import EmbeddingClient
>>> en = EmbeddingClient(host=<host_server_ip>, port=<host_port>)
>>> vecs = en.encode(texts=['hello aman', 'how are you?'])  
>>> vecs  
array([[[ 1.7049843 ,  0.        ,  1.3486509 , ..., -1.3647075 ,  
 0.6958289 ,  1.8013777 ], ... [ 0.4913215 ,  0.60877025,  0.73050433, ..., -0.64490885, 0.8525057 ,  0.3080206 ]]], dtype=float32)  
>>> vecs.shape  
(2, 128, 768) # batch x max_sequence_length x embedding_size  

详细内容请点击上方给出的github链接。

(2)BERT预训练模型字向量提取工具
本工具直接读取BERT预训练模型,从中提取样本文件中所有使用到字向量,保存成向量文件,为后续模型提供embdding。

本工具直接读取预训练模型,不需要其它的依赖,同时把样本中所有 出现的字符对应的字向量全部提取,后续的模型可以非常快速进行embdding
github完整源码

#!/usr/bin/env python
# coding: utf-8

__author__ = 'xmxoxo<[email protected]>'

'''
BERT预训练模型字向量提取工具
版本: v 0.3.2
更新:  2020/3/25 11:11
git: https://github.com/xmxoxo/BERT-Vector/
'''

import argparse
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
import numpy as np
import os
import sys
import traceback
import pickle


gblVersion = '0.3.2'
# 如果模型的文件名不同,可修改此处
model_name = 'bert_model.ckpt'
vocab_name = 'vocab.txt'

# BERT embdding提取类 
class bert_embdding(): 
    def __init__(self, model_path='', fmt='pkl'):
        # 模型和词表的文件名
        ckpt_path = os.path.join(model_path, model_name)
        vocab_file = os.path.join(model_path, vocab_name)
        if not os.path.isfile(vocab_file):
            print('词表文件不存在,请检查...')
            #sys.exit()
            return 
        
        # 从模型读出指定层
        reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
        #param_dict = reader.get_variable_to_shape_map()
        self.emb = reader.get_tensor("bert/embeddings/word_embeddings")
        self.vocab = open(vocab_file,'r', encoding='utf-8').read().split("\n")
        print('embeddings size: %s' % str(self.emb.shape))
        print('词表大小:%d' % len(self.vocab))

        # 兼容不同格式
        self.fmt=fmt

    # 取出指定字符的embdding,返回向量
    def get_embdding (self, char):
        if char in self.vocab:
            index = self.vocab.index(char)
            return self.emb[index,:]
        else:
            return None

    # 根据字符串提取向量并保存到文件
    def export (self, txt_all, out_file=''):
        # 过滤重复,形成字典
        txt_lst = sorted(list(set(txt_all)))

        print('文本字典长度:%d, 正在提取字向量...' % len(txt_lst))
        count = 0
        # 可选择输出哪种格式 2020/3/25 
        if self.fmt=='pkl':
            print('正在保存为pkl格式文件...')
            # 使用字典存储,使用时更加方便。 2020/3/23
            lst_vector = dict()
            for word in txt_lst:
                v = self.get_embdding(word)
                if not (v is None):
                    count += 1
                    lst_vector[word] = v

            # 改为使用pickle保存文件 2020/3/23
            with open(out_file, 'wb') as out: 
                pickle.dump(lst_vector, out, 2)

        if self.fmt=='txt':
            print('正在保存为txt格式文件...')
            with open(out_file, 'w', encoding='utf-8') as out: 
                for word in txt_lst:
                    v = self.get_embdding(word)
                    if not (v is None):
                        count += 1
                        out.write(word + " " + " ".join([str(i) for i in v])+"\n")

        print('字向量共提取:%d个' % count)

    # get all files and floders in a path
    # fileExt: ['png','jpg','jpeg']
    # return: 
    #    return a list ,include floders and files , like [['./aa'],['./aa/abc.txt']]
    @staticmethod
    def getFiles (workpath, fileExt = []):
        try:
            lstFiles = []
            lstFloders = []

            if os.path.isdir(workpath):
                for dirname in os.listdir(workpath) :
                    file_path = os.path.join(workpath, dirname)
                    if os.path.isfile(file_path):
                        if fileExt:
                            if dirname[dirname.rfind('.')+1:] in fileExt:
                               lstFiles.append (file_path)
                        else:
                            lstFiles.append (file_path)
                    if os.path.isdir( file_path ):
                        lstFloders.append (file_path)      

            elif os.path.isfile(workpath):
                lstFiles.append(workpath)
            else:
                return None
            
            lstRet = [lstFloders,lstFiles]
            return lstRet
        except Exception as e :
            return None

    # 增加批量处理目录下的某类文件 v 0.1.2  xmxoxo 2020/3/23
    def export_path (self, path, ext=['csv','txt'], out_file=''):
        try:
            files = self.getFiles(path,ext)
            # 合并数据
            txt_all = []
            tmp = ''
            for fn in files[1]:
                print('正在读取数据文件:%s' % fn)
                with open(fn, 'r', encoding='utf-8') as f: 
                    tmp = f.read()
                txt_all += list(set(tmp))
                txt_all = list(set(txt_all))
            
            self.export(txt_all, out_file=out_file)

        except Exception as e :
            print('批量处理出错:')
            print('Error in get_randstr: '+ traceback.format_exc())
            return None

# 命令行
def main_cli ():
    parser = argparse.ArgumentParser(description='BERT模型字向量提取工具')
    parser.add_argument('-v', '--version', action='version', version='%(prog)s ' + gblVersion )
    parser.add_argument('--model_path', default='', required=True, type=str, help='BERT预训练模型的目录')
    parser.add_argument('--in_file', default='', required=True, type=str, help='待提取的文件名或者目录名')
    parser.add_argument('--out_file', default='./bert_embedding.pkl', type=str,  help='输出文件名')
    parser.add_argument('--ext', default=['csv','txt'], type=str, nargs='+', help='指定目录时读取的数据文件扩展名')
    parser.add_argument('--fmt', default='pkl', type=str, help='输出文件的格式,可设置txt或者pkl, 默认为pkl')

    args = parser.parse_args()

    # 预训练模型的目录
    model_path = args.model_path
    # 输出文件名
    out_file = args.out_file
    # 包含所有文本的内容
    in_file = args.in_file
    # 指定的扩展名
    ext = args.ext
    # 文件格式
    fmt = args.fmt
    if not fmt in ['pkl', 'txt']:
        fmt='pkl'
    
    if fmt=='txt' and out_file[-4:]=='.pkl':
        out_file = out_file[:-3] + 'txt'

    if not os.path.isdir(model_path):
        print('模型目录不存在,请检查:%s' % model_path)
        sys.exit()

    if not (os.path.isfile(in_file) or os.path.isdir(in_file)):
        print('数据文件不存在,请检查:%s' % in_file)
        sys.exit()
    print('\nBERT 字向量提取工具 V' + gblVersion )
    print('-'*40)
   
    bertemb = bert_embdding(model_path=model_path, fmt=fmt)
    # 针对文件和目录分别处理 2020/3/23 by xmxoxo
    if os.path.isfile(in_file):
        txt_all = open(in_file,'r', encoding='utf-8').read()
        bertemb.export(txt_all, out_file=out_file)
    if os.path.isdir(in_file):
        bertemb.export_path(in_file, ext=ext, out_file=out_file)

if __name__ == '__main__':
    pass
    main_cli()

(3)使用BERT编码句子
本文将BERT进行了封装,我们可以直接输入句子并得到句子对应的向量。
如下所示:

from bert_encoder import BertEncoder
be = BertEncoder()
embedding = be.encode("新年快乐,恭喜发财,万事如意!")
print(embedding)
print(embedding.shape)

完整封装:
完整代码

# -*- coding:utf-8 -*-

import os
from bert import modeling
import tensorflow as tf
from bert import tokenization

flags = tf.flags
FLAGS = flags.FLAGS

bert_path = r'chinese_L-12_H-768_A-12'
root_path = os.getcwd()

flags.DEFINE_string(
    "bert_config_file", os.path.join(bert_path, 'bert_config.json'),
    "The config json file corresponding to the pre-trained BERT model."
)
flags.DEFINE_string("vocab_file", os.path.join(bert_path, 'vocab.txt'),
                    "The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text."
)
flags.DEFINE_integer(
    "max_seq_length", 128,
    "The maximum total input sequence length after WordPiece tokenization."
)

bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

def data_preprocess(sentence):
    tokens = []
    for i, word in enumerate(sentence):
        # 分词,如果是中文,就是分字
        token = tokenizer.tokenize(word)
        tokens.extend(token)
    # 序列截断
    if len(tokens) >= FLAGS.max_seq_length - 1:
        tokens = tokens[0:(FLAGS.max_seq_length - 2)]  # -2 的原因是因为序列需要加一个句首和句尾标志
    ntokens = []
    segment_ids = []
    ntokens.append("[CLS]")  # 句子开始设置CLS 标志
    segment_ids.append(0)
    # append("O") or append("[CLS]") not sure!
    for i, token in enumerate(tokens):
        ntokens.append(token)
        segment_ids.append(0)
    ntokens.append("[SEP]")  # 句尾添加[SEP] 标志
    segment_ids.append(0)
    # append("O") or append("[SEP]") not sure!
    input_ids = tokenizer.convert_tokens_to_ids(ntokens)  # 将序列中的字(ntokens)转化为ID形式
    # print(input_ids)
    input_mask = [1] * len(input_ids)
    # print(input_mask)
    while len(input_ids) < FLAGS.max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
    input_ids = [input_ids]
    return input_ids, input_mask

class BertEncoder(object):

    def __init__(self):
        self.bert_model = modeling.BertModel(config=bert_config, is_training=False, max_seq_length=FLAGS.max_seq_length)
        tvars = tf.trainable_variables()
        (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, FLAGS.init_cheeckpoint)
        tf.train.init_from_checkpoint(FLAGS.init_cheeckpoint, assignment_map)
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())

    def encode(self, sentence):
        input_ids, input_mask = data_preprocess(sentence)
        return self.sess.run(self.bert_model.embedding_output, feed_dict={
    
    self.bert_model.input_ids:input_ids})



if __name__ == "__main__":
    be = BertEncoder()
    embedding = be.encode("新年快乐,恭喜发财,万事如意!")
    print(embedding)
    print(embedding.shape)

关注编程ABC,靠近算法和NLP~
在这里插入图片描述

参考:
https://github.com/xmxoxo/BERT-Vector
https://github.com/lzphahaha/bert_encoder
https://github.com/amansrivastava17/embedding-as-service#-supported-embeddings-and-models

猜你喜欢

转载自blog.csdn.net/broccoli2/article/details/105465585