参考:https://github.com/xmxoxo/BERT-train2deploy
1. 准备
-
下载bert源码
https://github.com/google-research/bert.git -
下载bert预训练模型,本文使用中文预训练模型
https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
- bert_model.ckpt:负责模型变量载入
- vocab.txt:字典
- bert_config.json:bert训练时的可调参数
2. 语料
建议采用 \t 进行分割,第一列为标签,第二列为文本,训练集保存为train.tsv,测试集保存为test.tsv,测试集可以仅有1列,无标签。如下图语料有3个标签。将 train.tsv、val.tsv、test.tsv放入同一目录下。
0 光度是光作用于人眼所引起的明亮程度的感觉。
1 黄樟素属于易制毒化学品。
2 在佛教中,对僧人的称呼一般有
3. 建立bert多分类模型
bert多分类仅需修改 run_classifier.py 文件中的代码
- 在main()中添加多分类任务 classifytask,该任务类 ClassifyProcessor 继承 DataProcessor,与其它的cola、mnli等任务类似
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"classifytask": ClassifyProcessor, # 自定义的多分类任务
}
- 编写ClassifyProcessor类
class ClassifyProcessor(DataProcessor):
# 与训练集中定义的标签名一致
def __init__(self):
self.labels = [0, 1, 2]
# 加载训练集
def get_train_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
# 加载验证集
def get_dev_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "val.tsv")), "val")
# 加载测试集
def get_test_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
# 获取标签
def get_labels(self):
return self.labels
# 读取数据
def _create_examples(self, lines, set_type):
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
if set_type == "test":
# 测试集没有标签,因此可以设置一个默认值
text_a = tokenization.convert_to_unicode(line[0])
label = 0
else:
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
4. 模型训练
新建脚本 start.sh,输入下列内容。
export DATA_DIR= # 语料路径
export BERT_BASE_DIR= # 预训练模型路径
python run_classifier.py \
--task_name=classifytask \
--do_train=true \ # 是否进行fine tune
--do_eval=true \ # 是否进行evaluation
--data_dir=$DATA_DIR/ \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--max_seq_length=128 \ # 句子的最长长度
--train_batch_size=32 \
--learning_rate=2e-5 \
--num_train_epochs=3.0 \
--output_dir=$BERT_BASE_DIR/output # 输出目录
运行sh命令,bert分类任务开始执行,建议采用gpu进行训练,否则训练过程可能在Saving checkpoints for 0 … 处卡住。
sh start.sh
5. 预测
训练完成后得到的模型文件如图,后续的预测任务即可调用该模型进行,通常会采用ckpt模型生成pb模型文件进行调用。
使用下列命令进行预测,将init_checkpoint定义为训练生成的模型即可。
export BERT_BASE_DIR=chinese_L-12_H-768_A-12
export NER_DIR=dat
export OUTPUT=output
python run_mobile.py \
--task_name=setiment \
--do_predict=true \
--data_dir=$NER_DIR/ \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$OUTPUT/model.ckpt-455 \
--max_seq_length=128 \
--output_dir=$OUTPUT/
输出的预测文件如图,一行为一条记录对应每个类别的概率,将概率进行转换即可得到分类标签。本次训练的准确率约为85%左右。