RuntimeError: each element in list of batch should be of equal size (使用collate_fn解决)

前边在PyTorch训练时,在加载数据集的时候报错:

RuntimeError: each element in list of batch should be of equal size

这个错误的原因是同一个mini-batch中,数据的大小不一样。

getitem返回类型为Tensor、Tuple、List

这里先以简单的图像分类举例,讲解问题、原理、解决方案,dataset类和读取数据的代码应该大致如下:

from torch.utils.data import Dataset, DataLoader

class ImageDataset(Dataset):
	def __init__(self):
		# 初始化..., 把image路径和label信息都写入到 total_data_list 中
		self.total_data_list = ......
	def __len__(self):
		return len(self.total_data_list)
	def __getitem__(self, index):
		image, label = self.total_data_list[index]
		return image, label

dataset = ImageDataset()
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

# 两种读取数据的方式
# 1. sample获取的是整个batch的数据,即:下面的[[img_0, ...],[lbl_0, ...]]
for index, sample in enumerate(dataloader):
	....
# 2. 直接从原数据中unpack,得到image是下面的[img_0, ...], label是下面的[lbl_0, ...]
for image, data in dataloader:
	....

每次从dataloader中取出一个mini-batch的数据,默认的格式为:(batchsize为4)

[ [img_0, img_1, img_2, img_3], [lbl_0, lbl_1, lbl_2, lbl_3] ]

每取出一批数据,就返回一个list,该list中有若干的list,每个list的长度都是batchsize,一开始提到的错误就是若干个list中,存在list中元素大小不相等的情况

但是我们在 __getitem__() 返回的是一个元组 (image, label),dataloader在读取4个数据后做了怎样的操作?
一开始,dataloader是直接读取4个数据,即

[ (img_0, lbl_0), (img_1, lbl_1), (img_2, lbl_2), (img_3, lbl_3) ]

dataloader先在获取到的数据列表中,每个元素都在第一个维度前扩充一个维度,即:
原来3x640x512的image,变成1x3x640x512
原来为实数的label,变成1维的tensor

然后再沿着batchsize维度,分别将image和label连接起来,形成3x3x640x512 和 3x1 的两个tensor,然后再合并为一个list,就形成上面默认的数据格式。

代码如下,其中x就是 [ (img_0, lbl_0), (img_1, lbl_1), (img_2, lbl_2), (img_3, lbl_3) ]

collate_fn=lambda x:(
	torch.cat(
   		[x[i][j].unsqueeze(0) for i in range(len(x))], 0
   	) for j in range(len(x[0]))
)

注意到这里使用了 collate_fn ,这个就是我们主要用来解决该错误的DataLoader的一个参数,表示该DataLoader是如何取样本的,可以使用该参数定义自己的函数,来实现更加准确的取样功能。

如果在定义DataLoader时,不指定collate_fn,就是默认的取样方式,等价于指定collate_fn为上面的代码。同时也可以进行自定义,如:

collate_fn=lambda x:x
collate_fn=lambda x:torch.utils.data.dataloader.default_collate(list(filter(lambda x: x is not None, batch))) # 限制x不为空

getitem返回类型为Dict

上面是__getitem__中返回tensor、list或tuple的情况,如果返回的是dict类型数据,如:

class ImageDataset(Dataset):
	def __init__(self):
		# 初始化..., 把image路径和label信息都写入到 total_data_list 中
		self.total_data_list = ......
	def __len__(self):
		return len(self.total_data_list)
	def __getitem__(self, index):
		image, label = self.total_data_list[index]
		return {
    
    
			"image": image,
			"label": label
		}

在dataloader中,处理样本dict中value的方式,与上面介绍的相同,先扩充维度,再将batchsize个数据合在一起。不同的地方在于,将处理好后的包括batchsize个数据的新value,对应于原始的key,存放起来,最终返回的不是list,而是一个dict类型的数据。

{
	'image': [img_0, img_1, img_2, img_3], 
	'label': [lbl_0, lbl_1, lbl_2, lbl_3]
}

其实是非常类似的,只不过是将value部分从dict中按照key取出来,处理完再放回dict中对应的位置中。

正式介绍我遇到的问题,__getitem__返回的格式如下:

return {
    
    
	
    "images": images,          # List[Tensor]: [N][3,Hi,Wi], N is number of images
    "intrinsics": intrinsics,  # Tensor: [N,3,3]
    "extrinsics": extrinsics,  # Tensor: [N,4,4]
    "depth_min": depth_min,    # Tensor: [1]
    "depth_max": depth_max,    # Tensor: [1]
    "depth_gt": depth_gt,      # Tensor: [1,H0,W0] if exists
    "mask": mask,              # Tensor: [1,H0,W0] if exists
    "filename": os.path.join(scan, "{:0>8}".format(view_ids[0]) + "{}")
}

N等于1张ref图像+N-1张src图像,当N=6时,getitem返回的数据情况如下:

{
	'images' : [torch.Size([3, 480, 640]), torch.Size([3, 480, 640]), torch.Size([3, 480, 640]), torch.Size([3, 480, 640]), torch.Size([3, 480, 640]), torch.Size([3, 480, 640])]
	'intrinsics' : torch.Size([6, 3, 3])
	'extrinsics' : torch.Size([6, 4, 4])
	'depth_min' : torch.Size([1])
	'depth_max' : torch.Size([1])
	'depth_gt' : torch.Size([1, 480, 640])
	'mask' : torch.Size([1, 480, 640])
	'filename' : '57f8d9bbe73f6760f10e916a/00000182{}'
}

正常情况下,batchsize=4时,会将4个上面的数据,按照之前介绍的规则,组合在一起,输出一个batch的数据,格式如下:

{
	'images' : [torch.Size([4, 3, 480, 640]), torch.Size([4, 3, 480, 640]), torch.Size([4, 3, 480, 640]), torch.Size([4, 3, 480, 640]), torch.Size([4, 3, 480, 640]), torch.Size([4, 3, 480, 640])]
	'intrinsics' : torch.Size([4, 6, 3, 3])
	'extrinsics' : torch.Size([4, 6, 4, 4])
	'depth_min' : torch.Size([4])
	'depth_max' : torch.Size([4])
	'depth_gt' : torch.Size([4, 1, 480, 640])
	'mask' : torch.Size([4, 1, 480, 640])
	'filename' : ['57f8d9bbe73f6760f10e916a/00000182{}', '57f8d9bbe73f6760f10e916a/00000204{}', '57f8d9bbe73f6760f10e916a/00000235{}', '57f8d9bbe73f6760f10e916a/00000170{}']
}

ref图像的数量是确定的,只有一个。
训练数据的pairs.txt中,在指定ref下,会有m个src图像,按照score从大到小排列,在训练时,需要指定使用的src图像数量,也就是N-1的值,当m >= N-1时,也就是说一定可以取出设定的N-1个src图像,可以取score最大的N-1个,也可以为了训练的鲁棒,随机取N-1个。
但如果某个ref的源图像数量,m < N-1, 也就是说最多也就取m个,那么在按照mini-batch的方法取数据时,就可能出现取出的4个数据中,图像的总数(ref+src)出现 6,6,6,5的情况,这就导致在将四个数据进行合并时,出现问题,因为不可以将 [1x6xCxHxW][1x5xCxHxW] 的两个数据cat在一起。所以就出现了上面的错误:RuntimeError: each element in list of batch should be of equal size

解决方案

解决方案:
1.设置batchsize为1,第一种不应该叫做解决方案,只能说避免报错。
2.自定义collate_fn函数,更改默认的取样方法

在第二种方法中,最简单直接的:

dataloader = DataLoader(dataset, batch_size=4, collate_fn=lambda x:x)

即返回未处理的原格式:[ (img_0, lbl_0), (img_1, lbl_1), (img_2, lbl_2), (img_3, lbl_3) ]
如果是dict,返回的就是:[ {'image':img_0, 'label': lbl_0}, {'image':img_1, 'label': lbl_1}, {'image':img_2, 'label': lbl_2}, {'image':img_3, 'label': lbl_3} ]

然后在下一步读取dataloader数据后,改变处理数据的方式即可。

但是如果不准备更改后面的代码(比如使用别人的代码,已经写好且非常多非常完善),那么就只好在自定义collate_fn中实现一些更复杂的逻辑,来让数据变得正常。

目前我的解决方法是,将mini-batch中与其他数据大小不同的数据,剔除出去,返回剩下的数据。
具体到该例子,我将 m < N-1 的数据剔除掉,直接不使用,因为在pairs.txt中,这样的数据是极少数,对训练的影响不大。具体做法是获取dict中image的长度,看是否小于N,代码如下:

# src_nums + 1 为N的值,1是指ref图像
src_nums = input_args.num_views
def is_data_ok(data):
    if len(data['images']) < src_nums + 1:
        return False
    else:
        return True

# batch是原始的数据 [ {'image':img_0, 'label': lbl_0}, {'image':img_1, 'label': lbl_1}, {'image':img_2, 'label': lbl_2}, {'image':img_3, 'label': lbl_3} ]
def collate_fn(batch):
	# 对batch中的每个数据进行判断,is_data_ok(batch[i])为true则保留,否则丢弃
    batch_new = list(filter(is_data_ok, batch))
    # default_collate(batch_new): 对于剩下的数据batch_new,按照DataLoader默认的方式进行处理(即上面介绍的扩展维度、合并等)
    return  torch.utils.data.dataloader.default_collate(batch_new)

# 设置drop_last 为 False:假设过滤掉了1个数据,此时返回的数据只有batchsize-1个,drop_last=True会将这样的数据丢弃不要
train_loader = DataLoader(train_dataset, input_args.batch_size, shuffle=True, num_workers=8, drop_last=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, input_args.batch_size, shuffle=False, num_workers=4, drop_last=False, collate_fn=collate_fn)

我认为更科学的方法,是将这些数据保留起来,碰到大小一样的数据,再合并使用,目前还未写出代码,后期更新,知道怎么实现的小伙伴可以提提建议

参考

error solution:
http://www.manongjc.com/detail/25-hkwccoiwsencpam.html
https://www.cnblogs.com/vase/p/15354331.html
https://blog.csdn.net/weixin_44799217/article/details/115137820

collate_fn:
https://zhuanlan.zhihu.com/p/361830892
https://github.com/pytorch/pytorch/issues/57429
https://github.com/pytorch/pytorch/issues/67419
https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html#generate-data-batch-and-iterator

drop_last:
https://blog.csdn.net/xijuezhu8128/article/details/107954141

lambda:
https://blog.csdn.net/zagfai/article/details/8972618

猜你喜欢

转载自blog.csdn.net/qq_41340996/article/details/123156330