NLP任务之翻译

目录

1  加载预训练模型的分词器

2  加载本地数据集 

3  数据预处理

4  创建数据加载器

5  定义下游任务的模型 

6  测试代码 

7  训练代码 

8.保存与加载训练好的模型

 


#加载预训练的翻译分词器之前需要先安装一个第三方库

# -后面接的是清华源

! pip install sentencepiece -i Simple Index

 #sentencepiece开源工具, 可以更好的生成词向量

1  加载预训练模型的分词器

from transformers import AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained('../data/model/opus-mt-en-ro/', use_fast=True)
print(tokenizer)
MarianTokenizer(name_or_path='../data/model/opus-mt-en-ro/', vocab_size=59543, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	59542: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
#假设文本,查看分词器的输出结果
text = [['hello, everyone today is a good day', 'It is late, please go home']]
tokenizer.batch_encode_plus(text)
{'input_ids': [[92, 778, 3, 1773, 879, 32, 8, 265, 431, 84, 32, 1450, 3, 709, 100, 540, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

2  加载本地数据集 

from datasets import load_dataset


dataset = load_dataset('../data/datasets/wmt16-ro-en/')
dataset
DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 610320
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 1999
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 1999
    })
}) 
#数据采样,数据量太多的, 需要随机抽取一些
dataset['train'] = dataset['train'].shuffle(1).select(range(20000))
dataset['validation'] = dataset['validation'].shuffle(1).select(range(200))
dataset['test'] = dataset['test'].shuffle(1).select(range(200))

3  数据预处理

#查看训练数据的第一条数据
dataset['train'][0]
{'translation': {'en': 'For these reasons I voted in favour of the proposal for a new regulation that aims for greater clarity and transparency in the GSP system.',
  'ro': 'Din aceste motive am votat în favoarea propunerii de nou regulament care își propune o mai mare claritate și transparență în sistemul SPG.'}}
def preprocess_function(data, tokenizer):
    """定义数据预处理的函数"""
    #分别获取'en'与'ro'对应的文本句子
    en = [ex['en'] for ex in data['translation']]
    ro = [ex['ro'] for ex in data['translation']]
        
    #对‘en’文本进行编码分词
    data = tokenizer.batch_encode_plus(en, max_length=128, truncation=True)
    #对'ro'文本进行编码分词,并将结果的'input_ids'作为labels
    with tokenizer.as_target_tokenizer():
        data['labels'] = tokenizer.batch_encode_plus(
        ro, max_length=128, truncation=True)['input_ids']

    return data
#用map函数将定义的预处理函数加载进来
dataset = dataset.map(preprocess_function,
                      batched=True,
                      batch_size=1000,
                      num_proc=1, 
                      remove_columns=['translation'],
                      fn_kwargs={'tokenizer' : tokenizer})
#查看训练数据的第一条数据
print(dataset['train'][0])
{'input_ids': [460, 354, 3794, 12, 10677, 20, 5046, 14, 4, 2546, 37, 8, 397, 5551, 30, 10113, 37, 3501, 19814, 18, 8465, 20, 4, 44690, 782, 2, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [902, 576, 2946, 76, 10815, 17, 5098, 14997, 5, 559, 1140, 43, 2434, 6624, 27, 50, 337, 19216, 46, 22174, 17, 2317, 121, 16825, 2, 0]}
#数据批次处理函数:将数据一批批进行输出
def collate_fn(data):
    # 求最长的label
    max_length=max([len(i['labels']) for i in data])
    
    for i in data:
        #获取每一句需要补充的pad数量,赋值为100,
        pads = [-100] * (max_length - len(i['labels']))
        #每一句都加上需要补的pad
        i['labels'] = i['labels'] + pads
        
    #会自动将数据集中的所有类型的数据都按照最大序列长度进行补全pad
    data = tokenizer.pad(
        encoded_inputs=data,
        padding=True,
        max_length=None,
        pad_to_multiple_of=None,   #数据位数补齐到指定的倍数上(否)
        return_tensors='pt'
    )
    
    #序列数据也有编码器的输入数据  decoder_input_ids
    #字典添加数据的方式
    data['decoder_input_ids'] = torch.full_like(data['labels'], 
                                               tokenizer.get_vocab()['pad'],
                                               dtype=torch.long)
    #第一个token是cls,不需要传入校验预测值,就从索引为1的开始
    data['decoder_input_ids'][:, 1:] = data['labels'][:, :-1]
    data['decoder_input_ids'][data['decoder_input_ids'] == -100] = tokenizer.get_vocab()['<pad>']
    return data
tokenizer.get_vocab()['pad'], tokenizer.get_vocab()['<pad>']

4  创建数据加载器

import torch


loader = torch.utils.data.DataLoader(dataset=dataset['train'],
                                     batch_size=8,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

for data in loader:
    break


data
{'input_ids': tensor([[   12,   182,   381,   129,    13,  3177,     4,   397,  3490,    51,
             4, 31307,  8305,    30,   196,   451,  1304,    30,   314,    57,
           462,  5194,    14,     4,  6170,  1323,    13,   198,    13,    64,
           239,  3473,  1151,    20,  1273,     2,     0, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542],
        [   40, 16127,    56,  3024,    12,    76,   248,    13,  2043, 13500,
             3,    85,   932, 10119,     3,  4077,    14,     4,  2040,  5589,
          3551,    12,   123,   444,     4,  1586,  2716, 15373,     3,   193,
           174,   154,   166, 11192,   279,  4391,  4166,    20,    85,  3524,
            18,    33,    32,   381,   510,    20,   238, 14180,     2,     0],
        [   67,  3363,    14,  8822,     3, 16751,    18,   244,  4704,  2028,
           108,     4, 20738,  1058,  1136,  2936,     2,     0, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542],
        [   20,  4243,     2,    10,  2587,    14,   102,     3,    12,   182,
           129,    13, 22040,   238, 11617,     3,   372,    11,  3292, 46367,
            21,   464,   732,     3,    37,     4,  2082,    14,     4,  1099,
           211, 10197,   879,     2,     0, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542],