seq2seq中对文本数据的处理

接上文的seq2seq encoder和decoder阶段实现前的 text文本的处理. 无论是对英文还是中文都可以处理

比如 txt的一个文本, 一行代表一句话 无论是翻译还是聊天对答 我假设你的数据文本是以下的情况 以对话聊天为例子:

1. source文本是存粹的问话,一行就是一句问话
2. target文本是存粹的答话, 对应source的每一行

很多数据文本是 一句问一句答的, 这样也很简单 你他拆成source 跟target 就好了 相信会用python的人 都会处理

说到文本处理不得不说 nltk了 python的工具 pip3 install nltk 就可以了! 由于我对nltk也只是一知半解所以只是很少一部分用到了nltk的方法, 大部分还是利用了python的 list特性来做文本处理 一贯的方式,不多说直接上代码

import nltk
import itertools

FILEPATH_S = '/Users/apple/Desktop/twiiter_sample_s'
FILEPATH_T = '/Users/apple/Desktop/twiiter_sample_t'
EN_WHITELIST = '0123456789abcdefghijklmnopqrstuvwxyz '  # space is included in whitelist
CH_BLACKLIST = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\''
MAX_LENGTH = 30
most_vocab_size = 10000

def read_data(filepath):
    data = open(filepath).readlines()
    data = [line[:-1] for line in data]
    return data


def process_all_data(data_source, data_target,en_ch=True):
    data = data_source + data_target

    if en_ch:
        data_lines = [line.lower() for line in data]
        lines = [filter_line(nline, EN_WHITELIST, en_ch=True) for nline in data_lines]
    else:
        lines = [filter_line(nline, CH_BLACKLIST, en_ch=False) for nline in data]

    data_lines_list = [line.split(" ") for line in lines]

    freq_dist = nltk.FreqDist(itertools.chain(*data_lines_list))
    VOCAB = freq_dist.most_common(most_vocab_size)
    int2word = ['<PAD>'] + ['<UNK>'] + ['<GO>'] + ['<EOS>'] + [x[0] for x in VOCAB]
    word2int = dict([(w, i) for i, w in enumerate(int2word)])

    for line in data_lines_list:
        for i in range(len(line)):
            line[i] = word2int.get(line[i], '<UNK>')

    for line in data_lines_list:
        if len(line) < MAX_LENGTH:
            for _ in range(MAX_LENGTH - len(line)):
                line.append(word2int.get('<PAD>'))

    return data_lines_list

def process_data():
    data_source = read_data(FILEPATH_S)
    data_target = read_data(FILEPATH_T)
    data_lines_list = process_all_data(data_source, data_target,en_ch=True)
    input_source_int = data_lines_list[:len(data_source)]
    output_target_int = data_lines_list[len(data_source):]

    return input_source_int, output_target_int


def filter_line(line, charlist, en_ch=True):
    if en_ch:
        return "".join([ch for ch in line if ch in charlist])
    else:
        return "".join([ch for ch in line if ch not in charlist])

将这个代码保存为 data.py  然后在你需要用到的地方 import data  调用 data.process_data() 就可以了, 默认是英文, 如果要处理中文请注意编码以及将 en_ch 更改成False 就可以, 最终呈现的是类似这样的数据:

其中为什么会有那么多个0呢? 是因为我规定了max_sequence_length = 30 然后将不足30 的句子 append ‘<PAD>’ 

以上就是将你的文本文件 转成成 数字vector的办法, 当然要真成为vector 你还可以得用 np.save np.load转换一下 显示的数据是一样的. 另外,对于output_target_int 还要做个处理, 就是在每行前台添加 '<GO>' 每行末尾变成 '<EOS>' 当然 有些程序最终输入数据的时候又会去掉末尾的 EOS~~~这个就各位自己考虑了!

注: 以上部分代码思路参考了他人的著作,不好意思忘记在哪里看到了.......

猜你喜欢

转载自blog.csdn.net/weixin_42724775/article/details/81100342
今日推荐