import math
import json
import re
import random
import numpy as np
from collections import defaultdict
import cn2an
from tqdm import tqdm
from nl2sql.utils import read_data, read_tables, SQL, Query, Question, Table
from keras_bert import get_checkpoint_paths, load_vocabulary, Tokenizer, load_trained_model_from_checkpoint
from keras.utils.data_utils import Sequence
from keras.preprocessing.sequence import pad_sequences
from keras.layers import Input, Lambda, Dense
from keras.models import Model
from keras.optimizers import Adam
from keras.utils import multi_gpu_model
Using TensorFlow backend.
E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
Configuration
train_table_file = 'E:/zym_test/test/nlp/data/train/train.tables.json'
train_data_file = 'E:/zym_test/test/nlp/data/train/train.json'
val_table_file = 'E:/zym_test/test/nlp/data/val/val.tables.json'
val_data_file = 'E:/zym_test/test/nlp/data/val/val.json'
test_table_file = 'E:/zym_test/test/nlp/data/test/test.tables.json'
test_data_file = 'E:/zym_test/test/nlp/data/test/test.json'
# Download pretrained BERT model from https://github.com/ymcui/Chinese-BERT-wwm
bert_model_path = 'E:\\zym_test\\test\\nlp\\base-line\\chinese_wwm_ext_L-12_H-768_A-12'
paths = get_checkpoint_paths(bert_model_path)
task1_file = 'task1_output.json'
数据的读取
train_tables = read_tables(train_table_file)
train_data = read_data(train_data_file, train_tables)
val_tables = read_tables(val_table_file)
val_data = read_data(val_data_file, val_tables)
test_tables = read_tables(test_table_file)
test_data = read_data(test_data_file, test_tables)
构建Dataset
# is_float():判断是否为数字
def is_float(value):
try:
float(value)
return True
except ValueError:
return False
# 将中文数字转换为阿拉伯数字
def cn_to_an(string):
try:
# normal 表示“一二三”也可以转换为“123”
return str(cn2an.cn2an(string, 'normal'))
except ValueError:
return string
# 将阿拉伯数字转换为中文数字
def an_to_cn(string):
try:
return str(cn2an.an2cn(string))
except ValueError:
return string
# 将字符串转换为数字
def str_to_num(string):
try:
float_val = float(cn_to_an(string))
if int(float_val) == float_val:
return str(int(float_val))
else:
return str(float_val)
except ValueError:
return None
# 年份的转化(仅有数字后两位,再加2000,为2000年以后的时间)
def str_to_year(string):
year = string.replace('年', '')
year = cn_to_an(year)
if is_float(year) and float(year) < 1900:
year = int(year) + 2000
return str(year)
else:
return None
# 加载json文件
def load_json(json_file):
result = []
if json_file:
with open(json_file) as file:
for line in file:
result.append(json.loads(line))
return result
小demo
print("is_float:{}->{}".format('abc', is_float('abc')))
print("is_float:{}->{}".format('1', is_float('1')))
print("is_float:{}->{}".format('1.', is_float('1.')))
print("cn_to_an:{}->{}".format('五百五十', cn_to_an('五百五十')))
print("cn_to_an:{}->{}".format('abc', cn_to_an('abc')))
print("cn_to_an:{}->{}".format('一二三', cn_to_an('一二三')))
print("an_to_cn:{}->{}".format('1', an_to_cn('1')))
print("an_to_cn:{}->{}".format('123', an_to_cn('123')))
print("an_to_cn:{}->{}".format('cb', an_to_cn('cb')))
print("str_to_num:{}->{}".format('123', str_to_num('123')))
print("str_to_num:{}->{}".format('cb', str_to_num('cb')))
print("str_to_year:{}->{}".format('20年',str_to_year('20年')))
print("str_to_year:{}->{}".format('2020年',str_to_year('2020年')))
print("str_to_year:{}->{}".format('1800年',str_to_year('1800年')))
print("str_to_year:{}->{}".format('一九年',str_to_year('一九年')))
print("str_to_year:{}->{}".format('二零一九年',str_to_year('二零一九年')))
is_float:abc->False
is_float:1->True
is_float:1.->True
cn_to_an:五百五十->550
cn_to_an:abc->abc
cn_to_an:一二三->123
an_to_cn:1->一
an_to_cn:123->一百二十三
an_to_cn:cb->cb
str_to_num:123->123
str_to_num:cb->None
str_to_year:20年->2020
str_to_year:2020年->None
str_to_year:1800年->3800
str_to_year:一九年->2019
str_to_year:二零一九年->None
class QuestionCondPair:
def __init__(self, query_id, question, cond_text, cond_sql, label):
self.query_id = query_id
self.question = question
self.cond_text = cond_text
self.cond_sql = cond_sql
self.label = label
def __repr__(self):
repr_str = ''
repr_str += 'query_id: {}\n'.format(self.query_id)
repr_str += 'question: {}\n'.format(self.question)
repr_str += 'cond_text: {}\n'.format(self.cond_text)
repr_str += 'cond_sql: {}\n'.format(self.cond_sql)
repr_str += 'label: {}\n'.format(self.label)
return repr_str
class NegativeSampler:
"""
从 question - cond pairs 中采样
"""
def __init__(self, neg_sample_ratio=10):
self.neg_sample_ratio = neg_sample_ratio
# 区分正样本和负样本,抽取部分负样本与正样本组成新样本
def sample(self, data):
positive_data = [d for d in data if d.label == 1]
negative_data = [d for d in data if d.label == 0]
negative_sample = random.sample(negative_data,
len(positive_data) * self.neg_sample_ratio)
return positive_data + negative_sample
class FullSampler:
"""
不抽样,返回所有的 pairs
"""
def sample(self, data):
return data
小demo
a = [d for d in [1,2,3] if d < 10]
c = []
for d in [1,2,3]:
if d<10:
c.append(d)
print(a,c)
print(a+c)
[1, 2, 3] [1, 2, 3]
[1, 2, 3, 1, 2, 3]
class CandidateCondsExtractor:
"""
params:
- share_candidates: 在同 table 同 column 中共享 real 型 candidates
"""
CN_NUM = '〇一二三四五六七八九零壹贰叁肆伍陆柒捌玖貮两'
CN_UNIT = '十拾百佰千仟万萬亿億兆点'
def __init__(self, share_candidates=True):
self.share_candidates = share_candidates
self._cached = False
# 构建候选缓存
def build_candidate_cache(self, queries):
# defaultdict(set)表示当字典:self.cache不存在所索引的key时,那么返回set()
self.cache = defaultdict(set)
print('building candidate cache')
# tqdm是python进度条,用问题总数作为进度条总长度
# query_id, query为问题的id索引和问题
for query_id, query in tqdm(enumerate(queries), total=len(queries)):
# 文本中的数字、年份信息提取出来
value_in_question = self.extract_values_from_text(query.question.text)
# 从table的每一列提取与问题中有相同字的值
for col_id, (col_name, col_type) in enumerate(query.table.header):
value_in_column = self.extract_values_from_column(query, col_id)
if col_type == 'text':
cond_values = value_in_column
elif col_type == 'real':
if len(value_in_column) == 1:
cond_values = value_in_column + value_in_question
else:
cond_values = value_in_question
cache_key = self.get_cache_key(query_id, query, col_id)
self.cache[cache_key].update(cond_values)
self._cached = True
def get_cache_key(self, query_id, query, col_id):
if self.share_candidates:
return (query.table.id, col_id)
else:
return (query_id, query.table.id, col_id)
# 将年份信息提取出来
def extract_year_from_text(self, text):
values = []
# 从text寻找'数字' '数字' '年'
num_year_texts = re.findall(r'[0-9][0-9]年', text)
# 将两位数字加‘年’与 20 合并
values += ['20{}'.format(text[:-1]) for text in num_year_texts]
# 将中文文本中的年份信息检索出来
cn_year_texts = re.findall(r'[{}][{}]年'.format(self.CN_NUM, self.CN_NUM), text)
# 将中文数字转化为阿拉伯数字
cn_year_values = [str_to_year(text) for text in cn_year_texts]
values += [value for value in cn_year_values if value is not None]
return values
# 将数字、符号信息提取出来
def extract_num_from_text(self, text):
values = []
num_values = re.findall(r'[-+]?[0-9]*\.?[0-9]+', text)
values += num_values
cn_num_unit = self.CN_NUM + self.CN_UNIT
cn_num_texts = re.findall(r'[{}]*\.?[{}]+'.format(cn_num_unit, cn_num_unit), text)
cn_num_values = [str_to_num(text) for text in cn_num_texts]
values += [value for value in cn_num_values if value is not None]
cn_num_mix = re.findall(r'[0-9]*\.?[{}]+'.format(self.CN_UNIT), text)
for word in cn_num_mix:
num = re.findall(r'[-+]?[0-9]*\.?[0-9]+', word)
for n in num:
word = word.replace(n, an_to_cn(n))
str_num = str_to_num(word)
if str_num is not None:
values.append(str_num)
return values
def extract_values_from_text(self, text):
values = []
values += self.extract_year_from_text(text)
values += self.extract_num_from_text(text)
return list(set(values))
# 从问题中提取字符,然后与表格中的字符进行对比,提取在问题中出现的字符
def extract_values_from_column(self, query, col_ids):
question = query.question.text
question_chars = set(query.question.text)
unique_col_values = set(query.table.df.iloc[:, col_ids].astype(str))
select_col_values = [v for v in unique_col_values
if (question_chars & set(v))]
return select_col_values
text = "今年是2020年,明年是2021年,去年是二零一九年"
text_1 = '我的天哪哈哈哈18年,二零二二年,一九年'
CN_NUM = '〇一二三四五六七八九零壹贰叁肆伍陆柒捌玖貮两'
CN_UNIT = '十拾百佰千仟万萬亿億兆点'
values = []
num_year_texts = re.findall(r'[0-9][0-9]年', text)
print(num_year_texts)
values += ['20{}'.format(text[:-1]) for text in num_year_texts]
print(values)
for text in num_year_texts:
print(text)
# 去掉年份
print(text[:-1])
cn_year_texts = re.findall(r'[{}][{}]年'.format(CN_NUM, CN_NUM), text_1)
print(cn_year_texts)
cn_year_values = [str_to_year(text) for text in cn_year_texts]
print(cn_year_values)
['20年', '21年']
['2020', '2021']
20年
20
21年
21
['二二年', '一九年']
['2022', '2019']
class QuestionCondPairsDataset:
"""
question - cond pairs 数据集
"""
OP_PATTERN = {
'real':
[
{'cond_op_idx': 0, 'pattern': '{col_name}大于{value}'},
{'cond_op_idx': 1, 'pattern': '{col_name}小于{value}'},
{'cond_op_idx': 2, 'pattern': '{col_name}是{value}'}
],
'text':
[
{'cond_op_idx': 2, 'pattern': '{col_name}是{value}'}
]
}
def __init__(self, queries, candidate_extractor, has_label=True, model_1_outputs=None):
self.candidate_extractor = candidate_extractor
self.has_label = has_label
self.model_1_outputs = model_1_outputs
self.data = self.build_dataset(queries)
def build_dataset(self, queries):
if not self.candidate_extractor._cached:
self.candidate_extractor.build_candidate_cache(queries)
pair_data = []
for query_id, query in enumerate(queries):
select_col_id = self.get_select_col_id(query_id, query)
for col_id, (col_name, col_type) in enumerate(query.table.header):
if col_id not in select_col_id:
continue
cache_key = self.candidate_extractor.get_cache_key(query_id, query, col_id)
values = self.candidate_extractor.cache.get(cache_key, [])
pattern = self.OP_PATTERN.get(col_type, [])
pairs = self.generate_pairs(query_id, query, col_id, col_name,
values, pattern)
pair_data += pairs
return pair_data
def get_select_col_id(self, query_id, query):
if self.model_1_outputs:
select_col_id = [cond_col for cond_col, *_ in self.model_1_outputs[query_id]['conds']]
elif self.has_label:
select_col_id = [cond_col for cond_col, *_ in query.sql.conds]
else:
select_col_id = list(range(len(query.table.header)))
return select_col_id
def generate_pairs(self, query_id, query, col_id, col_name, values, op_patterns):
pairs = []
for value in values:
for op_pattern in op_patterns:
cond = op_pattern['pattern'].format(col_name=col_name, value=value)
cond_sql = (col_id, op_pattern['cond_op_idx'], value)
real_sql = {}
if self.has_label:
real_sql = {tuple(c) for c in query.sql.conds}
label = 1 if cond_sql in real_sql else 0
pair = QuestionCondPair(query_id, query.question.text,
cond, cond_sql, label)
pairs.append(pair)
return pairs
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
task1_result = load_json(task1_file)
tr_qc_pairs = QuestionCondPairsDataset(train_data,
candidate_extractor=CandidateCondsExtractor(share_candidates=False))
te_qc_pairs = QuestionCondPairsDataset(test_data,
candidate_extractor=CandidateCondsExtractor(share_candidates=True),
has_label=False,
model_1_outputs=task1_result)
0%|▏ | 81/41522 [00:00<00:51, 804.12it/s]
building candidate cache
100%|███████████████████████████████████████████████████████████████████████████| 41522/41522 [00:58<00:00, 706.46it/s]
2%|█▎ | 65/4086 [00:00<00:06, 645.07it/s]
building candidate cache
100%|█████████████████████████████████████████████████████████████████████████████| 4086/4086 [00:06<00:00, 656.83it/s]
构建模型
class SimpleTokenizer(Tokenizer):
def _tokenize(self, text):
R = []
for c in text:
if c in self._token_dict:
R.append(c)
elif self._is_space(c):
R.append('[unused1]')
else:
R.append('[UNK]')
return R
def construct_model(paths, use_multi_gpus=False):
token_dict = load_vocabulary(paths.vocab)
tokenizer = SimpleTokenizer(token_dict)
bert_model = load_trained_model_from_checkpoint(
paths.config, paths.checkpoint, seq_len=None)
for l in bert_model.layers:
l.trainable = True
x1_in = Input(shape=(None,), name='input_x1', dtype='int32')
x2_in = Input(shape=(None,), name='input_x2')
x = bert_model([x1_in, x2_in])
x_cls = Lambda(lambda x: x[:, 0])(x)
y_pred = Dense(1, activation='sigmoid', name='output_similarity')(x_cls)
# 类似sequentical
model = Model([x1_in, x2_in], y_pred)
if use_multi_gpus:
print('using multi-gpus')
model = multi_gpu_model(model, gpus=2)
model.compile(loss={'output_similarity': 'binary_crossentropy'},
optimizer=Adam(1e-5),
metrics={'output_similarity': 'accuracy'})
return model, tokenizer
model, tokenizer = construct_model(paths)
model.summary()
WARNING:tensorflow:From E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From E:\Anaconda\anaconda\envs\tensorflow1\lib\site-packages\keras\backend\tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_x1 (InputLayer) (None, None) 0
__________________________________________________________________________________________________
input_x2 (InputLayer) (None, None) 0
__________________________________________________________________________________________________
model_2 (Model) (None, None, 768) 101677056 input_x1[0][0]
input_x2[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda) (None, 768) 0 model_2[1][0]
__________________________________________________________________________________________________
output_similarity (Dense) (None, 1) 769 lambda_1[0][0]
==================================================================================================
Total params: 101,677,825
Trainable params: 101,677,825
Non-trainable params: 0
__________________________________________________________________________________________________
构建输入数据
class QuestionCondPairsDataseq(Sequence):
def __init__(self, dataset, tokenizer, is_train=True, max_len=120,
sampler=None, shuffle=False, batch_size=32):
self.dataset = dataset
self.tokenizer = tokenizer
self.is_train = is_train
self.max_len = max_len
self.sampler = sampler
self.shuffle = shuffle
self.batch_size = batch_size
self.on_epoch_end()
def _pad_sequences(self, seqs, max_len=None):
return pad_sequences(seqs, maxlen=max_len, padding='post', truncating='post')
def __getitem__(self, batch_id):
batch_data_indices = \
self.global_indices[batch_id * self.batch_size: (batch_id + 1) * self.batch_size]
batch_data = [self.data[i] for i in batch_data_indices]
X1, X2 = [], []
Y = []
for data in batch_data:
x1, x2 = self.tokenizer.encode(first=data.question.lower(),
second=data.cond_text.lower())
X1.append(x1)
X2.append(x2)
if self.is_train:
Y.append([data.label])
X1 = self._pad_sequences(X1, max_len=self.max_len)
X2 = self._pad_sequences(X2, max_len=self.max_len)
inputs = {'input_x1': X1, 'input_x2': X2}
if self.is_train:
Y = self._pad_sequences(Y, max_len=1)
outputs = {'output_similarity': Y}
return inputs, outputs
else:
return inputs
def on_epoch_end(self):
self.data = self.sampler.sample(self.dataset)
self.global_indices = np.arange(len(self.data))
if self.shuffle:
np.random.shuffle(self.global_indices)
def __len__(self):
return math.ceil(len(self.data) / self.batch_size)
tr_qc_pairs_seq = QuestionCondPairsDataseq(tr_qc_pairs, tokenizer,
sampler=NegativeSampler(), shuffle=True)
te_qc_pairs_seq = QuestionCondPairsDataseq(te_qc_pairs, tokenizer,
sampler=FullSampler(), shuffle=False, batch_size=128)
训练模型
model.fit_generator(tr_qc_pairs_seq, epochs=5, workers=4)
预测测试集
te_result = model.predict_generator(te_qc_pairs_seq, verbose=1)
对任务二做预测
def merge_result(qc_pairs, result, threshold):
select_result = defaultdict(set)
for pair, score in zip(qc_pairs, result):
if score > threshold:
select_result[pair.query_id].update([pair.cond_sql])
return dict(select_result)
task2_result = merge_result(te_qc_pairs, te_result, threshold=0.995)
最终输出
final_output_file = 'final_output.json'
with open(final_output_file, 'w') as f:
for query_id, pred_sql in enumerate(task1_result):
cond = list(task2_result.get(query_id, []))
pred_sql['conds'] = cond
json_str = json.dumps(pred_sql, ensure_ascii=False)
f.write(json_str + '\n')