PyTorch - Dataset 迭代数据接口 __getitem__ 异常处理

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/133378772

Dataset

在模型训练的过程中,加载数据部分,极其容易出现异常,以及不可控的因素,需要通过异常捕获的方式,及时处理,常用方式就是使用 collate_fn,除此之外,还可以直接跳过错误样本,运行下一个样本进行补充。

PyTorch Dataset 类是一个抽象类,用于表示一个数据集,可以将数据和标签封装成一个可迭代的对象。要使用 Dataset 类,我们需要继承它,并实现两个方法:

  • __getitem__(self, index):根据给定的索引,返回数据集中的一个样本和对应的标签。
  • __len__(self):返回数据集中的样本数量。

即:

  1. 将数据获取封装成单独函数。
  2. 使用 while True 持续监控,如果运行正确,即 break 跳过。
  3. 如果运行失败,则打印日志,选择下一个样本运行,即 idx += 1
  4. 注意,索引不要溢出。

源码如下:

    def __getitem__(self, idx):
        # TODO: 解决数据异常问题,KeyError,尽量保持数据干净
        while True:
            try:
                feats = self.getitem_wrapper(idx)
                break
            except Exception as e:
                name = self.idx_to_chain_id(idx)
                logger.error(f"err sample: {
      
      name} !!!")
                idx += 1
                idx = idx % len(self._chain_ids)  # 避免溢出
        return feats

猜你喜欢

转载自blog.csdn.net/u012515223/article/details/133378772