数据量太大无法一次性载入内存训练

问题:

主要特指在医学领域,尤其是3d分割。由于每个数据可能有好几百张的dicom。用patch分割,如果每次都去重新加载如此大的数据集会十分的费时。要么一次性将数据与标签都载入内存,那就需要几十到几百G的内存。显然一般情况下是不存在如此大的内存。

解决方法:

利用python的迭代器,可以完美的解决数据无法一次加载问题。

迭代器是一个可以记住遍历的位置的对象。

迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。迭代器只能往前不会后退。

迭代器有两个基本的方法:iter() 和 next()。

重写Pytorch中的Dataset。

基本思路:将数据分为n份,每次先加载一份进行训练,当训练完成时。删除内存,加载新的一份数据。

import numpy as np
from torch.utils.data import DataLoader,Dataset


class train_test_dataset(Dataset):
    def __init__(self,data,fen):
        self.data = data
        self.index = iter(np.arange(self.data.shape[0]))
        self.jj = 0
        self.fen = fen


    def datagen(self):
        print("datagen")
        data2 = self.data[self.m]
        return data2




    def __getitem__(self, item):
        hh = self.data_gen1[item]
        return hh


    def __len__(self):
        if self.jj<self.fen-1:
            print("len")
            self.m = next(self.index)
            self.data_gen1 = self.datagen()
            #文件加载要在shuffle前完成
            np.random.shuffle(self.data_gen1)


            print(self.data_gen1.shape[0])
            self.jj+=1
            return self.data_gen1.shape[0]
        else:
            self.jj = 0
            print("len")
            self.m = next(self.index)
            self.data_gen1 = self.datagen()
            # 文件加载要在shuffle前完成
            np.random.shuffle(self.data_gen1)


            print(self.data_gen1.shape[0])
            self.index = iter(np.arange(self.data.shape[0]))
            return self.data_gen1.shape[0]


if __name__ == "__main__":
    data = np.arange(80).reshape(4,20)
    fen = 4
    test_data = train_test_dataset(data,fen)
    datalader = DataLoader(test_data, batch_size=3,)
    for epoch in range(51):
        for i in range(fen):#分为4份
            for _,step in enumerate(datalader):
                print(step)

真实例子(不知要加载多少个病例,只需将第一次用到的病例都加载进来。后续删除即可)

def load_data(self):
        #加载部分数据并且返回(将数据分为10份


        if self.m<len(self.every_num_lis)-1:
            temp_volume_file = self.volume_file[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]
            temp_segmentation_file = self.segmentation_file[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]
            temp_startpoint_list = self.startpoint_list[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]
            temp_endpoint_list = self.endpoint_list[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]
            temp_pad_list = self.pad_list[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]


        else:#当取到最后一轮数据的时候
            temp_volume_file = self.volume_file[self.every_num_lis[self.m]:]
            temp_segmentation_file = self.segmentation_file[self.every_num_lis[self.m]:]
            temp_startpoint_list = self.startpoint_list[self.every_num_lis[self.m]:]
            temp_endpoint_list = self.endpoint_list[self.every_num_lis[self.m]:]
            temp_pad_list = self.pad_list[self.every_num_lis[self.m]:]


        #清空之前的数据
        global_key = []
        for key, value in globals().items():
            if "vvolume_" in key or "ssegmentation_" in key:
                global_key.append(key)
        for i in global_key:
            del globals()[i]




        name_cv = "hh"
        name2 = 0
        filename_label = []
        for i in range(len(temp_volume_file)):
            if name_cv != temp_volume_file[i]:
                name_cv = temp_volume_file[i]
                name2 += 1
                print(name2)
                globals()["vvolume_" + str(name2)] = sitk.GetArrayFromImage(sitk.ReadImage(name_cv, sitk.sitkInt16))
                globals()["ssegmentation_" + str(name2)] = sitk.GetArrayFromImage(sitk.ReadImage(temp_segmentation_file[i], sitk.sitkUInt8))
            filename_label.append(name2)


        #shuffle
        state = np.random.get_state()
        np.random.shuffle(filename_label)
        np.random.set_state(state)
        np.random.shuffle(temp_startpoint_list)
        np.random.set_state(state)
        np.random.shuffle(temp_endpoint_list)
        np.random.set_state(state)
        np.random.shuffle(temp_pad_list)
        return filename_label,temp_startpoint_list,temp_endpoint_list,temp_pad_list

猜你喜欢

转载自blog.csdn.net/weixin_41202834/article/details/121173754