目录
3.1 map映射:只对数据集中的'sentence'数据进行编码
1.加载编码器
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(r'../data/model/distilroberta-base/')
print(tokenizer)
RobertaTokenizerFast(name_or_path='../data/model/distilroberta-base/', vocab_size=50265, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=False), added_tokens_decoder={ 0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True), 1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True), 2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True), 3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True), 50264: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False, special=True), }
1.1编码试算
tokenizer.batch_encode_plus([
'hide new secretions from the parental units',
'this moive is great'
])
{'input_ids': [[0, 37265, 92, 3556, 2485, 31, 5, 20536, 2833, 2], [0, 9226, 7458, 2088, 16, 372, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]]}
# 'input_ids'中的0:表示 'bos_token': '<s>'
#'input_ids'中的2:表示 'eos_token': '</s>'
#Bert模型有特殊字符!!!!!!!
2.加载数据集
from datasets import load_from_disk #从本地加载已经下载好的数据集
dataset_dict = load_from_disk('../data/datasets/glue_sst2/')
dataset_dict
DatasetDict({ train: Dataset({ features: ['sentence', 'label', 'idx'], num_rows: 67349 }) validation: Dataset({ features: ['sentence', 'label', 'idx'], num_rows: 872 }) test: Dataset({ features: ['sentence', 'label', 'idx'], num_rows: 1821 }) })
#若是从网络下载(国内容易网络错误,下载不了,最好还是先去镜像网站下载,本地加载)
# from datasets import load_dataset
# dataset_dict2 = load_dataset(path='glue', name='sst2')
# dataset_dict2
3.数据集处理
3.1 map映射:只对数据集中的'sentence'数据进行编码
#预测中间词任务:只需要'sentence' ,不需要'label'和'idx'
#用map()函数,映射:只对数据集中的'sentence'数据进行编码
def f_1(data, tokenizer):
return tokenizer.batch_encode_plus(data['sentence'])
dataset_dict = dataset_dict.map(f_1,
batched=True,
batch_size=16,
drop_last_batch=True,
remove_columns=['sentence', 'label', 'idx'],
fn_kwargs={'tokenizer': tokenizer},
num_proc=8) #8个进程, 查看任务管理器>性能>逻辑处理器
dataset_dict
DatasetDict({ train: Dataset({ features: ['input_ids', 'attention_mask'], num_rows: 67328 }) validation: Dataset({ features: ['input_ids', 'attention_mask'], num_rows: 768 }) test: Dataset({ features: ['input_ids', 'attention_mask'], num_rows: 1792 }) })
3.2用filter()过滤 单词太少的句子过滤掉
#处理句子,让每一个句子的都至少有9个单词,单词太少的句子过滤掉
#用filter()过滤
def f_2(data):
return [len(i) >= 9 for i in data['input_ids']]
dataset_d