问题:
主要特指在医学领域,尤其是3d分割。由于每个数据可能有好几百张的dicom。用patch分割,如果每次都去重新加载如此大的数据集会十分的费时。要么一次性将数据与标签都载入内存,那就需要几十到几百G的内存。显然一般情况下是不存在如此大的内存。
解决方法:
利用python的迭代器,可以完美的解决数据无法一次加载问题。
迭代器是一个可以记住遍历的位置的对象。
迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。迭代器只能往前不会后退。
迭代器有两个基本的方法:iter() 和 next()。
重写Pytorch中的Dataset。
基本思路:将数据分为n份,每次先加载一份进行训练,当训练完成时。删除内存,加载新的一份数据。
import numpy as np
from torch.utils.data import DataLoader,Dataset
class train_test_dataset(Dataset):
def __init__(self,data,fen):
self.data = data
self.index = iter(np.arange(self.data.shape[0]))
self.jj = 0
self.fen = fen
def datagen(self):
print("datagen")
data2 = self.data[self.m]
return data2
def __getitem__(self, item):
hh = self.data_gen1[item]
return hh
def __len__(self):
if self.jj<self.fen-1:
print("len")
self.m = next(self.index)
self.data_gen1 = self.datagen()
#文件加载要在shuffle前完成
np.random.shuffle(self.data_gen1)
print(self.data_gen1.shape[0])
self.jj+=1
return self.data_gen1.shape[0]
else:
self.jj = 0
print("len")
self.m = next(self.index)
self.data_gen1 = self.datagen()
# 文件加载要在shuffle前完成
np.random.shuffle(self.data_gen1)
print(self.data_gen1.shape[0])
self.index = iter(np.arange(self.data.shape[0]))
return self.data_gen1.shape[0]
if __name__ == "__main__":
data = np.arange(80).reshape(4,20)
fen = 4
test_data = train_test_dataset(data,fen)
datalader = DataLoader(test_data, batch_size=3,)
for epoch in range(51):
for i in range(fen):#分为4份
for _,step in enumerate(datalader):
print(step)
真实例子(不知要加载多少个病例,只需将第一次用到的病例都加载进来。后续删除即可)
def load_data(self):
#加载部分数据并且返回(将数据分为10份
if self.m<len(self.every_num_lis)-1:
temp_volume_file = self.volume_file[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]
temp_segmentation_file = self.segmentation_file[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]
temp_startpoint_list = self.startpoint_list[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]
temp_endpoint_list = self.endpoint_list[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]
temp_pad_list = self.pad_list[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]
else:#当取到最后一轮数据的时候
temp_volume_file = self.volume_file[self.every_num_lis[self.m]:]
temp_segmentation_file = self.segmentation_file[self.every_num_lis[self.m]:]
temp_startpoint_list = self.startpoint_list[self.every_num_lis[self.m]:]
temp_endpoint_list = self.endpoint_list[self.every_num_lis[self.m]:]
temp_pad_list = self.pad_list[self.every_num_lis[self.m]:]
#清空之前的数据
global_key = []
for key, value in globals().items():
if "vvolume_" in key or "ssegmentation_" in key:
global_key.append(key)
for i in global_key:
del globals()[i]
name_cv = "hh"
name2 = 0
filename_label = []
for i in range(len(temp_volume_file)):
if name_cv != temp_volume_file[i]:
name_cv = temp_volume_file[i]
name2 += 1
print(name2)
globals()["vvolume_" + str(name2)] = sitk.GetArrayFromImage(sitk.ReadImage(name_cv, sitk.sitkInt16))
globals()["ssegmentation_" + str(name2)] = sitk.GetArrayFromImage(sitk.ReadImage(temp_segmentation_file[i], sitk.sitkUInt8))
filename_label.append(name2)
#shuffle
state = np.random.get_state()
np.random.shuffle(filename_label)
np.random.set_state(state)
np.random.shuffle(temp_startpoint_list)
np.random.set_state(state)
np.random.shuffle(temp_endpoint_list)
np.random.set_state(state)
np.random.shuffle(temp_pad_list)
return filename_label,temp_startpoint_list,temp_endpoint_list,temp_pad_list