nlp任务之预测中间词-huggingface

目录

1.加载编码器

1.1编码试算 

2.加载数据集 

3.数据集处理 

3.1 map映射:只对数据集中的'sentence'数据进行编码

3.2用filter()过滤 单词太少的句子过滤掉

3.3截断句子 

4.创建数据加载器Dataloader 

5. 下游任务模型 

6.测试预测代码 

7.训练代码

 8.保存与加载模型


1.加载编码器

from transformers import AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained(r'../data/model/distilroberta-base/')
print(tokenizer)
RobertaTokenizerFast(name_or_path='../data/model/distilroberta-base/', vocab_size=50265, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	50264: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False, special=True),
}

1.1编码试算 

tokenizer.batch_encode_plus([
    'hide new secretions from the parental units',
    'this moive is great'  
])
{'input_ids': [[0, 37265, 92, 3556, 2485, 31, 5, 20536, 2833, 2], [0, 9226, 7458, 2088, 16, 372, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]]}

# 'input_ids'中的0:表示 'bos_token': '<s>'
#'input_ids'中的2:表示 'eos_token': '</s>'
#Bert模型有特殊字符!!!!!!! 

2.加载数据集 

from datasets import load_from_disk  #从本地加载已经下载好的数据集


dataset_dict = load_from_disk('../data/datasets/glue_sst2/')
dataset_dict
DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

#若是从网络下载(国内容易网络错误,下载不了,最好还是先去镜像网站下载,本地加载)
# from datasets import load_dataset


# dataset_dict2 = load_dataset(path='glue', name='sst2')
# dataset_dict2

3.数据集处理 

3.1 map映射:只对数据集中的'sentence'数据进行编码

#预测中间词任务:只需要'sentence' ,不需要'label'和'idx'
#用map()函数,映射:只对数据集中的'sentence'数据进行编码
def f_1(data, tokenizer):
    return tokenizer.batch_encode_plus(data['sentence'])

dataset_dict = dataset_dict.map(f_1, 
                 batched=True,
                 batch_size=16,
                 drop_last_batch=True,
                 remove_columns=['sentence', 'label', 'idx'],
                 fn_kwargs={'tokenizer': tokenizer},
                 num_proc=8)  #8个进程, 查看任务管理器>性能>逻辑处理器



dataset_dict
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 67328
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 768
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 1792
    })
})

3.2用filter()过滤 单词太少的句子过滤掉

#处理句子,让每一个句子的都至少有9个单词,单词太少的句子过滤掉
#用filter()过滤
def f_2(data):
    return [len(i) >= 9 for i in data['input_ids']]

dataset_d