Dataset类分批加载数据集

在做NLP任务的时候,需要分批加载数据集进行训练,这个时候可以继承pytorch.utils.data中的Dataset类,就可以进行分批加载数据,并且可以将数据转换成tensor对象数据.
处理流程:
image.png

1.自定义Dataset类

这个类要配合的torch.utils.data 中的DataLoader类才可以发挥作用

# 因为我在数据预处理的时候将转换成id的数据集全部持久化处理了,所以需要这个方法加载数据
# 获取文件
def load_pkl(path,obj_name):
    print(f'get{obj_name} in {path}')
    with codecs.open(path,'rb')as f:
        data=pkl.load(f)
    return data

# 第三方库
import torch
from torch.utils.data import Dataset

# 自定义库
from BruceNRE.utils import load_pkl
# 数据集的加载
class CustomDataset(Dataset):
    def __init__(self,file_path,obj_name):
        self.file=load_pkl(file_path,obj_name)

    def __getitem__(self, item):
        sample=self.file[item]
        return sample

    def __len__(self):
        return len(self.file)

# 这个方法负责将数据进行填充,并且转换成tensor对象
def collate_fn(batch):
# 把这个批次中的数据按照list长度由高到低排序
    batch.sort(key=lambda data: len(data[0]),reverse=True)
# 将这个批次中数据长度放到len集合中
    lens=[len(data[0])for data in batch]
# 获得最大的长度
    max_len=max(lens)

    sent_list=[]
    head_pos_list=[]
    tail_pos_list=[]
    mask_pos_list=[]
    relation_list=[]

    # 填充数据,都用0来填充
    def _padding(x,max_len):
        return x+[0]*(max_len-len(x))
# 把数据集转换成tensor对象,然后封装到对应的list中
    for data in batch:
        sent,head_pos,tail_pos,mask_pos,relation=data
        sent_list.append(_padding(sent,max_len))
        head_pos_list.append(_padding(tail_pos,max_len))
        tail_pos_list.append(_padding(tail_pos,max_len))
        mask_pos_list.append(_padding(mask_pos,max_len))
        relation_list.append(relation)

    # 将numpy转换为tensor
    return torch.tensor(sent_list),torch.tensor(head_pos_list),torch.tensor(tail_pos_list),torch.tensor(mask_pos_list),torch.tensor(relation_list)

这个类解释一下作用:

  • init方法:把所有数据集加载进来
  • getitem:如果设置suffle为True就会打乱数据,传递数据的索引给getitem,就是item,然后根据索引加载数据.
  • len:获取数据集的索引长度
  • collate_fn:因为使用DataLoader这个类要求每一个批次中的数据的长度必须要一样,所以这个方法有两个作用,第一个作用就是把数据集全部用0填充到相同的长度,然后将数据集(是转换成字典标志位的数据集)转换成tensor对象

2.使用Dataset类

# 调用Dataset实现类
train_dataset=CustomDataset(train_data_path,'train-data')
# 将train_dataset放到DataLoader中,才可以使用
train_dataloader=DataLoader(
        dataset=train_dataset,
        batch_size=128,
        shuffle=True,
        drop_last=True,
        collate_fn=collate_fn
    )

    for batch_idx,batch in enumerate(train_dataloader):
        *x,y=[data.to(device) for data in batch]
    print('dataloader测试完成')

参数解析:
dataset:Dataset类封装的数据集
batch_size:每个批次处理的数据量,一般128或者64
shuffle:是否打乱顺序
drop_last:丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。
collate_fn:处理数据集成一样的长度,并且转换成tensor对象的方法

==============================================================

3.再看个例子:

我的原始数据格式:

体验2D巅峰 倚天屠龙记十大创新概览	8
60年铁树开花形状似玉米芯(组图)	5
同步A股首秀:港股缩量回调	2
中青宝sg现场抓拍 兔子舞热辣表演	8
锌价难续去年辉煌	0
2岁男童爬窗台不慎7楼坠下获救(图)	5
布拉特:放球员一条生路吧 FIFA能消化俱乐部的攻击	7
金科西府 名墅天成	1
状元心经:考前一周重点是回顾和整理	3
发改委治理涉企收费每年为企业减负超百亿	6
一年网事扫荡10年纷扰开心网李鬼之争和平落幕	4
2010英国新政府“三把火”或影响留学业	3
俄达吉斯坦共和国一名区长被枪杀	6
朝鲜要求日本对过去罪行道歉和赔偿	6
《口袋妖怪 黑白》日本首周贩售255万	8
图文:借贷成本上涨致俄罗斯铝业净利下滑21%	2
组图:新《三国》再曝海量剧照 火战场面极震撼	9
麻辣点评:如何走出“被留学”的尴尬	3
  • 创建一个Dataset的子类来处理数据
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer,BertConfig,BertModel
bert_model='./bert-base-chinese'
myconfig = BertConfig.from_pretrained("./bert-base-chinese")
tokenizer=BertTokenizer.from_pretrained(bert_model)
MAX_LEN = 256 - 2

class ElementDataset(Dataset):
    def __init__(self, f_path):
        sents, label_li = [], []  # list of lists
        with open(f_path, 'r', encoding='utf-8') as fr:
            for line in fr:
                if len(line) < 10:
                    continue
                entries = line.strip().split('\t')
                words = entries[0]
                label = entries[1:]
                label = list(map(int, label))
                sents.append(words)
                label_li.append(label)
        self.sents, self.label_li = sents, label_li

    def __getitem__(self, item):
        words,tags=self.sents[item],self.label_li[item]
        inputs=tokenizer.encode_plus(words)
        label=tags
        seqlen = len(inputs['input_ids'])
        sample=(inputs,label,seqlen)
        return sample

    def __len__(self):
        print('sents')
        return len(self.sents)

    # 填充
def collate_fn(batch):
    all_input_ids=[]
    all_attention_mask=[]
    all_token_type_ids=[]
    all_labels=[]
    lens=[data[2] for data in batch]
    max_len=max(lens)
    def padding(input,max_len,pad_token):
        return input+[pad_token]*(max_len-len(input))

    for data in batch:
        input,tags,_=data
        all_input_ids.append(padding(input['input_ids'],max_len,1))
        all_token_type_ids.append(padding(input['token_type_ids'],max_len,0))
        all_attention_mask.append(padding(input['attention_mask'],max_len,0))
        all_labels.append(tags)
    return torch.tensor(all_input_ids),torch.tensor(all_token_type_ids),torch.tensor(all_attention_mask),all_labels
  • 然后再调用的时候使用DataLoader加载数据
train_data=ElementDataset(args.Train)
    test_data=ElementDataset(args.Test)

    train_iter=DataLoader(dataset=train_data,
                               batch_size=10,
                               shuffle=True,
                               drop_last=True,
                               collate_fn=collate_fn)

    test_iter =DataLoader(dataset=test_data,
                                 batch_size=10,
                                 shuffle=True,
                                 drop_last=True,
                                 collate_fn=collate_fn)
# 可以使用一个for循环查看数据
    for i, batch in enumerate(iterator):
        input_ids,token_type_ids,attention_mask,labels= batch

batch就是每一个批次的数据,我设置的这个批次的数据是10个,则这个10个的数据的长度就是一样的长度,并且都是tensor格式.

猜你喜欢

转载自blog.csdn.net/qq_35653657/article/details/126003653