PyTorch Lightning - LightningModule 训练逻辑 (training_step) 异常处理 try-except

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/133673820

LightningModule

在使用 LightningModule 框架训练模型时,因数据导致的训练错误,严重影响训练稳定性,因此需要使用 try-except 及时捕获错误。即 当错误发生时,在 training_step 异常返回 None,同时,on_before_zero_grad 也需要进行异常处理,处理 training_step 的异常返回 None。

同样的,validation_step 也可以这样处理。

源码如下:

class MyObject(pl.LightningModule):
	def __init__(self, config, args):
		# ...
		
	def training_step_wrapper(self, batch, batch_idx, log_interval=10):
		# train key process
		
	def training_step(self, batch, batch_idx, log_interval=10):
        """
        typically, each step costs 50 seconds
        参考: https://github.com/Lightning-AI/lightning/pull/3566
        """
        try:
            res = self.training_step_wrapper(batch, batch_idx, log_interval)
            return res
        except Exception as e:
            logger.info(f"[CL] training_step, exception: {
      
      e}")
            return None
            
	def on_before_zero_grad(self, *args, **kwargs):
        try:
            self.ema.update(self.model)
        except Exception as e:
            # 支持 training_step return None
            logger.info(f"[CL] on_before_zero_grad, exception: {
      
      e}")
            return
            
	def validation_step_wrapper(self, batch, batch_idx):
        # val key process

    def validation_step(self, batch, batch_idx):
        try:
            self.validation_step_wrapper(batch, batch_idx)
        except Exception as e:
            logger.info(f"[CL] validation_step, exception: {
      
      e}")
            return

常见错误如下

数组越界:

index 0 is out of bounds for dimension 0 with size 0

字典错误字段:

num_res = int(np_example["seq_length"])
KeyError: 'seq_length'

计算输入数值为空:

V, _, W = torch.linalg.svd(C)

free()异常:

free(): invalid next size (fast)

munmap_chunk() 空指针:

munmap_chunk(): invalid pointer

猜你喜欢

转载自blog.csdn.net/u012515223/article/details/133673820