接上文的seq2seq encoder和decoder阶段实现前的 text文本的处理. 无论是对英文还是中文都可以处理
比如 txt的一个文本, 一行代表一句话 无论是翻译还是聊天对答 我假设你的数据文本是以下的情况 以对话聊天为例子:
1. source文本是存粹的问话,一行就是一句问话
2. target文本是存粹的答话, 对应source的每一行
很多数据文本是 一句问一句答的, 这样也很简单 你他拆成source 跟target 就好了 相信会用python的人 都会处理
说到文本处理不得不说 nltk了 python的工具 pip3 install nltk 就可以了! 由于我对nltk也只是一知半解所以只是很少一部分用到了nltk的方法, 大部分还是利用了python的 list特性来做文本处理 一贯的方式,不多说直接上代码
import nltk
import itertools
FILEPATH_S = '/Users/apple/Desktop/twiiter_sample_s'
FILEPATH_T = '/Users/apple/Desktop/twiiter_sample_t'
EN_WHITELIST = '0123456789abcdefghijklmnopqrstuvwxyz ' # space is included in whitelist
CH_BLACKLIST = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\''
MAX_LENGTH = 30
most_vocab_size = 10000
def read_data(filepath):
data = open(filepath).readlines()
data = [line[:-1] for line in data]
return data
def process_all_data(data_source, data_target,en_ch=True):
data = data_source + data_target
if en_ch:
data_lines = [line.lower() for line in data]
lines = [filter_line(nline, EN_WHITELIST, en_ch=True) for nline in data_lines]
else:
lines = [filter_line(nline, CH_BLACKLIST, en_ch=False) for nline in data]
data_lines_list = [line.split(" ") for line in lines]
freq_dist = nltk.FreqDist(itertools.chain(*data_lines_list))
VOCAB = freq_dist.most_common(most_vocab_size)
int2word = ['<PAD>'] + ['<UNK>'] + ['<GO>'] + ['<EOS>'] + [x[0] for x in VOCAB]
word2int = dict([(w, i) for i, w in enumerate(int2word)])
for line in data_lines_list:
for i in range(len(line)):
line[i] = word2int.get(line[i], '<UNK>')
for line in data_lines_list:
if len(line) < MAX_LENGTH:
for _ in range(MAX_LENGTH - len(line)):
line.append(word2int.get('<PAD>'))
return data_lines_list
def process_data():
data_source = read_data(FILEPATH_S)
data_target = read_data(FILEPATH_T)
data_lines_list = process_all_data(data_source, data_target,en_ch=True)
input_source_int = data_lines_list[:len(data_source)]
output_target_int = data_lines_list[len(data_source):]
return input_source_int, output_target_int
def filter_line(line, charlist, en_ch=True):
if en_ch:
return "".join([ch for ch in line if ch in charlist])
else:
return "".join([ch for ch in line if ch not in charlist])
将这个代码保存为 data.py 然后在你需要用到的地方 import data 调用 data.process_data() 就可以了, 默认是英文, 如果要处理中文请注意编码以及将 en_ch 更改成False 就可以, 最终呈现的是类似这样的数据:
其中为什么会有那么多个0呢? 是因为我规定了max_sequence_length = 30 然后将不足30 的句子 append ‘<PAD>’
以上就是将你的文本文件 转成成 数字vector的办法, 当然要真成为vector 你还可以得用 np.save np.load转换一下 显示的数据是一样的. 另外,对于output_target_int 还要做个处理, 就是在每行前台添加 '<GO>' 每行末尾变成 '<EOS>' 当然 有些程序最终输入数据的时候又会去掉末尾的 EOS~~~这个就各位自己考虑了!
注: 以上部分代码思路参考了他人的著作,不好意思忘记在哪里看到了.......